我们学完决策树算法原理,来通过Python代码实战一下。
首先,导入数据集合各种用到的包。并且显示鸢尾花数据集信息。
from sklearn.datasets import load_iris #导入鸢尾花数据集
from sklearn import tree #导入决策树模块
iris = load_iris() #加载数据集
print(iris.DESCR) #显示数据集信息
用到的是鸢尾花数据集。
然后看一下鸢尾花的四个特征数据:
iris.data#数据
查看鸢尾花的分类:
iris.target#类别
0、1、2分别表示3种鸢尾花。
引入决策树算法并训练模型:
clf = tree.DecisionTreeClassifier()
#clf = tree.DecisionTreeClassifier(max_depth=2)#设置最深的树的层数
clf.fit(iris.data, iris.target)#训练决策树
从下图可以看出决策树的默认参数。
接下来进行预测(建议划分数据集操作):
clf.predict(iris.data)#预测结果
预测结果如上所示。
接下来绘制决策树。
tree.export_graphviz(clf, out_file='tree.dot')#画出决策树
在jupyter notebook中显示决策树:
%matplotlib inline
from IPython.display import Image
Image('tree.png')
# X[3] petal width
#values【 , , 】三种的数量
如果想要绘制出上图的决策树,需要做一些额外操作。
https://graphviz.gitlab.io/_pages/Download/Download_windows.html
下载后正常安装,c盘programx86文件夹中,bin目录加入环境变量中
cmd进入图片目录
cd C:\Users\Administrator\代码
dot -T png tree.dot -o tree.png
其实就是将.dot文件转为了.png格式的图片。
这就是Python实现决策树的内容,你学会了么?