ITの隊長のブログ

ITの隊長のブログです。Rubyを使って仕事しています。最近も色々やっているお(^ω^ = ^ω^)

Tensorflow2でKerasみたいな保存の仕方すると死ぬ

スポンサードリンク

タイトルはてきとーにつけたので正しくはない

Colaboratoryで遊んでいるとき、他notebookで保存したモデルを読み込みたかった.

import tensorflow as tf


model = None  # 学習済みのモデルを想定
# Model is the full model w/o custom layers
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])

model.save('/content/drive/My Drive/model.h5')

保存に成功したので、別notebookでロードする.

model = tf.keras.models.load_model('/content/drive/My Drive/model.h5')

ValueError: Unknown entry in loss dictionary: class_name. Only expected following keys: ['xxx']

??????

どゆことと、調べたら下記issueがヒット

github.com

i find the reason, in my code: loss_object=tf.losses.SparseCategoricalCrossentropy() model.complie(loss=loss_object, optimizer="sgd")

it raise the error.

the i change my code to model.complie(loss="sparse_categorical_crossentropy", optimizer="sgd")

it is ok

なんだと・・・・

いや、それはそれでいいんだけど、カスタムlossとかの場合はどうなるんじゃろ???

それはさておき、issue出した人はPRも出していた

github.com

これをコピペしてロードしたら動いた

しかしどうしようかな。。。