1import shutil
2from pathlib import Path
3
4import joblib
5import smart_open
6
7from omegaml.backends.basecommon import BackendBaseCommon
8from omegaml.util import reshaped
9
10
11class BaseModelBackend(BackendBaseCommon):
12 """
13 OmegaML BaseModelBackend to be subclassed by other arbitrary backends
14
15 This provides the abstract interface for any model backend to be implemented
16 Subclass to implement custom backends.
17
18 Essentially a model backend:
19
20 * provides methods to serialize and deserialize a machine learning model for a given ML framework
21 * offers fit() and predict() methods to be called by the runtime
22 * offers additional methods such as score(), partial_fit(), transform()
23
24 Model backends are the middleware that connects the om.models API to specific frameworks. This class
25 makes it simple to implement a model backend by offering a common syntax as well as a default implementation
26 for get() and put().
27
28 Methods to implement:
29 # for model serialization (mandatory)
30 @classmethod supports() - determine if backend supports given model instance
31 _package_model() - serialize a model instance into a temporary file
32 _extract_model() - deserialize the model from a file-like
33
34 By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized
35 Python objects. If this is not sufficient or applicable to your type models, override these
36 methods.
37
38 Both methods provide readily set up temporary file names so that all you have to do is actually
39 save the model to the given output file and restore the model from the given input file, respectively.
40 All other logic has already been implemented (see get_model and put_model methods).
41
42 # for fitting and predicting (mandatory)
43 fit()
44 predict()
45
46 # other methods (optional)
47 fit_transform() - fit and return a transformed dataset
48 partial_fit() - fit incrementally
49 predict_proba() - predict probabilities
50 score() - score fitted classifier vv test dataset
51
52 """
53 _backend_version_tag = '_om_backend_version'
54 _backend_version = '1'
55
56 serializer = lambda store, model, filename, **kwargs: joblib.dump(model, filename)[0]
57 loader = lambda store, infile, filename=None, **kwargs: joblib.load(infile)
58 infer = lambda obj, **kwargs: getattr(obj, 'predict')
59 reshape = lambda data, **kwargs: reshaped(data)
60 types = None
61
62 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
63 assert model_store, "Need a model store"
64 assert data_store, "Need a data store"
65 self.model_store = model_store
66 self.data_store = data_store
67 self.tracking = tracking
68
69 @classmethod
70 def supports(self, obj, name, **kwargs):
71 """
72 test if this backend supports this obj
73 """
74 return isinstance(obj, self.types) if self.types else False
75
76 @property
77 def _call_handler(self):
78 # the model store handles _pre and _post methods in self.perform()
79 return self.model_store
80
81 def get(self, name, uri=None, **kwargs):
82 """
83 retrieve a model
84
85 :param name: the name of the object
86 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
87 for the file's name
88 :param version: the version of the object (not supported)
89
90 .. versionadded: 0.18
91 uri specifies a target filename to store the serialized model
92 """
93 # support new backend architecture while keeping back compatibility
94 return self.get_model(name, uri=uri, **kwargs)
95
96 def put(self, obj, name, uri=None, **kwargs):
97 """
98 store a model
99
100 :param obj: the model object to be stored
101 :param name: the name of the object
102 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
103 for the file's name
104 :param attributes: attributes for meta data
105
106 .. versionadded: 0.18
107 local specifies a local filename to store the serialized model
108 """
109 # support new backend architecture while keeping back compatibility
110 return self.put_model(obj, name, uri=uri, **kwargs)
111
112 def drop(self, name, force=False, version=-1, **kwargs):
113 return self.model_store._drop(name, force=force, version=version)
114
115 def _package_model(self, model, key, tmpfn, serializer=None, **kwargs):
116 """
117 implement this method to serialize a model to the given tmpfn
118
119 Args:
120 model (object): the model object to serialize to a file
121 key (str): the object store's key for this object
122 tmpfn (str): the filename to store the serialized object to
123 serializer (callable): optional, a callable as serializer(store, model, filename, **kwargs),
124 defaults to self.serializer, using joblib.dump()
125 **kwargs (dict): optional, keyword arguments passed to the serializer
126
127 Returns:
128 tmpfn or absolute path of serialized file
129
130 .. versionchanged:: 0.18
131 enable custom serializer
132 """
133 serializer = serializer or getattr(self.serializer, '__func__') # __func__ is the unbound method
134 kwargs.setdefault('key', key)
135 tmpfn = serializer(self, model, tmpfn) or tmpfn
136 return tmpfn
137
138 def _extract_model(self, infile, key, tmpfn, loader=None, **kwargs):
139 """
140 implement this method to deserialize a model from the given infile
141
142 Args:
143 infile (filelike): this is a file-like object supporting read() and seek(). if
144 deserializing from this does not work directly, use tmpfn
145 key (str): the object store's key for this object
146 tmpfn (str): the filename from which to extract the object
147 loader (callable): optional, a callable as loader(store, filename, **kwargs),
148 defaults to self.loader, using joblib.load()
149 **kwargs (dict): optional, keyword arguments passed to the loader
150
151 Returns:
152 model instance
153
154 .. versionchanged:: 0.18
155 enable custom loader
156 """
157 loader = loader or getattr(self.loader, '__func__') # __func__ is the unbound method
158 kwargs.setdefault('filename', tmpfn)
159 kwargs.setdefault('key', key)
160 obj = loader(self, infile, **kwargs)
161 return obj
162
163 def _remove_path(self, path):
164 """
165 Remove a path, either a file or a directory
166
167 Args:
168 path (str): filename or path to remove. If path is a directory,
169 it will be removed recursively. If path is a file, it will be
170 removed.
171
172 Returns:
173 None
174 """
175 if Path(path).is_dir():
176 shutil.rmtree(path, ignore_errors=True)
177 else:
178 Path(path).unlink(missing_ok=True)
179
180 def get_model(self, name, version=-1, uri=None, loader=None, **kwargs):
181 """
182 Retrieves a pre-stored model
183 """
184 meta = self.model_store.metadata(name)
185 storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
186 uri = uri or meta.uri
187 if uri:
188 uri = str(uri).format(key=storekey)
189 infile = smart_open.open(uri, 'rb')
190 else:
191 infile = meta.gridfile
192 model = self._extract_model(infile, storekey,
193 self._tmp_packagefn(self.model_store, storekey),
194 loader=loader, **kwargs)
195 infile.close()
196 return model
197
198 def put_model(self, obj, name, attributes=None, _kind_version=None, uri=None, **kwargs):
199 """
200 Packages a model using joblib and stores in GridFS
201 """
202 storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
203 tmpfn = self._tmp_packagefn(self.model_store, storekey)
204 packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn
205 gridfile = self._store_to_file(self.model_store, packagefname, storekey, uri=uri)
206 self._remove_path(packagefname)
207 kind_meta = {
208 self._backend_version_tag: self._backend_version,
209 }
210 return self.model_store._make_metadata(
211 name=name,
212 prefix=self.model_store.prefix,
213 bucket=self.model_store.bucket,
214 kind=self.KIND,
215 kind_meta=kind_meta,
216 attributes=attributes,
217 uri=str(uri or ''),
218 gridfile=gridfile).save()
219
[docs]
220 def predict(
221 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
222 """
223 predict using data stored in Xname
224
225 :param modelname: the name of the model object
226 :param Xname: the name of the X data set
227 :param rName: the name of the result data object or None
228 :param pure_python: if True return a python object. If False return
229 a dataframe. Defaults to True to support any client.
230 :param kwargs: kwargs passed to the model's predict method
231 :return: return the predicted outcome
232 """
233 model = self.model_store.get(modelname)
234 data = self._resolve_input_data('predict', Xname, 'X', **kwargs)
235 infer = getattr(self.infer, '__func__') # __func__ is the unbound method
236 reshape = getattr(self.reshape, '__func__')
237 result = infer(model)(reshape(data))
238 return self._prepare_result('predict', result, rName=rName,
239 pure_python=pure_python, **kwargs)
240
241 def _resolve_input_data(self, method, Xname, key, **kwargs):
242 data = self.data_store.get(Xname)
243 meta = self.data_store.metadata(Xname)
244 if self.tracking and getattr(self.tracking, 'autotrack', False):
245 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method)
246 return data
247
248 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs):
249 if pure_python:
250 result = result.tolist()
251 if rName:
252 meta = self.data_store.put(result, rName)
253 result = meta
254 if self.tracking and getattr(self.tracking, 'autotrack', False):
255 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind,
256 event=method)
257 return result
258
259 def predict_proba(
260 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
261 """
262 predict the probability using data stored in Xname
263
264 :param modelname: the name of the model object
265 :param Xname: the name of the X data set
266 :param rName: the name of the result data object or None
267 :param pure_python: if True return a python object. If False return
268 a dataframe. Defaults to True to support any client.
269 :param kwargs: kwargs passed to the model's predict method
270 :return: return the predicted outcome
271 """
272 raise NotImplementedError
273
[docs]
274 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
275 """
276 fit the model with data
277
278 :param modelname: the name of the model object
279 :param Xname: the name of the X data set
280 :param Yname: the name of the Y data set
281 :param pure_python: if True return a python object. If False return
282 a dataframe. Defaults to True to support any client.
283 :param kwargs: kwargs passed to the model's predict method
284 :return: return the meta data object of the model
285 """
286 raise NotImplementedError
287
288 def partial_fit(
289 self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
290 """
291 partially fit the model with data (online)
292
293 :param modelname: the name of the model object
294 :param Xname: the name of the X data set
295 :param Yname: the name of the Y data set
296 :param pure_python: if True return a python object. If False return
297 a dataframe. Defaults to True to support any client.
298 :param kwargs: kwargs passed to the model's predict method
299 :return: return the meta data object of the model
300 """
301
302 raise NotImplementedError
303
304 def fit_transform(
305 self, modelname, Xname, Yname=None, rName=None, pure_python=True,
306 **kwargs):
307 """
308 fit and transform using data
309
310 :param modelname: the name of the model object
311 :param Xname: the name of the X data set
312 :param Yname: the name of the Y data set
313 :param rName: the name of the transforms's result data object or None
314 :param pure_python: if True return a python object. If False return
315 a dataframe. Defaults to True to support any client.
316 :param kwargs: kwargs passed to the model's transform method
317 :return: return the meta data object of the model
318 """
319 raise NotImplementedError
320
332
333 def score(
334 self, modelname, Xname, Yname=None, rName=True, pure_python=True,
335 **kwargs):
336 """
337 score using data
338
339 :param modelname: the name of the model object
340 :param Xname: the name of the X data set
341 :param Yname: the name of the Y data set
342 :param rName: the name of the transforms's result data object or None
343 :param pure_python: if True return a python object. If False return
344 a dataframe. Defaults to True to support any client.
345 :param kwargs: kwargs passed to the model's predict method
346 :return: return the score result
347 """
348 raise NotImplementedError