-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[data] Add load_from_uris
#55554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[data] Add load_from_uris
#55554
Conversation
Signed-off-by: Matthew Owen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new load_from_uris
method to the Dataset
class. The intention is to load data from a column of URIs. My review focuses on several critical issues in the implementation. The use of take_all()
can lead to driver OOM for large datasets. The core logic for applying the decode_fn
is incorrect due to a type mismatch with Dataset.map()
. There are also several other medium to high severity issues related to potential runtime errors, style, and use of hardcoded values. I've provided suggestions to address these points.
# print(f"sampled_size: {sampled_size}") | ||
# print(f"in_memory_size_estimate: {in_memory_size_estimate}") | ||
# print(f"repartitioning to {num_partitions} partitions") | ||
return self.repartition(num_partitions).map(decode_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decode_fn
is being passed to Dataset.map()
, which expects a function that operates on individual rows (i.e., Dict[str, Any]
). However, the type hint and apparent intent of decode_fn
is to operate on a file-like object (io.RawIOBase
). This will cause a TypeError
at runtime.
You likely need to use map_batches()
and within the provided function, open and read the file for each URI in the batch, then apply the decode_fn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
|
||
# Step 2: Sample paths to estimate total in memory size after reading. | ||
sampled_paths = [ | ||
row[uri_column] for row in random.sample(paths, RANDOM_SAMPLE_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random.sample(population, k)
will raise a ValueError
if k > len(population)
. This can happen here if the number of paths is less than RANDOM_SAMPLE_SIZE
. You should guard against this by taking the minimum of the number of paths and the sample size.
row[uri_column] for row in random.sample(paths, RANDOM_SAMPLE_SIZE) | |
row[uri_column] for row in random.sample(paths, min(len(paths), RANDOM_SAMPLE_SIZE)) |
RANDOM_SAMPLE_SIZE = 10 | ||
import math | ||
import random |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def default_decode_fn(f: io.RawIOBase) -> Iterator[DataBatch]: | ||
return f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default implementation of decode_fn
returns the file-like object f
itself. However, the type hint for decode_fn
specifies a return type of Iterator[DataBatch]
. This mismatch should be corrected. The default behavior should probably be to read the file and return it as a binary chunk in a DataBatch
.
num_partitions = math.ceil( | ||
in_memory_size_estimate / (1024 * 1024 * 1024) | ||
) # (Each task reads 1 GiB) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of partitions is calculated using a hardcoded 1 GiB block size. It's better to use DataContext.get_current().target_max_block_size
to respect the configured block size for the dataset, making the behavior more consistent and configurable.
num_partitions = math.ceil( | |
in_memory_size_estimate / (1024 * 1024 * 1024) | |
) # (Each task reads 1 GiB) | |
target_max_block_size = DataContext.get_current().target_max_block_size or ( | |
1024 * 1024 * 1024 | |
) | |
num_partitions = math.ceil( | |
in_memory_size_estimate / target_max_block_size | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the Datasource
implementation, agree target_max_block_size
might be better here
# print(f"num_rows: {len(paths)}") | ||
# print(f"total_sampled_size: {total_sampled_size}") | ||
# print(f"sampled_size: {sampled_size}") | ||
# print(f"in_memory_size_estimate: {in_memory_size_estimate}") | ||
# print(f"repartitioning to {num_partitions} partitions") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self, | ||
uri_column: str, | ||
*, | ||
decode_fn: Optional[Callable[[io.RawIOBase], Iterator[DataBatch]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid blocking this PR on discussion about decode_fn
, I think it might be better to leave it out for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this is fine
import math | ||
import random |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move imports to top of file?
total_sampled_size = ( | ||
ray.data.read_binary_files(sampled_paths) | ||
.map(lambda r: {"size_bytes": len(r["bytes"])}) | ||
.sum("size_bytes") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use FileMetadataProvider
to get the file sizes? It'll be much cheaper than actually reading the data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And actually, if we're just listing all the file sizes with FileMetadataProvider
, I think we can just take the mean.
num_partitions = math.ceil( | ||
in_memory_size_estimate / (1024 * 1024 * 1024) | ||
) # (Each task reads 1 GiB) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the Datasource
implementation, agree target_max_block_size
might be better here
# print(f"sampled_size: {sampled_size}") | ||
# print(f"in_memory_size_estimate: {in_memory_size_estimate}") | ||
# print(f"repartitioning to {num_partitions} partitions") | ||
return self.repartition(num_partitions).map(decode_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
"""Load binary data from a column of URIs. | ||
|
||
Args: | ||
uri_column: The name of the column containing the URIs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please be sure to document what is the returned column type, and what is the name of the new column?
Going to go with a different approach using the expressions API. |
Why are these changes needed?
WIP, will add more description later.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.