gluonネットワークの可視化


1,print(net)を利用する.Netネットワークの具体的な各層のパラメータが表示されます.2,mx.viz.plot_network(net).view()を利用する.これはpdf画像を具体的に出すことができますが、symプログラミングに移行しなければなりません.どうやってここに?
栗を挙げます.
import mxnet as mx
from mxnet.gluon import nn


class Net(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            self.rgb_conv1 = nn.Conv2D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')
            self.rgb_conv2 = nn.Conv2D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')

            self.tdf_conv1 = nn.Conv3D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')
            self.tdf_conv2 = nn.Conv3D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')

            self.dense_fusion1 = nn.Dense(1024, activation='relu')
            self.dense_fusion2 = nn.Dense(2048, activation='relu')
            self.dense_fusion3 = nn.Dense(512, activation='relu')

            self.dense_prediction1 = nn.Dense(2048, activation='relu')
            self.dense_prediction2 = nn.Dense(2)

            self.maxpool2D = nn.MaxPool2D(pool_size=3, strides=2)
            self.maxpool3D = nn.MaxPool3D(pool_size=3, strides=2)

    def CNNBlock(self, F, rgb, tdf):
        rgb_conv = self.rgb_conv1(rgb)
        rgb_conv = self.maxpool2D(rgb_conv)
        rgb_conv = self.rgb_conv2(rgb_conv)

        tdf_conv = self.tdf_conv1(tdf)
        tdf_conv = self.maxpool3D(tdf_conv)
        tdf_conv = self.tdf_conv2(tdf_conv)

        flatten = nn.Flatten()
        rgb_conv = flatten(rgb_conv)
        tdf_conv = flatten(tdf_conv)

        fc = F.concat(rgb_conv, tdf_conv, dim=1)
        fc = self.dense_fusion1(fc)
        fc = self.dense_fusion2(fc)
        fc = self.dense_fusion3(fc)

        return fc

    def hybrid_forward(self, F, rgb1, tdf1, rgb2, tdf2):
        out1 = self.CNNBlock(F, rgb1, tdf1)
        out2 = self.CNNBlock(F, rgb2, tdf2)

        out = F.concat(out1, out2, dim=1)
        out = self.dense_prediction2(self.dense_prediction1(out))
        return out

    def getFeature(self, img, depth):
        return self.CNNBlock(img, depth)

net = Net()
#      
print(net)
#      
mx.viz.plot_network(net(mx.sym.var("data1"), mx.sym.var("data2"), mx.sym.var("data3"), mx.sym.var("data4"))).view()