[python for ML] Decision tree
tree.py
treePlotter.py
from math import log
# function to calculate entropy
## a data set with the last column being class label
def calEntropy(dataset):
# number of data set
numData = len(dataset)
# create a dictionary to contain number of each group
labelCount = {}
# fill key-value in dictionary labelCount
for sam in dataset:
# label of example sam
label = sam[-1]
# check if label is the key of classCount
if label not in labelCount.keys():
# if not create one
labelCount[label] = 0
# add value corresponding to label by 1
labelCount[label] += 1
# initialize entropy
ent = 0.0
for x in labelCount:
# calculate probability of each group
prob = float(labelCount[x])/numData
# mean of information of each group
ent -= prob*log(prob,2)
# return final entropy
return ent
# function to create data set
def creatData():
# create data set
dataset = [[1,1,"yes"],
[1,1,"yes"],
[1,0,"no"],
[0,1,"no"],
[0,1,"no"]]
# create name of each feature
label = ["no suifacing", "flippers"]
# return data set and name of feature
return dataset,label
# function to split data set
## axis: means which feature will be used
## value: value to split data set by axis
def splitData(dataset, axis, value):
# make a new list to contain the splitted data set
setDataset = []
# take each example
for x in dataset:
# check the value of feature axis if or not equivalent to value
if x[axis]==value:
# if so, delete feature axis from this example
## take features before axis 1:(axis-1)
reducedX = x[:axis]
## take features after axis (axis+1):last feature
reducedX.extend(x[axis+1:])
# add this new example to the reset data set
setDataset.append(reducedX)
# return the reset data set by value of feature axis
return setDataset
# function to choose best feature to split data
def chooseFeaturetoSplit(dataset):
# number of sample
numSam = len(dataset)
# number of features
numVar = len(dataset[0])-1
# base entropy of full data set
primEntropy = calEntropy(dataset)
# initialize value of best feature and best information gain
bestFeature = -1
bestInforGain = 0
# for each feature, do the following code
for ivar in range(numVar):
# use set comprehension to get the unique value of this feature
unival = {example[ivar] for example in dataset}
# initialize newEntropy
newEntropy = 0.0
# for each unique value of this feature
for j in unival:
# get the sub data set of which this feature equals to j
subData = splitData(dataset, ivar, j)
# calculate probability of sub data set
prob = len(subData)/float(numSam)
# compute newEntropy and to calculate the mean entropy if using this feature
newEntropy += prob*calEntropy(subData)
# compute information gain of this feature
InforGain = primEntropy - newEntropy
# check if or not information gain of this feature is bigger than the original one
if(InforGain > bestInforGain):
# if so, make best feature be this feature
bestFeature = ivar
# and set best information gain to be the one of this feature
bestInforGain = InforGain
# return best feature
return bestFeature
import operator
# function of classification by major votes
## classList is a list contained class information
def majorCnt(classList):
# make a dictionary to contain feature:number
classCount = {}
# for each class in classList
for vote in classList:
# check if it is a key of classCount
# if not, set a new key of classCount
if vote not in classCount.keys():classCount[vote]=0
# add value of this key by 1
classCount[vote]+=1
# sort a dictionary by value(decreasing)
## dict.items() returns a list of tuples containing key-value information of dict
## operator.itergetter(1) return the 1th item, which is the value of dict
## reverse=True means decreasing; reverse=False means increasing
sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# return major vote or the final class
return sortCount[0][0]
# function to create tree
## labels means names of feature contained in the data set
def createTree(dataset, labels):
# create a new list to contain class labels of all examples
classList = [example[-1] for example in dataset]
# stop splitting data if all data in this node are in the same group
if(classList.count(classList[0]) == len(classList)):
return classList[0]
# stop splitting data if there is no more feature to use
if len(dataset[0])==1:
return majorCnt(classList)
# choose the best feature to split data
bestFeat = chooseFeaturetoSplit(dataset)
# give name of best feature
bestLabel = labels[bestFeat]
# set empty dictionary to contain information of tree
mytree = {bestLabel: {}}
# delete best feature to create new labels
del(labels[bestFeat])
# take all example of best feature, to be a set(not a list)
featurValue = {example[bestFeat] for example in dataset}
# split data and make tree more deep
for val in featurValue:
# using all labels have not used now
subLabel = labels[:]
# create deeper trees of each node corresponding to unique value of best feature
mytree[bestLabel][val] = createTree(splitData(dataset, bestFeat, val), subLabel)
# return information of tree
return mytree
treePlotter.py
import matplotlib.pyplot as plt
# a dictionary containing properties of box around the annotation
## sawtooth: sawtooth( ) rectangular
## fc: color
decisionNode = dict(boxstyle = "sawtooth", fc="0.8")
## round4: Oval
leafNode = dict(boxstyle = "round4", fc="w")
# the arrow style
arrow_args = dict(arrowstyle = ")
# plot node or adding annotation in the tree
## nodetxt: txt information as annotation
## centerPt: position of child node
## parentPt: position of parent node
## nodeType: properties of node
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# The annotate() function in the pyplot module (or annotate method of the Axes class) is used to draw an arrow connecting two points on the plot.
## axes fraction: 0,0 is lower left of axes and 1,1 is upper right (can be seen as the usual coordinates)
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords="axes fraction", xytext=centerPt, textcoords="axes fraction",
## bbox gives a box around the text
## arrowprops: properties of arrow
## va, ha: v-vertical, horizontal
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
# begin to create a new figure: create a wall to plot
fig=plt.figure(1, facecolor="white")
# clear wall
fig.clf()
# create a new figure
## 1,1,1: nrow, ncol, number of figure to plot
## frameon: If True, the figure patch will be colored, if False, the figure background will be transparent
createPlot.ax1=plt.subplot(111, frameon=False)
# plot parent node
plotNode("a decision node", (0.5,0.1), (0.1,0.5), decisionNode)
# plot child node
plotNode("a leaf node", (0.8,0.1), (0.3,0.8), leafNode)
# show the figure
plt.show()