Tensorflow¶
Tensorflow provides several types of models
Native tensorflow models
Tensorflow Keras models
Estimator models
SavedModel
omega|ml supports all model variants as trained SavedModels. Keras models and Estimator models can also be serialized to and trained by the cluster as Python instances. The runtime can execute arbitrary functions that generate a model, train and save it as a SavedModel for subsequent consumption e.g. via the model REST API.
Concepts¶
Keras models¶
Consider the following Tensorflow model (source Modelnet). This is a stanard TF Keras model that uses the MobileNetV2 for image detection and trains a new output layer.
(...)
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)
mobile_net.trainable=False
model = tf.keras.Sequential([
mobile_net,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(label_names))])
model.compile(optimizer='adam',
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=["accuracy"])
model.summary()
model.fit(ds, epochs=1, steps_per_epoch=3)
Store the model to omega|ml as follows:
om.models.put(model, 'tfkeras-flower')
Load and use the model for prediction as follows. This runs the prediction on the local computer and does not use omega|ml’s runtime cluster.
model_ = om.models.get('tfkeras-flower')
img = plt.imread('/path/to/image')
result = model_.predict(np.array([img]))
Using the runtime cluster is equally straight forward:
img = plt.imread('/path/to/image')
result = om.runtime.model('tfkeras-flower').predict(np.array([img]))
The REST API similarly provides prediction:
resp = requests.put(predict_url, json={
'columns': ['x'],
'data': [{'x': img.flatten().tolist()}],
'shape': [192, 192, 3],
})
data = resp.json()
prediction = data['result']
tf.data.Dataset¶
Estimator models support tf.data.Dataset
by means of virtual datasets. Virtual datasets are Python
functions stored by om.datasets. On accessing a virtual dataset, the function is executed and the
result is returned. Thus for Estimator models, a virtual dataset should be used to return a tf.data.Dataset
.
om.datasets
supports storing tf.train.Example
records, a tf.data.Dataset
can easily be constructed
from this.