Skip to content

Conversation

NeoLegends
Copy link
Member

@NeoLegends NeoLegends commented Jan 16, 2025

Closes #1634

@NeoLegends NeoLegends self-assigned this Jan 16, 2025
@NeoLegends NeoLegends requested review from a team and albertz as code owners January 16, 2025 13:30
@NeoLegends NeoLegends changed the title Dataset: implement global sharding option Dataset: implement global dataset_distribution option Jan 16, 2025
@albertz

This comment was marked as outdated.

@albertz albertz marked this pull request as draft January 19, 2025 19:28
@NeoLegends
Copy link
Member Author

@albertz Do you think this needs a test around the config processing?

"""
from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that we access the global config here. I know this follows similar code as _get_default_random_seed_offset but I also don't like it there. Why is this needed? This should come from outside, or not? Specifically at the place where we call init_dataset. E.g. in the __main__. There we also call Dataset.kwargs_update_from_config.

Also, the code is wrong. Distributed training is only one possible source which defines/influences the shard index and num shards. But there are other reasons, for example the MultiProcDataset, or PyTorch DataLoader num_workers.

Copy link
Member Author

@NeoLegends NeoLegends Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiProcDataset

This is already setting the num_shards and shard index for its children. The code always was designed in such a way that it would only look at the global config in case there was no value already set. But I agree, now it's better.

PyTorch DataLoader num_workers

I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset. So factoring in the PyTorch num_workers needs to be done during the initial dataset initialization, which mixes up pytorch code with data initialization code a bit, and I feel that is going to be a bit messy. Do you know a good way? Maybe this is fine after all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot achieve both

But e.g. there could be a test with PyTorch DataLoader num_workers=2 which checks that all data from the dataset was properly covered.

and

I know this follows similar code as _get_default_random_seed_offset but I also don't like it there. Why is this needed? This should come from outside, or not?

because of

I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset

However, I think it's worth it to have proper support for torch_dataloader_opts = {"num_workers": n} with n > 1 because this makes it much simpler for the end user to have multi-process data loading and this feature can replace MultiProcDataset for simple use cases. So I think I need to revert back on the changes where num_shards and shard_index are set from the outside, and rather fetch them from inside the dataset, when they are needed. At that point the torch worker info is also available, which means that we can take it into account properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset.

Why difficult? We maybe just need a clean dataset API for that, some setter function set_num_shards_and_shard_idx or so. And then in ReturnnDatasetIterDataPipe.reset or so we just need to call that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I was originally not a fan of the mutability of these properties, but it seems ok now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot avoid such API like set_num_shards_and_shard_idx because of how the PyTorch data pipeline works.

@albertz
Copy link
Member

albertz commented Feb 6, 2025

@albertz Do you think this needs a test around the config processing?

I'm not exactly sure what you mean by that.

But e.g. there could be a test with PyTorch DataLoader num_workers=2 which checks that all data from the dataset was properly covered.

@NeoLegends NeoLegends marked this pull request as ready for review March 5, 2025 10:41
@NeoLegends NeoLegends force-pushed the moritz-shard-mgpu branch from 6e5f6d4 to 1c96291 Compare May 13, 2025 09:22
@NeoLegends NeoLegends force-pushed the moritz-shard-mgpu branch from 1c96291 to 3b28635 Compare May 13, 2025 09:22
Copy link
Collaborator

@Icemole Icemole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few minor comments from my side. The functionality looks good!

Comment on lines +178 to +180
assert 0 <= shard_index < num_shards
self.num_shards = num_shards
self.shard_index = shard_index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert 0 <= shard_index < num_shards
self.num_shards = num_shards
self.shard_index = shard_index
self.set_shard_idx_and_num_shards(shard_index, num_shards)

Slightly cleaner and reuses code, but it's also fine as is. Your choice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function has additional asserts that rely on correct initialization. Not for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function has additional asserts that rely on correct initialization. Not for now.

I don't understand the comment. What correct initialization? What do you mean by "not for now"?

# RETURNN will set sharding info on the dataset if the global config is set.
# If it's not set, however, we need to respect the existing `distrib_shard_files` property
# for backwards compatibility and load the sharding info ourselves.
return CachedDataset2._get_sharding_rank_and_size(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines -147 to +155
sub_dataset = {**self.dataset, "_num_shards": self.num_workers, "_shard_index": i}
sub_dataset = {
**self.dataset,
"num_shards": self.num_workers * self.num_shards,
"shard_index": (self.shard_index * self.num_workers) + i,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not know the context of this comment. Isn't sharding already allowed as per your change here?

@albertz

This comment was marked as resolved.

@NeoLegends

This comment was marked as resolved.

@albertz
Copy link
Member

albertz commented Jul 17, 2025

Sorry for introducing the small conflict, but my change should fix #1678 and #1737 already, and shouldn't really cause any issues to merge with the PR here.

@albertz
Copy link
Member

albertz commented Jul 17, 2025

Btw, also see #1738. Not sure if this is relevant here.

@albertz
Copy link
Member

albertz commented Jul 17, 2025

Can you summarize what this PR does? I will also try to write some summarizes here myself, but please edit your main description of the PR to cover that as well.

@albertz
Copy link
Member

albertz commented Jul 17, 2025

(Summarize) Added feature: when torch.utils.data.DataLoader is used with num_workers>1, this will set the sharding accordingly. (This is independent of the newly introduced global dataset_distribution option.)

Btw, some questions regarding this:

Just to confirm: this is independent of the newly introduced global dataset_distribution option?

What happens when this is used together with distributed training? Will it set num_shards = distrib_world_size * dataloader_num_workers then?

Is the order of seqs you get from the DataLoader deterministic?

Will it always be complete? E.g. if one worker returns more seqs than the other (e.g. total num seqs is 11, and 2 workers), will the DataLoader finish only until all the workers have finished?

def test_dataset_sharding():
from returnn.datasets.audio import OggZipDataset

with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test would be a bit nicer if the num_seqs is uneven, not divisible by num_shards.

Suggested change
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts:
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=11) as dataset_opts:

Comment on lines +55 to +57
self.dataset.set_shard_idx_and_num_shards(
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Why does this consider the existing shard_index/num_shards? I would expect that you overwrite those here.

Suggested change
self.dataset.set_shard_idx_and_num_shards(
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers
)
self.dataset.set_shard_idx_and_num_shards(worker_info.id, worker_info.num_workers)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do you expect that the existing shard_index/num_shards where set with the distributed rank/size, and this here additionally adds further sharding for workerid/num_workers?

But then this code written in this way is very confusing... I think this should be done differently somehow. Not sure how...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also do the distributed logic directly here. We pass on the rank/size info anyway to the subproc children via the env var _RETURNN_TORCH_DISTRIBUTED_INIT_INFO.

The question is, how to handle dataset_distribution. We could also just check the global config here at this point (even though I don't like accessing the global config too much... maybe I get a better idea). Then all the logic about what sharding options to set (or whether to set it at all) is all here in this place, together with the dataloader num workers.

Comment on lines +120 to +121
num_shards: int = 1,
shard_index: int = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand: Why are those public (not prefixed with _)? These are never supposed to be set by the user. Those are either internally set via DataLoader num_workers, or via parent dataset logic like MultiProcDataset, or via distributed training logic somehow, or so. But never directly by the user.

state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]}
return Dataset._create_from_reduce, (self.__class__, kwargs, state)

def set_shard_idx_and_num_shards(self, shard_index: int, num_shards: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small inconsistency: idx vs index.

assert dd_cfg in ["random_seed_offset", "shard"]
shard_index, num_shards = Dataset._get_sharding_rank_and_size(config) if dd_cfg == "shard" else 0, 1
set_or_remove("num_shards", num_shards)
set_or_remove("shard_index", shard_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether it is a good idea to do this logic in kwargs_update_from_config. You don't want to apply this just for every dataset. E.g. for the dev/test/eval datasets, we would not want this logic, at least not right now.

I think this should come from the outside, not from within.

# RETURNN will set sharding info on the dataset if the global config is set.
# If it's not set, however, we need to respect the existing `distrib_shard_files` property
# for backwards compatibility and load the sharding info ourselves.
return CachedDataset2._get_sharding_rank_and_size(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return CachedDataset2._get_sharding_rank_and_size(config)
return cls._get_sharding_rank_and_size(config)

?
Or:

Suggested change
return CachedDataset2._get_sharding_rank_and_size(config)
return Dataset._get_sharding_rank_and_size(config)

?
But CachedDataset2 doesn't really make sense?

@albertz
Copy link
Member

albertz commented Jul 17, 2025

The _get_random_seed_for_epoch, shouldn't it also consider num_shards/shard_index? Or only in the case of dataset_distribution=="dataset_distribution"?

Or not because random_seed_offset already covers this part? (But I find it a bit inconsistent that epoch/partition_epoch is handled here but shard_index/num_shards elsewhere...)

@albertz
Copy link
Member

albertz commented Jul 17, 2025

(Summary) New global config option dataset_distribution, which can be either set to "random_seed_offset" (default) or "shard". This is for distributed training. "shard" will enable sharding for the dataset, so on N GPUs, processing one full epoch will only go through the data once, unlike with "random_seed_offset", where one full epoch sees all the data N times (each worker with different random seed).

)
self.dataset.set_shard_idx_and_num_shards(
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having all this logic here, we maybe should better move it to ReturnnDatasetIterDataPipe.reset though?

Comment on lines +175 to 179
log.print_deprecation_warning(
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. "
"Use global config option `dataset_distribution` instead "
"for the same behavior across more types of datasets."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is really necessary to mark this as deprecated.

Suggested change
log.print_deprecation_warning(
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. "
"Use global config option `dataset_distribution` instead "
"for the same behavior across more types of datasets."
)

Comment on lines +155 to +157
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``.
Set to true to shard the data across worker processes in distributed training scenaria.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary to mark it as deprecated.

Suggested change
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``.
Set to true to shard the data across worker processes in distributed training scenaria.
:param distrib_shard_files: deprecated. set to true to shard the data across worker processes in
distributed training scenaria

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Sharding for multi-GPU training

3 participants