pytorchを解決するRuntimeError:CUDA out of memory.そしてlossが超大きな原因になった

9219 ワード

最初のepochの訓練を終えて最初のepochの検証に移行したとき、走って間もなくpytorchはメモリ不足を報告しました.次に、pytorchを使用したマルチGPUパラレルソリューションを使用して、インターネットでいくつかのソリューションを調べました.
    if t.cuda.device_count() > 1:
         model = nn.DataParallel(model)
    if opt.use_gpu: model.cuda()

しかし、その結果、batchsizeの検証をいくつか多く走った.間違いを報告し続けます.次にpytorch forumで与えられたスキームを試みた.

        for ii, data in enumerate(dataloader):
            input, label = data
            val_input1 = Variable(input, volatile=True)
            val_input = val_input1.float()
            val_label = Variable(label.type(t.LongTensor), volatile=True)
            if opt.use_gpu:
                val_input = val_input.cuda()
                val_label = val_label.cuda()
            score = model(val_input)
            loss = criterion(score,val_label)
            if ii == 0:
                loss_mean = loss.item() #    .item()  python number
            loss_mean = 0.1 * loss.item() + 0.9 * loss_mean
            '''
              torch.cuda.empty_cache()        
            '''
            t.cuda.empty_cache()        #            

しかし、発見はまだだめで、爆発は爆発を続けなければならない.しかし、私が使っているのはP 100のグラフィックスカードですね.2枚が並行して32 Gのディスプレイメモリです.なんと2枚とも揚げてしまいました.最終的に解決策が見つかりました.
'''
  with torch.no_grad,              。         model.eval    pytorch       val      。
'''
    with t.no_grad():
        for ii, data in enumerate(dataloader):
            input, label = data
            val_input1 = Variable(input, volatile=True)
            val_input = val_input1.float()
            val_label = Variable(label.type(t.LongTensor), volatile=True)
            if opt.use_gpu:
                val_input = val_input.cuda()
                val_label = val_label.cuda()
            score = model(val_input)
            loss = criterion(score,val_label)
            if ii == 0:
                loss_mean = loss.item() #    .item()  python number
            loss_mean = 0.1 * loss.item() + 0.9 * loss_mean
            t.cuda.empty_cache()        #            

行うno_grad後、やっとepoch全体を走ることに成功した.しかし、プログラムの中で私が出力したlossは100000に達した.後で尋ねたところ、あなたのtargetがlabelではないとき.lossをあなたのH*Wで割る必要があります.具体的には、
loss = criterion(score,target) / (128 * 128)

具体的な原因は数学的に完全に理解されていない.よく理解してから補足する