徐銘美,方 睿,羅 鳴,雷 蕾
(成都信息工程大學,四川 成都610225)
天氣與人們的生活息息相關,但由于其變化莫測的特性,長期以來都有許多專家和學者投身于氣候和天氣預測的研究當中。目前,國內外對降雨預測的研究大多都局限于衛星云圖的數據。但由于傳統的天氣預測方法存在很多弊端,如需要的數據量大,計算時間久,代價大,專業性高等,越來越多學者開始關注運用其它的自然天氣數據來進行預測研究。其中,一些科研人員提出了利用云朵圖片進行分類和檢測識別,如張飛等提出的基于深度卷積神經網絡的云分類算法[1],從而能夠快速簡單的獲取有價值的天氣信息。
隨著人工智能的不斷發展,在氣候研究方面也取得了較好的成果。如張敏靖等提出的基于對抗和遷移學習的災害天氣衛星云圖分類[2],以及李冰潔等提出的氣象衛星系統的云圖自動分類識別研究[3],均運用了深度學習算法對天氣進行預測。其中,降雨作為最為常見的天氣之一,也成為了天氣預測的重點。如高利峰提出的降雨量預測方法研究[4]。但是,由于目前關于云朵圖像的標準公開數據集較少,所以在進行天氣預測時,面臨著數據集缺乏的問題。大部分的深度學習算法都需要大量的相關數據進行訓練,才能得到較好的準確度。因此,目前利用深度學習框架對降雨云圖像進行分類存在著以下兩個問題:淺層的卷積神經網絡不能充分地提取降雨云圖像的特征信息;降雨云圖像數據樣本小導致深層卷積網絡在訓練過程中容易過擬合。
針對本次實驗的研究內容,本文采用小樣本學習的方法,即僅利用少量樣本就可以訓練得到不錯的效果,從而解決了上述問題,并得到一種更加簡單,便捷,代價小,專業性要求較低的天氣預測方法。本文提出的元基線改進模型Distillation-Meta-Baseline(后簡稱為D-MB模型)的實驗結果也表明,其能夠在少樣本數據的條件下,實現較好的分類效果。從而為天氣預測提供了一種有效地,新穎的,實時性強的輔助決策方法。
本文所采用的自建降雨云圖像數據由世界氣象組織(World Meteorological Organization,簡稱WMO)提供,其對應的相關降雨信息也均來自WMO官網。由于世界氣象組織所提供的數據包含了所有類型的云圖像,所以本文需要對其進行二次整合,從而得到一個較為標準的降雨云圖像數據集。同時,為了避免數據樣本數量分布及內容上的差異而導致的一些問題,本文首先對圖像進行增強和歸一化處理,通過隨機旋轉,偏移等操作提高樣本數量。并且依照世界氣象組織提供的降雨云分類標準,以及不同的降雨云類型將數據集分為6類:高層避光云,高積堡狀云,積雨鬢狀云,積雨禿狀云,雨層云,以及鉤卷云。如表1所示。

表1 降雨云數據信息表
每類圖像20張,數據集的總樣本數量為120張,部分降雨云樣本如圖1所示。

