Source code for omegaml.backends.basemodel

  1import joblib
  2import shutil
  3from omegaml.backends.basecommon import BackendBaseCommon
  4from omegaml.util import reshaped
  5from pathlib import Path
  6
  7
[docs] 8class BaseModelBackend(BackendBaseCommon): 9 """ 10 OmegaML BaseModelBackend to be subclassed by other arbitrary backends 11 12 This provides the abstract interface for any model backend to be implemented 13 Subclass to implement custom backends. 14 15 Essentially a model backend: 16 17 * provides methods to serialize and deserialize a machine learning model for a given ML framework 18 * offers fit() and predict() methods to be called by the runtime 19 * offers additional methods such as score(), partial_fit(), transform() 20 21 Model backends are the middleware that connects the om.models API to specific frameworks. This class 22 makes it simple to implement a model backend by offering a common syntax as well as a default implementation 23 for get() and put(). 24 25 Methods to implement: 26 # for model serialization (mandatory) 27 @classmethod supports() - determine if backend supports given model instance 28 _package_model() - serialize a model instance into a temporary file 29 _extract_model() - deserialize the model from a file-like 30 31 By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized 32 Python objects. If this is not sufficient or applicable to your type models, override these 33 methods. 34 35 Both methods provide readily set up temporary file names so that all you have to do is actually 36 save the model to the given output file and restore the model from the given input file, respectively. 37 All other logic has already been implemented (see get_model and put_model methods). 38 39 # for fitting and predicting (mandatory) 40 fit() 41 predict() 42 43 # other methods (optional) 44 fit_transform() - fit and return a transformed dataset 45 partial_fit() - fit incrementally 46 predict_proba() - predict probabilities 47 score() - score fitted classifier vv test dataset 48 49 """ 50 _backend_version_tag = '_om_backend_version' 51 _backend_version = '1' 52 53 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs): 54 assert model_store, "Need a model store" 55 assert data_store, "Need a data store" 56 self.model_store = model_store 57 self.data_store = data_store 58 self.tracking = tracking 59 60 @classmethod 61 def supports(self, obj, name, **kwargs): 62 """ 63 test if this backend supports this obj 64 """ 65 return False 66 67 @property 68 def _call_handler(self): 69 # the model store handles _pre and _post methods in self.perform() 70 return self.model_store 71 72 def get(self, name, **kwargs): 73 """ 74 retrieve a model 75 76 :param name: the name of the object 77 :param version: the version of the object (not supported) 78 """ 79 # support new backend architecture while keeping back compatibility 80 return self.get_model(name, **kwargs) 81 82 def put(self, obj, name, **kwargs): 83 """ 84 store a model 85 86 :param obj: the model object to be stored 87 :param name: the name of the object 88 :param attributes: attributes for meta data 89 """ 90 # support new backend architecture while keeping back compatibility 91 return self.put_model(obj, name, **kwargs) 92 93 def drop(self, name, force=False, version=-1, **kwargs): 94 return self.model_store._drop(name, force=force, version=version) 95 96 def _package_model(self, model, key, tmpfn, **kwargs): 97 """ 98 implement this method to serialize a model to the given tmpfn 99 100 Args: 101 model: 102 key: 103 tmpfn: 104 **kwargs: 105 106 Returns: 107 tmpfn or absolute path of serialized file 108 """ 109 with open(tmpfn, 'wb') as outf: 110 joblib.dump(model, outf) 111 return tmpfn 112 113 def _extract_model(self, infile, key, tmpfn, **kwargs): 114 """ 115 implement this method to deserialize a model from the given infile 116 117 Args: 118 infile: this is a file-like object supporting read() and seek(). if 119 deserializing from this does not work directly, use tmpfn 120 key: 121 tmpfn: 122 **kwargs: 123 124 Returns: 125 model instance 126 """ 127 obj = joblib.load(infile) 128 return obj 129 130 def _remove_path(self, path): 131 """ 132 Remove a path, either a file or a directory 133 134 Args: 135 path (str): filename or path to remove. If path is a directory, 136 it will be removed recursively. If path is a file, it will be 137 removed. 138 139 Returns: 140 None 141 """ 142 if Path(path).is_dir(): 143 shutil.rmtree(path, ignore_errors=True) 144 else: 145 Path(path).unlink(missing_ok=True) 146 147 def get_model(self, name, version=-1, **kwargs): 148 """ 149 Retrieves a pre-stored model 150 """ 151 meta = self.model_store.metadata(name) 152 storekey = self.model_store.object_store_key(name, 'omm', hashed=True) 153 model = self._extract_model(meta.gridfile, storekey, 154 self._tmp_packagefn(self.model_store, storekey), **kwargs) 155 return model 156 157 def put_model(self, obj, name, attributes=None, _kind_version=None, **kwargs): 158 """ 159 Packages a model using joblib and stores in GridFS 160 """ 161 storekey = self.model_store.object_store_key(name, 'omm', hashed=True) 162 tmpfn = self._tmp_packagefn(self.model_store, storekey) 163 packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn 164 gridfile = self._store_to_file(self.model_store, packagefname, storekey) 165 self._remove_path(packagefname) 166 kind_meta = { 167 self._backend_version_tag: self._backend_version, 168 } 169 return self.model_store._make_metadata( 170 name=name, 171 prefix=self.model_store.prefix, 172 bucket=self.model_store.bucket, 173 kind=self.KIND, 174 kind_meta=kind_meta, 175 attributes=attributes, 176 gridfile=gridfile).save() 177
[docs] 178 def predict( 179 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 180 """ 181 predict using data stored in Xname 182 183 :param modelname: the name of the model object 184 :param Xname: the name of the X data set 185 :param rName: the name of the result data object or None 186 :param pure_python: if True return a python object. If False return 187 a dataframe. Defaults to True to support any client. 188 :param kwargs: kwargs passed to the model's predict method 189 :return: return the predicted outcome 190 """ 191 model = self.model_store.get(modelname) 192 data = self._resolve_input_data('predict', Xname, 'X', **kwargs) 193 if not hasattr(model, 'predict'): 194 raise NotImplementedError 195 result = model.predict(reshaped(data)) 196 return self._prepare_result('predict', result, rName=rName, 197 pure_python=pure_python, **kwargs)
198 199 def _resolve_input_data(self, method, Xname, key, **kwargs): 200 data = self.data_store.get(Xname) 201 meta = self.data_store.metadata(Xname) 202 if self.tracking and getattr(self.tracking, 'autotrack', False): 203 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method) 204 return data 205 206 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs): 207 if pure_python: 208 result = result.tolist() 209 if rName: 210 meta = self.data_store.put(result, rName) 211 result = meta 212 if self.tracking and getattr(self.tracking, 'autotrack', False): 213 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind, 214 event=method) 215 return result 216 217 def predict_proba( 218 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 219 """ 220 predict the probability using data stored in Xname 221 222 :param modelname: the name of the model object 223 :param Xname: the name of the X data set 224 :param rName: the name of the result data object or None 225 :param pure_python: if True return a python object. If False return 226 a dataframe. Defaults to True to support any client. 227 :param kwargs: kwargs passed to the model's predict method 228 :return: return the predicted outcome 229 """ 230 raise NotImplementedError 231
[docs] 232 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 233 """ 234 fit the model with data 235 236 :param modelname: the name of the model object 237 :param Xname: the name of the X data set 238 :param Yname: the name of the Y data set 239 :param pure_python: if True return a python object. If False return 240 a dataframe. Defaults to True to support any client. 241 :param kwargs: kwargs passed to the model's predict method 242 :return: return the meta data object of the model 243 """ 244 raise NotImplementedError
245 246 def partial_fit( 247 self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 248 """ 249 partially fit the model with data (online) 250 251 :param modelname: the name of the model object 252 :param Xname: the name of the X data set 253 :param Yname: the name of the Y data set 254 :param pure_python: if True return a python object. If False return 255 a dataframe. Defaults to True to support any client. 256 :param kwargs: kwargs passed to the model's predict method 257 :return: return the meta data object of the model 258 """ 259 260 raise NotImplementedError 261 262 def fit_transform( 263 self, modelname, Xname, Yname=None, rName=None, pure_python=True, 264 **kwargs): 265 """ 266 fit and transform using data 267 268 :param modelname: the name of the model object 269 :param Xname: the name of the X data set 270 :param Yname: the name of the Y data set 271 :param rName: the name of the transforms's result data object or None 272 :param pure_python: if True return a python object. If False return 273 a dataframe. Defaults to True to support any client. 274 :param kwargs: kwargs passed to the model's transform method 275 :return: return the meta data object of the model 276 """ 277 raise NotImplementedError 278
[docs] 279 def transform(self, modelname, Xname, rName=None, **kwargs): 280 """ 281 transform using data 282 283 :param modelname: the name of the model object 284 :param Xname: the name of the X data set 285 :param rName: the name of the transforms's result data object or None 286 :param kwargs: kwargs passed to the model's transform method 287 :return: return the transform data of the model 288 """ 289 raise NotImplementedError
290 291 def score( 292 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 293 **kwargs): 294 """ 295 score using data 296 297 :param modelname: the name of the model object 298 :param Xname: the name of the X data set 299 :param Yname: the name of the Y data set 300 :param rName: the name of the transforms's result data object or None 301 :param pure_python: if True return a python object. If False return 302 a dataframe. Defaults to True to support any client. 303 :param kwargs: kwargs passed to the model's predict method 304 :return: return the score result 305 """ 306 raise NotImplementedError