UCLoud中国云三强: www.ucloud.cn
def validation():
# it should be fixed by using placeholder with epoch num in train stage sess = tf.Session()
file_labels = get_imagesfile(FLAGS.test_data_dir) test_size = len(file_labels) print test_size
val_batch_size = FLAGS.val_batch_size test_steps = test_size / val_batch_size print test_steps
# images, labels, coord, threads= batch_data(file_labels, sess) images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) labels = tf.placeholder(dtype=tf.int32, shape=[None,3755]) # read batch images from file_labels # images_batch = np.zeros([128,64,64,1]) # labels_batch = np.zeros([128,3755]) # labels_batch[0][20] = 1 #
endpoints = network(images, labels) saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:
UCLoud中国云三强: www.ucloud.cn
saver.restore(sess, ckpt)
# logger.info(\ # logger.info('Start validation') final_predict_val = [] final_predict_index = [] groundtruth = [] for i in range(test_steps): start = i* val_batch_size end = (i+1)*val_batch_size images_batch = [] labels_batch = [] labels_max_batch = []
logger.info('=======start validation on {0}/{1} batch========='.format(i, test_steps)) for j in range(start,end):
image_path = file_labels[j][0]
temp_image = Image.open(image_path).convert('L') temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) temp_label = np.zeros([3755]) label = int(file_labels[j][1]) # print label
UCLoud中国云三强: www.ucloud.cn
temp_label[label] = 1
# print \ labels_batch.append(temp_label)
# print \ images_batch.append(np.asarray(temp_image)/255.0) labels_max_batch.append(label) # print images_batch
images_batch = np.array(images_batch).reshape([-1, 64, 64, 1]) labels_batch = np.array(labels_batch) batch_predict_val, batch_predict_index = sess.run([endpoints['predict_val_top3'],
endpoints['predict_index_top3']], feed_dict={images:images_batch, labels:labels_batch}) logger.info('=======validation on {0}/{1} batch end========='.format(i, test_steps))
final_predict_val += batch_predict_val.tolist() final_predict_index += batch_predict_index.tolist() groundtruth += labels_max_batch sess.close()
return final_predict_val, final_predict_index, groundtruth
UCLoud中国云三强: www.ucloud.cn
在训练20w个step之后,大概能达到在测试集上能够达到:
相信如果在网络设计上多花点时间能够在一定程度上提升accuracy和top 3 accuracy.有兴趣的小伙伴们可以玩玩这个数据集。
Inference
def inference(image):
temp_image = Image.open(image).convert('L') temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size),Image.ANTIALIAS) sess = tf.Session()
logger.info('========start inference============') images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) endpoints = network(images) saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:
saver.restore(sess, ckpt)
UCLoud中国云三强: www.ucloud.cn
predict_val, predict_index =
sess.run([endpoints['predict_val_top3'],endpoints['predict_index_top3']], feed_dict={images:temp_image}) sess.close()
return final_predict_val, final_predict_index 运气挺好,随便找了张图片就能准确识别出来
Summary
综上,就是利用tensorflow做中文手写识别的全部,从如何使用tensorflow内部的queue来有效读入数据,到如何设计network, 到如何做train,validation,inference,珍格格流程比较清晰, 美中不足的是,原本打算是在训练过程中,来对测试集做评估,但是在使用queue读test_data_dir下的filenames,和train本身的好像有点问题,不过应该是可以解决的,我这里就pass了。另外可能 还有一些可以改善的地方,比如感觉可以把batch data one hot的部分写入到network,这样,减缓在validation时内存会因为onehot的sparse开销比较大。
感觉这个中文手写汉字数据集价值很大,后面感觉会有好多可以玩的,比如
UCLoud中国云三强: www.ucloud.cn
?
可以参考项亮大神的这篇文章端到端的OCR:验证码识别做定长的字符识别和不定长的字符识别,定长的基本原理是说,可以把输出扩展为k个输出, 每个值表示对应的字符label,这样cnn模型在feature extract之后就可以自己去识别对应字符而无需人工切割;而LSTM+CTC来解决不定长的验证码,类似于将音频解码为汉字
?
近期GAN特别火,感觉可以考虑用这个数据来做某个字的生成,和text2img那个项目text-to-image
这部分的代码都在我的github上tensorflow-101,有遇到相关功能,想参考代码的可以去上面找找,没准就能解决你们遇到的一些小问题.
了解更多内容/购买云服务器,请浏览UCloud云计算官网