尹 靜,李唯唯,楊德紅,閆 河
(重慶理工大學 計算機科學與工程學院,重慶 400054)
分類受限玻爾茲曼機(classification restricted boltzmann machine,ClassRBM)[1]是基于能量函數的無向圖模型,它是一個自帶標簽的隨機神經網絡模型,用于解決分類問題.ClassRBM提出以后,受到了研究者的廣泛關注[2-9],并將模型應用于各類應用中.[2]在分類受限玻爾茲曼機的網絡模型中增加線性變換,用于機器學習中自動識別不變模式;[3]提出一種判別學習方法訓練ClassRBM,其基本思想不再使用對數似然函數作為訓練目標,而使用均方誤差作為目標函數,采用隨機梯度下降法進行優化;[4]在ClassRBM的標簽層增加先驗知識,提高不同類別之間的信息共享,減少分類冗余,從而提高了模型的分類性能;[5]提出一種使用ClassRBM構建可理解的信用評分模型用于風險評估;[6]使用ClassRBM預測乳腺癌復發及在疾病再現中發現相關癥狀;[7]指出如何正確設置ClassRBM的參數是研究者面臨的主要問題之一,他們介紹了幾種啟發式技術來解決該問題;[8]利用ClassRBM對肝癌進行診斷;[9]直接利用ClassRBM對乳腺X光圖像進行分類,其分類準確度比其他基于統計特征提權的醫學圖像分類方法要高.
ClassRBM的標簽層采用一個神經元代表一個類別.因此,標簽層的神經元個數與數據的類別數一致.標簽層神經元總是稀疏的,而且每個神經元僅能為模型參數提供很少的信息,這可能會導致過擬合[1].為了解決該問題,本文提出了一種基于分類受限玻爾茲曼機的改進模型,用K個神經元表示一個類別,目的是為模型參數提供更多的信息,從而提高模型的分類性能.提出的模型在標準數字手寫體數據集MNIST、字母手寫體數據集OCR letters和文檔數據集20Newsgroups上進行了實驗,實驗結果表明我們的改進模型比ClassRBM有更好的分類性能;且改進模型在MNIST數據集上優于Random Forest,在OCR letters數據集上優于模型K-NN,在20Newsgroups數據集上優于模型SVM、Random Forest、RBM+NNet.
分類受限玻爾茲曼機可以看作是一個具有三層結構的隨機神經網絡模型.第一層是可見層,由|V|個神經元組成用以表示輸入數據v;第二層是隱層,由|H|個神經元組成用以表示數據的表達h;第三層是標簽層,代表輸入數據的標簽y,其中y∈{1,2,…,C}.其網絡結構如圖1所示.可見層與隱層之間的全連接權重用W表示,標簽層和隱層之間的全連接權重用U表示,每層各神經元之間沒有連接.

圖1 基本的ClassRBM模型Fig.1 Classification restricted boltzmann machine
為了表述簡潔,本文僅考慮模型采用二值單元的情況,當然也可以采用高斯單元、多項式單元、可矯正線性單元等[12].帶有標簽的二值ClassRBM的聯合概率分布有如下形式:

(1)
其中,Ζ(θ)=∑y,v,he-E(y,v,h|θ),也稱配分函數,以確保聯合概率分布是有效的.E(y,v,h|θ)是能量函數,表示為:
(2)
其中,θ是實數型參數bi、cj、Wij、Utj和dt的集合.vi、hj∈{0,1},當且僅當標簽為t時,yt=1,其他時候均為0.i∈{1,2,…,|V|},j∈{1,2,…,|H|}和t∈{1,2,…,C}.Wij是神經元vi和hj之間的連接權重,Utj是神經元yt和hj之間的連接權重,bi是第i個可見神經元的偏置,cj是第j個隱層神經元的偏置,而dt是第t個標簽層神經元的偏置.
對于分類任務,需要計算后驗概率p(yt|v),該條件概率有如下形式:
(3)
其中,softplus(x)=log(1+ex),t*代表輸入數據的標簽,t*∈{1,2,…,C}.

(4)
梯度的第一項比較容易計算,但第二項由于配分函數Ζ(θ)的存在,其計算復雜度很高.為了避免計算的復雜性,目前有多種算法對梯度進行近似計算,如:CD算法[13]、PCD算法[14]、PT算法[15]等.其中,CD算法是完成ClassRBM訓練的常用算法.
ClassRBM模型中,W學到有標簽信息的數據特征.執行分類任務時,ClassRBM通過U進行類別區分,從而確定數據的標簽.因此,U是控制不同類別信息非常重要的參數.ClassRBM模型的標簽層僅使用一個神經元表示某個具體類別,神經元總是稀疏的,而且單個神經元攜帶數據的類別信息是有限的,會影響分類效果.
使用K個神經元表示某個具體類別,增加神經元攜帶的類別信息,從而提高分類精度.我們建立了一個除標簽部分以外,其他與ClassRBM結構一樣的分類模型(K-Classification Restricted Boltzmann Machine,K-ClassRBM).標簽部分使用CK個神經元,每類使用連續的K個神經元.其網絡結構如圖2所示.如果數據的類別是t類,則神經元y1t,y2t,…,yKt取值1,剩余其他神經元取值0,t∈{1,2,…,C}.同樣,W是可見層和隱層之間神經元的連接權重,U是標簽層和隱層之間神經元的連接權重.

