【BPニューラルネットワーク】コード格納
16131 ワード
このプログラムはトレーニングセットに使用され、トレーニングセットを2つの部分に分割し、一部のトレーニング、一部のテストを行います.トレーニングセットのフォーマットは、答えを先に置いてからデータ(つまり28*28ピクチャ展開の784次元ベクトル)を入れるので、答えを先に読み込んでからデータを読み込んでください.注意:プログラムを実行する前に、データの最初の行(label,pixel 0,pixel 1…)を削除してください.そうしないとREになります.
#include
using namespace std;
const int A=784,B=28,C=10;
const double L=0.2;
class BP{
private:
const int IN,HN,ON;
double lambda;
bool isFirstTime;
struct neuron{
double I,O,theta;
};
vector InputNeurons;
vector HiddenNeurons;
vector OutputNeurons;
double WeightIH[A+1][B+1],WeightHO[B+1][C+1];//
double e[C+1];//e[i]=T[i]-OutputNeurons[i].O
double rand(const double &x,const double &y);// [x,y]
double f(const double &x);//Sigmoid
void setWeightRandomly();//
void FeedForward();//
// ,T ,
void BackPropagation(const vector<double> &T);
public:
BP(int a=0,int b=0,int c=0,double d=0.5):
IN(a),HN(b),ON(c),lambda(d),isFirstTime(true){
InputNeurons.resize(IN+1);// 1
HiddenNeurons.resize(HN+1);
OutputNeurons.resize(ON+1);
for(int i=1;i<=IN;++i)
InputNeurons[i]=(neuron){0,0,0};
for(int i=1;i<=HN;++i)
HiddenNeurons[i]=(neuron){0,0,0};
for(int i=1;i<=ON;++i)
OutputNeurons[i]=(neuron){0,0,0};
srand(time(0));//
}
//
void train(const vector<double> &data,const vector<double> &ans);
//
vector<double> test(const vector<double> &data);
// lambda
void setLambda(const double &x);
};
inline double BP::rand(const double &x,const double &y){
return (double)std::rand()*1.0/RAND_MAX*(y-x)+x;
}
inline double BP::f(const double &x){
return 1.0/(1+exp(-x));
}
inline void BP::setWeightRandomly(){// (-1,1)
int i,j;
for(i=1;i<=IN;++i)
for(j=1;j<=HN;++j)
WeightIH[i][j]=rand(-1,1);
for(i=1;i<=HN;++i)
for(j=1;j<=ON;++j)
WeightHO[i][j]=rand(-1,1);
for(i=1;i<=HN;++i)HiddenNeurons[i].theta=rand(0,1);
for(i=1;i<=ON;++i)OutputNeurons[i].theta=rand(0,1);
}
inline void BP::FeedForward(){//
int i,j;
for(j=1;j<=HN;++j){
neuron &p=HiddenNeurons[j];
for(i=1,p.I=0;i<=IN;++i)
p.I+=WeightIH[i][j]*InputNeurons[i].O;
p.O=f(p.I+=p.theta);
}
for(j=1;j<=ON;++j){
neuron &p=OutputNeurons[j];
for(i=1,p.I=0;i<=HN;++i)
p.I+=WeightHO[i][j]*HiddenNeurons[i].O;
p.O=f(p.I+=p.theta);
}
}
inline void BP::BackPropagation(const vector<double> &T){
int i,j,k;
for(i=1;i<=ON;++i)e[i]=T[i]-OutputNeurons[i].O;
for(k=1;k<=ON;++k)for(j=1;j<=HN;++j){
WeightHO[j][k]+=lambda*e[k]*HiddenNeurons[j].O;
}
for(k=1;k<=ON;++k) OutputNeurons[k].theta+=lambda*e[k];
for(j=1;j<=HN;++j){
double sum;
for(k=1,sum=0;k<=ON;++k)sum+=e[k]*WeightHO[j][k];
for(i=1;i<=IN;++i){
WeightIH[i][j]+=
lambda*HiddenNeurons[j].O*(1-HiddenNeurons[j].O)*InputNeurons[i].O*sum;
}
HiddenNeurons[j].theta+=lambda*HiddenNeurons[j].O*(1-HiddenNeurons[j].O)*sum;
}
}
inline void BP::train(const vector<double> &data,const vector<double> &ans){
int i;
for(i=1;i<=IN;++i)InputNeurons[i].O=data[i];
if(isFirstTime) setWeightRandomly(),isFirstTime=false;
FeedForward();
BackPropagation(ans);
}
inline vector<double> BP::test(const vector<double> &data){
int i;
for(i=1;i<=IN;++i)InputNeurons[i].O=data[i];//
FeedForward();
vector<double> ans;
ans.push_back(0);
for(i=1;i<=ON;++i) ans.push_back(OutputNeurons[i].O);
return ans;
}
inline void BP::setLambda(const double &x){lambda=x;}
inline int read(){
int x=0,ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch))x=x*10+ch-48,ch=getchar();
return x;
}
vector<double> create1(int pos){
vector<double> v(11,0);v[pos]=1;
return v;
}
int maxpos(const vector<double> &a){
double tmp=-1e9;
int i,pos;
for(i=1;iif(a[i]>tmp) tmp=a[i],pos=i;
return pos;
}
BP solver(A,B,C,L);
int main(){// !! BUG
int n,i,j,right;
vector<double> input,ans;
scanf("%d",&n);
for(i=1;i<=n;++i){
input.clear();input.push_back(0);
ans .clear();ans .push_back(0);
{int x;scanf("%d",&x);ans=create1(x);}
for(j=1;j<=A;++j){
double x;
scanf(",%lf",&x);
input.push_back(x/255.0);
}
solver.train(input,ans);
}
puts(" !");
scanf("%d",&n);
for(i=1,right=0;i<=n;++i){
input.clear();input.push_back(0);
ans .clear();ans .push_back(0);
int standardAns,calcAns;
scanf("%d",&standardAns);
for(j=1;j<=A;++j){double x;scanf(",%lf",&x);input.push_back(x);}
ans=solver.test(input);
calcAns=maxpos(ans);
if(calcAns==standardAns) right++;
printf("%d
",calcAns);
}
fclose(stdin);fclose(stdout);
return 0;
}