決定ツリーpython実装

7242 ワード

"""
@author kunji
@time 2020.5.23
@desc       

          :
  :0    ,1    ,2    
  :0   ,1   
  :0    ,1    ,2    
  :C1,C2
pickle           ,        
"""
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
import operator
import pickle

"""
    :       
Parameters:
    None
Returns:
    dataSet -    
    labels -     
"""


def createDataSet():
    #    
    dataSet = [[0, 0,  1, 'C1'],
               [1, 0,  0, 'C2'],
               [2, 1,  1, 'C1'],
               [0, 0,  0, 'C1'],
               [0, 1,  2, 'C1'],
               [2, 1,  1, 'C2'],
               [1, 0,  0, 'C2'],
               [1, 0,  2, 'C2']]


    #     
    labels = ['  ', '  ', '  ']
    #           
    return dataSet, labels


"""
    :           (   )
        Ent(D) = -SUM(kp*Log2(kp))
Parameters:
    dataSet -    

Returns:
    shannonEnt -    (   )

"""


def calcShannonEnt(dataSet):
    #         
    numEntires = len(dataSet)
    #       (Label)     “  ”
    labelCounts = {}
    #            
    for featVec in dataSet:
        #     (Label)  
        currentLabel = featVec[-1]
        #     (Label)           ,    
        if currentLabel not in labelCounts.keys():
            #          ,  currentLabel  0
            labelCounts[currentLabel] = 0
        # Label  
        labelCounts[currentLabel] += 1
    #    (   )
    shannonEnt = 0.0
    #      
    for key in labelCounts:
        #      (Label)   
        prob = float(labelCounts[key]) / numEntires
        #       
        shannonEnt -= prob * log(prob, 2)
    #      (   )
    return shannonEnt


"""
    :           
Parameters:
    dataSet -        
    axis -         
    values -          

Returns:
    None

"""


def splitDataSet(dataSet, axis, value):
    #           
    retDataSet = []
    #          
    for featVec in dataSet:
        if featVec[axis] == value:
            #   axis  
            reducedFeatVec = featVec[:axis]
            #                
            # extend()                         (           )。
            reducedFeatVec.extend(featVec[axis + 1:])
            #        
            retDataSet.append(reducedFeatVec)
            #          
    return retDataSet


"""
    :      
        Gain(D,g) = Ent(D) - SUM(|Dv|/|D|)*Ent(Dv)
Parameters:
    dataSet -    

Returns:
    bestFeature -        (  )      

"""


def chooseBestFeatureToSplit(dataSet):
    #     
    numFeatures = len(dataSet[0]) - 1
    #          
    baseEntropy = calcShannonEnt(dataSet)
    #     
    bestInfoGain = 0.0
    #         
    bestFeature = -1
    #       
    for i in range(numFeatures):
        #   dataSet  i       featList     (     )
        featList = [example[i] for example in dataSet]
        #   set  {},      ,         
        #          python                 
        uniqueVals = set(featList)
        #      
        newEntropy = 0.0
        #       
        for value in uniqueVals:
            # subDataSet      
            subDataSet = splitDataSet(dataSet, i, value)
            #        
            prob = len(subDataSet) / float(len(dataSet))
            #            
            newEntropy += prob * calcShannonEnt(subDataSet)
        #     
        infoGain = baseEntropy - newEntropy
        #            
        print(" %d       %.3f" % (i, infoGain))
        #       
        if (infoGain > bestInfoGain):
            #       ,         
            bestInfoGain = infoGain
            #                
            bestFeature = i
    #                
    return bestFeature


"""
    :  classList          (   )
                    
Parameters:
    classList -      

Returns:
    sortedClassCount[0][0] -          (   )

"""


def majorityCnt(classList):
    classCount = {}
    #   classList          
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    #           
    # operator.itemgetter(1)      1   
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    #   classList          
    return sortedClassCount[0][0]


"""
    :     (ID3  )
                 :1、          ,       
                        2、                  ,      ,               
Parameters:
    dataSet -      
    labels -       
    featLabels -            

Returns:
    myTree -    


"""


def createTree(dataSet, labels, featLabels):
    #      (:C1 or C2)
    classList = [example[-1] for example in dataSet]
    #                
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    #                     
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    #       
    bestFeat = chooseBestFeatureToSplit(dataSet)
    #        
    bestFeatLabel = labels[bestFeat]
    featLabels.append(bestFeatLabel)
    #             
    myTree = {bestFeatLabel: {}}
    #            
    del (labels[bestFeat])
    #                  
    featValues = [example[bestFeat] for example in dataSet]
    #         
    uniqueVals = set(featValues)
    #     ,     
    for value in uniqueVals:
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    return myTree


"""
    :            
Parameters:
    myTree -    

Returns:
    numLeafs -            

"""


def getNumLeafs(myTree):
    #      
    numLeafs = 0
    # python3 myTree.keys()    dict_keys,  list,     
    # myTree.keys()[0]         ,    list(myTree.keys())[0]
    # next()             next(iterator[, default])
    firstStr = next(iter(myTree))
    #        
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #           ,      ,          
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


"""
    :        
Parameters:
    myTree -    

Returns:
    maxDepth -       

"""


def getTreeDepth(myTree):
    #         
    maxDepth = 0
    # python3 myTree.keys()    dict_keys,  list,     
    # myTree.keys()[0]         ,    list(myTree.keys())[0]
    # next()             next(iterator[, default])
    firstStr = next(iter(myTree))
    #        
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #           ,      ,          
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        #       
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    #         
    return maxDepth


"""
    :    
Parameters:
    nodeTxt -    
    centerPt -     
    parentPt -        
    nodeType -     

Returns:
    None

"""


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #       
    arrow_args = dict(arrowstyle="