フルボリュームニューラルネットワークFCN-TensorFlowコード精密分析


FCN-TensorFlowフルコードGithub:https://github.com/EternityZY/FCN-TensorFlow.git
ここではすべてのコードを解析し、詳細なコメントを追加します.
注意事項:
  • コードの要求に従って、VGG-19モデルとトレーニングセットをダウンロードしてください.実行ダウンロードが遅いです.
  • MODEL_URL =  'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'
  • DATA_URL =  'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'
  • コードは修正を経てTensorFlow 1で実行できる.4上面
  • 訓練モデルはpython FCNを実行するだけである.py
  • 学習率1 e-5を修正してさらに小さくしないとlossは3程度フローティング
  • になります.
  • debugフラグは、トレーニング中に設定され、アクティブ化関数、勾配、変数などの情報を追加することができる.全卷积神经网络FCN-TensorFlow代码精析_第1张图片 全卷积神经网络FCN-TensorFlow代码精析_第2张图片 全卷积神经网络FCN-TensorFlow代码精析_第3张图片

  • FCN.py
    # coding=utf-8
    from __future__ import print_function
    import tensorflow as tf
    import numpy as np
    
    import TensorflowUtils as utils
    import read_MITSceneParsingData as scene_parsing
    import datetime
    import BatchDatsetReader as dataset
    from six.moves import xrange
    
    #     
    FLAGS = tf.flags.FLAGS
    tf.flags.DEFINE_integer("batch_size", "2", "batch size for training")
    tf.flags.DEFINE_string("logs_dir", "logs/", "path to logs directory")
    tf.flags.DEFINE_string("data_dir", "Data_zoo/MIT_SceneParsing/", "path to dataset")
    tf.flags.DEFINE_float("learning_rate", "1e-6", "Learning rate for Adam Optimizer")
    tf.flags.DEFINE_string("model_dir", "Model_zoo/", "Path to vgg model mat")
    tf.flags.DEFINE_bool('debug', "True", "Debug mode: True/ False")
    tf.flags.DEFINE_string('mode', "train", "Mode train/ test/ visualize")
    
    MODEL_URL = 'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'
    
    MAX_ITERATION = 20000        #     
    NUM_OF_CLASSESS = 151                #     151
    IMAGE_SIZE = 224                    #      224
    fine_tuning = False
    
    # VGG    ,weights     , image        
    def vgg_net(weights, image):
        # VGG       
        layers = (
            'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
    
            'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
    
            'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
            'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
    
            'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
            'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
    
            'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
            'relu5_3', 'conv5_4', 'relu5_4'
        )
    
        net = {}
        current = image     #     
        for i, name in enumerate(layers):
            kind = name[:4]
            if kind == 'conv':
                kernels, bias = weights[i][0][0][0][0]
                # matconvnet: weights are [width, height, in_channels, out_channels]
                # tensorflow: weights are [height, width, in_channels, out_channels]
                kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w")     # conv1_1_w
                bias = utils.get_variable(bias.reshape(-1), name=name + "_b")       # conv1_1_b
                current = utils.conv2d_basic(current, kernels, bias)        #        current
            elif kind == 'relu':
                current = tf.nn.relu(current, name=name)    # relu1_1
                if FLAGS.debug:     #     debug   true / false
                    utils.add_activation_summary(current)       #   
            elif kind == 'pool':
                # vgg   5  stride  2,    5  size    1 
                #       4  stride,       
                #  5  pool         ,       
                # pool1 size  2 
                # pool2 size  4 
                # pool3 size  8 
                # pool4 size  16 
                current = utils.avg_pool_2x2(current)
            net[name] = current     #           net ,      
    
        return net
    
    
    #     ,image     ,keep_prob dropout  
    def inference(image, keep_prob):
        """
        Semantic segmentation network definition    #         
        :param image: input image. Should have values in range 0-255
        :param keep_prob:
        :return:
        """
        #        VGG
        print("setting up vgg initialized conv layers ...")
        # model_dir Model_zoo/
        # MODEL_URL   VGG19  
        model_data = utils.get_model_data(FLAGS.model_dir, MODEL_URL)       #   VGG19     
    
        mean = model_data['normalization'][0][0][0]                         #       
        mean_pixel = np.mean(mean, axis=(0, 1))                             # RGB
    
        weights = np.squeeze(model_data['layers'])                          #   VGG     ,    1             
    
        processed_image = utils.process_image(image, mean_pixel)            #      
    
        with tf.variable_scope("inference"):                                #        inference
            image_net = vgg_net(weights, processed_image)                   #            ,         
            conv_final_layer = image_net["conv5_3"]                         #       
    
            pool5 = utils.max_pool_2x2(conv_final_layer)                    # /32   32 
    
            W6 = utils.weight_variable([7, 7, 512, 4096], name="W6")        #     6  w b
            b6 = utils.bias_variable([4096], name="b6")
            conv6 = utils.conv2d_basic(pool5, W6, b6)
            relu6 = tf.nn.relu(conv6, name="relu6")
            if FLAGS.debug:
                utils.add_activation_summary(relu6)
            relu_dropout6 = tf.nn.dropout(relu6, keep_prob=keep_prob)
    
            W7 = utils.weight_variable([1, 1, 4096, 4096], name="W7")       #  7    
            b7 = utils.bias_variable([4096], name="b7")
            conv7 = utils.conv2d_basic(relu_dropout6, W7, b7)
            relu7 = tf.nn.relu(conv7, name="relu7")
            if FLAGS.debug:
                utils.add_activation_summary(relu7)
            relu_dropout7 = tf.nn.dropout(relu7, keep_prob=keep_prob)
    
            W8 = utils.weight_variable([1, 1, 4096, NUM_OF_CLASSESS], name="W8")
            b8 = utils.bias_variable([NUM_OF_CLASSESS], name="b8")
            conv8 = utils.conv2d_basic(relu_dropout7, W8, b8)               #  8       151 
            # annotation_pred1 = tf.argmax(conv8, dimension=3, name="prediction1")
    
            # now to upscale to actual image size
            deconv_shape1 = image_net["pool4"].get_shape()                  #  pool4 1/16            [b,h,w,c]
            #         W,B [H, W, OUTC, INC]       pool4     ,   conv8    
            #         stride = 2  kernel_size = 4
            W_t1 = utils.weight_variable([4, 4, deconv_shape1[3].value, NUM_OF_CLASSESS], name="W_t1")
            b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
            #    conv8   ,            ,         pool4    
            conv_t1 = utils.conv2d_transpose_strided(conv8, W_t1, b_t1, output_shape=tf.shape(image_net["pool4"]))
            fuse_1 = tf.add(conv_t1, image_net["pool4"], name="fuse_1")     #           
    
            #   pool3         1/8
            deconv_shape2 = image_net["pool3"].get_shape()
            #       pool3   ,        pool4   
            W_t2 = utils.weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
            b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
            #         fuse_1     ,     pool3  
            conv_t2 = utils.conv2d_transpose_strided(fuse_1, W_t2, b_t2, output_shape=tf.shape(image_net["pool3"]))
            #     deconv(fuse_1) + pool3
            fuse_2 = tf.add(conv_t2, image_net["pool3"], name="fuse_2")
    
            shape = tf.shape(image)     #         
            #     ,       ,[b,  H,  W,    ]
            deconv_shape3 = tf.stack([shape[0], shape[1], shape[2], NUM_OF_CLASSESS])
            #      w[8     ks=16,           ,      pool3   ]
            W_t3 = utils.weight_variable([16, 16, NUM_OF_CLASSESS, deconv_shape2[3].value], name="W_t3")
            b_t3 = utils.bias_variable([NUM_OF_CLASSESS], name="b_t3")
            #    ,fuse_2   ,      [b,  H,  W,    ]
            conv_t3 = utils.conv2d_transpose_strided(fuse_2, W_t3, b_t3, output_shape=deconv_shape3, stride=8)
    
            #   conv_t3    size         size,         
            #                ,   3  (   )  argmax               
            #            ,NUM_OF_CLASSESS          ,           
            #       21  ,           
            #      ,           shape=[b,h,w]
            annotation_pred = tf.argmax(conv_t3, dimension=3, name="prediction")
        #           [b,h,w,c]   c=1, conv_t3    21      
        return tf.expand_dims(annotation_pred, dim=3), conv_t3
    
    
    def train(loss_val, var_list):
        """
    
        :param loss_val:      
        :param var_list:        
        :return:
        """
        optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
        grads = optimizer.compute_gradients(loss_val, var_list=var_list)
        if FLAGS.debug:
            # print(len(var_list))
            for grad, var in grads:
                utils.add_gradient_summary(grad, var)
        return optimizer.apply_gradients(grads)     #       
    
    
    def main(argv=None):
        # dropout   
        keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
        #     
        image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
        #     
        annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")
    
        #     batch         [b,h,w,c=1]       [b,h,w,c=151]
        pred_annotation, logits = inference(image, keep_probability)
        tf.summary.image("input_image", image, max_outputs=2)
        tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2)
        tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
        #          [b,h,w,c=151]   labels[b,h,w]            
        loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                              labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                              name="entropy")))
        tf.summary.scalar("entropy", loss)
    
        #            
        trainable_var = tf.trainable_variables()
        if FLAGS.debug:
            for var in trainable_var:
                utils.add_to_regularization_and_summary(var)
    
        #                 
        train_op = train(loss, trainable_var)
    
        print("Setting up summary op...")
        #       
        summary_op = tf.summary.merge_all()
    
        print("Setting up image reader...")
        # data_dir = Data_zoo/MIT_SceneParsing/
        # training: [{image:      , annotation:     , filename:    }] [{}][{}]
        train_records, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
        print(len(train_records))   #   
        print(len(valid_records))
    
        print("Setting up dataset reader")
        image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
        if FLAGS.mode == 'train':
            #                      
            train_dataset_reader = dataset.BatchDatset(train_records, image_options)
        validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)
    
        sess = tf.Session()
    
        print("Setting up Saver...")
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)
    
        sess.run(tf.global_variables_initializer())
        # logs/
        if fine_tuning:
            ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)    #       
            if ckpt and ckpt.model_checkpoint_path:                 #     checkpoint      sess
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("Model restored...")
    
        if FLAGS.mode == "train":
            for itr in range(MAX_ITERATION):
                #     batch
                train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
                feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85}
    
                #            
                sess.run(train_op, feed_dict=feed_dict)
    
                if itr % 10 == 0:
                    #   10     
                    train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
                    print("Step: %d, Train_loss:%g" % (itr, train_loss))
                    summary_writer.add_summary(summary_str, itr)
    
                if itr % 500 == 0:
                    #   500    
                    valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.batch_size)
                    valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
                                                           keep_probability: 1.0})
                    print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
                    #     
                    saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
    
        elif FLAGS.mode == "visualize":
            #    
            valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
            # pred_annotation     
            pred = sess.run(pred_annotation, feed_dict={image: valid_images, annotation: valid_annotations,
                                                        keep_probability: 1.0})
            valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
    
            for itr in range(FLAGS.batch_size):
                utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
                utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
                utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
                print("Saved image: %d" % itr)
    
    
    if __name__ == "__main__":
        tf.app.run()
    

    read_MITSceneParsingData.py
    # coding=utf-8
    __author__ = 'charlie'
    import numpy as np
    import os
    import random
    from six.moves import cPickle as pickle
    from tensorflow.python.platform import gfile
    import glob
    
    import TensorflowUtils as utils
    
    # DATA_URL = 'http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip'
    DATA_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'
    
    
    def read_dataset(data_dir):
        # data_dir = Data_zoo / MIT_SceneParsing /
        pickle_filename = "MITSceneParsing.pickle"
        #       Data_zoo / MIT_SceneParsing / MITSceneParsing.pickle
        pickle_filepath = os.path.join(data_dir, pickle_filename)
        if not os.path.exists(pickle_filepath):
            utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)       #          
            SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0]          # ADEChallengeData2016
            # result =   {training: [{image:      , annotation:     , filename:    }] [][]
            #            validation:[{image:     , annotation:     , filename:    }] [] []}
            result = create_image_lists(os.path.join(data_dir, SceneParsing_folder))    # Data_zoo / MIT_SceneParsing / ADEChallengeData2016
            print ("Pickling ...")      #   pickle  
            with open(pickle_filepath, 'wb') as f:
                pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
        else:
            print ("Found pickle file!")
    
        with open(pickle_filepath, 'rb') as f:      #   pickle  
            result = pickle.load(f)                 #   
            training_records = result['training']
            validation_records = result['validation']
            del result
        # training: [{image:      , annotation:     , filename:    }] [{}][{}]
        return training_records, validation_records
    
    
    def create_image_lists(image_dir):
        """
    
        :param image_dir:   Data_zoo / MIT_SceneParsing / ADEChallengeData2016
        :return:
        """
        if not gfile.Exists(image_dir):
            print("Image directory '" + image_dir + "' not found.")
            return None
        directories = ['training', 'validation']
        image_list = {}     #        training:[]  validation:[]
    
        for directory in directories:       #             
            file_list = []
            image_list[directory] = []
            # Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/training/*.jpg
            file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg')
            #                    +        Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/training/hi.jpg
            file_list.extend(glob.glob(file_glob))
    
            if not file_list:   #     
                print('No files found')
            else:
                for f in file_list:     #            f       
                    #        hi
                    filename = os.path.splitext(f.split("/")[-1])[0]
                    # Data_zoo/MIT_SceneParsing/ADEChallengeData2016/annotations/training/*.png
                    annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.png')
                    if os.path.exists(annotation_file):     #         
                        #  image:     , annotation:     , filename:    
                        record = {'image': f, 'annotation': annotation_file, 'filename': filename}
                        # image_list{training:[{image:     , annotation:     , filename:    }] [] []
                        #            validation:[{image:     , annotation:     , filename:    }] [] []}
                        image_list[directory].append(record)
                    else:
                        print("Annotation file not found for %s - Skipping" % filename)
            #          
            random.shuffle(image_list[directory])
            no_of_images = len(image_list[directory])   #          
            print ('No. of %s files: %d' % (directory, no_of_images))
    
        return image_list
    

    TensorflowUitls.py
    # coding=utf-8
    __author__ = 'Charlie'
    # Utils used with tensorflow implemetation
    import tensorflow as tf
    import numpy as np
    import scipy.misc as misc
    import os, sys
    from six.moves import urllib
    import tarfile
    import zipfile
    import scipy.io
    
    
    #   VGG     
    def get_model_data(dir_path, model_url):
        # model_dir Model_zoo/
        # MODEL_URL   VGG19  
        maybe_download_and_extract(dir_path, model_url)     #              ,       
        filename = model_url.split("/")[-1]                 #  url /  ,              
        filepath = os.path.join(dir_path, filename)         # dir_path/filename          
        if not os.path.exists(filepath):                    #          
            raise IOError("VGG Model not found!")
        data = scipy.io.loadmat(filepath)                   #   io  VGG.mat  
        return data
    
    
    def maybe_download_and_extract(dir_path, url_name, is_tarfile=False, is_zipfile=False):
        # dir_path Model_zoo/
        # url_name   VGG19  
        if not os.path.exists(dir_path):        #           ,           
            os.makedirs(dir_path)
        filename = url_name.split('/')[-1]      #  url    /  ,                
        filepath = os.path.join(dir_path, filename)     #      = dir_path/filename
        if not os.path.exists(filepath):         #          (   ),     ,   
            def _progress(count, block_size, total_size):       #     
                sys.stdout.write(
                    '\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0))
                sys.stdout.flush()
    
            filepath, _ = urllib.request.urlretrieve(url_name, filepath, reporthook=_progress)    #  url       filepath   
            print()
            statinfo = os.stat(filepath)
            print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
            if is_tarfile:          #    tar  ,    
                tarfile.open(filepath, 'r:gz').extractall(dir_path)
            elif is_zipfile:        #    zip      
                with zipfile.ZipFile(filepath) as zf:
                    zip_dir = zf.namelist()[0]
                    zf.extractall(dir_path)
    
    

    BatchDatsetReader.py
    # coding=utf-8
    """
    Code ideas from https://github.com/Newmu/dcgan and tensorflow mnist dataset reader
    """
    import numpy as np
    import scipy.misc as misc
    
    
    class BatchDatset:
        files = []
        images = []
        annotations = []
        image_options = {}
        batch_offset = 0
        epochs_completed = 0
    
        def __init__(self, records_list, image_options={}):
            """
            Intialize a generic file reader with batching for list of files
            :param records_list: list of file records to read -
            sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
            :param image_options: A dictionary of options for modifying the output image
            Available options:
            resize = True/ False
            resize_size = #size of output image - does bilinear resize
            color=True/False
            """
            print("Initializing Batch Dataset Reader...")
            print(image_options)
            self.files = records_list       #     
            self.image_options = image_options  #        resize  224
            self._read_images()
    
        def _read_images(self):
            self.__channels = True
            #   files     image      
            #            ,      RGB  
            self.images = np.array([self._transform(filename['image']) for filename in self.files])
            self.__channels = False
    
            #   files     annotation      
            #            ,           
            self.annotations = np.array(
                [np.expand_dims(self._transform(filename['annotation']), axis=3) for filename in self.files])
            print (self.images.shape)
            print (self.annotations.shape)
    
        def _transform(self, filename):
            #       
            image = misc.imread(filename)
            if self.__channels and len(image.shape) < 3:  # make sure images are of shape(h,w,3)
                #                
                image = np.array([image for i in range(3)])
    
            if self.image_options.get("resize", False) and self.image_options["resize"]:
    
                resize_size = int(self.image_options["resize_size"])
                #         resize  
                resize_image = misc.imresize(image,
                                             [resize_size, resize_size], interp='nearest')
            else:
                resize_image = image
    
            return np.array(resize_image)       #     resize   
    
        def get_records(self):
            """
                      
            :return:
            """
            return self.images, self.annotations
    
        def reset_batch_offset(self, offset=0):
            """
               batch
            :param offset:
            :return:
            """
            self.batch_offset = offset
    
        def next_batch(self, batch_size):
            #      batch
            start = self.batch_offset
            #      batch    offset   +batch_size
            self.batch_offset += batch_size
            # iamges         images.shape(len, h, w)
            if self.batch_offset > self.images.shape[0]:      #      batch                   epoch
                # Finished epoch
                self.epochs_completed += 1      # epochs    +1
                print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
                # Shuffle the data
                perm = np.arange(self.images.shape[0])      # arange    (0 - len-1)       
                np.random.shuffle(perm)         #        
                self.images = self.images[perm]     #          
                self.annotations = self.annotations[perm]
                # Start next epoch
                start = 0           #    epoch 0  
                self.batch_offset = batch_size  #     batch   
    
            end = self.batch_offset             #      self.batch_offset   self.batch_offset+batch_size
            return self.images[start:end], self.annotations[start:end]      #   batch
    
        def get_random_batch(self, batch_size):
            #     batch_size                    ,        
            indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
            return self.images[indexes], self.annotations[indexes]