哈工大讯飞联合实验室发布飞鹰智能文本校对系统1.0

哈工大讯飞联合实验室(HFL)发布飞鹰智能文本校对系统(简称:飞鹰校对)1.0。系统针对中文文本的校对需求,提供拼写纠错、语法纠错、标点纠错及敏感词检测等功能,现已开放通用领域以及司法、教育等专用领域的智能文本校对服务。欢迎大家体验。PC端请访问:http://check.hfl-rc.com/

赋予 SQL AI 能力 SQLFlow

SQLFlow

Build Status Coverage Status GoDoc License Go Report Card

What is SQLFlow

SQLFlow is a bridge that connects a SQL engine, e.g. MySQL, Hive or MaxCompute, with TensorFlow, XGBoost and other machine learning toolkits. SQLFlow extends the SQL syntax to enable model training, prediction and model explanation.

目标检测YOLOv3

简介

YOLOv3 是由 Joseph Redmon 和 Ali Farhadi 提出的单阶段检测器, 该检测器与达到同样精度的传统目标检测方法相比,推断速度能达到接近两倍.

PaddlePaddle Models

PaddlePaddle Models

Documentation Status License

PaddlePaddle 提供了丰富的计算单元,使得用户可以采用模块化的方法解决各种学习问题。在此Repo中,我们展示了如何用 PaddlePaddle来解决常见的机器学习任务,提供若干种不同的易学易用的神经网络模型。PaddlePaddle用户可领取免费Tesla V100在线算力资源,高效训练模型,每日登陆即送12小时连续五天运行再加送48小时前往使用免费算力

Peak Labs

Peak Labs从零开始为Magi项目创建了整个技术堆栈。从自然语言理解到Web规模的搜索引擎,这些经过考验的产品可随时帮助您增强业务并为您的客户带来Magi体验。

Desition Tree

目录

1. 什么是决策树

1.1 决策树的基本思想

其实用一下图片能更好的理解LR模型和决策树模型算法的根本区别,我们可以思考一下一个决策问题:是否去相亲,一个女孩的母亲要给这个女海介绍对象。

image

大家都看得很明白了吧!LR模型是一股脑儿的把所有特征塞入学习,而决策树更像是编程语言中的if-else一样,去做条件判断,这就是根本性的区别。

1.2 “树”的成长过程

决策树基于“树”结构进行决策的,这时我们就要面临两个问题 :

  • “树”怎么长。
  • 这颗“树”长到什么时候停。

弄懂了这两个问题,那么这个模型就已经建立起来了,决策树的总体流程是“分而治之”的思想,一是自根至叶的递归过程,一是在每个中间节点寻找一个“划分”属性,相当于就是一个特征属性了。接下来我们来逐个解决以上两个问题。

这颗“树”长到什么时候停

  • 当前结点包含的样本全属于同一类别,无需划分;例如:样本当中都是决定去相亲的,属于同一类别,就是不管特征如何改变都不会影响结果,这种就不需要划分了。
  • 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分;例如:所有的样本特征都是一样的,就造成无法划分了,训练集太单一。
  • 当前结点包含的样本集合为空,不能划分。

1.3 “树”怎么长

在生活当中,我们都会碰到很多需要做出决策的地方,例如:吃饭地点、数码产品购买、旅游地区等,你会发现在这些选择当中都是依赖于大部分人做出的选择,也就是跟随大众的选择。其实在决策树当中也是一样的,当大部分的样本都是同一类的时候,那么就已经做出了决策。

我们可以把大众的选择抽象化,这就引入了一个概念就是纯度,想想也是如此,大众选择就意味着纯度越高。好,在深入一点,就涉及到一句话:信息熵越低,纯度越高。我相信大家或多或少都听说过“熵”这个概念,信息熵通俗来说就是用来度量包含的“信息量”,如果样本的属性都是一样的,就会让人觉得这包含的信息很单一,没有差异化,相反样本的属性都不一样,那么包含的信息量就很多了。

