吳鵬翔,李凡長
(蘇州大學(xué)計算機科學(xué)與技術(shù)學(xué)院,江蘇蘇州 215006)
隨著計算設(shè)備并行計算性能的大幅提升,以及近年來深度神經(jīng)網(wǎng)絡(luò)在各個領(lǐng)域不斷取得重大突破,由深度神經(jīng)網(wǎng)絡(luò)模型衍生而來的多個機器學(xué)習(xí)新領(lǐng)域逐漸成型,如強化學(xué)習(xí)、深度監(jiān)督學(xué)習(xí)等[1-2]。在大量訓(xùn)練數(shù)據(jù)的加持下,深度神經(jīng)網(wǎng)絡(luò)技術(shù)已經(jīng)在機器翻譯、機器人控制、大數(shù)據(jù)分析、智能推送、模式識別等方面得到了廣泛應(yīng)用[3-4]。深度學(xué)習(xí)在完成這些任務(wù)時需要在大量數(shù)據(jù)上進行訓(xùn)練才能擬合出一個好的結(jié)果,一旦需要被識別物體類別不在訓(xùn)練集中,便無法進行正確識別。但是在實際的許多任務(wù)中,要求在少量數(shù)據(jù)上進行快速學(xué)習(xí)和適應(yīng)[5]。
元學(xué)習(xí)的提出為上述問題提供了一個解決方案,其目的是解決傳統(tǒng)神經(jīng)網(wǎng)絡(luò)模型泛化能力不足、對新種類任務(wù)適應(yīng)性較差的問題。快速學(xué)習(xí)的能力是人類區(qū)別于人工智能的一個關(guān)鍵特征[6],人類能夠有效地利用以前的知識和經(jīng)驗來快速學(xué)習(xí)新的技能。元學(xué)習(xí)的訓(xùn)練和測試可類比為人類在掌握一些基本技能后快速學(xué)習(xí)并適應(yīng)新的任務(wù)[7]。例如:人類可以根據(jù)一張從未見過的動物的照片辨認出該動物,而不是需要大量該動物的照片才能辨認。人類在幼兒階段掌握的對世界的大量基礎(chǔ)知識和對行為模式的認知基礎(chǔ)便對應(yīng)元學(xué)習(xí)中的“元”概念[8-9]。元學(xué)習(xí)的最終目標是實現(xiàn)擁有類似人類學(xué)習(xí)能力的強人工智能,這在當前階段體現(xiàn)為對新數(shù)據(jù)集的快速適應(yīng)以得到較高的準確度,因此,目前元學(xué)習(xí)目標主要表現(xiàn)為提高泛化性能、獲取好的初始參數(shù),以及通過少量計算和新訓(xùn)練數(shù)據(jù)即可在模型上實現(xiàn)和海量訓(xùn)練數(shù)據(jù)一樣的識別準確度[10]。受當前計算資源與算法能力限制,元學(xué)習(xí)往往以小樣本學(xué)習(xí)以及對新任務(wù)的快速適應(yīng)作為切入點,因此,當前研究也多以在小樣本數(shù)據(jù)集上的識別準確率作為實驗衡量標準[11]。
基于度量的元學(xué)習(xí)方法是一種可行的元學(xué)習(xí)方法。KOCH等于2015年提出了一種用于解決單樣本學(xué)習(xí)圖像分類問題的方法:孿生網(wǎng)絡(luò)(Siamese network)[12],通過訓(xùn)練集學(xué)習(xí)一個卷積孿生網(wǎng)絡(luò),利用該網(wǎng)絡(luò)計算待測試圖像與所有單標注樣本的相似度,相似度最高的單標注樣本所對應(yīng)的類別即是待測試圖像的類別。VINYALS 于2016 年提出了匹配網(wǎng)絡(luò)模型[13],其主要創(chuàng)新體現(xiàn)在建模過程和訓(xùn)練過程。對于建模過程的創(chuàng)新,該文通過設(shè)計基于記憶和注意力機制的匹配網(wǎng)絡(luò),使得模型能夠?qū)⑴c訓(xùn)練的樣本進行快速學(xué)習(xí)。對于訓(xùn)練過程的創(chuàng)新,該文基于傳統(tǒng)機器學(xué)習(xí)的一個重要原則,即訓(xùn)練和測試應(yīng)在同樣條件下進行,提出在訓(xùn)練時每次僅使用每一類任務(wù)的少量樣本參與網(wǎng)絡(luò)的訓(xùn)練,與測試過程保持一致。SNELL 于2017 年提出了原型網(wǎng)絡(luò)[14],該網(wǎng)絡(luò)模型基于一個基本假設(shè),即在數(shù)據(jù)集中,對于每種不同的類型都存在一個原型點。數(shù)據(jù)集中距離該原型點越近的樣本,其標簽與該原型點對應(yīng)的標簽相同的概率就越大。文獻[15]提出了由嵌入模塊和關(guān)系模塊組成的關(guān)系網(wǎng)絡(luò),其中嵌入模塊用于提取輸入圖像的特征,關(guān)系模塊用于得到輸入特征的相似度。
傳統(tǒng)基于度量的元學(xué)習(xí)算法采用卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network,CNN)提取特征,但是元學(xué)習(xí)問題中的某些樣本圖片特征不僅具有平移對稱性[16],而且還具有旋轉(zhuǎn)對稱性和鏡像對稱性[17],但是CNN 只具有平移不變性,不存在對后兩者的不變性,這就使得傳統(tǒng)的元學(xué)習(xí)算法不能利用具有對稱性的特征。常用的解決方法是數(shù)據(jù)增強[18],即對樣本進行隨機變換。此類方法雖然在一定程度上增強了泛化性,但是并不能保留局部對稱性[19],更不能保證在每一層卷積上的等變性。群等變卷積神經(jīng)網(wǎng)絡(luò)(Group equivariant CNN,G-CNN)則能較好地解決這一問題[20],其不僅具有平移不變性,而且還具有旋轉(zhuǎn)和鏡像不變性。
為有效利用樣本圖片的局部旋轉(zhuǎn)對稱性和鏡像對稱性,提高特征提取能力,本文提出一種基于G-CNN 的度量元學(xué)習(xí)算法。通過由群等變卷積構(gòu)成的4 層映射網(wǎng)絡(luò)學(xué)習(xí)一個合適的度量空間,根據(jù)查詢集中樣本離原型點的距離完成分類。
元學(xué)習(xí)的目標是跨任務(wù)的泛化。考慮一個任務(wù)分布P(T),即該模型所適配的數(shù)據(jù)的全體,目的是使這個模型可以適應(yīng)這個任務(wù)分布P(T)。與傳統(tǒng)機器學(xué)習(xí)不同,元學(xué)習(xí)不是根據(jù)每個樣本來優(yōu)化,而是根據(jù)元任務(wù)來優(yōu)化。每個元任務(wù)包含一個支持集和一個對應(yīng)的查詢集。在n-wayk-shot 元學(xué)習(xí)問題中,對于每個元任務(wù)定義支持集S和查詢集Q,支持集和查詢集中包含n個類別的樣本,支持集中每類樣本只存在k個,查詢集中每類樣本個數(shù)不定,支持集S定義如式(1)所示:

