Skip to content

Commit 4748033

Browse files
authored
Performance improvement: Convert unnecessary (?) deep copies to shallow copies (#358)
1 parent d3b75a2 commit 4748033

File tree

3 files changed

+82
-52
lines changed

3 files changed

+82
-52
lines changed

cf_xarray/accessor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,7 +1797,7 @@ def guess_coord_axis(self, verbose: bool = False) -> DataArray | Dataset:
17971797
DataArray or Dataset
17981798
with appropriate attributes added
17991799
"""
1800-
obj = self._obj.copy(deep=True)
1800+
obj = self._obj.copy(deep=False)
18011801
for var in obj.coords.variables:
18021802
var_is_coord = any(var in val for val in obj.cf.coordinates.values())
18031803
if not var_is_coord and obj[var].ndim == 1 and _is_datetime_like(obj[var]):
@@ -1937,7 +1937,7 @@ def add_canonical_attributes(
19371937
info, table, aliases = parse_cf_standard_name_table(source)
19381938

19391939
# Loop over standard names
1940-
ds = self._maybe_to_dataset().copy()
1940+
ds = self._maybe_to_dataset().copy(deep=False)
19411941
attrs_to_print: dict = {}
19421942
for std_name, var_names in ds.cf.standard_names.items():
19431943

@@ -2209,7 +2209,7 @@ def add_bounds(self, keys: str | Iterable[str], *, dim=None):
22092209
apply_mapper(_get_all, self._obj, key, error=False, default=[key])
22102210
)
22112211

2212-
obj = self._maybe_to_dataset(self._obj.copy(deep=True))
2212+
obj = self._maybe_to_dataset(self._obj.copy(deep=False))
22132213

22142214
bad_vars: set[str] = variables - set(obj.variables)
22152215
if bad_vars:
@@ -2286,7 +2286,7 @@ def bounds_to_vertices(
22862286
else:
22872287
coords = keys
22882288

2289-
obj = self._maybe_to_dataset(self._obj.copy(deep=True))
2289+
obj = self._maybe_to_dataset(self._obj.copy(deep=False))
22902290

22912291
for coord in coords:
22922292
try:

cf_xarray/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
* 110e3
1212
)
1313

14-
ds_no_attrs = airds.copy(deep=True)
14+
ds_no_attrs = airds.copy(deep=False)
1515
for _variable in ds_no_attrs.variables:
1616
ds_no_attrs[_variable].attrs = {}
1717

cf_xarray/tests/test_accessor.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_coordinates_quantified():
230230

231231

232232
def test_cell_measures():
233-
ds = airds.copy(deep=True)
233+
ds = airds.copy(deep=False)
234234
ds["foo"] = xr.DataArray(ds["cell_area"], attrs=dict(standard_name="foo_std_name"))
235235
ds["air"].attrs["cell_measures"] += " foo_measure: foo"
236236
assert ("foo_std_name" in ds.cf["air_temperature"].cf) and ("foo_measure" in ds.cf)
@@ -307,7 +307,7 @@ def test_getitem_standard_name():
307307
expected = airds["lat"]
308308
assert_identical(actual, expected)
309309

310-
ds = airds.copy(deep=True)
310+
ds = airds.copy(deep=False)
311311
ds["air2"] = ds.air
312312
with pytest.raises(KeyError):
313313
ds.cf["air_temperature"]
@@ -340,19 +340,18 @@ def test_getitem_ancillary_variables():
340340

341341

342342
def test_rename_like():
343-
original = popds.copy(deep=True)
344-
345343
# it'll match for axis: X (lon, nlon) and coordinate="longitude" (lon, TLONG)
346344
# so delete the axis attributes
347-
newair = airds.copy(deep=True)
345+
original = popds
346+
newair = airds.copy(deep=False)
348347
del newair.lon.attrs["axis"]
349348
del newair.lat.attrs["axis"]
350349

351350
renamed = popds.cf["TEMP"].cf.rename_like(newair)
351+
assert original.TEMP.attrs["coordinates"] == "TLONG TLAT"
352352
for k in ["TLONG", "TLAT"]:
353-
assert k not in renamed.coords
354353
assert k in original.coords
355-
assert original.TEMP.attrs["coordinates"] == "TLONG TLAT"
354+
assert k not in renamed.coords
356355

357356
assert "lon" in renamed.coords
358357
assert "lat" in renamed.coords
@@ -479,7 +478,7 @@ def test_pos_args_methods():
479478

480479
def test_preserve_unused_keys():
481480

482-
ds = airds.copy(deep=True)
481+
ds = airds.copy(deep=False)
483482
ds.time.attrs.clear()
484483
actual = ds.cf.sel(X=260, Y=40, time=airds.time[:2], method="nearest")
485484
expected = ds.sel(lon=260, lat=40, time=airds.time[:2], method="nearest")
@@ -528,7 +527,7 @@ def test_args_methods(obj):
528527

529528
def test_dataarray_getitem():
530529

531-
air = airds.air.copy()
530+
air = airds.air.copy(deep=False)
532531
air.name = None
533532

534533
assert_identical(air.cf["longitude"], air["lon"])
@@ -543,7 +542,7 @@ def test_dataarray_getitem():
543542

544543
def test_dataarray_plot():
545544

546-
obj = airds.air.copy(deep=True)
545+
obj = airds.air.copy(deep=False)
547546

548547
rv = obj.isel(time=1).transpose("lon", "lat").cf.plot()
549548
assert isinstance(rv, mpl.collections.QuadMesh)
@@ -636,14 +635,14 @@ def test_getitem(obj, key, expected_key):
636635
def test_getitem_errors(obj):
637636
with pytest.raises(KeyError):
638637
obj.cf["XX"]
639-
obj2 = obj.copy(deep=True)
638+
obj2 = obj.copy(deep=False)
640639
obj2.lon.attrs = {}
641640
with pytest.raises(KeyError):
642641
obj2.cf["X"]
643642

644643

645644
def test_bad_cell_measures_attribute():
646-
air2 = airds.copy(deep=True)
645+
air2 = airds.copy(deep=False)
647646
air2.air.attrs["cell_measures"] = "--OPT"
648647
with pytest.warns(UserWarning):
649648
air2.cf["air"]
@@ -736,20 +735,21 @@ def test_plot_xincrease_yincrease():
736735

737736
@pytest.mark.parametrize("dims", ["time2", "lat", "time", ["lat", "lon"]])
738737
def test_add_bounds(dims):
739-
obj = airds.copy(deep=True)
738+
ds = airds
739+
original = ds.copy(deep=True)
740740

741741
expected = {}
742742
expected["lat"] = xr.concat(
743743
[
744-
obj.lat.copy(data=np.arange(76.25, 16.0, -2.5)),
745-
obj.lat.copy(data=np.arange(73.75, 13.6, -2.5)),
744+
ds.lat.copy(data=np.arange(76.25, 16.0, -2.5)),
745+
ds.lat.copy(data=np.arange(73.75, 13.6, -2.5)),
746746
],
747747
dim="bounds",
748748
)
749749
expected["lon"] = xr.concat(
750750
[
751-
obj.lon.copy(data=np.arange(198.75, 325 - 1.25, 2.5)),
752-
obj.lon.copy(data=np.arange(201.25, 325 + 1.25, 2.5)),
751+
ds.lon.copy(data=np.arange(198.75, 325 - 1.25, 2.5)),
752+
ds.lon.copy(data=np.arange(201.25, 325 + 1.25, 2.5)),
753753
],
754754
dim="bounds",
755755
)
@@ -759,8 +759,8 @@ def test_add_bounds(dims):
759759
dtb2 = pd.Timedelta("3h")
760760
expected["time"] = xr.concat(
761761
[
762-
obj.time.copy(data=pd.date_range(start=t0 - dtb2, end=t1 - dtb2, freq=dt)),
763-
obj.time.copy(data=pd.date_range(start=t0 + dtb2, end=t1 + dtb2, freq=dt)),
762+
ds.time.copy(data=pd.date_range(start=t0 - dtb2, end=t1 - dtb2, freq=dt)),
763+
ds.time.copy(data=pd.date_range(start=t0 + dtb2, end=t1 + dtb2, freq=dt)),
764764
],
765765
dim="bounds",
766766
)
@@ -769,8 +769,9 @@ def test_add_bounds(dims):
769769
expected["lon"].attrs.clear()
770770
expected["time"].attrs.clear()
771771

772-
obj.coords["time2"] = obj.time
773-
added = obj.cf.add_bounds(dims)
772+
added = ds.copy(deep=False)
773+
added.coords["time2"] = ds.time
774+
added = added.cf.add_bounds(dims)
774775
if isinstance(dims, str):
775776
dims = (dims,)
776777

@@ -780,6 +781,8 @@ def test_add_bounds(dims):
780781
assert added[dim].attrs["bounds"] == name
781782
assert_allclose(added[name].reset_coords(drop=True), expected[dim])
782783

784+
_check_unchanged(original, ds)
785+
783786

784787
def test_add_bounds_multiple():
785788
# Test multiple dimensions
@@ -810,7 +813,7 @@ def test_add_bounds_nd_variable():
810813

811814

812815
def test_bounds():
813-
ds = airds.copy(deep=True).cf.add_bounds("lat")
816+
ds = airds.copy(deep=False).cf.add_bounds("lat")
814817

815818
actual = ds.cf.bounds
816819
expected = {"Y": ["lat_bounds"], "lat": ["lat_bounds"], "latitude": ["lat_bounds"]}
@@ -858,32 +861,36 @@ def test_bounds():
858861

859862
def test_bounds_to_vertices():
860863
# All available
861-
ds = airds.cf.add_bounds(["lon", "lat"])
862-
dsc = ds.cf.bounds_to_vertices()
863-
assert "lon_vertices" in dsc
864-
assert "lat_vertices" in dsc
864+
ds = airds
865+
original = ds.copy(deep=True)
866+
dsb = ds.cf.add_bounds(["lon", "lat"])
867+
dsv = dsb.cf.bounds_to_vertices()
868+
assert "lon_vertices" in dsv
869+
assert "lat_vertices" in dsv
865870

866871
# Giving key
867-
dsc = ds.cf.bounds_to_vertices("longitude")
868-
assert "lon_vertices" in dsc
869-
assert "lat_vertices" not in dsc
872+
dsv = dsb.cf.bounds_to_vertices("longitude")
873+
assert "lon_vertices" in dsv
874+
assert "lat_vertices" not in dsv
870875

871-
dsc = ds.cf.bounds_to_vertices(["longitude", "latitude"])
872-
assert "lon_vertices" in dsc
873-
assert "lat_vertices" in dsc
876+
dsv = dsb.cf.bounds_to_vertices(["longitude", "latitude"])
877+
assert "lon_vertices" in dsv
878+
assert "lat_vertices" in dsv
874879

875880
# Error
876881
with pytest.raises(ValueError):
877-
dsc = ds.cf.bounds_to_vertices("T")
882+
dsv = dsb.cf.bounds_to_vertices("T")
878883

879884
# Words on datetime arrays to
880-
ds = airds.cf.add_bounds("time")
881-
dsc = ds.cf.bounds_to_vertices()
882-
assert "time_bounds" in dsc
885+
dsb = dsb.cf.add_bounds("time")
886+
dsv = dsb.cf.bounds_to_vertices()
887+
assert "time_bounds" in dsv
888+
889+
_check_unchanged(original, ds)
883890

884891

885892
def test_get_bounds_dim_name():
886-
ds = airds.copy(deep=True).cf.add_bounds("lat")
893+
ds = airds.cf.add_bounds("lat")
887894
assert ds.cf.get_bounds_dim_name("latitude") == "bounds"
888895
assert ds.cf.get_bounds_dim_name("lat") == "bounds"
889896

@@ -921,6 +928,26 @@ def _make_names(prefixes):
921928
]
922929

923930

931+
def _check_unchanged(old, new):
932+
# Check data array attributes or global dataset attributes
933+
assert type(old) == type(new)
934+
assert old.attrs.keys() == new.attrs.keys() # set comparison
935+
for att, old_val in old.attrs.items():
936+
assert id(old_val) == id(new.attrs[att])
937+
938+
# Check coordinate attributes and data variable attributes
939+
dicts = [(old.coords, new.coords)]
940+
if isinstance(old, xr.Dataset):
941+
dicts.append((old.data_vars, new.data_vars))
942+
for old_dict, new_dict in dicts:
943+
assert old_dict.keys() == new_dict.keys() # set comparison
944+
for key, old_obj in old_dict.items():
945+
new_obj = new_dict[key]
946+
assert old_obj.attrs.keys() == new_obj.attrs.keys() # set comparison
947+
for att, old_val in old_obj.attrs.items():
948+
assert id(old_val) == id(new_obj.attrs[att]) # numpy-safe comparison
949+
950+
924951
_TIME_NAMES = ["t"] + _make_names(
925952
[
926953
"time",
@@ -1015,7 +1042,7 @@ def test_attributes():
10151042
with pytest.raises(AttributeError):
10161043
airds.da.cf.chunks
10171044

1018-
airds2 = airds.copy(deep=True)
1045+
airds2 = airds.copy(deep=False)
10191046
airds2.lon.attrs = {}
10201047
actual = airds2.cf.sizes
10211048
expected = {"lon": 50, "Y": 25, "T": 4, "latitude": 25, "time": 4}
@@ -1046,7 +1073,7 @@ def test_attributes():
10461073
assert_identical(ds1.cf.data_vars["T"], ds1["T"])
10471074

10481075
# multiple latitudes but only one latitude data_var
1049-
ds = popds.copy(deep=True)
1076+
ds = popds.copy(deep=False)
10501077
for var in ["ULAT", "TLAT"]:
10511078
ds[var].attrs["standard_name"] = "latitude"
10521079
ds = ds.reset_coords("ULAT")
@@ -1068,7 +1095,7 @@ def test_Z_vs_vertical_ROMS():
10681095
romsds.z_rho_dummy.reset_coords(drop=True), romsds.temp.cf["vertical"]
10691096
)
10701097

1071-
romsds = romsds.copy(deep=True)
1098+
romsds = romsds.copy(deep=False)
10721099

10731100
romsds.temp.attrs.clear()
10741101
# look in encoding
@@ -1109,12 +1136,12 @@ def test_param_vcoord_ocean_s_coord():
11091136
romsds.cf.decode_vertical_coords(outnames={"s_rho": "ZZZ_rho"})
11101137
assert "ZZZ_rho" in romsds.coords
11111138

1112-
copy = romsds.copy(deep=True)
1139+
copy = romsds.copy(deep=False)
11131140
del copy["zeta"]
11141141
with pytest.raises(KeyError):
11151142
copy.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"})
11161143

1117-
copy = romsds.copy(deep=True)
1144+
copy = romsds.copy(deep=False)
11181145
copy.s_rho.attrs["formula_terms"] = "s: s_rho C: Cs_r depth: h depth_c: hc"
11191146
with pytest.raises(KeyError):
11201147
copy.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"})
@@ -1125,7 +1152,7 @@ def test_param_vcoord_ocean_sigma_coordinate():
11251152
pomds.cf.decode_vertical_coords(outnames={"sigma": "z"})
11261153
assert_allclose(pomds.z.reset_coords(drop=True), expected.reset_coords(drop=True))
11271154

1128-
copy = pomds.copy(deep=True)
1155+
copy = pomds.copy(deep=False)
11291156
del copy["zeta"]
11301157
with pytest.raises(AssertionError):
11311158
copy.cf.decode_vertical_coords()
@@ -1146,7 +1173,7 @@ def test_formula_terms():
11461173
assert romsds["temp"].cf.formula_terms == srhoterms
11471174
assert romsds["s_rho"].cf.formula_terms == srhoterms
11481175

1149-
s_rho = romsds["s_rho"].copy(deep=True)
1176+
s_rho = romsds["s_rho"].copy(deep=False)
11501177
del s_rho.attrs["standard_name"]
11511178
del s_rho.s_rho.attrs["standard_name"] # TODO: xarray bug
11521179
assert s_rho.cf.formula_terms == srhoterms
@@ -1535,6 +1562,7 @@ def test_datetime_like(reshape):
15351562
def test_add_canonical_attributes(override, skip, verbose, capsys):
15361563

15371564
ds = airds
1565+
original = ds.copy(deep=True)
15381566
cf_ds = ds.cf.add_canonical_attributes(
15391567
override=override, skip=skip, verbose=verbose
15401568
)
@@ -1581,6 +1609,8 @@ def test_add_canonical_attributes(override, skip, verbose, capsys):
15811609
cf_da.attrs.pop("history")
15821610
assert_identical(cf_da, cf_ds["air"])
15831611

1612+
_check_unchanged(original, ds)
1613+
15841614

15851615
@pytest.mark.parametrize("op", ["ge", "gt", "eq", "ne", "le", "lt"])
15861616
def test_flag_features(op):
@@ -1617,20 +1647,20 @@ def test_flag_errors():
16171647
def test_missing_variables():
16181648

16191649
# Bounds
1620-
ds = mollwds.copy(deep=True)
1650+
ds = mollwds.copy(deep=False)
16211651
ds = ds.drop_vars("lon_bounds")
16221652
assert ds.cf.bounds == {"lat": ["lat_bounds"], "latitude": ["lat_bounds"]}
16231653

16241654
with pytest.raises(KeyError, match=r"No results found for 'longitude'."):
16251655
ds.cf.get_bounds("longitude")
16261656

16271657
# Cell measures
1628-
ds = airds.copy(deep=True)
1658+
ds = airds.copy(deep=False)
16291659
ds = ds.drop_vars("cell_area")
16301660
assert ds.cf.cell_measures == {}
16311661

16321662
# Formula terms
1633-
ds = vert.copy(deep=True)
1663+
ds = vert.copy(deep=False)
16341664
ds = ds.drop_vars("ap")
16351665
assert ds.cf.formula_terms == {"lev": {"b": "b", "ps": "ps"}}
16361666

0 commit comments

Comments
 (0)