ITの隊長のブログ

ITの隊長のブログです。Pythonを使って仕事しています。最近は機械学習をさわりはじめたお(^ω^ = ^ω^)

【Python】kerasで保存したweightsをh5pyを使って取得する

スポンサードリンク

難しかった。(というかこのファイル構造よくわからん)

$ ls
model_weights.h5 # kerasで保存したファイル
$ python
# ...
>>> import h5py

>>> model_weights = h5py.File('./model_weights.h5', 'r')
>>> model_weights.keys()
KeysView(<Attributes of HDF5 object at 4383104920>) # (´・ω・`)?

>>> model_weights.attrs.keys()
KeysView(<Attributes of HDF5 object at 4383104920>) # (´;ω;`)?

>>> list(model_weights.attrs.keys())
['layer_names', 'backend', 'keras_version'] # (`・ω・’)!

>>> list(model_weights.attrs.get('layer_names'))
[b'dense_1', b'activation_1', b'dropout_1', b'dense_2', b'activation_2', b'dropout_2', b'dense_3', b'activation_3'] # (`・ω・’)

>>> list(model_weights.attrs.get('layer_names'))[0]
b'dense_1' # (´・ω・`)?

>>> # .....悩み中

>>> # !? そういえばattrs無しで試していない!

>>> list(model_weights)
['activation_1', 'activation_2', 'activation_3', 'dense_1', 'dense_2', 'dense_3', 'dropout_1', 'dropout_2'] # うぉおおお!!!

>>> list(model_weights)['dense']
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: list indices must be integers or slices, not str # (´・ω・`)

>>> list(model_weights.get('dense_1'))
['dense_1'] # !? なにこれ・・・?

>>> model_weights['dense_1'].keys()
KeysView(<HDF5 group "/dense_1" (1 members)>) # わけわかめ

>>> model_weights['dense_1'].get('dense_1')
<HDF5 group "/dense_1/dense_1" (2 members)> # !? なるほど...

>>> list(array.get('dense_1').keys())
['bias:0', 'kernel:0']

>>> array.get('dense_1').get('bias:0')
<HDF5 dataset "bias:0": shape (512,), type "<f4">

>>> array.get('dense_1').get('bias:0')[()]
array([ -5.14987158e-04,  -1.05651123e-02,  -6.37231674e-03,
# ...
         1.46674449e-02,  -1.39807556e-02], dtype=float32) # ヾ(*´∀`*)ノキャッキャ

>>> np_array = array.get('dense_1').get('bias:0')[()]
>>> np_array.shape
(512,)

ちなみにpipでインストールすることができます。

$ pip install h5py