其中:xi表示樣本的D維向量表示;yi表示樣本對應(yīng)的標簽;n表示樣本類別總數(shù)。查詢集Q取自數(shù)據(jù)集中和支持集S同類別但不同的樣本,不帶標簽。圖1給出了5-way 1-shot 元學(xué)習(xí)問題中訓(xùn)練時所使用的的支持集和查詢集示例。

圖1 5-way 1-shot 元學(xué)習(xí)問題中元訓(xùn)練使用的支持集和查詢集示例Fig.1 Example of support set and query set using in meta-training for 5-way 1-shot meta-learning problems
在訓(xùn)練階段,從P(T)的訓(xùn)練數(shù)據(jù)集上采樣訓(xùn)練元任務(wù),通過元任務(wù)對損失函數(shù)進行最小化,從而優(yōu)化模型參數(shù)。在訓(xùn)練結(jié)束后,從同取自P(T)未參與訓(xùn)練的測試數(shù)據(jù)集(測試集中的樣本和訓(xùn)練集中的樣本類別不同)中采樣測試元任務(wù),對訓(xùn)練好的模型進行測試。
盡管現(xiàn)階段的神經(jīng)網(wǎng)絡(luò)研究缺少理論支撐,但是大量經(jīng)驗證據(jù)表明,卷積權(quán)值共享和網(wǎng)絡(luò)深度對于神經(jīng)網(wǎng)絡(luò)的效果起到了重要作用[21-22]。卷積權(quán)值共享的有效性依賴于其在多數(shù)感知任務(wù)中都具有平移不變性,即預(yù)測標簽的函數(shù)和數(shù)據(jù)分布對于平移變換都近似于不變。由于平移不變性,共享權(quán)重的卷積核可以從圖像的局部區(qū)域提取特征,并且參數(shù)量遠少于全連接網(wǎng)絡(luò)[23],同時能夠?qū)W到更多有效的變換信息[24-25]。卷積層可以有效地應(yīng)用于深度網(wǎng)絡(luò)中,因為這種網(wǎng)絡(luò)中的所有層都具有平移不變性:將圖片平移后再送入若干卷積層得到的結(jié)果,與將原圖直接送入相同卷積層再對特征圖進行平移所得到的結(jié)果相同[26]。因此,為提高特征提取能力,本文使用G-CNN 來構(gòu)建映射網(wǎng)絡(luò),使映射網(wǎng)絡(luò)對具有旋轉(zhuǎn)對稱的特征和鏡像對稱的特征也能保持不變性。映射網(wǎng)絡(luò)使用4 層G-CNN 構(gòu)建,每層由卷積核、batch-norm、relu 激活函數(shù)和最大池化層組成。
對于輸入的2 維圖片,卷積是不斷平移卷積核和特征圖做點積運算的過程,以群G上的函數(shù)代替平移就得到了群卷積,如式(2)所示:

