1# import glob
2import glob
3import os
4import tempfile
5from inspect import isfunction
6from zipfile import ZipFile, ZIP_DEFLATED
7
8import dill
9import logging
10
11from omegaml.backends.basemodel import BaseModelBackend
12
13ok = lambda v, vtype: isinstance(v, vtype)
14
15logger = logging.getLogger(__name__)
16
17
18class TFEstimatorModel(object):
19 """
20 A serializable/deserizable wrapper for a TF Estimator
21
22 Usage:
23 estimator = TFEstimatorModel(estimator_fn)
24 estimator.fit(input_fn=input_fn)
25 estimator.predict(input_fn=input_fn)
26
27 The estimator_fn returns a tf.estimator.Estimator or subclass.
28 """
29
30 def __init__(self, estimator_fn, model=None, input_fn=None, model_dir=None, v1_compat=False):
31 """
32
33 Args:
34 estimator_fn (func): the function to return a valid tf.estimator.Estimator instance. Called as
35 fn(model_dir=)
36 model (tf.Estimator): an existing e.g. pre-fitted Estimator instance, optional. If not specified,
37 the model will be recreated by calling estimator_fn. If specified, the
38 model's weights and parameters will be saved and reloaded so that a fitted
39 model can be used without further training.
40 input_fn (func|dict): the function to create the input_fn as fn(mode, X, Y, batch_size=n), where mode
41 is either 'fit', 'evaluate', or 'predict'. If not provide defaults to an input_fn
42 that tries to infer the correct input_fn from the method and input arguments. If
43 provided as a dict, must contain the 'fit', 'evaluate' and 'predict' keys where
44 each value is a valid input_fn as fn(X, Y, batch_size=n).
45 model_dir (str): the model directory to use. Defaults to whatever estimator_fn/Estimator instance sets
46 v1_compat (bool): use tensorflow.compat.v1 to create create the input functions. Use this when
47 migrating tensorflow v1.x Estimator models that are not yet v2.x native yet.
48 """
49 self.estimator_fn = estimator_fn
50 self._model_dir = model_dir
51 self._estimator = model
52 self._input_fn = input_fn
53 self.v1_compat = v1_compat
54
55 @property
56 def model_dir(self):
57 return self.estimator.model_dir
58
59 @property
60 def estimator(self):
61 if self._estimator is None:
62 self._estimator = self.estimator_fn(model_dir=self._model_dir)
63 return self._estimator
64
65 def restore(self, model_dir):
66 self._estimator = None
67 self._model_dir = model_dir
68 return self
69
70 def make_input_fn(self, mode, X, Y=None, batch_size=1):
71 """
72 Return a tf.data.Dataset from the input provided
73
74 Args:
75 mode (str): calling mode, either 'fit', 'predict' or 'evaluate'
76 X (NDArray|Tensor|Dataset): features, or Dataset of (features, labels)
77 Y (NDArray|Tensor|Dataset): labels, optional
78
79 Notes:
80 X can be a Dataset of (features, labels), or just features. If X is
81 just features, also provide a Dataset of just labels.
82
83 If X, Y are NDArrays or Tensors, Dataset.from_tensor_slices((dict(X), Y))
84 is used to create the Dataset. If only X is provided as a NDArray or Tensor,
85 only X is used to create the Dataset.
86
87 If none of these options work, create your own input_fn and pass it
88 to the .fit/.predict methods using the input_fn= kwarg
89 """
90 import pandas as pd
91 import numpy as np
92
93 if self.v1_compat:
94 # https://www.tensorflow.org/guide/migrate
95 import tensorflow.compat.v1 as tf
96 tf.disable_v2_behavior()
97 else:
98 import tensorflow as tf
99
100 if self._input_fn is not None:
101 if isinstance(self._input_fn, dict):
102 return self._input_fn[mode](X, Y=Y, batch_size=batch_size)
103 else:
104 return self._input_fn(mode, X, Y=Y, batch_size=batch_size)
105
106 def input_fn():
107 # if we have a dataset, use that
108 if isinstance(X, tf.data.Dataset):
109 if Y is None:
110 return X
111 elif isinstance(Y, tf.data.Dataset):
112 return X.zip(Y)
113 else:
114 return X, Y
115 # if we have a dataframe, create a dataset from it
116 if ok(X, pd.DataFrame) and ok(Y, pd.Series):
117 dataset = tf.data.Dataset.from_tensor_slices((dict(X), Y))
118 result = dataset.batch(batch_size)
119 elif ok(X, pd.DataFrame):
120 dataset = tf.data.Dataset.from_tensor_slices(dict(X))
121 result = dataset.batch(batch_size)
122 else:
123 result = X, Y
124 return result
125
126 if isinstance(X, (dict, np.ndarray)):
127 input_fn = tf.estimator.inputs.numpy_input_fn(x=X, y=Y, num_epochs=1, shuffle=False)
128 return input_fn
129
130 def fit(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
131 """
132 Args:
133 X (Dataset|ndarray): features
134 Y (Dataset|ndarray): labels, optional
135 """
136 assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
137 if input_fn is None:
138 input_fn = self.make_input_fn('fit', X, Y, batch_size=batch_size)
139 return self.estimator.train(input_fn=input_fn, **kwargs)
140
141 def score(self, X=None, Y=None, input_fn=None, batch_size=100, **kwargs):
142 """
143 Args:
144 X (Dataset|ndarray): features
145 Y (Dataset|ndarray): labels, optional
146 """
147 assert (ok(X, object) or ok(input_fn, object)), "specify either X, Y or input_fn, not both"
148 if input_fn is None:
149 input_fn = self.make_input_fn('score', X, Y, batch_size=batch_size)
150 return self.estimator.evaluate(input_fn=input_fn)
151
152 def predict(self, X=None, Y=None, input_fn=None, batch_size=1, **kwargs):
153 """
154 Args:
155 X (Dataset|ndarray): features
156 Y (Dataset|ndarray): labels, optional
157 """
158 options1 = (X is None) and (input_fn is not None)
159 options2 = (X is not None) and (input_fn is None)
160 assert options1 or options2, "specify either X, Y or input_fn, not both"
161 if input_fn is None:
162 input_fn = self.make_input_fn('predict', X, Y, batch_size=batch_size)
163 return self.estimator.predict(input_fn=input_fn)
164
165
[docs]
166class TFEstimatorModelBackend(BaseModelBackend):
167 KIND = 'tfestimator.model'
168
[docs]
169 @classmethod
170 def supports(self, obj, name, **kwargs):
171 return isinstance(obj, TFEstimatorModel)
172
173 def _package_model(self, model, key, tmpfn, **kwargs):
174 model_dir = model.model_dir
175 fname = os.path.basename(tmpfn)
176 zipfname = os.path.join(self.model_store.tmppath, fname)
177 # get relevant parts of model_dir
178 with ZipFile(zipfname, 'w', compression=ZIP_DEFLATED) as zipf:
179 zipf.writestr('modelobj.dill', dill.dumps(model))
180 for part in glob.glob(os.path.join(model_dir, '*')):
181 arcname = os.path.basename(part)
182 if arcname == 'modelobj.dill':
183 # ignore pre-existing model
184 continue
185 zipf.write(part, arcname)
186 return zipfname
187
188 def _extract_model(self, infile, key, tmpfn, **kwargs):
189 lpath = tempfile.mkdtemp()
190 with open(tmpfn, 'wb') as pkgf:
191 pkgf.write(infile.read())
192 with ZipFile(tmpfn) as zipf:
193 zipf.extractall(lpath)
194 with open(os.path.join(lpath, 'modelobj.dill'), 'rb') as fin:
195 model = dill.load(fin)
196 model.restore(lpath)
197 return model
198
[docs]
199 def fit(self, modelname, Xname, Yname=None, pure_python=True, **kwargs):
200 model = self.model_store.get(modelname)
201 X = self.data_store.get(Xname)
202 Y = self.data_store.get(Yname) if Yname else None
203 if isfunction(X) and Y is None:
204 # support f
205 model.fit(input_fn=X)
206 else:
207 model.fit(X, Y)
208 meta = self.model_store.put(model, modelname)
209 return meta
210
[docs]
211 def predict(
212 self, modelname, Xname, rName=None, pure_python=True, **kwargs):
213 import pandas as pd
214 model = self.model_store.get(modelname)
215 X = self._resolve_input_data('predict', Xname, 'X', **kwargs)
216 if isfunction(X):
217 result = pd.DataFrame(v for v in model.predict(input_fn=X))
218 else:
219 result = pd.DataFrame(v for v in model.predict(X))
220 return self._prepare_result('predict', result, rName=rName, pure_python=pure_python, **kwargs)
221
[docs]
222 def score(
223 self, modelname, Xname, Yname=None, rName=True, pure_python=True,
224 **kwargs):
225 import pandas as pd
226 model = self.model_store.get(modelname)
227 X = self.data_store.get(Xname)
228 Y = self.data_store.get(Yname)
229 if isfunction(X) and Y is None:
230 # support f
231 result = model.fit(input_fn=X)
232 else:
233 result = model.score(X, Y)
234 if not pure_python:
235 result = pd.Series(result)
236 if rName is not None:
237 result = self.data_store.put(result, rName)
238 return result