一到这里就头疼了,因为马上要引入信息熵的公式,其实也很简单:

Pk表示的是:当前样本集合D中第k类样本所占的比例为Pk。

信息增益

废话不多说直接上公式:

image

看不懂的先不管,简单一句话就是:划分前的信息熵–划分后的信息熵。表示的是向纯度方向迈出的“步长”。

好了,有了前面的知识,我们就可以开始“树”的生长了。

1.3.1 ID3算法

解释:在根节点处计算信息熵,然后根据属性依次划分并计算其节点的信息熵,用根节点信息熵–属性节点的信息熵=信息增益,根据信息增益进行降序排列,排在前面的就是第一个划分属性,其后依次类推,这就得到了决策树的形状,也就是怎么“长”了。

如果不理解的,可以查看我分享的图片示例,结合我说的,包你看懂:

  1. 第一张图.jpg
  2. 第二张图.jpg
  3. 第三张图.jpg
  4. 第四张图.jpg

不过,信息增益有一个问题:对可取值数目较多的属性有所偏好,例如:考虑将“编号”作为一个属性。为了解决这个问题,引出了另一个 算法C4.5。

1.3.2 C4.5

为了解决信息增益的问题,引入一个信息增益率:

其中:

属性a的可能取值数目越多(即V越大),则IV(a)的值通常就越大。信息增益比本质: 是在信息增益的基础之上乘上一个惩罚参数。特征个数较多时,惩罚参数较小;特征个数较少时,惩罚参数较大。不过有一个缺点:

  • 缺点:信息增益率偏向取值较少的特征。

使用信息增益率:基于以上缺点,并不是直接选择信息增益率最大的特征,而是现在候选特征中找出信息增益高于平均水平的特征,然后在这些特征中再选择信息增益率最高的特征。

1.3.3 CART算法

数学家真实聪明,想到了另外一个表示纯度的方法,叫做基尼指数(讨厌的公式):

image

表示在样本集合中一个随机选中的样本被分错的概率。举例来说,现在一个袋子里有3种颜色的球若干个,伸手进去掏出2个球,颜色不一样的概率,这下明白了吧。Gini(D)越小,数据集D的纯度越高。

举个例子

假设现在有特征 “学历”,此特征有三个特征取值: “本科”,“硕士”, “博士”,

当使用“学历”这个特征对样本集合D进行划分时,划分值分别有三个,因而有三种划分的可能集合,划分后的子集如下:

1.划分点: “本科”,划分后的子集合 : {本科},{硕士,博士}

2.划分点: “硕士”,划分后的子集合 : {硕士},{本科,博士}

3.划分点: “硕士”,划分后的子集合 : {博士},{本科,硕士}}

对于上述的每一种划分,都可以计算出基于 划分特征= 某个特征值 将样本集合D划分为两个子集的纯度:

因而对于一个具有多个取值(超过2个)的特征,需要计算以每一个取值作为划分点,对样本D划分之后子集的纯度Gini(D,Ai),(其中Ai 表示特征A的可能取值)

然后从所有的可能划分的Gini(D,Ai)中找出Gini指数最小的划分,这个划分的划分点,便是使用特征A对样本集合D进行划分的最佳划分点。到此就可以长成一棵“大树”了。

1.3.4 三种不同的决策树

  • ID3:取值多的属性,更容易使数据更纯,其信息增益更大。

    训练得到的是一棵庞大且深度浅的树:不合理。

  • C4.5:采用信息增益率替代信息增益。

  • CART:以基尼系数替代熵,最小化不纯度,而不是最大化信息增益。

2. 树形结构为什么不需要归一化?

