一.概述
PyG是面向图数据的,它同时支持同构图(homogeneous graphs)和异构图(heterogeneous)。同构图指只包含一种类型的节点和边的图(下图左)。而异构图指包含两种及以上类型的节点和边的图(下图右)。
在PyG中,同构图被描述为torch_geometric.data.Data
类的实例,而异构图被描述为torch_geometric.data.HeteroData
的实例。
本文主要介绍PyG关于同构图的的相关操作,操作环境为:
pytorch = 1.10.1cuda = 11.3torch_geometric = 2.0.4
二.基本图操作
2.1 图的创建
同构图是用Data
类是进行描述的,因此首先查看其初始化函数的参数列表:
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,edge_attr: OptTensor = None, y: OptTensor = None,pos: OptTensor = None, **kwargs):
对应的参数说明为:
Data
类的初始化函数中参数默认值都为None
,这意味着没有哪个参数是必要的,在实际使用时需要根据待构造图的实际情况来传入相应的属性。
2.2 常用的图属性与方法
在PyG中,对于一个Data
对象其包含众多属性和方法,这里列举一下常用的,更详细的请参见官网Data部分。
2.3 演示示例
首先创建一个包含5个顶点、12条边的无向图。需要注意的是,在edge_index
中边都有有方向的,即从源节点到目标节点。若要创建从节点vvv到节点uuu的无向边,则需要在edge_index
中传入两条相应的边,即(u,v), (v,u)
。
import torchimport torch_geometric.data as datafrom torch_geometric.utils import to_networkximport matplotlib.pyplot as pltimport networkx as nxedge_index = torch.LongTensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4],[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0]])x = torch.ones(5, 2)g = data.Data(edge_index=edge_index, x=x)print(g)"""Data(edge_index=[2, 12], x=[5, 2])"""# 转换为nextworkx格式的图并可视化g = to_networkx(g)nx.draw(g, with_labels=g.nodes)plt.show()
创建的图可视化结果为:
对上述创建的Data
对象应用2.2节介绍的部分方法实例代码如下:
print(g.num_nodes, g.num_edges)# 5 12print(g.keys)# ['x', 'edge_index']print(g.num_node_features)# 2print(g.is_undirected())# Trueprint(g.has_isolated_nodes())# False
若要将自己创建的图实例保存到本地磁盘或从本地磁盘加载保存的图数据,可以使用torch.save()
和torch.load()
:
torch.save([g], "temp/data.pt")g = torch.load("temp/data.pt")print(g)# [Data(edge_index=[2, 12], x=[5, 2])]
三.进阶图操作
在torch_geometric.utils
模块中包含了许多对图数据的高级操作方法,下面将对其中最常用的方法进行介绍。
3.1 度的计算
通过degree(index, num_nodes=None)
方法可以计算图中节点的度,其中:
index
:edge_index
中的两个维度中任意一个num_nodes
:节点的数量,可选参数
示例代码:
print(degree(g.edge_index[0]))# tensor([3., 2., 3., 2., 2.])print(degree(g.edge_index[1]))# tensor([3., 2., 3., 2., 2.])
3.2 自环的添加与删除
自环指节点指向自身的边。在utils
中处理自环的方法包括:
contains_self_loops(edge_index)
:判断图中节点是否包含自环。remove_self_loops(edge_index)
:删除图中所有的自环。add_self_loops(edge_index)
:为图中的节点添加自环,对于有自环的节点,它会再为该节点添加一个自环。add_remaining_self_loops
:为图中还没有自环的节点添加自环。
示例代码:
print(contains_self_loops(g.edge_index))# Falseedge_index, _ = add_self_loops(g.edge_index)print(edge_index)"""tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 0, 1, 2, 3, 4],[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0, 0, 1, 2, 3, 4]])"""edge_index, _ = add_remaining_self_loops(edge_index)print(edge_index)"""没有添加新的自环tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 0, 1, 2, 3, 4],[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0, 0, 1, 2, 3, 4]])"""edge_index, _ = remove_self_loops(edge_index)print(edge_index)"""tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4],[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0]])"""
3.2 子图提取
utils
中提供了若干方法用来在图中提取子图。
subgraph(subset, edge_index)
:根据给定的图节点集合subset
来抽取图中包含这些节点的子图。k_hop_subgraph(node_idx, num_hops, edge_index)
:提取给定节点集node_idx
能经过num_hops
跳到达的所有节点组成的子图(包括node_idx
本身)。
sub_graph
方法示例代码:
def draw(edge_index):graph = data.Data(edge_index=edge_index)graph = to_networkx(graph)print(graph.nodes)nx.draw(graph, with_labels=graph.nodes)plt.show()edge_index, _ = subgraph(subset=torch.LongTensor([0, 1, 2]), edge_index=g.edge_index)draw(edge_index)
提取的子图可视化如下所示:
k_hop_subgraph
方法的示例代码如下所示:
g = k_hop_subgraph(node_idx=[0], num_hops=1, edge_index=g.edge_index)print(g)"""(tensor([0, 1, 2, 4]), tensor([[0, 0, 0, 1, 1, 2, 2, 4],[1, 2, 4, 0, 2, 1, 0, 0]]), tensor([0]), tensor([ True, True, True, True, True, True, True, False, False, False,False, True]))"""
从上图可以看出,该方法返回一个4元组,元组的4个元素依次为:子图的节点集、子图的边集、用来查询的节点集(中心节点集)、指示原始图g
中的边是否在子图中的布尔数组。我们取子图的边集进行可视化结果如下:
3.4 转换为无向图
通过to_undirected(edge_index)
可以将一个图转换为无向图:
edge_index = torch.LongTensor([[0, 0], [1, 2]])edge_index = to_undirected(edge_index)print(edge_index)"""tensor([[0, 0, 1, 2],[1, 2, 0, 0]])"""
结语
参考资料:
torch_geometric.datatorch_geometric.utils
本文主要介绍了PyG中对单个图的相关操作方法,从上面的操作可以看出对于PyG对图结构的操作其实就是在操作edge_index
(该属性本来就用来在PyG中保存图的结构信息)。