Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit 6a946ff

Browse files
Enable mypy (#23)
* re-enable mypy * ignored untyped imports * draft implementation of a TreeNode class which stores children in a dict * separate path-like access out into mixin * pseudocode for node getter * basic idea for a path-like object which inherits from pathlib * pass type checking * implement attach * consolidate tree classes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * passes some basic family tree tests * frozen children * passes all basic family tree tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * copied iterators code over from anytree * get nodes with path-like syntax * relative path method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set and get node methods * copy anytree iterators * add anytree license * change iterator import * copy anytree's string renderer * renderer * refactored treenode to use .get * black * updated datatree tests to match new path API * moved io tests to their own file * reimplemented getitem in terms of .get * reimplemented setitem in terms of .update * remove anytree dependency * from_dict constructor * string representation of tree * fixed tree diff * fixed io * removed cheeky print statements * fixed isomorphism checking * fixed map_over_subtree * removed now-uneeded utils.py compatibility functions * fixed tests for mapped dataset api methods * updated API docs * reimplement __setitem__ in terms of _set * fixed bug by ensuring name of child node is changed to match key it is stored under * updated docs * added whats-new, and put all changes from this PR in it * added summary of previous versions * remove outdated ._add_child method * fix some of the easier typing errors * generic typevar for tree in TreeNode * datatree.py almost passes type checking * ignore remaining typing errors for now * fix / ignore last few typing errors * remove spurious type check Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f492787 commit 6a946ff

File tree

6 files changed

+160
-192
lines changed

6 files changed

+160
-192
lines changed

.pre-commit-config.yaml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,22 @@ repos:
3131
# hooks:
3232
# - id: velin
3333
# args: ["--write", "--compact"]
34-
# - repo: https://github.com/pre-commit/mirrors-mypy
35-
# rev: v0.910
36-
# hooks:
37-
# - id: mypy
38-
# # Copied from setup.cfg
39-
# exclude: "properties|asv_bench"
40-
# additional_dependencies: [
41-
# # Type stubs
42-
# types-python-dateutil,
43-
# types-pkg_resources,
44-
# types-PyYAML,
45-
# types-pytz,
46-
# # Dependencies that are typed
47-
# numpy,
48-
# typing-extensions==3.10.0.0,
49-
# ]
34+
- repo: https://github.com/pre-commit/mirrors-mypy
35+
rev: v0.910
36+
hooks:
37+
- id: mypy
38+
# Copied from setup.cfg
39+
exclude: "properties|asv_bench|docs"
40+
additional_dependencies: [
41+
# Type stubs
42+
types-python-dateutil,
43+
types-pkg_resources,
44+
types-PyYAML,
45+
types-pytz,
46+
# Dependencies that are typed
47+
numpy,
48+
typing-extensions==3.10.0.0,
49+
]
5050
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194
5151
# - repo: https://github.com/asottile/pyupgrade
5252
# rev: v1.22.1

datatree/datatree.py

Lines changed: 38 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from __future__ import annotations
22

3+
from collections import OrderedDict
34
from typing import (
45
TYPE_CHECKING,
56
Any,
67
Callable,
7-
Hashable,
8+
Generic,
89
Iterable,
910
Mapping,
1011
MutableMapping,
12+
Optional,
1113
Tuple,
1214
Union,
1315
)
1416

15-
from xarray import DataArray, Dataset, merge
16-
from xarray.core import dtypes, utils
17+
from xarray import DataArray, Dataset
18+
from xarray.core import utils
1719
from xarray.core.variable import Variable
1820

1921
from .formatting import tree_repr
@@ -24,7 +26,7 @@
2426
MappedDataWithCoords,
2527
)
2628
from .render import RenderTree
27-
from .treenode import NodePath, TreeNode
29+
from .treenode import NodePath, Tree, TreeNode
2830