因为数值缩放不影响分裂点位置,对树模型的结构不造成影响。
按照特征值进行排序的,排序的顺序不变,那么所属的分支以及分裂点就不会有不同。而且,树模型是不能进行梯度下降的,因为构建树模型(回归树)寻找最优点时是通过寻找最优分裂点完成的,因此树模型是阶跃的,阶跃点是不可导的,并且求导没意义,也就不需要归一化。

既然树形结构(如决策树、RF)不需要归一化,那为何非树形结构比如Adaboost、SVM、LR、Knn、KMeans之类则需要归一化。

对于线性模型,特征值差别很大时,运用梯度下降的时候,损失等高线是椭圆形,需要进行多次迭代才能到达最优点。
但是如果进行了归一化,那么等高线就是圆形的,促使SGD往原点迭代,从而导致需要的迭代次数较少。

3. 分类决策树和回归决策树的区别

Classification And Regression Tree(CART)是决策树的一种,CART算法既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Regression Tree),两者在建树的过程稍有差异。

回归树

CART回归树是假设树为二叉树,通过不断将特征进行分裂。比如当前树结点是基于第j个特征值进行分裂的,设该特征值小于s的样本划分为左子树,大于s的样本划分为右子树。

而CART回归树实质上就是在该特征维度对样本空间进行划分,而这种空间划分的优化是一种NP难问题,因此,在决策树模型中是使用启发式方法解决。典型CART回归树产生的目标函数为:

因此,当我们为了求解最优的切分特征j和最优的切分点s,就转化为求解这么一个目标函数:

所以我们只要遍历所有特征的的所有切分点,就能找到最优的切分特征和切分点。最终得到一棵回归树。

参考文章:经典算法详解–CART分类决策树、回归树和模型树

4. 决策树如何剪枝

决策树的剪枝基本策略有 预剪枝 (Pre-Pruning) 和 后剪枝 (Post-Pruning)。

  • 预剪枝:其中的核心思想就是,在每一次实际对结点进行进一步划分之前,先采用验证集的数据来验证如果划分是否能提高划分的准确性。如果不能,就把结点标记为叶结点并退出进一步划分;如果可以就继续递归生成节点。
  • 后剪枝:后剪枝则是先从训练集生成一颗完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来泛化性能提升,则将该子树替换为叶结点。

参考文章:决策树及决策树生成与剪枝

5. 代码实现

GitHub:https://github.com/NLP-LOVE/ML-NLP/blob/master/Machine%20Learning/3.Desition%20Tree/DecisionTree.ipynb


作者:@mantchs

GitHub:https://github.com/NLP-LOVE/ML-NLP

欢迎大家加入讨论!共同完善此项目!群号:【541954936】NLP面试学习群

2019年度机器学习49个顶级工程汇总

2019年度机器学习49个顶级工程汇总

过去一年中,我们比较了近22000个机器学习开源工程,并筛选了49个顶级项目(筛选率0.22%)。

其中包括以下6个分类

  • 计算机视觉(1~5)
  • 强化学习(6~13)
  • NLP(14~20)
  • GAN(21~26)
  • Neural Network(27~35)
  • Toolkit(36~49)

我们花了很大的精力筛选这个list,并小心的选择出2018年1月到12月间最好的工程。为了保证名单质量,Mybridge AI协同考虑了流行度、参与度、发布时间等多重因素。

计算机视觉
1、Detectron:facebook发布的目标检测工具【18913 star on Github】
项目地址:
https://github.com/facebookresearch/Detectron?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

image.png

2、Openpost:多人实时特征点检测工具【11052 stars on GitHub】
项目地址:
https://github.com/CMU-Perceptual-Computing-Lab/openpose?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

image.png

3、DensePost:2维人体图片转3维的实时映射方法。【4165 stars on Github】
项目地址:
https://github.com/facebookresearch/Densepose?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

image.png

4、Maskrcnn-benchmark:(Pytorch)语义分割与目标检测工具包。【3888 stars on Github】
项目地址:
https://github.com/facebookresearch/maskrcnn-benchmark?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

