Source code for omegaml.backends.tensorflow.tfestimatormodel

  1# import glob
  2import glob
  3import logging
  4import os
  5import tempfile
  6from inspect import isfunction
  7from zipfile import ZipFile, ZIP_DEFLATED
  8
  9import dill
 10
 11from omegaml.backends.basemodel import BaseModelBackend
 12
 13ok = lambda v, vtype: isinstance(v, vtype)
 14
 15logger = logging.getLogger(__name__)
 16
 17
 18class TFEstimatorModel(object):
 19    """
 20    A serializable/deserizable wrapper for a TF Estimator
 21
 22    Usage:
 23        estimator = TFEstimatorModel(estimator_fn)
 24        estimator.fit(input_fn=input_fn)
 25        estimator.predict(input_fn=input_fn)
 26
 27        The estimator_fn returns a tf.estimator.Estimator or subclass.
 28
 29    .. versionchanged:: 0.18.0
 30        Only supported for tensorflow <= 2.15 and Python <= 3.11
 31
 32    .. deprecated:: 0.18.0
 33        Use an object helper or a serializer/loader combination instead.
 34    """
 35
 36    def __init__(self, estimator_fn, model=None, input_fn=None, model_dir=None, v1_compat=False):
 37        """
 38
 39        Args:
 40            estimator_fn (func): the function to return a valid tf.estimator.Estimator instance. Called as
 41                                 fn(model_dir=)
 42            model (tf.Estimator): an existing e.g. pre-fitted Estimator instance, optional. If not specified,
 43                                  the model will be recreated by calling estimator_fn. If specified, the
 44                                  model's weights and parameters will be saved and reloaded so that a fitted
 45                                  model can be used without further training.
 46            input_fn (func|dict): the function to create the input_fn as fn(mode, X, Y, batch_size=n), where mode
 47                                  is either 'fit', 'evaluate', or 'predict'. If not provide defaults to an input_fn
 48                                  that tries to infer the correct input_fn from the method and input arguments. If
 49                                  provided as a dict, must contain the 'fit', 'evaluate' and 'predict' keys where
 50                                  each value is a valid input_fn as fn(X, Y, batch_size=n).
 51            model_dir (str): the model directory to use. Defaults to whatever estimator_fn/Estimator instance sets
 52            v1_compat (bool): use tensorflow.compat.v1 to create create the input functions. Use this when
 53                 migrating tensorflow v1.x Estimator models that are not yet v2.x native yet.
 54        """
 55        self.estimator_fn = estimator_fn
 56        self._model_dir = model_dir
 57        self._estimator = model
 58        self._input_fn = input_fn
 59        self.v1_compat = v1_compat
 60
 61    @property
 62    def model_dir(self):
 63        return self.estimator.model_dir
 64
 65    @property
 66    def estimator(self):
 67        if self._estimator is None:
 68            self._estimator = self.estimator_fn(model_dir=self._model_dir)
 69        return self._estimator
 70
 71    def restore(self, model_dir):
 72        self._estimator = None
 73        self._model_dir = model_dir
 74        return self
 75
 76    def make_input_fn(self, mode, X, Y=None, batch_size=1):
 77        """
 78        Return a tf.data.Dataset from the input provided
 79
 80        Args:
 81            mode (str): calling mode, either 'fit', 'predict' or 'evaluate'
 82            X (NDArray|Tensor|Dataset): features, or Dataset of (features, labels)
 83            Y (NDArray|Tensor|Dataset): labels, optional
 84
 85        Notes:
 86            X can be a Dataset of (features, labels), or just features. If X is
 87            just features, also provide a Dataset of just labels.
 88
 89            If X, Y are NDArrays or Tensors, Dataset.from_tensor_slices((dict(X), Y))
 90            is used to create the Dataset. If only X is provided as a NDArray or Tensor,
 91            only X is used to create the Dataset.
 92
 93            If none of these options work, create your own input_fn and pass it
 94            to the .fit/.predict methods using the input_fn= kwarg
 95        """
 96        import pandas as pd
 97        import numpy as np
 98
 99        if self.v1_compat:
