蔣帥 何良華


關(guān)鍵詞:深度學(xué)習(xí);小樣本;關(guān)系三元組抽取;對(duì)抗訓(xùn)練
1 概述
深度神經(jīng)網(wǎng)絡(luò),尤其是預(yù)訓(xùn)練語(yǔ)言模型(如BERT) 的使用,使得關(guān)系抽取任務(wù)獲得了極大的性能提升。然而,現(xiàn)有的很多方法往往依賴大量的標(biāo)注數(shù)據(jù),且很難解決數(shù)據(jù)的長(zhǎng)尾分布問(wèn)題,如關(guān)系分類問(wèn)題,訓(xùn)練數(shù)據(jù)集中不同類別的樣本數(shù)量不同,某些關(guān)系類別樣本非常少,對(duì)于這些關(guān)系難以準(zhǔn)確分類。又如,由于醫(yī)學(xué)中的罕見(jiàn)疾病病人數(shù)量非常少,難以獲得大量數(shù)據(jù)樣本。
為了解決這些問(wèn)題,一些基于Few-Shot的實(shí)體識(shí)別和關(guān)系分類數(shù)據(jù)集和算法被提出。小樣本學(xué)習(xí)結(jié)合原型網(wǎng)絡(luò),能夠很好地學(xué)習(xí)類別不斷變化情況下模型的泛化能力。對(duì)于目標(biāo)領(lǐng)域樣本較少的情況,通常可以借助另一較大的數(shù)據(jù)集(源域)學(xué)習(xí)通用知識(shí),然后應(yīng)用到目標(biāo)域上。此外,在源域上訓(xùn)練時(shí)結(jié)合領(lǐng)域判別對(duì)抗訓(xùn)練,可以學(xué)習(xí)源域和目標(biāo)域共性特征,能有效提升模型的泛化能力。結(jié)合這些方法和思想,本文的主要工作如下:
1) 結(jié)合域判別對(duì)抗訓(xùn)練在Wiki域的數(shù)據(jù)集上預(yù)訓(xùn)練模型,為小樣本醫(yī)學(xué)關(guān)系三元組抽取學(xué)習(xí)共性的知識(shí);
2) 提出一個(gè)基于殘差連接的原型網(wǎng)絡(luò)模塊應(yīng)用于小樣本醫(yī)學(xué)關(guān)系三元組抽取問(wèn)題,取得了很好的抽取效果。
2 相關(guān)工作
2.1 小樣本關(guān)系三元組抽取
小樣本學(xué)習(xí)(Few-shot Learning) 模型大致分為三類:基于模型的方式、基于度量的方式和基于優(yōu)化的方式。基于模型的方式致力于改進(jìn)模型的結(jié)構(gòu),使得在少量樣本上快速更新模型的參數(shù),基于度量的方式(如原型網(wǎng)絡(luò)[1-3]) 通過(guò)度量測(cè)試樣本和支持集樣本的距離完成分類,基于優(yōu)化的方式則致力于改進(jìn)參數(shù)的優(yōu)化方法。當(dāng)前,在小樣本關(guān)系三元組抽取領(lǐng)域,大多數(shù)算法都采用基于度量的方式解決分類問(wèn)題。
Haiyang Yu 等人[4]在2020 年提出了MPE(Multi-Prototype Embedding) 模型,應(yīng)用于小樣本關(guān)系三元組抽取,先采用序列標(biāo)記的方式抽取實(shí)體,然后根據(jù)support樣本學(xué)習(xí)實(shí)體原型和句子原型用于表征關(guān)系原型,關(guān)系分類準(zhǔn)確度非常高。然而,這種方式使得整體關(guān)系三元組的抽取效果強(qiáng)依賴于實(shí)體抽取結(jié)果,盡管能夠取得較高的關(guān)系分類性能,但是由于實(shí)體抽取結(jié)果較差,導(dǎo)致最終關(guān)系三元組抽取結(jié)果不理想。
Xin Cong等人[5]在2022年提出了RelATE模型,基于原型學(xué)習(xí)和注意力網(wǎng)絡(luò)先對(duì)關(guān)系進(jìn)行分類,然后在每個(gè)關(guān)系下學(xué)習(xí)實(shí)體的START標(biāo)記原型和END標(biāo)記原型,即識(shí)別出實(shí)體的第一個(gè)token和最后一個(gè)token 位置,從而抽取出實(shí)體,然后和之前的關(guān)系組合成關(guān)系三元組。此算法相對(duì)于MPE,采用關(guān)系指導(dǎo)的方式,一方面避免了實(shí)體對(duì)后續(xù)模塊的影響,另一方面將實(shí)體和關(guān)系進(jìn)行了一定的語(yǔ)義結(jié)合。同時(shí),實(shí)踐證明,基于原型網(wǎng)絡(luò)抽取實(shí)體比直接采用序列標(biāo)記方式在小樣本情況下效果更好。但是相對(duì)地,關(guān)系分類的準(zhǔn)確度有一定的下降。
2.2 跨域小樣本學(xué)習(xí)
對(duì)于當(dāng)前域(稱為目標(biāo)域)樣本較少的情況,通常可以借助于另一個(gè)較大的域(稱為源域)學(xué)習(xí)一些共性的知識(shí),然后將其應(yīng)用于目標(biāo)域以提升性能。此外,域判別對(duì)抗訓(xùn)練能夠?qū)W習(xí)不同域的共性特征,有助于提升模型的泛化能力,如文獻(xiàn)[6]中的Proto-ADV(BERT)網(wǎng)絡(luò),基于原型學(xué)習(xí)和對(duì)抗訓(xùn)練的方式提升小樣本關(guān)系抽取醫(yī)學(xué)域適應(yīng)模型的性能。
3 小樣本醫(yī)學(xué)關(guān)系三元組抽取模型
小樣本醫(yī)學(xué)關(guān)系三元組(Few-Shot Bio Triple Ex?traction) 的總體模型如圖1所示,包含4個(gè)主要模塊:編碼器模塊、關(guān)系分類模塊、實(shí)體識(shí)別模塊以及域判別對(duì)抗訓(xùn)練模塊。本文主要介紹當(dāng)前所做的醫(yī)學(xué)域適應(yīng)工作,用于解決小樣本醫(yī)學(xué)關(guān)系三元組抽取問(wèn)題,包括一個(gè)殘差連接的實(shí)體標(biāo)記原型網(wǎng)絡(luò)以及域判別對(duì)抗訓(xùn)練模塊,其余模塊包括關(guān)系分類模塊和實(shí)體識(shí)別模塊,見(jiàn)圖2和圖3。
3.1 問(wèn)題描述
依據(jù)傳統(tǒng)的Few-shot任務(wù)設(shè)定,將小樣本關(guān)系三元組抽取問(wèn)題定義為NwayKshot問(wèn)題,其中N對(duì)應(yīng)每次分類時(shí)關(guān)系的類別數(shù),K表示每個(gè)關(guān)系類別Support 樣本的數(shù)量。對(duì)于每個(gè)Query語(yǔ)句,關(guān)系分類問(wèn)題即對(duì)N個(gè)類別進(jìn)行分類;而實(shí)體識(shí)別問(wèn)題,使用傳統(tǒng)的實(shí)體標(biāo)記(BH、IH、BT、IT、O分別對(duì)應(yīng)頭實(shí)體第一個(gè)to?ken、頭實(shí)體其他token、尾實(shí)體第一個(gè)token、尾實(shí)體其他token、非實(shí)體token) 對(duì)句子進(jìn)行序列標(biāo)記,實(shí)體識(shí)別即建模為標(biāo)記預(yù)測(cè)問(wèn)題。
3.2 殘差連接的實(shí)體標(biāo)記原型網(wǎng)絡(luò)
對(duì)于每個(gè)query語(yǔ)句經(jīng)過(guò)encoder層,獲得每個(gè)to?ken的特征表示Qori(t1,t2,…,tn),其中n 為句子的長(zhǎng)度,類似的每個(gè)support語(yǔ)句經(jīng)過(guò)encoder層得到Sori(t1,t2,…,tn)。對(duì)于support樣本,由于每個(gè)token的實(shí)體標(biāo)記已知,本文對(duì)同一關(guān)系類別的K 個(gè)樣本的同類實(shí)體標(biāo)記對(duì)應(yīng)的token特征向量做平均池化,得到實(shí)體標(biāo)記原型Tori(N,5,D)的特征矩陣,其中“5”是實(shí)體標(biāo)記的類別共5 類,D 是特征向量長(zhǎng)度。Qori(t1,t2,…,tn)和Tori(N, 5, D)基于注意力機(jī)制得到Attention之后的特征表示Qatt(t1,t2,…tn) 和Tatt(N, 5, D),則實(shí)體識(shí)別模塊的輸入:query的token特征矩陣Q(t1,t2,..tn)= Qori(t1,t2,…,tn) || Qatt(t1,t2,…tn),實(shí)體標(biāo)記原型T(N,5,D)= Tori(N, 5, D) ||Tatt(N, 5, D),其中||表示拼接操作。
3.3 域判別對(duì)抗訓(xùn)練
域判別對(duì)抗訓(xùn)練模塊的總體流程如圖4所示。從源域(Wiki) 和目標(biāo)域(Bio) 中分別選取M個(gè)樣本構(gòu)造無(wú)標(biāo)記樣本集合W和B,每個(gè)batch分別從兩個(gè)域選取m個(gè)樣本,經(jīng)過(guò)encoder層對(duì)句子進(jìn)行編碼,選取CLS作為句子表征,2m個(gè)特征構(gòu)成特征矩陣E(2m,D),經(jīng)過(guò)FFN 層得到預(yù)測(cè)結(jié)果Y(2)=W2*ReLU(W1E(2m,D)+B1)+B2,其中D 表示特征向量長(zhǎng)度,“2”表示有2個(gè)域。最終,域判別對(duì)抗訓(xùn)練模塊的損失:
4 實(shí)驗(yàn)與結(jié)果分析
4.1 實(shí)驗(yàn)數(shù)據(jù)集
本文實(shí)驗(yàn)均基于Fewrel 2.0 da(domain adaption) 數(shù)據(jù)集(詳見(jiàn)文獻(xiàn)[6]) 。該數(shù)據(jù)集包含來(lái)自Wikipedia 語(yǔ)料庫(kù)和Wikidata知識(shí)庫(kù)采集的80種關(guān)系每種關(guān)系包含700個(gè)樣本,以及從PubMed數(shù)據(jù)庫(kù)采集的10種醫(yī)學(xué)關(guān)系,每種關(guān)系包含100個(gè)醫(yī)學(xué)樣本。
4.2 實(shí)驗(yàn)設(shè)置
所有實(shí)驗(yàn)基于Python3.8 和Pytorch1.7 框架,在NVIDIA GEFORCE 3090 GPU上進(jìn)行訓(xùn)練和測(cè)試。實(shí)驗(yàn)隨機(jī)選取Wiki域的50種關(guān)系訓(xùn)練模型(35 000個(gè)樣本),隨機(jī)選取PubMed域的3種關(guān)系作為驗(yàn)證數(shù)據(jù)集(300個(gè)樣本),余下7種關(guān)系作為測(cè)試數(shù)據(jù)集(700個(gè)樣本)。所有實(shí)驗(yàn)均在Wiki域迭代訓(xùn)練10 000次,batch?size固定為1,每次迭代以5way5shot方式采樣support 樣本和query樣本,在驗(yàn)證集上測(cè)試3way3shot關(guān)系三元組抽取的F1-score,保存取得最優(yōu)結(jié)果時(shí)的模型,然后在測(cè)試集上分別以3way-3shot 和7way-7shot 方式隨機(jī)采樣1 000次,計(jì)算關(guān)系三元組預(yù)測(cè)結(jié)果的pre?cision、recall、F1-score的均值。
對(duì)比實(shí)驗(yàn)設(shè)置如下:
1) Rel+EntTag ProtoNet:關(guān)系采用ProtoNet,實(shí)體采用BIO標(biāo)記原型(即不區(qū)分關(guān)系);
2) Rel+RGEntTag ProtoNet:關(guān)系采用ProtoNet,實(shí)體采用關(guān)系指導(dǎo)的實(shí)體BIO標(biāo)記原型;
3) RelATE:文獻(xiàn)[3]中的方法;
4) FSBTE,本文方法;
5) FSBTE-Adv,減去域判別對(duì)抗訓(xùn)練模塊;
6) FSBTE-Adv-Roberta_Bio,在5)的基礎(chǔ)上進(jìn)一步將Roberta-Bio語(yǔ)言模型[7]替換為Bert語(yǔ)言模型,作文Encoder模塊;
7) FSBTE-Adv-Roberta_Bio-Ori Feature,在6)的基礎(chǔ)上進(jìn)一步減去殘差連接模塊中Bert得到的token 特征表示,僅根據(jù)關(guān)系模塊中attention得到token特征表示計(jì)算實(shí)體標(biāo)記原型;
8) FSBTE-Adv-Roberta_Bio-Att Feature,在6)的基礎(chǔ)上進(jìn)一步減去殘差連接模塊中關(guān)系模塊的atten?tion得到token特征表示,僅根據(jù)Bert得到的token表示計(jì)算實(shí)體標(biāo)記原型。
4.3 實(shí)驗(yàn)結(jié)果
3way3shot和7way7shot醫(yī)學(xué)關(guān)系三元組抽取實(shí)驗(yàn)結(jié)果如表1 和表2 所示。根據(jù)Rel+EntTag ProtNet、Rel+RGEntTag ProtoNet和RelATE三組實(shí)驗(yàn)結(jié)果可以看出,采用關(guān)系指導(dǎo)的方式將實(shí)體根據(jù)關(guān)系區(qū)分學(xué)習(xí)實(shí)體的原型表示,可以顯著地提升關(guān)系三元組抽取的性能;根據(jù)FSBTE和FSBTE-Adv兩組對(duì)比實(shí)驗(yàn)可以看出,域判別對(duì)抗訓(xùn)練方式對(duì)于提升模型的泛化能力依然是十分有效的手段;根據(jù)最后三組對(duì)比實(shí)驗(yàn)可以看出,本文提出的殘差連接模塊極大地提升了Wiki域適應(yīng)到醫(yī)學(xué)域泛化性能,表明其對(duì)于小樣本醫(yī)學(xué)關(guān)系三元組抽取問(wèn)題的有效性。
5 總結(jié)
本文提出了一個(gè)基于殘差連接的原型網(wǎng)絡(luò)模塊,應(yīng)用于小樣本醫(yī)學(xué)關(guān)系三元組抽取,同時(shí)結(jié)合域判別對(duì)抗訓(xùn)練,提升了網(wǎng)絡(luò)域適應(yīng)能力。多組對(duì)比實(shí)驗(yàn)證明了本文方法的有效性。