1import os
2
3from omegaml.backends.keras import KerasBackend
4from omegaml.util import temp_filename
5
6
[docs]
7class TensorflowKerasBackend(KerasBackend):
8 KIND = 'tfkeras.h5'
9
[docs]
10 @classmethod
11 def supports(self, obj, name, **kwargs):
12 import tensorflow as tf
13 tfSequential = tf.keras.models.Sequential
14 tfModel = tf.keras.models.Model
15 return isinstance(obj, (tfSequential, tfModel)) and not kwargs.get('as_savedmodel')
16
17 def _save_model(self, model, fn):
18 # override to implement model saving
19 import tensorflow as tf
20 from tensorflow import keras
21 if tf.executing_eagerly():
22 self._fix_model_for_saving(model)
23 keras.models.save_model(model, fn)
24
25 def _fix_model_for_saving(self, model):
26 # see
27 import tensorflow as tf
28 from tensorflow.python.keras import backend as K
29 with K.name_scope(model.optimizer.__class__.__name__):
30 try:
31 for i, var in enumerate(model.optimizer.weights):
32 name = 'variable{}'.format(i)
33 model.optimizer.weights[i] = tf.Variable(var, name=name)
34 except NotImplementedError:
35 pass
36
37 def _extract_model(self, infile, key, tmpfn):
38 # override to implement model loading
39 from tensorflow import keras
40 with open(tmpfn, 'wb') as pkgfn:
41 pkgfn.write(infile.read())
42 return keras.models.load_model(tmpfn)
43
[docs]
44 def fit(self, modelname, Xname, Yname=None, pure_python=True, tpu_specs=None, **kwargs):
45 meta = self.model_store.metadata(modelname)
46 tpu_specs = tpu_specs or meta.attributes.get('tpu_specs')
47 if tpu_specs:
48 try:
49 result = self._fit_tpu(modelname, Xname, Yname=Yname, tpu_specs=tpu_specs, **kwargs)
50 except:
51 import logging
52 logger = logging.getLogger(__name__)
53 logger.warning('Error in _fit_tpu, reverting to fit on CPU')
54 else:
55 return result
56 result = super(TensorflowKerasBackend, self).fit(modelname, Xname, Yname=Yname, pure_python=pure_python,
57 **kwargs)
58 return result
59
60 def _fit_tpu(self, modelname, Xname, Yname=None, tpu_specs=None, **kwargs):
61 import tensorflow as tf
62 # adopted from https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/
63 # This address identifies the TPU we'll use when configuring TensorFlow.
64 # FIXME this will fail in tf 2.0, see https://github.com/tensorflow/tensorflow/issues/24412#issuecomment-491980177
65 assert tf.__version__.startswith('1.'), "TPU only supported on tf < 2.0"
66 tpu_device = tpu_specs or os.environ.get('COLAB_TPU_ADDR', '')
67 assert tpu_device, "there is no TPU device"
68 if tpu_device.startswith('grpc://'):
69 tpu_worker = tpu_device
70 else:
71 tpu_worker = 'grpc://' + tpu_device
72 tf.logging.set_verbosity(tf.logging.INFO)
73 model = self.get_model(modelname)
74 tpu_model = tf.contrib.tpu.keras_to_tpu_model(
75 model,
76 strategy=tf.contrib.tpu.TPUDistributionStrategy(
77 tf.contrib.cluster_resolver.TPUClusterResolver(tpu_worker)))
78 X = self.data_store.get(Xname)
79 Y = self.data_store.get(Yname)
80 tpu_model.fit(X, Y)
81 fn = temp_filename()
82 tpu_model.save_weights(fn, overwrite=True)
83 model.load_weights(fn)
84 meta = self.put(model, modelname)
85 return meta