Source code for omegaml.backends.basemodel

  1import joblib
  2import shutil
  3from omegaml.backends.basecommon import BackendBaseCommon
  4from omegaml.util import reshaped
  5from pathlib import Path
  6
  7
  8class BaseModelBackend(BackendBaseCommon):
  9    """
 10    OmegaML BaseModelBackend to be subclassed by other arbitrary backends
 11
 12    This provides the abstract interface for any model backend to be implemented
 13    Subclass to implement custom backends.
 14
 15    Essentially a model backend:
 16
 17     * provides methods to serialize and deserialize a machine learning model for a given ML framework
 18     * offers fit() and predict() methods to be called by the runtime
 19     * offers additional methods such as score(), partial_fit(), transform()
 20
 21    Model backends are the middleware that connects the om.models API to specific frameworks. This class
 22    makes it simple to implement a model backend by offering a common syntax as well as a default implementation
 23    for get() and put().
 24
 25    Methods to implement:
 26        # for model serialization (mandatory)
 27        @classmethod supports() - determine if backend supports given model instance
 28        _package_model() - serialize a model instance into a temporary file
 29        _extract_model() - deserialize the model from a file-like
 30
 31        By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized
 32        Python objects. If this is not sufficient or applicable to your type models, override these
 33        methods.
 34
 35        Both methods provide readily set up temporary file names so that all you have to do is actually
 36        save the model to the given output file and restore the model from the given input file, respectively.
 37        All other logic has already been implemented (see get_model and put_model methods).
 38
 39        # for fitting and predicting (mandatory)
 40        fit()
 41        predict()
 42
 43        # other methods (optional)
 44        fit_transform() - fit and return a transformed dataset
 45        partial_fit() - fit incrementally
 46        predict_proba() - predict probabilities
 47        score() - score fitted classifier vv test dataset
 48
 49    """
 50    _backend_version_tag = '_om_backend_version'
 51    _backend_version = '1'
 52    
 53    def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
 54        assert model_store, "Need a model store"
 55        assert data_store, "Need a data store"
 56        self.model_store = model_store
 57        self.data_store = data_store
 58        self.tracking = tracking
 59
 60    @classmethod
 61    def supports(self, obj, name, **kwargs):
 62        """
 63        test if this backend supports this obj
 64        """
 65        return False
 66
 67    @property
 68    def _call_handler(self):
 69        # the model store handles _pre and _post methods in self.perform()
 70        return self.model_store
 71
 72    def get(self, name, **kwargs):
 73        """
 74        retrieve a model
 75
 76        :param name: the name of the object
 77        :param version: the version of the object (not supported)
 78        """
 79        # support new backend architecture while keeping back compatibility
 80        return self.get_model(name, **kwargs)
 81
 82    def put(self, obj, name, **kwargs):
 83        """
 84        store a model
 85
 86        :param obj: the model object to be stored
 87        :param name: the name of the object
 88        :param attributes: attributes for meta data
 89        """
 90        # support new backend architecture while keeping back compatibility
 91        return self.put_model(obj, name, **kwargs)
 92
 93    def drop(self, name, force=False, version=-1, **kwargs):
 94        return self.model_store._drop(name, force=force, version=version)
 95
 96    def _package_model(self, model, key, tmpfn, **kwargs):
 97        """
 98        implement this method to serialize a model to the given tmpfn
 99
100        Args:
101            model:
102            key:
103            tmpfn:
104            **kwargs:
105
106        Returns:
107            tmpfn or absolute path of serialized file
108        """
109        with open(tmpfn, 'wb') as outf:
110            joblib.dump(model, outf)
111        return tmpfn
112
113    def _extract_model(self, infile, key, tmpfn, **kwargs):
114        """
115        implement this method to deserialize a model from the given infile
116
117        Args:
118            infile: this is a file-like object supporting read() and seek(). if
119                deserializing from this does not work directly, use tmpfn
120            key:
121            tmpfn:
122            **kwargs:
123
124        Returns:
125            model instance
126        """
127        obj = joblib.load(infile)
128        return obj
129
130    def _remove_path(self, path):
131        """
132        Remove a path, either a file or a directory
133
134        Args:
135            path (str): filename or path to remove. If path is a directory,
136            it will be removed recursively. If path is a file, it will be
137            removed.
138
139        Returns:
140            None
141        """
142        if Path(path).is_dir():
143            shutil.rmtree(path, ignore_errors=True)
144        else:
145            Path(path).unlink(missing_ok=True)
146
147    def get_model(self, name, version=-1, **kwargs):
148        """
149        Retrieves a pre-stored model
150        """
151        meta = self.model_store.metadata(name)
152        storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
153        model = self._extract_model(meta.gridfile, storekey,
154                                    self._tmp_packagefn(self.model_store, storekey), **kwargs)
155        return model
156
157    def put_model(self, obj, name, attributes=None, _kind_version=None, **kwargs):
158        """
159        Packages a model using joblib and stores in GridFS
160        """
161        storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
162        tmpfn = self._tmp_packagefn(self.model_store, storekey)
163        packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn
164        gridfile = self._store_to_file(self.model_store, packagefname, storekey)
165        self._remove_path(packagefname)
166        kind_meta = {
167            self._backend_version_tag: self._backend_version,
168        }
169        return self.model_store._make_metadata(
170            name=name,
171            prefix=self.model_store.prefix,
172            bucket=self.model_store.bucket,
173            kind=self.KIND,
174            kind_meta=kind_meta,
175            attributes=attributes,
176            gridfile=gridfile).save()
177
[docs] 178 def predict( 179 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 180 """ 181 predict using data stored in Xname 182 183 :param modelname: the name of the model object 184 :param Xname: the name of the X data set 185 :param rName: the name of the result data object or None 186 :param pure_python: if True return a python object. If False return 187 a dataframe. Defaults to True to support any client. 188 :param kwargs: kwargs passed to the model's predict method 189 :return: return the predicted outcome 190 """ 191 model = self.model_store.get(modelname) 192 data = self._resolve_input_data('predict', Xname, 'X', **kwargs) 193 if not hasattr(model, 'predict'): 194 raise NotImplementedError 195 result = model.predict(reshaped(data)) 196 return self._prepare_result('predict', result, rName=rName, 197 pure_python=pure_python, **kwargs)
198 199 def _resolve_input_data(self, method, Xname, key, **kwargs): 200 data = self.data_store.get(Xname) 201 meta = self.data_store.metadata(Xname) 202 if self.tracking and getattr(self.tracking, 'autotrack', False): 203 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method) 204 return data 205 206 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs): 207 if pure_python: 208 result = result.tolist() 209 if rName: 210 meta = self.data_store.put(result, rName) 211 result = meta 212 if self.tracking and getattr(self.tracking, 'autotrack', False): 213 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind, 214 event=method) 215 return result 216 217 def predict_proba( 218 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 219 """ 220 predict the probability using data stored in Xname 221 222 :param modelname: the name of the model object 223 :param Xname: the name of the X data set 224 :param rName: the name of the result data object or None 225 :param pure_python: if True return a python object. If False return 226 a dataframe. Defaults to True to support any client. 227 :param kwargs: kwargs passed to the model's predict method 228 :return: return the predicted outcome 229 """ 230 raise NotImplementedError 231
[docs] 232 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 233 """ 234 fit the model with data 235 236 :param modelname: the name of the model object 237 :param Xname: the name of the X data set 238 :param Yname: the name of the Y data set 239 :param pure_python: if True return a python object. If False return 240 a dataframe. Defaults to True to support any client. 241 :param kwargs: kwargs passed to the model's predict method 242 :return: return the meta data object of the model 243 """ 244 raise NotImplementedError
245 246 def partial_fit( 247 self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 248 """ 249 partially fit the model with data (online) 250 251 :param modelname: the name of the model object 252 :param Xname: the name of the X data set 253 :param Yname: the name of the Y data set 254 :param pure_python: if True return a python object. If False return 255 a dataframe. Defaults to True to support any client. 256 :param kwargs: kwargs passed to the model's predict method 257 :return: return the meta data object of the model 258 """ 259 260 raise NotImplementedError 261 262 def fit_transform( 263 self, modelname, Xname, Yname=None, rName=None, pure_python=True, 264 **kwargs): 265 """ 266 fit and transform using data 267 268 :param modelname: the name of the model object 269 :param Xname: the name of the X data set 270 :param Yname: the name of the Y data set 271 :param rName: the name of the transforms's result data object or None 272 :param pure_python: if True return a python object. If False return 273 a dataframe. Defaults to True to support any client. 274 :param kwargs: kwargs passed to the model's transform method 275 :return: return the meta data object of the model 276 """ 277 raise NotImplementedError 278
[docs] 279 def transform(self, modelname, Xname, rName=None, **kwargs): 280 """ 281 transform using data 282 283 :param modelname: the name of the model object 284 :param Xname: the name of the X data set 285 :param rName: the name of the transforms's result data object or None 286 :param kwargs: kwargs passed to the model's transform method 287 :return: return the transform data of the model 288 """ 289 raise NotImplementedError
290 291 def score( 292 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 293 **kwargs): 294 """ 295 score using data 296 297 :param modelname: the name of the model object 298 :param Xname: the name of the X data set 299 :param Yname: the name of the Y data set 300 :param rName: the name of the transforms's result data object or None 301 :param pure_python: if True return a python object. If False return 302 a dataframe. Defaults to True to support any client. 303 :param kwargs: kwargs passed to the model's predict method 304 :return: return the score result 305 """ 306 raise NotImplementedError