Python

如何打印 RandomForestClassifier 的決策樹

  • October 6, 2014

最近,我注意到這裡sklearn.tree.export_graphviz記錄了一種方法。

但是,我不知道如何將其應用於RandomForestClassifier.

我嘗試了以下幼稚的代碼,但它不起作用,而且我不知道如何從 a 中獲取其中一棵樹RandomForestClassifier

print('Training...')
forest = RandomForestClassifier(n_estimators=100)
forest = forest.fit( train_data[0::,1::], train_data[0::,0] )


print('Predicting...')
output = forest.predict(test_data).astype(int)

if sys.version_info >= (3,0,0):
   predictions_file = open("myfirstforest.csv", 'w', newline='')
else:
   predictions_file = open("myfirstforest.csv", 'wb')


tree.export_graphviz(forest, out_file='tree.dot')

您的 RandomForest 創建了 100 棵樹,因此您無法一步打印這些。嘗試遍歷森林中的樹木並一一打印出來:

from sklearn import tree
i_tree = 0
for tree_in_forest in forest.estimators_:
   with open('tree_' + str(i_tree) + '.dot', 'w') as my_file:
       my_file = tree.export_graphviz(tree_in_forest, out_file = my_file)
   i_tree = i_tree + 1

如果你想知道樹的實際參數,如分裂屬性(特徵)、分裂值(閾值)、節點樣本(n_node_samples)等,可以print getmembers(tree_in_forest.tree_)在for循環中使用。要使用這些參數之一,例如。閾值,使用這個:tree_in_forest.tree_.threshold它返回一個列表。

引用自:https://stats.stackexchange.com/questions/118016

comments powered by Disqus