其中:Z2是2 維圖片上的整數(shù)平移群;群運算是加法(n,m)+(p,q)=(n+p,m+q);f是輸入的特征圖;φ是卷積核。f和φ都是Z2上的函數(shù),只適用于群卷積的第1 層,但由于卷積輸出的結(jié)果是離散群G上的函數(shù),因此第1 層后的卷積如式(3)所示:

其中:輸入的特征圖f是群G上的函數(shù)。
令h=uh,等變性證明如式(4)所示:

映射網(wǎng)絡(luò)中的非線性單元包括激活函數(shù),可以將非線性單元看作一個映射:v:R →R,非線性單元作用于特征圖f可以視為一系列操作算子的組合,如式(5)所示:

因此,使用非線性單元處理特征圖后依然能保持等變性。
池化可以分解為不帶步長的池化和下采樣[27]兩部分。對于不帶步長的池化,定義池化操作為P,作用于特征圖f的最大池化如式(6)所示(平均池化同理):

其中:gU是G的子群U上的一個g變換。在G-CNN中,下采樣表示在G的子群H上下采樣。例如:對輸入2 維圖片做步長為2 的最大池化,等價于先進行不帶步長的池化,再在Z2的子群H={(2i,2j)|(i,j)∈Z2}上進行下采樣。
對于具有90°旋轉(zhuǎn)對稱特征的圖片,群G使用p4 群;對于具有90°旋轉(zhuǎn)對稱和鏡像對稱的特征,群G使用p4m 群[28]。p4 群的群元定義如式(7)所示:

其中:0≤r<4,r=0 表示無旋 轉(zhuǎn),r=1 表示旋轉(zhuǎn)90°;(u,v)∈Z2,表示在二維平面上的水平和垂直移動。群運算為矩陣乘法。對于輸入的特征圖上的某點(x,y),p4 群作用于點(x,y)的運算如式(8)所示:

其中:m=0 或1,1 表示鏡像翻轉(zhuǎn),其余定義與p4 群相同,群運算為矩陣乘法。作用于輸入特征圖上某點(x,y)的運算如式(10)所示:

當群G使用p4群時,第1層的G-CNN 是Z2-p4 卷積層,操作如圖2 所示,依次將卷積核旋轉(zhuǎn)90°,得到4 組卷積核,分別與輸入圖片做卷積,得到4 組映射特征。第一層后面的G-CNN 是p4-p4 卷積,操作如圖3所示,對于前層輸入的4 組映射特征,卷積核依次旋轉(zhuǎn)90°得到4 組卷積核,然后每組卷積核依次和輸入的4 組特征做卷積,將得到的結(jié)果求和得到輸出特征。使用p4m 群構(gòu)建映射網(wǎng)絡(luò)時,卷積核需要額外進行鏡像翻轉(zhuǎn),因此,卷積核的數(shù)目是8 組,得到的輸出特征也是8 組,操作與使用p4 群類似。

