ITの隊長のブログ

ITの隊長のブログです。Rubyを使って仕事しています。最近も色々やっているお(^ω^ = ^ω^)

雑ログ

スポンサードリンク

決定木で、テストデータがどの葉に分類されるかを自動でやりたいとき、なんかないかなーと探してたら公式ドキュメントにあった。

scikit-learn.org

# decision_pathで全体の結果が返ってくる?
node_indicator = estimator.decision_path(X_test)


# テストデータの行番号
sample_id = 0

# 0行目のデータを決定木に入力したときに、通る順番のノード番号が入っている.
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]]

# ということは、このリストの最後に入っているノード番号が、葉のノードを示していることになる.
last_idx = node_index[-1]

# ノード番号をkeyに、そこに到達する行番号を保持するようにする
last_node_d = {}
last_node_d[last_idx] = []
last_node_d[last_idx].append(sample_id)

last_node_d = {}
for sample_id in range(X_test.shape[0]):
    node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]]
    last_idx = node_index[-1]
    if not last_idx in last_node_d.keys():
        last_node_d[last_idx] = []

    last_node_d[last_idx].append(sample_id)

これで、あるノード番号に含まれるテストデータを抽出することができる.