决策树(下)

阿里云2000元红包!本站用户参与享受九折优惠!

用Matplotlib绘制树形图

Matplotlib annotations

Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图形上添加文本注解。

import matplotlib.pyplot as plt
# 中文显示配置
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 定义文本框和箭头格式
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle = '<-')
def createPlot():
    fig = plt.figure(1, facecolor='black')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords='axes fraction',\
                           xytext = centerPt, textcoords='axes fraction',\
                           va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)

执行createPlot()得到如下图形。


注解

构造注解树

绘制树需要知道有多少个叶节点,以便确定x轴的长度,还需要知道树有多少层,以便确定y轴的高度。

# 获取叶节点的个数
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]   # 根节点
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':    # 测试节点的类型是否为dict
            numLeafs += getNumLeafs(secondDict[key])    # 递归遍历
        else:
            numLeafs += 1
    return numLeafs
# 获取树的层树
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth
# 预存树信息。
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[I]
# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)
    
def plotTree(myTree, parentPt, nodeTxt):
    # 计算宽高,决定x轴和y轴的长度
    numLeafs = getNumLeafs(myTree)  
    depth = getTreeDepth(myTree)
    
    firstStr = list(myTree.keys())[0]     # 节点文本
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict': # 看节点类型是否为字典,不是则为叶节点
            plotTree(secondDict[key],cntrPt,str(key))        # 递归
        else:   
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
    fig = plt.figure(1, facecolor='black')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

执行结果


决策树

测试和存储分类器

测试算法:使用决策树执行分类

# 使用决策树的分类器
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

用之前预存树信息的retrieveTree函数获得一棵树,然后进行分类。

myTree = retrieveTree(0)
classify(myTree, labels, [1,1])

结果为yes。分类正确。

决策树的存储

构建决策树是很耗时的任务,为了节省时间,最好是直接用已构建好的决策树。为了解决这个问题,需要用模块pickle存储决策树。

# 使用pickle存储决策树
import pickle
def storeTree(inputTree, filename):
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()
    
def grabTree(filename):
    fr = open(filename)
    return pickle.load(fr)

小结

开始处理数据集时,首先需要计算熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。

ID3算法可以用于划分标称型数据集。

示例中的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配的问题。可以通过裁剪树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题。

还有其他决策树的构造算法,例如,C4.5CART

https://www.jianshu.com/p/ba922c7ce727

Python量化投资网携手4326手游为资深游戏玩家推荐:《《纯白魔女》:【转自投石姬】系统介绍-仇恨机制说明

「点点赞赏,手留余香」

    还没有人赞赏,快来当第一个赞赏的人吧!
0 条回复 A 作者 M 管理员
    所有的伟大,都源于一个勇敢的开始!
欢迎您,新朋友,感谢参与互动!欢迎您 {{author}},您在本站有{{commentsCount}}条评论