2931
if TYPE_CHECKING:
3032
from xarray.core.merge import CoercibleValue
@@ -51,6 +53,7 @@ class DataTree(
5153
MappedDatasetMethodsMixin,
5254
MappedDataWithCoords,
5355
DataTreeArithmeticMixin,
56+
Generic[Tree],
5457
):
5558
"""
5659
A tree-like hierarchical collection of xarray objects.
@@ -74,12 +77,14 @@ class DataTree(
7477

7578
# TODO .loc, __contains__, __iter__, __array__, __len__
7679

77-
_name: str | None
78-
_ds: Dataset | None
80+
_name: Optional[str]
81+
_parent: Optional[Tree]
82+
_children: OrderedDict[str, Tree]
83+
_ds: Dataset
7984

8085
def __init__(
8186
self,
82-
data: Dataset | DataArray = None,
87+
data: Optional[Dataset | DataArray] = None,
8388
parent: DataTree = None,
8489
children: Mapping[str, DataTree] = None,
8590
name: str = None,
@@ -109,9 +114,9 @@ def __init__(
109114
"""
110115

111116
super().__init__(children=children)
112-
self._name = name
117+
self.name = name
113118
self.parent = parent
114-
self.ds = data
119+
self.ds = data # type: ignore[assignment]
115120

116121
@property
117122
def name(self) -> str | None:
@@ -122,8 +127,13 @@ def name(self) -> str | None:
122127
def name(self, name: str | None) -> None:
123128
self._name = name
124129

125-
@TreeNode.parent.setter
126-
def parent(self, new_parent: DataTree) -> None:
130+
@property
131+
def parent(self: DataTree) -> DataTree | None:
132+
"""Parent of this node."""
133+
return self._parent
134+
135+
@parent.setter
136+
def parent(self: DataTree, new_parent: DataTree) -> None:
127137
if new_parent and self.name is None:
128138
raise ValueError("Cannot set an unnamed node as a child of another node")
129139
self._set_parent(new_parent, self.name)
@@ -134,7 +144,7 @@ def ds(self) -> Dataset:
134144
return self._ds
135145

136146
@ds.setter
137-
def ds(self, data: Union[Dataset, DataArray] = None):
147+
def ds(self, data: Union[Dataset, DataArray] = None) -> None:
138148
if not isinstance(data, (Dataset, DataArray)) and data is not None:
139149
raise TypeError(
140150
f"{type(data)} object is not an xarray Dataset, DataArray, or None"
@@ -168,7 +178,7 @@ def is_empty(self) -> bool:
168178
"""False if node contains any data or attrs. Does not look at children."""
169179
return not (self.has_data or self.has_attrs)
170180

171-
def _pre_attach(self, parent: TreeNode) -> None:
181+
def _pre_attach(self: DataTree, parent: DataTree) -> None:
172182
"""
173183
Method which superclass calls before setting parent, here used to prevent having two
174184
children with duplicate names (or a data variable with the same name as a child).
@@ -186,8 +196,8 @@ def __str__(self):
186196
return tree_repr(self)
187197

188198
def get(
189-
self, key: str, default: DataTree | DataArray = None
190-
) -> DataTree | DataArray | None:
199+
self: DataTree, key: str, default: Optional[DataTree | DataArray] = None
200+
) -> Optional[DataTree | DataArray]:
191201
"""
192202
Access child nodes stored in this node as a DataTree or variables or coordinates stored in this node as a
193203
DataArray.
@@ -207,7 +217,7 @@ def get(
207217
else:
208218
return default
209219

210-
def __getitem__(self, key: str) -> DataTree | DataArray:
220+
def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
211221
"""
212222
Access child nodes stored in this tree as a DataTree or variables or coordinates stored in this tree as a
213223
DataArray.
@@ -272,7 +282,7 @@ def __setitem__(
272282
else:
273283
raise ValueError("Invalid format for key")
274284

275-
def update(self, other: Dataset | Mapping[str, DataTree | CoercibleValue]) -> None:
285+
def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None:
276286
"""
277287
Update this node's children and / or variables.
278288
@@ -285,10 +295,8 @@ def update(self, other: Dataset | Mapping[str, DataTree | CoercibleValue]) -> No
285295
if isinstance(v, DataTree):
286296
new_children[k] = v
287297
elif isinstance(v, (DataArray, Variable)):
288-
# TODO this should also accomodate other types that can be coerced into Variables
298+
# TODO this should also accommodate other types that can be coerced into Variables
289299
new_variables[k] = v
290-
elif isinstance(v, Dataset):
291-
new_variables = v.variables
292300
else:
293301
raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")
294302

@@ -298,7 +306,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | CoercibleValue]) -> No
298306
@classmethod
299307
def from_dict(
300308
cls,
301-
d: MutableMapping[str, Any],
309+
d: MutableMapping[str, DataTree | Dataset | DataArray],
302310
name: str = None,
303311
) -> DataTree:
304312
"""
@@ -322,15 +330,16 @@ def from_dict(
322330
"""
323331

