Pytorch読み出しパラメータエラーRuntimeError:cuda runtime error(10):invalid device ordinal

4936 ワード

久しぶりにブログを出しましたが、今日Pytorchのパラメータ読み込み中に珍しいバグに遭遇しました.RuntimeError: cuda runtime error (10) : invalid device ordinal at torch/csrc/cuda/Module.cpp:87
長い間解決策を探していなかったが、StackOverflowやpytorchのissuesでも似たような問題に遭遇した人はいなかった.最後は自分で出陣するしかなく、徐々にデバッグしていくうちに底から問題が見えてきたので、ここで他の人に助けてほしいことを記録します.
-
1.問題シーンサーバでモデルを訓練し、保存したパラメータをローカルマシンのloadに持って行って結果分析を行ったとき、このエラーが発生しました.Loadパラメータの時、cudaが見つからないことをヒントにしました.詳細ポイントのエラーメッセージは次のとおりです.
THCudaCheck FAIL file=torch/csrc/cuda/Module.cpp line=87 error=10 : invalid device ordinal
Traceback (most recent call last):
  File "/home/sw/Shin/Codes/DL4SS_Keras/Torch_multi/main_run_multi_selfSS_subeval.py", line 557, in 
    main()
  File "/home/sw/Shin/Codes/DL4SS_Keras/Torch_multi/main_run_multi_selfSS_subeval.py", line 482, in main
    mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_hidden3d_190'))
  File "/usr/local/lib/python2.7/dist-packages/torch/serialization.py", line 231, in load
    return _load(f, map_location, pickle_module)
  File "/usr/local/lib/python2.7/dist-packages/torch/serialization.py", line 379, in _load
    result = unpickler.load()
  File "/usr/local/lib/python2.7/dist-packages/torch/serialization.py", line 350, in persistent_load
    data_type(size), location)
  File "/usr/local/lib/python2.7/dist-packages/torch/serialization.py", line 85, in default_restore_location
    result = fn(storage, location)
  File "/usr/local/lib/python2.7/dist-packages/torch/serialization.py", line 67, in _cuda_deserialize
    return obj.cuda(device_id)
  File "/usr/local/lib/python2.7/dist-packages/torch/_utils.py", line 58, in _cuda
    with torch.cuda.device(device):
  File "/usr/local/lib/python2.7/dist-packages/torch/cuda/__init__.py", line 128, in __enter__
    torch._C._cuda_setDevice(self.idx)
RuntimeError: cuda runtime error (10) : invalid device ordinal at torch/csrc/cuda/Module.cpp:87

私のローカルマシンCUDAに問題があったのかと思っていましたが、以前保存していたパラメータを変えれば、意外にもこの問題はなく、普通にworkできます.分析すると、この読み込むパラメータ自体に問題があるはずです.
下部を見て、問題が見つかりました.元はPytorchでパラメータ保存時に元のパラメータ位置に関するlocationを登録します.例えば、サーバー上のGPU 1トレーニングでは、このlocationがGPU 1である可能性が高いです.デスクトップにGPUが1つしかない場合、つまりGPU 0の場合、このパラメータが持ち込まれたLocation情報はデスクトップと互換性がなく、cuda deviceが見つからないという問題が発生します.
2.ソリューションはLoadパラメータの時、あなたの現在のマシン上のGPU状態に基づいてmapを行います.例えば、従来のGPU 1は、torchにおいてGPU 0に変換することができる.load関数にはパラメータがあります.
load(f, map_location=None, pickle_module=pickle)

私の実験では、最終的にはこう言いました.
att_speech_layer.load_state_dict(torch.load('params—xxxxx',map_location={'cuda:1':'cuda:0'}))

解決しました.もちろんあなたのパソコンに合わせてこのmapを調整しなければなりません.