TensorflowベースのUnetネットワーク実装


論文の住所:http://www.arxiv.org/pdf/1505.04597.pdf
基于Tensorflow的Unet网络实现_第1张图片
インターネットネットワーク、分野を分割する経典の作品、みんなは試してみることができます.くだらないことは言わないで、コードをつけてください.
import tensorflow as tf


def convolutional(input_data, filters_shape, trainable, name, downsample=False, activate=True, bn=True):

    with tf.variable_scope(name):
        if downsample:
            pad_h, pad_w = (filters_shape[0] - 2) // 2 + 1, (filters_shape[1] - 2) // 2 + 1
            paddings = tf.constant([[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]])
            input_data = tf.pad(input_data, paddings, 'CONSTANT')
            strides = (1, 2, 2, 1)
            padding = 'VALID'
        else:
            strides = (1, 1, 1, 1)
            padding = "SAME"

        weight = tf.get_variable(name='weight', dtype=tf.float32, trainable=True,
                                 shape=filters_shape, initializer=tf.random_normal_initializer(stddev=0.01))
        conv = tf.nn.conv2d(input=input_data, filter=weight, strides=strides, padding=padding)

        if bn:
            conv = tf.layers.batch_normalization(conv, beta_initializer=tf.zeros_initializer(),
                                                 gamma_initializer=tf.ones_initializer(),
                                                 moving_mean_initializer=tf.zeros_initializer(),
                                                 moving_variance_initializer=tf.ones_initializer(), training=trainable)
        else:
            bias = tf.get_variable(name='bias', shape=filters_shape[-1], trainable=True,
                                   dtype=tf.float32, initializer=tf.constant_initializer(0.0))
            conv = tf.nn.bias_add(conv, bias)

        if activate is True:
            conv = tf.nn.leaky_relu(conv, alpha=0.1)

    return conv


def upsample(input_data, name, method="deconv"):
    assert method in ["resize", "deconv"]
    if method == "resize":
        with tf.variable_scope(name):
            input_shape = tf.shape(input_data)
            output = tf.image.resize_nearest_neighbor(input_data, (input_shape[1] * 2, input_shape[2] * 2))

    if method == "deconv":
        # replace resize_nearest_neighbor with conv2d_transpose To support TensorRT optimization
        numm_filter = input_data.shape.as_list()[-1]
        output = tf.layers.conv2d_transpose(input_data, numm_filter//2, kernel_size=4, padding='same',
                                            strides=(2, 2), kernel_initializer=tf.random_normal_initializer())

    return output


def Unet(images, filters=8, name='unet'):
    with tf.variable_scope(name):
        endpoints = {}
        conv = convolutional(images, [3, 3, 3, filters], trainable=True, name='conv1')
        conv = convolutional(conv, [3, 3, filters, filters], trainable=True, name='conv2')
        endpoints['C1'] = conv
        # downsample 1
        conv = convolutional(conv, [3, 3, filters, filters], trainable=True, name='conv3', downsample=True)
        conv = convolutional(conv, [3, 3, filters, filters * 2], trainable=True, name='conv4')
        conv = convolutional(conv, [3, 3, filters * 2, filters * 2], trainable=True, name='conv5')
        endpoints['C2'] = conv
        # downsample 2
        conv = convolutional(conv, [3, 3, filters * 2, filters * 2], trainable=True, name='conv6', downsample=True)
        conv = convolutional(conv, [3, 3, filters * 2, filters * 4], trainable=True, name='conv7')
        conv = convolutional(conv, [3, 3, filters * 4, filters * 4], trainable=True, name='conv8')
        endpoints['C3'] = conv
        # downsample 3
        conv = convolutional(conv, [3, 3, filters * 4, filters * 4], trainable=True, name='conv9', downsample=True)
        conv = convolutional(conv, [3, 3, filters * 4, filters * 8], trainable=True, name='conv10')
        conv = convolutional(conv, [3, 3, filters * 8, filters * 8], trainable=True, name='conv11')
        endpoints['C4'] = conv
        # downsample 4
        conv = convolutional(conv, [3, 3, filters * 8, filters * 8], trainable=True, name='conv12', downsample=True)
        conv = convolutional(conv, [3, 3, filters * 8, filters * 16], trainable=True, name='conv13')
        conv = convolutional(conv, [3, 3, filters * 16, filters * 16], trainable=True, name='conv14')
        endpoints['C5'] = conv

        conv = convolutional(conv, [3, 3, filters * 16, filters * 16], trainable=True, name='conv15', downsample=True)
        conv = convolutional(conv, [3, 3, filters * 16, filters * 32], trainable=True, name='conv16')
        conv = convolutional(conv, [3, 3, filters * 32, filters * 32], trainable=True, name='conv17')
        endpoints['C6'] = conv

        conv = convolutional(conv, [3, 3, filters * 32, filters * 32], trainable=True, name='conv18', downsample=True)
        conv = convolutional(conv, [3, 3, filters * 32, filters * 64], trainable=True, name='conv19')
        conv = convolutional(conv, [3, 3, filters * 64, filters * 64], trainable=True, name='conv20')
        endpoints['C7'] = conv


        for i in range(7, 1, -1):
            with tf.variable_scope('Ronghe%d' % i):
                uplayer = upsample(conv, 'deconv%d' % (8-i), method="deconv")
                concat = tf.concat([endpoints['C%d' % (i-1)], uplayer], axis=-1)
                dim = concat.get_shape()[-1].value
                conv = convolutional(concat, [3, 3, dim, dim//2], trainable=True, name='conv1')
                conv = convolutional(conv, [3, 3, dim//2, dim//2], trainable=True, name='conv2')
        out = convolutional(conv, [3, 3, dim//2, 1], trainable=True, name='out', activate=False, bn=False)

    return out