Source code for omegaml.backends.basemodel

  1import shutil
  2from pathlib import Path
  3
  4import joblib
  5import smart_open
  6
  7from omegaml.backends.basecommon import BackendBaseCommon
  8from omegaml.util import reshaped
  9
 10
[docs] 11class BaseModelBackend(BackendBaseCommon): 12 """ 13 OmegaML BaseModelBackend to be subclassed by other arbitrary backends 14 15 This provides the abstract interface for any model backend to be implemented 16 Subclass to implement custom backends. 17 18 Essentially a model backend: 19 20 * provides methods to serialize and deserialize a machine learning model for a given ML framework 21 * offers fit() and predict() methods to be called by the runtime 22 * offers additional methods such as score(), partial_fit(), transform() 23 24 Model backends are the middleware that connects the om.models API to specific frameworks. This class 25 makes it simple to implement a model backend by offering a common syntax as well as a default implementation 26 for get() and put(). 27 28 Methods to implement: 29 # for model serialization (mandatory) 30 @classmethod supports() - determine if backend supports given model instance 31 _package_model() - serialize a model instance into a temporary file 32 _extract_model() - deserialize the model from a file-like 33 34 By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized 35 Python objects. If this is not sufficient or applicable to your type models, override these 36 methods. 37 38 Both methods provide readily set up temporary file names so that all you have to do is actually 39 save the model to the given output file and restore the model from the given input file, respectively. 40 All other logic has already been implemented (see get_model and put_model methods). 41 42 # for fitting and predicting (mandatory) 43 fit() 44 predict() 45 46 # other methods (optional) 47 fit_transform() - fit and return a transformed dataset 48 partial_fit() - fit incrementally 49 predict_proba() - predict probabilities 50 score() - score fitted classifier vv test dataset 51 52 """ 53 _backend_version_tag = '_om_backend_version' 54 _backend_version = '1' 55 56 serializer = lambda store, model, filename, **kwargs: joblib.dump(model, filename)[0] 57 loader = lambda store, infile, filename=None, **kwargs: joblib.load(infile) 58 59 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs): 60 assert model_store, "Need a model store" 61 assert data_store, "Need a data store" 62 self.model_store = model_store 63 self.data_store = data_store 64 self.tracking = tracking 65 66 @classmethod 67 def supports(self, obj, name, **kwargs): 68 """ 69 test if this backend supports this obj 70 """ 71 return False 72 73 @property 74 def _call_handler(self): 75 # the model store handles _pre and _post methods in self.perform() 76 return self.model_store 77 78 def get(self, name, uri=None, **kwargs): 79 """ 80 retrieve a model 81 82 :param name: the name of the object 83 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder 84 for the file's name 85 :param version: the version of the object (not supported) 86 87 .. versionadded: NEXT 88 uri specifies a target filename to store the serialized model 89 """ 90 # support new backend architecture while keeping back compatibility 91 return self.get_model(name, uri=uri, **kwargs) 92 93 def put(self, obj, name, uri=None, **kwargs): 94 """ 95 store a model 96 97 :param obj: the model object to be stored 98 :param name: the name of the object 99 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder 100 for the file's name 101 :param attributes: attributes for meta data 102 103 .. versionadded: NEXT 104 local specifies a local filename to store the serialized model 105 """ 106 # support new backend architecture while keeping back compatibility 107 return self.put_model(obj, name, uri=uri, **kwargs) 108 109 def drop(self, name, force=False, version=-1, **kwargs): 110 return self.model_store._drop(name, force=force, version=version) 111 112 def _package_model(self, model, key, tmpfn, serializer=None, **kwargs): 113 """ 114 implement this method to serialize a model to the given tmpfn 115 116 Args: 117 model (object): the model object to serialize to a file 118 key (str): the object store's key for this object 119 tmpfn (str): the filename to store the serialized object to 120 serializer (callable): optional, a callable as serializer(store, model, filename, **kwargs), 121 defaults to self.serializer, using joblib.dump() 122 **kwargs (dict): optional, keyword arguments passed to the serializer 123 124 Returns: 125 tmpfn or absolute path of serialized file 126 127 .. versionchanged:: NEXT 128 enable custom serializer 129 """ 130 serializer = serializer or getattr(self.serializer, '__func__') # __func__ is the unbound method 131 kwargs.setdefault('key', key) 132 tmpfn = serializer(self, model, tmpfn) or tmpfn 133 return tmpfn 134 135 def _extract_model(self, infile, key, tmpfn, loader=None, **kwargs): 136 """ 137 implement this method to deserialize a model from the given infile 138 139 Args: 140 infile (filelike): this is a file-like object supporting read() and seek(). if 141 deserializing from this does not work directly, use tmpfn 142 key (str): the object store's key for this object 143 tmpfn (str): the filename from which to extract the object 144 loader (callable): optional, a callable as loader(store, filename, **kwargs), 145 defaults to self.loader, using joblib.load() 146 **kwargs (dict): optional, keyword arguments passed to the loader 147 148 Returns: 149 model instance 150 151 .. versionchanged:: NEXT 152 enable custom loader 153 """ 154 loader = loader or getattr(self.loader, '__func__') # __func__ is the unbound method 155 kwargs.setdefault('filename', tmpfn) 156 kwargs.setdefault('key', key) 157 obj = loader(self, infile, **kwargs) 158 return obj 159 160 def _remove_path(self, path): 161 """ 162 Remove a path, either a file or a directory 163 164 Args: 165 path (str): filename or path to remove. If path is a directory, 166 it will be removed recursively. If path is a file, it will be 167 removed. 168 169 Returns: 170 None 171 """ 172 if Path(path).is_dir(): 173 shutil.rmtree(path, ignore_errors=True) 174 else: 175 Path(path).unlink(missing_ok=True) 176 177 def get_model(self, name, version=-1, uri=None, loader=None, **kwargs): 178 """ 179 Retrieves a pre-stored model 180 """ 181 meta = self.model_store.metadata(name) 182 storekey = self.model_store.object_store_key(name, 'omm', hashed=True) 183 uri = uri or meta.uri 184 if uri: 185 uri = str(uri).format(key=storekey) 186 infile = smart_open.open(uri, 'rb') 187 else: 188 infile = meta.gridfile 189 model = self._extract_model(infile, storekey, 190 self._tmp_packagefn(self.model_store, storekey), 191 loader=loader, **kwargs) 192 infile.close() 193 return model 194 195 def put_model(self, obj, name, attributes=None, _kind_version=None, uri=None, **kwargs): 196 """ 197 Packages a model using joblib and stores in GridFS 198 """ 199 storekey = self.model_store.object_store_key(name, 'omm', hashed=True) 200 tmpfn = self._tmp_packagefn(self.model_store, storekey) 201 packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn 202 gridfile = self._store_to_file(self.model_store, packagefname, storekey, uri=uri) 203 self._remove_path(packagefname) 204 kind_meta = { 205 self._backend_version_tag: self._backend_version, 206 } 207 return self.model_store._make_metadata( 208 name=name, 209 prefix=self.model_store.prefix, 210 bucket=self.model_store.bucket, 211 kind=self.KIND, 212 kind_meta=kind_meta, 213 attributes=attributes, 214 uri=str(uri or ''), 215 gridfile=gridfile).save() 216
[docs] 217 def predict( 218 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 219 """ 220 predict using data stored in Xname 221 222 :param modelname: the name of the model object 223 :param Xname: the name of the X data set 224 :param rName: the name of the result data object or None 225 :param pure_python: if True return a python object. If False return 226 a dataframe. Defaults to True to support any client. 227 :param kwargs: kwargs passed to the model's predict method 228 :return: return the predicted outcome 229 """ 230 model = self.model_store.get(modelname) 231 data = self._resolve_input_data('predict', Xname, 'X', **kwargs) 232 if not hasattr(model, 'predict'): 233 raise NotImplementedError 234 result = model.predict(reshaped(data)) 235 return self._prepare_result('predict', result, rName=rName, 236 pure_python=pure_python, **kwargs)
237 238 def _resolve_input_data(self, method, Xname, key, **kwargs): 239 data = self.data_store.get(Xname) 240 meta = self.data_store.metadata(Xname) 241 if self.tracking and getattr(self.tracking, 'autotrack', False): 242 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method) 243 return data 244 245 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs): 246 if pure_python: 247 result = result.tolist() 248 if rName: 249 meta = self.data_store.put(result, rName) 250 result = meta 251 if self.tracking and getattr(self.tracking, 'autotrack', False): 252 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind, 253 event=method) 254 return result 255 256 def predict_proba( 257 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 258 """ 259 predict the probability using data stored in Xname 260 261 :param modelname: the name of the model object 262 :param Xname: the name of the X data set 263 :param rName: the name of the result data object or None 264 :param pure_python: if True return a python object. If False return 265 a dataframe. Defaults to True to support any client. 266 :param kwargs: kwargs passed to the model's predict method 267 :return: return the predicted outcome 268 """ 269 raise NotImplementedError 270
[docs] 271 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 272 """ 273 fit the model with data 274 275 :param modelname: the name of the model object 276 :param Xname: the name of the X data set 277 :param Yname: the name of the Y data set 278 :param pure_python: if True return a python object. If False return 279 a dataframe. Defaults to True to support any client. 280 :param kwargs: kwargs passed to the model's predict method 281 :return: return the meta data object of the model 282 """ 283 raise NotImplementedError
284 285 def partial_fit( 286 self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 287 """ 288 partially fit the model with data (online) 289 290 :param modelname: the name of the model object 291 :param Xname: the name of the X data set 292 :param Yname: the name of the Y data set 293 :param pure_python: if True return a python object. If False return 294 a dataframe. Defaults to True to support any client. 295 :param kwargs: kwargs passed to the model's predict method 296 :return: return the meta data object of the model 297 """ 298 299 raise NotImplementedError 300 301 def fit_transform( 302 self, modelname, Xname, Yname=None, rName=None, pure_python=True, 303 **kwargs): 304 """ 305 fit and transform using data 306 307 :param modelname: the name of the model object 308 :param Xname: the name of the X data set 309 :param Yname: the name of the Y data set 310 :param rName: the name of the transforms's result data object or None 311 :param pure_python: if True return a python object. If False return 312 a dataframe. Defaults to True to support any client. 313 :param kwargs: kwargs passed to the model's transform method 314 :return: return the meta data object of the model 315 """ 316 raise NotImplementedError 317
[docs] 318 def transform(self, modelname, Xname, rName=None, **kwargs): 319 """ 320 transform using data 321 322 :param modelname: the name of the model object 323 :param Xname: the name of the X data set 324 :param rName: the name of the transforms's result data object or None 325 :param kwargs: kwargs passed to the model's transform method 326 :return: return the transform data of the model 327 """ 328 raise NotImplementedError
329 330 def score( 331 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 332 **kwargs): 333 """ 334 score using data 335 336 :param modelname: the name of the model object 337 :param Xname: the name of the X data set 338 :param Yname: the name of the Y data set 339 :param rName: the name of the transforms's result data object or None 340 :param pure_python: if True return a python object. If False return 341 a dataframe. Defaults to True to support any client. 342 :param kwargs: kwargs passed to the model's predict method 343 :return: return the score result 344 """ 345 raise NotImplementedError