文獻標(biāo)識碼: A
DOI:10.16157/j.issn.0258-7998.182249
中文引用格式: 黃睿,陸許明,鄔依林. 基于TensorFlow深度學(xué)習(xí)手寫體數(shù)字識別及應(yīng)用[J].電子技術(shù)應(yīng)用,2018,44(10):6-10.
英文引用格式: Huang Rui,Lu Xuming,Wu Yilin. Handwriting digital recognition and application based on TensorFlow deep learning[J]. Application of Electronic Technique,2018,44(10):6-10.
0 引言
隨著科技的發(fā)展,人工智能識別技術(shù)已廣泛應(yīng)用于各個領(lǐng)域,同時也推動著計算機的應(yīng)用朝著智能化發(fā)展。一方面,以深度學(xué)習(xí)、神經(jīng)網(wǎng)絡(luò)為代表的人工智能模型獲國內(nèi)外學(xué)者的廣泛關(guān)注;另一方面,人工智能與機器學(xué)習(xí)的系統(tǒng)開源,構(gòu)建了開放的技術(shù)平臺,促進人工智能研究的開發(fā)。本文基于TensorFlow深度學(xué)習(xí)框架,構(gòu)建Softmax、CNN模型,并完成手寫體數(shù)字的識別。
LECUN Y等提出了一種LeNet-5的多層神經(jīng)網(wǎng)絡(luò)用于識別0~9的手寫體數(shù)字,該研究模型通過反向傳播(Back Propagation,BP)算法進行學(xué)習(xí),建立起CNN應(yīng)用的最早模型[1-2]。隨著人工智能圖像識別的出現(xiàn),CNN成為研究熱點,近年來主要運用于圖像分類[3]、目標(biāo)檢測[4]、目標(biāo)跟蹤[5]、文本識別[6]等方面,其中AlexNet[7]、GoogleNet[8]和ResNet[9]等算法取得了較大的成功。
本文基于Google第二代人工智能開源平臺TensorFlow,結(jié)合深度學(xué)習(xí)框架,對Softmax回歸算法和CNN模型進行對比驗證,最后對訓(xùn)練的模型基于Android平臺下進行應(yīng)用。
1 TensorFlow簡介
2015年11月9日,Google發(fā)布并開源第二代人工智能學(xué)習(xí)系統(tǒng)TensorFlow[10]。Tensor表示張量(由N維數(shù)組成),F(xiàn)low(流)表示基于數(shù)據(jù)流圖的計算,TensorFlow表示為將張量從流圖的一端流動到另一端的計算。TensorFlow支持短期記憶網(wǎng)絡(luò)(Long Short Term Memory Networks,LSTMN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Networks,RNN)和卷積神經(jīng)網(wǎng)絡(luò)(CNN)等深度神經(jīng)網(wǎng)絡(luò)模型。TensorFlow的基本架構(gòu)如圖1所示。
由圖1可知,TensorFlow的基本架構(gòu)可分為前端和后端。前端:基于支持多語言的編程環(huán)境,通過調(diào)用系統(tǒng)API來訪問后端的編程模型。后端:提供運行環(huán)境,由分布式運行環(huán)境、內(nèi)核、網(wǎng)絡(luò)層和設(shè)備層組成。
2 Softmax回歸
Softmax回歸算法能將二分類的Logistic回歸問題擴展至多分類。假設(shè)回歸模型的樣本由K個類組成,共有m個,則訓(xùn)練集可由式(1)表示:
式中,x(i)∈R(n+1),y(i)∈{1,2,…,K},n+1為特征向量x的維度。對于給定的輸入值x,輸出的K個估計概率由式(2)表示:
對參數(shù)θ1,θ2,…,θk進行梯度下降,得到Softmax回歸模型,在TensorFlow中的實現(xiàn)如圖2所示。
對圖2進行矩陣表達,可得式(5):
將測試集數(shù)據(jù)代入式(5),并計算所屬類別的概率,則概率最大的類別即為預(yù)測結(jié)果。
3 CNN
卷積神經(jīng)網(wǎng)絡(luò)(CNN)是一種前饋神經(jīng)網(wǎng)絡(luò),通常包含數(shù)據(jù)輸入層、卷積計算層、ReLU激勵層、池化層、全連接層等,是由卷積運算來代替?zhèn)鹘y(tǒng)矩陣乘法運算的神經(jīng)網(wǎng)絡(luò)。CNN常用于圖像的數(shù)據(jù)處理,常用的LenNet-5神經(jīng)網(wǎng)絡(luò)模型圖如圖3所示。
該模型由2個卷積層、2個抽樣層(池化層)、3個全連接層組成。
3.1 卷積層
卷積層是通過一個可調(diào)參數(shù)的卷積核與上一層特征圖進行滑動卷積運算,再加上一個偏置量得到一個凈輸出,然后調(diào)用激活函數(shù)得出卷積結(jié)果,通過對全圖的滑動卷積運算輸出新的特征圖,如式(6)~式(7)所示:
3.2 抽樣層
抽樣層是將輸入的特征圖用n×n窗口劃分成多個不重疊的區(qū)域,然后對每個區(qū)域計算出最大值或者均值,使圖像縮小了n倍,最后加上偏置量通過激活函數(shù)得到抽樣數(shù)據(jù)。其中,最大值法、均值法及輸出函數(shù)如式(8)~式(10)所示:
3.3 全連接輸出層
全連接層則是通過提取的特征參數(shù)對原始圖片進行分類。常用的分類方法如式(11)所示:
4 實驗分析
本文基于TensorFlow深度學(xué)習(xí)框架,數(shù)據(jù)源使用MNIST數(shù)據(jù)集,分別采用Softmax回歸算法和CNN深度學(xué)習(xí)進行模型訓(xùn)練,然后將訓(xùn)練的模型進行對比驗證,并在Android平臺上進行應(yīng)用。
4.1 MNIST數(shù)據(jù)集
MNIST數(shù)據(jù)集包含60 000行的訓(xùn)練數(shù)據(jù)集(train-images-idx3)和10 000行的測試數(shù)據(jù)集(test-images-idx3)。每個樣本都有唯一對應(yīng)的標(biāo)簽(label),用于描述樣本的數(shù)字,每張圖片包含28×28個像素點,如圖4所示。
由圖4可知,每一幅樣本圖片由28×28個像素點組成,可由一個長度為784的向量表示。MNIST的訓(xùn)練數(shù)據(jù)集可轉(zhuǎn)換成[60 000,784]的張量,其中,第一個維度數(shù)據(jù)用于表示索引圖片,第二個維度數(shù)據(jù)用于表示每張圖片的像素點。而樣本對應(yīng)的標(biāo)簽(label)是介于0到9的數(shù)字,可由獨熱編碼(one-hot Encoding)進行表示。一個獨熱編碼除了某一位數(shù)字是1以外,其余維度數(shù)字都是0,如標(biāo)簽0表示為[1,0,0,0,0,0,0,0,0,0],所以,樣本標(biāo)簽為[60 000,10]的張量。
4.2 Softmax模型實現(xiàn)
根據(jù)式(5),可以將Softmax模型分解為矩陣基本運算和Softmax調(diào)用,該模型實現(xiàn)方式如下:(1)使用符號變量創(chuàng)建可交互的操作單元;(2)創(chuàng)建權(quán)重值和偏量;(3)根據(jù)式(5),實現(xiàn)Softmax回歸。
4.3 CNN模型實現(xiàn)
結(jié)合LenNet-5神經(jīng)網(wǎng)絡(luò)模型,基于TensorFlow深度學(xué)習(xí)模型實現(xiàn)方式如下:
(1)初始化權(quán)重和偏置;
(2)創(chuàng)建卷積和池化模板;
(3)進行兩次的卷積、池化;
(4)進行全連接輸出;
(5)Softmax回歸。
4.4 評估指標(biāo)
采用常用的成本函數(shù)“交叉熵”(cross-entropy),如式(12)所示:
4.5 模型檢驗
預(yù)測結(jié)果檢驗方法如下:
(1)將訓(xùn)練后的模型進行保存;
(2)輸入測試樣本進行標(biāo)簽預(yù)測;
(3)調(diào)用tf.argmax函數(shù)獲取預(yù)測的標(biāo)簽值;
(4)與實際標(biāo)簽值進行匹配,最后計算識別率。
根據(jù)上述步驟,分別采用Softmax模型和卷積神經(jīng)網(wǎng)絡(luò)對手寫數(shù)字0~9的識別數(shù)量、識別率分別如圖5、表1所示。
根據(jù)表1的模型預(yù)測結(jié)果可知,Softmax模型對數(shù)字1的識別率為97.9%,識別率最高。對數(shù)字3和數(shù)字8的識別率相對較小,分別是84.9%、87.7%。Softmax模型對手寫體數(shù)字0~9的整體識別率達91.57%。
結(jié)合圖5和表1可知,基于CNN模型的整體識別率高于Softmax模型,其中對數(shù)字3的識別率提高了14.7%,對數(shù)字1的識別率只提高了1.7%?;谏疃葘W(xué)習(xí)CNN模型對手寫體數(shù)字0~9的整體識別率高達99.17%,比Softmax模型整體提高了7.6%。
4.6 模型應(yīng)用
通過模型的對比驗證可知,基于深度學(xué)習(xí)CNN的識別率優(yōu)于Softmax模型?,F(xiàn)將訓(xùn)練好的模型移植到Android平臺,進行跨平臺應(yīng)用,實現(xiàn)方式如下。
(1)UI設(shè)計
用1個Bitmap控件顯示用戶手寫觸屏的軌跡,并將2個Button控件分別用于數(shù)字識別和清屏。
(2)TensorFlow引用
首先編譯需要調(diào)用的TensorFlow的jar包和so文件。其次將訓(xùn)練好的模型(.pb)導(dǎo)入到Android工程。
(3)接口實現(xiàn)
①接口定義及初始化:
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
②接口的調(diào)用:
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs);
③獲取預(yù)測結(jié)果
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs);
通過上述步驟即可完成基于Android平臺環(huán)境的搭建及應(yīng)用,首先利用Android的觸屏功能捕獲并記錄手寫軌跡,手寫完成后單擊識別按鈕,系統(tǒng)將調(diào)用模型進行識別,并將識別結(jié)果輸出到用戶界面。識別完成后,單擊清除按鈕,循環(huán)上述操作步驟可進行手寫數(shù)字的再次識別,部分手寫數(shù)字的識別效果如圖6所示。
由圖6可知,在Android平臺上完成了基于TensorFlow深度學(xué)習(xí)手寫數(shù)字的識別,并且采用CNN的訓(xùn)練模型有較好的識別效果,實現(xiàn)了TensorFlow訓(xùn)練模型跨平臺的應(yīng)用。
5 結(jié)論
本文基于TensorFlow深度學(xué)習(xí)框架,采用Softmax回歸和CNN等算法進行手寫體數(shù)字訓(xùn)練,并將模型移植到Android平臺,進行跨平臺應(yīng)用。實驗數(shù)據(jù)表明,基于Softmax回歸模型的識別率為91.57%,基于CNN模型的識別率高達99.17%。表明基于深度學(xué)習(xí)的手寫體數(shù)字識別在人工智能識別方面具有一定的參考意義。
參考文獻
[1] HUBEL D H,WIESEL T N.Receptive fields and functional architecture of monkey striate cortex[J].Journal of Physiology,1968,195(1):215-243.
[2] LECUN Y,BOTTOU L,BENGIO Y,et al.Gradient-based learning applied to document recognition[J].Proceedings of the IEEE,1998,86(11):2278-2324.
[3] ZEILER M D,F(xiàn)ERGUS R.Visualizing and understanding convolutional networks[J].arXiv:1311.2901[cs.CV].
[4] 賀康建,周冬明,聶仁燦,等.基于局部控制核的彩色圖像目標(biāo)檢測方法[J].電子技術(shù)應(yīng)用,2016,42(12):89-92.
[5] LI H,LI Y,PORIKLI F.DeepTrack:learning discriminative feature representations online for robust visual tracking[J].IEEE Transactions on Image Processing,2015,25(4):1834-1848.
[6] GOODFELLOW I J,BULATOV Y,IBARZ J,et al.Multi-digit number recognition from street view imagery using deep convolutional neural networks[J].arXiv:1312.6082[cs.CV].
[7] KRIZHEVSKY A,SUTSKEVER I,HINTON G E.ImageNet classifycation with deep convolutional neural networks[C].International Conference on Neural Information Processing Systems.Curran Associates Inc.,2012:1097-1105.
[8] SZEGEDY C,LIU W,JIA Y,et al.Going deeper with convoluteons[C].IEEE Conference on Computer Vision and Pattern Recognition.IEEE,2015:1-9.
[9] HE K,ZHANG X,REN S,et al.Deep residual learning for image recognition[J].arXiv:1512.03385[cs.CV].
[10] ABADI M,AGARWAL A,BARHAM P,et al.TensorFlow:large-scale machine learning on heterogeneous distributed systems[J].arXiv:1603.04467[cs.DC].
作者信息:
黃 睿,陸許明,鄔依林
(廣東第二師范學(xué)院 計算機科學(xué)系,廣東 廣州510303)