UCLoud中国云三强: www.ucloud.cn
Train
train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。
def train():
sess = tf.Session()
file_labels = get_imagesfile(FLAGS.train_data_dir)
images, labels, coord, threads = batch_data(file_labels, sess) endpoints = network(images, labels) saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
train_writer = tf.train.SummaryWriter('./log' + '/train',sess.graph) test_writer = tf.train.SummaryWriter('./log' + '/val') start_step = 0 if FLAGS.restore:
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:
saver.restore(sess, ckpt)
print \ start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')
UCLoud中国云三强: www.ucloud.cn
try:
while not coord.should_stop(): # logger.info('step {0} start'.format(i)) start_time = time.time()
_, loss_val, train_summary, step = sess.run([endpoints['train_op'], endpoints['loss'], endpoints['merged_summary_op'], endpoints['global_step']]) train_writer.add_summary(train_summary, step) end_time = time.time()
logger.info(\end_time-start_time, loss_val))
if step > FLAGS.max_steps: break
# logger.info(\end_time-start_time, loss_val))
if step % FLAGS.eval_steps == 1:
accuracy_val,test_summary, step = sess.run([endpoints['accuracy'], endpoints['merged_summary_op'], endpoints['global_step']]) test_writer.add_summary(test_summary, step)
logger.info('===============Eval a batch in Train data=======================')
logger.info( 'the step {0} accuracy {1}'.format(step, accuracy_val))
UCLoud中国云三强: www.ucloud.cn
logger.info('===============Eval a batch in Train data=======================') if step % FLAGS.save_steps == 1:
logger.info('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step']) except tf.errors.OutOfRangeError:
# print \ logger.info('==================Train Finished================')
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step']) finally:
coord.request_stop() coord.join(threads) sess.close()
UCLoud中国云三强: www.ucloud.cn
Graph
UCLoud中国云三强: www.ucloud.cn
Loss and Accuracy
Validation
训练完成之后,想对完成的模型在测试数据集上做一个评估,这里我也曾经尝试利用batch_data,将slice_input_producer中epoch设置为1,来做相关的工作,但是发现这里无法和train 共用,会出现epoch无初始化值的问题(train中传epoch为None),所以这里自己写了shuffle batch的逻辑,将测试集的images和labels通过feed_dict传进到网络,得到模型的输出, 然后做相关指标的计算: