1import os
2
3from omegaml.backends.keras import KerasBackend
4from omegaml.util import temp_filename
5
6
[docs]
7class TensorflowKerasBackend(KerasBackend):
8 """
9 .. versionchanged:: 0.18.0
10 Only supported for tensorflow <= 2.15 and Python <= 3.11
11
12 .. deprecated:: 0.18.0
13 Use an object helper or a serializer/loader combination instead.
14 """
15 KIND = 'tfkeras.h5'
16
[docs]
17 @classmethod
18 def supports(self, obj, name, **kwargs):
19 import tensorflow as tf
20 tfSequential = tf.keras.models.Sequential
21 tfModel = tf.keras.models.Model
22 return isinstance(obj, (tfSequential, tfModel)) and not kwargs.get('as_savedmodel')
23
24 def _save_model(self, model, fn):
25 # override to implement model saving
26 import tensorflow as tf
27 from tensorflow import keras
28 if tf.executing_eagerly():
29 self._fix_model_for_saving(model)
30 keras.models.save_model(model, fn)
31
32 def _fix_model_for_saving(self, model):
33 # see
34 import tensorflow as tf
35 from tensorflow.python.keras import backend as K
36 with K.name_scope(model.optimizer.__class__.__name__):
37 try:
38 for i, var in enumerate(model.optimizer.weights):
39 name = 'variable{}'.format(i)
40 model.optimizer.weights[i] = tf.Variable(var, name=name)
41 except NotImplementedError:
42 pass
43
44 def _extract_model(self, infile, key, tmpfn, **kwargs):
45 # override to implement model loading
46 from tensorflow import keras
47 with open(tmpfn, 'wb') as pkgfn:
48 pkgfn.write(infile.read())
49 return keras.models.load_model(tmpfn)
50
[docs]
51 def fit(self, modelname, Xname, Yname=None, pure_python=True, tpu_specs=None, **kwargs):
52 meta = self.model_store.metadata(modelname)
53 tpu_specs = tpu_specs or meta.attributes.get('tpu_specs')
54 if tpu_specs:
55 try:
56 result = self._fit_tpu(modelname, Xname, Yname=Yname, tpu_specs=tpu_specs, **kwargs)
57 except:
58 import logging
59 logger = logging.getLogger(__name__)
60 logger.warning('Error in _fit_tpu, reverting to fit on CPU')
61 else:
62 return result
63 result = super(TensorflowKerasBackend, self).fit(modelname, Xname, Yname=Yname, pure_python=pure_python,
64 **kwargs)
65 return result
66
67 def _fit_tpu(self, modelname, Xname, Yname=None, tpu_specs=None, **kwargs):
68 import tensorflow as tf
69 # adopted from https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/
70 # This address identifies the TPU we'll use when configuring TensorFlow.
71 # FIXME this will fail in tf 2.0, see https://github.com/tensorflow/tensorflow/issues/24412#issuecomment-491980177
72 assert tf.__version__.startswith('1.'), "TPU only supported on tf < 2.0"
73 tpu_device = tpu_specs or os.environ.get('COLAB_TPU_ADDR', '')
74 assert tpu_device, "there is no TPU device"
75 if tpu_device.startswith('grpc://'):
76 tpu_worker = tpu_device
77 else:
78 tpu_worker = 'grpc://' + tpu_device
79 tf.logging.set_verbosity(tf.logging.INFO)
80 model = self.get_model(modelname)
81 tpu_model = tf.contrib.tpu.keras_to_tpu_model(
82 model,
83 strategy=tf.contrib.tpu.TPUDistributionStrategy(
84 tf.contrib.cluster_resolver.TPUClusterResolver(tpu_worker)))
85 X = self.data_store.get(Xname)
86 Y = self.data_store.get(Yname)
87 tpu_model.fit(X, Y)
88 fn = temp_filename()
89 tpu_model.save_weights(fn, overwrite=True)
90 model.load_weights(fn)
91 meta = self.put(model, modelname)
92 return meta