异构图中节点的分类/回归
异构图中节点的分类/回归
导入包
import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
创建一个异构图
设置这个图中的节点个数和边的个数
n_users = 100 #user节点个数
n_jobspre = 500 #jobpre节点的个数
n_uj = 3000 #边'uj'的个数
n_ju = 3000 #边'ju'的个数
n_college = 100 #college节点个数
n_cu = 100
n_uc = 100
n_hetero_features = 20 #相当于要做的嵌入维度
n_user_classes = 20 #假设得出的job种类有这么多
n_skills = 500 #skill节点个数
n_us = 4000
n_su = 4000
n_js = 4000
n_sj = 4000
设置每条边对应的头和尾(如果有数据集,则直接导入就好)这里自动生成 还不是因为没有数据集
uj_src = np.random.randint(0,n_users,n_uj) #user_jobpre中user的编号
uj_dst = np.random.randint(0,n_jobspre,n_uj) #user_jobpre中jobpre的编号
uc_src = np.random.randint(0,n_users,n_cu) #user_college中user的编号
uc_dst = np.random.randint(0,n_college,n_cu) #user_college中college的编号
us_src = np.random.randint(0,n_users,n_us) #user_skill中user的编号
us_dst = np.random.randint(0,n_skills,n_us) #user_skill中skill的编号
js_src = np.random.randint(0,n_jobspre,n_js)
js_dst = np.random.randint(0,n_skills,n_js)
利用dgl构建异构图
hetero_graph = dgl.heterograph({
('user', 'uj', 'jobpre'): (uj_src, uj_dst),
('jobpre', 'ju', 'user'): (uj_dst, uj_src),
('user', 'uc', 'college'): (uc_src, uc_dst),
('college', 'cu', 'user'): (uc_dst, uc_src),
('user', 'us', 'skill'): (us_src, us_dst),
('skill', 'su', 'user'): (us_dst, us_src),
('jobpre','js','skill'):(js_src,js_dst),
('skill','sj','jobpre'):(js_dst,js_src),
})
初始化每个节点的嵌入
hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features) #给user node 添加属性,相当于用户嵌入维度为20
hetero_graph.nodes['jobpre'].data['feature'] = torch.randn(n_jobspre, n_hetero_features)
hetero_graph.nodes['college'].data['feature'] = torch.randn(n_college, n_hetero_features)
hetero_graph.nodes['skill'].data['feature'] = torch.randn(n_skills, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,)) #用户的个人标签
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6) #选出一些计算损失函数
定义一个异构卷积层
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
# 实例化HeteroGraphConv,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregate是聚合函数的类型
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs): #inputs: node_features
# 输入是节点的特征字典
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
定义模型并进行训练
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes) #['clicked-by', 'disliked-by', 'click', 'dislike', 'follow', 'followed-by']
user_feats = hetero_graph.nodes['user'].data['feature'] #用户的特征嵌入
jobpre_feats = hetero_graph.nodes['jobpre'].data['feature'] #物品的特征嵌入
college_feats = hetero_graph.nodes['college'].data['feature'] #物品的特征嵌入
skill_feats = hetero_graph.nodes['skill'].data['feature'] #物品的特征嵌入
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']
node_features = {'user': user_feats, 'jobpre': jobpre_feats,'college':college_feats,'skill':skill_feats} #所有用户的特征嵌入 所有物品的特征嵌入
opt = torch.optim.Adam(model.parameters())
for epoch in range(5):
model.train()
# 使用所有节点的特征进行前向传播计算,并提取输出的user节点嵌入
logits = model(hetero_graph, node_features)['user']
h_dict = model(hetero_graph, {'user': user_feats, 'jobpre': jobpre_feats,'college':college_feats,'skill':skill_feats} )
h_user = h_dict['user'] #模型每次迭代后得出的 用户的特征嵌入
print(h_user[2]) # 第2个用户的特征,输出特征为5,因为用户的分类为5个
h_jobpre= h_dict['jobpre'] #模型每次迭代后得出的 物品的特征嵌入
h_college = h_dict['college']
h_skill = h_dict['skill']
# 计算损失值
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
# 计算验证集的准确度。在本例中省略。
# 进行反向传播计算
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
torch.save(model.state_dict(), "./model/main01" + "_" + str(epoch)) #保存模型
使用(测试)模型
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
model.load_state_dict(torch.load("./model/main01_4" ))
user_feats = hetero_graph.nodes['user'].data['feature']
print(user_feats )
print(user_feats[0])
结果截图:
posted on 2023-05-01 09:48 monster-little 阅读(122) 评论(0) 编辑 收藏 举报