Pytochは分類器の正確度を計算することを実現します。(全分類とサブ分類)
分類器の平均精度計算:
correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
output = model(images)
prediction = torch.argmax(output, 1)
correct += (prediction == labels).sum().float()
total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())
分類器の各サブクラスの正確度の計算:
correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
output = model(images)
prediction = torch.argmax(output, 1)
res = prediction == labels
for label_idx in range(len(labels)):
label_single = label[label_idx]
correct[label_single] += res[label_idx].item()
total[label_single] += 1
acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
for acc_idx in range(len(train_class_correct)):
try:
acc = correct[acc_idx]/total[acc_idx]
except:
acc = 0
finally:
acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)
以上のPytouchは分類器の正確度(総分類とサブ分類)を計算することを実現しました。小編集は皆さんに共有した内容の全部です。参考にしてもらいたいです。皆さんもよろしくお願いします。