『機械学習』第三章決定木学習ID 3アルゴリズムc++実現コード
4589 ワード
こんなにたくさんのstlを含むプログラムを書くのは久しぶりで、わざとset、map、vector、熟練した手を使っています.
記録しておきましょう.下手ですが.
3つのファイル:
テストデータ:data.txt
プログラムヘッダファイル:id 3.h
プログラムソースファイルid 3.cpp
記録しておきましょう.下手ですが.
3つのファイル:
テストデータ:data.txt
D1 Sunny Hot High Weak No
D2 Sunny Hot High Strong No
D3 Overcast Hot High Weak Yes
D4 Rain Mild High Weak Yes
D5 Rain Cool Normal Weak Yes
D6 Rain Cool Normal Strong No
D7 Overcast Cool Normal Strong Yes
D8 Sunny Mild High Weak No
D9 Sunny Cool Normal Weak Yes
D10 Rain Mild Normal Weak Yes
D11 Sunny Mild Normal Strong Yes
D12 Overcast Mild High Strong Yes
D13 Overcast Hot Normal Weak Yes
D14 Rain Mild High Strong No
プログラムヘッダファイル:id 3.h
#ifndef ID3_H
#define ID3_H
#include
#include
#include
#include
プログラムソースファイルid 3.cpp
#include "id3.h"
string DataTable[DataRow][DataColumn];
map str2int;
set S;
set Attributes;
string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};
string attrValue[DataColumn][DataRow]=
{
{},//D1,D2
{"Sunny","Overcast","Rain"},
{"Hot","Mild","Cool"},
{"High","Normal"},
{"Weak","Strong"},
{"No","Yes"}
};
int attrCount[DataColumn]={14,3,3,2,2,2};
double lg2(double n)
{
return log(n)/log(2);
}
void Init()
{
ifstream fin("data.txt");
for(int i=0;i<14;i++)
{
for(int j=0;j<6;j++)
{
fin>>DataTable[i][j];
}
}
fin.close();
for(int i=1;i<=5;i++)
{
str2int[attrName[i]]=i;
for(int j=0;j &s)
{
double yes=0,no=0,sum=s.size(),ans=0;
for(set::iterator it=s.begin();it!=s.end();it++)
{
string s=DataTable[*it][str2int["PlayTennis"]];
if(s=="Yes")
yes++;
else
no++;
}
if(no==0||yes==0)
return ans=0;
ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);
return ans;
}
double Gain(const set & example,int attrid)
{
int attrcount=attrCount[attrid];
double ans=Entropy(example);
double sum=example.size();
set * pset=new set[attrcount];
for(set::iterator it=example.begin();it!=example.end();it++)
{
pset[str2int[DataTable[*it][attrid]]].insert(*it);
}
for(int i=0;i & example,const set & attr)
{
double mx=0;
int k=-1;
for(set::iterator i=attr.begin();i!=attr.end();i++)
{
double ret=Gain(example,*i);
if(ret>mx)
{
mx=ret;
k=*i;
}
}
if(k==-1)
cout< example,set & attributes,Node * parent)
{
Node *now=new Node;// 。
now->parentNode=parent;
if(attributes.empty())// , , 。
return now;
/*
* example,
* childNode 。
*/
int yes=0,no=0,sum=example.size();
for(set::iterator it=example.begin();it!=example.end();it++)
{
string s=DataTable[*it][str2int["PlayTennis"]];
if(s=="Yes")
yes++;
else
no++;
}
if(yes==sum||yes==0)
{
now->value=yes/sum;
return now;
}
/* attributes */
int bestattrid=FindBestAttribute(example,attributes);
now->attrid=bestattrid;
attributes.erase(attributes.find(bestattrid));
/* exmple , */
vector< set > child=vector< set >(attrCount[bestattrid]);
for(set::iterator i=example.begin();i!=example.end();i++)
{
int id=str2int[DataTable[*i][bestattrid]];
child[id].insert(*i);
}
for(int i=0;ichildNode.push_back(ret);
}
return now;
}
int main()
{
Init();
Node * Root=Id3_solution(S,Attributes,NULL);
return 0;
}