image.png

5、SNIPER:多尺度目标检测算法。【1963 stars on Github】
项目地址:
https://github.com/mahyarnajibi/SNIPER?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

image.png

强化学习
6、Psychlab:Psychlab实验范例。【5595 stars on Github】
项目地址:
https://github.com/deepmind/lab/tree/master/game_scripts/levels/contributed/psychlab?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

image.png

7、ELF:一个灵活、轻量、可扩展的游戏研究平台。【2406 stars on Github】
项目地址:
https://github.com/pytorch/elf?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

8、TRFL:(TensorFlow)强化学习agent工具包。【2312 stars on Github】
项目地址:
https://github.com/deepmind/trfl?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

9、Horizon:首个用于大规模需求的开源强化学习平台。【1703 stars on Github】
项目地址:
https://github.com/facebookresearch/Horizon?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

10、Chess-alpha-zero:国际象棋强化学习项目(基于AlphaGo Zero方法)。【1307 stars on Github】
项目地址:
https://github.com/Zeta36/chess-alpha-zero?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

11、Dm_control:DeepMind工具包。【1231 stars on Github】
项目地址:
https://github.com/deepmind/dm_control?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

12、MAMEToolkit:基于强化学习的电子游戏python库。【437 stars on Github】
项目地址:
https://github.com/M-J-Murray/MAMEToolkit?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

13、Reaver:模块化的深度强化学习框架(星际争霸2)。【355 stars on Github】
项目地址:
https://github.com/inoryy/reaver?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

NLP
14、Bert:BERT的TensorFlow代码,以及预训练模型。【11703 stars on Github】
项目地址:
https://github.com/google-research/bert?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

15、Pytext:基于Pytorch的神经语言模型框架。【4466 stars on Github】
项目地址:
https://github.com/facebookresearch/pytext?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

16、Bert-as-service:BERT模型的网络服务版本
项目地址:
https://github.com/hanxiao/bert-as-service?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

17、UnsupervisedMT:基于Phrase的无监督机器翻译方法。【1068 stars on Github】
项目地址:
https://github.com/facebookresearch/UnsupervisedMT?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

18、DecaNLP:NLP十项全能工具,多任务模型。【1648 stars on Github】
项目地址:
https://github.com/salesforce/decaNLP?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

19、NLP-architect:来自英特尔AI实验室的python工具包,包含了当前NLP领域的多种最佳模型。【1751 stars on Github】
项目地址:
https://github.com/NervanaSystems/nlp-architect?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

20、Gluon-nlp:NLP工具包。【1263 stars on Github】
项目地址:
https://github.com/dmlc/gluon-nlp?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

GAN
21、DeOldify:一个基于深度学习的图像补全工具包。【5060 stars on Github】
项目地址:
https://github.com/jantic/DeOldify?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

22、Progressive_growing_of_gans:GAN的变种实现,提高生产质量、稳定性以及多样性。
项目地址:
https://github.com/tkarras/progressive_growing_of_gans?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

23、MUNIT:多模态无监督图像翻译。【1339 stars on Github】
项目地址:
https://github.com/NVlabs/MUNIT?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

24、Transparent_latent_gan:使用监督学习来解释GAN的隐空间信息。【1337 stars on Github】
项目地址:
https://github.com/SummitKwan/transparent_latent_gan?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

25、Gandissect:基于Pytorch的可视化以及理解GAN的神经元信息。【1065 stars on Github】
项目地址:
https://github.com/CSAILVision/gandissect?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

26、GANimation:单张图片的表情变换。【869 stars on Github】
项目地址:
https://github.com/albertpumarola/GANimation?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more
神经网络
27、Fastai:加速神经网络训练过程,并提高准确率。【11597 stars on Github】
项目地址:
https://github.com/fastai/fastai?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

28、DeepCreamPy:图像修复。【7046 stars on Github】
项目地址:
https://github.com/deeppomf/DeepCreamPy?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

