600字范文,内容丰富有趣,生活中的好帮手!
600字范文 > 【PyG 教程】PyG 自定义构造 GNN

【PyG 教程】PyG 自定义构造 GNN

时间:2018-10-09 03:37:46

相关推荐

【PyG 教程】PyG 自定义构造 GNN

文章作者:梦家

个人站点:dreamhomes.top

原文地址:https://dreamhomes.github.io/posts/1115.html

公众号ID:DreamHub

基于 PyG 构造消息传递网络

图上的卷积操作主要包含两部分:节点消息传递与消息聚集。假设 xi(k−1)∈RF\mathbf{x}_i^{(k-1)} \in \mathbb{R}^{F}xi(k−1)​∈RF 表示k−1k-1k−1层节点的特征,ej,i∈RD\mathbf{e}_{j, i} \in \mathbb{R}^{D}ej,i​∈RD表示节点jjj到节点iii 的边的特征。那么消息传递的图神经网络可以表示为:

xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))\mathbf{x}_{i}^{(k)}=\gamma^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \square_{j \in \mathcal{N}(i)} \phi^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \mathbf{x}_{j}^{(k-1)}, \mathbf{e}_{j, i}\right)\right)xi(k)​=γ(k)(xi(k−1)​,□j∈N(i)​ϕ(k)(xi(k−1)​,xj(k−1)​,ej,i​))

其中□\square□ 表示可微分的排列不变函数,e.g. sum,mean,max。λ\lambdaλ和γ\gammaγ 表示可微分的函数,e.g. MLPs。

PyG 中torch_geometric.nn.MessagePassing提供一系列的消息传递方法来自动处理消息传播过程。

接下来以构造经典的kipf提出的GCNGCNGCN为例。

实现 GCN 层

GCN层的数学定义如下:

xi(k)=∑j∈N(i)∪{i}1deg⁡(i)⋅deg⁡(j)⋅(Θ⋅xj(k−1))\mathbf{x}_{i}^{(k)}=\sum_{j \in \mathcal{N}(i) \cup\{i\}} \frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}} \cdot\left(\boldsymbol{\Theta} \cdot \mathbf{x}_{j}^{(k-1)}\right)xi(k)​=j∈N(i)∪{i}∑​deg(i)​⋅deg(j)​1​⋅(Θ⋅xj(k−1)​)

由上式可知节点特征首先经过Θ\ThetaΘ 进行特征变换,然后根据度进行归一化然后求和。计算公式可以拆分为以下几步:

邻接矩阵 AAA 添加自环。节点特征矩阵的线性变换。计算归一化系数。归一化节点特征。ϕ\phiϕ邻居节点求和。(add聚集)得到最后的节点嵌入向量。γ\gammaγ

以上过程基于 PyG 的实现如下:

import torchfrom torch_geometric.nn import MessagePassingfrom torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add') # "Add" aggregation.self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.x = self.lin(x)# Step 3: Compute normalizationrow, col = edge_indexdeg = degree(row, x.size(0), dtype=x.dtype)deg_inv_sqrt = deg.pow(-0.5)norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]# Step 4-6: Start propagating messages.return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x,norm=norm)def message(self, x_j, norm):# x_j has shape [E, out_channels]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef update(self, aggr_out):# aggr_out has shape [N, out_channels]# Step 6: Return new node embeddings.return aggr_out

定义好卷积层后即可调用卷积层进行堆叠:

conv = GCNConv(16, 32)x = conv(x, edge_index)

实现边卷积

这种方式个人用得较少,简单记录下。对于点云数据的卷积定义为:

xi(k)=max⁡j∈N(i)hΘ(xi(k−1),xj(k−1)−xi(k−1))\mathbf{x}_{i}^{(k)}=\max _{j \in \mathcal{N}(i)} h_{\Theta}\left(\mathbf{x}_{i}^{(k-1)}, \mathbf{x}_{j}^{(k-1)}-\mathbf{x}_{i}^{(k-1)}\right)xi(k)​=j∈N(i)max​hΘ​(xi(k−1)​,xj(k−1)​−xi(k−1)​)

基于 PyG 实现方式如下:

import torchfrom torch.nn import Sequential as Seq, Linear, ReLUfrom torch_geometric.nn import MessagePassingclass EdgeConv(MessagePassing):def __init__(self, in_channels, out_channels):super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation.self.mlp = Seq(Linear(2 * in_channels, out_channels),ReLU(),Linear(out_channels, out_channels))def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)def message(self, x_i, x_j):# x_i has shape [E, in_channels]# x_j has shape [E, in_channels]tmp = torch.cat([x_i, x_j - x_i], dim=1) # tmp has shape [E, 2 * in_channels]return self.mlp(tmp)def update(self, aggr_out):# aggr_out has shape [N, out_channels]return aggr_out

from torch_geometric.nn import knn_graphclass DynamicEdgeConv(EdgeConv):def __init__(self, in_channels, out_channels, k=6):super(DynamicEdgeConv, self).__init__(in_channels, out_channels)self.k = kdef forward(self, x, batch=None):edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)return super(DynamicEdgeConv, self).forward(x, edge_index)

conv = DynamicEdgeConv(3, 128, k=6)x = conv(pos, batch)

了解以上内容即可知道如何自定义 GNN 计算方式。

更多内容参考官网教程:https://pytorch-geometric.readthedocs.io/en/latest/index.html

联系作者

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。