許仁杰,劉寶弟,張凱,劉偉鋒
(中國石油大學(華東)海洋與空間信息學院,青島 266580)
元學習(meta learning)是一種“學習如何學習”的機器學習算法。模型無關的元學習(Model Agnostic Meta Learning,MAML)算法[1]通過元任務集合中的數據學習可以快速適應某些目標數據任務的初始模型[2-5],可以使用數量有限、有標記的任務樣本進行訓練,且在訓練時可使用不同的模型,因而被廣泛應用于解決各個領域的問題[6-14]。
雖然MAML 在解決回歸、分類和強化學習等問題中都有較好的表現,但是其計算復雜度高、過擬合、梯度下降速度慢等問題還有待解決,因此研究者們從多個角度對MAML 進行了改進,包括簡化運算、改進損失函數和運算流程等。針對MAML 的計算復雜度太高與梯度更新方法計算過于復雜的問題,有學者提出了只用一階導數對二階導數進行逼近的元參數優化的一階MAML、Reptile[10]方法;這一類方法雖然可以簡化MAML 的計算復雜度,但通常是以準確率的降低作為代價。針對MAML 在某些情況下會產生過擬合或者無法有效訓練等問題,有學者通過信息論、高斯過程等方法提出了更緊的損失函數[6-7,12],或者通過設置參數使MAML 對過去學到的元知識進行遺忘[9];但這種方式通常是針對某一類具體問題而進行的,往往有它自身的局限性。針對MAML 下降速度過慢的問題,有學者從一個函數空間中構造更好的權函數[8,15]或者提出了更有效的梯度更新方式[16-18];但這一類改進方法會在提升準確率的同時帶來更大的計算量。
盡管上述方法都對MAML 進行了有效的改進,但是由于通過前兩種方法改進MAML 在訓練過程中認為每個樣本對于元知識的影響都是一樣的,無法很好地根據不同的任務對損失函數進行調整,也不能根據抽取樣本是否能很好地體現該任務的性質而改變樣本對整體的影響,所以在學習過程中依然可能會產生訓練速度低、過擬合或準確率較低等問題;而后一種方法的計算量較大。為了解決這些問題,本文通過概率方法構造出了一個更好的權函數來提高MAML 的訓練速度以及準確率。與文獻[8]中從再生核希爾伯特空間搜索損失函數不同,本文提出了一種更輕量、更便于計算的權函數,對每個任務損失函數進行加權,用來表示不同的任務在訓練過程中的重要程度。具體地,本文認為隨機抽取的任務近似符合一個高斯分布,越靠近這個高斯分布的期望的任務在元參數更新過程中占據更加重要的地位;相反地,越遠離高斯分布期望的任務所占的權重應該越小。添加這個權函數的MAML 可以在更快逼近任務分布的期望的同時避免一些小概率出現的任務對網絡訓練造成更大的影響,從而在提升訓練速度的同時增加模型的準確度,訓練好的元參數也能更適用于高概率出現的任務。
將本文方法與基礎的MAML 方法在Omniglot 與Mini-ImageNet 數據集上進行小樣本圖像分類實驗,結果表明在大多數情況下,本文方法的準確率都高于傳統的MAML。
本文主要工作包括:從高斯隨機過程的角度提出了一種與迭代相關的MAML 解釋方法,并根據這種解釋方法通過貝葉斯分析提出了一種加權的MAML——BW-MAML,最后通過實驗驗證了BW-MAML 的有效性。
本文工作的基礎是HB-MAML(Model-Agnostic Meta-Learning as Hierarchical Bayesian)[3]以及加權元學習[8]。文獻[7]從貝葉斯分析的角度出發,將元學習的過程描述為一個高斯隨機過程,并以此提出了一個正則化項;而文獻[8]從泛函分析的角度,認為常用的平方誤差與Hinge 誤差在原空間的核函數都能構成再生核希爾伯特空間,并在這個空間中選取最優的損失函數。通過文獻[8]中方法可以找到下降速度更快的損失函數,但該損失函數是通過抽取的樣本獲得的,所以根據抽取任務的不同會使損失函數產生較大的波動,進而使優化難度偏高?;谏鲜鑫墨I的成果,本文從貝葉斯分析[19]與高斯隨機過程[18,20]的角度在線性函數空間中找到一個更便于計算的最優損失函數,使更重要任務的損失在損失函數中占更大的權重,通過優化這個損失函數可以使元參數更容易向最優解進行梯度下降。