圖2 Z2-p4 卷積層示意圖Fig.2 Schematic diagram of Z2-p4 convolution layer

圖3 p4-p4 卷積層示意圖Fig.3 Schematic diagram of p4-p4 convolution layer
本文算法基于以下基本假設(shè):存在一個空間,在這個空間中,屬于相同類別的樣本距離近,不同類別的樣本距離遠,這樣就可以通過簡單度量函數(shù)進行分類。本文算法是通過學(xué)習(xí)一個映射網(wǎng)絡(luò)將樣本映射到合適的度量空間,然后通過簡單度量方法完成分類。在n-wayk-shot 元學(xué)習(xí)問題中,對于每個元任務(wù),支持集中每類有k個樣本,支持集經(jīng)過映射網(wǎng)絡(luò)映射到度量空間后,每一類就有k個表示,取每類k個表示的均值作為該類在度量空間中的代表。每個類在度量空間的代表稱為該類的原型點cj,計算公式如式(11)所示:

其中:k表示支持集中每類樣本的個數(shù);fθ表示映射網(wǎng)絡(luò);(xi,yj)表示輸入的樣本和對應(yīng)的標簽。查詢集經(jīng)過同樣的映射網(wǎng)絡(luò)映射到度量空間中,利用距離計算函數(shù)d來計算查詢集中待分類樣本到每類原型點的距離,再利用softmax 函數(shù)計算屬于每個類的概率,如式(12)所示:

最后,使用交叉熵作為損失函數(shù),如式(13)所示:

通過Adam 優(yōu)化器來最小化損失函數(shù),從而優(yōu)化映射網(wǎng)絡(luò)的參數(shù),不斷從訓(xùn)練集中抽取樣本組成元任務(wù)來訓(xùn)練模型,直到得到一個能很好地將訓(xùn)練樣本映射到合適度量空間的模型。
本文提出的基于群等變卷積的度量元學(xué)習(xí)算法(Metric Meta-learning algorithm Based on Group Equivariant Convolution,MMBOGEC)如算法1 所示。
算法1MMBOGEC
輸入訓(xùn)練集D={(x1,y1),(x2,y2),…,(xN,yN)}
輸出模型在測試集上的分類準確率
1)在訓(xùn)練集中隨機選取n個類,對于選取的每個類,取k個樣本組成支持集,取Nq個樣本組成查詢集。
2)通過映射網(wǎng)絡(luò)得到支持集樣本在度量空間中的表示,取每個類所有樣本在度量空間中特征表示的均值作為該類的原型點。
3)利用同樣的映射網(wǎng)絡(luò)得到查詢集樣本在度量空間中的表示,利用距離計算公式計算查詢集樣本在度量空間中的表示到每個類原型點的距離,利用softmax 函數(shù)計算屬于每個類的概率,將概率最大的類別作為預(yù)測類別。
4)使用交叉熵作為損失函數(shù)更新?lián)p失J。
5)使用Adam 優(yōu)化器最小化損失J來更新網(wǎng)絡(luò)參數(shù)。
6)重復(fù)步驟1~步驟5,直到損失J不再下降。
7)在測試集中生成若干個元任務(wù),每個元任務(wù)隨機選取n個類,對于選取的每個類,取k個樣本組成支持集,取Nq個樣本組成查詢集,將這些元任務(wù)輸入訓(xùn)練好的模型,得到分類準確率,最后將分類準確率的平均值作為輸出結(jié)果。
本文在常用的小樣本數(shù)據(jù)集miniImageNet 和Omniglot 上進行實驗。
miniImageNet 數(shù)據(jù)集包含60 000 張彩色圖片,分為100 個類,每個類600 張。首先將所有圖片處理成84 像素×84 像素大小,將其中的64 類作為訓(xùn)練集,16 類作為驗證集,剩下的20 類作為測試集。本文使用64 類來訓(xùn)練模型,驗證集僅僅用來判斷模型泛化性的好壞,不參與模型的參數(shù)優(yōu)化。
輸入的樣本圖片經(jīng)過映射網(wǎng)絡(luò)得到其在度量空間中的特征表示,映射網(wǎng)絡(luò)包含4 層由G-CNN 構(gòu)成的卷積,每一層使用64個3×3卷積核,包含batch-norm、relu 激活函數(shù)以及3×3 的最大池化層。最后將得到的特征表示展開成一維向量,利用距離計算函數(shù)計算其到各個原型點的距離,將距離最近的類別作為預(yù)測標簽。以交叉熵作為損失函數(shù),不添加正則項損失,學(xué)習(xí)率設(shè)置為10-3,使用Adam 優(yōu)化器對網(wǎng)絡(luò)參數(shù)進行優(yōu)化。
針對miniImageNet 數(shù)據(jù)集常用的有兩種訓(xùn)練方法,分別是5-way 1-shot 和5-way 5-shot。5-way 1-shot訓(xùn)練方法先任意地從訓(xùn)練集中選5 個類別,每個類別包含1 個樣本,總計5 個樣本作為支持集,再從前面5 類中每類選取若干個不同的樣本(本文實驗中設(shè)置為15 個)作為查詢集,使模型根據(jù)支持集來分類查詢集。5-way 5-shot 訓(xùn)練方法將支持集每類選取樣本數(shù)改為5,其余和前面一致。當驗證集上的驗證損失不再下降時,停止訓(xùn)練模型,在測試集上驗證模型的效果,測試方法和訓(xùn)練方法保持一致,測試使用隨機產(chǎn)生的600 個元任務(wù),以平均準確率作為評估指標。
4.1.1 不同距離計算公式對實驗結(jié)果的影響
不同距離的度量公式會對算法的實驗結(jié)果產(chǎn)生影響,本文使用常用的4 種距離計算公式進行測試,分別是歐式距離、余弦距離、切比雪夫距離和曼哈頓距離,測試結(jié)果對比如表1 所示。可以看出,在miniImageNet 數(shù)據(jù)集5-way 1-shot 和5-way 5-shot 方法中,歐氏距離作為距離計算公式最有效,其次是曼哈頓距離,切比雪夫距離最差。

