pytorchアヤメの分類を解決

7848 ワード

半年前にnumpyでアヤメの分類200行を書いた.各ステップの計算は手書きpythonでbpニューラルネットワークを構築します.アヤメの花の分類
今pytorchで簡単に書いて、pytorchの文法の解釈は前のpytorchを見て簡単なネットワークを構築してください
 1 import pandas as pd
 2 import torch.nn as nn
 3 import torch
 4 
 5 
 6 class MyNet(nn.Module):
 7     def __init__(self):
 8         super(MyNet, self).__init__()
 9         self.fc = nn.Sequential(
10             nn.Linear(4, 3),
11             nn.Sigmoid(),
12             nn.Linear(3, 3),
13             nn.Sigmoid(),
14             nn.Linear(3, 1),
15         )
16         self.mls = nn.MSELoss()
17         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
18 
19     def get_data(self):
20         inputs = []
21         labels = []
22         with open('flower.csv') as file:
23             df = pd.read_csv(file, header=None)
24             x = df.iloc[:, 0:4].values
25             y = df.iloc[:, 4].values
26             for i in range(len(x)):
27                 inputs.append(x[i])
28             for j in range(len(y)):
29                 a = []
30                 a.append(y[j])
31                 labels.append(a)
32 
33         return inputs, labels
34 
35     def forward(self, inputs):
36         out = self.fc(inputs)
37         return out
38 
39     def train(self, x, label):
40         out = self.forward(x)
41         loss = self.mls(out, label)
42         self.opt.zero_grad()
43         loss.backward()
44         self.opt.step()
45 
46     def test(self, x):
47         return self.fc(x)
48 
49 
50 if __name__ == '__main__':
51     net = MyNet()
52     inputs, labels = net.get_data()
53     for i in range(1000):
54         for index, input in enumerate(inputs):
55             #     .float()   ,           
56             input = torch.from_numpy(input).float()
57             label = torch.Tensor(labels[index])
58             net.train(input, label)
59     #       
60     c = torch.Tensor([[5.6, 2.7, 4.2, 1.3]])
61     print(net.test(c))

運転結果は0.5に近く正確で、単純にpytorchを練習すると、訓練セット、テストセットがありません.
1 tensor([[0.5392]], grad_fn=)

手書きを使わずに逆伝播と勾配が下がるのはどんなに幸せなことか~