Source code for omegaml.backends.tensorflow.tfestimatormodel

  1# import glob
  2import glob
  3import os
  4import tempfile
  5from inspect import isfunction
  6from zipfile import ZipFile, ZIP_DEFLATED
  7
  8import dill
  9import logging
 10
 11from omegaml.backends.basemodel import BaseModelBackend
 12
 13ok = lambda v, vtype: isinstance(v, vtype)
 14
 15logger = logging.getLogger(__name__)
 16
 17
 18class TFEstimatorModel(object):
 19    """
 20    A serializable/deserizable wrapper for a TF Estimator
 21
 22    Usage:
 23        estimator = TFEstimatorModel(estimator_fn)
 24        estimator.fit(input_fn=input_fn)
 25        estimator.predict(input_fn=input_fn)
 26
 27        The estimator_fn returns a tf.estimator.Estimator or subclass.
 28    """
 29
 30    def __init__(self, estimator_fn, model=None, input_fn=None, model_dir=None, v1_compat=False):
 31        """
 32
 33        Args:
 34            estimator_fn (func): the function to return a valid tf.estimator.Estimator instance. Called as
 35                                 fn(model_dir=)
 36            model (tf.Estimator): an existing e.g. pre-fitted Estimator instance, optional. If not specified,
 37                                  the model will be recreated by calling estimator_fn. If specified, the
 38                                  model's weights and parameters will be saved and reloaded so that a fitted
 39                                  model can be used without further training.
 40            input_fn (func|dict): the function to create the input_fn as fn(mode, X, Y, batch_size=n), where mode
 41                                  is either 'fit', 'evaluate', or 'predict'. If not provide defaults to an input_fn
 42                                  that tries to infer the correct input_fn from the method and input arguments. If
 43                                  provided as a dict, must contain the 'fit', 'evaluate' and 'predict' keys where
 44                                  each value is a valid input_fn as fn(X, Y, batch_size=n).
 45            model_dir (str): the model directory to use. Defaults to whatever estimator_fn/Estimator instance sets
 46            v1_compat (bool): use tensorflow.compat.v1 to create create the input functions. Use this when
 47                 migrating tensorflow v1.x Estimator models that are not yet v2.x native yet.
 48        """
 49        self.estimator_fn = estimator_fn
 50        self._model_dir = model_dir
 51        self._estimator = model
 52        self._input_fn = input_fn
 53        self.v1_compat = v1_compat
 54
 55    @property
 56    def model_dir(self):
 57        return self.estimator.model_dir
 58
 59    @property
 60    def estimator(self):
 61        if self._estimator is None:
 62            self._estimator = self.estimator_fn(model_dir=self._model_dir)
 63        return self._estimator
 64
 65    def restore(self, model_dir):
 66        self._estimator = None
 67        self._model_dir = model_dir
 68        return self
 69
 70    def make_input_fn(self, mode, X, Y=None, batch_size=1):
 71        """
 72        Return a tf.data.Dataset from the input provided
 73
 74        Args:
 75            mode (str): calling mode, either 'fit', 'predict' or 'evaluate'
 76            X (NDArray|Tensor|Dataset): features, or Dataset of (features, labels)
 77            Y (NDArray|Tensor|Dataset): labels, optional
 78
 79        Notes:
 80            X can be a Dataset of (features, labels), or just features. If X is
 81            just features, also provide a Dataset of just labels.
 82
 83            If X, Y are NDArrays or Tensors, Dataset.from_tensor_slices((dict(X), Y))
 84            is used to create the Dataset. If only X is provided as a NDArray or Tensor,
 85            only X is used to create the Dataset.
 86
 87            If none of these options work, create your own input_fn and pass it
 88            to the .fit/.predict methods using the input_fn= kwarg
 89        """
 90        import pandas as pd
 91        import numpy as np
 92
 93        if self.v1_compat:
 94            # https://www.tensorflow.org/guide/migrate
 95            import tensorflow.compat.v1 as tf
 96            tf.disable_v2_behavior()
 97        else:
 98            import tensorflow as tf
 99
