Decission Tree
載入模組
from sklearn import tree
初始化model
< model>= tree.DecisionTreeClassifier()
參數列表
DecisionTreeClassifier(criterion=’gini’, splitter=’best’, max_depth=None, min_samples_split=2,min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None,class_weight=None, presort=False)
參數說明
criterion: (default is “gini”),”gini” for the Gini impurity and “entropy” for the information gain.
gini ,計算快
entropy ,切割效果比gini好,但計算速度較久
讓model學習
< model>.fit(intput, output)
看model學習狀況
< model>.score(intput, output)
根據input預測output
< model>.predict(input )
顯示特徵重要性, 數字越高表示此欄位特徵越明顯
print < model>.feature_importances_
refer
scikit-learn uses an optimised version of the CART algorithm.
http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#.
………………………………………………………………..
視覺化Decision Tree判斷結果
透過圖形顯示decission tree的決策過程
需要export_graphviz和dot這兩個套件
Install
on ubuntu
#sudo apt-get install graphviz python-pydot
export_graphviz
用法如下
tree.export_graphviz(,out_file=< outpath>,feature_names=< list of feature>)
ex:
import os
from sklearn.datasets import load_iris
iris = load_iris()
from sklearn import tree
clf = tree.DecisionTreeClassifier()
model = clf.fit(iris.data, iris.target)
export_file = tree.export_graphviz( model ,out_file='tree.dot')
refer
http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html#sklearn.tree.export_graphviz
pydot
檢視dot格式,或輸出成圖形
ex:
dot -Tpng tree.dot -o tree.png
tree.dot格式簡單說明如下
digraph Tree {
0 [label=”X[13] <= 416.0000ngini = 0.000270672616505nsamples = 100″, shape=”box”] ;
1 [label=”gini = 0.0000nsamples = 7387nvalue = [ 99. 0.]”, shape=”box”] ;
0 -> 1 ;
2 [label=”gini = 0.0000nsamples = 1nvalue = [ 0. 1.]”, shape=”box”] ;
0 -> 2 ;
}
[類別0,類別1]
第13個欄位小於416,會讓99個屬於類別0,1個屬於類別1
…………………………………………………………………………..
Example in simple dataset
#vi dataset
class,packet,traffic
0,1,4
0, 2,3
1, 3 ,2
1, 4 ,1
#vi train.py
import numpy as np
import sys
import os
##############
filepath=sys.argv[1]
f = open(filepath)
dataset = np.loadtxt(f,delimiter=',',skiprows=1)
target=dataset[:,0]
data=dataset[:,1:]
f.seek(0)
listhead=f.readlines()[0].strip().split(',')[1:]
from sklearn import tree
clf = tree.DecisionTreeClassifier()
result = clf.fit(data,target)
tree.export_graphviz(result,out_file=filepath+'_tree.dot',feature_names=listhead)
### list important feature ( result is few different every compute)
importfeature=clf.feature_importances_
dictname=dict()
for name,value in zip(listhead,importfeature):
dictname[name]=value
dicsorted= sorted(dictname.iteritems(), key=lambda d:d[1], reverse = True)
for line in dicsorted:
print line