張目飛,李 廷,蘇 鵬
(1 浪潮云信息技術股份公司 服務研發部,濟南 250000;2 山東浪潮新基建科技有限公司,濟南 250000)
隨著個人智能設備和圖像相關應用的普及,會產生大量的圖像數據,如何高效、合理地對這些圖像數據進行合理的分類是一項技術難題。在過去的幾年中,深度神經網絡(DNN)在計算機視覺和模式識別任務中,如:圖像分類、語義分割、對象檢測應用廣泛。卷積神經網絡中的卷積層能夠捕獲圖像的局部特征,以獲得與輸入維度相似的空間表示,使用全連接層和softmax 分類層生成概率表示,來達到分類效果[1]。He 等[2]提出了深度殘差網絡ResNet34,引入了殘差結構,可以更好地學習殘差信息,并在后續層中使用這些殘差信息,提高了圖像分類的性能,為深度學習領域帶來了新的思路和方法。
許多基于深度神經網絡,在網絡學習過程中添加注意力機制來獲得圖像中感興趣區域,通過選擇給定輸入的特征通道、區域來自動提取相關特征[3]。Woo 等[4]將注意力機 制模塊集成 到CNN中,提高網絡的特征表達能力,從而提高了圖像分類的準確率;Wang[5]提出了殘差注意網絡,殘差結構可以使網絡更好地學習圖像中的特征,通過添加注意力模塊來學習圖像中的局部區域特征;Park 等[6]提出了一種新的注意力機制,可以在空間和通道維度上同時進行特征加權,更加準確地捕捉到圖像中的重要信息;Xi 等[7]提出用殘差注意模塊進行特征提取,以增強分類任務中的關鍵特征,抑制無用的特征;Liang[8]提出將自下而上和自上而下的前饋注意力殘差模塊用于圖像分類。以上工作說明殘差結構和注意力機制都可以幫助模型更好地學習圖像特征,提高圖像分類的準確性。
隨著數據集規模的增大和類別的增多,訓練一個高準確率的分類模型變得越來越困難。傳統的數據增強方法對原始圖像進行幾何變換或者對圖像進行隨機擾動,雖然可以增加數據集的樣本量,提高分類模型的準確率,但是這些方法無法生成新的數據分布。而生成網絡是一種可以學習數據分布的生成模型,可以生成新的樣本,從而擴大數據集并且增加數據多樣性,從而可以提高分類模型的泛化性[9]。因此,本文提出一個深度殘差注意力生成網絡來生成圖像數據,對數據進行必要的數據增強,利用ResNet34 網絡進行圖像分類。
本文提出了一個深度殘差注意力生成網絡模型用于圖像數據增強,主要結構包括生成器、判別器和殘差注意力模塊。生成器包含4 個反卷積層(DConv)和3 個殘差注意力模型(SPAM),殘差注意力模型能夠對圖像的重點區域進行特別關注,以生成高質量的圖像,在生成器的最后一層使用Tanh 函數將數據映射到[-1,1]的區間內;判別器包括4 個卷積層(Conv),能夠提取圖像細節特征。深度殘差注意力生成網絡模型結構如圖1 所示。

圖1 深度殘差注意力生成網絡模型結構Fig.1 Deep residual attention generation network model
生成網絡由生成器和判別器組成。生成器將隨機向量Z作為輸入,學習真實數據分布p(x)從而合成逼真的圖像;判別器區分生成的圖像與真實的圖像,其輸出表示從真實分布p(x)提取樣本y的概率。生成網絡的最終目標是讓生成器生成和真實圖像相同的數據分布,而判別器無法判定圖像為真實圖像還是生成圖像,達到一個納什平衡。在生成器和判別器相互博弈的過程中,生成網絡的目標函數定義為公式(1):
其中,p(x)表示真實數據分布;p(z)表示生成數據分布;D(x)表示判別器運算;G(z)表示生成器運算。
本文隨機選取Z=100 維的隨機數據作為生成器的輸入,經過生成器生成圖像;判別器網絡的輸入為生成圖像和真實圖像,判別器網絡指導生成器合成圖像,鼓勵生成器捕捉更為精細的特征細節,使得生成器生成的圖像和真實圖像難以區分。
殘差注意力模型使具有相似特征的區域相互增強,以突出全局視野中的感興趣區域,殘差注意力模型如圖2 所示。通過sigmoid 函數可以得到一個[0,1]的系數,給每個通道或空間分配不同的權重,可以給每個特征圖分配不同的重要程度。

圖2 殘差注意力模型Fig.2 Residual attention model
本文設C × H × W為殘差注意力模型的輸入,C為特征圖的數量,H和W分別表示為圖像的高度和寬度;通過卷積和批量歸一化運算對輸入的特征進行處理,利用Sigmoid函數得到空間注意系數S;將輸入的特征圖和通過注意力模型得到的特征圖利用殘差結構進行融合,得到最終的殘差空間注意力特征表示,公式(2)和公式(3):
其中,X表示空間注意模型的輸入,Conv 表示卷積運算。
首先,對輸入圖像進行數據預處理,主要包括:將圖像裁剪為28×28 的大小,并進行隨機旋轉和對比度增強;其次,將預處理的數據送入到深度殘差注意力生成網絡中進行數據增強。深度殘差注意力生成網絡通過學習圖像不變性特征,合成高質量的數據,注意力機制對圖像的感興趣區域進行重點關注;生成器通過學習隨機數據來生成感興趣的圖像分布,判別器學習真實樣本的分布,辨別生成器生成的圖像;同時訓練生成器和判別器,促使兩者競爭,在理想情況下,生成器可以生成近似于真實的圖像數據,而判別器不能將真實圖像與生成圖像區分,從而達到納什平衡,達到數據增強的目的;最后,利用ResNet34 網絡對增強的圖像數據進行分類。
本文使用PyTorch 深度學習框架來訓練模型,GPU 為NVIDIA Tesla V100,顯存為32 GB。采用Adam 算法優化損失函數,采用小批量樣本的方式訓練深度學習模型,batch_size 設置為64,在訓練的過程中采用固定步長策略調整學習率,初始學習率設置為0.000 1,gamma 值為0.85,L2 正則化系數設置為0.000 1,迭代次數為50 000 次。
本文采用的數據集為MNIST 數據集和cirfar10數據集。MNIST 數據集一共有70 000張圖片,其中60 000 張作為訓練集,10 000 張作為測試集,每張圖片由28×28 的0~9 的手寫數字圖片組成;cirfar10數據集由60 000 張32×32 的彩色圖片組成,一共有十個類別,每個類別有6 000 張圖片,其中50 000 張圖片作為訓練集,10 000 張圖片作為測試集。
使用深度殘差注意力生成網絡分別對MNIST和cirfar10 數據集中的圖像進行圖像增強,使得圖像的特征更加多樣,對MNIST 數據集進行數據增強的效果如圖3 所示,對cirfar10 數據進行數據增強的效果如圖4 所示。

圖3 MNIST 數據集數據增強的效果Fig.3 Effect of data enhancement of MNIST dataset

圖4 cirfar10 數據集數據增強的效果Fig.4 Effect of data enhancement on the cirfar10 dataset
從圖3 和圖4 可以看出,使用深度殘差注意力生成網絡對MNIST 和cirfar10 數據集進行數據增強,具有很強的視覺可讀性,同時也具有較清晰的紋理特征,實現了數據增強,擴充了數據集。
為了驗證本文模型數據增強后的MNIST 以及cirfar10 數據在分類方面的效果,選擇 CNN、ResNet18、ResNet34、ResNet50 和ResNet101 作為分類網絡做對比實驗。第一組測試增強數據的分類準確率;第二組,測試原始數據的分類準確率;第三組,將增強數據和原始數據各拿出50%組成新的數據集進行測試,實驗結果見表1 和表2。

表1 MNIST 數據集分類準確率實驗結果(%)Tab.1 Experimental results of classification accuracy of MNIST dataset(%)
通過表1 和表2 可以看出,使用深度殘差注意力生成網絡進行數據增強能夠提高數據集的分類效果,證明本文提出的模型是切實有效的。利用本文模型進行數據增強的數據和原始數據相結合,在MNIST 數據集上達到了98.95% 的準確率,在cirfar10 數據集上達到了92.68%的準確率。

表2 cirfar10 數據集分類準確率實驗結果(%)Tab.2 Experimental results of classification accuracy(%)for the cirfar10 dataset
本文提出了一種深度殘差注意力生成網絡用于數據增強,從而提高分類的準確率。實驗結果證明,該模型在MNIST 數據集上獲得了98.95%的準確率,準確率提高了0.93 個百分點;在cirfar10 數據集上獲得了92.68%的準確率,準確率提高了0.65 個百分點。本文模型的提出,為數據增強提供了一種解決思路和方式。