深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(2)

2018-11-17 20:27

UCLoud中国云三强: www.ucloud.cn

os.mkdir(dir_name)

im.convert('RGB').save(dir_name+'/' + str(test_counter) + '.png') test_counter += 1

处理好的数据,放到了云盘,大家可以直接在我的云盘来下载处理好的数据集HWDB1. 这里说明下,char_dict是汉字和对应的数字label的记录。

得到数据集后,就要考虑如何读取了,一次用numpy读入内存在很多小数据集上是可以行的,但是在稍微大点的数据集上内存就成了瓶颈,但是不要害怕,TensorFlow有自己的方法:

def batch_data(file_labels,sess, batch_size=128): image_list = [file_label[0] for file_label in file_labels] label_list = [int(file_label[1]) for file_label in file_labels] print 'tag2 {0}'.format(len(image_list))

images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string) labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64)

input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor])

labels = input_queue[1]

images_content = tf.read_file(input_queue[0])

# images = tf.image.decode_png(images_content, channels=1)

images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)

UCLoud中国云三强: www.ucloud.cn

# images = images / 256 images = pre_process(images) # print images.get_shape() # one hot

labels = tf.one_hot(labels, 3755)

image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000) # print 'image_batch', image_batch.get_shape()

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord) return image_batch, label_batch, coord, threads

简单介绍下,首先你需要得到所有的图像的path和对应的label的列表,利用tf.convert_to_tensor转换为对应的tensor, 利用tf.train.slice_input_producer将image_list ,label_list做一个slice处理,然后做图像的读取、预处理,以及label的one_hot表示,然后就是传到tf.train.shuffle_batch产生一个个shuffle batch,这些就可以feed到你的 模型。 slice_input_producer和shuffle_batch这类操作内部都是基于queue,是一种异步的处理方式,会在设备中开辟一段空间用作cache,不同的进程会分别一直往cache中塞数据 和取数据,保证内存或显存的占用以及每一个mini-batch不需要等待,直接可以从cache中获取。

UCLoud中国云三强: www.ucloud.cn

Data Augmentation

由于图像场景不复杂,只是做了一些基本的处理,包括图像翻转,改变下亮度等等,这些在TensorFlow里面有现成的api,所以尽量使用TensorFlow来做相关的处理:

def pre_process(images):

if FLAGS.random_flip_up_down:

images = tf.image.random_flip_up_down(images) if FLAGS.random_flip_left_right:

images = tf.image.random_flip_left_right(images) if FLAGS.random_brightness:

images = tf.image.random_brightness(images, max_delta=0.3) if FLAGS.random_contrast:

UCLoud中国云三强: www.ucloud.cn

images = tf.image.random_contrast(images, 0.8, 1.2)

new_size = tf.constant([FLAGS.image_size,FLAGS.image_size], dtype=tf.int32) images = tf.image.resize_images(images, new_size) return images

Build Graph

这里很简单的构造了一个两个卷积+一个全连接层的网络,没有做什么更深的设计,感觉意义不大,设计了一个dict,用来返回后面要用的所有op,还有就是为了方便再训练中查看loss和accuracy, 没有什么特别的,很容易理解, labels 为None时 方便做inference。

def network(images, labels=None): endpoints = {}

conv_1 = slim.conv2d(images, 32, [3,3],1, padding='SAME')

max_pool_1 = slim.max_pool2d(conv_1, [2,2],[2,2], padding='SAME') conv_2 = slim.conv2d(max_pool_1, 64, [3,3],padding='SAME') max_pool_2 = slim.max_pool2d(conv_2, [2,2],[2,2], padding='SAME') flatten = slim.flatten(max_pool_2)

out = slim.fully_connected(flatten,3755, activation_fn=None) global_step = tf.Variable(initial_value=0) if labels is not None:

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out, labels)) train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=global_step)

UCLoud中国云三强: www.ucloud.cn

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(labels, 1)), tf.float32)) tf.summary.scalar('loss', loss)

tf.summary.scalar('accuracy', accuracy)

merged_summary_op = tf.summary.merge_all() output_score = tf.nn.softmax(out)

predict_val_top3, predict_index_top3 = tf.nn.top_k(output_score, k=3)

endpoints['global_step'] = global_step if labels is not None:

endpoints['labels'] = labels endpoints['train_op'] = train_op endpoints['loss'] = loss

endpoints['accuracy'] = accuracy

endpoints['merged_summary_op'] = merged_summary_op endpoints['output_score'] = output_score endpoints['predict_val_top3'] = predict_val_top3 endpoints['predict_index_top3'] = predict_index_top3 return endpoints


深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(2).doc 将本文的Word文档下载到电脑 下载失败或者文档不完整,请联系客服人员解决!

下一篇:2014年国家公务员面试辅导:我的海关不是梦

相关阅读
本类排行
× 注册会员免费下载(下载后可以自由复制和排版)

马上注册会员

注:下载文档有可能“只有目录或者内容不全”等情况,请下载之前注意辨别,如果您已付费且无法下载或内容有问题,请联系我们协助你处理。
微信: QQ: