アルゴリズム(二)決定ツリー

3586 ワード

axisを使用する目的は、最適インデックスを削除し、valueと比較して分類サブセットを作成することであり、valueを使用する目的は分類サブセットです.
各ツリーは最適インデックスとvalueで分類されたサブセットです
インデックスとvalueの転送の目的は、データセットを分解することです.
背景:決定ツリーアルゴリズムは条件分類に用いられ、多種の情景モードで異なる結果を生成する.シナリオに対する予測方法の1つ:情報利得を用いて情報利得を決定する公式:データセット全体のカテゴリ-plog(p,2)和-各シナリオにおける異なるカテゴリの-plog(p,2)和にこのシナリオを乗じたこのカテゴリの重み決定ツリー構築方式:ツリールートはすべての配列であり、幹は情報利得を利用し、最大情報利得のインデックス値である.このインデックスというすべての属性を利用して、このインデックス値と各属性を利用して配列を分離し、データを利用して幹の主要関数を構築します:香りの濃いエントロピーを求めて、最大の情報利得を求めて、決定木を構築して、配列を分離します
コードは次のとおりです.
import operator
import math
# a = {'a':0,'b':3,'c':1}
# b = sorted(a.items(),key=operator.itemgetter(1),reverse=True)
# print(b[0][0])
# 1      
# 2            
# 3           
# 4            

#    
dataset=[[0, 0, 1, 0, 'no'],
        [0, 0, 1, 1, 'no'],
        [0, 1, 1, 1, 'yes'],
        [0, 1, 0, 0, 'yes'],
        [0, 0, 1, 0, 'no'],
        [1, 0, 1, 0, 'no'],
        [1, 0, 1, 1, 'no'],
        [1, 1, 0, 1, 'yes'],
        [1, 0, 0, 2, 'yes'],
        [1, 0, 0, 2, 'yes'],
        [2, 0, 0, 2, 'yes'],
        [2, 0, 0, 1, 'yes'],
        [2, 1, 1, 1, 'yes'],
        [2, 1, 1, 2, 'yes'],
        [2, 0, 1, 0, 'no']]

#      ,        ,             ,     ,    
def computeExp(dataset):
    #      
    result = 0
    #       
    num_class = list(set([data[-1] for data in dataset]))
    len_dict = {}
    for data in dataset:
        if data[-1] not in len_dict.keys():
            len_dict[data[-1]] = 0
        len_dict[data[-1]] += 1
    for i in range(len(num_class)):
        port = len_dict[num_class[i]]/len(dataset)
        result -= port*math.log(port,2)
    return result

def splitList(dataset,axis,label):
    data_dict = []
    for data in dataset:
        if data[axis] == label:
            data_dict.append(data[:axis] + data[axis + 1:])
    return data_dict

#            
def increment(dataset,axis):
    num_label = list(set([data[axis] for data in dataset]))
    base_exp = computeExp(dataset)
    #      
    data_exp = 0
    for label in num_label:
        data_dict = splitList(dataset,axis,label)
        port = len(data_dict)/len(dataset)
        data_exp += port*computeExp(data_dict)
    return base_exp - data_exp

def getBestFeatrue(dataset):
    #       
    num_featrue = len(dataset[0][:-1])
    #       
    max_exp = 0
    #             
    max_index = -1

    for i in range(num_featrue):
        new_exp = increment(dataset, i) if increment(dataset, i) > max_exp else max_exp
        if new_exp != max_exp:
            max_index = i
            max_exp = new_exp
    return max_index

def getResult(classes):
    compare_dict = {}
    for label in classes:
        if label not in compare_dict.keys():
            compare_dict[label] = 0
        compare_dict[label] += 1
    return sorted(compare_dict.items(), key=operator.itemgetter(1), reverse=True)[0][0]

def createTree(dataset):
    bestfeatrue = getBestFeatrue(dataset)
    if len(list(set([data[-1] for data in dataset])))==1:
        return dataset[0][-1]
    if(len(dataset[0])==1):
        return getResult(dataset)
    mytree = {bestfeatrue:{}}
    num_class = list(set([data[bestfeatrue] for data in dataset]))
    for label in num_class:
        if label not in mytree[bestfeatrue].keys():
            mytree[bestfeatrue][label] = {}
        new_dataset = splitList(dataset,bestfeatrue,label)
        mytree[bestfeatrue][label] = createTree(new_dataset)
    return mytree

print(createTree(dataset))