600字范文,内容丰富有趣,生活中的好帮手!
600字范文 > PyG创建数据集

PyG创建数据集

时间:2020-01-12 00:58:45

相关推荐

PyG创建数据集

PyG创建数据集

使用自己的数据集,通过PyG封装的库来转变为Pytorch的数据集

文章目录

PyG创建数据集前言一、封装的库二、创建内存数据集1.一个例子三、创建较大的数据集1.例子代码如下:四、总结1.常见的问题:2.尝试思考:

前言

虽然 PyG 已经包含许多有用的数据集,但您可能希望使用自我记录或非公开可用的数据创建自己的数据集。

自己实现数据集很简单,可能只需要查看源代码以了解各种数据集是如何实现的。下面将简要介绍设置您自己的数据集所需的内容。

一、封装的库

PyG为数据集提供了两个抽象类:torch_geometric.data.Dataset 和torch_geometric.data.InMemoryDataset。InMemoryDataset继承自Dataset,如果整个数据集储存在CPU,则应该使用它。

按照 torchvision 约定,每个数据集都存在一个根文件夹,该文件夹指示数据集的存储位置。

我们将根文件夹分成两个文件夹:raw_dir,数据集下载到的位置,以及处理后的数据集保存的位置。

另外,每个数据集都可以传递一个transform、一个pre_transform和一个pre_filter函数,它们默认为None。

transform的功能在访问数据之前动态的转换数据对象(用来数据增强)。pre_transform的功能将数据对象保存到磁盘之前应用的转换(因此它最好用于只需要执行一次大量的预计算)。pre_filter的功能可以在保存之前手动过滤掉数据对象。用例可能涉及限制数据对象属于特定类(过滤筛选)。

二、创建内存数据集

为了创建一个 torch_geometric.data.InMemoryDataset,需要实现四个基本方法:

torch_geometric.data.InMemoryDataset.raw_file_names(): 为了跳过下载,需要找到 raw_dir 中的文件列表。torch_geometric.data.InMemoryDataset.processed_file_names():为了跳过处理,需要找到process_dir中的文件列表。torch_geometric.data.InMemoryDataset.download():将原始数据下载到 raw_dir。torch_geometric.data.InMemoryDataset.process(): 处理原始数据并将其保存到 processes_dir 中。

process()函数是真正起到主体作用的函数。在这里,我们需要读取并创建一个 Data 对象列表并将其保存到 processes_dir 中。因为保存一个巨大的 python 列表相当慢,我们在保存之前通过 torch_geometric.data.InMemoryDataset.collat​​e() 将列表整理成一个巨大的 Data 对象。整理后的数据对象将所有示例连接到一个大数据对象中,此外,还返回一个切片字典以从该对象重构单个示例。最后,我们需要在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。

1.一个例子

代码如下(示例):

import torchfrom torch_geometric.data import InMemoryDataset, download_urlclass MyOwnDataset(InMemoryDataset):## 继承torch_geometric.data.InMemoryDataset父类def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):super().__init__(root, transform, pre_transform, pre_filter)self.data, self.slices = torch.load(self.processed_paths[0]) ##在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。@propertydef raw_file_names(self): ##如果raw_file中有文件就会跳过下载return ['some_file_1', 'some_file_2', ...]@property ##如果processed_file就会跳过处理def processed_file_names(self):return ['data.pt']def download(self):# 下载到`self.raw_dir`。download_url(url, self.raw_dir)...def process(self):# 将数据读入巨大的“数据”列表。data_list = [...]if self.pre_filter is not None:data_list = [data for data in data_list if self.pre_filter(data)]if self.pre_transform is not None:data_list = [self.pre_transform(data) for data in data_list]data, slices = self.collate(data_list)torch.save((data, slices), self.processed_paths[0])

三、创建较大的数据集

如果你的内存比较小无法创建内存数据集,可以使用 torch_geometric.data.Dataset,它紧跟 torchvision 数据集的概念。它另外实现以下方法:

torch_geometric.data.Dataset.len():返回数据集中样本的数量。torch_geometric.data.Dataset.get():实现加载单个图形的逻辑。

在内部,torch_geometric.data.Dataset.getitem() 从 torch_geometric.data.Dataset.get() 获取数据对象,并可选择根据变换对其进行变换。

1.例子代码如下:

import os.path as ospimport torchfrom torch_geometric.data import Dataset, download_urlclass MyOwnDataset(Dataset):def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):super().__init__(root, transform, pre_transform, pre_filter)@propertydef raw_file_names(self):return ['some_file_1', 'some_file_2', ...]@propertydef processed_file_names(self):return ['data_1.pt', 'data_2.pt', ...]def download(self):path = download_url(url, self.raw_dir)...def process(self):idx = 0for raw_path in self.raw_paths:data = Data(...)if self.pre_filter is not None and not self.pre_filter(data):continueif self.pre_transform is not None:data = self.pre_transform(data)torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))idx += 1def len(self):return len(self.processed_file_names)def get(self, idx):data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))return data

在这里,每个图形数据对象都单独保存在 process() 中,并在 get() 中手动加载。

四、总结

1.常见的问题:

如何跳过 download() 和/或 process() 的执行?

你可以通过不重写download()和process()方法来跳过下载和/或处理。

## 比如,对比上述代码class MyOwnDataset(Dataset):def __init__(self, transform=None, pre_transform=None):super().__init__(None, transform, pre_transform)

我们真的需要使用这些数据集接口吗?

不!就像在常规 PyTorch 中一样,您不必使用数据集,例如,当您想要动态创建合成数据而不将它们显式保存到磁盘时。在这种情况下,只需传递一个包含 torch_geometric.data.Data 对象的常规 python 列表并将它们传递给 torch_geometric.loader.DataLoader:

from torch_geometric.data import Datafrom torch_geometric.loader import DataLoaderdata_list = [Data(...), ..., Data(...)]loader = DataLoader(data_list, batch_size=32)

2.尝试思考:

class MyDataset(InMemoryDataset):def __init__(self, root, data_list, transform=None):self.data_list = data_listsuper().__init__(root, transform)self.data, self.slices = torch.load(self.processed_paths[0])@propertydef processed_file_names(self):return 'data.pt'def process(self):torch.save(self.collate(self.data_list), self.processed_paths[0])

1.上述代码中self.processed_paths[0]输出的是什么?

2.collat​​e() 有什么作用?(将torch_geometric.data.Data 对象的 Python 列表整理为 InMemoryDataset 的内部存储格式。)

一般还是手动加载数据集,因为内存有限数据集可能非常大。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。