29、Augmentor v0.2、图像增强工具包。【2805 stars on Github】
项目地址:
https://github.com/mdbloice/Augmentor?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

30、Graph_nets:Tensorflow的图网络构建工具。【2723 stars on Github】
项目地址:
https://github.com/deepmind/graph_nets?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

31、Textgenrnn:使用预训练字符级RNN生成文本。【1900 stars on Github】
项目地址:
https://github.com/minimaxir/textgenrnn?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

32、Person-blocker:图像中自动删除人像。【1806 stars on Github】
项目地址:
https://github.com/minimaxir/person-blocker?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

33、Deepvariant:DNA序列数据的分析工具
项目地址:
https://github.com/google/deepvariant?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

34、Video-nolocal-net:non-local神经网络的视频分类方法。【1049 stars on Github】
项目地址:
https://github.com/facebookresearch/video-nonlocal-net?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

35、Ann-visualizer:神经网络可视化工具。【922 stars on Github】
项目地址:
https://github.com/Prodicode/ann-visualizer?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

工具包
36、Tfjs:一个基于JS的ML模型训练部署工具包。【10268 stars on Github】
项目地址:
https://github.com/tensorflow/tfjs?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

37:Dopamine:快速的强化学习研究框架。【7142 stars on Github】
项目地址:
https://github.com/google/dopamine?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

38、Lime:分类器解释工具包。【5173 stars on Github】
项目地址:
https://github.com/marcotcr/lime?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

39、Autokeras:自动机器学习的开源软件库。【4520 stars on Github】
项目地址:
https://github.com/jhfjhfj1/autokeras?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

40、Shap:神经网络解释工具。【3496 stars on Github】
项目地址:
https://github.com/slundberg/shap?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

41、MMdnn:模型适配器。【3021 stars on Github】
项目地址:
https://github.com/Microsoft/MMdnn?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

42、Mlflow:机器学习生命周期管理。【3013 stars on Github】
项目地址:
https://github.com/mlflow/mlflow?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

43、Mace:面向移动计算平台的深度学习推断框架。【2979 stars on Github】
项目地址:
https://github.com/XiaoMi/mace?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

44、PySyft:关注安全性的深度学习库。【2595 stars on Github】
项目地址:
https://github.com/OpenMined/PySyft?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

45、Adanet:AutoML计算库。【2293 stars on Github】
项目地址:
https://github.com/tensorflow/adanet?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

46、Tencent-ml-images:最大的多标签图像数据库。【2094 stars on Github】
项目地址:
https://github.com/Tencent/tencent-ml-images?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

47、Donkeycar、开源的软硬件自动驾驶平台。【1207 stars on Github】

项目地址:
https://github.com/autorope/donkeycar?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

48、PocketFlow:自动模型压缩框架。【1677 stars on Github】
项目地址:
https://github.com/Tencent/PocketFlow?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

49、DALI:深度学习应用的优化工具包以及数据处理扩展引擎。【1013 stars on Github】
项目地址:
https://github.com/NVIDIA/dali?utm_source=mybridge&utm_medium=blog&utm_campaign=read_more

传送门

TensorFlow 2.0 中文手写字识别(汉字OCR)

image.png

TensorFlow 2.0 中文手写字识别(汉字OCR)

  • 搜索空间空前巨大,我们使用的数据集1.0版本汉字就多大3755个,如果加上1.1版本一起,总共汉字可以分为多达7599+个类别!这比10个阿拉伯字母识别难度大很多!
  • 数据集处理挑战更大,相比于mnist和fasionmnist来说,汉字手写字体识别数据集非常少,而且仅有的数据集数据预处理难度非常大,非常不直观,但是,千万别吓到,相信你看完本教程一定会收货满满!
  • 汉字识别更考验选手的建模能力,还在分类花?分类猫和狗?随便搭建的几层在搜索空间巨大的汉字手写识别里根本不work!你现在是不是想用很深的网络跃跃欲试?更深的网络在这个任务上可能根本不可行!!看完本教程我们就可以一探究竟!总之一句话,模型太简单和太复杂都不好,甚至会发散!(想亲身体验模型训练发散抓狂的可以来尝试一下!)。

