Skip to content

How to use PyTorch dataloader #2222

@QiJune

Description

@QiJune

背景介绍

PyTorch的 Dataset class定义

我们可以发现,PyTorch要求Dataset必须提供 __len__ 接口和 __getitem__接口,这就要求 数据集是已知长度的,并且是可以被随机访问的。

这里与TensorFlow不同,TensorFlow的Dataset是可以从一个generator创建的,generator只要求用户实现 __next__接口即可,并不要求 __len__ 接口和 __getitem__ 接口。

因此,我们需要提出一种新的思路。

简单的做法

  1. worker从master那里拿到一个task
  2. worker 使用 recordio_reader提供的接口,把该task包含的record都读到内存中
  3. records是一个即知道长度,又可以随机访问的数组,我们可以从这个数组中创建一个 RecordDataset
  4. RecordDataset中,每一个record都是string类型的,用户需要提供一个feed函数把string类型转换为数值类型。我们发现,这个feed函数实际上就是 PyTorch中的 Transform
  5. 最后我们从TransformedDataset中创建一个Dataloader,然后做batch, shuffle等,开始训练

伪代码

while True:
    task = get_task()
    records = [record for r in reader.read_records(task)]
    RecordDataset = create_dataset(records)
    TransfromedDataSet = dataset_fn(RecordDataset)
    dataloader = DataLoader(TransfromedDataSet, shuffle=true, batch_size=32)
    for batch in dataloader:
        self.ps_client.pull_dense_parameters()
        loss = forward(batch)
        loss.backward()
        with torch.no_grad():
            grads = [param.grad.numpy() for param in model.params()]
            self.ps_client.push_gradients(grads)

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions