train関数

1917 ワード

訓練のたびにテストする
def get_acc(output,label):
    total = output.shape[0]
    _,pred_label = output.max(1)
    return (pred_label == label).sum().data.item()/total

def train(net,train_data,valid_data,num_epochs,optimizer,criterion):
    if torch.cuda.is_available():
        net = net.cuda()
    time0 = time.time()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        time1 = time.time()
        for im,label in train_data:
            im = Variable(im.cuda())
            label = Variable(label.cuda())
            output = net(im)
            #print(output)
            loss = criterion(output,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.data.item()
            train_acc += get_acc(output,label)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net.eval()
            for im,label in valid_data:
                im = Variable(im.cuda())
                label = Variable(label.cuda())
                output = net(im)
                loss = criterion(output,label)
                valid_loss += loss.data.item()
                valid_acc += get_acc(output,label)
            time3 = time.time()
            print('epoch: {}, Train Loss: {},Train Acc: {},Valid Loss: {},Valid Acc: {},UseTime: {}s'.format(epoch,train_loss / len(train_data),train_acc / len(train_data),valid_loss / len(valid_data),valid_acc / len(valid_data),time3 - time1))
        else:
            time2 = time.time()
            print('epoch: {}, Train Loss: {},Train Acc: {},UseTime: {}s'.format(epoch, train_loss / len(train_data), train_acc / len(train_data),time2-time1))
    time5 = time.time()
    print('TotalTime:{}s'.format(time5-time0))
train(net, train_data,test_data, 20, optimizer, criterion)