-
Notifications
You must be signed in to change notification settings - Fork 116
Open
Description
背景介绍
PyTorch的 Dataset class定义
我们可以发现,PyTorch要求Dataset必须提供 __len__ 接口和 __getitem__接口,这就要求 数据集是已知长度的,并且是可以被随机访问的。
这里与TensorFlow不同,TensorFlow的Dataset是可以从一个generator创建的,generator只要求用户实现 __next__接口即可,并不要求 __len__ 接口和 __getitem__ 接口。
因此,我们需要提出一种新的思路。
简单的做法
- worker从master那里拿到一个task
- worker 使用 recordio_reader提供的接口,把该task包含的record都读到内存中
- records是一个即知道长度,又可以随机访问的数组,我们可以从这个数组中创建一个 RecordDataset
- RecordDataset中,每一个record都是string类型的,用户需要提供一个feed函数把string类型转换为数值类型。我们发现,这个feed函数实际上就是 PyTorch中的 Transform
- 最后我们从TransformedDataset中创建一个Dataloader,然后做batch, shuffle等,开始训练
伪代码
while True:
task = get_task()
records = [record for r in reader.read_records(task)]
RecordDataset = create_dataset(records)
TransfromedDataSet = dataset_fn(RecordDataset)
dataloader = DataLoader(TransfromedDataSet, shuffle=true, batch_size=32)
for batch in dataloader:
self.ps_client.pull_dense_parameters()
loss = forward(batch)
loss.backward()
with torch.no_grad():
grads = [param.grad.numpy() for param in model.params()]
self.ps_client.push_gradients(grads)Metadata
Metadata
Assignees
Labels
No labels