表1 使用不同距離計算公式的實驗結(jié)果對比Table 1 Comparison of experimental results using different distance calculation formulas %
4.1.2 消融實驗
為驗證本文算法的有效性,分別使用p4 群、p4m群和普通CNN 構(gòu)建映射網(wǎng)絡(luò)行實驗,對比實驗結(jié)果如表2 所示。可以看出:不使用群等變卷積的方法,實驗結(jié)果最差;使用p4 群的方法,實驗結(jié)果優(yōu)于使用普通CNN 的方法,表明在本實驗中,具有旋轉(zhuǎn)不變性的方法比不具有旋轉(zhuǎn)不變性的方法更有效;使用p4m 群的方法,實驗效果最好,表明利用旋轉(zhuǎn)不變性和鏡像對稱不變性能有效提高元學(xué)習(xí)的自適應(yīng)性。

表2 消融實驗結(jié)果對比Table 2 Comparison of ablation experimental results %
4.1.3 G-CNN 層數(shù)對實驗結(jié)果的影響
為進一步驗證群等變卷積的有效性,在部分卷積層上使用群等變卷積進行實驗,實驗結(jié)果如表3所示,其中第1 列表示使用群等變卷積的卷積層,如1 表示僅在第1 層使用,其余層使用普通CNN。可以看出,在5-way 1-shot 和5-way 5-shot 的實驗中,僅僅在單層中使用群等變卷積,不論是在哪一層使用,實驗結(jié)果都相差不大,表明僅在某一層具有等變性不能很好地提升元學(xué)習(xí)的自適應(yīng)性。隨著使用群等變卷積層數(shù)的增加,實驗效果隨之提升,完整的4 層群等變卷積網(wǎng)絡(luò)效果最好,表明整個網(wǎng)絡(luò)都具有等變性才能更好地適用于元學(xué)習(xí)問題。

