決定ツリーID 3とC 4.5のpython実装

7042 ワード

テストファイルの形式は以下の通り、isFishとして保存する.csv
no surfacing,flippers,isFish
1,1,yes
1,1,yes
1,0,no
0,1,no
0,1,no

コードは次のとおりです.具体的な基礎論理は参照できます.https://blog.csdn.net/lwc5411117/article/details/102514421
#!/usr/bin/python
# -*- coding: UTF-8 -*-

import numpy as np
import pandas as pd
import operator

class DecisionTree:
    def loadData(self, path):
        """
        Import data from file
        @ return dataSet: data set from file
        """
        dataSet = pd.read_csv(path, delimiter=',')
        labelSet = list(dataSet.columns.values)
        dataSet = dataSet.values
        return dataSet, labelSet

    def calcShannonEnt(self, dataSet):
        """
        Calculate the entropy of data set
        @ param dataSet: data set
        @ return shannonEnt: shannon entropy
        """
        numEntries = len(dataSet)
        labelCounts = {}
        for featVec in dataSet:
            # sample type
            currentLabel = featVec[-1]
            # if currentLable is not in labelCounts, create it
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 1
            labelCounts[currentLabel] += 1

        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key]) / numEntries
            shannonEnt -= prob * np.log2(prob)

        return shannonEnt

    def splitDataSet(self, dataSet, axis, value):
        """
        divide data set, and get the value which apply one of the features
        @ param dataSet: data set
        @ param axis: features to divide data set
        @ param value: get list which apply the feature
        """
        retDataSet = []
        for featVec in dataSet:
            # get data with same feature, and remove current feature with axis in feature vector
            if featVec[axis] == value:
                reduceFeatVec = list(featVec[:axis])
                reduceFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reduceFeatVec)

        return retDataSet

    def chooseBestFeatureID3(self, dataSet):
        """
        select best feature for dividing with ID3
        @ param dataSet: data set
        @ return bestFeature: best feature for dividing
        """
        numFeature = len(dataSet[0]) - 1
        baseEntropy = self.calcShannonEnt(dataSet)
        bestInfoGain = 0.0
        bestFeature = -1
        for i in range(numFeature):
            # get all entropy values based on ith feature
            featureList = [example[i] for example in dataSet]
            # remove duplicate value
            uniqueVals = set(featureList)
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                # calculate ratio which feature value is "value" on ith
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy += prob * np.log2(prob)
            infoGain = baseEntropy - newEntropy

            if infoGain > bestInfoGain:
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature

    def chooseBestFeatureC45(self, dataSet):
        """
        select best feature for dividing with C4.5
        @ param dataSet: data set
        @ return bestFeature: best feature for dividing
        """
        numFeature = len(dataSet[0]) - 1
        baseEntropy = self.calcShannonEnt(dataSet)
        bestInfoGainRatio = 0.0
        bestFeature = -1
        for i in range(numFeature):
            # get all entropy values based on ith feature
            featureList = [example[i] for example in dataSet]
            # remove duplicate value
            uniqueVals = set(featureList)
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                # calculate ratio which feature value is "value" on ith
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy += prob * np.log2(prob)
            infoGain = baseEntropy - newEntropy

            # calculate split information
            splitInfo = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                splitInfo += prob * np.log2(prob)

            infoGainRatio = infoGain / (-splitInfo)
            if infoGainRatio > bestInfoGainRatio:
                bestInfoGainRatio = infoGainRatio
                bestFeature = i
        return bestFeature

    def majorityCnt(self, classList):
        """
        recurrent to build decision tree
        @ param classList: class list
        @ return sortedClassCount[0][0]: class appeared max count
        """
        classCount = {}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            classCount += 1

        # sort class based on count
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        # return the class which appears with max count
        return sortedClassCount[0][0]

    def createTree(self, dataSet, labels, method):
        """
        construct decision tree
        @ param dataSet: data set
        @ param labels: label set
        @ return myTree: decision tree
        """
        classList = [example[-1] for example in dataSet]
        # stop when number of class equals to number of feature
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        # go through all the features, and then return the class with max count
        if len(dataSet[0]) == 1:
            return self.majorityCnt(classList)

        # get best feature for dividing
        if method == "ID3":
            bestFeat = self.chooseBestFeatureID3(dataSet)
        elif method == "C4.5":
            bestFeat = self.chooseBestFeatureID3(dataSet)
        else:
            bestFeat = self.chooseBestFeatureID3(dataSet)

        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueValues = set(featValues)
        for value in uniqueValues:
            subLabels = labels[:]
            # recurrent to build decision tree
            myTree[bestFeatLabel][value] = self.createTree(self.splitDataSet(dataSet, bestFeat, value), subLabels, method)
        return myTree

if __name__ == '__main__':
    decTree = DecisionTree()
    dataSet, labelSet = decTree.loadData('isFish.csv')
    tree = decTree.createTree(dataSet, labelSet, "ID3")
    print(tree)

参考文献:ここで参考:https://juejin.im/post/5aa503b4518825555d46e1d8