CARTアルゴリズム
23065 ワード
CARTアルゴリズムは分類と回帰問題を解決するために使用できます.CARTは、決定ツリーが二叉樹であり、内部ノードの値が「はい」または「いいえ」であると仮定し、左の枝は「はい」の枝を取り、右の枝は「いいえ」の枝を取る.CARTアルゴリズムはgigin係数を用いて特徴の選択と分割を行う.
1.https://github.com/Dod-o/Statistical-Learning-Method_コード
import cv2
import time
import logging
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import os
import struct
total_class = 10
# ,
#
def binaryzation(img):
cv_img = img.astype(np.uint8)
cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)
return cv_img
#@log
def binaryzation_features(trainset):
features = []
for img in trainset:
img = np.reshape(img, (28, 28))
cv_img = img.astype(np.uint8)
img_b = binaryzation(cv_img)
features.append(img_b)
features = np.array(features)
features = np.reshape(features, (-1, 784))
return features
def load_mnist(path, kind='train'):
labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8)) # I 4 , 8 , magic n
labels = np.fromfile(lbpath, dtype=np.uint8) # , 8
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = binaryzation_features(np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784))
# images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) # 16 , 8
return images, labels
# 9912422 , 60000 , ,
class TreeNode(object):
""" """
def __init__(self, **kwargs): # ** ,
'''
attr_index:
attr:
label: (y)
left_chuld:
right_child:
'''
self.attr_index = kwargs.get('attr_index') #
self.attr = kwargs.get('attr')
self.label = kwargs.get('label')
self.left_child = kwargs.get('left_child')
self.right_child = kwargs.get('right_child')
#
def gini_train_set(train_label): #
train_label_value = set(train_label)
gini = 0.0
for i in train_label_value: #
train_label_temp = train_label[train_label == i]
pk = float(len(train_label_temp)) / len(train_label)
gini += pk * (1 - pk) # gini 5.22
return gini
# ,
def gini_feature(train_feature, train_label):
train_feature_value = set(train_feature)
min_gini = float('inf')
return_feature_value = 0
for i in train_feature_value:
train_feature_class1 = train_feature[train_feature == i] # array
label_class1 = train_label[train_feature == i]
# train_feature_class2 = train_feature[train_feature != i]
label_class2 = train_label[train_feature != i]
# CART ,
D1 = float(len(train_feature_class1)) / len(train_feature)
D2 = 1 - D1
gini = D1 * gini_train_set(label_class1) + D2 * gini_train_set(label_class2) # 5.25
if min_gini > gini:
min_gini = gini
return_feature_value = i
return min_gini, return_feature_value
def get_best_index(train_set, train_label, feature_indexes):
'''
:param train_set:
:param train_label:
:return: ,
'''
min_gini = float('inf') #
feature_index = 0 #
return_feature_value = 0 #
for i in range(len(train_set[0])):
if i in feature_indexes: # feature_indexes i
train_feature = train_set[:, i]
gini, feature_value = gini_feature(train_feature, train_label)
# ,
if gini < min_gini:
min_gini = gini
feature_index = i
return_feature_value = feature_value
return feature_index, return_feature_value
# ( )
def divide_train_set(train_set, train_label, feature_index, feature_value):
left = []
right = []
left_label = []
right_label = []
for i in range(len(train_set)): #
line = train_set[i] #
if line[feature_index] == feature_value: #
left.append(line)
left_label.append(train_label[i])
else:
right.append(line)
right_label.append(train_label[i])#
return np.array(left), np.array(right), np.array(left_label), np.array(right_label)
n = 0
def build_tree(train_set, train_label, feature_indexes):
#
global n
# n = n + 1
# print(n)
train_label_value = set(train_label)
if len(train_label_value) == 1: # ,
return TreeNode(label=train_label[0])
# , ,
if len(train_label_value) == 0: #
return None
if len(feature_indexes) == 0 or gini_train_set(train_label) < 0.1: #
# return TreeNode(label=train_label[0]) # ,
max_label_number = 0
res = 0
for i in train_label_value:
now_label_number = len(train_label[train_label == i])
if now_label_number > max_label_number:
max_label_number = now_label_number
res = i
return TreeNode(label=res)
#
feature_index, feature_value = get_best_index(train_set, train_label, feature_indexes)
#
left, right, left_label, right_label = divide_train_set(train_set, train_label, feature_index, feature_value)
#
feature_indexes.remove(feature_index) # , ,
# ,
left_branch = build_tree(left, left_label, feature_indexes) # feature_indexes
right_branch = build_tree(right, right_label, feature_indexes)
return TreeNode(left_child = left_branch,
right_child = right_branch,
attr_index = feature_index,
attr = feature_value)
def predict_one(node, test):
while node is not None and node.label is None:
if test[node.attr_index] == node.attr:
node = node.left_child
else:
node = node.right_child
if node is None:
return 10
else:
return node.label
#@log
def predict(tree, test_set):
result = []
for test in test_set:
label = predict_one(tree, test)
result.append(label)
return result
if __name__ == '__main__':
print("CART")
print("Start read data...")
t1 = time.time()
train_features,train_labels = load_mnist("data" , kind='train')
train_features = train_features[0:1000]
train_labels = train_labels[0:1000]
test_features,test_labels = load_mnist("data", kind='t10k')
test_features = test_features[0:1000]
test_labels = test_labels[0:1000]
t2 = time.time()
print(" :" + str((t2-t1)))
print('Start training...')
tree = build_tree(train_features, train_labels, [i for i in range(784)])
t3 = time.time()
print(" :" + str((t3-t2)))
print('Start predicting...')
test_predict = predict(tree, test_features) # test_features ,tree , np.array
t4 = time.time()
print(" :" + str((t4-t3)))
r = 0
for i in range(len(test_predict)):
if test_predict[i] != None and test_predict[i] == test_labels[i]:
r = r + 1
# print(test_predict[i], test_labels[i])
score = float(r)/float(len(test_predict))
print("The accruacy score is %f" % score)
参照リンク:1.https://github.com/Dod-o/Statistical-Learning-Method_コード