1.1 消息传递范式
如下图所示,绿框为对应节点的属性;蓝框代表对应边的属性。节点的状态更新需聚合各入边传递来的消息。
消息传递范式的形式化表达如下所示:
1.2 内置消息传递API
详情可访问:dgl.function - DGL 0.6.1 documentation
消息函数
消息函数:接受一个参数edges(EdgeBatch实例),在消息传递时,它在DGL内部表示一批边。edges属性src、dst、data分别表示源节点特征、目标节点特征、边特征。
比如:要对源节点的hu特征和目标节点的hv特征求和,然后将结果保存在边的he特征上。解决此问题有两个方案。
# 方法一:使用DGL的内置消息函数解决
dgl.function.u_add_v('hu', 'hv', 'he')
# 方支二:用户自定义消息函数
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
聚合函数
聚合函数接受一个nodes参数(NodeBatch实例),在消息传递时,在DGL内部表示一批节点。nodes的mailbox属性代表节点收到的消息。
比如:要把节点收到的消息m进行sum聚合,再把结果赋值给节点的h特征。
# 方法一:应用DGL内置函数解决
dgl.function.sum('m', 'h')
# 方法二:用户自定义函数
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
apply_edges与update_all
如果不涉及消息传递,可通过apply_edges()单独调用逐边计算。apply_edges()参数是一个消息函数,默认情况这个接口将更新所有的边。比如下述调用:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
对于消息传递,可考虑使用update_all()。它的参数为一个消息函数、一个聚合函数、一个更新函数。更新函数是可选项,DGL不推荐指定,而在外围基于张量操作实现。
例如:将源节点特征ft与边特征a相乘生成消息m,然后对所有消息求和来更新节点特征ft,再将ft乘以2得到最终结果final_ft。其形式化表达和实现代码如下:
def updata_all_example(graph):
# 在graph.ndata['ft']中存储结果
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# 在update_all外调用更新函数
final_ft = graph.ndata['ft'] * 2
return final_ft
1.3 边权重使用方法
有时需要在消息聚合前使用边的权重,比如GAT和一些GCN的变种模型。对应的解决方案为:
将权重存为边的特征。
在消息函数中用边的特征与源节点特征相乘。如下述代码中eweight被用作边的权重:
import dgl.function as fn
# 假定eweight是一个形状为(E, *)的张量,E是边的数量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
1.4 异构图中消息传递
异构图上的消息传递可以分为两个部分:1)对每个关系计算和聚合消息 2)对每个结点聚合来自不同关系的消息。在DGL中,对异构图进行消息传递的接口为multi_update_all(),其有两个参数:
参数1:字典型,其中键代表关系,值是这种关系对应update_all()的参数。
参数2:字符串型,用来表示整合不同关系聚合结果的方式。可结合下例进行理解:
import dgl.function as fn
for c_etype in G.canonical_etypes:
srctype, etype, dsttype = c_etype
Wh = self.weight[etype](feat_dict[srctype])
# 把它存在图中用来做消息传递
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# 指定每个关系的消息传递函数:(message_func, reduce_func).
# 注意结果保存在同一个目标特征“h”,说明聚合是逐类进行的。
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 将每个类型消息聚合的结果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新过的节点特征字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
基于上述对消息传递机制的解读,直接来看SAGEConv的forward方法如下:
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
from torch import nn
from torch.nn import functional as F
from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The LSTM module is using xavier initialization method for its weights.
"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
def forward(self, graph, feat, edge_weight=None):
r"""
Description
-----------
Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
with graph.local_scope():
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst) # @@将张量的dtype,device转换成与参数feat_dst一致
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst) # @@返回一维张量
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) # @@扩展为2维张量,维度广播加1
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
DGL官方文档:User Guide - DGL 0.6.1 documentation
https://docs.dgl.ai/guide/index.html