深さ学習(七十二)tensorflowクラスタ訓練
5701 ワード
#encoding:utf-8
# -*- coding: utf-8 -*-
# :1、 ;2、 ;
# 3、 ; 4、 batch size ;5、 、
#6、 boostrap loss
import os
import tensorflow as tf
from input_data import Data_layer
import net
num_class = 2
input_height = 256
input_width = 256
crop_height = 224
crop_width = 224
learning_rate = 0.01
tf.set_random_seed(123)
batch_size = tf.placeholder(tf.int32, [], 'batch_size')
tf.add_to_collection('batch_size', batch_size)
is_training = tf.placeholder(tf.bool, [])
is_boostrap = tf.placeholder(tf.bool, [])
drop_prob = tf.placeholder(tf.float32, [])
tf.add_to_collection('is_training', is_training)
def load_save_model(sess,saver,model_path,is_save):
if is_save is False:
print "***********restore model from %s***************"%model_path
saver.restore(sess, model_path)
else:
saver.save(sess, model_path)
def train_cluster(train_volume_data,valid_volume_data,model_path):
tf.flags.DEFINE_string("ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs")
tf.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "", "Either 'ps' or 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_string("volumes", "", "volumes info")
FLAGS = tf.app.flags.FLAGS
ps_hosts = FLAGS.ps_hosts.split(",")
print("ps_hosts:", ps_hosts)
worker_hosts = FLAGS.worker_hosts.split(",")
print("worker_hosts:", worker_hosts)
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
print("FLAGS.task_index:", FLAGS.task_index)
# Create and start a server for the local task.
server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
input_data = Data_layer(train_volume_data, valid_volume_data, batch_size=batch_size
, image_height=input_height, image_width=input_width, crop_height=crop_height
, crop_width=crop_width)
images, labels = input_data.get_next_batch(is_training, num_class)
#net_worker = net.resnet(images, labels, num_class, 18, is_training, drop_prob)
net_worker = net.resnet256(images, labels, num_class, is_training,is_boostrap)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), init_op=init,saver=saver,global_step=net_worker['global_step'])
with sv.prepare_or_wait_for_session(server.target) as session:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
#threads = sv.start_queue_runners(session)
load_save_model(session, saver, model_path, False)
try:
for i in range(400000):
if i < 10000:
train_dict = {batch_size: 32,
drop_prob: 1,
is_training: True,
is_boostrap: False
}
else:
train_dict = {batch_size: 32,
drop_prob: 1,
is_training: True,
is_boostrap: True
}
step, _ = session.run([net_worker['global_step'], net_worker['train_op']], feed_dict=train_dict)
if i % 500 == 0:
train_dict = {batch_size: 32,
drop_prob: 1,
is_training: True,
is_boostrap: False
}
entropy, train_acc = session.run([net_worker['cross_entropy'], net_worker['accuracy']],
feed_dict=train_dict)
print('***** {}:{},{} *****'.format(i, entropy, train_acc))
if i % 2000 == 0:
test_dict = {drop_prob: 1.0,
is_training: False,
batch_size: 256,
is_boostrap:False}
acc = session.run(net_worker['accuracy'], feed_dict=test_dict)
print('*****locate step {},valid step {}:accuracy {} *****'.format(i,step, acc))
if i>3000:
print "**************save model***************"
load_save_model(session,saver,model_path,True)
except Exception, e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)