sklearn decision tree

Decission Tree


載入模組
from sklearn import tree

初始化model
< model>= tree.DecisionTreeClassifier()
參數列表
DecisionTreeClassifier(criterion='gini', splitter='best', max_depth=None, min_samples_split=2,min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None,class_weight=None, presort=False)
參數說明
criterion: (default is "gini"),"gini" for the Gini impurity and "entropy" for the information gain.
 gini ,計算快
 entropy ,切割效果比gini好,但計算速度較久

 

讓model學習
< model>.fit(intput, output)

看model學習狀況
< model>.score(intput, output)

根據input預測output
< model>.predict(input )

顯示特徵重要性, 數字越高表示此欄位特徵越明顯
print < model>.feature_importances_

 

 

refer
scikit-learn uses an optimised version of the CART algorithm.
http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#


...........................................................................

視覺化Decision Tree判斷結果 


透過圖形顯示decission tree的決策過程
需要
export_graphviz和dot這兩個套件

 

Install
on ubuntu
#sudo apt-get install graphviz python-pydot

 

export_graphviz
用法如下
tree.export_graphviz(,out_file=< outpath>,feature_names=< list of feature>)
ex:
import os
from sklearn.datasets import load_iris
iris = load_iris()
from sklearn import tree
clf = tree.DecisionTreeClassifier()
model = clf.fit(iris.data, iris.target)
export_file = tree.export_graphviz( model ,out_file='tree.dot')

refer
http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html#sklearn.tree.export_graphviz


pydot
檢視dot格式,或輸出成圖形 
ex:
dot -Tpng tree.dot -o tree.png


tree.dot格式簡單說明如下  
digraph Tree {
0 [label="X[13] <= 416.0000ngini = 0.000270672616505nsamples = 100", shape="box"] ;
1 [label="gini = 0.0000nsamples = 7387nvalue = [ 99. 0.]", shape="box"] ;
0 -> 1 ;
2 [label="gini = 0.0000nsamples = 1nvalue = [ 0. 1.]", shape="box"] ;
0 -> 2 ;
}
[類別0,類別1]
第13個欄位小於416,會讓99個屬於類別0,1個屬於類別1

 

......................................................................................

Example in simple dataset

 

 

#vi dataset
class,packet,traffic
0,1,4
0, 2,3
1, 3 ,2
1, 4 ,1

 

#vi train.py
import numpy as np
import sys
import os

##############
filepath=sys.argv[1]
f = open(filepath)
dataset = np.loadtxt(f,delimiter=',',skiprows=1)
target=dataset[:,0]
data=dataset[:,1:]

f.seek(0)
listhead=f.readlines()[0].strip().split(',')[1:]

from sklearn import tree
clf = tree.DecisionTreeClassifier()
result = clf.fit(data,target)
tree.export_graphviz(result,out_file=filepath+'_tree.dot',feature_names=listhead)

### list important feature ( result is few different every compute)
importfeature=clf.feature_importances_
dictname=dict()
for name,value in zip(listhead,importfeature):
 dictname[name]=value
dicsorted= sorted(dictname.iteritems(), key=lambda d:d[1], reverse = True)
for line in dicsorted:
 print line

 

 

 

 

2019-07-27 16:09:20發表 2019-07-27 16:45:24修改   

金融科技
數據分析

程式開發
計算機組織與結構
資料結構與演算法
Database and MySql
manage tool
windows
unix-like
linux service
network
network layer3
network layer2
network WAN
network service
作業系統
數位鑑識
資訊安全解決方案
資訊安全威脅
Cisco security
Cisco network
Cisco layer3
Cisco layer2



  登入      [牛的大腦] | [單字我朋友] Powered by systw.net