Source code for omegaml.backends.basemodel

  1import shutil
  2from pathlib import Path
  3
  4import joblib
  5import smart_open
  6
  7from omegaml.backends.basecommon import BackendBaseCommon
  8from omegaml.util import reshaped
  9
 10
 11class BaseModelBackend(BackendBaseCommon):
 12    """
 13    OmegaML BaseModelBackend to be subclassed by other arbitrary backends
 14
 15    This provides the abstract interface for any model backend to be implemented
 16    Subclass to implement custom backends.
 17
 18    Essentially a model backend:
 19
 20     * provides methods to serialize and deserialize a machine learning model for a given ML framework
 21     * offers fit() and predict() methods to be called by the runtime
 22     * offers additional methods such as score(), partial_fit(), transform()
 23
 24    Model backends are the middleware that connects the om.models API to specific frameworks. This class
 25    makes it simple to implement a model backend by offering a common syntax as well as a default implementation
 26    for get() and put().
 27
 28    Methods to implement:
 29        # for model serialization (mandatory)
 30        @classmethod supports() - determine if backend supports given model instance
 31        _package_model() - serialize a model instance into a temporary file
 32        _extract_model() - deserialize the model from a file-like
 33
 34        By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized
 35        Python objects. If this is not sufficient or applicable to your type models, override these
 36        methods.
 37
 38        Both methods provide readily set up temporary file names so that all you have to do is actually
 39        save the model to the given output file and restore the model from the given input file, respectively.
 40        All other logic has already been implemented (see get_model and put_model methods).
 41
 42        # for fitting and predicting (mandatory)
 43        fit()
 44        predict()
 45
 46        # other methods (optional)
 47        fit_transform() - fit and return a transformed dataset
 48        partial_fit() - fit incrementally
 49        predict_proba() - predict probabilities
 50        score() - score fitted classifier vv test dataset
 51
 52    """
 53    _backend_version_tag = '_om_backend_version'
 54    _backend_version = '1'
 55
 56    serializer = lambda store, model, filename, **kwargs: joblib.dump(model, filename)[0]
 57    loader = lambda store, infile, filename=None, **kwargs: joblib.load(infile)
 58
 59    def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
 60        assert model_store, "Need a model store"
 61        assert data_store, "Need a data store"
 62        self.model_store = model_store
 63        self.data_store = data_store
 64        self.tracking = tracking
 65
 66    @classmethod
 67    def supports(self, obj, name, **kwargs):
 68        """
 69        test if this backend supports this obj
 70        """
 71        return False
 72
 73    @property
 74    def _call_handler(self):
 75        # the model store handles _pre and _post methods in self.perform()
 76        return self.model_store
 77
 78    def get(self, name, uri=None, **kwargs):
 79        """
 80        retrieve a model
 81
 82        :param name: the name of the object
 83        :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
 84           for the file's name
 85        :param version: the version of the object (not supported)
 86
 87        .. versionadded: NEXT
 88            uri specifies a target filename to store the serialized model
 89        """
 90        # support new backend architecture while keeping back compatibility
 91        return self.get_model(name, uri=uri, **kwargs)
 92
 93    def put(self, obj, name, uri=None, **kwargs):
 94        """
 95        store a model
 96
 97        :param obj: the model object to be stored
 98        :param name: the name of the object
 99        :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
100           for the file's name
101        :param attributes: attributes for meta data
102
103        .. versionadded: NEXT
104            local specifies a local filename to store the serialized model
105        """
106        # support new backend architecture while keeping back compatibility
107        return self.put_model(obj, name, uri=uri, **kwargs)
108
109    def drop(self, name, force=False, version=-1, **kwargs):
110        return self.model_store._drop(name, force=force, version=version)
111
112    def _package_model(self, model, key, tmpfn, serializer=None, **kwargs):
113        """
114        implement this method to serialize a model to the given tmpfn
115
116        Args:
117            model (object): the model object to serialize to a file
118            key (str): the object store's key for this object
119            tmpfn (str): the filename to store the serialized object to
120            serializer (callable): optional, a callable as serializer(store, model, filename, **kwargs),
121               defaults to self.serializer, using joblib.dump()
122            **kwargs (dict): optional, keyword arguments passed to the serializer
123
124        Returns:
125            tmpfn or absolute path of serialized file
126
127        .. versionchanged:: NEXT
128            enable custom serializer
129        """
130        serializer = serializer or getattr(self.serializer, '__func__')  # __func__ is the unbound method
131        kwargs.setdefault('key', key)
132        tmpfn = serializer(self, model, tmpfn) or tmpfn
133        return tmpfn
134
135    def _extract_model(self, infile, key, tmpfn, loader=None, **kwargs):
136        """
137        implement this method to deserialize a model from the given infile
138
139        Args:
140            infile (filelike): this is a file-like object supporting read() and seek(). if
141                deserializing from this does not work directly, use tmpfn
142            key (str): the object store's key for this object
143            tmpfn (str): the filename from which to extract the object
144            loader (callable): optional, a callable as loader(store, filename, **kwargs),
145               defaults to self.loader, using joblib.load()
146            **kwargs (dict): optional, keyword arguments passed to the loader
147
148        Returns:
149            model instance
150
151        .. versionchanged:: NEXT
152            enable custom loader
153        """
154        loader = loader or getattr(self.loader, '__func__')  # __func__ is the unbound method
155        kwargs.setdefault('filename', tmpfn)
156        kwargs.setdefault('key', key)
157        obj = loader(self, infile, **kwargs)
158        return obj
159
160    def _remove_path(self, path):
161        """
162        Remove a path, either a file or a directory
163
164        Args:
165            path (str): filename or path to remove. If path is a directory,
166            it will be removed recursively. If path is a file, it will be
167            removed.
168
169        Returns:
170            None
171        """
172        if Path(path).is_dir():
173            shutil.rmtree(path, ignore_errors=True)
174        else:
175            Path(path).unlink(missing_ok=True)
176
177    def get_model(self, name, version=-1, uri=None, loader=None, **kwargs):
178        """
179        Retrieves a pre-stored model
180        """
181        meta = self.model_store.metadata(name)
182        storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
183        uri = uri or meta.uri
184        if uri:
185            uri = str(uri).format(key=storekey)
186            infile = smart_open.open(uri, 'rb')
187        else:
188            infile = meta.gridfile
189        model = self._extract_model(infile, storekey,
190                                    self._tmp_packagefn(self.model_store, storekey),
191                                    loader=loader, **kwargs)
192        infile.close()
193        return model
194
195    def put_model(self, obj, name, attributes=None, _kind_version=None, uri=None, **kwargs):
196        """
197        Packages a model using joblib and stores in GridFS
198        """
199        storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
200        tmpfn = self._tmp_packagefn(self.model_store, storekey)
201        packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn
202        gridfile = self._store_to_file(self.model_store, packagefname, storekey, uri=uri)
203        self._remove_path(packagefname)
204        kind_meta = {
205            self._backend_version_tag: self._backend_version,
206        }
207        return self.model_store._make_metadata(
208            name=name,
209            prefix=self.model_store.prefix,
210            bucket=self.model_store.bucket,
211            kind=self.KIND,
212            kind_meta=kind_meta,
213            attributes=attributes,
214            uri=str(uri or ''),
215            gridfile=gridfile).save()
216
[docs] 217 def predict( 218 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 219 """ 220 predict using data stored in Xname 221 222 :param modelname: the name of the model object 223 :param Xname: the name of the X data set 224 :param rName: the name of the result data object or None 225 :param pure_python: if True return a python object. If False return 226 a dataframe. Defaults to True to support any client. 227 :param kwargs: kwargs passed to the model's predict method 228 :return: return the predicted outcome 229 """ 230 model = self.model_store.get(modelname) 231 data = self._resolve_input_data('predict', Xname, 'X', **kwargs) 232 if not hasattr(model, 'predict'): 233 raise NotImplementedError 234 result = model.predict(reshaped(data)) 235 return self._prepare_result('predict', result, rName=rName, 236 pure_python=pure_python, **kwargs)
237 238 def _resolve_input_data(self, method, Xname, key, **kwargs): 239 data = self.data_store.get(Xname) 240 meta = self.data_store.metadata(Xname) 241 if self.tracking and getattr(self.tracking, 'autotrack', False): 242 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method) 243 return data 244 245 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs): 246 if pure_python: 247 result = result.tolist() 248 if rName: 249 meta = self.data_store.put(result, rName) 250 result = meta 251 if self.tracking and getattr(self.tracking, 'autotrack', False): 252 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind, 253 event=method) 254 return result 255 256 def predict_proba( 257 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 258 """ 259 predict the probability using data stored in Xname 260 261 :param modelname: the name of the model object 262 :param Xname: the name of the X data set 263 :param rName: the name of the result data object or None 264 :param pure_python: if True return a python object. If False return 265 a dataframe. Defaults to True to support any client. 266 :param kwargs: kwargs passed to the model's predict method 267 :return: return the predicted outcome 268 """ 269 raise NotImplementedError 270
[docs] 271 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 272 """ 273 fit the model with data 274 275 :param modelname: the name of the model object 276 :param Xname: the name of the X data set 277 :param Yname: the name of the Y data set 278 :param pure_python: if True return a python object. If False return 279 a dataframe. Defaults to True to support any client. 280 :param kwargs: kwargs passed to the model's predict method 281 :return: return the meta data object of the model 282 """ 283 raise NotImplementedError
284 285 def partial_fit( 286 self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 287 """ 288 partially fit the model with data (online) 289 290 :param modelname: the name of the model object 291 :param Xname: the name of the X data set 292 :param Yname: the name of the Y data set 293 :param pure_python: if True return a python object. If False return 294 a dataframe. Defaults to True to support any client. 295 :param kwargs: kwargs passed to the model's predict method 296 :return: return the meta data object of the model 297 """ 298 299 raise NotImplementedError 300 301 def fit_transform( 302 self, modelname, Xname, Yname=None, rName=None, pure_python=True, 303 **kwargs): 304 """ 305 fit and transform using data 306 307 :param modelname: the name of the model object 308 :param Xname: the name of the X data set 309 :param Yname: the name of the Y data set 310 :param rName: the name of the transforms's result data object or None 311 :param pure_python: if True return a python object. If False return 312 a dataframe. Defaults to True to support any client. 313 :param kwargs: kwargs passed to the model's transform method 314 :return: return the meta data object of the model 315 """ 316 raise NotImplementedError 317
[docs] 318 def transform(self, modelname, Xname, rName=None, **kwargs): 319 """ 320 transform using data 321 322 :param modelname: the name of the model object 323 :param Xname: the name of the X data set 324 :param rName: the name of the transforms's result data object or None 325 :param kwargs: kwargs passed to the model's transform method 326 :return: return the transform data of the model 327 """ 328 raise NotImplementedError
329 330 def score( 331 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 332 **kwargs): 333 """ 334 score using data 335 336 :param modelname: the name of the model object 337 :param Xname: the name of the X data set 338 :param Yname: the name of the Y data set 339 :param rName: the name of the transforms's result data object or None 340 :param pure_python: if True return a python object. If False return 341 a dataframe. Defaults to True to support any client. 342 :param kwargs: kwargs passed to the model's predict method 343 :return: return the score result 344 """ 345 raise NotImplementedError