324332
# First create the root node
333+
# TODO there is a real bug here where what if root_data is of type DataTree?
325334
root_data = d.pop("/", None)
326-
obj = cls(name=name, data=root_data, parent=None, children=None)
335+
obj = cls(name=name, data=root_data, parent=None, children=None) # type: ignore[arg-type]
327336

328337
if d:
329338
# Populate tree with children determined from data_objects mapping
330339
for path, data in d.items():
331340
# Create and set new node
332341
node_name = NodePath(path).name
333-
new_node = cls(name=node_name, data=data)
342+
new_node = cls(name=node_name, data=data) # type: ignore[arg-type]
334343
obj._set_item(
335344
path,
336345
new_node,
@@ -346,8 +355,8 @@ def nbytes(self) -> int:
346355
def isomorphic(
347356
self,
348357
other: DataTree,
349-
from_root=False,
350-
strict_names=False,
358+
from_root: bool = False,
359+
strict_names: bool = False,
351360
) -> bool:
352361
"""
353362
Two DataTrees are considered isomorphic if every node has the same number of children.
@@ -386,7 +395,7 @@ def isomorphic(
386395
except (TypeError, TreeIsomorphismError):
387396
return False
388397

389-
def equals(self, other: DataTree, from_root=True) -> bool:
398+
def equals(self, other: DataTree, from_root: bool = True) -> bool:
390399
"""
391400
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
392401
and if they have matching variables and coordinates, all of which are equal.
@@ -479,7 +488,8 @@ def map_over_subtree(
479488
"""
480489
# TODO this signature means that func has no way to know which node it is being called upon - change?
481490

482-
return map_over_subtree(func)(self, *args, **kwargs)
491+
# TODO fix this typing error
492+
return map_over_subtree(func)(self, *args, **kwargs) # type: ignore[operator]
483493

484494
def map_over_subtree_inplace(
485495
self,
@@ -516,31 +526,6 @@ def render(self):
516526
for ds_line in repr(node.ds)[1:]:
517527
print(f"{fill}{ds_line}")
518528

519-
# TODO re-implement using anytree findall function?
520-
def get_all(self, *tags: Hashable) -> DataTree:
521-
"""
522-
Return a DataTree containing the stored objects whose path contains all of the given tags,
523-
where the tags can be present in any order.
524-
"""
525-
matching_children = {
526-
c.tags: c.get_node(tags)
527-
for c in self.descendants
528-
if all(tag in c.tags for tag in tags)
529-
}
530-
return DataTree(data_objects=matching_children)
531-
532-
# TODO re-implement using anytree find function?
533-
def get_any(self, *tags: Hashable) -> DataTree:
534-
"""
535-
Return a DataTree containing the stored objects whose path contains any of the given tags.
536-
"""
537-
matching_children = {
538-
c.tags: c.get_node(tags)
539-
for c in self.descendants
540-
if any(tag in c.tags for tag in tags)
541-
}
542-
return DataTree(data_objects=matching_children)
543-
544529
def merge(self, datatree: DataTree) -> DataTree:
545530
"""Merge all the leaves of a second DataTree into this one."""
546531
raise NotImplementedError
@@ -549,23 +534,7 @@ def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree:
549534
"""Merge a set of child nodes into a single new node."""
550535
raise NotImplementedError
551536

552-
def merge_child_datasets(
553-
self,
554-
*paths: T_Path,
555-
compat: str = "no_conflicts",
556-
join: str = "outer",
557-
fill_value: Any = dtypes.NA,
558-
combine_attrs: str = "override",
559-
) -> Dataset:
560-
"""Merge the datasets at a set of child nodes and return as a single Dataset."""
561-
datasets = [self.get(path).ds for path in paths]
562-
return merge(
563-
datasets,
564-
compat=compat,
565-
join=join,
566-
fill_value=fill_value,
567-
combine_attrs=combine_attrs,
568-
)
537+
# TODO some kind of .collapse() or .flatten() method to merge a subtree
569538

570539
def as_array(self) -> DataArray:
571540
return self.ds.as_dataarray()

0 commit comments

Comments
 (0)