pytorch使用GPU

3025 ワード

  • 初期化device
  •     if torch.cuda.is_available():
            if not opt.gpuid:#opt           
                opt.gpuid = 0
            opt.device = torch.device("cuda:%d" % opt.gpuid)
        else:
            opt.device = torch.device("cpu")
            opt.gpuid = -1
            print("CUDA is not available, fall back to CPU.")
    
  • criterion,model,tensorをGPU上の
  • に移行
    	# model = Seq2SeqModel(opt)
    	model = model.to(opt.device)
    	criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    	#src: a LongTensor
    	src = src.to(opt.device)
    
  • GPUからデータを取り出す
  •     src = src.cpu()