数据准备

在开始之前,先介绍一下本项目所采用的数据信息。我们的数据全部来自于CASIA的开源中文手写字数据集,该数据集分为两部分:

  • CASIA-HWDB:离线的HWDB,我们仅仅使用1.0-1.2,这是单字的数据集,2.0-2.2是整张文本的数据集,我们暂时不用,单字里面包含了约7185个汉字以及171个英文字母、数字、标点符号等;
  • CASIA-OLHWDB:在线的HWDB,格式一样,包含了约7185个汉字以及171个英文字母、数字、标点符号等,我们不用。

其实你下载1.0的train和test差不多已经够了,可以直接运行 dataset/get_hwdb_1.0_1.1.sh 下载。原始数据下载链接点击这里. 由于原始数据过于复杂,我们使用一个类来封装数据读取过程,这是我们展示的效果:
image.png

看到这么密密麻麻的文字相信连人类都…. 开始头疼了,这些复杂的文字能够通过一个神经网络来识别出来??答案是肯定的…. 不有得感叹一下神经网络的强大。。上面的部分文字识别出来的结果是这样的:

image.png

关于数据的处理部分,从服务器下载到的原始数据是 trn_gnt.zip 解压之后是 gnt.alz, 需要再次解压得到一个包含 gnt文件的文件夹。里面每一个gnt文件都包含了若干个汉字及其标注。直接处理比较麻烦,也不方便抽取出图片再进行操作,虽然转为图片存入文件夹比较直观,但是不适合批量读取和训练, 后面我们统一转为tfrecord进行训练。

更新: 实际上,由于单个汉字图片其实很小,差不多也就最大80x80的大小,这个大小不适合转成图片保存到本地,因此我们将hwdb原始的二进制保存为tfrecord。同时也方便后面训练,可以直接从tfrecord读取图片进行训练。

image.png

训练过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def train():
all_characters = load_characters()
num_classes = len(all_characters)
logging.info('all characters: {}'.format(num_classes))
train_dataset = load_ds()
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()

val_ds = load_val_ds()
val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()

for data in train_dataset.take(2):
print(data)

# init model
model = build_net_003((64, 64, 1), num_classes)
model.summary()
logging.info('model loaded.')

start_epoch = 0
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
model.load_weights(latest_ckpt)
logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
else:
logging.info('passing resume since weights not there. training from scratch')

if use_keras_fit:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
period=500)
]
try:
model.fit(
train_dataset,
validation_data=val_ds,
validation_steps=1000,
epochs=15000,
steps_per_epoch=1024,
callbacks=callbacks)
except KeyboardInterrupt:
model.save_weights(ckpt_path.format(epoch=0))
logging.info('keras model saved.')
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))

大家在以后编写训练代码的时候其实可以保持这个好的习惯。

OK,整个模型训练起来之后,可以在短时间内达到95%的准确率:

image.png

image.png

总结

通过本教程,我们完成了使用tensorflow 2.0全新的API搭建一个中文汉字手写识别系统。模型基本能够实现我们想要的功能。要知道,这个模型可是在搜索空间多大3755的类别当中准确的找到最相似的类别!!通过本实验,我们有几点心得:

  • 神经网络不仅仅是在学习,它具有一定的想象力!!比如它的一些看着很像的字:拜-佯, 扮-捞,笨-苯…. 这些字如果手写出来,连人都比较难以辨认!!但是大家要知道这些字在类别上并不是相领的!也就是说,模型具有一定的联想能力!
  • 不管问题多复杂,要敢于动手、善于动手。
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×