表3 在不同卷積層使用G-CNN 的實驗結(jié)果對比Table 3 Comparison of experimental results using G-CNN in different convolutional layers %
4.1.4 與4 層元學(xué)習(xí)算法的實驗結(jié)果對比
將本文算法與傳統(tǒng)4 層元學(xué)習(xí)算法進行對比,實驗結(jié)果如表4 所示(加粗數(shù)據(jù)表示最優(yōu)數(shù)據(jù))。可以看出,無論是5-way 1-shot 還是5-way 5-shot,本文算法性能都優(yōu)于傳統(tǒng)4 層元學(xué)習(xí)算法。

表4 不同算法在miniImageNet數(shù)據(jù)集上的實驗結(jié)果對比Table 4 Comparison of experimental results of different algorithms on miniImageNet dataset %
Omniglot數(shù)據(jù)集包含50種不同語言,共計1 623種手寫字符,每種字符包含20個樣本,每個樣本由不同人書寫。本文將樣本圖片大小統(tǒng)一為28 像素×28 像素,使用其中的1 028 類作為訓(xùn)練集,423 類作為測試集,剩下的作為驗證集。
輸入的樣本圖片經(jīng)過映射網(wǎng)絡(luò)得到其在度量空間中的特征表示,映射網(wǎng)絡(luò)包含4 層由G-CNN 構(gòu)成的卷積層,每層使用64 個3×3 卷積核、batch-norm、relu 激活函數(shù)以及3×3 的最大池化層。在度量空間中使用歐氏距離計算查詢集到原型點的距離,將距離最短的原型點對應(yīng)的標簽作為預(yù)測標簽,以交叉熵作為損失函數(shù),不添加正則項損失,學(xué)習(xí)率設(shè)置為10-3,使用Adam 優(yōu)化器對網(wǎng)絡(luò)參數(shù)進行優(yōu)化。
Omniglot 數(shù)據(jù)集常用的有4 種訓(xùn)練方法,分別是5-way 1-shot、5-way 5-shot、20-way 1-shot 和20-way 5-shot,測試時同樣使用對應(yīng)的方法。測試使用隨機產(chǎn)生的1 000 個元任務(wù),以平均準確率作為最后的結(jié)果。
本文算法與傳統(tǒng)4 層元學(xué)習(xí)算法在Omniglot 數(shù)據(jù)集上實驗結(jié)果對比如表5 所示(加粗數(shù)據(jù)表示最優(yōu)數(shù)據(jù)),可以看出,在5-way 1-shot、5-way 5-shot 實驗中,本文算法性能都優(yōu)于其他算法。

表5 不同算法在Omniglot 數(shù)據(jù)集上的實驗結(jié)果對比Table 5 Comparison of experimental results of different algorithms on Omniglot dataset %
本文算法針對n-wayk-shot 元學(xué)習(xí)問題,對于每個元任務(wù),需要n類支持集樣本,每類樣本包含k個實例,對q個支持集樣本進行分類,因此每個元任務(wù)的平均復(fù)雜度為O(n×k×q)。
MMBOGEC 算法與傳統(tǒng)4 層元學(xué)習(xí)算法的參數(shù)量對比如表6 所示(加粗數(shù)據(jù)表示最優(yōu)數(shù)據(jù))。可以看出,MMBOGEC 算法參數(shù)量只比原型網(wǎng)絡(luò)算法多,而少于其他4 種算法。

表6 不同算法的參數(shù)量對比Table 6 Comparison of the number of parameters of different algorithms
針對傳統(tǒng)機器學(xué)習(xí)的自適應(yīng)性問題,本文提出一種基于群等變卷積的度量元學(xué)習(xí)算法,使用群等變卷積神經(jīng)網(wǎng)絡(luò)構(gòu)建映射網(wǎng)絡(luò),充分利用樣本圖片的局部旋轉(zhuǎn)對稱性和鏡像對稱性,將樣本圖片映射到合適的度量空間,根據(jù)所提取特征到每個類原型點的距離遠近來實現(xiàn)分類。在Omniglot 數(shù)據(jù)集和miniImageNet數(shù)據(jù)集上的實驗結(jié)果表明,該算法對于元學(xué)習(xí)問題的學(xué)習(xí)性能優(yōu)于傳統(tǒng)4 層元學(xué)習(xí)算法。下一步將對本文算法進行改進,探索更有效的特征映射網(wǎng)絡(luò)和特征距離比較方法。