GME模型复现
文章原文请见Learning Graph Meta Embeddings for Cold-Start Ads in Click-Through Rate Prediction。
文章的核心思路就是根据物品相似度去建一张图,用新广告的邻居老广告的信息加权生成新广告的id表示,再用这个新广告的id表示替换原来的id表示。
作者提供的是tf1的代码,我根据文章和源代码写了一下torch版本的。核心部分为以下两个,一个是gme利用邻居生成当前节点表示,另一个是元训练。我复现的是GME-A,基础模型是deepfm。
#获取邻居节点的表示
class GME_A(nn.Module):
......省略
def forward(self, embeddings, idx):
self_one_hot_index = idx[:, :self.args['n_one_hot_slot']]
self_mul_hot_index = idx[:, self.args['n_one_hot_slot']:16]
self_mul_hot_index = self_mul_hot_index.reshape(idx.shape[0], -1 ,self.args['max_len_per_slot'])
ngb_index = idx[:, 16: ].reshape(idx.shape[0], self.args['max_n_ngb'], -1)
ngb_one_hot_index = ngb_index[:, :, :self.args['n_one_hot_slot']]
ngb_mul_hot_index = ngb_index[:, :, self.args['n_one_hot_slot']:]
ngb_mul_hot_index = ngb_mul_hot_index.reshape(ngb_mul_hot_index.shape[0], self.args['max_n_ngb'], self.args['n_mul_hot_slot'], -1)
self_one_hot_emb = get_masked_one_hot_emb(embeddings, self_one_hot_index)
self_mul_hot_emb = get_masked_mul_hot_emb(embeddings, self_mul_hot_index)
ngb_one_hot_emb = get_masked_one_hot_emb_ngb(embeddings, ngb_one_hot_index)
ngb_mul_hot_emb = get_masked_mul_hot_emb_ngb(embeddings, ngb_mul_hot_index)
ngb_emb = torch.cat((ngb_one_hot_emb, ngb_mul_hot_emb), dim=2)
self_emb = torch.cat((self_one_hot_emb, self_mul_hot_emb), dim=1)
self_attr_emb = self_emb[:, self.args['attr_idx'], :]
ngb_attr_emb = ngb_emb[:, :, self.args['attr_idx'], :]
self_attr_emb = self_attr_emb.reshape(self_attr_emb.shape[0], -1)
ngb_attr_emb_ori = ngb_attr_emb.reshape(ngb_attr_emb.shape[0], ngb_attr_emb.shape[1], -1)
self_attr_emb_exp = self_attr_emb.unsqueeze(1)
self_attr_emb_tile = torch.tile(self_attr_emb_exp, (1, self.args['max_n_ngb']+1, 1))
ngb_attr_emb = torch.cat((ngb_attr_emb_ori, self_attr_emb_exp), dim=1)
self_attr_emb_2d = self_attr_emb_tile.reshape(-1, len(self.args['attr_idx']) * self.args['k'])
ngb_attr_emb_2d = ngb_attr_emb.reshape(-1, len(self.args['attr_idx']) * self.args['k'])
temp_self = torch.matmul(self_attr_emb_2d, self.W_gat)
temp_ngb = torch.matmul(ngb_attr_emb_2d, self.W_gat)
wgt = self.leaky_relu(torch.matmul(torch.cat((temp_self, temp_ngb), dim=1), self.a_gat))
#batch_size, weight, 1
wgt = wgt.reshape(-1, self.args['max_n_ngb']+1, 1)
nlz_wgt = self.softmax1(wgt)
temp_ngb_re = temp_ngb.reshape(-1, self.args['max_n_ngb']+1, self.args['att_dim'])
up_attr_emb = self.elu(torch.mul(temp_ngb_re, nlz_wgt).sum(1))
pred_emb = self.args['gamma'] * self.tanh(torch.matmul(up_attr_emb, self.W_meta))
pred_emb = pred_emb.reshape(-1, 1, self.args['k'])
#return pred_emb, wgt, nlz_wgt
return pred_emb
def train_gme_model(base_model, gme):
traina = pd.read_csv('./data/ml-1m/train_oneshot_a_w_ngb.csv', header=None).to_numpy()
traina_dataset = MovieDataset(traina[:, 1:], traina[:, :1])
traina_loader = DataLoader(traina_dataset, batch_size=args['meta_batch_size'], shuffle=False)
trainb = pd.read_csv('./data/ml-1m/train_oneshot_b_w_ngb.csv', header=None).to_numpy()
trainb_dataset = MovieDataset(trainb[:, 1:], trainb[:, :1])
trainb_loader = DataLoader(trainb_dataset, batch_size=args['meta_batch_size'], shuffle=False)
loss_func = nn.BCELoss(reduction='mean')
gme.train()
optimizer = torch.optim.Adam(gme.parameters(), lr=args['meta_eta'])
loss_sum = 0.0
steps = 0
for (idxa, ya), (idxb, yb) in zip(traina_loader, trainb_loader):
idxa, ya, idxb, yb = idxa.to(device).long(), ya.to(device).float(), idxb.to(device).long(), yb.to(device).float()
one_hot_mid_emb = gme(base_model.embeddings, idxa)
pred_a = base_model.use_ngb_emb(idxa, one_hot_mid_emb)
loss_a = loss_func(pred_a, ya)
grad_a = torch.autograd.grad(loss_a, one_hot_mid_emb, retain_graph=True)
one_hot_mid_emb = one_hot_mid_emb - args['cold_eta'] * grad_a[0]
pred_b = base_model.use_ngb_emb(idxb, one_hot_mid_emb)
loss_b = loss_func(pred_b, yb)
loss = loss_a * args['alpha'] + loss_b * (1-args['alpha'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
steps += 1
print(f'meta-training loss: {loss_sum / steps: .4f}')
数据集为movielens-1m,RndEmb为没有提供初始化id表示的结果。
文章结果
AUC | Loss | |
---|---|---|
RndEmb | 0.7143 | 0.6462 |
GME-A | 0.7206 | 0.6449 |
我的复现
AUC | Loss | |
---|---|---|
RndEmb | 0.7148 | 0.6609 |
GME-A | 0.7202 | 0.6498 |
还是可以比较明显看出效果的。