-
Notifications
You must be signed in to change notification settings - Fork 133
Dataset: implement global dataset_distribution
option
#1676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
dataset_distribution
option
9a10fef
to
924045e
Compare
This comment was marked as outdated.
This comment was marked as outdated.
@albertz Do you think this needs a test around the config processing? |
returnn/datasets/basic.py
Outdated
""" | ||
from returnn.config import get_global_config | ||
|
||
config = get_global_config(raise_exception=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like that we access the global config here. I know this follows similar code as _get_default_random_seed_offset
but I also don't like it there. Why is this needed? This should come from outside, or not? Specifically at the place where we call init_dataset
. E.g. in the __main__
. There we also call Dataset.kwargs_update_from_config
.
Also, the code is wrong. Distributed training is only one possible source which defines/influences the shard index and num shards. But there are other reasons, for example the MultiProcDataset, or PyTorch DataLoader num_workers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MultiProcDataset
This is already setting the num_shards and shard index for its children. The code always was designed in such a way that it would only look at the global config in case there was no value already set. But I agree, now it's better.
PyTorch DataLoader num_workers
I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset. So factoring in the PyTorch num_workers needs to be done during the initial dataset initialization, which mixes up pytorch code with data initialization code a bit, and I feel that is going to be a bit messy. Do you know a good way? Maybe this is fine after all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot achieve both
But e.g. there could be a test with PyTorch DataLoader num_workers=2 which checks that all data from the dataset was properly covered.
and
I know this follows similar code as _get_default_random_seed_offset but I also don't like it there. Why is this needed? This should come from outside, or not?
because of
I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset
However, I think it's worth it to have proper support for torch_dataloader_opts = {"num_workers": n}
with n > 1
because this makes it much simpler for the end user to have multi-process data loading and this feature can replace MultiProcDataset
for simple use cases. So I think I need to revert back on the changes where num_shards
and shard_index
are set from the outside, and rather fetch them from inside the dataset, when they are needed. At that point the torch worker info is also available, which means that we can take it into account properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset.
Why difficult? We maybe just need a clean dataset API for that, some setter function set_num_shards_and_shard_idx
or so. And then in ReturnnDatasetIterDataPipe.reset
or so we just need to call that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I was originally not a fan of the mutability of these properties, but it seems ok now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot avoid such API like set_num_shards_and_shard_idx
because of how the PyTorch data pipeline works.
I'm not exactly sure what you mean by that. But e.g. there could be a test with PyTorch DataLoader num_workers=2 which checks that all data from the dataset was properly covered. |
8dca377
to
7d753f5
Compare
6e5f6d4
to
1c96291
Compare
1c96291
to
3b28635
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few minor comments from my side. The functionality looks good!
assert 0 <= shard_index < num_shards | ||
self.num_shards = num_shards | ||
self.shard_index = shard_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert 0 <= shard_index < num_shards | |
self.num_shards = num_shards | |
self.shard_index = shard_index | |
self.set_shard_idx_and_num_shards(shard_index, num_shards) |
Slightly cleaner and reuses code, but it's also fine as is. Your choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That function has additional asserts that rely on correct initialization. Not for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That function has additional asserts that rely on correct initialization. Not for now.
I don't understand the comment. What correct initialization? What do you mean by "not for now"?
# RETURNN will set sharding info on the dataset if the global config is set. | ||
# If it's not set, however, we need to respect the existing `distrib_shard_files` property | ||
# for backwards compatibility and load the sharding info ourselves. | ||
return CachedDataset2._get_sharding_rank_and_size(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't you call the local _get_rank_and_size
as per the previous change in https://github.com/rwth-i6/returnn/pull/1676/files#diff-44adfde5339aa94bf0770b09138330d5ea06d6b8e3f3b975bf270058f8c0c4baL188?
sub_dataset = {**self.dataset, "_num_shards": self.num_workers, "_shard_index": i} | ||
sub_dataset = { | ||
**self.dataset, | ||
"num_shards": self.num_workers * self.num_shards, | ||
"shard_index": (self.shard_index * self.num_workers) + i, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might not know the context of this comment. Isn't sharding already allowed as per your change here?
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
Btw, also see #1738. Not sure if this is relevant here. |
Can you summarize what this PR does? I will also try to write some summarizes here myself, but please edit your main description of the PR to cover that as well. |
(Summarize) Added feature: when Btw, some questions regarding this: Just to confirm: this is independent of the newly introduced global What happens when this is used together with distributed training? Will it set num_shards = distrib_world_size * dataloader_num_workers then? Is the order of seqs you get from the Will it always be complete? E.g. if one worker returns more seqs than the other (e.g. total num seqs is 11, and 2 workers), will the |
def test_dataset_sharding(): | ||
from returnn.datasets.audio import OggZipDataset | ||
|
||
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the test would be a bit nicer if the num_seqs is uneven, not divisible by num_shards.
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts: | |
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=11) as dataset_opts: |
self.dataset.set_shard_idx_and_num_shards( | ||
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand. Why does this consider the existing shard_index/num_shards? I would expect that you overwrite those here.
self.dataset.set_shard_idx_and_num_shards( | |
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers | |
) | |
self.dataset.set_shard_idx_and_num_shards(worker_info.id, worker_info.num_workers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or do you expect that the existing shard_index/num_shards where set with the distributed rank/size, and this here additionally adds further sharding for workerid/num_workers?
But then this code written in this way is very confusing... I think this should be done differently somehow. Not sure how...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also do the distributed logic directly here. We pass on the rank/size info anyway to the subproc children via the env var _RETURNN_TORCH_DISTRIBUTED_INIT_INFO
.
The question is, how to handle dataset_distribution
. We could also just check the global config here at this point (even though I don't like accessing the global config too much... maybe I get a better idea). Then all the logic about what sharding options to set (or whether to set it at all) is all here in this place, together with the dataloader num workers.
num_shards: int = 1, | ||
shard_index: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand: Why are those public (not prefixed with _
)? These are never supposed to be set by the user. Those are either internally set via DataLoader num_workers, or via parent dataset logic like MultiProcDataset, or via distributed training logic somehow, or so. But never directly by the user.
state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]} | ||
return Dataset._create_from_reduce, (self.__class__, kwargs, state) | ||
|
||
def set_shard_idx_and_num_shards(self, shard_index: int, num_shards: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small inconsistency: idx vs index.
assert dd_cfg in ["random_seed_offset", "shard"] | ||
shard_index, num_shards = Dataset._get_sharding_rank_and_size(config) if dd_cfg == "shard" else 0, 1 | ||
set_or_remove("num_shards", num_shards) | ||
set_or_remove("shard_index", shard_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether it is a good idea to do this logic in kwargs_update_from_config
. You don't want to apply this just for every dataset. E.g. for the dev/test/eval datasets, we would not want this logic, at least not right now.
I think this should come from the outside, not from within.
# RETURNN will set sharding info on the dataset if the global config is set. | ||
# If it's not set, however, we need to respect the existing `distrib_shard_files` property | ||
# for backwards compatibility and load the sharding info ourselves. | ||
return CachedDataset2._get_sharding_rank_and_size(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return CachedDataset2._get_sharding_rank_and_size(config) | |
return cls._get_sharding_rank_and_size(config) |
?
Or:
return CachedDataset2._get_sharding_rank_and_size(config) | |
return Dataset._get_sharding_rank_and_size(config) |
?
But CachedDataset2
doesn't really make sense?
The Or not because |
(Summary) New global config option |
) | ||
self.dataset.set_shard_idx_and_num_shards( | ||
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having all this logic here, we maybe should better move it to ReturnnDatasetIterDataPipe.reset
though?
log.print_deprecation_warning( | ||
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. " | ||
"Use global config option `dataset_distribution` instead " | ||
"for the same behavior across more types of datasets." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is really necessary to mark this as deprecated.
log.print_deprecation_warning( | |
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. " | |
"Use global config option `dataset_distribution` instead " | |
"for the same behavior across more types of datasets." | |
) |
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``. | ||
Set to true to shard the data across worker processes in distributed training scenaria. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary to mark it as deprecated.
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``. | |
Set to true to shard the data across worker processes in distributed training scenaria. | |
:param distrib_shard_files: deprecated. set to true to shard the data across worker processes in | |
distributed training scenaria |
Closes #1634