圖2 含CK個標簽神經元的ClassRBM模型(K-ClassRBM)Fig.2 Classification restricted boltzmann machine which has CK neurons in the label layer
帶有標簽的二值K-ClassRBM模型的能量函數有如下形式:
(5)

(6)
由于hj∈{0,1},故可得到hj=1的條件概率為:
(7)

(8)
給定隱層數據表達,類別t神經元對應的條件概率為:
(9)
當執行分類任務時,計算后驗概率p(yt|v)推斷數據的類標:
(10)

θ=θ+εΔθ
(11)
其中,ε是學習率.模型參數梯度具體更新公式:
(12)
(13)
(14)
(15)
(16)

下面是K-ClassRBM的具體訓練步驟:


參數更新:W=W+εΔW,U=U+εΔU,c=c+εΔc,
b=b+εΔb,d=d+εΔd,
不斷執行positive、negative階段,以及參數更新直到滿足訓練結束條件.
改進模型在分類模型的基礎上,增加了標簽層的神經元數量,使神經元攜帶更多的類別信息.從3.2節的各個公式來看,改進模型的計算公式與分類模型的計算公式有一些區別,這些區別在于每類用K個標簽神經元來標識.為了能更好地分析增加的神經元對模型參數的影響和對最終分類性能的改善,我們以參數U為例介紹改進模型參數的變化.ClassRBM的梯度有如下形式:
(17)
K-ClassRBM的梯度為:
(18)


為了驗證K-ClassRBM通過在ClassRBM的標簽層中增加一定數量的神經元后,分類性能有一定提升,我們在手寫體字符識別和文檔分類兩個應用中進行實驗,結果表明K-ClassRBM能提升ClassRBM的分類性能.
論文中實驗的參數設置:K-ClassRBM的學習率ε、隱層數采用[10,11]相同的設置,即W、U初始值的取值范圍為[-m-0.5,m-0.5],其中m是W或U中行數和列數中的最大值,且取值滿足均勻分布的隨機數,所有偏置初值均為0,隱層的個數同樣在[500~6000]之間.在手寫體字符識別和文檔分類實驗中,數據集分成訓練數據集、驗證數據集和測試數據集,訓練階段通過驗證數據集觀察模型在訓練數據集上的學習情況,最終在數據測試集上進行測試.
我們在兩個手寫體字符數據集上驗證K-ClassRBM的分類性能.標準的數字手寫體數據集MNIST和字母手寫體數據集OCR.這兩個數據集均包含四個部分:訓練數據集及它們的標簽,測試數據集及它們的標簽.標準MNIST數據集是手寫的0到9的數字圖片組成黑底白字的標準數據集.為了比較字符識別性能,實驗不但采用了與[10,11]一樣的參數設置,同樣也將原始訓練數據分成兩部分:50000條數據組成訓練集、10000條數據組成驗證集;同時使用標準的測試數據10000條進行測試.字母手寫體數據集OCR是手寫的a到z的字母圖片組成黑底白字的字符集.其中,原始訓練數據中的32152條數據作為訓練數據集,10000條數據組成驗證集,測試數據集由10000條數據組成.
4.1.1 MNIST數據集上的識別性能比較
我們比較多個模型在MNIST數據集上的識別性能.分類結果見表1.由于模型參數的初值是隨機的,和訓練過程中的隨機采樣,為了保證實驗結果的有效性,表格中K-ClassRBM模型的實驗數據是經過10次實驗的平均結果;ClassRBM和Random Forest的實驗結果來自其他論文的實驗數據[10,11].
表1 各模型在MNIST的分類錯誤率
Table 1 Classification error rates for the MNIST data set

ModelParametersErrorK?ClassRBM =0.005H=6000K=21.97%K=32.28%K=42.79%K=53.06%K=63.45%ClassRBM =0.005 H=60003.39%RandomForest2.94%
從表1的結果來看,ClassRBM僅用1位神經元標識數據類標,錯誤率為3.39%.K-ClassRBM用數量不同的神經元來標識數據類標,得到的分類錯誤率是不同的.其中,用兩位神經元來標識數據類標得到了最小的錯誤率1.97%,隨著K的不斷增加,錯誤率逐漸增加,當K=6時,K-ClassRBM的錯誤率高于ClassRBM.因此,用于標識類標的神經元過多并不利于分類.當K超過4時,K-ClassRBM的分類效果會低于Random Forest的分類效果.
4.1.2 OCR數據集上的識別性能比較
我們在字母手寫體數據集OCR上進行了識別性能比較實驗.結果見表2.同樣因為參數初始值的隨機性及訓練過程中的隨機采樣,我們重復10次實驗,表格中的結果是10次實驗的均值.K-ClassRBM模型的K值不同,得到的分類錯誤率也不同.其中,用三位神經元來標識數據類標得到了最小的錯誤率12.5%,隨著K的不斷增加,錯誤率逐漸增加,當K=8時,K-ClassRBM的錯誤率高于ClassRBM,但錯誤率低于K-NN.
表2 各模型在OCR的分類錯誤率
Table 2 Classification error rates for the OCR data set

ModelParametersErrorK?ClassRBM =0.005H=2000K=213.20%K=312.50%K=413.50%K=514.00%K=614.40%K=715.01%K=815.60%ClassRBM =0.005 H=200015.05%K?NN18.92%
我們也評估K-ClassRBM模型在文檔分類方面的性能.為了對比文檔分類的效果,采用了與ClassRBM模型一樣的數據集20Newsgroups,這個數據集包含了不同時期收集的新聞文檔,因此數據集更能反映實際的應用.20Newsgroups數據集由11269條訓練數據及標簽,7505條測試數據及標簽組成,共20類新聞.訓練時,將原始訓練數據分成兩部分:9578條數據組成訓練集和1691條數據組成驗證集,同時使用標準的測試數據7505條進行測試.由于數據集包含的詞比較多,我們僅選擇了出現頻率最大的5000個詞作為輸入數據的維度.分類效果見表3.同樣,由于模型的參數初值是隨機的,而且訓練過程的隨機采樣,為了保證實驗結果的有效性,表格中K-ClassRBM模型的實驗數據是經過10次實驗的平均結果.ClassRBM、SVM、RBM+NNet和Random Forest的實驗結果直接應用其他論文的實驗結果[10,11].
表3 各模型在20Newsgroups的分類錯誤率
Table 3 Classification error rates for the 20Newgroups data set

ModelParametersErrorK?ClassRBM =0.0005H=1000K=222.3%K=321.3%K=422.8%K=524.2%K=626.7%ClassRBM =0.0005 H=100024.9%SVM32.8%RBM+NNet26.8%RandomForest29.0%
表3的結果來看,當使用1位神經元標識數據類標時,ClassRBM在20 Newsgroups數據集上的錯誤率為24.9%.K-ClassRBM使用不同數量的神經元來標識數據類標,得到了不同的分類錯誤率.其中,用三位神經元來標識數據類標得到了最小的錯誤率21.3%,當K=6時,K-ClassRBM的錯誤率高于ClassRBM,但低于SVM、RBM+NNet和Random Forest.
前面的實驗證實了K-ClassRBM在不同數據集上的分類性能要優于ClassRBM.這部分從收斂速率和學到的數據特征來討論K-ClassRBM和ClassRBM的差異.實驗均在MNIST數據集上完成,參數設置與前面的實驗相同.

圖3 收斂速率比較圖Fig.3 Comparisons of convergence rate
在相同硬件配置下運行整個MNIST數據集,模型收斂速率的比較結果見圖3.其中橫坐標表示迭代次數,縱坐標表示分類錯誤率.曲線分別顯示了ClassRBM(K=1)和K-ClassRBM(K=2,3,4)在不同迭代次數下,分類錯誤率的降低情況.顯然K-ClassRBM的收斂速率快于ClassRBM.
圖4顯示了K-ClassRBM和ClassRBM學到的數據特征.由于隱層神經元個數為6000,為了更直觀的顯示數據特征,我們以隱層中的三個神經元為準,顯示模型學到的數據特征.

圖4 模型學到的數據特征Fig.4 Filters learned by models
圖4左邊是ClassRBM學到的數據特征;右邊是K-ClassRBM學到的數據特征(從上到下對應K=2,3,4).從圖4顯示的數據特征來看,K-ClassRBM 為區分數據類別學到了比ClassRBM更具體的特征.
對于分類問題,ClassRBM可以獲得較好的分類效果.但模型標簽層每個類別僅有一個神經元,它提供的模型參數信息太少,可能導致過擬合.K-ClassRBM在ClassRBM的標簽層中用K位神經元來標識某一個類別,通過利用K位神經元為模型提供更多參數信息,以便更好的分類.為了驗證改進模型的性能,分別在數據集MNIST、OCR和20Newsgroups上做了實驗.實驗結果表明K-ClassRBM能對ClassRBM的分類性能有一定的提升.將來需要在更多數據集或應用中進行運用K-ClassRBM,特別是在連續數據集中的應用.