Source code for omegaml.backends.tensorflow.tfkeras

 1import os
 2
 3from omegaml.backends.keras import KerasBackend
 4from omegaml.util import temp_filename
 5
 6
[docs] 7class TensorflowKerasBackend(KerasBackend): 8 KIND = 'tfkeras.h5' 9
[docs] 10 @classmethod 11 def supports(self, obj, name, **kwargs): 12 import tensorflow as tf 13 tfSequential = tf.keras.models.Sequential 14 tfModel = tf.keras.models.Model 15 return isinstance(obj, (tfSequential, tfModel)) and not kwargs.get('as_savedmodel')
16 17 def _save_model(self, model, fn): 18 # override to implement model saving 19 import tensorflow as tf 20 from tensorflow import keras 21 if tf.executing_eagerly(): 22 self._fix_model_for_saving(model) 23 keras.models.save_model(model, fn) 24 25 def _fix_model_for_saving(self, model): 26 # see 27 import tensorflow as tf 28 from tensorflow.python.keras import backend as K 29 with K.name_scope(model.optimizer.__class__.__name__): 30 try: 31 for i, var in enumerate(model.optimizer.weights): 32 name = 'variable{}'.format(i) 33 model.optimizer.weights[i] = tf.Variable(var, name=name) 34 except NotImplementedError: 35 pass 36 37 def _extract_model(self, infile, key, tmpfn): 38 # override to implement model loading 39 from tensorflow import keras 40 with open(tmpfn, 'wb') as pkgfn: 41 pkgfn.write(infile.read()) 42 return keras.models.load_model(tmpfn) 43
[docs] 44 def fit(self, modelname, Xname, Yname=None, pure_python=True, tpu_specs=None, **kwargs): 45 meta = self.model_store.metadata(modelname) 46 tpu_specs = tpu_specs or meta.attributes.get('tpu_specs') 47 if tpu_specs: 48 try: 49 result = self._fit_tpu(modelname, Xname, Yname=Yname, tpu_specs=tpu_specs, **kwargs) 50 except: 51 import logging 52 logger = logging.getLogger(__name__) 53 logger.warning('Error in _fit_tpu, reverting to fit on CPU') 54 else: 55 return result 56 result = super(TensorflowKerasBackend, self).fit(modelname, Xname, Yname=Yname, pure_python=pure_python, 57 **kwargs) 58 return result
59 60 def _fit_tpu(self, modelname, Xname, Yname=None, tpu_specs=None, **kwargs): 61 import tensorflow as tf 62 # adopted from https://www.dlology.com/blog/how-to-train-keras-model-x20-times-faster-with-tpu-for-free/ 63 # This address identifies the TPU we'll use when configuring TensorFlow. 64 # FIXME this will fail in tf 2.0, see https://github.com/tensorflow/tensorflow/issues/24412#issuecomment-491980177 65 assert tf.__version__.startswith('1.'), "TPU only supported on tf < 2.0" 66 tpu_device = tpu_specs or os.environ.get('COLAB_TPU_ADDR', '') 67 assert tpu_device, "there is no TPU device" 68 if tpu_device.startswith('grpc://'): 69 tpu_worker = tpu_device 70 else: 71 tpu_worker = 'grpc://' + tpu_device 72 tf.logging.set_verbosity(tf.logging.INFO) 73 model = self.get_model(modelname) 74 tpu_model = tf.contrib.tpu.keras_to_tpu_model( 75 model, 76 strategy=tf.contrib.tpu.TPUDistributionStrategy( 77 tf.contrib.cluster_resolver.TPUClusterResolver(tpu_worker))) 78 X = self.data_store.get(Xname) 79 Y = self.data_store.get(Yname) 80 tpu_model.fit(X, Y) 81 fn = temp_filename() 82 tpu_model.save_weights(fn, overwrite=True) 83 model.load_weights(fn) 84 meta = self.put(model, modelname) 85 return meta