袁培森 吳茂盛 翟肇裕 楊承林 徐煥良
(1.南京農業大學信息科學技術學院, 南京 210095; 2.馬德里理工大學技術工程和電信系統高級學院, 馬德里 28040)
表型(Phenotype)研究核心是獲取高質量的性狀數據,進而對基因型和環境互作效應(Genotype-by-Environment) 進行分析[1-2],表型組學近年來發展迅猛,已成為分子育種和農業應用中的重要技術支撐[3-4]。然而,植物表型數據的獲取需搭建實驗環境,并需昂貴的數據采集工具,具有周期長、代價高昂等特點[1,5-6]。當前,以大數據為基礎的深度學習正在成為表型數據分析的有力工具[7-8],深度學習相關算法的有效性在很大程度上取決于標記樣本的數量,因此限制了其在小樣本量環境中的應用[9]。數據的非均衡性是生物表型數據具有挑戰性的問題[10-13]。
為了提升非均衡數據分析的性能和質量,文獻[14-15]提出了數據生成的方法。然而,過采樣技術SMOTE[15]、ADASYN[16]等對于處理經典學習系統中的類不平衡有效,但是此類方法生成的數據不能直接應用于深度學習系統[17]。近年來,生成式對抗網絡(Generative adversarial networks,GAN)[18]的出現為計算機視覺應用提供了新的技術和手段,GAN采用零和博弈與對抗訓練的思想生成高質量的樣本,具有比傳統機器學習算法更強大的特征學習和特征表達能力[19],是一種基于深度學習的學習模型,可以用于海量數據的智能生成,已經廣泛用于圖像、文本、語音、語言等領域[20-21]。
有學者提出將GAN網絡技術用于生物學等領域的數據生成問題[9,22-25],結果顯示生成數據的質量有顯著提高。目前,記錄約8萬種真菌、近1 500種野生蘑菇種類的圖像數據集,這對種類繁多和分布非均衡的菌類識別和分類具有重要的生態意義[26-28]。
本文提出基于生成對抗網絡的菌菇表型數據生成方法(Mushroom phenotypic based on generative adversarial network, MPGAN)。以菌菇表型為研究對象,在特定目標域上訓練GAN網絡,作為GAN發生器網絡的輸入給出潛在模型,以期生成可控制和高質量的蘑菇圖像。
GAN[18]的核心思想來源于博弈論的納什均衡,它設定雙方分別為生成器和判別器,生成器的目的是盡量學習真實的數據分布,而判別器的目的是盡量正確判別輸入數據是來自真實數據還是來自生成器。GAN中的生成器和判別器需要不斷優化,各自提高生成能力和判別能力,其學習優化過程就是尋找二者之間的一個納什均衡[29]。
GAN系統一般框架如圖1所示,系統結構主要包括:生成器(用于生成虛擬圖像),它通過接收隨機噪聲z,通過這個噪聲生成網絡G(z)。判別器是負責判斷圖像真假,輸入圖像x,輸出對該圖像的判別結果D(x)。

圖1 一般的GAN框架Fig.1 Framework of GAN
首先,在給定生成器G的情況下,最優化判別器D。采用基于Sigmoid的二分類模型的訓練方式,判別器D的訓練是最小化交叉熵的過程,其損失函數表示為
(1)
式中x——采樣于真實數據分布Pdata(x)
z——采樣于先驗分布Pz(z),例如高斯噪聲分布
E(·)——計算期望值
式(1)中判別器的訓練數據集來源于真實數據集分布Pdata(x)(標注為1) 和生成器數據分布Pg(x)(標注為0)。
給定生成器G,最小化式(1)得到最優解。對于任意的非零實數m和n,且實數值y∈[0,1],表達式為
Φ=-mlgy-nlg(1-y)
(2)

(3)
D(x)代表x來源于真實數據而非生成數據的概率。當輸入數據采樣自真實數據x時,D的目標是使得輸出概率D(x)趨近于1,而當輸入來自生成數據G(z)時,D的目標是正確判斷數據來源,使得D(G(z))趨近于0,同時G的目標是使得其趨近于1。生成器G損失函數可表示為
OG(θG)=-OD(θD,θG)
(4)
其優化問題是一個極值問題,GAN的目標函數可以描述為
min(G)max(D){f(D,G)=Ex~Pdata(x)lgD(x)+Ez~Pz(z)lg(1-D(G(z)))}
(5)
GAN模型需要訓練模型D最大化判別數據來源于真實數據或者偽數據分布G(z)的準確率,同時,需要訓練模型G最小化lg(1-D(G(z)))。
GAN學習優化的方法為:先固定生成器G,優化判別器D,使得D的判別準確率最大化;然后固定判別器D,優化生成器G,使得D的判別準確率最小化。當且僅當Pdata=Pg時達到全局最優解。
MPGAN系統的框架如圖2所示,蘑菇圖像的生成過程為:生成器G(z)使用截斷到一定范圍內的隨機正態分布數據作為輸入,輸入到卷積網絡(Convolutional neural network, CNN),最后輸出生成圖像數據。判別器D(x)根據真實圖像數據和生成圖像數據輸出判別結果,并對神經網絡的所有參數進行反向更新操作。

圖2 蘑菇表型數據生成的MPGAN框架Fig.2 MPGAN framework for mushroom phenotypic data generation

圖3 生成器神經網絡框架Fig.3 Neural network framework of generator
2.1.1生成器
生成器卷積神經網絡結構的作用是通過輸入隨機數據生成128×128×3的圖像,128表示像素數,3表示RGB的通道數。圖3是生成器的框架。
生成器采用8層的卷積神經網絡,首先是Input數據輸入層,第2層是全連接層(Fully connected, FC),然后是連續5個反卷積層(Deconvolution, DeConv),其中分為DC反卷積層、BN批歸一化層(Batch normalization,BN)和激活函數,批歸一化層是對于同一批次數據按照給定的系數進行規范化處理,以防止梯度彌散,最后是Output數據輸出層。生成器的反卷積層如圖4所示,各層具體描述如下:
(1)FC全連接層設計輸入為生成100個圖像的隨機數據,經過全連接層的8 192個神經元處理以及形狀重塑后變為4×4×512大小的數據,再經過批歸一化層及ReLU激活函數后將結果輸出到下一層。
(2)生成器中包括5個反卷積層,卷積核的移動步長為2,卷積核尺寸為5×5,1~4層的每一層經過批歸一化層及ReLU激活函數后將結果輸出到下一層,其中:
第1層輸入數據為4×4×512。反卷積層的卷積核數為256個,經過反卷積后得到的數據為8×8×256。
第2層輸入數據為8×8×256。反卷積層的卷積核數為128個,經過反卷積后得到的數據為16×16×128。
第3層輸入數據為16×16×128。反卷積層的卷積核數為64個,經過反卷積后得到的數據為32×32×64。
第4層輸入數據為32×32×64。反卷積層的卷積核數為32個,經過反卷積后得到的數據為64×64×32。

圖4 生成器的反卷積層Fig.4 Deconvolution layer of generator
第5層輸入數據為64×64×32。反卷積層的卷積核數為3個。輸入數據經過反卷積后得到的數據為128×128×3,再經過批歸一化層及tanh激活函數后將結果輸出到下一層。tanh函數表達式為
(6)
式中a——參數
不使用傳統的Sigmod函數進行Output輸出層,而是直接將上一層輸入結果輸出。生成器網絡參數如表1所示。

表1 生成器網絡參數Tab.1 Summary of generator network parameters

圖5 判別器神經網絡框架Fig.5 Neural network framework of discriminator
2.1.2判別器
判別器的作用是盡量擬合樣本之間的Wasserstein距離,從而將分類任務轉換成回歸任務。判別器采用7層的卷積神經網絡,首先是Input數據入層,接著是連續4個卷積層(Convolution,Conv),其中分為卷積層、歸一化層和激活函數,然后是全連接層FC,最后是數據輸出層Output。判別器的架構如圖5所示。
判別器的Conv卷積層設計如圖6所示。判別器共有4個卷積層,卷積核的移動步長為2,卷積核尺寸為5×5,經過歸一化層及Leaky ReLU激活函數后將結果輸出到下一層。
第1層輸入數據為128×128×3。卷積層的卷積核數為64個,經過卷積后得到的數據為64×64×64。
第2層輸入數據為64×64×64。卷積層的卷積核數為128個,經過卷積后得到的數據為32×32×128。

圖6 判別器的卷積層操作Fig.6 Convolution layer of discriminator
第3層輸入數據為32×32×128。卷積層的卷積核數為256個,經過卷積后得到的數據為16×16×256。
第4層輸入數據為16×16×256。卷積層的卷積核數為512個,經過卷積后得到的數據為8×8×512。
FC全連接層設計的輸入數據為8×8×512,經過全連接層處理以及形狀重塑后變為大小為1的蘑菇圖像,并將結果輸出。判別器的網絡參數如表2所示。

表2 判別器網絡參數Tab.2 Summary of discriminator network parameters
2.2.1Wasserstein距離
MPGAN系統采用帶有梯度懲罰的Wasserstein距離[30],Wasserstein距離[9,31-32]又叫推土機(Earth-mover,EM)距離,定義為
(7)
式中Pr——真實數據分布
Pg——生成數據分布
r——真實樣本
y——生成樣本
γ——聯合分布
∏(Pr,Pg)——Pr和Pg組合起來的所有可能的聯合分布的集合
對于每個可能的聯合分布γ而言,采樣(x,y)~γ得到一個真實樣本x和一個生成樣本y,并計算這對樣本之間的距離‖x-y‖,計算該聯合分布γ下樣本對距離的期望值E(x,y)~γ(‖x-y‖)。Wasserstein距離定義為在所有可能的聯合分布中能夠對這個期望值的下界[31]。
2.2.2系統損失函數
設定fw代表判別器網絡,根據Lipschitz連續性條件的要求,該判別器網絡含參數w,并且參數w不超過某個范圍,根據式(7)定義的Wasserstein距離,MPGAN系統判別器的目的是近似擬合Wasserstein距離,因此判別器的損失函數可以表示為
LD=Ex~Pg(fw(x))-Ex~Pr(fw(x))
(8)
MPGAN系統生成器的目的是近似地最小化Wasserstein距離,即最小化式(8),因此生成器的損失函數可以表示為
LG=Ex~Pr(fw(x))-Ex~Pg(fw(x))
(9)
GULRAJANI等[30]提出的帶有梯度懲罰的Wasserstein距離來滿足Lipschitz連續性。當生成數據分布Pg接近真實數據分布Pr時,Lipschitz連續性可表示為
‖D(Pg)-D(Pr)‖≤K‖Pg-Pr‖
(10)
式(10)可轉換為
(11)
式中Pc——生成數據分布與真實數據分布的差值
K——整數常量
先對真假樣本的數據分布進行隨機差值采樣,即產生一對真假樣本Xr和Xg,采樣公式為
X=ξXr+(1-ξ)Xg
(12)
式中ξ——[0,1]區間的隨機數
(13)
式中λ——調節梯度懲罰項大小的參數
K為使得Lipschitz連續性條件成立的常量,設定K為1,MPGAN系統的判別器損失函數式(9)和梯度懲罰項式(13),損失函數可表示為

(14)
根據GAN網絡的框架和優化過程,MPGAN系統的訓練過程如圖7所示。

圖7 MPGAN系統的訓練過程Fig.7 Training procedure of MPGAN system
圖7中的訓練過程描述如下:
(1)采用方差為0.02的截斷正態分布初始化網絡中的權值參數W和卷積核初始化網絡的偏置值b,初始化學習率η,即每次參數更新幅度。在訓練過程中,參數更新向著損失函數梯度下降的方向,表示為
Wn+1=Wn-ηΔ
(15)
式中Δ——梯度,即損失函數的導數
(2)采用區間為[-1,1]的均勻分布初始化隨機噪聲。
(3)采用數據集中隨機獲取批次大小的訓練樣本,并在輸入隊列中進行數據預處理。
(4)將步驟(2)中生成的隨機噪聲輸入到生成器網絡,生成虛擬圖像數據,將生成的虛擬圖像數據輸入判別器,得到生成圖像判別結果;將步驟(3)中獲取的訓練樣本使用批歸一化操作輸入判別器,得到真實圖像判別結果;計算判別器損失并反向更新判別器參數。
(5)計算梯度懲罰項,為判別器損失施加懲罰,然后使用優化器反向更新判別器參數,使用梯度懲罰項,替換原來的權重截斷策略。
(6)判斷是否達到指定判別器優化次數,即每優化一次生成器時優化N次判別器,若是則進入步驟(7),若否則重新進入步驟(3)。其中N由用戶設定。
(7)將步驟(2)中生成的隨機噪聲輸入到生成器網絡,計算生成器損失并使用優化器反向更新判別器參數。
(8)判斷是否達到指定迭代次數,即是否遍歷完全部樣本,若是則進入步驟(9),否則重新進入步驟(2)。
(9)判斷是否達到EPOCH次數,EPOCH為總共訓練的輪次,若是則結束,否則重新進入步驟(2)。
實驗平臺為Windows 10系統,16 GB內存,256 GB SSD,1 TB HD,Intel QuadCore i7-8700, 4.2 GHz, Nvidia GTX 1070,8 GB。算法采用Tensorflow V1.1 GPU框架[33]和Python 3.6實現。
采用兩類數據集:開源蘑菇數據集Fungi[28],選擇了其中375幅圖像;私有數據集,共138幅圖像。圖像預處理方法包括隨機翻轉、隨機亮度變換、隨機對比度變換和圖像歸一化,前面幾種預處理方法主要是為了增加樣本數量,而圖像歸一化是為了降低幾何變換帶來的影響。
圖8為開源數據集Fungi蘑菇示例圖像,該數據集環境噪聲大且背景復雜,背景中有草地、林地、樹葉、木塊等多種干擾物。

圖8 開源數據集示例Fig.8 Examples of public dataset
私有蘑菇數據集采用鳳尾菇作為對象,該數據集采用黑色作為背景,背景噪聲小,且蘑菇形狀不同,適合菌菇表型圖像生成。圖9為私有蘑菇數據集的示例圖像。

圖9 私有蘑菇數據集示例Fig.9 Examples of private dataset
MPGAN系統默認使用Adam優化器[34],優化器超參數β1=0.5、β2=0.9、ε=1×10-8,學習率η默認為0.000 3,判別器優化次數N=5。
3.2.1生成器參數設置
由于生成器的輸出層直接將前一層的值作為輸入,最后激活函數選擇tanh激活函數,該激活函數可以將輸出層的輸出約束到區間[-1,1]。
為了保證數據分布的一致性,并防止反向傳播權值更新時發生梯度彌散并加速收斂,采用批歸一化(Local response normalization),對同一批次數據按照給定的系數進行規范化處理。其處理步驟如下:
(1)沿通道計算同一批次內所有圖像的均值μB,計算式為
(16)
(17)
(3)對圖像做歸一化處理,計算式為
(18)
ω——防止方差為0的參數
(4)加入縮放變量γ和平移變量φ,得出結果
yi=γi+φ≡BNγ,φ(xi)
(19)
式中yi——加入縮放變量γ和平移變量φ處理結果
3.2.2判別器參數設置
選擇Leaky ReLU激活函數作為判別器激活函數,確保梯度更新整個圖像。Leaky ReLU激活函數表達式為
(20)
式中α——(1,+∞)區間內的參數
MPGAN系統生成式對抗網絡模型的梯度懲罰策略采用層歸一化函數(Layer normalization,LN)。
在學習率η為0.000 3時,使用開源數據集和私有數據集作為訓練數據集,MPGAN系統的Wasserstein距離與EPOCH的關系如圖10所示。

圖10 Wasserstein距離收斂曲線Fig.10 Wasserstein distance convergence curves
由圖10a可知,在開源數據集,EPOCH大于2 000后逐漸開始學習到真實圖像的數據分布,在EPOCH達到10 000后逐漸趨于穩定,在這個階段數據集本身噪聲較大導致模型的學習能力有所下降,所以模型學習的特征被背景所干擾,并且在曲線尾部的振蕩程度明顯增大,此時減小學習率η可以使模型訓練更加穩定。
由圖10b可知,Wasserstein距離在EPOCH達到2 000后不斷收斂,在10 000左右有小幅振蕩,EPOCH在超過35 000之后,振蕩幅度減小,模型比較穩定。
由圖10可知,不同數據集訓練的EPOCH次數不同,開源數據集的噪聲較大,模型不容易收斂,并且相似度衡量指標Wasserstein距離在EPOCH為12 000時開始穩定在一個較高的程度;私有數據集上的噪聲較小,當在該數據集,模型收斂更加快速,Wasserstein距離在EPOCH大于35 000時開始逐漸收斂穩定。
基于開源數據集的學習率與EPOCH關系如圖11所示。從圖11可看出,提高學習率η時,模型的收斂速度有明顯的提升并在EPOCH為1 000后逐漸穩定,但是隨著學習率的提高,收斂的振蕩程度也在加大,因此可以在訓練初期使用較大的學習率提高初始收斂速度,然后逐漸減小學習率保證訓練過程穩定。由于在私有數據集上的結果類似,因此僅報告了開源數據集上的測試結果。

圖11 基于開源數據集的學習率與EPOCH關系Fig.11 Learning rate and EPOCH relationship based on open source dataset
首先,系統測試了數據中的scalpturatum口蘑,EPOCH為1 000時,學習率η為0.000 1~0.000 5生成圖像如圖12所示。圖12a為原始圖像,從圖12b可看出,學習率η為0.000 3時,生成的菌菇圖像相對較好。

圖12 不同學習率的菌菇圖像生成結果對比Fig.12 Mushroom image generation results comparison at different learning rates
當學習率η為0.000 3時,在開源數據集和私有數據集上,測試了系統菌菇圖像生成結果,生成圖像尺寸設置為64像素×64像素,結果分別如圖13和圖14所示。圖13為EPOCH為15 000時,開源數據集上的生成結果。圖13b的生成圖像能夠清晰地顯示出原始菌菇的表型特征。
圖14為EPOCH為50 000時,私有數據集上的生成結果。圖14b的生成圖像能夠清晰地顯示出原始菌菇的表型特征。

圖13 基于開源數據集上的蘑菇生成圖像Fig.13 Illustration of generating Fungi images based on public dataset

圖14 基于私有數據集上的蘑菇生成圖像Fig.14 Illustration of generating Fungi images based on private dataset
對比圖13b和圖14b可以看出,圖14b質量優于圖13b,表明高質量的菌菇訓練數據對圖菌菇表型圖像的生成有重要影響。
(1)研究了菌菇表型數據生成技術,設計了用于菌菇表型數據生成的生成式對抗網絡結構。使用Wasserstein距離和帶有梯度懲罰的損失函數。
(2)利用開源數據和私有數據集進行了測試,結果表明,數據集噪聲越小越好,噪聲越小則損失越容易收斂,否則背景和主體目標發生混淆時,損失會在一個較大程度上振蕩。
(3)測試了學習率η、EPOCH與Wasserstein距離關系,系統生成的菌菇表型數據可為后期菌菇數據分類與識別提供大數據基礎,為解決菌菇分類的數據非均衡、長尾分布等問題提供研究基礎。