1import joblib
2import shutil
3from omegaml.backends.basecommon import BackendBaseCommon
4from omegaml.util import reshaped
5from pathlib import Path
6
7
8class BaseModelBackend(BackendBaseCommon):
9 """
10 OmegaML BaseModelBackend to be subclassed by other arbitrary backends
11
12 This provides the abstract interface for any model backend to be implemented
13 Subclass to implement custom backends.
14
15 Essentially a model backend:
16
17 * provides methods to serialize and deserialize a machine learning model for a given ML framework
18 * offers fit() and predict() methods to be called by the runtime
19 * offers additional methods such as score(), partial_fit(), transform()
20
21 Model backends are the middleware that connects the om.models API to specific frameworks. This class
22 makes it simple to implement a model backend by offering a common syntax as well as a default implementation
23 for get() and put().
24
25 Methods to implement:
26 # for model serialization (mandatory)
27 @classmethod supports() - determine if backend supports given model instance
28 _package_model() - serialize a model instance into a temporary file
29 _extract_model() - deserialize the model from a file-like
30
31 By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized
32 Python objects. If this is not sufficient or applicable to your type models, override these
33 methods.
34
35 Both methods provide readily set up temporary file names so that all you have to do is actually
36 save the model to the given output file and restore the model from the given input file, respectively.
37 All other logic has already been implemented (see get_model and put_model methods).
38
39 # for fitting and predicting (mandatory)
40 fit()
41 predict()
42
43 # other methods (optional)
44 fit_transform() - fit and return a transformed dataset
45 partial_fit() - fit incrementally
46 predict_proba() - predict probabilities
47 score() - score fitted classifier vv test dataset
48
49 """
50 _backend_version_tag = '_om_backend_version'
51 _backend_version = '1'
52
53 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
54 assert model_store, "Need a model store"
55 assert data_store, "Need a data store"
56 self.model_store = model_store
57 self.data_store = data_store
58 self.tracking = tracking
59
60 @classmethod
61 def supports(self, obj, name, **kwargs):
62 """
63 test if this backend supports this obj
64 """
65 return False
66
67 @property
68 def _call_handler(self):
69 # the model store handles _pre and _post methods in self.perform()
70 return self.model_store
71
72 def get(self, name, **kwargs):
73 """
74 retrieve a model
75
76 :param name: the name of the object
77 :param version: the version of the object (not supported)
78 """
79 # support new backend architecture while keeping back compatibility
80 return self.get_model(name, **kwargs)
81
82 def put(self, obj, name, **kwargs):
83 """
84 store a model
85
86 :param obj: the model object to be stored
87 :param name: the name of the object
88 :param attributes: attributes for meta data
89 """
90 # support new backend architecture while keeping back compatibility
91 return self.put_model(obj, name, **kwargs)
92
93 def drop(self, name, force=False, version=-1, **kwargs):
94 return self.model_store._drop(name, force=force, version=version)
95
96 def _package_model(self, model, key, tmpfn, **kwargs):
97 """
98 implement this method to serialize a model to the given tmpfn
99
100 Args:
101 model:
102 key:
103 tmpfn:
104 **kwargs:
105
106 Returns:
107 tmpfn or absolute path of serialized file
108 """
109 with open(tmpfn, 'wb') as outf:
110 joblib.dump(model, outf)
111 return tmpfn
112
113 def _extract_model(self, infile, key, tmpfn, **kwargs):
114 """
115 implement this method to deserialize a model from the given infile
116
117 Args:
118 infile: this is a file-like object supporting read() and seek(). if
119 deserializing from this does not work directly, use tmpfn
120 key:
121 tmpfn:
122 **kwargs:
123
124 Returns:
125 model instance
126 """
127 obj = joblib.load(infile)
128 return obj
129
130 def _remove_path(self, path):
131 """
132 Remove a path, either a file or a directory
133
134 Args:
135 path (str): filename or path to remove. If path is a directory,
136 it will be removed recursively. If path is a file, it will be
137 removed.
138
139 Returns:
140 None
141 """
142 if Path(path).is_dir():
143 shutil.rmtree(path, ignore_errors=True)
144 else:
145 Path(path).unlink(missing_ok=True)
146
147 def get_model(self, name, version=-1, **kwargs):
148 """
149 Retrieves a pre-stored model
150 """
151 meta = self.model_store.metadata(name)
152 storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
153 model = self._extract_model(meta.gridfile, storekey,
154 self._tmp_packagefn(self.model_store, storekey), **kwargs)
155 return model
156
157 def put_model(self, obj, name, attributes=None, _kind_version=None, **kwargs):
158 """
159 Packages a model using joblib and stores in GridFS
160 """
161 storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
162 tmpfn = self._tmp_packagefn(self.model_store, storekey)
163 packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn
164 gridfile = self._store_to_file(self.model_store, packagefname, storekey)
165 self._remove_path(packagefname)
166 kind_meta = {
167 self._backend_version_tag: self._backend_version,
168 }
169 return self.model_store._make_metadata(
170 name=name,
171 prefix=self.model_store.prefix,
172 bucket=self.model_store.bucket,
173 kind=self.KIND,
174 kind_meta=kind_meta,
175 attributes=attributes,
176 gridfile=gridfile).save()
177
[docs]
178 def predict(
179 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
180 """
181 predict using data stored in Xname
182
183 :param modelname: the name of the model object
184 :param Xname: the name of the X data set
185 :param rName: the name of the result data object or None
186 :param pure_python: if True return a python object. If False return
187 a dataframe. Defaults to True to support any client.
188 :param kwargs: kwargs passed to the model's predict method
189 :return: return the predicted outcome
190 """
191 model = self.model_store.get(modelname)
192 data = self._resolve_input_data('predict', Xname, 'X', **kwargs)
193 if not hasattr(model, 'predict'):
194 raise NotImplementedError
195 result = model.predict(reshaped(data))
196 return self._prepare_result('predict', result, rName=rName,
197 pure_python=pure_python, **kwargs)
198
199 def _resolve_input_data(self, method, Xname, key, **kwargs):
200 data = self.data_store.get(Xname)
201 meta = self.data_store.metadata(Xname)
202 if self.tracking and getattr(self.tracking, 'autotrack', False):
203 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method)
204 return data
205
206 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs):
207 if pure_python:
208 result = result.tolist()
209 if rName:
210 meta = self.data_store.put(result, rName)
211 result = meta
212 if self.tracking and getattr(self.tracking, 'autotrack', False):
213 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind,
214 event=method)
215 return result
216
217 def predict_proba(
218 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
219 """
220 predict the probability using data stored in Xname
221
222 :param modelname: the name of the model object
223 :param Xname: the name of the X data set
224 :param rName: the name of the result data object or None
225 :param pure_python: if True return a python object. If False return
226 a dataframe. Defaults to True to support any client.
227 :param kwargs: kwargs passed to the model's predict method
228 :return: return the predicted outcome
229 """
230 raise NotImplementedError
231
[docs]
232 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
233 """
234 fit the model with data
235
236 :param modelname: the name of the model object
237 :param Xname: the name of the X data set
238 :param Yname: the name of the Y data set
239 :param pure_python: if True return a python object. If False return
240 a dataframe. Defaults to True to support any client.
241 :param kwargs: kwargs passed to the model's predict method
242 :return: return the meta data object of the model
243 """
244 raise NotImplementedError
245
246 def partial_fit(
247 self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
248 """
249 partially fit the model with data (online)
250
251 :param modelname: the name of the model object
252 :param Xname: the name of the X data set
253 :param Yname: the name of the Y data set
254 :param pure_python: if True return a python object. If False return
255 a dataframe. Defaults to True to support any client.
256 :param kwargs: kwargs passed to the model's predict method
257 :return: return the meta data object of the model
258 """
259
260 raise NotImplementedError
261
262 def fit_transform(
263 self, modelname, Xname, Yname=None, rName=None, pure_python=True,
264 **kwargs):
265 """
266 fit and transform using data
267
268 :param modelname: the name of the model object
269 :param Xname: the name of the X data set
270 :param Yname: the name of the Y data set
271 :param rName: the name of the transforms's result data object or None
272 :param pure_python: if True return a python object. If False return
273 a dataframe. Defaults to True to support any client.
274 :param kwargs: kwargs passed to the model's transform method
275 :return: return the meta data object of the model
276 """
277 raise NotImplementedError
278
290
291 def score(
292 self, modelname, Xname, Yname=None, rName=True, pure_python=True,
293 **kwargs):
294 """
295 score using data
296
297 :param modelname: the name of the model object
298 :param Xname: the name of the X data set
299 :param Yname: the name of the Y data set
300 :param rName: the name of the transforms's result data object or None
301 :param pure_python: if True return a python object. If False return
302 a dataframe. Defaults to True to support any client.
303 :param kwargs: kwargs passed to the model's predict method
304 :return: return the score result
305 """
306 raise NotImplementedError