決定ツリーpython実装
7242 ワード
"""
@author kunji
@time 2020.5.23
@desc
:
:0 ,1 ,2
:0 ,1
:0 ,1 ,2
:C1,C2
pickle ,
"""
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
import operator
import pickle
"""
:
Parameters:
None
Returns:
dataSet -
labels -
"""
def createDataSet():
#
dataSet = [[0, 0, 1, 'C1'],
[1, 0, 0, 'C2'],
[2, 1, 1, 'C1'],
[0, 0, 0, 'C1'],
[0, 1, 2, 'C1'],
[2, 1, 1, 'C2'],
[1, 0, 0, 'C2'],
[1, 0, 2, 'C2']]
#
labels = [' ', ' ', ' ']
#
return dataSet, labels
"""
: ( )
Ent(D) = -SUM(kp*Log2(kp))
Parameters:
dataSet -
Returns:
shannonEnt - ( )
"""
def calcShannonEnt(dataSet):
#
numEntires = len(dataSet)
# (Label) “ ”
labelCounts = {}
#
for featVec in dataSet:
# (Label)
currentLabel = featVec[-1]
# (Label) ,
if currentLabel not in labelCounts.keys():
# , currentLabel 0
labelCounts[currentLabel] = 0
# Label
labelCounts[currentLabel] += 1
# ( )
shannonEnt = 0.0
#
for key in labelCounts:
# (Label)
prob = float(labelCounts[key]) / numEntires
#
shannonEnt -= prob * log(prob, 2)
# ( )
return shannonEnt
"""
:
Parameters:
dataSet -
axis -
values -
Returns:
None
"""
def splitDataSet(dataSet, axis, value):
#
retDataSet = []
#
for featVec in dataSet:
if featVec[axis] == value:
# axis
reducedFeatVec = featVec[:axis]
#
# extend() ( )。
reducedFeatVec.extend(featVec[axis + 1:])
#
retDataSet.append(reducedFeatVec)
#
return retDataSet
"""
:
Gain(D,g) = Ent(D) - SUM(|Dv|/|D|)*Ent(Dv)
Parameters:
dataSet -
Returns:
bestFeature - ( )
"""
def chooseBestFeatureToSplit(dataSet):
#
numFeatures = len(dataSet[0]) - 1
#
baseEntropy = calcShannonEnt(dataSet)
#
bestInfoGain = 0.0
#
bestFeature = -1
#
for i in range(numFeatures):
# dataSet i featList ( )
featList = [example[i] for example in dataSet]
# set {}, ,
# python
uniqueVals = set(featList)
#
newEntropy = 0.0
#
for value in uniqueVals:
# subDataSet
subDataSet = splitDataSet(dataSet, i, value)
#
prob = len(subDataSet) / float(len(dataSet))
#
newEntropy += prob * calcShannonEnt(subDataSet)
#
infoGain = baseEntropy - newEntropy
#
print(" %d %.3f" % (i, infoGain))
#
if (infoGain > bestInfoGain):
# ,
bestInfoGain = infoGain
#
bestFeature = i
#
return bestFeature
"""
: classList ( )
Parameters:
classList -
Returns:
sortedClassCount[0][0] - ( )
"""
def majorityCnt(classList):
classCount = {}
# classList
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
#
# operator.itemgetter(1) 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# classList
return sortedClassCount[0][0]
"""
: (ID3 )
:1、 ,
2、 , ,
Parameters:
dataSet -
labels -
featLabels -
Returns:
myTree -
"""
def createTree(dataSet, labels, featLabels):
# (:C1 or C2)
classList = [example[-1] for example in dataSet]
#
if classList.count(classList[0]) == len(classList):
return classList[0]
#
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#
bestFeat = chooseBestFeatureToSplit(dataSet)
#
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
#
myTree = {bestFeatLabel: {}}
#
del (labels[bestFeat])
#
featValues = [example[bestFeat] for example in dataSet]
#
uniqueVals = set(featValues)
# ,
for value in uniqueVals:
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
"""
:
Parameters:
myTree -
Returns:
numLeafs -
"""
def getNumLeafs(myTree):
#
numLeafs = 0
# python3 myTree.keys() dict_keys, list,
# myTree.keys()[0] , list(myTree.keys())[0]
# next() next(iterator[, default])
firstStr = next(iter(myTree))
#
secondDict = myTree[firstStr]
for key in secondDict.keys():
# , ,
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
"""
:
Parameters:
myTree -
Returns:
maxDepth -
"""
def getTreeDepth(myTree):
#
maxDepth = 0
# python3 myTree.keys() dict_keys, list,
# myTree.keys()[0] , list(myTree.keys())[0]
# next() next(iterator[, default])
firstStr = next(iter(myTree))
#
secondDict = myTree[firstStr]
for key in secondDict.keys():
# , ,
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
#
if thisDepth > maxDepth:
maxDepth = thisDepth
#
return maxDepth
"""
:
Parameters:
nodeTxt -
centerPt -
parentPt -
nodeType -
Returns:
None
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#
arrow_args = dict(arrowstyle="