













摘 要:PCA數據降維技術廣泛應用于數據降維和數據的特征提取,可以很大程度上降低算法的計算復雜度,提升程序運行效率。文章將MNIST原始數據集和對原始數據集進行PCA降維處理之后的數據集作為樣本,分別采用K-鄰近算法、決策樹ID3算法、SVC分類模型,以及選取不同分類算法作為基礎分類器的集成學習方法,實現手寫數字識別。在對MNIST數據集進行PCA降維前后,以及不同分類算法和模型執行結果的時間復雜度與預測準確率進行比對與分析,進一步強化與優化手寫數字識別準確率等各項指標。
關鍵詞:PCA降維;MNIST手寫數字識別;K-鄰近算法;決策樹;SVC分類模型;集成學習
中圖分類號:TP391.4;TP181 文獻標識碼:A 文章編號:2096-4706(2024)16-0064-05
Optimization of MNIST Handwritten Digit Recognition Based on PCA Dimensionality Reduction
Abstract: PCA data dimensionality reduction technology is widely used in data dimensionality reduction and feature extraction, which can greatly reduce the computational complexity of algorithms and improve program efficiency. This paper takes the MNIST original dataset and the dataset after PCA dimensionality reduction as samples, and uses K-Nearest Neighbor algorithm, Decision Tree ID3 algorithm, SVC classification model, as well aGA+K2yAEPOy2eiDF3uhRaA==s Ensemble Learning methods that select different classification algorithms as basic classifiers to achieve handwritten digit recognition. It compares and analyzes the time complexity and prediction accurLm6ZfPkSb1C60t9ylZPjEw==acy of different classification algorithms and models before and after PCA dimensionality reduction on the MNIST dataset, further enhances and optimizes various indicators such as handwritten digit recognition accuracy.
Keywords: PCA dimensionality reduction; MNIST handwritten digit recognition; K-Nearest Neighbor algorithm; Decision Tree; SVC classification model; Ensemble Learning
0 引 言
MNIST[1]經典手寫數字數據集由訓練集、測試集、訓練結果集和測試結果集四個文件子集構成,四個子集存放在mnist.npz數組壓縮文件中。該數據集的構建目標是讓機器學習[2]運用其分類預測方法,達到識別手寫數字[3]的目的;常用的機器學習分類算法有:K-鄰近算法[4]、決策樹ID3算法[5]、SVC分類模型[6]和集成學習[7]等方法;數據降維是一種將高維數據轉換為地位數據的技術,同時盡量保留原始數據的重要信息及數據的變化,以期達到降低計算資源,提升算法效率的效果。主成分分析PCA降維技術[2]通過減少數據集維度,具有降低算法計算復雜度的優勢,運用保留數據主要變化模式、去除噪聲和不重要特征的數據處理工作原理。不僅高效地實現了數據降維,還能夠保留數據集核心信息。本文以MNIST數據集和經過PCA降維的數據集為樣本,分別采用機器學習的K-鄰近算法、決策樹、SVC分類模型等多種分類預測算法,實現手寫數字識別,同時進行多種算法的計算復雜度和預測準確率比對。
1 數據預處理
mnist.npz數據集文件由6萬個手寫數字訓練樣本和1萬個測試樣本組成,1張28×28灰度圖像構成1個樣本,訓練樣本和測試樣本分別對應的數字標簽結果數據集也包含在內。為了簡化數據,同時進一步降低算法計算量,需要對數據進行歸一化處理和降維預處理。
1.1 歸一化處理與數據格式轉換
為了分類器能夠識別輸入數據集格式,將數據集進行歸一化處理和格式轉換。具體流程如下:
1)加載mnist.npz數據集壓縮包。
2)提取壓縮包中的數組、獲取訓練集、測試集以及分類數字標簽。
3)將訓練集、測試集數組非零元素歸一化為1。
4)數據格式轉換,將28×28矩陣轉換成1×784列的矩陣。
處理結果如圖1和圖2所示。
1.2 數據降維
本文采用PCA降維技術[8]對mnist.npz數據集進行數據降維。PCA降維運用正交變化,將原始數據轉化到一組線性不相關成份上的原理,對數據集進行降維。PCA降維技術能夠顯著降低分類訓練和預測的計算量,但是會造成分類預測精度降低。
PCA降維技術有一個非常重要的降維參數n_components,是用來指定降維后的特征值,通常代表期望將原始數據壓縮成的特征數。本文通過將降維參數n_components通過反復測試,確定將歸一化和格式轉換后的數據集采用PCA降維技術由784個特征值,降維到n_components值為36,來減少數據量,進而降低算法計算復雜度。手寫數字數據訓練集和測試集進行PCA降維代碼如下:
PCA降維之前,mnist.npz數據集中的手寫數字是人眼可以識別的,通過PCA降維技術對mnist.npz訓練集和測試集進行降維之后,手寫數字為人眼不可以識別,部分手寫數字降維之前和降維之后顯示如圖3和圖4所示。
2 手寫數字識別實現及分類預測
2.1 k-鄰近算法實現手寫數字識別
K-鄰近算法[4]是一種易于掌握且十分有效的機器學習算法。采用測量不同特征值之間的距離進行分類。根據歐氏距離公式,計算兩個向量點之間的距離。數據集特征值為n的兩個向量a(x11,x12,…,x1n)和b(x21,x22,…,x2n)之間的歐氏距離為:
將各個點之間的距離計算完成之后,按照從小到大的順序將所有距離排序,然后選取前k個最小距離,再求出前k個距離對應的每個分類標簽,預測出的結果值是出現頻率最高的分類標簽。精度高、對異常值不敏感,以及空間復雜度較高是k-鄰近算法的主要特點,該模型實現識別手寫數字流程如下:
1)收集及解析數據集。
2)計算測試集中的每一個當前點與訓練集中點的距離。
3)將距離按照遞增順序進行排序。
4)選取距離最小的前k個點。
5)計算前k個點類別標簽的頻率。
6)將頻率最高的類別標簽確定為該點的預測分類。
具體流程圖如圖5所示。
按照圖5流程,采用K-鄰近算法分類器對樣本數據集的60 000個訓練集向量和10 000個測試集向量進行距離計算和測試,原始數據集進行PCA降維前,每個距離計算包括784個維度浮點運算;采用PCA降維之后,每個距離計算包括36個維度浮點運算。運用PCA-K-鄰近算法[8]進行測試,當k設置為3時,PCA降維前錯誤率為3.91%,準確率為96.09%,但是識別時長非常高,為10 876.904秒,執行效率并不高,但識別準確率較高;PCA降維后錯誤率降為3.57%,準確率為96.46%,識別時長大大降低,僅僅花費了238.77秒。相比較于降維前,PCA降維后的時間復雜度和識別準確率都有優化,尤其在時間復雜度方面,識別花費時長大大降低,PCA降維技術表現非常優異。
2.2 模型創建與代碼實現手寫數字識別
采用決策樹ID3算法、SVC分類模型算法和集成學習技術,將mnist.npz數據集進行PCA降維前和PCA降維后,分別實現手寫數字識別,將實驗結果進行比對,給出最優方案。
2.2.1 決策樹ID3算法構造決策樹模型
決策樹ID3算法[9]是一種其數據形式非常易于理解,計算復雜度較低。決策樹算法常見的有二分法和ID3等算法。鑒于ID3算法的優勢以及對數據集的規則要求,本文采用ID3算法實現手寫數字識別。其數據結構滿足ID3算法的數據結構要求。決策樹ID3算法對數據集的每次劃分會選取一個特征屬性作為參考屬性,這個參考屬性的確定基于計算信息增益的原則,此度量方式稱為熵,熵期望值的計算公式為:
其中n為分類的數目,p(xi)為選擇該分類的概率。決策樹ID3算法優勢在于能夠簡單且快速遍歷整個數據集,循環計算熵值和數據集子集的劃分。
本文運用Python語言skleanrn庫的tree模塊提供的分類樹DecisionTreeClassifier()方法,將mnist.npz文件中60 000個原始訓練集或者降維后訓練集作為樣本訓練集,采用分類器方法DecisionTreeClassifier()提供的fit()方法進行訓練,構造手寫數字識別決策樹模型;運用決策樹模型提供的score()方法分別對mnist.npz文件中10 000個原始測試集向量或者降維后測試集進行測試,分別得出識別時間復雜度和預測準確率。
構造決策樹識別模型的具體方法如下:
1)調用Python語言sklean庫的分類樹tree.DecisionTreeClassifier()方法生成決策樹分類器。
2)運用決策樹分類器fit()方法對樣本訓練集進行訓練,生成決策樹識別模型。
3)調用決策樹識別模型的score()方法對樣本測試集進行測試。
4)求出預測準確率和時間復雜度。
2.2.2 SVC算法構建識別模型
SVC(Support Vector Classification)支持向量機分類[10]是一種二分類算法模型,在模式識別中表現優異,能夠很好地解決二分類或者多分類問題。運用核函數,將低維數據映射到高維特征空間,在高維特征空間中尋找出超平面方法,將不同類別之間的間隔最大化,實現分類目的。
本文采用Python語言skleanrn庫中svm模塊提供的支持向量分類器函數svc()構造支持向量機;同樣運用支持向量機svc()函數提供的fit()方法,將mnist.npz文件中60 000個原始訓練集或者降維后訓練集作為樣本訓練集,構造手寫數字識別支持向量機模型;運用支持向量機模型提供的score()方法得到預測指標。
2.2.3 集成學習技術構建識別模型
集成學習[11](ensemble learning)是一個分類器集成器,將多個機器學習算法作為基礎分類模型,之后采用投票法、學習法等構建策略,將選取的基礎分類器結合起來,用以完成機器學習任務。本文的集成學習,將SVC分類模型和決策樹ID3算法為基礎分類器,以投票法作為構建策略實現手寫數字識別,如圖6所示。
2.2.4 流程與Python代碼實現
分類模型識別手寫數字實現流程如下:
1)數據集收集、歸一化與格式轉換處理。
2)對數據集進行PCA降維。
3)ID3算法構造決策樹。
4)決策樹對訓練集(PCA降維前、降維后)進行訓練。
5)SVC模式算法對訓練集(PCA降維前、降維后)進行訓練。
6)集成學習對訓練集(PCA降維前、降維后)進行訓練,以決策樹與SVC分類模型為基礎分類器。
7)分別預測測試集分類結果。
8)將預測分類結果與測試集實際分類標簽進行比對,獲得預測錯誤率。
部分Python核心代碼如下:
上述代碼中,采用決策樹ID3算法對數據集PCA降維前后、數據集PCA降維之后采用SVC分類模式算法,以及以數據集PCA_決策樹ID3算法和PCA_SVC分類算法作為基礎分類器實現的集成學習技術,進行數據集訓練;在集成學習代碼中,權重的分配通過反復測試,采用PCA_決策樹對PCA_SVC分類算法位1.8比1比例進行分配,以期得到最優識別效果。
2.3 結果分析與比對
圖7為通過K-鄰近算法PCA降維前后手寫數字識別結果和運行結果。
總結不同算法所用訓練時長和預測準確率如表1所示。
表1給出了8種不同機器學習分類算法識別手寫數字的運行結果,從表1可以看出,K-鄰近算法降維前訓練時長為10 876.90秒,是其余算法訓練時長的42倍左右,將耗費大量資源,因此將其作為異常值在剔除,繪制不同算法識別手寫數字運行結果對比圖,如圖8所示。
表1和圖8中,K-鄰近算法PCA降維前后的預測準率最高,分別達到了96.3%和96.5%,但是降維前訓練時長遠遠超出預期,將K-鄰近算法排除;而PCA-K-鄰近算法預測結果不僅在預測準確率上保持了降維前的高預測率,而且在算法時間復雜度上得到大幅度降低;通常,SVC分類模型在中小型數據集上的表現和神經網絡一樣突出,從表1和圖8可以看出,SVC算法降維前后在時間復雜度和預測準確率都有較好的預期結果值,PCA-SVC算法相對于SVC算法在預測準確率方面不分上下,在時間復雜度上表現更為優異;降維前后的決策樹ID3算法對數據集進行降維之后,并沒有提升預測準確率,反而降低了預測效率,但是兩者時間復雜度相對于其他算法大幅度降低,執行效率非常高,訓練時間均在10秒以內;數據顯示,集成學習預測準確率在降維前后的表現沒有特別優勢,表現一般。
3 結 論
本文分別采用K-鄰近算法、SVC分類模型、決策樹ID3算法以及集成學習算法將mnist.npz原始數據集進行歸一化、矩陣轉換等處理結果作為樣本,實現手寫數字識別;然后以采用PCA算法對其進行降維處理之后的數據集作為樣本,再次分別運用上述四種機器學習分類方法實現識別。對各種算法執行結果的時間復雜度和預測準確率進行對比,篩選相對較優算法。從實驗數據可以看出,其中PCA-K-鄰近算法和SVC分類模型PCA降維前后在預測準確率表現非常優異,均達到93%以上,K-鄰近算法在預測準確率方面甚至高達96%以上;在執行效率方面,決策樹ID3算法表現非常優異,訓練時長相對于其他算法降低了大約38倍到57倍之間,而且預測準確率也接近90%。從運行結果的比對可以看出,不同算法各有優勢,程序運行結果數據具有較高參考價值。
參考文獻:
[1] 張貫航.基于MNIST數據集的激活函數比較研究 [J].軟件,2023(9):165-168.
[2] Peter Harrington.機器學習實戰 [M].李銳,李鵬,曲亞東,王斌,譯.北京:人民郵電出版社,2013.
[3] 黃明春,田秀云,謝玉萍,等.基于人工智能的手寫數字識別方法研究 [J].機電工程技術,2023(4):185-189.
[4] 辛英.基于k-鄰近算法的手寫識別系統的研究與實現 [J].電子設計工程,2018(7):27-30.
[5] 趙力衡.基于決策樹的手寫數字識別的應用研究 [J].軟件,2018(3):90-94.
[6] 曹啼.C-SVC、ν-SVC與LSSVC三種支持向量分類機的對比研究 [D].上海:華東師范大學,2019.
[7] 王衡軍.機器學習 [M].北京:清華大學出版社,2020.
[8] 楊濟萍.基于主成分降維模型的手寫數字識別研究 [J].網絡安全技術與應用,2021(3):31-32.
[9] 張桂杰,王小燦,邢維康,等.基于決策樹分類算法的心理測評模型研究 [J].吉林師范大學學報,2023(4):123-130.
[10] 李雅琴.SVM在手寫數字識別中的應用研究 [D].武漢:華中師范大學.
[11] 符新偉,王舒可.基于集成機器學習的手寫數字識別技術研究 [J].中阿科技論壇:中英文,2022(11):124-128.