Skip to content

Commit 5aa2878

Browse files
committed
fix types
1 parent 196225b commit 5aa2878

File tree

1 file changed

+69
-71
lines changed

1 file changed

+69
-71
lines changed

cf_xarray/accessor.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class GridMapping:
9797
The coordinate reference system object
9898
array : xarray.DataArray
9999
The grid mapping variable as a DataArray containing the CRS parameters
100-
coordinates : tuple[str, ...]
100+
coordinates : tuple[Hashable, ...]
101101
Names of coordinate variables associated with this grid mapping. For grid mappings
102102
that are explicitly listed with coordinates in the grid_mapping attribute
103103
(e.g., ``'spatial_ref: crs_4326: latitude longitude'``), this contains those coordinates.
@@ -109,7 +109,7 @@ class GridMapping:
109109
name: str
110110
crs: Any # really pyproj.CRS
111111
array: xr.DataArray
112-
coordinates: tuple[str, ...]
112+
coordinates: tuple[Hashable, ...]
113113

114114
def __repr__(self) -> str:
115115
# Try to get EPSG code first, fallback to shorter description
@@ -496,7 +496,7 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
496496
return list(results)
497497

498498

499-
def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str]]:
499+
def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hashable]]:
500500
"""
501501
Parse a grid_mapping attribute that may contain multiple grid mappings.
502502
@@ -522,7 +522,7 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str]
522522
if not grid_mappings:
523523
return {grid_mapping_attr.strip(): []}
524524

525-
result = {}
525+
result: dict[str, list[Hashable]] = {}
526526

527527
# Now extract coordinates for each grid mapping
528528
# Split the string to find what comes after each grid mapping variable
@@ -545,13 +545,76 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[str]
545545
coords = coord_text.split() if coord_text else []
546546
# Filter out the next grid mapping variable if it got captured
547547
coords = [c for c in coords if c not in grid_mappings]
548-
result[gm] = coords
548+
result[gm] = coords # type: ignore[assignment]
549549
else:
550550
result[gm] = []
551551

552552
return result
553553

554554

555+
def _create_grid_mapping(
556+
var_name: str,
557+
obj_dataset: Dataset,
558+
grid_mapping_dict: dict[str, list[Hashable]],
559+
) -> GridMapping:
560+
"""
561+
Create a GridMapping dataclass instance from a grid mapping variable.
562+
563+
Parameters
564+
----------
565+
var_name : str
566+
Name of the grid mapping variable
567+
obj_dataset : Dataset
568+
Dataset containing the grid mapping variable
569+
grid_mapping_dict : dict[str, list[Hashable]]
570+
Dictionary mapping grid mapping variable names to their coordinate variables
571+
572+
Returns
573+
-------
574+
GridMapping
575+
GridMapping dataclass instance
576+
577+
Notes
578+
-----
579+
Assumes pyproj is available (should be checked by caller).
580+
"""
581+
from pyproj import (
582+
CRS, # Safe to import since grid_mappings property checks availability
583+
)
584+
585+
var = obj_dataset._variables[var_name]
586+
587+
# Create DataArray from Variable, preserving the name
588+
# Use reset_coords(drop=True) to avoid coordinate conflicts
589+
if var_name in obj_dataset.coords:
590+
da = obj_dataset.coords[var_name].reset_coords(drop=True)
591+
else:
592+
da = obj_dataset[var_name].reset_coords(drop=True)
593+
594+
# Get the CF grid mapping name from the variable's attributes
595+
cf_name = var.attrs.get("grid_mapping_name", var_name)
596+
597+
# Create CRS from the grid mapping variable
598+
try:
599+
crs = CRS.from_cf(var.attrs)
600+
except Exception:
601+
# If CRS creation fails, use None
602+
crs = None
603+
604+
# Get associated coordinate variables, fallback to dimension names
605+
coordinates: list[Hashable] = grid_mapping_dict.get(var_name, [])
606+
if not coordinates:
607+
# For DataArrays, find the data variable that references this grid mapping
608+
for _data_var_name, data_var in obj_dataset.data_vars.items():
609+
if "grid_mapping" in data_var.attrs:
610+
gm_attr = data_var.attrs["grid_mapping"]
611+
if var_name in gm_attr:
612+
coordinates = list(data_var.dims)
613+
break
614+
615+
return GridMapping(name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates))
616+
617+
555618
def _get_grid_mapping_name(obj: DataArray | Dataset, key: str) -> list[str]:
556619
"""
557620
Translate from grid mapping name attribute to appropriate variable name.
@@ -2462,71 +2525,6 @@ def add_canonical_attributes(
24622525

24632526
return obj
24642527

2465-
def _create_grid_mapping(
2466-
self,
2467-
var_name: str,
2468-
obj_dataset: Dataset,
2469-
grid_mapping_dict: dict[str, list[str]],
2470-
) -> GridMapping:
2471-
"""
2472-
Create a GridMapping dataclass instance from a grid mapping variable.
2473-
2474-
Parameters
2475-
----------
2476-
var_name : str
2477-
Name of the grid mapping variable
2478-
obj_dataset : Dataset
2479-
Dataset containing the grid mapping variable
2480-
grid_mapping_dict : dict[str, list[str]]
2481-
Dictionary mapping grid mapping variable names to their coordinate variables
2482-
2483-
Returns
2484-
-------
2485-
GridMapping
2486-
GridMapping dataclass instance
2487-
2488-
Notes
2489-
-----
2490-
Assumes pyproj is available (should be checked by caller).
2491-
"""
2492-
from pyproj import (
2493-
CRS, # Safe to import since grid_mappings property checks availability
2494-
)
2495-
2496-
var = obj_dataset._variables[var_name]
2497-
2498-
# Create DataArray from Variable, preserving the name
2499-
# Use reset_coords(drop=True) to avoid coordinate conflicts
2500-
if var_name in obj_dataset.coords:
2501-
da = obj_dataset.coords[var_name].reset_coords(drop=True)
2502-
else:
2503-
da = obj_dataset[var_name].reset_coords(drop=True)
2504-
2505-
# Get the CF grid mapping name from the variable's attributes
2506-
cf_name = var.attrs.get("grid_mapping_name", var_name)
2507-
2508-
# Create CRS from the grid mapping variable
2509-
try:
2510-
crs = CRS.from_cf(var.attrs)
2511-
except Exception:
2512-
# If CRS creation fails, use None
2513-
crs = None
2514-
2515-
# Get associated coordinate variables, fallback to dimension names
2516-
coordinates = grid_mapping_dict.get(var_name, [])
2517-
if not coordinates:
2518-
# For DataArrays, find the data variable that references this grid mapping
2519-
for _data_var_name, data_var in obj_dataset.data_vars.items():
2520-
if "grid_mapping" in data_var.attrs:
2521-
gm_attr = data_var.attrs["grid_mapping"]
2522-
if var_name in gm_attr:
2523-
coordinates = list(data_var.dims)
2524-
break
2525-
2526-
return GridMapping(
2527-
name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates)
2528-
)
2529-
25302528
@property
25312529
def grid_mappings(self) -> tuple[GridMapping, ...]:
25322530
"""
@@ -2611,7 +2609,7 @@ def grid_mappings(self) -> tuple[GridMapping, ...]:
26112609
continue
26122610

26132611
grid_mappings.append(
2614-
self._create_grid_mapping(var_name, obj_dataset, grid_mapping_dict)
2612+
_create_grid_mapping(var_name, obj_dataset, grid_mapping_dict)
26152613
)
26162614

26172615
return tuple(grid_mappings)

0 commit comments

Comments
 (0)