決定木で、テストデータがどの葉に分類されるかを自動でやりたいとき、なんかないかなーと探してたら公式ドキュメントにあった。
# decision_pathで全体の結果が返ってくる? node_indicator = estimator.decision_path(X_test) # テストデータの行番号 sample_id = 0 # 0行目のデータを決定木に入力したときに、通る順番のノード番号が入っている. node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]] # ということは、このリストの最後に入っているノード番号が、葉のノードを示していることになる. last_idx = node_index[-1] # ノード番号をkeyに、そこに到達する行番号を保持するようにする last_node_d = {} last_node_d[last_idx] = [] last_node_d[last_idx].append(sample_id) last_node_d = {} for sample_id in range(X_test.shape[0]): node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]] last_idx = node_index[-1] if not last_idx in last_node_d.keys(): last_node_d[last_idx] = [] last_node_d[last_idx].append(sample_id)
これで、あるノード番号に含まれるテストデータを抽出することができる.