train関数
1917 ワード
訓練のたびにテストする
def get_acc(output,label):
total = output.shape[0]
_,pred_label = output.max(1)
return (pred_label == label).sum().data.item()/total
def train(net,train_data,valid_data,num_epochs,optimizer,criterion):
if torch.cuda.is_available():
net = net.cuda()
time0 = time.time()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
time1 = time.time()
for im,label in train_data:
im = Variable(im.cuda())
label = Variable(label.cuda())
output = net(im)
#print(output)
loss = criterion(output,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.data.item()
train_acc += get_acc(output,label)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net.eval()
for im,label in valid_data:
im = Variable(im.cuda())
label = Variable(label.cuda())
output = net(im)
loss = criterion(output,label)
valid_loss += loss.data.item()
valid_acc += get_acc(output,label)
time3 = time.time()
print('epoch: {}, Train Loss: {},Train Acc: {},Valid Loss: {},Valid Acc: {},UseTime: {}s'.format(epoch,train_loss / len(train_data),train_acc / len(train_data),valid_loss / len(valid_data),valid_acc / len(valid_data),time3 - time1))
else:
time2 = time.time()
print('epoch: {}, Train Loss: {},Train Acc: {},UseTime: {}s'.format(epoch, train_loss / len(train_data), train_acc / len(train_data),time2-time1))
time5 = time.time()
print('TotalTime:{}s'.format(time5-time0))
train(net, train_data,test_data, 20, optimizer, criterion)