最終在外循環中使用隨機梯度下降方法通過迭代求得其最小值。在這個過程中,將每個任務的損失相加作為整個模型的損失,旨在求得在迭代一次后對每個任務的損失都最小的θ。所以將

作為外循環的迭代方法。
MAML 在訓練時,首先從任務分布中抽取一些任務,使用一個內循環針對每個任務的參數根據損失函數進行梯度下降;然后根據更新過參數的任務損失,使用一個外循環對元參數進行梯度下降,以獲得一個最適合全部任務的元參數。在這個過程中,MAML 將所有任務視為是同等重要的。
基于貝葉斯權函數的模型無關元學習就是在MAML 的元梯度更新方法上進行改進,在本文中,根據不同任務在訓練中重要性不同,在外循環的元梯度下降時求MAML 中每個任務損失的加權和,從而能使元參數更快地進行訓練。本文采用由貝葉斯分析推導而來的損失函數,因此本文將這種改進算法稱為基于貝葉斯權函數的模型無關元學習(Bayes-Weighted Model-Agnostic Meta-Learning,BW-MAML)算法。
接下來介紹貝葉斯分析角度設置的損失函數及權函數的推導過程。
高斯隨機過程[18,20]是機器學習中常用的方法之一,在實踐中可以對機器學習的梯度下降過程視為一串隨機的概率事件進行分析。而對于其中的一個隨機事件,與文獻[7]中的推導類似,根據貝葉斯分析將上文中的損失函數(3)重寫成一個概率形式:

元學習的損失函數最小的問題就轉化為一個令負log 概率最小的問題,也就是找到一個元參數,使在各個任務中經過一次或幾次梯度下降后的任務參數屬于該任務的概率最高。
基于損失函數(4),可以得到如下推斷:如果使用抽取的訓練任務以元參數為基礎進行訓練,在理想情況下,第n個元參數θ()n會在數次迭代后達到一個對該任務最優的點,記為,本文認為所有的都是對的逼近,而且由于噪聲的存在,一般認為:

由于抽取的任務隨機,并且都屬于同一個任務分布P(T),所以這些任務都獨立同分布,即它們都擁有同樣的統計學規律。根據一般性假設,在本文中認為這些任務在任務空間中都符合高斯分布[18,20],使用一個邊界似然函數來表示一步元參數更新的條件概率:



為元參數的更新方式,而不是簡單地把各個樣本看作是均勻分布。

又由于每個θi符合一個高斯分布,所以任意幾個的值的分布也應該符合一個同期望的高斯分布,所以把這個公式的右側進行歸一化作為本文算法的權函數就可以得到最終的元迭代格式:

通過將添加這個權函數的元參數更新方式替代原本的元參數更新方式,可以對優化元參數貢獻更大的損失進行強調,對出現概率較小的損失則通過較小的權函數降低其對整個迭代過程的影響。因此BW-MAML 可以降低整個梯度下降過程的隨機性,并且使終點更加趨近于所有分布的平均值,以獲得一個更重視高概率出現的任務,一定程度上忽略小概率出現任務對元參數產生的影響。訓練時算法的偽代碼如算法1 所示。
算法1 BW-MAML 的訓練過程。
輸入 任務分布p(T)步長α,β;
輸出 優化后的參數θ。

如圖1 所示,BW-MAML 等價于將MAML 通過幾個任務的參數求得下一步的元參數的過程改為通過估計元參數的期望,并將得到的期望作為下一步的元參數開始下一次迭代。

圖1 一階BW-MAML原理Fig.1 Principle of first-order BW-MAML
BW-MAML 與基礎MAML 算法的不同點體現在算法1 中的第8)行,簡單來講,傳統的MAML 算法直接將幾個任務的損失相加,而本文算法在計算任務損失函數的加權和的同時使用高斯分布的權函數而不是均勻分布,使元參數能更快、更準確地逼近最優解。
本文在Mini-ImageNet 數據集[21]與Omniglot 數據集[22]上進行了小樣本圖像分類實驗,對BW-MAML 的有效性和實用性進行驗證。
Omniglot 是一個手寫字母數據集,包含50 個不同字母的1 623 個不同手寫字符,在處理數據集時將其分成了包含30個字母的“背景”集和包含20 個字母的“評估”集;Mini-ImageNet 數據集是元學習和小樣本學習中常用的數據集之一,它包含100 類共60 000 幅彩色圖片,每類中含有600 個樣本,每幅圖片的規格為84×84。
為了驗證BW-MAML 在較小數據集上的性能,在Omniglot 數據集上測試了一階MAML(First-Order MAML,FOMAML)與BW-MAML 的5-way 1-shot、5-way 5-shot、20-way 1-shot 以及20-way 5-shot 的小樣本分類對比實驗,其中,NwayK-shot 意味著在任務中包含N個類,而每個類中包含K個樣本。在網絡選擇上,本文采用了一個使用3×3 卷積核的四層卷積神經網絡(Convolutional Neural Network,CNN)作為其內容網絡。在訓練過程中,每次從訓練集中隨機抽取6 個訓練任務,然后對內容網絡按照一階BW-MAML、一階MAML等不同算法針對每個任務每次進行5 次梯度下降,總共進行60 000 次迭代。對于超參數,與MAML 相同,本文選擇任務參數學習率α=0.1,元參數學習率β=0.001,元參數的訓練使用Adam[23]作為優化器。本文將準確率定義為測試集中預測正確的數量與總量的比值,表1 中的準確率是10 組準確率的平均值。從表1 可以看出,在Omniglot 數據集上,1-way 5-shot 與5-way 5-shot 時BW-MAML 和MAML 的準確率接近,20-way 1-shot 與20-way 5-shot 時,BW-MAML 相對MAML 的準確率平均提升了0.199 個百分點。

表1 兩種算法在Omniglot上的準確率對比 單位:%Tab.1 Accuracy comparison of two algorithms on Omniglot unit:%
在較大的數據集Mini-ImageNet 上進行實驗時,本文將Mini-ImageNet 隨機分為不相交的訓練集與測試集,并將訓練集依次傳入對網絡進行訓練。與上一組實驗類似,本文在Mini-ImageNet 上進行了一階、二階MAML 與一階、二階BWMAML 與其他元學習算法的5-way 1-shot、5-way 5-shot 的小樣本分類對比實驗,除每次訓練迭代100 000 次以外,其他超參數與在Omniglot 上的實驗一致。實驗結果如表2 所示,可以看出,在Mini-ImageNet 上BW-MAML 的各項準確率都比MAML 更高。通過使用權函數對損失的重要性進行區分,BW-MAML 比MAML 的平均準確率提高了0.907 個百分點,可見本文的方法無論是在Omniglot 還是在Mini-ImageNet 這樣略大的數據集上都表現得更好。

表2 Mini-ImageNet上的準確率對比 單位:%Tab.2 Accuracy comparison on Mini-ImageNet unit:%
為了驗證每次抽取的不同任務數對模型的影響,在Mini-ImageNet 中使用5-way 1-shot 的一階BW-MAML 并進行60 000 次迭代,每隔500 步使用100 個測試任務對模型效果進行評估,然后選取了準確率變化較明顯的訓練時期(前段)以使結果更為明顯,其他參數設置與之前的實驗相同。從第n=500,1 000,1 500,2 000,2 500 步與訓練完成后最終的準確率探究了每次抽取4 個、6 個與8 個任務對BW-MAML 訓練速度的影響,結果如表3 所示。從表3 可以看出,BW-MAML 在收斂速度方面的效果也優于MAML,在訓練進行2 500 步后,6 任務時BW-MAML 的準確率是最高的,且比同樣6 任務的MAML 準確率提高了1.9 個百分點。但在訓練完成后,6 任務的最終的準確率介于8 任務和4 任務的準確率之間??梢婋m然最終的準確率和每次訓練所用的任務數存在正比例關系,但在2 500 步內,BW-MAML 在6 任務情況下的訓練速度最快。

表3 針對不同任務數在Mini-ImageNet上的準確率對比Tab.3 Contrast experiment for different task numbers on Mini-ImageNet
由于MAML 在選擇任務上具有隨機性,而在實際使用這些任務進行訓練時并沒有考慮每個任務對元參數的影響。在本文中通過理論推導并論證了一種新的貝葉斯加權的MAML,然后通過實驗驗證了這個方法在兩個數據集上的實用性,并通過一個對比實驗檢驗了超參數(任務數)的選擇,這證明本文提出的方案確實提升了實驗的準確率,本文的方法可以提升在較為符合高斯分布的數據集上的準確率。在常用的數據集中BW-MAML 比MAML 的準確率更高。但還有很多新的思路亟待嘗試,比如先選擇一種更好的損失基函數,然后再對這組基函數求出最優的權系數;或者先通過一些方法求出樣本大概的分布情況,然后在這個基礎上再進行加權;再或者直接通過高斯過程設計出新的結構以取代梯度下降等。