论文解读(AdSPT)《Adversarial Soft Prompt Tuning for Cross-Domain Sentiment Analysis》
Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]
论文信息
论文标题:Adversarial Soft Prompt Tuning for Cross-Domain Sentiment Analysis
论文作者:Hui Wu、Xiaodong Shi
论文来源:2022 ACL
论文地址:download
论文代码:download
视屏讲解:click
1 介绍
动机:直接使用固定的预定义模板进行跨域研究,不能对不同域的 [MASK] 标记在不同域中的不同分布进行建模,因此没有充分利用提示调优技术。在本文中,提出了一种新的对抗性软提示调优方法(AdSPT)来更好地建模跨域情绪分析;
贡献:
-
- 在提示式调优中,本文采用软提示来学习领域知识嵌入,从而减轻了 [MASK] 位置的领域差异;
- 设计了一种新的对抗性训练策略来学习 [MASK] 位置的域不变表示;
- 在 Amazon 评论数据集上的实验表明,AdSPT 方法在单源域适应下、多源域适应 取得了重大改进;
2 相关
在 book review 和 videp review 中,预测高频词不一样,且 高频词不在 预定义输出 “good,bad” 中;
3 方法
模型框架:
3.1 提示输入
提示输入 $\boldsymbol{x}_{\text {prompt }}$:
$\boldsymbol{x}_{\text {prompt }}= {[\mathbf{e}(\text { "CLS }] "), \mathbf{e}(\boldsymbol{x}), \mathbf{h}_{0}, \ldots, \mathbf{h}_{k-1}, }\mathbf{e}(\text { "[MASK]") }, \mathbf{e}(\text { "[SEP]") })]$
注意:输入 $\boldsymbol{x}_{\text {prompt }}$ 不是一个 $\text{raw text}$ ,而是一个嵌入矩阵,$\text{nn.Embedding}$ 后的结果;
3.2 Encoder 输出
将 $\boldsymbol{x}_{\text {prompt }}$ 输入编码器,得到:
$\mathbf{h}_{[\mathrm{MASK}]}, \mathbf{s}_{[\mathrm{MASK}]}=\mathcal{M}\left(\boldsymbol{x}_{\text {prompt }}\right) $
其中,$\mathbf{h}_{[\text {MASK }]} \in \mathbb{R}^{h}$,$\mathbf{s}_{[\text {MASK }]} \in \mathbb{R}^{|\mathcal{V}|}$,$\mathrm{s}_{[\mathrm{MASK}]}= f\left(\mathbf{h}_{[\text {MASK }]}\right) $,$f$ 是 $\text{MLM head function}$;
3.3 情感分类
情感预测:
$\begin{aligned}p(y \mid \boldsymbol{x}) & =p\left(\mathcal{V}_{y}^{*} \leftarrow[\mathrm{MASK}] \mid \boldsymbol{x}_{\text {prompt }}\right) \\& =\frac{\exp \left(\mathbf{s}_{[\mathrm{MASK}]}\left(\mathcal{V}_{y}^{*}\right)\right)}{\sum_{y^{\prime} \in \mathcal{Y}} \exp \left(\mathbf{s}_{[\mathrm{MASK}]}\left(\mathcal{V}_{y^{\prime}}^{*}\right)\right)}\end{aligned}$
其中,$\mathcal{V}^{*} \in \{ \text{good,bad} \}$;
情感分类损失:
$\mathcal{L}_{\text {class }}\left(\mathcal{S} ; \theta_{\mathcal{M}, p, f}\right) =-\sum_{i=1}^{N} {\left[\log p\left(y_{i} \mid \boldsymbol{x}_{i}\right)^{\mathbb{I}\left\{\hat{y}_{i}=1\right\}}\right.} \left.+\log \left(1-p\left(y_{i} \mid \boldsymbol{x}_{i}\right)\right)^{\mathbb{I}\left\{\hat{y}_{i}=0\right\}}\right]$
3.4 域对抗性训练
设有 $\text{m}$ 个源域 ,源域、目标域的域标签分别为 $0 , 1$,$m$ 个域鉴别器 $\mathbf{g}=\left\{g_{l}\right\}_{l=1}^{m}$;
域预测:
$p(d \mid \boldsymbol{x})=\frac{\exp \left(g_{l}^{d}\left(\mathbf{h}_{[\mathrm{MASK}]}\right)\right)}{\sum_{d^{\prime} \in \mathcal{D}} \exp \left(g_{l}^{d^{\prime}}\left(\mathbf{h}_{[\mathrm{MASK}]}\right)\right)}$
域分类损失:
$\mathcal{L}_{\text {domain }}\left(\hat{\mathcal{S}}, \mathcal{T} ; \theta_{\mathcal{M}, p, \mathbf{g}}\right) =-\sum_{l=1}^{m} \sum_{i=1}^{N_{l}^{s}+N^{t}} {\left[\log p\left(d_{i} \mid \boldsymbol{x}_{i}\right)^{\mathbb{I}\left\{\hat{d}_{i}=1\right\}}\right.}\left.+\log \left(1-p\left(d_{i} \mid \boldsymbol{x}_{i}\right)\right)^{\mathbb{I}\left\{\hat{d}_{i}=0\right\}}\right]$
域对抗训练:
$\underset{\mathcal{M}, p}{\text{max}}\; \underset{\mathbf{g}}{\text{min}} \;\mathcal{L}_{\text {domain }}\left(\hat{\mathcal{S}}, \mathcal{T} ; \theta_{\mathcal{M}, p, \mathbf{g}}\right)$
3.5 训练目标
优化 $\text{PLM}$ $\mathcal{M}$ ,$\text{soft prompt embeddings}$ $p$ , $\text{MLM head function}$ $f$,$\text{domain discriminators }$ $\mathbf{g}$:
$\underset{\mathcal{M}, p, f}{\text{min}} \{ \lambda \mathcal{L}_{\text {class }}\left(\mathcal{S} ; \theta_{\mathcal{M}, p, f}\right) \left.-\underset{\mathbf{g}}{\text{min}} \mathcal{L}_{\text {domain }}\left(\hat{\mathcal{S}}, \mathcal{T} ; \theta_{\mathcal{M}, p, \mathbf{g}}\right)\right\}$
3.6 算法
如下:
4 实验
single-source domain adaptation on Amazon reviews
Results of multi-source domain adaptation on Amazon reviews
Ablation experiments
因上求缘,果上努力~~~~ 作者:图神经网络,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/17665254.html