深度学习进阶笔记之八 TensorFlow与中文手写汉字识别

2018-11-17 20:27

UCLoud中国云三强: www.ucloud.cn

深度学习进阶笔记之八 | TensorFlow与中文手写汉字识别

引言

TensorFlow是Google基于DistBelief进行研发的第二代人工智能学习系统,被广泛用于语音识别或图像识别等多项机器深度学习领域。其命名来源于本身的运行原理。Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,TensorFlow代表着张量从图象的一端流动到另一端计算过程,是将复杂的数据结构传输至人工智能神经网中进行分析和处理的过程。

TensorFlow完全开源,任何人都可以使用。可在小到一部智能手机、大到数千台数据中心服务器的各种设备上运行。

『机器学习进阶笔记』系列将深入解析TensorFlow系统的技术实践,从零开始,由浅入深,与大家一起走上机器学习的进阶之路。 Goal

本文目标是利用TensorFlow做一个简单的图像分类器,在比较大的数据集上,尽可能高效地做图像相关处理,从Train,Validation到Inference,是一个比较基本的Example, 从一个基本的任务学习如果在TensorFlow下做高效地图像读取,基本的图像处理,整个项目很简单,但其中有一些trick,在实际项目当中有很大的好处, 比如坚决不要一次读入所有的 的数据到内存(尽管在Mnist这类级别的例子上经常出现)…

UCLoud中国云三强: www.ucloud.cn

刚开始看到是这篇blog里面的TensorFlow练习22: 手写汉字识别, 但是这篇文章只用了140训练与测试,试了下代码 很快,但是当扩展到所有的时,发现32g的内存都不够用,这才注意到原文中都是用numpy,会先把所有的数据放入到内存,但这个不必须的,无论在MXNet还是TensorFlow中都是不必 须的,MXNet使用的是DataIter,会在程序运行的过程中异步读取数据,TensorFlow也是这样的,TensorFlow封装了高级的api,用来做数据的读取,比如TFRecord,还有就是从filenames中读取, 来异步读取文件,然后做shuffle batch,再feed到模型的Graph中来做模型参数的更新。具体在tf如何做数据的读取可以看看reading data in tensorflow

这里我会拿到所有的数据集来做训练与测试,算作是对斗大的熊猫上面那篇文章的一个扩展。

Batch Generate

数据集来自于中科院自动化研究所,感谢分享精神!!!具体下载:

UCLoud中国云三强: www.ucloud.cn

wget

http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip wget

http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip 解压后发现是一些gnt文件,然后用了斗大的熊猫里面的代码,将所有文件都转化为对应label目录下的所有png的图片。(注意在HWDB1.1trn_gnt.zip解压后是alz文件,需要再次解压 我在mac没有找到合适的工具,windows上有alz的解压工具)。

import os

import numpy as np import struct

from PIL import Image

data_dir = '../data'

train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt') test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')

def read_from_gnt_dir(gnt_dir=train_data_dir): def one_file(f): header_size = 10 while True:

UCLoud中国云三强: www.ucloud.cn

header = np.fromfile(f, dtype='uint8', count=header_size) if not header.size: break

sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)

tagcode = header[5] + (header[4]<<8) width = header[6] + (header[7]<<8) height = header[8] + (header[9]<<8) if header_size + width*height != sample_size: break

image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width)) yield image, tagcode for file_name in os.listdir(gnt_dir): if file_name.endswith('.gnt'):

file_path = os.path.join(gnt_dir, file_name) with open(file_path, 'rb') as f: for image, tagcode in one_file(f): yield image, tagcode char_set = set()

for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):

tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') char_set.add(tagcode_unicode)

UCLoud中国云三强: www.ucloud.cn

char_list = list(char_set)

char_dict = dict(zip(sorted(char_list), range(len(char_list)))) print len(char_dict) import pickle

f = open('char_dict', 'wb') pickle.dump(char_dict, f) f.close()

train_counter = 0 test_counter = 0

for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') im = Image.fromarray(image)

dir_name = '../data/train/' + '%0.5d'%char_dict[tagcode_unicode] if not os.path.exists(dir_name): os.mkdir(dir_name)

im.convert('RGB').save(dir_name+'/' + str(train_counter) + '.png') train_counter += 1

for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir): tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') im = Image.fromarray(image)

dir_name = '../data/test/' + '%0.5d'%char_dict[tagcode_unicode] if not os.path.exists(dir_name):


深度学习进阶笔记之八 TensorFlow与中文手写汉字识别.doc 将本文的Word文档下载到电脑 下载失败或者文档不完整,请联系客服人员解决!

下一篇:2014年国家公务员面试辅导:我的海关不是梦

相关阅读
本类排行
× 注册会员免费下载(下载后可以自由复制和排版)

马上注册会员

注:下载文档有可能“只有目录或者内容不全”等情况,请下载之前注意辨别,如果您已付费且无法下载或内容有问题,请联系我们协助你处理。
微信: QQ: