Source code for omegaml.backends.tensorflow.tfkeras

 1import os
 2
 3from omegaml.backends.keras import KerasBackend
 4from omegaml.util import temp_filename
 5
 6
[docs] 7class TensorflowKerasBackend(KerasBackend): 8 """ 9 .. versionchanged:: 0.18.0 10 Only supported for tensorflow <= 2.15 and Python <= 3.11 11 12 .. deprecated:: 0.18.0 13 Use an object helper or a serializer/loader combination instead. 14 """ 15 KIND = 'tfkeras.h5' 16
[docs] 17 @classmethod 18 def supports(self, obj, name, **kwargs): 19 import tensorflow as tf 20 tfSequential = tf.keras.models.Sequential 21 tfModel = tf.keras.models.Model 22 return isinstance(obj, (tfSequential, tfModel)) and not kwargs.get('as_savedmodel')
23 24 def _save_model(self, model, fn): 25 # override to implement model saving 26 import tensorflow as tf 27 from tensorflow import keras 28 if tf.executing_eagerly(): 29 self._fix_model_for_saving(model) 30 keras.models.save_model(model, fn) 31 32 def _fix_model_for_saving(self, model): 33 # see 34 import tensorflow as tf 35 from tensorflow.python.keras import backend as K 36 with K.name_scope(model.optimizer.__class__.__name__): 37 try: 38 for i, var in enumerate(model.optimizer.weights): 39 name = 'variable{}'.format(i) 40 model.optimizer.weights[i] = tf.Variable(var, name=name) 41 except NotImplementedError: 42 pass 43 44 def _extract_model(self, infile, key, tmpfn, **kwargs): 45 # override to implement model loading 46 from tensorflow import keras 47 with open(tmpfn, 'wb') as pkgfn: 48 pkgfn.write(infile.read()) 49 return keras.models.load_model(tmpfn) 50
[docs] 51 def fit(self, modelname, Xname, Yname=None, pure_python=True, tpu_specs=None, **kwargs): 52 meta = self.model_store.metadata(modelname) 53 tpu_specs = tpu_specs or meta.attributes.get('tpu_specs') 54 if tpu_specs: 55 try: 56 result = self._fit_tpu(modelname, Xname, Yname=Yname, tpu_specs=tpu_specs, **kwargs) 57 except: 58 import logging 59 logger = logging.getLogger(__name__) 60 logger.warning('Error in _fit_tpu, reverting to fit on CPU') 61 else: 62 return result 63 result = super(TensorflowKerasBackend, self).fit(modelname, Xname, Yname=Yname, pure_python=pure_python, 64 **kwargs) 65 return result
66 67 def _fit_tpu(self, modelname, Xname, Yname=None, tpu_specs=None, **kwargs): 68 import tensorflow as tf 69 # adopted from https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/ 70 # This address identifies the TPU we'll use when configuring TensorFlow. 71 # FIXME this will fail in tf 2.0, see https://github.com/tensorflow/tensorflow/issues/24412#issuecomment-491980177 72 assert tf.__version__.startswith('1.'), "TPU only supported on tf < 2.0" 73 tpu_device = tpu_specs or os.environ.get('COLAB_TPU_ADDR', '') 74 assert tpu_device, "there is no TPU device" 75 if tpu_device.startswith('grpc://'): 76 tpu_worker = tpu_device 77 else: 78 tpu_worker = 'grpc://' + tpu_device 79 tf.logging.set_verbosity(tf.logging.INFO) 80 model = self.get_model(modelname) 81 tpu_model = tf.contrib.tpu.keras_to_tpu_model( 82 model, 83 strategy=tf.contrib.tpu.TPUDistributionStrategy( 84 tf.contrib.cluster_resolver.TPUClusterResolver(tpu_worker))) 85 X = self.data_store.get(Xname) 86 Y = self.data_store.get(Yname) 87 tpu_model.fit(X, Y) 88 fn = temp_filename() 89 tpu_model.save_weights(fn, overwrite=True) 90 model.load_weights(fn) 91 meta = self.put(model, modelname) 92 return meta