深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(4)

2018-11-17 20:27

UCLoud中国云三强: www.ucloud.cn

def validation():

# it should be fixed by using placeholder with epoch num in train stage sess = tf.Session()

file_labels = get_imagesfile(FLAGS.test_data_dir) test_size = len(file_labels) print test_size

val_batch_size = FLAGS.val_batch_size test_steps = test_size / val_batch_size print test_steps

# images, labels, coord, threads= batch_data(file_labels, sess) images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) labels = tf.placeholder(dtype=tf.int32, shape=[None,3755]) # read batch images from file_labels # images_batch = np.zeros([128,64,64,1]) # labels_batch = np.zeros([128,3755]) # labels_batch[0][20] = 1 #

endpoints = network(images, labels) saver = tf.train.Saver()

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:

UCLoud中国云三强: www.ucloud.cn

saver.restore(sess, ckpt)

# logger.info(\ # logger.info('Start validation') final_predict_val = [] final_predict_index = [] groundtruth = [] for i in range(test_steps): start = i* val_batch_size end = (i+1)*val_batch_size images_batch = [] labels_batch = [] labels_max_batch = []

logger.info('=======start validation on {0}/{1} batch========='.format(i, test_steps)) for j in range(start,end):

image_path = file_labels[j][0]

temp_image = Image.open(image_path).convert('L') temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) temp_label = np.zeros([3755]) label = int(file_labels[j][1]) # print label

UCLoud中国云三强: www.ucloud.cn

temp_label[label] = 1

# print \ labels_batch.append(temp_label)

# print \ images_batch.append(np.asarray(temp_image)/255.0) labels_max_batch.append(label) # print images_batch

images_batch = np.array(images_batch).reshape([-1, 64, 64, 1]) labels_batch = np.array(labels_batch) batch_predict_val, batch_predict_index = sess.run([endpoints['predict_val_top3'],

endpoints['predict_index_top3']], feed_dict={images:images_batch, labels:labels_batch}) logger.info('=======validation on {0}/{1} batch end========='.format(i, test_steps))

final_predict_val += batch_predict_val.tolist() final_predict_index += batch_predict_index.tolist() groundtruth += labels_max_batch sess.close()

return final_predict_val, final_predict_index, groundtruth

UCLoud中国云三强: www.ucloud.cn

在训练20w个step之后,大概能达到在测试集上能够达到:

相信如果在网络设计上多花点时间能够在一定程度上提升accuracy和top 3 accuracy.有兴趣的小伙伴们可以玩玩这个数据集。

Inference

def inference(image):

temp_image = Image.open(image).convert('L') temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) sess = tf.Session()

logger.info('========start inference============') images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) endpoints = network(images) saver = tf.train.Saver()

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:

saver.restore(sess, ckpt)

UCLoud中国云三强: www.ucloud.cn

predict_val, predict_index =

sess.run([endpoints['predict_val_top3'],endpoints['predict_index_top3']], feed_dict={images:temp_image}) sess.close()

return final_predict_val, final_predict_index 运气挺好,随便找了张图片就能准确识别出来

Summary

综上,就是利用tensorflow做中文手写识别的全部,从如何使用tensorflow内部的queue来有效读入数据,到如何设计network, 到如何做train,validation,inference,珍格格流程比较清晰, 美中不足的是,原本打算是在训练过程中,来对测试集做评估,但是在使用queue读test_data_dir下的filenames,和train本身的好像有点问题,不过应该是可以解决的,我这里就pass了。另外可能 还有一些可以改善的地方,比如感觉可以把batch data one hot的部分写入到network,这样,减缓在validation时内存会因为onehot的sparse开销比较大。

感觉这个中文手写汉字数据集价值很大,后面感觉会有好多可以玩的,比如

UCLoud中国云三强: www.ucloud.cn

?

可以参考项亮大神的这篇文章端到端的OCR:验证码识别做定长的字符识别和不定长的字符识别,定长的基本原理是说,可以把输出扩展为k个输出, 每个值表示对应的字符label,这样cnn模型在feature extract之后就可以自己去识别对应字符而无需人工切割;而LSTM+CTC来解决不定长的验证码,类似于将音频解码为汉字

?

近期GAN特别火,感觉可以考虑用这个数据来做某个字的生成,和text2img那个项目text-to-image

这部分的代码都在我的github上tensorflow-101,有遇到相关功能,想参考代码的可以去上面找找,没准就能解决你们遇到的一些小问题.

了解更多内容/购买云服务器,请浏览UCloud云计算官网


深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(4).doc 将本文的Word文档下载到电脑 下载失败或者文档不完整,请联系客服人员解决!

下一篇:2014年国家公务员面试辅导:我的海关不是梦

相关阅读
本类排行
× 注册会员免费下载(下载后可以自由复制和排版)

马上注册会员

注:下载文档有可能“只有目录或者内容不全”等情况,请下载之前注意辨别,如果您已付费且无法下载或内容有问题,请联系我们协助你处理。
微信: QQ: