侍海峰,何良華,盧劍
(1.同濟(jì)大學(xué)計(jì)算機(jī)系,上海201804;2.北京大學(xué)第三醫(yī)院,北京100083)
生成式對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Nets,GAN)[1]是一類基于對(duì)抗訓(xùn)練和深度神經(jīng)網(wǎng)絡(luò)的無監(jiān)督生成式模型,由一個(gè)生成器網(wǎng)絡(luò)和一個(gè)判別器網(wǎng)絡(luò)構(gòu)成,可以生成服從訓(xùn)練集分布的無限多的樣本。自2014年被提出后,迅速成為深度學(xué)習(xí)、人工智能領(lǐng)域的研究熱點(diǎn)之一,在圖像生成、圖像風(fēng)格變換、圖像超分辨率、視頻生成等領(lǐng)域應(yīng)用廣泛。
然而GAN存在訓(xùn)練不穩(wěn)定,容易對(duì)抗崩潰的問題。此外,生成分布還存在“模式丟失(mode dropping)”問題[2],只能生成訓(xùn)練集分布的一個(gè)子集,多樣性不足。WGAN-GP[3]改進(jìn)了原始GAN的目標(biāo)函數(shù),解決了訓(xùn)練不穩(wěn)定的問題。混合生成式對(duì)抗網(wǎng)絡(luò)(Mixture GAN,MGAN)[4]是一種集成模型,通過混合多個(gè)生成器的分布來改善模式丟失問題,增加生成樣本的多樣性。然而,MGAN的多個(gè)生成器的混合權(quán)重被設(shè)置為均等值,不適合類別不平衡且比例未知的數(shù)據(jù)集。
基于WGAN-GP和MGAN,本文提出模式分工型混合生成式對(duì)抗網(wǎng)絡(luò)(Mode-Splitting MGAN,MSMGAN),向MGAN的訓(xùn)練算法中加入了生成器混合權(quán)重學(xué)習(xí)環(huán)節(jié),提高了MGAN在類別不平衡數(shù)據(jù)集上的生成效果,能促使多個(gè)生成器分別學(xué)習(xí)訓(xùn)練集中不同的模式,即“模式分工”。此外,替換MGAN的原始GAN目標(biāo)函數(shù)為WGAN-GP的目標(biāo)函數(shù),使訓(xùn)練更穩(wěn)定。在由UTKFace[5]和Toronto Face Dataset[6]混合而成的多模態(tài)不平衡人臉圖像數(shù)據(jù)集上的實(shí)驗(yàn)表明,MSMGAN生成分布具有更低的Frechet Inception Distance(FID)[7],且支持按類別生成圖片。
原始生成式對(duì)抗網(wǎng)絡(luò)[1]結(jié)構(gòu)如圖1所示,由生成器和判別器兩個(gè)神經(jīng)網(wǎng)絡(luò)構(gòu)成。生成器將輸入的噪聲向量z映射為生成樣本xF(又稱假樣本);判別器接收生成的假樣本和來自訓(xùn)練集的真實(shí)樣本x,輸出樣本為真實(shí)樣本的概率。
噪聲向量z的各分量通常是相互獨(dú)立的高斯噪聲。生成器的訓(xùn)練目標(biāo)是優(yōu)化自身參數(shù),盡可能使判別器誤把假樣本判別成真樣本;判別器的訓(xùn)練目標(biāo)則是優(yōu)化自身參數(shù),盡可能準(zhǔn)確地區(qū)分真實(shí)樣本和假樣本。生成器和判別器訓(xùn)練的目標(biāo)函數(shù)(損失函數(shù))分別為:


圖1 生成式對(duì)抗網(wǎng)絡(luò)結(jié)構(gòu)圖
文獻(xiàn)[1]證明了在理想的條件下,對(duì)抗訓(xùn)練將達(dá)到納什均衡,生成器生成的樣本xF的分布Pg將和訓(xùn)練集真實(shí)分布Pr相同,判別器將無法區(qū)分其輸入樣本的來源,輸出恒定為0.5。此時(shí)的生成器即可用于生成能以假亂真的樣本。
原始GAN存在訓(xùn)練不穩(wěn)定,容易崩潰的問題。針對(duì)該問題,文獻(xiàn)[3]提出WGAN-GP模型,模型的判別器損失函數(shù)為:


WGAN-GP訓(xùn)練穩(wěn)定性和生成分布的多樣性均優(yōu)于原始GAN,本文提出的MS-MGAN模型亦采用該損失函數(shù)和訓(xùn)練方式。
為了解決GAN存在的“模式丟失”問題,增加生成樣本的多樣性,一些基于集成模型思路的集成類GAN模型被提出,Mixture GAN(MGAN)便是其中的典型。MGAN由K個(gè)生成器網(wǎng)絡(luò)、一個(gè)判別器網(wǎng)絡(luò)D和一個(gè)分類器C構(gòu)成。分類器預(yù)測(cè)生成樣本來源于哪一個(gè)生成器,判別器預(yù)測(cè)生成樣本來源于真實(shí)分布還是生成分布。MGAN的生成器、判別器和分類器進(jìn)行如下的最小-最大博弈:

可見生成器的目標(biāo)有兩部分,既含有原始GAN目標(biāo)函數(shù)中對(duì)抗判別器的項(xiàng),又包含迎合分類器分類的項(xiàng)。后者含有超參數(shù)β,用于平衡目標(biāo)函數(shù)中二者的比例。MGAN中各生成器G1,G2,…,Gk的混合權(quán)重被π1,π2,…,πK設(shè)定為1/K,即K個(gè)生成器的分布均勻混合。MGAN的目標(biāo)函數(shù)能直接迫使多個(gè)生成器生成不同模式的樣本,以便于分類器區(qū)分,適合學(xué)習(xí)由若干個(gè)良好分離的分布等概率混合而成的分布。
基于WGAN-GP和MGAN模型,本文提出模式分工混合生成對(duì)抗網(wǎng)絡(luò)(MS-MGAN),以更好地學(xué)習(xí)和生成類別不均衡的數(shù)據(jù)分布。原始GAN和WGAN-GP模型都只具有一個(gè)生成器,讓單個(gè)生成器網(wǎng)絡(luò)學(xué)習(xí)復(fù)雜的多模態(tài)圖像數(shù)據(jù)分布是比較困難的,易導(dǎo)致生成的圖像質(zhì)量欠佳。MGAN采用多個(gè)生成器分工學(xué)習(xí)復(fù)雜的數(shù)據(jù)分布,其實(shí)驗(yàn)表明[4],算法提高了生成樣本的質(zhì)量和多樣性,但是其超參數(shù)β對(duì)數(shù)據(jù)集比較敏感,需要精心調(diào)節(jié),增加了算法的調(diào)參難度。此外,MGAN的多個(gè)生成器的混合權(quán)重π1,π2,…,πK被設(shè)定為均勻分布,但現(xiàn)實(shí)中的數(shù)據(jù)集往往各類別(模式)的占比不均勻,導(dǎo)致MGAN的各生成器會(huì)出現(xiàn)不合理的分工,影響生成質(zhì)量。
本文提出的MS-MGAN舍棄了MGAN中的分類器,從而從模型中去除了敏感的超參數(shù) β。采用WGAN-GP的訓(xùn)練目標(biāo)代替了MGAN中使用的原始GAN目標(biāo)函數(shù),提高了模型的訓(xùn)練穩(wěn)定性。此外,增加了多個(gè)生成器混合權(quán)重的學(xué)習(xí)環(huán)節(jié),能根據(jù)訓(xùn)練分布中不同模式樣本數(shù)量的占比分配各生成器對(duì)應(yīng)得權(quán)重,使得各生成器合理分工學(xué)習(xí)訓(xùn)練集中不同得模式,即使沒有額外分類器的促使作用。MS-MGAN的判別器損失函數(shù)同WGAN-GP的判別器損失函數(shù),即式(3)。而生成器損失函數(shù)則被修改為判別器對(duì)多個(gè)生成器樣本評(píng)價(jià)的加權(quán)值:

MS-MGAN的訓(xùn)練算法在WGAN-GP的基礎(chǔ)上增加了對(duì)π1,π2,…,πk的梯度下降法更新過程。在一次GAN訓(xùn)練迭代中,除了原有的①固定生成器網(wǎng)絡(luò),訓(xùn)練判別器;②固定判別器網(wǎng)絡(luò),訓(xùn)練生成器;這兩個(gè)步驟以外,加入混合權(quán)重學(xué)習(xí)環(huán)節(jié)③固定判別器網(wǎng)絡(luò)和各生成器網(wǎng)絡(luò),并從每一個(gè)生成器各采樣一個(gè)mini-batch的生成樣本,計(jì)算判別器對(duì)各生成器分布的期望評(píng)價(jià)并作為常數(shù)帶入(6)式,再將生成器的損失函數(shù)對(duì)π1,π2,…,πK,求梯度,更新混合權(quán)重。每一個(gè)迭代中進(jìn)行上述3個(gè)步驟的計(jì)算,可以使多個(gè)生成器按訓(xùn)練集中不同模式的比重合理分工。
為了測(cè)試MS-MGAN在類別不均衡數(shù)據(jù)集上的表現(xiàn),將UTKFace和Toronto Face Dataset(TFD)數(shù)據(jù)集中的人臉圖像混合成一個(gè)訓(xùn)練集。UTKFace數(shù)據(jù)集提供了23708張分辨率為200×200的剪裁并對(duì)齊了的彩色人臉圖像,TFD包含102236張分辨率96×96的灰度人臉圖像。由于UTKFace中的圖像數(shù)量較少,將每一張人臉圖像都水平翻轉(zhuǎn),以將數(shù)據(jù)集圖像數(shù)量倍增至47416。TFD數(shù)據(jù)集中的圖像的灰度通道則被復(fù)制為3通道彩色圖像。所有人臉圖像均縮放至64×64分辨率。因此可知不平衡混合人臉數(shù)據(jù)集中UTKFace和TFD這兩種模式的比例為47416:102236=0.3168:0.6832。
本文MS-MGAN實(shí)驗(yàn)程序基于WGAN-GP的官方開源代碼修改而成,生成器網(wǎng)絡(luò)和判別器網(wǎng)絡(luò)均選用廣泛使用的類DCGAN[8]的網(wǎng)絡(luò)結(jié)構(gòu)。實(shí)驗(yàn)使用NVIDIA GTX 1080Ti GPU和TensorFlow 1.12進(jìn)行訓(xùn)練,操作系統(tǒng)為Ubuntu 18.04。
實(shí)驗(yàn)評(píng)價(jià)除了采用直接觀察生成圖像的定性方法外,還采用廣泛使用的定量指標(biāo)Frechet Inception Distance(FID)。FID使用Inception[9]模型的中間編碼層的特征向量,對(duì)訓(xùn)練集真實(shí)圖像和生成器合成的圖像的Inception編碼層特征分別回歸成多元高斯分布,然后計(jì)算這兩個(gè)多元高斯分布之間的Frechet距離,計(jì)算公式如下:

其中mr,mF分別是真實(shí)圖像和生成圖像輸入Inception模型得到的編碼層向量的均值,Cr和CF分別是協(xié)方差。FID值越低,生成分布就更接近真實(shí)分布。
實(shí)驗(yàn)考察了單生成器的WGAN-GP模型、具有2個(gè)生成器的MGAN模型和具有2個(gè)生成器的MSMGAN模型。為了公平比較,單生成器的WGAN-GP模型的生成器規(guī)模等比例放大到MGAN和MS-MGAN生成器的2倍。三個(gè)模型的實(shí)驗(yàn)的批大小均為64,訓(xùn)練迭代200000次(每次迭代中生成器被訓(xùn)練一次,判別器被訓(xùn)練5次)。每個(gè)WGAN-GP模型和MGAN模型訓(xùn)練一次花費(fèi)GPU約26小時(shí),MS-MGAN由于增加了混合權(quán)重學(xué)習(xí)環(huán)節(jié),訓(xùn)練一次花費(fèi)27小時(shí)左右。為了避免神經(jīng)網(wǎng)絡(luò)類方法固有的隨機(jī)誤差,每個(gè)模型均隨機(jī)初始化訓(xùn)練了4次。訓(xùn)練過程中,每10000次迭代,采樣50000個(gè)生成樣本,計(jì)算FID值。每個(gè)模型在一次訓(xùn)練過程中達(dá)到的最小FID值的平均值、標(biāo)準(zhǔn)差如表1所示。表中生成器參數(shù)量是指模型包含的所有生成器網(wǎng)絡(luò)參數(shù)的總和。所有模型都只有一個(gè)相同結(jié)構(gòu)的判別器,判別器參數(shù)量均為4.317M。

表1 不同模型生成質(zhì)量(FID)對(duì)比表
從表1可以看出,WGAN-GP、MGAN和MSMGAN的FID值依次降低,表明后二者生成分布質(zhì)量均優(yōu)于WGAN-GP,即使單生成器模型的生成器尺寸已經(jīng)翻倍。支持生成器混合權(quán)重學(xué)習(xí)的MS-MGAN取得了最低的FID值,表明MS-MGAN更能適應(yīng)現(xiàn)實(shí)中更為普遍的類別不均衡數(shù)據(jù)集。
圖2為MS-MGAN模型在200000次GAN訓(xùn)練迭代中,兩個(gè)生成器的混合權(quán)重的變化趨勢(shì)。橫軸為迭代次數(shù),為展現(xiàn)訓(xùn)練早期的曲線變化,采用對(duì)數(shù)坐標(biāo)。可見訓(xùn)練初期兩個(gè)生成器的混合權(quán)重波動(dòng)較為劇烈,因?yàn)榇藭r(shí)模型剛初始化,見到的訓(xùn)練集樣本不多,生成分布和判別器判別都不太準(zhǔn)確。但是迭代次數(shù)超過10000次后,混合權(quán)重已經(jīng)基本穩(wěn)定在真實(shí)值附近小幅波動(dòng)了。因此,本文提出的MS-MGAN可以快速學(xué)習(xí)出合理的生成器混合權(quán)重,其實(shí)可以在后95%的訓(xùn)練迭代中的固定混合權(quán)重不再學(xué)習(xí),以加快訓(xùn)練速度。

圖2 MS-MGAN生成器混合權(quán)重變化圖
表2為MS-MGAN在4次隨機(jī)初始化實(shí)驗(yàn)中學(xué)習(xí)到的兩個(gè)生成器的混合權(quán)重。由于兩個(gè)生成器符號(hào)具有輪換對(duì)稱性,我們約定訓(xùn)練結(jié)束后混合權(quán)重較小的生成器為G1,對(duì)應(yīng)權(quán)重為π1。混合權(quán)重較大的則為G2,對(duì)應(yīng)權(quán)重為π2。可見4次實(shí)驗(yàn)中MS-GAN均基本準(zhǔn)確地向兩個(gè)生成器分配了符合訓(xùn)練集兩種模式占比(0.3168:0.6832)的混合權(quán)重。因此,本文提出的混合權(quán)重學(xué)習(xí)算法是穩(wěn)定而精確的。

表2 MS-MGAN混合權(quán)重學(xué)習(xí)結(jié)果表
從MS-MGAN和MGAN的兩個(gè)生成器隨機(jī)采樣的一部分樣本如圖3所示,前2行(紅線上方)分別是MS-MGAN的兩個(gè)生成器生成的樣本,后兩行(紅線下方)分別是MGAN的兩個(gè)生成器生成的樣本。通過比較可以發(fā)現(xiàn),MS-MGAN的兩個(gè)生成器G1,G2分別學(xué)習(xí)生成了混合人臉數(shù)據(jù)集中UTKFace(彩色)和TFD(灰度)這兩種模式的樣本,而MGAN的生成器G2負(fù)責(zé)生成TFD的人臉,G1既負(fù)責(zé)生成一部分UTKFace的彩色人臉,又負(fù)責(zé)生成一部分TFD人臉。這是因?yàn)镸GAN的兩個(gè)生成器的混合權(quán)重被固定為0.5,而訓(xùn)練集中兩種模式的比例分別為0.3168:0.6832,導(dǎo)致必須有一個(gè)生成器負(fù)責(zé)兩種模式樣本的生成,才能使生成分布服從真實(shí)分布。MS-GAN由于具有混合權(quán)重學(xué)習(xí)環(huán)節(jié),兩個(gè)生成器的混合權(quán)重能快速收斂至訓(xùn)練集中兩種模式的混合比例,使生成器更合理的分工,從而可以生成比MGAN具有更少瑕疵、失真的人臉樣本,并可以通過選擇不同的生成器以實(shí)現(xiàn)按類別采樣,增加了采樣的可操控性。

圖3 MS-MGAN和MGAN隨機(jī)生成樣本圖
本文提出了模式分工型混合生成式對(duì)抗網(wǎng)絡(luò)(MS-MGAN),向MGAN訓(xùn)練算法中加入混合比例學(xué)習(xí)環(huán)節(jié),在類別分布不平衡數(shù)據(jù)集上的生成質(zhì)量?jī)?yōu)于MGAN,且各生成器能分工學(xué)習(xí)訓(xùn)練集中不同的模式,從而支持按類別采樣生成。