Skip to content
Draft
83 changes: 8 additions & 75 deletions src/uproot/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def keys_for_buffer_keys(self, buffer_keys: frozenset[str]) -> frozenset[str]:
keys: set[str] = set()
for buffer_key in buffer_keys:
# Identify form key
form_key, attribute = buffer_key.rsplit("-", maxsplit=1)
form_key, attribute = buffer_key.replace("@.", "<root>.").rsplit("-", maxsplit=1)
# Identify key from form_key
keys.add(self._form_key_to_key[form_key])
return frozenset(keys)
Expand Down Expand Up @@ -954,6 +954,7 @@ def __call__(self, form: Form) -> tuple[Form, TrivialFormMappingInfo]:
class UprootReadMixin:
base_form: Form
expected_form: Form
behavior = {}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how behaviours should be handled, so this probably needs updating

form_mapping_info: ImplementsFormMappingInfo
common_keys: frozenset[str]
interp_options: dict[str, Any]
Expand Down Expand Up @@ -1026,83 +1027,15 @@ def read_tree(
assert tree.source # we must be reading something here
return out, tree.source.performance_counters

def mock(self) -> AwkArray:
awkward = uproot.extras.awkward()
return awkward.typetracer.typetracer_from_form(
self.expected_form,
highlevel=True,
behavior=self.form_mapping_info.behavior,
)

def mock_empty(self, backend) -> AwkArray:
awkward = uproot.extras.awkward()
return awkward.to_backend(
self.expected_form.length_zero_array(highlevel=False),
backend=backend,
highlevel=True,
behavior=self.form_mapping_info.behavior,
)

def prepare_for_projection(self) -> tuple[AwkArray, TypeTracerReport, dict]:
awkward = uproot.extras.awkward()
dask_awkward = uproot.extras.dask_awkward()

# A form mapping will (may) remap the base form into a new form
# The remapped form can be queried for structural information

# Build typetracer and associated report object
meta, report = awkward.typetracer.typetracer_with_report(
self.expected_form,
highlevel=True,
behavior=self.form_mapping_info.behavior,
buffer_key=self.form_mapping_info.buffer_key,
)

return (
meta,
report,
{
"trace": dask_awkward.lib.utils.trace_form_structure(
self.expected_form,
buffer_key=self.form_mapping_info.buffer_key,
),
"form_info": self.form_mapping_info,
},
)

def project(self: T, *, report: TypeTracerReport, state: dict) -> T:
keys = self.necessary_columns(report=report, state=state)
return self.project_keys(keys)

def necessary_columns(
self, *, report: TypeTracerReport, state: dict
) -> frozenset[str]:
## Read from stash
# Form hierarchy information
form_key_to_parent_form_key: dict = state["trace"][
"form_key_to_parent_form_key"
]
# Buffer hierarchy information
form_key_to_buffer_keys: dict = state["trace"]["form_key_to_buffer_keys"]
# Restructured form information
form_info = state["form_info"]

# Require the data of metadata buffers above shape-only requests
dask_awkward = uproot.extras.dask_awkward()
data_buffers = {
*report.data_touched,
*dask_awkward.lib.utils.buffer_keys_required_to_compute_shapes(
form_info.parse_buffer_key,
report.shape_touched,
form_key_to_parent_form_key,
form_key_to_buffer_keys,
),
}
@property
def form(self):
return self.expected_form

# Determine which TTree keys need to be read
return form_info.keys_for_buffer_keys(data_buffers) & frozenset(
def project(self, columns) -> T:
keys = self.form_mapping_info.keys_for_buffer_keys(columns) & frozenset(
self.common_keys
)
return self.project_keys(keys)

def project_keys(self: T, keys: frozenset[str]) -> T:
raise NotImplementedError
Expand Down