pytorch androidデプロイメントdemo自分で訓練したカスタムモデルのピット記録


githubプロジェクトに自分で定義したモデル(分類数を少し変更したvggネットワーク、40種類に分けて)を追加したときに遭遇した小さな穴を記録します.
2021-01-26 19:02:42.191 19212-19370/org.pytorch.demo E/AndroidRuntime: FATAL EXCEPTION: ModuleActivity
    Process: org.pytorch.demo, PID: 19212
    com.facebook.jni.CppException: 
    
    aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor):
    Expected at most 12 arguments but found 13 positional arguments.
    :
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py(419): _conv_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py(423): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/container.py(117): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/garbage_classify/mycode/vgg.py(42): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/garbage_classify/mycode/myVGG.py(25): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/jit/_trace.py(934): trace_module
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/jit/_trace.py(733): trace
    /home/xutengfei/garbage_classify/mycode/deployment_script.py(32): 
    Serialized   File "code/__torch__/torch/nn/modules/conv.py", line 10
        input: Tensor) -> Tensor:
        _0 = self.bias
        input0 = torch._convolution(input, self.weight, _0, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
                 ~~~~~~~~~~~~~~~~~~ <--- HERE
        return input0
    
        at org.pytorch.NativePeer.initHybrid(Native Method)
        at org.pytorch.NativePeer.(NativePeer.java:24)
        at org.pytorch.Module.load(Module.java:23)
        at org.pytorch.demo.vision.ImageClassificationActivity.analyzeImage(ImageClassificationActivity.java:166)
        at org.pytorch.demo.vision.ImageClassificationActivity.analyzeImage(ImageClassificationActivity.java:31)
        at org.pytorch.demo.vision.AbstractCameraXActivity.lambda$setupCameraX$2$AbstractCameraXActivity(AbstractCameraXActivity.java:90)
        at org.pytorch.demo.vision.-$$Lambda$AbstractCameraXActivity$t0OjLr-l_M0-_0_dUqVE4yqEYnE.analyze(Unknown Source:2)
        at androidx.camera.core.ImageAnalysisAbstractAnalyzer.analyzeImage(ImageAnalysisAbstractAnalyzer.java:57)
        at androidx.camera.core.ImageAnalysisNonBlockingAnalyzer$1.run(ImageAnalysisNonBlockingAnalyzer.java:135)
        at android.os.Handler.handleCallback(Handler.java:900)
        at android.os.Handler.dispatchMessage(Handler.java:103)
        at android.os.Looper.loop(Looper.java:219)
        at android.os.HandlerThread.run(HandlerThread.java:67)

エラーメッセージをキャッチ:Expected at most 12 arguments but found 13 positional arguments.パラメータをよく照合すると、trueが1つ増えていることがわかります.
aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
"input", "self.weight", _0, "[1, 1]", "[1, 1]", "[1, 1]", "False", "[0, 0]", "1", "False, False, True", True         。

その後、エラーのヒントに基づいてネット上で関連資料を調べ、バージョンの問題かもしれないと推定します.その後、やはりgithubのissueで欲しい答えを見つけました.buildを修正します.gradleのpytorch-androidは最新バージョンでいいです!
implementation 'org.pytorch:pytorch_android:1.7.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.7.0'