1# import glob
2import glob
3import logging
4import os
5import tempfile
6from inspect import isfunction
7from zipfile import ZipFile, ZIP_DEFLATED
8
9import dill
10
11from omegaml.backends.basemodel import BaseModelBackend
12
13ok = lambda v, vtype: isinstance(v, vtype)
14
15logger = logging.getLogger(__name__)
16
17
18class TFEstimatorModel(object):
19 """
20 A serializable/deserizable wrapper for a TF Estimator
21
22 Usage:
23 estimator = TFEstimatorModel(estimator_fn)
24 estimator.fit(input_fn=input_fn)
25 estimator.predict(input_fn=input_fn)
26
27 The estimator_fn returns a tf.estimator.Estimator or subclass.
28
29 .. versionchanged:: 0.18.0
30 Only supported for tensorflow <= 2.15 and Python <= 3.11
31
32 .. deprecated:: 0.18.0
33 Use an object helper or a serializer/loader combination instead.
34 """
35
36 def __init__(self, estimator_fn, model=None, input_fn=None, model_dir=None, v1_compat=False):
37 """
38
39 Args:
40 estimator_fn (func): the function to return a valid tf.estimator.Estimator instance. Called as
41 fn(model_dir=)
42 model (tf.Estimator): an existing e.g. pre-fitted Estimator instance, optional. If not specified,
43 the model will be recreated by calling estimator_fn. If specified, the
44 model's weights and parameters will be saved and reloaded so that a fitted
45 model can be used without further training.
46 input_fn (func|dict): the function to create the input_fn as fn(mode, X, Y, batch_size=n), where mode
47 is either 'fit', 'evaluate', or 'predict'. If not provide defaults to an input_fn
48 that tries to infer the correct input_fn from the method and input arguments. If
49 provided as a dict, must contain the 'fit', 'evaluate' and 'predict' keys where
50 each value is a valid input_fn as fn(X, Y, batch_size=n).
51 model_dir (str): the model directory to use. Defaults to whatever estimator_fn/Estimator instance sets
52 v1_compat (bool): use tensorflow.compat.v1 to create create the input functions. Use this when
53 migrating tensorflow v1.x Estimator models that are not yet v2.x native yet.
54 """
55 self.estimator_fn = estimator_fn
56 self._model_dir = model_dir
57 self._estimator = model
58 self._input_fn = input_fn
59 self.v1_compat = v1_compat
60
61 @property
62 def model_dir(self):
63 return self.estimator.model_dir
64
65 @property
66 def estimator(self):
67 if self._estimator is None:
68 self._estimator = self.estimator_fn(model_dir=self._model_dir)
69 return self._estimator
70
71 def restore(self, model_dir):
72 self._estimator = None
73 self._model_dir = model_dir
74 return self
75
76 def make_input_fn(self, mode, X, Y=None, batch_size=1):
77 """
78 Return a tf.data.Dataset from the input provided
79
80 Args:
81 mode (str): calling mode, either 'fit', 'predict' or 'evaluate'
82 X (NDArray|Tensor|Dataset): features, or Dataset of (features, labels)
83 Y (NDArray|Tensor|Dataset): labels, optional
84
85 Notes:
86 X can be a Dataset of (features, labels), or just features. If X is
87 just features, also provide a Dataset of just labels.
88
89 If X, Y are NDArrays or Tensors, Dataset.from_tensor_slices((dict(X), Y))
90 is used to create the Dataset. If only X is provided as a NDArray or Tensor,
91 only X is used to create the Dataset.
92
93 If none of these options work, create your own input_fn and pass it
94 to the .fit/.predict methods using the input_fn= kwarg
95 """
96 import pandas as pd
97 import numpy as np
98
99 if self.v1_compat:
100 # https://www.tensorflow.org/guide/migrate
101 import tensorflow.compat.v1 as tf
102 tf.disable_v2_behavior()
103 else:
104 import tensorflow as tf
105
106 if self._input_fn is not None:
107 if isinstance(self._input_fn, dict):
108 return self._input_fn[mode](X, Y=Y, batch_size=batch_size)
109 else:
110 return self._input_fn(mode, X, Y=Y, batch_size=batch_size)
111
112 def input_fn():
113 # if we have a dataset, use that
114 if isinstance(X, tf.data.Dataset):
115 if Y is None:
116 return X
117 elif isinstance(Y, tf.data.Dataset):
118 return X.zip(Y)
119 else:
120 return X, Y
121 # if we have a dataframe, create a dataset from it
122 if ok(X, pd.DataFrame) and ok(Y, pd.Series):
123 dataset = tf.data.Dataset.from_tensor_slices((dict(X), Y))
124 result = dataset.batch(batch_size)
125 elif ok(X, pd.DataFrame):
126 dataset = tf.data.Dataset.from_tensor_slices(dict(X))
127 result = dataset.batch(batch_size)
128 else:
129 result = X, Y
130 return result
131
132 if isinstance(X, (dict, np.ndarray)):
133 input_fn = tf.estimator.inputs.numpy_input_fn(x=X, y=Y, num_epochs=1, shuffle=False)
134 return input_fn
135
136 def fit(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
137 """
138 Args:
139 X (Dataset|ndarray): features
140 Y (Dataset|ndarray): labels, optional
141 """
142 assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
143 if input_fn is None:
144 input_fn = self.make_input_fn('fit', X, Y, batch_size=batch_size)
145 return self.estimator.train(input_fn=input_fn, **kwargs)
146
147 def score(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
148 """
149 Args:
150 X (Dataset|ndarray): features
151 Y (Dataset|ndarray): labels, optional
152 """
153 assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
154 if input_fn is None:
155 input_fn = self.make_input_fn('score', X, Y, batch_size=batch_size)
156 return self.estimator.evaluate(input_fn=input_fn)
157
158 def predict(self, X=None, Y=None, input_fn=None, batch_size=1, **kwargs):
159 """
160 Args:
161 X (Dataset|ndarray): features
162 Y (Dataset|ndarray): labels, optional
163 """
164 options1 = (X is None) and (input_fn is not None)
165 options2 = (X is not None) and (input_fn is None)
166 assert options1 or options2, "specify either X, Y or input_fn, not both"
167 if input_fn is None:
168 input_fn = self.make_input_fn('predict', X, Y, batch_size=batch_size)
169 return self.estimator.predict(input_fn=input_fn)
170
171
[docs]
172class TFEstimatorModelBackend(BaseModelBackend):
173 KIND = 'tfestimator.model'
174
[docs]
175 @classmethod
176 def supports(self, obj, name, **kwargs):
177 return isinstance(obj, TFEstimatorModel)
178
179 def _package_model(self, model, key, tmpfn, **kwargs):
180 model_dir = model.model_dir
181 fname = os.path.basename(tmpfn)
182 zipfname = os.path.join(self.model_store.tmppath, fname)
183 # get relevant parts of model_dir
184 with ZipFile(zipfname, 'w', compression=ZIP_DEFLATED) as zipf:
185 zipf.writestr('modelobj.dill', dill.dumps(model))
186 for part in glob.glob(os.path.join(model_dir, '*')):
187 arcname = os.path.basename(part)
188 if arcname == 'modelobj.dill':
189 # ignore pre-existing model
190 continue
191 zipf.write(part, arcname)
192 return zipfname
193
194 def _extract_model(self, infile, key, tmpfn, **kwargs):
195 lpath = tempfile.mkdtemp()
196 with open(tmpfn, 'wb') as pkgf:
197 pkgf.write(infile.read())
198 with ZipFile(tmpfn) as zipf:
199 zipf.extractall(lpath)
200 with open(os.path.join(lpath, 'modelobj.dill'), 'rb') as fin:
201 model = dill.load(fin)
202 model.restore(lpath)
203 return model
204
[docs]
205 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
206 model = self.model_store.get(modelname)
207 X = self.data_store.get(Xname)
208 Y = self.data_store.get(Yname) if Yname else None
209 if isfunction(X) and Y is None:
210 # support f
211 model.fit(input_fn=X)
212 else:
213 model.fit(X, Y)
214 meta = self.model_store.put(model, modelname)
215 return meta
216
[docs]
217 def predict(
218 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
219 import pandas as pd
220 model = self.model_store.get(modelname)
221 X = self._resolve_input_data('predict', Xname, 'X', **kwargs)
222 if isfunction(X):
223 result = pd.DataFrame(v for v in model.predict(input_fn=X))
224 else:
225 result = pd.DataFrame(v for v in model.predict(X))
226 return self._prepare_result('predict', result, rName=rName, pure_python=pure_python, **kwargs)
227
[docs]
228 def score(
229 self, modelname, Xname, Yname=None, rName=True, pure_python=True,
230 **kwargs):
231 import pandas as pd
232 model = self.model_store.get(modelname)
233 X = self.data_store.get(Xname)
234 Y = self.data_store.get(Yname)
235 if isfunction(X) and Y is None:
236 # support f
237 result = model.fit(input_fn=X)
238 else:
239 result = model.score(X, Y)
240 if not pure_python:
241 result = pd.Series(result)
242 if rName is not None:
243 result = self.data_store.put(result, rName)
244 return result