import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self, output_node):
super(MyModel, self).__init__()
self.d1 = tf.keras.layers.Dense(128, activation='relu')
self.d2 = tf.keras.layers.Dense(output_node, activation='softmax')
def call(self, x):
x = self.d1(x)
return self.d2(x)
こんな簡単なモデルを用意します。
model = MyModel(output_node=10)
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop())
model.build(input_shape=(32, 10))
model.summary()
モデルをビルドして構造確認するとこんな感じ
Model: "my_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) multiple 1408
_________________________________________________________________
dense_1 (Dense) multiple 1290
=================================================================
Total params: 2,698
Trainable params: 2,698
Non-trainable params: 0
_________________________________________________________________
Keras触っているならなんら普通な流れ.
問題はここから。
tf.keras.experimental.export_saved_model(model, 'path_to_saved_model', serving_only=True)
ドキュメント確認すると↑のコードで保存ができるらしい。が、しかし
model2 = tf.keras.experimental.load_from_saved_model('path_to_saved_model')
???
前のコードをよくみると W0506 10:36:19.523535 140296071452544 saved_model.py:124] Skipped saving model JSON, subclassed model does not have get_config() defined.
と、書いてあった。なるほど。。。
で、ここから下記対応したが駄目だった
サブクラスで、 get_config
を実装する
NotImplementedError
??? コード読んでみたけどわからんかった
- コードをちゃんと読んでみた
pickle
で保存できないか確認してみる
TypeError: can't pickle weakref objects
なるほど。。。
というわけでうまくいかんかったので、 model.summary()
の結果を保存することにした。。。
これいつか対応してくれるとうれしいな
保存できる方法について
サブクラスを使わずに、KerasのSequential
と、Functional APIを使えばいける。
Sequential
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(10,), batch_size=32),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop())
model.build(input_shape=(32, 10))
tf.keras.experimental.export_saved_model(model, 'path_to_saved_model', serving_only=True)
Functional API
inputs = tf.keras.Input(shape=(32, 10))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
predictions = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop())
model.build(input_shape=(32, 10))
tf.keras.experimental.export_saved_model(model, 'path_to_saved_model', serving_only=True)