1import shutil
2from pathlib import Path
3
4import joblib
5import smart_open
6
7from omegaml.backends.basecommon import BackendBaseCommon
8from omegaml.util import reshaped
9
10
[docs]
11class BaseModelBackend(BackendBaseCommon):
12 """
13 OmegaML BaseModelBackend to be subclassed by other arbitrary backends
14
15 This provides the abstract interface for any model backend to be implemented
16 Subclass to implement custom backends.
17
18 Essentially a model backend:
19
20 * provides methods to serialize and deserialize a machine learning model for a given ML framework
21 * offers fit() and predict() methods to be called by the runtime
22 * offers additional methods such as score(), partial_fit(), transform()
23
24 Model backends are the middleware that connects the om.models API to specific frameworks. This class
25 makes it simple to implement a model backend by offering a common syntax as well as a default implementation
26 for get() and put().
27
28 Methods to implement:
29 # for model serialization (mandatory)
30 @classmethod supports() - determine if backend supports given model instance
31 _package_model() - serialize a model instance into a temporary file
32 _extract_model() - deserialize the model from a file-like
33
34 By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized
35 Python objects. If this is not sufficient or applicable to your type models, override these
36 methods.
37
38 Both methods provide readily set up temporary file names so that all you have to do is actually
39 save the model to the given output file and restore the model from the given input file, respectively.
40 All other logic has already been implemented (see get_model and put_model methods).
41
42 # for fitting and predicting (mandatory)
43 fit()
44 predict()
45
46 # other methods (optional)
47 fit_transform() - fit and return a transformed dataset
48 partial_fit() - fit incrementally
49 predict_proba() - predict probabilities
50 score() - score fitted classifier vv test dataset
51
52 """
53 _backend_version_tag = '_om_backend_version'
54 _backend_version = '1'
55
56 serializer = lambda store, model, filename, **kwargs: joblib.dump(model, filename)[0]
57 loader = lambda store, infile, filename=None, **kwargs: joblib.load(infile)
58
59 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
60 assert model_store, "Need a model store"
61 assert data_store, "Need a data store"
62 self.model_store = model_store
63 self.data_store = data_store
64 self.tracking = tracking
65
66 @classmethod
67 def supports(self, obj, name, **kwargs):
68 """
69 test if this backend supports this obj
70 """
71 return False
72
73 @property
74 def _call_handler(self):
75 # the model store handles _pre and _post methods in self.perform()
76 return self.model_store
77
78 def get(self, name, uri=None, **kwargs):
79 """
80 retrieve a model
81
82 :param name: the name of the object
83 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
84 for the file's name
85 :param version: the version of the object (not supported)
86
87 .. versionadded: NEXT
88 uri specifies a target filename to store the serialized model
89 """
90 # support new backend architecture while keeping back compatibility
91 return self.get_model(name, uri=uri, **kwargs)
92
93 def put(self, obj, name, uri=None, **kwargs):
94 """
95 store a model
96
97 :param obj: the model object to be stored
98 :param name: the name of the object
99 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
100 for the file's name
101 :param attributes: attributes for meta data
102
103 .. versionadded: NEXT
104 local specifies a local filename to store the serialized model
105 """
106 # support new backend architecture while keeping back compatibility
107 return self.put_model(obj, name, uri=uri, **kwargs)
108
109 def drop(self, name, force=False, version=-1, **kwargs):
110 return self.model_store._drop(name, force=force, version=version)
111
112 def _package_model(self, model, key, tmpfn, serializer=None, **kwargs):
113 """
114 implement this method to serialize a model to the given tmpfn
115
116 Args:
117 model (object): the model object to serialize to a file
118 key (str): the object store's key for this object
119 tmpfn (str): the filename to store the serialized object to
120 serializer (callable): optional, a callable as serializer(store, model, filename, **kwargs),
121 defaults to self.serializer, using joblib.dump()
122 **kwargs (dict): optional, keyword arguments passed to the serializer
123
124 Returns:
125 tmpfn or absolute path of serialized file
126
127 .. versionchanged:: NEXT
128 enable custom serializer
129 """
130 serializer = serializer or getattr(self.serializer, '__func__') # __func__ is the unbound method
131 kwargs.setdefault('key', key)
132 tmpfn = serializer(self, model, tmpfn) or tmpfn
133 return tmpfn
134
135 def _extract_model(self, infile, key, tmpfn, loader=None, **kwargs):
136 """
137 implement this method to deserialize a model from the given infile
138
139 Args:
140 infile (filelike): this is a file-like object supporting read() and seek(). if
141 deserializing from this does not work directly, use tmpfn
142 key (str): the object store's key for this object
143 tmpfn (str): the filename from which to extract the object
144 loader (callable): optional, a callable as loader(store, filename, **kwargs),
145 defaults to self.loader, using joblib.load()
146 **kwargs (dict): optional, keyword arguments passed to the loader
147
148 Returns:
149 model instance
150
151 .. versionchanged:: NEXT
152 enable custom loader
153 """
154 loader = loader or getattr(self.loader, '__func__') # __func__ is the unbound method
155 kwargs.setdefault('filename', tmpfn)
156 kwargs.setdefault('key', key)
157 obj = loader(self, infile, **kwargs)
158 return obj
159
160 def _remove_path(self, path):
161 """
162 Remove a path, either a file or a directory
163
164 Args:
165 path (str): filename or path to remove. If path is a directory,
166 it will be removed recursively. If path is a file, it will be
167 removed.
168
169 Returns:
170 None
171 """
172 if Path(path).is_dir():
173 shutil.rmtree(path, ignore_errors=True)
174 else:
175 Path(path).unlink(missing_ok=True)
176
177 def get_model(self, name, version=-1, uri=None, loader=None, **kwargs):
178 """
179 Retrieves a pre-stored model
180 """
181 meta = self.model_store.metadata(name)
182 storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
183 uri = uri or meta.uri
184 if uri:
185 uri = str(uri).format(key=storekey)
186 infile = smart_open.open(uri, 'rb')
187 else:
188 infile = meta.gridfile
189 model = self._extract_model(infile, storekey,
190 self._tmp_packagefn(self.model_store, storekey),
191 loader=loader, **kwargs)
192 infile.close()
193 return model
194
195 def put_model(self, obj, name, attributes=None, _kind_version=None, uri=None, **kwargs):
196 """
197 Packages a model using joblib and stores in GridFS
198 """
199 storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
200 tmpfn = self._tmp_packagefn(self.model_store, storekey)
201 packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn
202 gridfile = self._store_to_file(self.model_store, packagefname, storekey, uri=uri)
203 self._remove_path(packagefname)
204 kind_meta = {
205 self._backend_version_tag: self._backend_version,
206 }
207 return self.model_store._make_metadata(
208 name=name,
209 prefix=self.model_store.prefix,
210 bucket=self.model_store.bucket,
211 kind=self.KIND,
212 kind_meta=kind_meta,
213 attributes=attributes,
214 uri=str(uri or ''),
215 gridfile=gridfile).save()
216
[docs]
217 def predict(
218 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
219 """
220 predict using data stored in Xname
221
222 :param modelname: the name of the model object
223 :param Xname: the name of the X data set
224 :param rName: the name of the result data object or None
225 :param pure_python: if True return a python object. If False return
226 a dataframe. Defaults to True to support any client.
227 :param kwargs: kwargs passed to the model's predict method
228 :return: return the predicted outcome
229 """
230 model = self.model_store.get(modelname)
231 data = self._resolve_input_data('predict', Xname, 'X', **kwargs)
232 if not hasattr(model, 'predict'):
233 raise NotImplementedError
234 result = model.predict(reshaped(data))
235 return self._prepare_result('predict', result, rName=rName,
236 pure_python=pure_python, **kwargs)
237
238 def _resolve_input_data(self, method, Xname, key, **kwargs):
239 data = self.data_store.get(Xname)
240 meta = self.data_store.metadata(Xname)
241 if self.tracking and getattr(self.tracking, 'autotrack', False):
242 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method)
243 return data
244
245 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs):
246 if pure_python:
247 result = result.tolist()
248 if rName:
249 meta = self.data_store.put(result, rName)
250 result = meta
251 if self.tracking and getattr(self.tracking, 'autotrack', False):
252 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind,
253 event=method)
254 return result
255
256 def predict_proba(
257 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
258 """
259 predict the probability using data stored in Xname
260
261 :param modelname: the name of the model object
262 :param Xname: the name of the X data set
263 :param rName: the name of the result data object or None
264 :param pure_python: if True return a python object. If False return
265 a dataframe. Defaults to True to support any client.
266 :param kwargs: kwargs passed to the model's predict method
267 :return: return the predicted outcome
268 """
269 raise NotImplementedError
270
[docs]
271 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
272 """
273 fit the model with data
274
275 :param modelname: the name of the model object
276 :param Xname: the name of the X data set
277 :param Yname: the name of the Y data set
278 :param pure_python: if True return a python object. If False return
279 a dataframe. Defaults to True to support any client.
280 :param kwargs: kwargs passed to the model's predict method
281 :return: return the meta data object of the model
282 """
283 raise NotImplementedError
284
285 def partial_fit(
286 self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
287 """
288 partially fit the model with data (online)
289
290 :param modelname: the name of the model object
291 :param Xname: the name of the X data set
292 :param Yname: the name of the Y data set
293 :param pure_python: if True return a python object. If False return
294 a dataframe. Defaults to True to support any client.
295 :param kwargs: kwargs passed to the model's predict method
296 :return: return the meta data object of the model
297 """
298
299 raise NotImplementedError
300
301 def fit_transform(
302 self, modelname, Xname, Yname=None, rName=None, pure_python=True,
303 **kwargs):
304 """
305 fit and transform using data
306
307 :param modelname: the name of the model object
308 :param Xname: the name of the X data set
309 :param Yname: the name of the Y data set
310 :param rName: the name of the transforms's result data object or None
311 :param pure_python: if True return a python object. If False return
312 a dataframe. Defaults to True to support any client.
313 :param kwargs: kwargs passed to the model's transform method
314 :return: return the meta data object of the model
315 """
316 raise NotImplementedError
317
329
330 def score(
331 self, modelname, Xname, Yname=None, rName=True, pure_python=True,
332 **kwargs):
333 """
334 score using data
335
336 :param modelname: the name of the model object
337 :param Xname: the name of the X data set
338 :param Yname: the name of the Y data set
339 :param rName: the name of the transforms's result data object or None
340 :param pure_python: if True return a python object. If False return
341 a dataframe. Defaults to True to support any client.
342 :param kwargs: kwargs passed to the model's predict method
343 :return: return the score result
344 """
345 raise NotImplementedError