pytorchトレーニングモデルの注意事項
1.torch tensorでGPU上で演算してデータセットを生成し、データ生成を加速するデータセットがオンライン生成を必要とする場合(すなわちdataloaderで計算してfeatureとlabelを生成する)、データ量が比較的大きく、マトリクス演算に関連する場合、torch tensorで計算することができる.マトリックスをGPU上に置くと計算が速い. datasetを構築する際にcuda tensorで計算する場合、dataloaderを作成してマルチスレッドでデータをロードする場合は、import multiprocessing as mp,mpを加えることに注意してください.set_start_method(‘spawn’)ではないとマルチスレッドロードに問題が発生します.
2.入力データセットにnanなどの汚れたデータがあるかどうかを確認することに注意してください.これはnanを訓練するのに簡単です.データセットがnumpy形式である、np.any(np.isnan(x))は、nanがあるか否かを判断する.すなわち、 データセットならtorch.tensorフォーマットはtorch.any(torch.isnan(x))はnan があるかどうかを判断する
3.訓練lossにnanが現れる可能性のある原因
1)学習率が高すぎる,2)loss関数が不適切である,3)データそのものが存在するかどうか,Nan 4)targetそのものがloss関数で計算できるはずである.
注意点:
データセット自体の計算でnanを生成する torch.acos(x)の場合、xの値が-1または1に近いとnanの現象が発生し、ここではtorchなどのxの範囲制限を推奨する.clip(x, -1+1e-6, 1-1e-6) の計算における除算は、分母が0であるかどうかに注意しなければならない.分母が0であるとnanの現象も発生し,1/(x+1 e-6) のように極小数を加えることで解決できる.計算に求逆演算がある場合、torch.inv(x)マトリクスが可逆的(満ランク)であることを確認し、xが満ランクでないことを心配する場合は対角アレイx+torchを加えることができる.eye(x.shape[-1]) 先にacosを求めることに出会って、それから逆数を求めて、それから逆を求めて、xに対して以下の操作を行うことができます:x=torch.clip(x, -1+1e-6, 1-1e-6) , torch.acos(x)
target自体がloss関数に入力されると、lossの計算中にnanが生成され、すなわちtarget自体がloss関数によって計算されないは、例えば、クロスエントロピー損失などのloss関数を手動で書く、-torchに関する.log(x)は、ここでxが0にならないことを保証するため、-torchのような1 e-6を加えることができる.log(x+1e-6) 例えばsigmoid活性化関数のtargetは0 より大きいべきである.
4.クロスエントロピーを計算する場合、pred結果にmaskがある場合は、次のクロスエントロピー計算を行うことができます.
5.データが不均衡な場合はfocal lossまたはlossで重み調整が可能
Focal loss
6.勾配が累積した場合、累積回数で除算することを忘れないでください.そうしないと、勾配が大きすぎて訓練異常を引き起こすことになります.
学習pytorch
# cuda feature , dataloader , 。 pytorch 。
import multiprocessing as mp
mp.set_start_method('spawn')
2.入力データセットにnanなどの汚れたデータがあるかどうかを確認することに注意してください.これはnanを訓練するのに簡単です.
assert not np.any(np.isnan(x)),'nan exists!'
assert not torch.any(torch.isnan(x)),'nan exists!'
3.訓練lossにnanが現れる可能性のある原因
1)学習率が高すぎる,2)loss関数が不適切である,3)データそのものが存在するかどうか,Nan 4)targetそのものがloss関数で計算できるはずである.
注意点:
データセット自体の計算でnanを生成する
x = torch.clip(x, -1+1e-6, 1-1e-6)
y = torch.acos(x)
z = 1/(y+1e-6)
w = torch.inv(z)
target自体がloss関数に入力されると、lossの計算中にnanが生成され、すなわちtarget自体がloss関数によって計算されない
4.クロスエントロピーを計算する場合、pred結果にmaskがある場合は、次のクロスエントロピー計算を行うことができます.
"""
pytorch
mask 1: ,0:
- check_mask_valid(mask): mask 0, 0, , loss
- maskNLLLoss(pred, target, mask): mask , mask 0, tensor(0., requres_grad=True)
- calAcc(pred, target, mask): mask ACC, mask 0, 1。( )
"""
def check_mask_valid(mask):
# mask 0, 0, target
if torch.sum(mask.float()).item() == 0:
return False
else:
return True
def maskNLLLoss(pred, target, mask):
# loss tensor, train loss.backward()
if not check_mask_valid(mask):
return torch.tensor(0., requres_grad=True)# mask 0, ,loss 0, tensor grad, backward
target1hot = torch.nn.functional.one_hot(target.long(), pred.shape[1]).permuate(0,3,1,2).float()# B*C*H*W
crossEntropy = -torch.log((target1hot * pred).sum(dim=1) + 1e-6) * mask.float()# B*C*H*W
loss = crossEntropy.sum()/ mask.float().sum()
return loss
def calAcc(pred, target, mask):
# acc tensor , acc backward(),
if not check_mask_valid(mask):
return 1
pred_ = torch.argmax(pred, dim=1)
equal = torch.eq(pred_.float(), target.float()).float()
masked_equal = equal * mask.float()
acc = torch.sum(masked_equal) / torch.sum(mask.float())
return acc.item()
5.データが不均衡な場合はfocal lossまたはlossで重み調整が可能
Focal loss
6.勾配が累積した場合、累積回数で除算することを忘れないでください.そうしないと、勾配が大きすぎて訓練異常を引き起こすことになります.
"""
acc_freq, loss , loss/acc_freq, optimize.step()
if acc_freq > 0:#
loss = loss/acc_freq #
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch.optim as optim
val_split = 0.2
acc_freq = 4
#
dataset = MyDataset()
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset, batch_size=4, sampler=train_sampler, num_workers=2)
#
lr = 1e-4
weight_decay = 0.2
Model = Model()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr)
#
def train():
for epoch in range(total_epoch):
for batch, (features, labels, mask) in enmuerate(train_loader):
pred = Model(features) #
loss = loss_fn(pred, labels, mask) # loss
acc = cal_acc(pred, labels, mask) # acc
if acc_freq > 0:#
loss = loss/acc_freq #
loss.backward()#
if (batch+1) % acc_freq == 0:
optimizer.step() #
optimizer.zero_grad() # ( optimizer.step(), zero_grad())
# update loss and acc
# save model
学習pytorch