Source code for omegaml.backends.basemodel

  1import shutil
  2from pathlib import Path
  3
  4import joblib
  5import smart_open
  6
  7from omegaml.backends.basecommon import BackendBaseCommon
  8from omegaml.util import reshaped
  9
 10
 11class BaseModelBackend(BackendBaseCommon):
 12    """
 13    OmegaML BaseModelBackend to be subclassed by other arbitrary backends
 14
 15    This provides the abstract interface for any model backend to be implemented
 16    Subclass to implement custom backends.
 17
 18    Essentially a model backend:
 19
 20     * provides methods to serialize and deserialize a machine learning model for a given ML framework
 21     * offers fit() and predict() methods to be called by the runtime
 22     * offers additional methods such as score(), partial_fit(), transform()
 23
 24    Model backends are the middleware that connects the om.models API to specific frameworks. This class
 25    makes it simple to implement a model backend by offering a common syntax as well as a default implementation
 26    for get() and put().
 27
 28    Methods to implement:
 29        # for model serialization (mandatory)
 30        @classmethod supports() - determine if backend supports given model instance
 31        _package_model() - serialize a model instance into a temporary file
 32        _extract_model() - deserialize the model from a file-like
 33
 34        By default BaseModelBackend uses joblib.dumps/loads to store the model as serialized
 35        Python objects. If this is not sufficient or applicable to your type models, override these
 36        methods.
 37
 38        Both methods provide readily set up temporary file names so that all you have to do is actually
 39        save the model to the given output file and restore the model from the given input file, respectively.
 40        All other logic has already been implemented (see get_model and put_model methods).
 41
 42        # for fitting and predicting (mandatory)
 43        fit()
 44        predict()
 45
 46        # other methods (optional)
 47        fit_transform() - fit and return a transformed dataset
 48        partial_fit() - fit incrementally
 49        predict_proba() - predict probabilities
 50        score() - score fitted classifier vv test dataset
 51
 52    """
 53    _backend_version_tag = '_om_backend_version'
 54    _backend_version = '1'
 55
 56    serializer = lambda store, model, filename, **kwargs: joblib.dump(model, filename)[0]
 57    loader = lambda store, infile, filename=None, **kwargs: joblib.load(infile)
 58    infer = lambda obj, **kwargs: getattr(obj, 'predict')
 59    reshape = lambda data, **kwargs: reshaped(data)
 60    types = None
 61
 62    def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
 63        assert model_store, "Need a model store"
 64        assert data_store, "Need a data store"
 65        self.model_store = model_store
 66        self.data_store = data_store
 67        self.tracking = tracking
 68
 69    @classmethod
 70    def supports(self, obj, name, **kwargs):
 71        """
 72        test if this backend supports this obj
 73        """
 74        return isinstance(obj, self.types) if self.types else False
 75
 76    @property
 77    def _call_handler(self):
 78        # the model store handles _pre and _post methods in self.perform()
 79        return self.model_store
 80
 81    def get(self, name, uri=None, **kwargs):
 82        """
 83        retrieve a model
 84
 85        :param name: the name of the object
 86        :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
 87           for the file's name
 88        :param version: the version of the object (not supported)
 89
 90        .. versionadded: 0.18
 91            uri specifies a target filename to store the serialized model
 92        """
 93        # support new backend architecture while keeping back compatibility
 94        return self.get_model(name, uri=uri, **kwargs)
 95
 96    def put(self, obj, name, uri=None, **kwargs):
 97        """
 98        store a model
 99
100        :param obj: the model object to be stored
101        :param name: the name of the object
102        :param uri: optional, /path/to/file, defaults to meta.gridfile, may use /path/{key} as placeholder
103           for the file's name
104        :param attributes: attributes for meta data
105
106        .. versionadded: 0.18
107            local specifies a local filename to store the serialized model
108        """
109        # support new backend architecture while keeping back compatibility
110        return self.put_model(obj, name, uri=uri, **kwargs)
111
112    def drop(self, name, force=False, version=-1, **kwargs):
113        return self.model_store._drop(name, force=force, version=version)
114
115    def _package_model(self, model, key, tmpfn, serializer=None, **kwargs):
116        """
117        implement this method to serialize a model to the given tmpfn
118
119        Args:
120            model (object): the model object to serialize to a file
121            key (str): the object store's key for this object
122            tmpfn (str): the filename to store the serialized object to
123            serializer (callable): optional, a callable as serializer(store, model, filename, **kwargs),
124               defaults to self.serializer, using joblib.dump()
125            **kwargs (dict): optional, keyword arguments passed to the serializer
126
127        Returns:
128            tmpfn or absolute path of serialized file
129
130        .. versionchanged:: 0.18
131            enable custom serializer
132        """
133        serializer = serializer or getattr(self.serializer, '__func__')  # __func__ is the unbound method
134        kwargs.setdefault('key', key)
135        tmpfn = serializer(self, model, tmpfn) or tmpfn
136        return tmpfn
137
138    def _extract_model(self, infile, key, tmpfn, loader=None, **kwargs):
139        """
140        implement this method to deserialize a model from the given infile
141
142        Args:
143            infile (filelike): this is a file-like object supporting read() and seek(). if
144                deserializing from this does not work directly, use tmpfn
145            key (str): the object store's key for this object
146            tmpfn (str): the filename from which to extract the object
147            loader (callable): optional, a callable as loader(store, filename, **kwargs),
148               defaults to self.loader, using joblib.load()
149            **kwargs (dict): optional, keyword arguments passed to the loader
150
151        Returns:
152            model instance
153
154        .. versionchanged:: 0.18
155            enable custom loader
156        """
157        loader = loader or getattr(self.loader, '__func__')  # __func__ is the unbound method
158        kwargs.setdefault('filename', tmpfn)
159        kwargs.setdefault('key', key)
160        obj = loader(self, infile, **kwargs)
161        return obj
162
163    def _remove_path(self, path):
164        """
165        Remove a path, either a file or a directory
166
167        Args:
168            path (str): filename or path to remove. If path is a directory,
169            it will be removed recursively. If path is a file, it will be
170            removed.
171
172        Returns:
173            None
174        """
175        if Path(path).is_dir():
176            shutil.rmtree(path, ignore_errors=True)
177        else:
178            Path(path).unlink(missing_ok=True)
179
180    def get_model(self, name, version=-1, uri=None, loader=None, **kwargs):
181        """
182        Retrieves a pre-stored model
183        """
184        meta = self.model_store.metadata(name)
185        storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
186        uri = uri or meta.uri
187        if uri:
188            uri = str(uri).format(key=storekey)
189            infile = smart_open.open(uri, 'rb')
190        else:
191            infile = meta.gridfile
192        model = self._extract_model(infile, storekey,
193                                    self._tmp_packagefn(self.model_store, storekey),
194                                    loader=loader, **kwargs)
195        infile.close()
196        return model
197
198    def put_model(self, obj, name, attributes=None, _kind_version=None, uri=None, **kwargs):
199        """
200        Packages a model using joblib and stores in GridFS
201        """
202        storekey = self.model_store.object_store_key(name, 'omm', hashed=True)
203        tmpfn = self._tmp_packagefn(self.model_store, storekey)
204        packagefname = self._package_model(obj, storekey, tmpfn, **kwargs) or tmpfn
205        gridfile = self._store_to_file(self.model_store, packagefname, storekey, uri=uri)
206        self._remove_path(packagefname)
207        kind_meta = {
208            self._backend_version_tag: self._backend_version,
209        }
210        return self.model_store._make_metadata(
211            name=name,
212            prefix=self.model_store.prefix,
213            bucket=self.model_store.bucket,
214            kind=self.KIND,
215            kind_meta=kind_meta,
216            attributes=attributes,
217            uri=str(uri or ''),
218            gridfile=gridfile).save()
219
[docs] 220 def predict( 221 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 222 """ 223 predict using data stored in Xname 224 225 :param modelname: the name of the model object 226 :param Xname: the name of the X data set 227 :param rName: the name of the result data object or None 228 :param pure_python: if True return a python object. If False return 229 a dataframe. Defaults to True to support any client. 230 :param kwargs: kwargs passed to the model's predict method 231 :return: return the predicted outcome 232 """ 233 model = self.model_store.get(modelname) 234 data = self._resolve_input_data('predict', Xname, 'X', **kwargs) 235 infer = getattr(self.infer, '__func__') # __func__ is the unbound method 236 reshape = getattr(self.reshape, '__func__') 237 result = infer(model)(reshape(data)) 238 return self._prepare_result('predict', result, rName=rName, 239 pure_python=pure_python, **kwargs)
240 241 def _resolve_input_data(self, method, Xname, key, **kwargs): 242 data = self.data_store.get(Xname) 243 meta = self.data_store.metadata(Xname) 244 if self.tracking and getattr(self.tracking, 'autotrack', False): 245 self.tracking.log_data(key, data, dataset=Xname, kind=meta.kind, event=method) 246 return data 247 248 def _prepare_result(self, method, result, rName=None, pure_python=False, **kwargs): 249 if pure_python: 250 result = result.tolist() 251 if rName: 252 meta = self.data_store.put(result, rName) 253 result = meta 254 if self.tracking and getattr(self.tracking, 'autotrack', False): 255 self.tracking.log_data('Y', result, dataset=rName, kind=str(type(result)) if rName is None else meta.kind, 256 event=method) 257 return result 258 259 def predict_proba( 260 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 261 """ 262 predict the probability using data stored in Xname 263 264 :param modelname: the name of the model object 265 :param Xname: the name of the X data set 266 :param rName: the name of the result data object or None 267 :param pure_python: if True return a python object. If False return 268 a dataframe. Defaults to True to support any client. 269 :param kwargs: kwargs passed to the model's predict method 270 :return: return the predicted outcome 271 """ 272 raise NotImplementedError 273
[docs] 274 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 275 """ 276 fit the model with data 277 278 :param modelname: the name of the model object 279 :param Xname: the name of the X data set 280 :param Yname: the name of the Y data set 281 :param pure_python: if True return a python object. If False return 282 a dataframe. Defaults to True to support any client. 283 :param kwargs: kwargs passed to the model's predict method 284 :return: return the meta data object of the model 285 """ 286 raise NotImplementedError
287 288 def partial_fit( 289 self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 290 """ 291 partially fit the model with data (online) 292 293 :param modelname: the name of the model object 294 :param Xname: the name of the X data set 295 :param Yname: the name of the Y data set 296 :param pure_python: if True return a python object. If False return 297 a dataframe. Defaults to True to support any client. 298 :param kwargs: kwargs passed to the model's predict method 299 :return: return the meta data object of the model 300 """ 301 302 raise NotImplementedError 303 304 def fit_transform( 305 self, modelname, Xname, Yname=None, rName=None, pure_python=True, 306 **kwargs): 307 """ 308 fit and transform using data 309 310 :param modelname: the name of the model object 311 :param Xname: the name of the X data set 312 :param Yname: the name of the Y data set 313 :param rName: the name of the transforms's result data object or None 314 :param pure_python: if True return a python object. If False return 315 a dataframe. Defaults to True to support any client. 316 :param kwargs: kwargs passed to the model's transform method 317 :return: return the meta data object of the model 318 """ 319 raise NotImplementedError 320
[docs] 321 def transform(self, modelname, Xname, rName=None, **kwargs): 322 """ 323 transform using data 324 325 :param modelname: the name of the model object 326 :param Xname: the name of the X data set 327 :param rName: the name of the transforms's result data object or None 328 :param kwargs: kwargs passed to the model's transform method 329 :return: return the transform data of the model 330 """ 331 raise NotImplementedError
332 333 def score( 334 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 335 **kwargs): 336 """ 337 score using data 338 339 :param modelname: the name of the model object 340 :param Xname: the name of the X data set 341 :param Yname: the name of the Y data set 342 :param rName: the name of the transforms's result data object or None 343 :param pure_python: if True return a python object. If False return 344 a dataframe. Defaults to True to support any client. 345 :param kwargs: kwargs passed to the model's predict method 346 :return: return the score result 347 """ 348 raise NotImplementedError