圖1 降雨云圖像
小樣本學習與傳統的監督學習不同,它的目標是使模型通過訓練學會學習,而不是著重于讓機器識別訓練集中的圖片并泛化到測試集上。與此同時,隨著小樣本學習領域的快速發展,元學習作為小樣本學習中最主要的一類方法,也涌現出了大量的新算法。常見的元學習框架可以大致分為3類:基于記憶的元學習方法,基于優化的元學習方法,以及基于距離度量的元學習方法。
基于記憶的元學習方法即在原本的元學習框架中添加記憶機制,使得模型能夠對學習到的知識進行總結提取,并輔助后面的學習任務。Ravi和Larochelle在2017年便提出將LSTM與元學習相結合的優化算法[9]。
基于優化的元學習方法是指通過元學習模型在支持集上進行優化操作。基于優化的元學習方法有很多,其中包括了選擇初始化參數,采用不同的梯度更新算法等。例如Finn等人在MAML[10]中就提出通過為每個訓練任務提供一個更好的初始化參數,從而達到更佳的訓練效果。
基于距離的元學習方法通過度量特征之間的距離來進行網絡的訓練。在距離度量的元學習模型中,通過計算比較詢問集和支持集樣本特征之間的距離來實現分類。如Vinyals等人在2017年發表的Matching Network[11]就提出了一種快速學習樣本間度量方式的框架。
盡管上述的元學習算法已經取得了許多優異的成績,但是近年來一些對預訓練分類器性能的研究,如Gidaris和Komodakis[12]提出的余弦度量分類器訓練方法,以及Yinbo Chen[7]等人基于元基線提出的優化方法,表現出的效果更優于之前的幾種元學習方法,尤其是在面對跨域問題時。
因此,本文提出的D-MB模型將分類器基線和元學習的優點相結合,并引入知識蒸餾的思想,使得模型性能優于以往的方法,并將其應用于降雨云圖像的分類和天氣預測。
基于小樣本的D-MB模型訓練主要分為兩個階段:分類器訓練階段和元學習階段。對于分類器訓練階段,需要使用大量帶標簽的基類(Cbase)數據訓練出一個分類器,從而為后面的元學習模型提供性能優異的特征提取器(或稱為編碼器encoder)。然后,將新類(Cnovel)圖像數據輸入到到元學習框架中進行訓練和學習。其中,值得注意的是本文采用的基類數據來自于公開數據集cifar100,新類數據為本文自建的降雨云數據集。
3.2.1 分類器模型
傳統的分類器模型是通過使用大量的數據獨立訓練而成的,其訓練結果的好壞通常都是通過與數據標簽進行對比得到的。但是數據標簽包含的信息量往往較少,只能反映出結果的對錯。所以,有學者提出在訓練分類器模型時,引入一個預訓練好的復雜模型(或稱為教師模型)來進行輔助,此時被訓練的分類器模型稱為學生模型。具體而言,即使用教師模型中的softmax層輸出來作為另一種“標簽”,Hinton[13]將其稱為soft target,與學生模型的輸出進行比較,從而獲得更加豐富的反饋信息。種訓練的過程就被稱作為“知識蒸餾”。值得注意的是,如果在實驗過程中soft target的數值方差太大,則引入教師模型的意義就不大了,所以在這里需要引入溫度參數T來控制教師模型對學生模型的影響,具體可見式(1)。

(1)
其中zi表示分類器模型中softmax層的輸出,zj表示其它模型的輸出,qi表示zi與zj之前的關聯度。溫度參數的值是根據具體實驗的要求進行人為設置的,常設為1。另外,T如果太大了,會導致正確項的數值與錯誤項的數據差距太小,無法區分出哪個是正確的選項;T如果太小了,模型在“蒸餾”過程中會弱化soft target的作用,從而失去了蒸餾的意義。
根據Hinton[13]的研究表明,可以根據自身的實驗需求進行教師模型的選擇,不一定要是一個復雜的網絡模型。最后,會得到一個網絡層數更淺,運行更快,但準確度堪比同類型復雜網絡的分類器模型。分類器蒸餾訓練的具體流程如圖2所示。
3.2.2 元學習模型

圖2 分類器蒸餾實驗流程
在進行元學習模型訓練之前,需要先將蒸餾過后的分類器網絡去掉全連接層,并將其作為元學習模型中的encoder。同時,將Cnovel的數據劃分為支持集(support set)和詢問集(query set)。
元學習的主要特點是以task作為基本單位進行網絡訓練,即將整個網絡的訓練過程分為多個小任務進行。在每個task中,需要在支持集上的N個類各抽取K張的降雨云圖像(即N-way K-shot)輸入到編碼器fθ中,從而提取出各類數據的特征,同時在詢問集中也要抽出一定數量的圖片進行特征提取。然后分別計算出詢問集數據與支持集中各類數據之間的相似度,最后將計算結果與詢問集中抽取的數據標簽進行對比,計算出loss。其中,相似度的計算可以選用L2或者COS來度量兩者之間的距離。具體的元學習模型框架如圖3所示。
3.2.3 損失函數

圖3 元學習模型算法流程
a) 分類器模型的損失函數
本文的蒸餾實驗采用標準交叉熵作為分類器訓練的損失函數,其loss通常包括了兩個部分:一個是學生模型與教師模型輸出之間的loss1,另一個是學生模型與數據標簽之間的loss2。具體的損失函數見式(2)
loss=loss1+loss2
(2)
其中,loss1的具體計算可見式(3)

(3)
loss2的具體計算可見式(4)

(4)

b) 元學習模型的損失函數
由于整個元基線模型的損失函數是由每個訓練任務的損失一起構成的,所以需要計算每個任務的損失。首先,在支持集中計算N個類的質心,這些質心定義在式(5)中。

(5)
然后,用式(6)計算定義的查詢集中每個樣本的預測概率分布。

(6)
其中,S為支持集,Sc為在c類別的選取的數據,wc為計算出的特征平均值(類中心),fθ為編碼器函數,x為輸入的圖像樣本數據,p為計算出的余弦相似度。
損失是由p和查詢集中樣本的標簽計算的交叉熵損失,具體可見式(7)

(7)
值得注意的是,將每個任務都視為訓練過程中的數據點,每個batch可能包含多個任務,并計算平均損失。
本文的實驗環境為Linux操作系統,采用英偉達(NVDIA)顯卡,CUDA10.0,Pytorch版本為3.7,顯卡內存為12G。
實驗主要分為了三部分:運用知識蒸餾的思想訓練分類器模型,并與未經蒸餾訓練的模型進行對比;構建一個基于小樣本的元學習分類模型實現降雨預測,并對比不同深度的分類器網絡對整個元學習模型準確度的影響;與目前主流的元學習模型進行效果對比。
通過運用知識蒸餾的思想,本文選擇ResNet110作為預訓練的教師模型。同時考慮到不同網絡深度的encoder對D-MB模型分類效果的影響,分別選擇ResNet12,ResNet18,ResNet34,ResNet50作為學生模型,在cifar100數據集上進行訓練對比。
本小節的實驗主要分為兩個部分。首先,對上面所提及的4個學生網絡分別進行了隨機初始化獨立訓練。然后,再對同樣的4個學生網絡進行蒸餾實驗。訓練的基本設置為迭代200次,批處理數量為128,學習率為0.1,權重遞減1e-4,學習動量為0.9,優化器選用Adam。兩次的具體結果如下表2所示。

表2 分類器模型結果對比
從上表可以觀察到,隨著網絡的加深,學生模型的準確度越來越高。且通過蒸餾實驗訓練出的模型準確度均高于獨立訓練的模型。由此可見,運用知識蒸餾的思想可以明顯提高分類器網絡的性能。
本次實驗采用的數據集為自建的降雨云數據。為了體現D-MB模型在跨域分類方面的有效性,本文將該數據集(共6個類,120張)中的4個類劃分為支持集(80張),剩下的2個類劃分為詢問集(40張),然后分別將shot數設置為1和5進行訓練和測試。
在進行D-MB實驗之前,需要將蒸餾后的學生網絡去掉全連接層作為元基線模型的encoder。在接下來的實驗中,分別采用ResNet12,ResNet18,ResNet34,ResNet50作為元基線模型的encoder進行實驗對比,挑選出性能最佳的主干網絡模型。訓練的基本設置為迭代20次,每個task的batch為4,學習率為0.001,權重遞減1e-4,優化器選用Adam。具體的模型準確度測試結果如表4所示。

表4 D-MB模型實驗結果
由上表觀察可得,ResNet12在降雨云數據集上的分類效果最佳,因此本文采用ResNet12作為元基線模型的encoder。D-MB模型訓練精度如圖4和圖5所示。

圖4 ResNet12的1shot訓練準確度
從圖中可以看出,D-MB模型的1shot和5shot的訓練精度分別可以達到57%和74.78%。

圖5 ResNet12的5shot訓練準確度
目前,元學習模型方法大致可以分為基于記憶的元學習方法,基于優化的元學習方法,以及基于距離度量的元學習方法這3類。在本小節中,本文主要選擇兩種應用較為廣泛且效果得到了學界認可的元學習模型——Prototype Network和Matching Network。在進行對比實驗時,對于Prototype Network和Matching Network模型訓練的基本設置均為迭代80次,采用隨機梯度下降法,學習率為0.001。最后的實驗對比結果如下表5所示。

表5 元學習模型實驗結果對比
從表中可以看出,本文提出的D-MB模型相比于Prototype Network和Matching Network模型在自建的降雨云數據上表現出了更高的預測準確度。
本文根據世界氣象組織提供的資料,建立了一個新的降雨云圖像數據集,其中包括了6類降雨云,各類20張,共120張圖片。并在此基礎上提出了一種基于小樣本學習的降雨云分類模型(D-MB模型),進行天氣預測。整個模型分為了兩個部分:分類器模型和元學習模型。其中,在分類器模型的訓練過程中引入知識蒸餾的思想,使得相比于傳統獨立訓練出的分類模型準確率要高。利用降雨云數據集訓練出的元學習模型的1shot和5shot測試精度高達54.2%和70.2%。相比于目前常見的元學習分類模型擁有更好的跨域性和更高的準確性。通過各類降雨云對應的降雨信息,在一定程度上可以實現實時有效地天氣預測。