100            # https://www.tensorflow.org/guide/migrate
101            import tensorflow.compat.v1 as tf
102            tf.disable_v2_behavior()
103        else:
104            import tensorflow as tf
105
106        if self._input_fn is not None:
107            if isinstance(self._input_fn, dict):
108                return self._input_fn[mode](X, Y=Y, batch_size=batch_size)
109            else:
110                return self._input_fn(mode, X, Y=Y, batch_size=batch_size)
111
112        def input_fn():
113            # if we have a dataset, use that
114            if isinstance(X, tf.data.Dataset):
115                if Y is None:
116                    return X
117                elif isinstance(Y, tf.data.Dataset):
118                    return X.zip(Y)
119                else:
120                    return X, Y
121            # if we have a dataframe, create a dataset from it
122            if ok(X, pd.DataFrame) and ok(Y, pd.Series):
123                dataset = tf.data.Dataset.from_tensor_slices((dict(X), Y))
124                result = dataset.batch(batch_size)
125            elif ok(X, pd.DataFrame):
126                dataset = tf.data.Dataset.from_tensor_slices(dict(X))
127                result = dataset.batch(batch_size)
128            else:
129                result = X, Y
130            return result
131
132        if isinstance(X, (dict, np.ndarray)):
133            input_fn = tf.estimator.inputs.numpy_input_fn(x=X, y=Y, num_epochs=1, shuffle=False)
134        return input_fn
135
136    def fit(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
137        """
138        Args:
139           X (Dataset|ndarray): features
140           Y (Dataset|ndarray): labels, optional
141        """
142        assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
143        if input_fn is None:
144            input_fn = self.make_input_fn('fit', X, Y, batch_size=batch_size)
145        return self.estimator.train(input_fn=input_fn, **kwargs)
146
147    def score(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
148        """
149        Args:
150           X (Dataset|ndarray): features
151           Y (Dataset|ndarray): labels, optional
152        """
153        assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
154        if input_fn is None:
155            input_fn = self.make_input_fn('score', X, Y, batch_size=batch_size)
156        return self.estimator.evaluate(input_fn=input_fn)
157
158    def predict(self, X=None, Y=None, input_fn=None, batch_size=1, **kwargs):
159        """
160        Args:
161           X (Dataset|ndarray): features
162           Y (Dataset|ndarray): labels, optional
163        """
164        options1 = (X is None) and (input_fn is not None)
165        options2 = (X is not None) and (input_fn is None)
166        assert options1 or options2, "specify either X, Y or input_fn, not both"
167        if input_fn is None:
168            input_fn = self.make_input_fn('predict', X, Y, batch_size=batch_size)
169        return self.estimator.predict(input_fn=input_fn)
170
171
[docs] 172class TFEstimatorModelBackend(BaseModelBackend): 173 KIND = 'tfestimator.model' 174
[docs] 175 @classmethod 176 def supports(self, obj, name, **kwargs): 177 return isinstance(obj, TFEstimatorModel)
178 179 def _package_model(self, model, key, tmpfn, **kwargs): 180 model_dir = model.model_dir 181 fname = os.path.basename(tmpfn) 182 zipfname = os.path.join(self.model_store.tmppath, fname) 183 # get relevant parts of model_dir 184 with ZipFile(zipfname, 'w', compression=ZIP_DEFLATED) as zipf: 185 zipf.writestr('modelobj.dill', dill.dumps(model)) 186 for part in glob.glob(os.path.join(model_dir, '*')): 187 arcname = os.path.basename(part) 188 if arcname == 'modelobj.dill': 189 # ignore pre-existing model 190 continue 191 zipf.write(part, arcname) 192 return zipfname 193 194 def _extract_model(self, infile, key, tmpfn, **kwargs): 195 lpath = tempfile.mkdtemp() 196 with open(tmpfn, 'wb') as pkgf: 197 pkgf.write(infile.read()) 198 with ZipFile(tmpfn) as zipf: 199 zipf.extractall(lpath) 200 with open(os.path.join(lpath, 'modelobj.dill'), 'rb') as fin: 201 model = dill.load(fin) 202 model.restore(lpath) 203 return model 204
[docs] 205 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs): 206 model = self.model_store.get(modelname) 207 X = self.data_store.get(Xname) 208 Y = self.data_store.get(Yname) if Yname else None 209 if isfunction(X) and Y is None: 210 # support f 211 model.fit(input_fn=X) 212 else: 213 model.fit(X, Y) 214 meta = self.model_store.put(model, modelname) 215 return meta
216
[docs] 217 def predict( 218 self, modelname, Xname, rName=None, pure_python=True, **kwargs): 219 import pandas as pd 220 model = self.model_store.get(modelname) 221 X = self._resolve_input_data('predict', Xname, 'X', **kwargs) 222 if isfunction(X): 223 result = pd.DataFrame(v for v in model.predict(input_fn=X)) 224 else: 225 result = pd.DataFrame(v for v in model.predict(X)) 226 return self._prepare_result('predict', result, rName=rName, pure_python=pure_python, **kwargs)
227
[docs] 228 def score( 229 self, modelname, Xname, Yname=None, rName=True, pure_python=True, 230 **kwargs): 231 import pandas as pd 232 model = self.model_store.get(modelname) 233 X = self.data_store.get(Xname) 234 Y = self.data_store.get(Yname) 235 if isfunction(X) and Y is None: 236 # support f 237 result = model.fit(input_fn=X) 238 else: 239 result = model.score(X, Y) 240 if not pure_python: 241 result = pd.Series(result) 242 if rName is not None: 243 result = self.data_store.put(result, rName) 244 return result