ITの隊長のブログ

ITの隊長のブログです。Rubyを使って仕事しています。最近も色々やっているお(^ω^ = ^ω^)

Tensorflow2系のサブクラスモデルの構造は保存できない

スポンサードリンク

import tensorflow as tf


class MyModel(tf.keras.Model):
  def __init__(self, output_node):
    super(MyModel, self).__init__()
    self.d1 = tf.keras.layers.Dense(128, activation='relu')
    self.d2 = tf.keras.layers.Dense(output_node, activation='softmax')
  
  def call(self, x):
    x = self.d1(x)
    return self.d2(x)

こんな簡単なモデルを用意します。

model = MyModel(output_node=10)
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop())
model.build(input_shape=(32, 10))
model.summary()

モデルをビルドして構造確認するとこんな感じ

Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  1408      
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
=================================================================
Total params: 2,698
Trainable params: 2,698
Non-trainable params: 0
_________________________________________________________________

Keras触っているならなんら普通な流れ.

問題はここから。

tf.keras.experimental.export_saved_model(model, 'path_to_saved_model', serving_only=True)
# W0506 10:36:19.523535 140296071452544 saved_model.py:124] Skipped saving model JSON, subclassed model does not have get_config() defined.

ドキュメント確認すると↑のコードで保存ができるらしい。が、しかし

model2 = tf.keras.experimental.load_from_saved_model('path_to_saved_model')
# NotFoundError: path_to_saved_model/assets/saved_model.json; No such file or directory

???

前のコードをよくみると W0506 10:36:19.523535 140296071452544 saved_model.py:124] Skipped saving model JSON, subclassed model does not have get_config() defined. と、書いてあった。なるほど。。。

で、ここから下記対応したが駄目だった

  • サブクラスで、 get_config を実装する

  • pickle で保存できないか確認してみる

    • TypeError: can't pickle weakref objects なるほど。。。

というわけでうまくいかんかったので、 model.summary() の結果を保存することにした。。。

これいつか対応してくれるとうれしいな

保存できる方法について

サブクラスを使わずに、KerasのSequentialと、Functional APIを使えばいける。

Sequential

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(10,), batch_size=32),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop())
model.build(input_shape=(32, 10))
tf.keras.experimental.export_saved_model(model, 'path_to_saved_model', serving_only=True)

Functional API

inputs = tf.keras.Input(shape=(32, 10))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
predictions = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)

model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop())
model.build(input_shape=(32, 10))
tf.keras.experimental.export_saved_model(model, 'path_to_saved_model', serving_only=True)