Source code for omegaml.backends.basemodel

  1import shutil
  2from pathlib import Path
  3
  4import joblib
  5import smart_open
  6
  7from omegaml.backends.basecommon import BackendBaseCommon
  8from omegaml.util import reshaped
  9
 10
[docs] 11class BaseModelBackend(BackendBaseCommon): 12 """ 13 OmegaML BaseModelBackend to be subclassed by other arbitrary backends 14 15 This provides the abstract interface for any model backend to be implemented 16 Subclass to implement custom backends. 17 18 Essentially a model backend: 19 20 * provides methods to serialize and deserialize a machine learning model for a given ML framework 21 * offers fit() and predict() methods to be called by the runtime 22 * offers additional methods such as score(), partial_fit(), transform() 23 24 Model backends are the middleware that connects the om.models API to specific frameworks. This class 25 makes it simple to implement a model backend by offering a common syntax as well as a default implementation 26 for get() and put(). 27 28 Methods to implement: 29 # for model serialization (mandatory) 30 @classmethod supports() - determine if backend supports given model instance 31 _package_model() - serialize a model instance into a temporary file 32 _extract_model() - deserialize the model from a file-like 33 34 By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized 35 Python objects. If this is not sufficient or applicable to your type models, override these 36 methods. 37 38 Both methods provide readily set up temporary file names so that all you have to do is actually 39 save the model to the given output file and restore the model from the given input file, respectively. 40 All other logic has already been implemented (see get_model and put_model methods). 41 42 # for fitting and predicting (mandatory) 43 fit() 44 predict() 45 46 # other methods (optional) 47 fit_transform() - fit and return a transformed dataset 48 partial_fit() - fit incrementally 49 predict_proba() - predict probabilities 50 score() - score fitted classifier vv test dataset 51 52 """ 53 _backend_version_tag = '_om_backend_version' 54 _backend_version = '1' 55 56 serializer = lambda store, model, filename, **kwargs: joblib.dump(model, filename)[0] 57 loader = lambda store, infile, filename=None, **kwargs: joblib.load(infile) 58 infer = lambda obj, **kwargs: getattr(obj, 'predict') 59 reshape = lambda data, **kwargs: reshaped(data) 60 types = None 61 62 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs): 63 assert model_store, "Need a model store" 64 assert data_store, "Need a data store" 65 self.model_store = model_store 66 self.data_store = data_store 67 self.tracking = tracking 68 69 @classmethod 70 def supports(self, obj, name, **kwargs): 71 """ 72 test if this backend supports this obj 73 """ 74 return isinstance(obj, self.types) if self.types else False 75 76 @property 77 def _call_handler(self): 78 # the model store handles _pre and _post methods in self.perform() 79 return self.model_store 80 81 def get(self, name, uri=None, **kwargs): 82 """ 83 retrieve a model 84 85 :param name: the name of the object 86 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder 87 for the file's name 88 :param version: the version of the object (not supported) 89 90 .. versionadded: NEXT 91 uri specifies a target filename to store the serialized model 92 """ 93 # support new backend architecture while keeping back compatibility 94 return self.get_model(name, uri=uri, **kwargs) 95 96 def put(self, obj, name, uri=None, **kwargs): 97 """ 98 store a model 99 100 :param obj: the model object to be stored 101 :param name: the name of the object 102 :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder 103 for the file's name 104 :param attributes: attributes for meta data 105 106 .. versionadded: NEXT 107 local specifies a local filename to store the serialized model 108 """ 109 # support new backend architecture while keeping back compatibility 110 return self.put_model(obj, name, uri=uri, **kwargs) 111 112 def drop(self, name, force=False, version=-1, **kwargs): 113 return self.model_store._drop(name, force=force, version=version) 114 115 def _package_model(self, model, key, tmpfn, serializer=None, **kwargs): 116 """ 117 implement this method to serialize a model to the given tmpfn 118 119 Args: 120 model (object): the model object to serialize to a file 121 key (str): the object store's key for this object 122 tmpfn (str): the filename to store the serialized object to 123 serializer (callable): optional, a callable as serializer(store, model, filename, **kwargs), 124 defaults to self.serializer, using joblib.dump() 125 **kwargs (dict): optional, keyword arguments passed to the serializer 126 127 Returns: 128 tmpfn or absolute path of serialized file 129 130 .. versionchanged:: NEXT 131 enable custom serializer 132 """ 133 serializer = serializer or getattr(self.serializer, '__func__') # __func__ is the unbound method 134 kwargs.setdefault('key', key) 135 tmpfn = serializer(self, model, tmpfn) or tmpfn 136 return tmpfn 137 138 def _extract_model(self, infile, key, tmpfn, loader=None, **kwargs): 139 """ 140 implement this method to deserialize a model from the given infile 141 142 Args: 143 infile (filelike): this is a file-like object supporting read() and seek(). if 144 deserializing from this does not work directly, use tmpfn 145 key (str): the object store's key for this object 146 tmpfn (str): the filename from which to extract the object 147 loader (callable): optional, a callable as loader(store, filename, **kwargs), 148 defaults to self.loader, using joblib.load() 149 **kwargs (dict): optional, keyword arguments passed to the loader 150 151 Returns: 152 model instance 153 154 .. versionchanged:: NEXT 155 enable custom loader 156 """ 157 loader = loader or getattr(self.loader, '__func__') # __func__ is the unbound method 158 kwargs.setdefault('filename', tmpfn) 159 kwargs.setdefault('key', key) 160 obj = loader(self, infile, **kwargs) 161 return obj 162 163 def _remove_path(self, path): 164 """ 165 Remove a path, either a file or a directory 166 167 Args: 168 path (str): filename or path to remove. If path is a directory, 169 it will be removed recursively. If path is a file, it will be 170 removed. 171 172 Returns: 173 None 174 """ 175 if Path(path).is_dir(): 176 shutil.rmtree(path, ignore_errors=True) 177 else: 178 Path(path).unlink(missing_ok=True) 179 180 def get_model(self, name, version=-1, uri=None, loader=None, **kwargs): 181 """ 182 Retrieves a pre-stored model 183 """ 184 meta = self.model_store.metadata(name) 185 storekey = self.model_store.object_store_key(name, 'omm', hashed=True) 186 uri = uri or meta.uri 187 if uri: 188 uri = str(uri).format(key=storekey) 189 infile = smart_open.open(uri, 'rb') 190 else: 191 infile = meta.gridfile 192 model = self._extract_model(infile, storekey, 193 self._tmp_packagefn(self.model_store, storekey), 194 loader=loader, **kwargs) 195 infile.close() 196 return model 197 198 def put_model(self, obj, name, attributes=None, _kind_version=None, uri=None, **kwargs): 199 """ 200 Packages a model using joblib and stores in GridFS 201 """ 202 storekey = self.model_store.object_store_key(name, 'omm', hashed=True) 203 tmpfn = self._tmp_packagefn(self.model_store, storekey) 204 packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn 205 gridfile = self._store_to_file(self.model_store, packagefname, storekey, uri=uri) 206 self._remove_path(packagefname) 207 kind_meta = { 208 self._backend_version_tag: self._backend_version, 209 } 210 return self.model_store._make_metadata( 211 name=name, 212 prefix=self.model_store.prefix, 213 bucket=self.model_store.bucket, 214 kind=self.KIND, 215 kind_meta=kind_meta, 216 attributes=attributes, 217 uri=str(uri or ''), 218 gridfile=gridfile).save() 219
[docs] 220 def predict( 221 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 222 """ 223 predict using data stored in Xname 224 225 :param modelname: the name of the model object 226 :param Xname: the name of the X data set 227 :param rName: the name of the result data object or None 228 :param pure_python: if True return a python object. If False return 229 a dataframe. Defaults to True to support any client. 230 :param kwargs: kwargs passed to the model's predict method 231 :return: return the predicted outcome 232 """ 233 model = self.model_store.get(modelname) 234 data = self._resolve_input_data('predict', Xname, 'X', **kwargs) 235 infer = getattr(self.infer, '__func__') # __func__ is the unbound method 236 reshape = getattr(self.reshape, '__func__') 237 result = infer(model)(reshape(data)) 238 return self._prepare_result('predict', result, rName=rName, 239 pure_python=pure_python, **kwargs)
240 241 def _resolve_input_data(self, method, Xname, key, **kwargs): 242 data = self.data_store.get(Xname) 243 meta = self.data_store.metadata(Xname) 244 if self.tracking and getattr(self.tracking, 'autotrack', False): 245 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method) 246 return data 247 248 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs): 249 if pure_python: 250 result = result.tolist() 251 if rName: 252 meta = self.data_store.put(result, rName) 253 result = meta 254 if self.tracking and getattr(self.tracking, 'autotrack', False): 255 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind, 256 event=method) 257 return result 258 259 def predict_proba( 260 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 261 """ 262 predict the probability using data stored in Xname 263 264 :param modelname: the name of the model object 265 :param Xname: the name of the X data set 266 :param rName: the name of the result data object or None 267 :param pure_python: if True return a python object. If False return 268 a dataframe. Defaults to True to support any client. 269 :param kwargs: kwargs passed to the model's predict method 270 :return: return the predicted outcome 271 """ 272 raise NotImplementedError 273
[docs] 274 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 275 """ 276 fit the model with data 277 278 :param modelname: the name of the model object 279 :param Xname: the name of the X data set 280 :param Yname: the name of the Y data set 281 :param pure_python: if True return a python object. If False return 282 a dataframe. Defaults to True to support any client. 283 :param kwargs: kwargs passed to the model's predict method 284 :return: return the meta data object of the model 285 """ 286 raise NotImplementedError
287 288 def partial_fit( 289 self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 290 """ 291 partially fit the model with data (online) 292 293 :param modelname: the name of the model object 294 :param Xname: the name of the X data set 295 :param Yname: the name of the Y data set 296 :param pure_python: if True return a python object. If False return 297 a dataframe. Defaults to True to support any client. 298 :param kwargs: kwargs passed to the model's predict method 299 :return: return the meta data object of the model 300 """ 301 302 raise NotImplementedError 303 304 def fit_transform( 305 self, modelname, Xname, Yname=None, rName=None, pure_python=True, 306 **kwargs): 307 """ 308 fit and transform using data 309 310 :param modelname: the name of the model object 311 :param Xname: the name of the X data set 312 :param Yname: the name of the Y data set 313 :param rName: the name of the transforms's result data object or None 314 :param pure_python: if True return a python object. If False return 315 a dataframe. Defaults to True to support any client. 316 :param kwargs: kwargs passed to the model's transform method 317 :return: return the meta data object of the model 318 """ 319 raise NotImplementedError 320
[docs] 321 def transform(self, modelname, Xname, rName=None, **kwargs): 322 """ 323 transform using data 324 325 :param modelname: the name of the model object 326 :param Xname: the name of the X data set 327 :param rName: the name of the transforms's result data object or None 328 :param kwargs: kwargs passed to the model's transform method 329 :return: return the transform data of the model 330 """ 331 raise NotImplementedError
332 333 def score( 334 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 335 **kwargs): 336 """ 337 score using data 338 339 :param modelname: the name of the model object 340 :param Xname: the name of the X data set 341 :param Yname: the name of the Y data set 342 :param rName: the name of the transforms's result data object or None 343 :param pure_python: if True return a python object. If False return 344 a dataframe. Defaults to True to support any client. 345 :param kwargs: kwargs passed to the model's predict method 346 :return: return the score result 347 """ 348 raise NotImplementedError