smaote(オーバーサンプリングアルゴリズム)
5378 ワード
#-*-coding:utf-8-*-
# smote unbalance dataset
import numpy as np
from sklearn.neighbors import NearestNeighbors
import pandas as pd
def smote(data, tag_label='tag_1', amount_personal=0, std_rate=1, k=5,method = 'mean'):
cnt = data[tag_label].groupby(data[tag_label]).count()
rate = max(cnt) / min(cnt)
location = []
if rate < 1:
print(' smote ')
return data
else:
#
less_data = np.array(data[data[tag_label] == np.array(cnt[cnt == min(cnt)].index)[0]])
more_data = np.array(data[data[tag_label] == np.array(cnt[cnt == max(cnt)].index)[0]])
# k
neighbors = NearestNeighbors(n_neighbors=k).fit(less_data)
for i in range(len(less_data)):
point = less_data[i, :]
location_set = neighbors.kneighbors([less_data[i]], return_distance=False)[0]
location.append(location_set)
#
# , std_rate( )
if amount_personal > 0:
amount = amount_personal
else:
amount = int(max(cnt) / std_rate)
# ,
times = 0
continue_index = [] #
class_index = [] #
for i in range(less_data.shape[1]):
if len(pd.DataFrame(less_data[:, i]).drop_duplicates()) > 10:
continue_index.append(i)
else:
class_index.append(i)
case_update = pd.DataFrame()
while times < amount:
# k ,
new_case = []
pool = np.random.permutation(len(location))[0]
neighbor_group = less_data[location[pool], :]
if method == 'mean':
new_case1 = neighbor_group[:, continue_index].mean(axis=0)
#
if method =='random':
new_case1 =less_data[pool][continue_index] + np.random.rand()*(less_data[pool][continue_index]-neighbor_group[0][continue_index])
# mode
new_case2 = []
for i in class_index:
L = pd.DataFrame(neighbor_group[:, i])
new_case2.append(np.array(L.mode()[0])[0])
# 。 , ( feature label )。 。 。
# new_case.extend([np.array(new_case2[0])])
new_case.extend(new_case1)
# new_case.extend([np.array(new_case2[1])])
new_case.extend(new_case2)
case_update = pd.concat([case_update, pd.DataFrame(new_case)], axis=1)
# print(' %s , %.2f' % (times, times * 100 / amount))
# b=np.array(new_case2[0])
# print(type([b]))
times = times + 1
data_res = np.vstack((more_data, np.array(case_update.T)))
data_res = pd.DataFrame(data_res)
data_res.columns = data.columns
return data_res
参考記事:http://shataowei.com/2017/12/01/python開発:特徴工程コードテンプレート-一/