PGL图神经网络学习总结
DAY1
- 图的概念,非结构化数据,描述复杂系统的语言。
- 常见的图:社交网络、推荐系统、化学分子结构
- 应用:分为点任务——欺诈识别、用户画像,边任务——推荐系统,图任务——社区发现
- 算法:图游走类算法——deepwalk、node2vec,图神经网络——GCN、GAT,图嵌入
- 图构建:转移矩阵、边列表
DAY2
- DeepWalk的主要原理是通过随机游走生成节点路径,然后将其作为词向量模型SkipGram的输入来学习节点表示。SkipGram和CBow的区别在于是从嵌入词推断上下文还是上下文推断中心词,SkipGram采用负采样的方式评估损失。
- Node2Vec和RandomWalk的区别在于使用q和p倾向于探索局部还是深度
DAY3
对于给定的节点,DeepWalk会等概率的选取下一个相邻节点加入路径,直至达到最大路径长度,或者没有下一个节点可选。
实现代码:
%%writefile userdef_graph.py from pgl.graph import Graph import numpy as np class UserDefGraph(Graph): def random_walk(self, nodes, walk_len): """ 输入:nodes - 当前节点id list (batch_size,) walk_len - 最大路径长度 int 输出:以当前节点为起点得到的路径 list (batch_size, walk_len) 用到的函数 1. self.successor(nodes) 描述:获取当前节点的下一个相邻节点id列表 输入:nodes - list (batch_size,) 输出:succ_nodes - list of list ((num_successors_i,) for i in range(batch_size)) 2. self.outdegree(nodes) 描述:获取当前节点的出度 输入:nodes - list (batch_size,) 输出:out_degrees - list (batch_size,) """ walks = [[node] for node in nodes] walks_ids = np.arange(0, len(nodes)) cur_nodes = np.array(nodes) for l in range(walk_len): """选取有下一个节点的路径继续采样,否则结束""" outdegree = self.outdegree(cur_nodes) walk_mask = (outdegree != 0) if not np.any(walk_mask): break cur_nodes = cur_nodes[walk_mask] walks_ids = walks_ids[walk_mask] outdegree = outdegree[walk_mask] ###################################### # 请在此补充代码采样出下一个节点 succ_nodes = self.successor(cur_nodes) next_nodes = [ np.random.choice(node) for node in succ_nodes] for i, next_node in zip(walks_ids,next_nodes): walks[i].append(next_node) ###################################### cur_nodes = np.array(next_nodes) return walks
NOTE:在得到节点路径后,node2vec会使用SkipGram模型学习节点表示,给定中心节点,预测局部路径中还有哪些节点。模型中用了negative sampling来降低计算量。
实现代码:
%%writefile userdef_model.py import paddle.fluid.layers as l def userdef_loss(embed_src, weight_pos, weight_negs): """ 输入:embed_src - 中心节点向量 list (batch_size, 1, embed_size) weight_pos - 标签节点向量 list (batch_size, 1, embed_size) weight_negs - 负样本节点向量 list (batch_size, neg_num, embed_size) 输出:loss - 正负样本的交叉熵 float """ """ pos_logits = l.matmul( embed_src, weight_pos, transpose_y=True) # [batch_size, 1, 1] neg_logits = l.matmul( embed_src, weight_negs, transpose_y=True) # [batch_size, 1, neg_num] """ ################################## # 请在这里实现SkipGram的loss计算过程 pos_logits = l.matmul( embed_src, weight_pos, transpose_y=True) # [batch_size, 1, 1] neg_logits = l.matmul( embed_src, weight_negs, transpose_y=True) # [batch_size, 1, neg_num] ones_label = pos_logits * 0. + 1. ones_label.stop_gradient = True pos_loss = l.sigmoid_cross_entropy_with_logits(pos_logits, ones_label) zeros_label = neg_logits * 0. zeros_label.stop_gradient = True neg_loss = l.sigmoid_cross_entropy_with_logits(neg_logits, zeros_label) loss = (l.reduce_mean(pos_loss) + l.reduce_mean(neg_loss)) / 2 ################################## return loss
Node2Vec采样算法
Node2Vec会根据与上个节点的距离按不同概率采样得到当前节点的下一个节点。
实现代码:
%%writefile userdef_sample.py import numpy as np def node2vec_sample(succ, prev_succ, prev_node, p, q): """ 输入:succ - 当前节点的下一个相邻节点id列表 list (num_neighbors,) prev_succ - 前一个节点的下一个相邻节点id列表 list (num_neighbors,) prev_node - 前一个节点id int p - 控制回到上一节点的概率 float q - 控制偏向DFS还是BFS float 输出:下一个节点id int """ ################################## # 请在此实现node2vec的节点采样函数 probs = [] prob_sum = 0. for succ_ in succ: if succ_ == prev_node: prob = 1. / p elif succ_ in prev_succ: prob = 1. else: prob = 1.0 / q probs.append(prob) prob_sum += prob random_number = np.random.rand() * prob_sum for i, succ_ in enumerate(succ): random_number -= probs[i] if random_number <= 0: sampled_succ = succ_ break ################################## return sampled_succ
DAY4 GAT模型
DAY5 GraphSage的采样聚合
GraphSage主要解决图空间太大不够内存计算的问题,采取了类似mini-batch的方式。
- 假设我们要利用中心节点的k阶邻居信息,则在聚合的时候,需要从第k阶邻居传递信息到k-1阶邻居,并依次传递到中心节点。
- 采样的过程刚好与此相反,在构造第t轮训练的Mini-batch时,我们从中心节点出发,在前序节点集合中采样NtN_tNt个邻居节点加入采样集合。
- 接着将邻居节点作为新的中心节点继续进行第t-1轮训练的节点采样,以此类推。
- 最后将采样到的节点和边一起构造得到子图。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)