一.前言
在PyG中,除了直接使用它自带的benchmark数据集外,用户还可以自定义数据集,其方式与Pytorch类似,需要继承数据集类。PyG中提供了两个数据集抽象类:
torch_geometric.data.Dataset
:用于构建大型数据集(非内存数据集);torch_geometric.data.InMemoryDataset
:用于构建内存数据集(小数据集),继承自Dataset
。
下面是对其的详细介绍。
二.内存数据集
2.1 创建说明
在PyG中要构建自己的内存数据集需要先继承InMemoryDataset
类,并实现如下方法:
raw_file_names()
:返回原始数据集的文件名列表,若self.raw_dir
中没有该列表中的文件,则会通过download()
进行下载;processed_file_names()
:返回process()
方法处理后的文件名列表,若self.processed_dir
中没有确实该列表中的文件,则需要通过process()
方法进行处理;download()
:下载原始数据集到self.raw_dir
中;process()
:处理原始数据集,并保存到processed_dir
中。
在前两个方法中,若只有单个文件,则直接返回文件字符串即可,不一定要返回list对象。
另外,上面的self.raw_dir
和self.processed_dir
其实是两个方法,其源码为:
# 加上@property,可以使得方法像属性一样被调用@propertydef raw_dir(self) -> str:return osp.join(self.root, 'raw')@propertydef processed_dir(self) -> str:return osp.join(self.root, 'processed')
从源码可以看出,self.raw_dir
和self.processed_dir
是给定保存路径root
下的原始数据文件夹和处理后的数据文件夹的路径。
2.2 创建演示
本文以SNAP数据集中的一个社交网络Facebook为例,来演示如何创建一个InMemoryDataset
数据集FaceBook
,该数据集包含4039个节点、88234条边。利用Gephi对该网络进行可视化如下:
根据3.1节中的说明,下面是自定义FaceBook
类的源码:
import osimport pandas as pdimport torchfrom torch_geometric.data import Datafrom torch_geometric.data import InMemoryDataset, download_url, extract_gzclass FaceBook(InMemoryDataset):url = "https://snap.stanford.edu/data/facebook_combined.txt.gz"def __init__(self,root,transform=None,pre_transform=None,pre_filter=None):super().__init__(root, transform, pre_transform, pre_filter)self.data, self.slices = torch.load(self.processed_paths[0])@propertydef raw_file_names(self):return ["facebook_combined.txt"]@propertydef processed_file_names(self):return "data.pt"def download(self):path = download_url(self.url, self.raw_dir)extract_gz(path, self.raw_dir)def process(self):# 加载原始数据文件path = os.path.join(self.raw_dir, "facebook_combined.txt")edges = pd.read_csv(path, header=None,delimiter=" ").values.reshape(2, -1)# 构建Data对象edge_index = torch.from_numpy(edges)g = Data(edge_index=edge_index, num_nodes=4039)data, slices = self.collate([g])torch.save((data, slices), self.processed_paths[0])if __name__ == "__main__":dataset = FaceBook(root="tmp")data = dataset[0]print(data.num_edges, data.num_nodes)# 88234 4039
需要注意的是
download
和process
只在第一次调用时会调用,之后会直接加载处理好的数据集。以上4个方法并不都是需要的,例如如果你本地已经有了数据集,就不需要重写download()
函数来下载原始数据集。
三.大型数据集
对于大型图数据集,需要继承Dataset
类,除了InMemoryDataset
中需要重写的4个方法外,还需重写如下方法:
len()
: 返回数据集中实例的数量;get()
:加载单个图的逻辑。
由于自定义大型数据集与InMemoryDataset
类似,具体演示略。
四.结语
参考资料:
creating your own datasetstorch_geometric.data
自定义数据集是一项重要的事情,尤其是当你本地有些数据需要转换为PyG中标准的图数据集的时候。