100        if self._input_fn is not None:
101            if isinstance(self._input_fn, dict):
102                return self._input_fn[mode](X, Y=Y, batch_size=batch_size)
103            else:
104                return self._input_fn(mode, X, Y=Y, batch_size=batch_size)
105
106        def input_fn():
107            # if we have a dataset, use that
108            if isinstance(X, tf.data.Dataset):
109                if Y is None:
110                    return X
111                elif isinstance(Y, tf.data.Dataset):
112                    return X.zip(Y)
113                else:
114                    return X, Y
115            # if we have a dataframe, create a dataset from it
116            if ok(X, pd.DataFrame) and ok(Y, pd.Series):
117                dataset = tf.data.Dataset.from_tensor_slices((dict(X), Y))
118                result = dataset.batch(batch_size)
119            elif ok(X, pd.DataFrame):
120                dataset = tf.data.Dataset.from_tensor_slices(dict(X))
121                result = dataset.batch(batch_size)
122            else:
123                result = X, Y
124            return result
125
126        if isinstance(X, (dict, np.ndarray)):
127            input_fn = tf.estimator.inputs.numpy_input_fn(x=X, y=Y, num_epochs=1, shuffle=False)
128        return input_fn
129
130    def fit(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
131        """
132        Args:
133           X (Dataset|ndarray): features
134           Y (Dataset|ndarray): labels, optional
135        """
136        assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
137        if input_fn is None:
138            input_fn = self.make_input_fn('fit', X, Y, batch_size=batch_size)
139        return self.estimator.train(input_fn=input_fn, **kwargs)
140
141    def score(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
142        """
143        Args:
144           X (Dataset|ndarray): features
145           Y (Dataset|ndarray): labels, optional
146        """
147        assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
148        if input_fn is None:
149            input_fn = self.make_input_fn('score', X, Y, batch_size=batch_size)
150        return self.estimator.evaluate(input_fn=input_fn)
151
152    def predict(self, X=None, Y=None, input_fn=None, batch_size=1, **kwargs):
153        """
154        Args:
155           X (Dataset|ndarray): features
156           Y (Dataset|ndarray): labels, optional
157        """
158        options1 = (X is None) and (input_fn is not None)
159        options2 = (X is not None) and (input_fn is None)
160        assert options1 or options2, "specify either X, Y or input_fn, not both"
161        if input_fn is None:
162            input_fn = self.make_input_fn('predict', X, Y, batch_size=batch_size)
163        return self.estimator.predict(input_fn=input_fn)
164
165
[docs] 166class TFEstimatorModelBackend(BaseModelBackend): 167 KIND = 'tfestimator.model' 168
[docs] 169 @classmethod 170 def supports(self, obj, name, **kwargs): 171 return isinstance(obj, TFEstimatorModel)
172 173 def _package_model(self, model, key, tmpfn, **kwargs): 174 model_dir = model.model_dir 175 fname = os.path.basename(tmpfn) 176 zipfname = os.path.join(self.model_store.tmppath, fname) 177 # get relevant parts of model_dir 178 with ZipFile(zipfname, 'w', compression=ZIP_DEFLATED) as zipf: 179 zipf.writestr('modelobj.dill', dill.dumps(model)) 180 for part in glob.glob(os.path.join(model_dir, '*')): 181 arcname = os.path.basename(part) 182 if arcname == 'modelobj.dill': 183 # ignore pre-existing model 184 continue 185 zipf.write(part, arcname) 186 return zipfname 187 188 def _extract_model(self, infile, key, tmpfn, **kwargs): 189 lpath = tempfile.mkdtemp() 190 with open(tmpfn, 'wb') as pkgf: 191 pkgf.write(infile.read()) 192 with ZipFile(tmpfn) as zipf: 193 zipf.extractall(lpath) 194 with open(os.path.join(lpath, 'modelobj.dill'), 'rb') as fin: 195 model = dill.load(fin) 196 model.restore(lpath) 197 return model 198
[docs] 199 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 200 model = self.model_store.get(modelname) 201 X = self.data_store.get(Xname) 202 Y = self.data_store.get(Yname) if Yname else None 203 if isfunction(X) and Y is None: 204 # support f 205 model.fit(input_fn=X) 206 else: 207 model.fit(X, Y) 208 meta = self.model_store.put(model, modelname) 209 return meta
210
[docs] 211 def predict( 212 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 213 import pandas as pd 214 model = self.model_store.get(modelname) 215 X = self._resolve_input_data('predict', Xname, 'X', **kwargs) 216 if isfunction(X): 217 result = pd.DataFrame(v for v in model.predict(input_fn=X)) 218 else: 219 result = pd.DataFrame(v for v in model.predict(X)) 220 return self._prepare_result('predict', result, rName=rName, pure_python=pure_python, **kwargs)
221
[docs] 222 def score( 223 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 224 **kwargs): 225 import pandas as pd 226 model = self.model_store.get(modelname) 227 X = self.data_store.get(Xname) 228 Y = self.data_store.get(Yname) 229 if isfunction(X) and Y is None: 230 # support f 231 result = model.fit(input_fn=X) 232 else: 233 result = model.score(X, Y) 234 if not pure_python: 235 result = pd.Series(result) 236 if rName is not None: 237 result = self.data_store.put(result, rName) 238 return result