600字范文,内容丰富有趣,生活中的好帮手!
600字范文 > PyG教程(4):自定义数据集

PyG教程(4):自定义数据集

时间:2020-02-08 07:02:55

相关推荐

PyG教程(4):自定义数据集

一.前言

在PyG中,除了直接使用它自带的benchmark数据集外,用户还可以自定义数据集,其方式与Pytorch类似,需要继承数据集类。PyG中提供了两个数据集抽象类:

torch_geometric.data.Dataset:用于构建大型数据集(非内存数据集);torch_geometric.data.InMemoryDataset:用于构建内存数据集(小数据集),继承自Dataset

下面是对其的详细介绍。

二.内存数据集

2.1 创建说明

在PyG中要构建自己的内存数据集需要先继承InMemoryDataset类,并实现如下方法:

raw_file_names():返回原始数据集的文件名列表,若self.raw_dir中没有该列表中的文件,则会通过download()进行下载;processed_file_names():返回process()方法处理后的文件名列表,若self.processed_dir中没有确实该列表中的文件,则需要通过process()方法进行处理;download():下载原始数据集到self.raw_dir中;process():处理原始数据集,并保存到processed_dir中。

在前两个方法中,若只有单个文件,则直接返回文件字符串即可,不一定要返回list对象。

另外,上面的self.raw_dirself.processed_dir其实是两个方法,其源码为:

# 加上@property,可以使得方法像属性一样被调用@propertydef raw_dir(self) -> str:return osp.join(self.root, 'raw')@propertydef processed_dir(self) -> str:return osp.join(self.root, 'processed')

从源码可以看出,self.raw_dirself.processed_dir是给定保存路径root下的原始数据文件夹和处理后的数据文件夹的路径。

2.2 创建演示

本文以SNAP数据集中的一个社交网络Facebook为例,来演示如何创建一个InMemoryDataset数据集FaceBook,该数据集包含4039个节点、88234条边。利用Gephi对该网络进行可视化如下:

根据3.1节中的说明,下面是自定义FaceBook类的源码:

import osimport pandas as pdimport torchfrom torch_geometric.data import Datafrom torch_geometric.data import InMemoryDataset, download_url, extract_gzclass FaceBook(InMemoryDataset):url = "https://snap.stanford.edu/data/facebook_combined.txt.gz"def __init__(self,root,transform=None,pre_transform=None,pre_filter=None):super().__init__(root, transform, pre_transform, pre_filter)self.data, self.slices = torch.load(self.processed_paths[0])@propertydef raw_file_names(self):return ["facebook_combined.txt"]@propertydef processed_file_names(self):return "data.pt"def download(self):path = download_url(self.url, self.raw_dir)extract_gz(path, self.raw_dir)def process(self):# 加载原始数据文件path = os.path.join(self.raw_dir, "facebook_combined.txt")edges = pd.read_csv(path, header=None,delimiter=" ").values.reshape(2, -1)# 构建Data对象edge_index = torch.from_numpy(edges)g = Data(edge_index=edge_index, num_nodes=4039)data, slices = self.collate([g])torch.save((data, slices), self.processed_paths[0])if __name__ == "__main__":dataset = FaceBook(root="tmp")data = dataset[0]print(data.num_edges, data.num_nodes)# 88234 4039

需要注意的是

downloadprocess只在第一次调用时会调用,之后会直接加载处理好的数据集。以上4个方法并不都是需要的,例如如果你本地已经有了数据集,就不需要重写download()函数来下载原始数据集。

三.大型数据集

对于大型图数据集,需要继承Dataset类,除了InMemoryDataset中需要重写的4个方法外,还需重写如下方法:

len(): 返回数据集中实例的数量;get():加载单个图的逻辑。

由于自定义大型数据集与InMemoryDataset类似,具体演示略。

四.结语

参考资料:

creating your own datasetstorch_geometric.data

自定义数据集是一项重要的事情,尤其是当你本地有些数据需要转换为PyG中标准的图数据集的时候。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。