深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(3)

2018-11-17 20:27

UCLoud中国云三强: www.ucloud.cn

Train

train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。

def train():

sess = tf.Session()

file_labels = get_imagesfile(FLAGS.train_data_dir)

images, labels, coord, threads = batch_data(file_labels, sess) endpoints = network(images, labels) saver = tf.train.Saver()

sess.run(tf.global_variables_initializer())

train_writer = tf.train.SummaryWriter('./log' + '/train',sess.graph) test_writer = tf.train.SummaryWriter('./log' + '/val') start_step = 0 if FLAGS.restore:

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:

saver.restore(sess, ckpt)

print \ start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')

UCLoud中国云三强: www.ucloud.cn

try:

while not coord.should_stop(): # logger.info('step {0} start'.format(i)) start_time = time.time()

_, loss_val, train_summary, step = sess.run([endpoints['train_op'], endpoints['loss'], endpoints['merged_summary_op'], endpoints['global_step']]) train_writer.add_summary(train_summary, step) end_time = time.time()

logger.info(\end_time-start_time, loss_val))

if step > FLAGS.max_steps: break

# logger.info(\end_time-start_time, loss_val))

if step % FLAGS.eval_steps == 1:

accuracy_val,test_summary, step = sess.run([endpoints['accuracy'], endpoints['merged_summary_op'], endpoints['global_step']]) test_writer.add_summary(test_summary, step)

logger.info('===============Eval a batch in Train data=======================')

logger.info( 'the step {0} accuracy {1}'.format(step, accuracy_val))

UCLoud中国云三强: www.ucloud.cn

logger.info('===============Eval a batch in Train data=======================') if step % FLAGS.save_steps == 1:

logger.info('Save the ckpt of {0}'.format(step))

saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step']) except tf.errors.OutOfRangeError:

# print \ logger.info('==================Train Finished================')

saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step']) finally:

coord.request_stop() coord.join(threads) sess.close()

UCLoud中国云三强: www.ucloud.cn

Graph

UCLoud中国云三强: www.ucloud.cn

Loss and Accuracy

Validation

训练完成之后,想对完成的模型在测试数据集上做一个评估,这里我也曾经尝试利用batch_data,将slice_input_producer中epoch设置为1,来做相关的工作,但是发现这里无法和train 共用,会出现epoch无初始化值的问题(train中传epoch为None),所以这里自己写了shuffle batch的逻辑,将测试集的images和labels通过feed_dict传进到网络,得到模型的输出, 然后做相关指标的计算:


深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(3).doc 将本文的Word文档下载到电脑 下载失败或者文档不完整,请联系客服人员解决!

下一篇:2014年国家公务员面试辅导:我的海关不是梦

相关阅读
本类排行
× 注册会员免费下载(下载后可以自由复制和排版)

马上注册会员

注:下载文档有可能“只有目录或者内容不全”等情况,请下载之前注意辨别,如果您已付费且无法下载或内容有问题,请联系我们协助你处理。
微信: QQ: