徐 哲,耿 杰,蔣 雯,張 卓,曾慶捷
(西北工業大學電子信息學院,西安710072)
圖像分類作為計算機視覺領域最基礎的任務之一,主要通過提取原始圖像的特征并根據特征學習進行分類[1]。傳統的特征提取方法主要是對圖像的顏色、紋理、局部特征等圖像表層特征進行處理實現的,例如尺度不變特征變換法[2],方向梯度法[3]以及局部二值法[4]等。但是這些特征都是人工設計的特征,很大程度上靠人類對識別目標的先驗知識進行設計,具有一定的局限性。隨著大數據時代的到來,基于深度學習的圖像分類方法具有對大量復雜數據進行處理和表征的能力,能夠有效學習目標的特征信息,從而提高圖像分類的精度[5-8]。
深度學習以大數據驅動的方式進行訓練,對標簽數據依賴性較強,而在現實應用中往往難以獲取大量的標簽數據。當樣本數量不足時,深度卷積網絡模型容易過擬合,導致分類性能較差。生成對抗網絡(Generative Adversarial Networks,GAN)[9]具有強大的數據生成能力,采用博弈對抗的方式,既訓練正常的樣本,也能對抗學習達到納什均衡,從而完成網絡的訓練。這樣GAN 在訓練時既能夠生成樣本,又能夠提高特征提取能力,可以用來解決小樣本條件下網絡過擬合的問題。但GAN 網絡還存在穩定性差和依賴標簽數據的問題,不能直接應用于分類任務中。
針對GAN 存在的問題,有不少學者從網絡框架和理論模型兩個角度對GAN 進行了改進。從網絡框架角度,Radford 等[10]提出了深度卷積生成對抗網絡(Deep Convolutional GAN,DCGAN),將卷積神經網絡應用到生成對抗網絡中,提高了GAN 訓練的穩定性;Shaham 等[11]提出了單圖像生成對抗網絡(SinGAN),運用一個多尺度金字塔結構的全卷積網絡,能夠學習到不同尺度的圖像塊分布;Karnewar 等[12]提出了多尺度梯度生成對抗網絡(Multi-Scale Gradients GAN,MSG-GAN),通過從判別器到生成器的梯度流向多個尺度來解決訓練不穩定的問題。從理論模型角度,Arjovsky 等[13]提出Wasserstein 生成對抗網絡(Wasserstein GAN,WGAN),使用Earth-Mover 距離代替JS(Jensen-Shannon)散度來計算生成樣本分布與真實樣本分布之間的距離,緩解了GAN 訓練不穩定和梯度消失的問題;Goodfellow 等[14]提出了自注意力生成對抗網絡(Self-Attention Generative Adversarial Networks,SAGAN),通過引入自注意力機制來增大深度卷積網絡的感受野,從而更好地獲取圖像的全局信息。
為進一步提高圖像分類的準確率,解決GAN 訓練穩定性差的問題,本文提出一種聯合訓練生成對抗網絡(Co-Training Generative Adversarial Networks,CT-GAN)的半監督分類方法,設計兩個判別器進行聯合訓練,以消除單個判別器存在的分布誤差問題,同時利用大量無標簽數據和少量標簽數據進行半監督學習,設計新的監督損失和無監督損失以優化網絡模型,能夠學習到泛化能力較強、性能更好的模型,在一定程度上減小網絡對標簽數據的依賴,提高網絡的分類準確率。
生成對抗網絡是由Goodfellow 等[9]在2014年提出的無監督生成模型,由一個生成器(Generator)和一個判別器(Discriminator)構成。生成器依據樣本的數據分布來生成盡可能逼真的偽數據,判別器用于判別輸入數據是真實數據還是生成器生成的偽數據,生成器和判別器經過博弈對抗達到納什均衡,此時生成的數據能夠擬合真實樣本的數據分布。GAN 的網絡結構如圖1所示。

圖1 GAN 的網絡結構Fig.1 Structure of GAN
生成器G和判別器D通常可以由卷積神經網絡或者函數表示,G輸入隨機噪聲z用于生成偽數據G(z),D對輸入的真實數據x和偽數據G(z)判別真偽,輸出其屬于真實樣本的概率。生成器G和判別器D通過損失函數相互博弈對抗進行訓練,其優化過程是極大極小博弈的過程,目標函數為:

其中:x表示真實數據,Pdata為x的數據分布,z表示服從標準正態分布的隨機噪聲,Pz為z的數據分布,G(z)表示生成器生成的偽數據,D(·)表示判別器判別輸入樣本來自真實樣本的概率。對于判別器D而言,其希望判別的準確率越高,即希望D(x) 越接近1,D(G(z)) 越接近0,此時V(D,G)取極大值。對于生成器G而言,生成的能力越強,生成的數據分布越接近真實的數據分布,即希望D(G(z))越接近1 越好,此時V(D,G)取極小值。
當V(D,G)取到極大極小值時,生成對抗網絡達到納什均衡,此時生成的數據能夠擬合真實數據分布。
半監督生成對抗網絡(Semi-Supervised Learning with Generative Adversarial Networks,SGAN)[15]是由Odena 提出的半監督生成模型,其對原始GAN 網絡進行改進,引入半監督學習,將標簽數據和無標簽數據共同輸入到判別器中進行訓練,并輸出K+1 維帶有類別信息的分類結果。SGAN 的網絡結構如圖2 所示。
在SGAN 中,隨機噪聲z通過生成器生成的偽數據G(z)與K類標簽數據xl和無標簽數據xu共同輸入到判別器中進行訓練,在判別器的最后一層使用softmax 非線性分類器,最終輸出K+1維分類結果{l1,l2,…,lK+1},其中前K維輸出代表對應類的置信度,第K+1 維代表判定為“偽”的置信度。

圖2 SGAN 的網絡結構Fig.2 Structure of SGAN
SGAN 采用了半監督訓練方式,利用少量標簽數據和大量無標簽數據同時進行網絡訓練,從而提高半監督分類的準確率。但有研究表明,SGAN 仍存在訓練不穩定的問題[16],主要表現在訓練過程中可能出現梯度消失,導致網絡不收斂的問題。這一問題的原因是SGAN 在訓練過程中,單個判別器可能存在較大的分布誤差,從而造成梯度消失,判別器網絡不收斂。其中,分布誤差是指判別器對樣本類別預測時的概率分布誤差。一般情況下,判別器預測樣本類別的分布誤差都可以通過訓練迭代,逐漸消除其對網絡訓練的影響。但當出現較大的分布誤差時,判別器網絡會對樣本產生較大的誤判,造成梯度消失,使得判別器網絡不收斂,影響其分類性能。
為進一步提高圖像分類的準確率,解決SGAN 訓練不穩定的問題,本文提出一種聯合訓練生成對抗網絡(Co-training GAN,CT-GAN)的半監督分類方法,CT-GAN 的網絡結構如圖3所示。

圖3 CT-GAN 的網絡結構Fig.3 Structure of CT-GAN
在CT-GAN 中,采用了兩個判別器D1,D2進行聯合訓練,能夠有效提升網絡訓練穩定性的同時提高圖像分類的準確率。判別器D1,D2共享同一個生成器G,同時兩個判別器的網絡結構和初始參數設為相同。不同的是,將標簽數據和無標簽數據的順序打亂后分別輸入到判別器D1,D2中,即保證在訓練過程中兩個判別器是動態變化的。CT-GAN 采用兩個判別器進行聯合訓練,在訓練過程計算損失函數時,取兩個判別器損失的平均值,以消除單個判別器存在的分布誤差。同時在訓練過程中,兩個判別器不僅僅輸出K+1維分類結果,還設置了一個置信度閾值,如果生成數據的置信度高于該閾值,則賦予其偽標簽并加入到初始標簽數據集中,在訓練過程中就能夠擴充數據集,加快網絡收斂。
對于CT-GAN 的生成器G而言,G的能力越強,生成的圖像越接近真實圖像,即希望D(G(z))越接近1 越好,此時V(D,G)取極小值。由此可得到生成器的損失為:

同時為了讓生成器生成的數據分布更接近真實數據的統計分布,采用特征匹配[17]的方法對生成器的損失進行約束,定義特征匹配損失為:

其中:fj(·)表示判別器Dj在全連接層前的最后一層輸出的特征值。這樣,CT-GAN 生成器的總損失為:

對于CT-GAN 的判別器損失函數,采取監督損失和無監督損失相結合的方式給出。對于判別器的監督損失,需要加入標簽信息,因此以交叉熵的形式定義如下:

其中:yi表示第i維標簽,Dj(xi)表示判別器Dj判別標簽數據的標簽結果為第i維的概率。
對于無監督損失,CT-GAN 需要判別無標簽數據的類別標簽。考慮到兩個判別器聯合訓練的情況,CT-GAN 判別器的無監督損失定義如下:

其中:yi′表示判別器前一次迭代時判別無標簽數據的類別為第i維,Dj(xi)表示判別器Dj判別標簽數據的標簽結果為第i維的概率。
由式(5)和式(6)可得CT-GAN 判別器的總損失函數為:

由CT-GAN 生成器總損失函數和判別器總損失函數相加,可以得到CT-GAN 整體的損失函數如下:

CT-GAN 網絡的聯合訓練示意如圖4 所示。對于生成器生成的偽數據而言,判別器只需判斷其真偽,不判別其類別,所以在此聯合訓練中暫不考慮偽數據的輸入,只考慮標簽數據和無標簽數據的輸入。

圖4 網絡聯合訓練示意圖Fig.4 Schematic of co-training method
在CT-GAN 中,為保證判別器D1,D2訓練時是動態變化的,首先將標簽數據和無標簽數據的順序打亂得到標簽樣本L1,L2和無標簽樣本U1,U2,分別輸入到判別器D1,D2中進行聯合訓練。以判別器D1為例,訓練過程按照以下步驟進行訓練:
(1)利用標簽樣本L1訓練判別器D1。標簽樣本L1輸入到判別器D1中,輸出L1分類結果,計算判別器的監督損失以訓練判別器D1;
(2)利用判別器D1來預測無標簽樣本U1的標簽。判別器D1將前一次迭代得到的U1分類結果轉化為獨熱向量并認為是當前無標簽樣本U1的標簽,與當前得到的U1分類結果共同計算判別器的無監督損失,從而不斷優化預測無標簽樣本U1的標簽;
(3)利用無標簽樣本U1擴充標簽樣本L2。設置一個置信度閾值,對每次迭代得到的無標簽樣本U1的分類結果進行置信度判斷,如果大于該置信度閾值,則賦予其偽標簽并加入到對應的標簽樣本L2中繼續訓練,這樣在訓練過程中就可以擴充數據集,加快網絡收斂。

CT-GAN 模型通過判別器D1,D2的聯合訓練,一方面可以消除單個判別器存在的分布誤差,提高判別器訓練的穩定性;另一方面,利用無標簽數據在訓練時擴充標簽數據集,能夠加快網絡收斂。因此,CT-GAN 模型能夠充分利用少量標簽數據的標簽信息和大量無標簽數據的分布信息來獲取整個樣本的特征分布,從而進一步提高網絡識別的精度。
本文實驗所使用的數據集為CIFAR-10 和SVHN 數據集,其中CIFAR10 數據集是一個包含10 個類別32×32 的彩色圖像數據集,共計60 000 張圖像,其中40 000 張作為訓練集,20 000張作為測試集,即每個類別有4 000 張訓練樣本和2 000 張測試樣本。SVHN 數據集是一個真實街景數字數據集,包含10 個類別32×32 的彩色圖像,共計99 289 張圖片,其中73 257 張作為訓練集,26 032 張作為測試集。
數據集中的每張圖像均包含一個類別信息,即均為標簽數據。為滿足本文實驗要求,對訓練集中圖像進行預處理,按一定的比例隨機去除部分標簽數據的類別信息,得到無標簽數據,CIFAR-10 數據集和SVHN 數據集的預處理方案如表1 所示。其中CIFAR-10 數據集在各類別標簽數量分別為10,100,250,500,1 000 和2 000 時分別進行實驗,SVHN 數據集在各類別標簽數量分別為100 和1 000 時分別進行實驗,以研究不同數量標簽數據對網絡的影響。

表1 CIFAR10數據集各類別標簽數據的數量及所占比例Tab.1 Amount and proportion of labeled data in each category of the CIFAR10 data
本實驗采用一個RTX 2080Ti 的GPU 進行訓練,共訓練200 個epochs,且設置batch size 為128,即每個epoch 迭代313 次。設置初始學習率為0.000 2,并在迭代50 000 次和90 000 次時分別衰減為原來的1/10。采用Adam 優化算法對網絡進行優化,其中一階動量設為0.5,二階動量設為0.999。模型采用基于PyTorch 的深度學習框架實現。
在CIFAR-10 數據集上,CT-GAN 模型的生成器框架和判別器框架分別如圖5(a)和(b)所示。生成器的輸入為(128,100)的隨機噪聲,首先通過(100,8 192)的全連接層得到(128,8 192)的張量,經過維度轉換得到維度為(128,128,8,8)的圖像,經過兩次上采樣操作和三次步長為1的3×3 卷積核的卷積操作后得到維度為(128,3,32,32)的圖像,其中每次完成卷積操作后都使用批歸一化(Batch Normalization)操作并加入ReLU 激活函數。最后一層通過Tanh 激活函數輸出生成數據G(z)。
判別器的輸入為128 張大小為32×32 的3通道RGB 彩色圖像,其維度為(128,3,32,32),經過四次步長為2 的3×3 卷積核的卷積操作,最終輸出圖像維度為(128,128,2,2),其中每次完成卷積操作后都加入LeakyReLU 激活函數和Dropout 操作以防止過擬合,而除了首次卷積不使用批歸一化外,其余卷積操作后都使用批歸一化。將卷積輸出圖像進行維度轉換得到維度為(128,512)的張量,通過(512,10)的全連接層和softmax 分類器得到分類結果,同時通過(512,1)的全連接層和Sigmoid 分類器判別真偽。

圖5 CT-GAN 模型的生成器和判別器框架Fig.5 Structure of generator and discriminator in CTGAN
在CIFAR-10 數據集上的實驗首先按照4.1節中的數據集預處理方案,對數據集中的圖像按一定的比例去除部分標簽數據的標簽信息,構成無標簽數據。在各類別標簽數量分別為10,100,250,500,1 000 和2 000 時分別進行實驗,以研究不同數量的標簽數據下CT-GAN 模型的性能。如圖6 和圖7 給出了在各類標簽數據數量分別為10,100,250,500,1 000 和2 000 時的CT-GAN 判別器和生成器損失變化曲線。
分析圖6 可知,在不同數量的標簽數據下,CT-GAN 的判別器損失在一定迭代次數后都達到了穩定,標簽數據越少,損失趨于穩定需要迭代的次數也越少。這是因為當標簽數量越少時,整個數據所含的類別信息也就越少,判別器可以學習的信息也相應減少,導致損失收斂速度加快。雖然不同標簽數量下損失收斂所需的迭代次數不同,但是其損失收斂值大致相同。這說明標簽數量對CT-GAN 的判別器的訓練影響很小,在一定程度上CT-GAN 模型能夠減小對標簽數據的依賴。分析圖7 可知,在不同數量的標簽數據下,CT-GAN 的生成器損失值逐漸減小并收斂到較低水平。
為了驗證本文方法的有效性,利用CIFAR-10 數據集對比了不同數量標簽數據下CT-GAN模型與相關的深度網絡模型的分類效果,其分類準確率如表2 所示。實驗在不同條件下分別進行了20 次重復實驗,計算平均精度和方差。

圖6 CT-GAN 判別器損失變化曲線Fig.6 Discriminator loss of CT-GAN

圖7 CT-GAN 生成器損失變化曲線Fig.7 Generator loss of CT-GAN
分析表2 的分類準確率可知,本文提出的CT-GAN 模型在CIFAR-10 數據集上的分類精度更高,在不同數量的標簽數據下的分類精度都有不同程度的提升,在標簽數據數量僅為10 時,就可以達到47.6%的分類精度,相比SGAN 模型提高了6.5%,這說明CT-GAN 模型能夠有效提升在標簽數據極少情況下的分類準確率,在一定程度上解決了GAN 網絡在小樣本條件下的過擬合問題。

表2 CIFAR-10 數據集上不同數量標簽樣本的半監督分類精度Tab.2 Using different number of labeled data when semi-supervised training on CIFAR-10(%)
為更好地說明本文所提算法的有效性,在SVHN 數據集上進行實驗。按照4.1 節中的數據集預處理方案,在各類別標簽數量分別為100和1 000 時分別進行實驗,以研究不同數量的標簽數據下CT-GAN 模型的性能。表3 為SVHN數據集上不同數量標簽樣本的半監督分類精度。實驗在不同條件下分別進行了20 次重復實驗,計算平均精度和方差。
分析表3 可知,本文所提方法CT-GAN 模型在SVHN 數據集上的分類性能優異,在不同數量的標簽數據下的分類精度都達到了較高水平,特別是當標簽樣本數量僅為100 時,即少量標簽樣本的情況下,達到了77.7%,相較于其他算法分別高38.33%,21.40%,6.34%和13.85%,進一步說明CT-GAN 模型能夠在少量標簽樣本條件下有效提升網絡的分類精度。同時,CT-GAN 在不同標簽樣本數量下的分類精度誤差都在0.1%左右,相較于其他對比方法,本文所提模型訓練更加穩定。

表3 SVHN 數據集上不同數量標簽樣本的半監督分類精度Tab.3 Classification accuracy of different number of labeled data on SVHN(%)
本文提出了一種基于聯合訓練生成對抗網絡(CT-GAN)的半監督分類方法,通過兩個判別器的聯合訓練來消除單個判別器存在的分布誤差,同時利用無標簽數據來擴充標簽數據集,可以有效提升半監督分類的精度。實驗結果表明,在少量標簽樣本條件下,CT-GAN 模型能夠有效提升圖像分類精度,在一定程度上降低了GAN 網絡對標簽數據的依賴。此外,在不同數量的標簽數據下,CT-GAN 模型都取得了較好的分類效果,多種情況下的分類準確率相比其他方法都有一定程度提升,說明了本文模型的有效性。