pytorchプロジェクトをmatlabに移行
久しぶりにブログを書いたので、パスワードが見つからないと思った.
最近暇を見つけて小さなプロジェクトに参加して、耻ずかしくて、3つの小さなことしかしませんでした.
1.PyTorchに基づいて一連の単一画像超分解ニューラルネットワークを訓練した
PyTorchに基づいて,2〜10からの超分解能係数をもつ一連の単一画像超分解能ニューラルネットワークを訓練した.このセクションの実装はpytorch公式repoのSRルーチンを参照し、トレーニングプログラムは`./train`フォルダ.このプロジェクトは、高効率サブピクセルボリューム[1]に基づいて空間解像度向上操作を行い、訓練速度が極めて速い.[1] ["Shi W, Caballero J, Huszar F, et al. Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network[J]. 2016:1874-1883.](https://arxiv.org/abs/1609.05158)
2.訓練したモデルの重み値をMATLABファイルに転送する.
単純で乱暴で,異常に直接で,対応するボリューム層の重み値をすべて抽出すればよい.
抽出するときはpytorchのVariable形式をTensorに変換し、CPUモードに変換してnumpy配列に変換することに注意してください.
この一連のプロセスを統合すると、
具体的には、次のようになります.
3.ネットワークのtestプロセスをMATLABプラットフォームに移植し、テストコードを作成した.
ボリューム層とpixelshuffle層をmatlabで書き直しました.
pixelshuffleレイヤを復元する際にトラブルに遭遇し、pytorchのテストコードを振り返った.
`https://github.com/pytorch/pytorch/blob/master/test/test_nn.py `
理の構想を整理して、MATLABコードに書き換えます:
4.完全なエンジニアリングgithubリンク.
https://github.com/JiJingYu/super-resolution-by-subpixel-convolution
モデルの重み値はmatlabの重み値として保存され、matlabで直接`demoを実行します.m`ファイルで検証可能
転載先:https://www.cnblogs.com/nwpuxuezha/p/7834344.html
最近暇を見つけて小さなプロジェクトに参加して、耻ずかしくて、3つの小さなことしかしませんでした.
1.PyTorchに基づいて一連の単一画像超分解ニューラルネットワークを訓練した
PyTorchに基づいて,2〜10からの超分解能係数をもつ一連の単一画像超分解能ニューラルネットワークを訓練した.このセクションの実装はpytorch公式repoのSRルーチンを参照し、トレーニングプログラムは`./train`フォルダ.このプロジェクトは、高効率サブピクセルボリューム[1]に基づいて空間解像度向上操作を行い、訓練速度が極めて速い.[1] ["Shi W, Caballero J, Huszar F, et al. Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network[J]. 2016:1874-1883.](https://arxiv.org/abs/1609.05158)
2.訓練したモデルの重み値をMATLABファイルに転送する.
単純で乱暴で,異常に直接で,対応するボリューム層の重み値をすべて抽出すればよい.
抽出するときはpytorchのVariable形式をTensorに変換し、CPUモードに変換してnumpy配列に変換することに注意してください.
この一連のプロセスを統合すると、
Var.data.cpu().numpy()
具体的には、次のようになります.
1 from __future__ import print_function
2
3 import torch
4 import numpy as np
5 import scipy.io as sio
6
7 for i in [2, 3, 4, 5, 6, 7, 8, 9, 10]:
8
9 model_name = 'model_upscale_{}_epoch_101.pth'.format(i)
10 model = torch.load(model_name)
11 print(model._modules)
12
13 weight = dict()
14 weight['conv1_w'] = model._modules['conv1']._parameters['weight'].data.cpu().numpy()
15 weight['conv2_w'] = model._modules['conv2']._parameters['weight'].data.cpu().numpy()
16 weight['conv3_w'] = model._modules['conv3']._parameters['weight'].data.cpu().numpy()
17 weight['conv4_w'] = model._modules['conv4']._parameters['weight'].data.cpu().numpy()
18
19 weight['conv1_b'] = model._modules['conv1']._parameters['bias'].data.cpu().numpy()
20 weight['conv2_b'] = model._modules['conv2']._parameters['bias'].data.cpu().numpy()
21 weight['conv3_b'] = model._modules['conv3']._parameters['bias'].data.cpu().numpy()
22 weight['conv4_b'] = model._modules['conv4']._parameters['bias'].data.cpu().numpy()
23
24 sio.savemat('model_upscale_{}.mat'.format(i), mdict=weight)
3.ネットワークのtestプロセスをMATLABプラットフォームに移植し、テストコードを作成した.
ボリューム層とpixelshuffle層をmatlabで書き直しました.
pixelshuffleレイヤを復元する際にトラブルに遭遇し、pytorchのテストコードを振り返った.
`https://github.com/pytorch/pytorch/blob/master/test/test_nn.py `
# https://github.com/pytorch/pytorch/blob/master/test/test_nn.py
def _verify_pixel_shuffle(self, input, output, upscale_factor):
for c in range(output.size(1)):
for h in range(output.size(2)):
for w in range(output.size(3)):
height_idx = h // upscale_factor
weight_idx = w // upscale_factor
channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
(c * upscale_factor ** 2)
self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx])
理の構想を整理して、MATLABコードに書き換えます:
1 function [ outputs ] = PixelShuffle( inputs, upscale_factor )
2 % PixelShuffle :
3 %
4 % input : N, upscale_factor ** 2, H, W
5 % output : N, 1, H*upscale_factor, W*upscale_factor
6
7 [N, ~, H, W] = size(inputs);
8 H_out = H*upscale_factor;
9 W_out = W*upscale_factor;
10 outputs = zeros([N, 1, H_out, W_out]);
11 for i = 1:N
12 for h = 1: H_out
13 for w = 1:W_out
14 height_idx = floor(h / upscale_factor+0.5);
15 weight_idx = floor(w / upscale_factor+0.5);
16 channel_idx = (upscale_factor * mod(h-1, upscale_factor)) + mod(w-1, upscale_factor)+1;
17 outputs(i, 1, h, w) = inputs(i, channel_idx, height_idx, weight_idx);
18 end
19 end
20 end
21 end
4.完全なエンジニアリングgithubリンク.
https://github.com/JiJingYu/super-resolution-by-subpixel-convolution
モデルの重み値はmatlabの重み値として保存され、matlabで直接`demoを実行します.m`ファイルで検証可能
転載先:https://www.cnblogs.com/nwpuxuezha/p/7834344.html