『機械学習』第三章決定木学習ID 3アルゴリズムc++実現コード

4589 ワード

こんなにたくさんのstlを含むプログラムを書くのは久しぶりで、わざとset、map、vector、熟練した手を使っています.
記録しておきましょう.下手ですが.
3つのファイル:
テストデータ:data.txt
D1    Sunny        Hot    High        Weak    No
D2    Sunny        Hot    High        Strong    No
D3    Overcast    Hot    High        Weak    Yes
D4    Rain        Mild    High        Weak    Yes
D5    Rain        Cool    Normal        Weak    Yes
D6    Rain        Cool    Normal        Strong    No
D7    Overcast    Cool    Normal        Strong    Yes
D8    Sunny        Mild    High        Weak    No
D9    Sunny        Cool    Normal        Weak    Yes
D10    Rain        Mild    Normal        Weak    Yes
D11    Sunny        Mild    Normal        Strong    Yes
D12    Overcast    Mild    High        Strong    Yes
D13    Overcast    Hot    Normal        Weak    Yes
D14    Rain        Mild    High        Strong    No

プログラムヘッダファイル:id 3.h
#ifndef ID3_H
#define ID3_H
#include
#include
#include
#include
#include
#include
using namespace std;
const int DataRow=14;
const int DataColumn=6;
struct Node
{
	double value;//    yes   。
	int attrid;
	Node * parentNode;
	vector childNode;
};
#endif

プログラムソースファイルid 3.cpp
#include "id3.h"

string DataTable[DataRow][DataColumn];
map str2int;
set S;
set Attributes;
string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};
string attrValue[DataColumn][DataRow]=
{
	{},//D1,D2       
	{"Sunny","Overcast","Rain"},
	{"Hot","Mild","Cool"},
	{"High","Normal"},
	{"Weak","Strong"},
	{"No","Yes"}
};
int attrCount[DataColumn]={14,3,3,2,2,2};
double lg2(double n)
{
	return log(n)/log(2);
}
void Init()
{
	ifstream fin("data.txt");
	for(int i=0;i<14;i++)
	{
	  for(int j=0;j<6;j++)
	  {
		  fin>>DataTable[i][j];
	  }
	}
	fin.close();
	for(int i=1;i<=5;i++)
	{
		str2int[attrName[i]]=i;
		for(int j=0;j &s)
{
	double yes=0,no=0,sum=s.size(),ans=0;
	for(set::iterator it=s.begin();it!=s.end();it++)
	{
		string s=DataTable[*it][str2int["PlayTennis"]];
		if(s=="Yes")
		  yes++;
		else
		  no++;
	}
	if(no==0||yes==0)
	  return ans=0;
	ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);
	return ans;
}
double Gain(const set & example,int attrid)
{
	int attrcount=attrCount[attrid];
	double ans=Entropy(example);
	double sum=example.size();
	set * pset=new set[attrcount];
	for(set::iterator it=example.begin();it!=example.end();it++)
	{
		pset[str2int[DataTable[*it][attrid]]].insert(*it);
	}
	for(int i=0;i & example,const set & attr)
{
	double mx=0;
	int k=-1;
	for(set::iterator i=attr.begin();i!=attr.end();i++)
	{
		double ret=Gain(example,*i);
		if(ret>mx)
		{
			mx=ret;
			k=*i;
		}
	}
	if(k==-1)
	  cout< example,set & attributes,Node * parent)
{
	Node *now=new Node;//     。
	now->parentNode=parent;
	if(attributes.empty())//           ,   ,   。
	  return now;

	/*
	 *     example,                         
	 *          childNode  。
	 */
	int yes=0,no=0,sum=example.size();
	for(set::iterator it=example.begin();it!=example.end();it++)
	{
		string s=DataTable[*it][str2int["PlayTennis"]];
		if(s=="Yes")
		  yes++;
		else
		  no++;
	}
	if(yes==sum||yes==0)
	{
		now->value=yes/sum;
		return now;
	}
	

	/*                 attributes     */
	int bestattrid=FindBestAttribute(example,attributes);
	now->attrid=bestattrid;
	attributes.erase(attributes.find(bestattrid));
	
	/* exmple                  ,          */
	vector< set > child=vector< set >(attrCount[bestattrid]);
	for(set::iterator i=example.begin();i!=example.end();i++)
	{
		int id=str2int[DataTable[*i][bestattrid]];
		child[id].insert(*i);
	}
	for(int i=0;ichildNode.push_back(ret);
	}
	return now;
}

int main()
{
	Init();
	Node * Root=Id3_solution(S,Attributes,NULL);
	return 0;
}