文獻(xiàn)標(biāo)識(shí)碼: A
DOI:10.16157/j.issn.0258-7998.182249
中文引用格式: 黃睿,陸許明,鄔依林. 基于TensorFlow深度學(xué)習(xí)手寫體數(shù)字識(shí)別及應(yīng)用[J].電子技術(shù)應(yīng)用,2018,44(10):6-10.
英文引用格式: Huang Rui,Lu Xuming,Wu Yilin. Handwriting digital recognition and application based on TensorFlow deep learning[J]. Application of Electronic Technique,2018,44(10):6-10.
0 引言
隨著科技的發(fā)展,人工智能識(shí)別技術(shù)已廣泛應(yīng)用于各個(gè)領(lǐng)域,同時(shí)也推動(dòng)著計(jì)算機(jī)的應(yīng)用朝著智能化發(fā)展。一方面,以深度學(xué)習(xí)、神經(jīng)網(wǎng)絡(luò)為代表的人工智能模型獲國內(nèi)外學(xué)者的廣泛關(guān)注;另一方面,人工智能與機(jī)器學(xué)習(xí)的系統(tǒng)開源,構(gòu)建了開放的技術(shù)平臺(tái),促進(jìn)人工智能研究的開發(fā)。本文基于TensorFlow深度學(xué)習(xí)框架,構(gòu)建Softmax、CNN模型,并完成手寫體數(shù)字的識(shí)別。
LECUN Y等提出了一種LeNet-5的多層神經(jīng)網(wǎng)絡(luò)用于識(shí)別0~9的手寫體數(shù)字,該研究模型通過反向傳播(Back Propagation,BP)算法進(jìn)行學(xué)習(xí),建立起CNN應(yīng)用的最早模型[1-2]。隨著人工智能圖像識(shí)別的出現(xiàn),CNN成為研究熱點(diǎn),近年來主要運(yùn)用于圖像分類[3]、目標(biāo)檢測(cè)[4]、目標(biāo)跟蹤[5]、文本識(shí)別[6]等方面,其中AlexNet[7]、GoogleNet[8]和ResNet[9]等算法取得了較大的成功。
本文基于Google第二代人工智能開源平臺(tái)TensorFlow,結(jié)合深度學(xué)習(xí)框架,對(duì)Softmax回歸算法和CNN模型進(jìn)行對(duì)比驗(yàn)證,最后對(duì)訓(xùn)練的模型基于Android平臺(tái)下進(jìn)行應(yīng)用。
1 TensorFlow簡介
2015年11月9日,Google發(fā)布并開源第二代人工智能學(xué)習(xí)系統(tǒng)TensorFlow[10]。Tensor表示張量(由N維數(shù)組成),F(xiàn)low(流)表示基于數(shù)據(jù)流圖的計(jì)算,TensorFlow表示為將張量從流圖的一端流動(dòng)到另一端的計(jì)算。TensorFlow支持短期記憶網(wǎng)絡(luò)(Long Short Term Memory Networks,LSTMN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Networks,RNN)和卷積神經(jīng)網(wǎng)絡(luò)(CNN)等深度神經(jīng)網(wǎng)絡(luò)模型。TensorFlow的基本架構(gòu)如圖1所示。
由圖1可知,TensorFlow的基本架構(gòu)可分為前端和后端。前端:基于支持多語言的編程環(huán)境,通過調(diào)用系統(tǒng)API來訪問后端的編程模型。后端:提供運(yùn)行環(huán)境,由分布式運(yùn)行環(huán)境、內(nèi)核、網(wǎng)絡(luò)層和設(shè)備層組成。
2 Softmax回歸
Softmax回歸算法能將二分類的Logistic回歸問題擴(kuò)展至多分類。假設(shè)回歸模型的樣本由K個(gè)類組成,共有m個(gè),則訓(xùn)練集可由式(1)表示:
式中,x(i)∈R(n+1),y(i)∈{1,2,…,K},n+1為特征向量x的維度。對(duì)于給定的輸入值x,輸出的K個(gè)估計(jì)概率由式(2)表示:
對(duì)參數(shù)θ1,θ2,…,θk進(jìn)行梯度下降,得到Softmax回歸模型,在TensorFlow中的實(shí)現(xiàn)如圖2所示。
對(duì)圖2進(jìn)行矩陣表達(dá),可得式(5):
將測(cè)試集數(shù)據(jù)代入式(5),并計(jì)算所屬類別的概率,則概率最大的類別即為預(yù)測(cè)結(jié)果。
3 CNN
卷積神經(jīng)網(wǎng)絡(luò)(CNN)是一種前饋神經(jīng)網(wǎng)絡(luò),通常包含數(shù)據(jù)輸入層、卷積計(jì)算層、ReLU激勵(lì)層、池化層、全連接層等,是由卷積運(yùn)算來代替?zhèn)鹘y(tǒng)矩陣乘法運(yùn)算的神經(jīng)網(wǎng)絡(luò)。CNN常用于圖像的數(shù)據(jù)處理,常用的LenNet-5神經(jīng)網(wǎng)絡(luò)模型圖如圖3所示。
該模型由2個(gè)卷積層、2個(gè)抽樣層(池化層)、3個(gè)全連接層組成。
3.1 卷積層
卷積層是通過一個(gè)可調(diào)參數(shù)的卷積核與上一層特征圖進(jìn)行滑動(dòng)卷積運(yùn)算,再加上一個(gè)偏置量得到一個(gè)凈輸出,然后調(diào)用激活函數(shù)得出卷積結(jié)果,通過對(duì)全圖的滑動(dòng)卷積運(yùn)算輸出新的特征圖,如式(6)~式(7)所示:
3.2 抽樣層
抽樣層是將輸入的特征圖用n×n窗口劃分成多個(gè)不重疊的區(qū)域,然后對(duì)每個(gè)區(qū)域計(jì)算出最大值或者均值,使圖像縮小了n倍,最后加上偏置量通過激活函數(shù)得到抽樣數(shù)據(jù)。其中,最大值法、均值法及輸出函數(shù)如式(8)~式(10)所示:
3.3 全連接輸出層
全連接層則是通過提取的特征參數(shù)對(duì)原始圖片進(jìn)行分類。常用的分類方法如式(11)所示:
4 實(shí)驗(yàn)分析
本文基于TensorFlow深度學(xué)習(xí)框架,數(shù)據(jù)源使用MNIST數(shù)據(jù)集,分別采用Softmax回歸算法和CNN深度學(xué)習(xí)進(jìn)行模型訓(xùn)練,然后將訓(xùn)練的模型進(jìn)行對(duì)比驗(yàn)證,并在Android平臺(tái)上進(jìn)行應(yīng)用。
4.1 MNIST數(shù)據(jù)集
MNIST數(shù)據(jù)集包含60 000行的訓(xùn)練數(shù)據(jù)集(train-images-idx3)和10 000行的測(cè)試數(shù)據(jù)集(test-images-idx3)。每個(gè)樣本都有唯一對(duì)應(yīng)的標(biāo)簽(label),用于描述樣本的數(shù)字,每張圖片包含28×28個(gè)像素點(diǎn),如圖4所示。
由圖4可知,每一幅樣本圖片由28×28個(gè)像素點(diǎn)組成,可由一個(gè)長度為784的向量表示。MNIST的訓(xùn)練數(shù)據(jù)集可轉(zhuǎn)換成[60 000,784]的張量,其中,第一個(gè)維度數(shù)據(jù)用于表示索引圖片,第二個(gè)維度數(shù)據(jù)用于表示每張圖片的像素點(diǎn)。而樣本對(duì)應(yīng)的標(biāo)簽(label)是介于0到9的數(shù)字,可由獨(dú)熱編碼(one-hot Encoding)進(jìn)行表示。一個(gè)獨(dú)熱編碼除了某一位數(shù)字是1以外,其余維度數(shù)字都是0,如標(biāo)簽0表示為[1,0,0,0,0,0,0,0,0,0],所以,樣本標(biāo)簽為[60 000,10]的張量。
4.2 Softmax模型實(shí)現(xiàn)
根據(jù)式(5),可以將Softmax模型分解為矩陣基本運(yùn)算和Softmax調(diào)用,該模型實(shí)現(xiàn)方式如下:(1)使用符號(hào)變量創(chuàng)建可交互的操作單元;(2)創(chuàng)建權(quán)重值和偏量;(3)根據(jù)式(5),實(shí)現(xiàn)Softmax回歸。
4.3 CNN模型實(shí)現(xiàn)
結(jié)合LenNet-5神經(jīng)網(wǎng)絡(luò)模型,基于TensorFlow深度學(xué)習(xí)模型實(shí)現(xiàn)方式如下:
(1)初始化權(quán)重和偏置;
(2)創(chuàng)建卷積和池化模板;
(3)進(jìn)行兩次的卷積、池化;
(4)進(jìn)行全連接輸出;
(5)Softmax回歸。
4.4 評(píng)估指標(biāo)
采用常用的成本函數(shù)“交叉熵”(cross-entropy),如式(12)所示:
4.5 模型檢驗(yàn)
預(yù)測(cè)結(jié)果檢驗(yàn)方法如下:
(1)將訓(xùn)練后的模型進(jìn)行保存;
(2)輸入測(cè)試樣本進(jìn)行標(biāo)簽預(yù)測(cè);
(3)調(diào)用tf.argmax函數(shù)獲取預(yù)測(cè)的標(biāo)簽值;
(4)與實(shí)際標(biāo)簽值進(jìn)行匹配,最后計(jì)算識(shí)別率。
根據(jù)上述步驟,分別采用Softmax模型和卷積神經(jīng)網(wǎng)絡(luò)對(duì)手寫數(shù)字0~9的識(shí)別數(shù)量、識(shí)別率分別如圖5、表1所示。
根據(jù)表1的模型預(yù)測(cè)結(jié)果可知,Softmax模型對(duì)數(shù)字1的識(shí)別率為97.9%,識(shí)別率最高。對(duì)數(shù)字3和數(shù)字8的識(shí)別率相對(duì)較小,分別是84.9%、87.7%。Softmax模型對(duì)手寫體數(shù)字0~9的整體識(shí)別率達(dá)91.57%。
結(jié)合圖5和表1可知,基于CNN模型的整體識(shí)別率高于Softmax模型,其中對(duì)數(shù)字3的識(shí)別率提高了14.7%,對(duì)數(shù)字1的識(shí)別率只提高了1.7%。基于深度學(xué)習(xí)CNN模型對(duì)手寫體數(shù)字0~9的整體識(shí)別率高達(dá)99.17%,比Softmax模型整體提高了7.6%。
4.6 模型應(yīng)用
通過模型的對(duì)比驗(yàn)證可知,基于深度學(xué)習(xí)CNN的識(shí)別率優(yōu)于Softmax模型?,F(xiàn)將訓(xùn)練好的模型移植到Android平臺(tái),進(jìn)行跨平臺(tái)應(yīng)用,實(shí)現(xiàn)方式如下。
(1)UI設(shè)計(jì)
用1個(gè)Bitmap控件顯示用戶手寫觸屏的軌跡,并將2個(gè)Button控件分別用于數(shù)字識(shí)別和清屏。
(2)TensorFlow引用
首先編譯需要調(diào)用的TensorFlow的jar包和so文件。其次將訓(xùn)練好的模型(.pb)導(dǎo)入到Android工程。
(3)接口實(shí)現(xiàn)
①接口定義及初始化:
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
②接口的調(diào)用:
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs);
③獲取預(yù)測(cè)結(jié)果
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs);
通過上述步驟即可完成基于Android平臺(tái)環(huán)境的搭建及應(yīng)用,首先利用Android的觸屏功能捕獲并記錄手寫軌跡,手寫完成后單擊識(shí)別按鈕,系統(tǒng)將調(diào)用模型進(jìn)行識(shí)別,并將識(shí)別結(jié)果輸出到用戶界面。識(shí)別完成后,單擊清除按鈕,循環(huán)上述操作步驟可進(jìn)行手寫數(shù)字的再次識(shí)別,部分手寫數(shù)字的識(shí)別效果如圖6所示。
由圖6可知,在Android平臺(tái)上完成了基于TensorFlow深度學(xué)習(xí)手寫數(shù)字的識(shí)別,并且采用CNN的訓(xùn)練模型有較好的識(shí)別效果,實(shí)現(xiàn)了TensorFlow訓(xùn)練模型跨平臺(tái)的應(yīng)用。
5 結(jié)論
本文基于TensorFlow深度學(xué)習(xí)框架,采用Softmax回歸和CNN等算法進(jìn)行手寫體數(shù)字訓(xùn)練,并將模型移植到Android平臺(tái),進(jìn)行跨平臺(tái)應(yīng)用。實(shí)驗(yàn)數(shù)據(jù)表明,基于Softmax回歸模型的識(shí)別率為91.57%,基于CNN模型的識(shí)別率高達(dá)99.17%。表明基于深度學(xué)習(xí)的手寫體數(shù)字識(shí)別在人工智能識(shí)別方面具有一定的參考意義。
參考文獻(xiàn)
[1] HUBEL D H,WIESEL T N.Receptive fields and functional architecture of monkey striate cortex[J].Journal of Physiology,1968,195(1):215-243.
[2] LECUN Y,BOTTOU L,BENGIO Y,et al.Gradient-based learning applied to document recognition[J].Proceedings of the IEEE,1998,86(11):2278-2324.
[3] ZEILER M D,F(xiàn)ERGUS R.Visualizing and understanding convolutional networks[J].arXiv:1311.2901[cs.CV].
[4] 賀康建,周冬明,聶仁燦,等.基于局部控制核的彩色圖像目標(biāo)檢測(cè)方法[J].電子技術(shù)應(yīng)用,2016,42(12):89-92.
[5] LI H,LI Y,PORIKLI F.DeepTrack:learning discriminative feature representations online for robust visual tracking[J].IEEE Transactions on Image Processing,2015,25(4):1834-1848.
[6] GOODFELLOW I J,BULATOV Y,IBARZ J,et al.Multi-digit number recognition from street view imagery using deep convolutional neural networks[J].arXiv:1312.6082[cs.CV].
[7] KRIZHEVSKY A,SUTSKEVER I,HINTON G E.ImageNet classifycation with deep convolutional neural networks[C].International Conference on Neural Information Processing Systems.Curran Associates Inc.,2012:1097-1105.
[8] SZEGEDY C,LIU W,JIA Y,et al.Going deeper with convoluteons[C].IEEE Conference on Computer Vision and Pattern Recognition.IEEE,2015:1-9.
[9] HE K,ZHANG X,REN S,et al.Deep residual learning for image recognition[J].arXiv:1512.03385[cs.CV].
[10] ABADI M,AGARWAL A,BARHAM P,et al.TensorFlow:large-scale machine learning on heterogeneous distributed systems[J].arXiv:1603.04467[cs.DC].
作者信息:
黃 睿,陸許明,鄔依林
(廣東第二師范學(xué)院 計(jì)算機(jī)科學(xué)系,廣東 廣州510303)