Skip to content

Commit 42fda15

Browse files
committed
Merge fix-dateline-clean: handle antimeridian-crossing bounding boxes (fix #18)
2 parents 82b8112 + eb7db69 commit 42fda15

File tree

4 files changed

+366
-29
lines changed

4 files changed

+366
-29
lines changed

sardem/cop_dem.py

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,116 @@ def download_and_stitch(
3232
https://spacedata.copernicus.eu/web/cscda/dataset-details?articleId=394198
3333
https://copernicus-dem-30m.s3.amazonaws.com/readme.html
3434
"""
35+
import os
36+
import tempfile
37+
3538
from osgeo import gdal
3639

3740
gdal.UseExceptions()
38-
# TODO: does downloading make it run any faster?
39-
# if download_vrt:
40-
# cache_dir = utils.get_cache_dir()
41-
# vrt_filename = os.path.join(cache_dir, "cop_global.vrt")
42-
# if not os.path.exists(vrt_filename):
43-
# make_cop_vrt(vrt_filename)
44-
# else:
41+
4542
if vrt_filename is None:
4643
vrt_filename = "/vsicurl/https://raw.githubusercontent.com/scottstanie/sardem/master/sardem/data/cop_global.vrt" # noqa
4744

45+
bboxes = utils.check_dateline(bbox)
46+
47+
if len(bboxes) == 1:
48+
_download_single_bbox(
49+
output_name,
50+
bbox,
51+
vrt_filename,
52+
keep_egm,
53+
xrate,
54+
yrate,
55+
output_format,
56+
output_type,
57+
)
58+
return
59+
60+
# Dateline crossing: download each sub-bbox, shift tiles, merge
61+
logger.info(
62+
"Dateline crossing detected, downloading {} separate regions".format(
63+
len(bboxes)
64+
)
65+
)
66+
67+
temp_files = []
68+
with tempfile.TemporaryDirectory() as tmpdir:
69+
for idx, sub_bbox in enumerate(bboxes):
70+
temp_file = os.path.join(tmpdir, "dem_part_{}.tif".format(idx))
71+
temp_files.append(temp_file)
72+
logger.info("Downloading region {} of {}".format(idx + 1, len(bboxes)))
73+
_download_single_bbox(
74+
temp_file,
75+
sub_bbox,
76+
vrt_filename,
77+
keep_egm,
78+
xrate,
79+
yrate,
80+
"GTiff",
81+
output_type,
82+
)
83+
84+
# Shift eastern tiles so they're adjacent to western tiles in pixel space
85+
for temp_file in temp_files:
86+
_shift_tile_if_needed(temp_file)
87+
88+
logger.info("Merging {} regions into final DEM".format(len(temp_files)))
89+
vrt_temp = os.path.join(tmpdir, "merged.vrt")
90+
gdal.BuildVRT(vrt_temp, temp_files)
91+
92+
if output_format == "GTiff":
93+
gdal.Warp(
94+
output_name,
95+
vrt_temp,
96+
options=gdal.WarpOptions(
97+
format=output_format,
98+
multithread=True,
99+
callback=gdal.TermProgress,
100+
),
101+
)
102+
else:
103+
gdal.Translate(
104+
output_name,
105+
vrt_temp,
106+
format=output_format,
107+
callback=gdal.TermProgress,
108+
)
109+
110+
111+
def _shift_tile_if_needed(filepath):
112+
"""Shift tile geotransform by -360 if x origin is positive.
113+
114+
Makes eastern tiles (e.g., 179.7 to 180) adjacent to western tiles
115+
(e.g., -180 to -179.8) by shifting to (-180.3 to -180).
116+
Only useful for tiles that are part of a dateline-crossing split.
117+
"""
118+
from osgeo import gdal
119+
120+
ds = gdal.Open(filepath, gdal.GA_Update)
121+
gt = list(ds.GetGeoTransform())
122+
if gt[0] > 0:
123+
logger.info("Shifting tile {} x origin from {} to {}".format(
124+
filepath, gt[0], gt[0] - 360.0
125+
))
126+
gt[0] -= 360.0
127+
ds.SetGeoTransform(gt)
128+
ds.FlushCache()
129+
ds = None
130+
131+
132+
def _download_single_bbox(
133+
output_name,
134+
bbox,
135+
vrt_filename,
136+
keep_egm,
137+
xrate,
138+
yrate,
139+
output_format,
140+
output_type,
141+
):
142+
"""Download a single bbox from the COP DEM."""
143+
from osgeo import gdal
144+
48145
if keep_egm:
49146
t_srs = s_srs = None
50147
else:
@@ -55,7 +152,6 @@ def download_and_stitch(
55152
yres = DEFAULT_RES / yrate
56153
resamp = "bilinear" if (xrate > 1 or yrate > 1) else "nearest"
57154

58-
# access_mode = "overwrite" if overwrite else None
59155
option_dict = dict(
60156
format=output_format,
61157
outputBounds=utils.align_bounds_to_pixel_grid(bbox),
@@ -69,29 +165,24 @@ def download_and_stitch(
69165
warpMemoryLimit=5000,
70166
warpOptions=["NUM_THREADS=4"],
71167
)
72-
# When converting from geoid to ellipsoid, preserve ocean (value=0) as nodata
73-
# COP DEM has ocean areas as 0 (sea level relative to geoid), which would
74-
# otherwise become ~geoid_undulation after the vertical datum conversion
168+
# Preserve ocean (value=0) as nodata during geoid-to-ellipsoid conversion
75169
if not keep_egm:
76170
option_dict["srcNodata"] = 0
77171
option_dict["dstNodata"] = 0
78172

79-
# Used the __RETURN_OPTION_LIST__ to get the list of options for debugging
80173
logger.info("Creating {}".format(output_name))
81174
logger.info("Fetching remote tiles...")
82175
try:
83176
cmd = _gdal_cmd_from_options(vrt_filename, output_name, option_dict)
84177
logger.info("Running GDAL command:")
85178
logger.info(cmd)
86179
except Exception:
87-
# Can't form the cli version due to `deepcopy` Pickle error, just skip
88180
logger.info("Running gdal.Warp with options:")
89181
logger.info(option_dict)
90182
pass
91-
# Now convert to something GDAL can actually use
183+
92184
option_dict["callback"] = gdal.TermProgress
93185
gdal.Warp(output_name, vrt_filename, options=gdal.WarpOptions(**option_dict))
94-
return
95186

96187

97188
def _gdal_cmd_from_options(src, dst, option_dict):

sardem/dem.py

Lines changed: 114 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -414,24 +414,124 @@ def main(
414414
output_format,
415415
)
416416

417-
tile_names = list(Tile(*bbox).srtm1_tile_names())
417+
# Check for dateline crossing
418+
bboxes = utils.check_dateline(bbox)
418419

419-
d = Downloader(tile_names, data_source=data_source, cache_dir=cache_dir)
420-
local_filenames = d.download_all()
420+
if len(bboxes) == 1:
421+
# No dateline crossing, proceed normally
422+
tile_names = list(Tile(*bbox).srtm1_tile_names())
421423

422-
s = Stitcher(tile_names, filenames=local_filenames, data_source=data_source)
423-
stitched_dem = s.load_and_stitch()
424+
d = Downloader(tile_names, data_source=data_source, cache_dir=cache_dir)
425+
local_filenames = d.download_all()
424426

425-
# Now create corresponding rsc file for all the tiles
426-
rsc_dict_tiles = s.create_dem_rsc()
427+
s = Stitcher(tile_names, filenames=local_filenames, data_source=data_source)
428+
stitched_dem = s.load_and_stitch()
427429

428-
logger.info("Cropping stitched DEM to boundaries")
429-
stitched_dem = upsample.resample(stitched_dem, rsc_dict_tiles, bbox)
430-
rsc_dict = rsc_dict_tiles.copy()
431-
rsc_dict["X_FIRST"] = bbox[0]
432-
rsc_dict["Y_FIRST"] = bbox[3]
433-
rsc_dict["FILE_LENGTH"] = stitched_dem.shape[0]
434-
rsc_dict["WIDTH"] = stitched_dem.shape[1]
430+
rsc_dict_tiles = s.create_dem_rsc()
431+
432+
logger.info("Cropping stitched DEM to boundaries")
433+
stitched_dem = upsample.resample(stitched_dem, rsc_dict_tiles, bbox)
434+
rsc_dict = rsc_dict_tiles.copy()
435+
rsc_dict["X_FIRST"] = bbox[0]
436+
rsc_dict["Y_FIRST"] = bbox[3]
437+
rsc_dict["FILE_LENGTH"] = stitched_dem.shape[0]
438+
rsc_dict["WIDTH"] = stitched_dem.shape[1]
439+
else:
440+
# Dateline crossing: download each sub-bbox, merge via GDAL VRT
441+
import tempfile
442+
443+
logger.info(
444+
"Dateline crossing detected, downloading {} separate regions".format(
445+
len(bboxes)
446+
)
447+
)
448+
utils._gdal_installed_correctly()
449+
from osgeo import gdal, osr
450+
451+
temp_files = []
452+
with tempfile.TemporaryDirectory() as tmpdir:
453+
for idx, sub_bbox in enumerate(bboxes):
454+
logger.info(
455+
"Processing region {} of {}".format(idx + 1, len(bboxes))
456+
)
457+
458+
tile_names = list(Tile(*sub_bbox).srtm1_tile_names())
459+
d = Downloader(
460+
tile_names, data_source=data_source, cache_dir=cache_dir
461+
)
462+
local_filenames = d.download_all()
463+
464+
s = Stitcher(
465+
tile_names, filenames=local_filenames, data_source=data_source
466+
)
467+
dem_part = s.load_and_stitch()
468+
469+
rsc_dict_part = s.create_dem_rsc()
470+
dem_part = upsample.resample(dem_part, rsc_dict_part, sub_bbox)
471+
472+
# Write as GeoTIFF for VRT merging
473+
temp_tif = os.path.join(tmpdir, "dem_part_{}.tif".format(idx))
474+
temp_files.append(temp_tif)
475+
476+
x_first = sub_bbox[0]
477+
y_first = sub_bbox[3]
478+
x_step = (sub_bbox[2] - sub_bbox[0]) / dem_part.shape[1]
479+
y_step = -(sub_bbox[3] - sub_bbox[1]) / dem_part.shape[0]
480+
481+
# Shift western tile to 0-360 range so GDAL sees tiles as adjacent
482+
if x_first < 0:
483+
x_first += 360.0
484+
485+
driver = gdal.GetDriverByName("GTiff")
486+
ds = driver.Create(
487+
temp_tif,
488+
dem_part.shape[1],
489+
dem_part.shape[0],
490+
1,
491+
gdal.GDT_Int16
492+
if data_source != "NASA_WATER"
493+
else gdal.GDT_Byte,
494+
)
495+
ds.SetGeoTransform([x_first, x_step, 0, y_first, 0, y_step])
496+
497+
srs = osr.SpatialReference()
498+
srs.ImportFromEPSG(4326)
499+
ds.SetProjection(srs.ExportToWkt())
500+
501+
ds.GetRasterBand(1).WriteArray(dem_part)
502+
ds.FlushCache()
503+
ds = None
504+
505+
# Merge sub-regions via VRT
506+
logger.info("Merging {} regions into VRT".format(len(temp_files)))
507+
vrt_temp = os.path.join(tmpdir, "merged.vrt")
508+
gdal.BuildVRT(vrt_temp, temp_files)
509+
510+
ds = gdal.Open(vrt_temp)
511+
stitched_dem = ds.GetRasterBand(1).ReadAsArray()
512+
geotransform = ds.GetGeoTransform()
513+
ds = None
514+
515+
rsc_dict = collections.OrderedDict.fromkeys(RSC_KEYS)
516+
rsc_dict.update(
517+
{
518+
"X_UNIT": "degrees",
519+
"Y_UNIT": "degrees",
520+
"Z_OFFSET": 0,
521+
"Z_SCALE": 1,
522+
"PROJECTION": "LL",
523+
}
524+
)
525+
x_first = geotransform[0]
526+
# Wrap back to -180..180 if needed
527+
if x_first > 180:
528+
x_first -= 360
529+
rsc_dict["X_FIRST"] = x_first
530+
rsc_dict["Y_FIRST"] = geotransform[3]
531+
rsc_dict["X_STEP"] = geotransform[1]
532+
rsc_dict["Y_STEP"] = geotransform[5]
533+
rsc_dict["FILE_LENGTH"] = stitched_dem.shape[0]
534+
rsc_dict["WIDTH"] = stitched_dem.shape[1]
435535

436536
rsc_filename = output_name + ".rsc"
437537

sardem/tests/test_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,66 @@ def test_shift_integer_bbox():
99
hp = DEFAULT_RES / 2
1010
expected = [-156.0 - hp, 19.0 + hp, -155.0 - hp, 20.0 + hp]
1111
assert utils.shift_integer_bbox(bbox) == pytest.approx(expected)
12+
13+
14+
class TestCheckDateline:
15+
"""Tests for the check_dateline function."""
16+
17+
def test_no_dateline_crossing(self):
18+
bbox = (-156.0, 19.0, -155.0, 20.0)
19+
result = utils.check_dateline(bbox)
20+
assert len(result) == 1
21+
assert result[0] == pytest.approx(bbox)
22+
23+
def test_no_dateline_crossing_positive_lon(self):
24+
bbox = (100.0, -10.0, 110.0, 10.0)
25+
result = utils.check_dateline(bbox)
26+
assert len(result) == 1
27+
assert result[0] == pytest.approx(bbox)
28+
29+
def test_dateline_crossing_standard(self):
30+
# 170E to 170W = from 170 to -170 (crossing 180)
31+
bbox = (170.0, -10.0, -170.0, 10.0)
32+
result = utils.check_dateline(bbox)
33+
assert len(result) == 2
34+
35+
result = sorted(result, key=lambda x: x[0])
36+
37+
# Western part: -180 to -170
38+
assert result[0][0] == pytest.approx(-180.0)
39+
assert result[0][2] == pytest.approx(-170.0)
40+
assert result[0][1] == pytest.approx(-10.0)
41+
assert result[0][3] == pytest.approx(10.0)
42+
43+
# Eastern part: 170 to 180
44+
assert result[1][0] == pytest.approx(170.0)
45+
assert result[1][2] == pytest.approx(180.0)
46+
assert result[1][1] == pytest.approx(-10.0)
47+
assert result[1][3] == pytest.approx(10.0)
48+
49+
def test_dateline_crossing_wide_span(self):
50+
# 160E to 160W = 40 degrees, crossing dateline
51+
bbox = (160.0, -5.0, -160.0, 5.0)
52+
result = utils.check_dateline(bbox)
53+
assert len(result) == 2
54+
55+
result = sorted(result, key=lambda x: x[0])
56+
57+
assert result[0][0] == pytest.approx(-180.0)
58+
assert result[0][2] == pytest.approx(-160.0)
59+
60+
assert result[1][0] == pytest.approx(160.0)
61+
assert result[1][2] == pytest.approx(180.0)
62+
63+
def test_bbox_at_dateline_not_crossing(self):
64+
# 175E to 180 (right at dateline but not crossing)
65+
bbox = (175.0, -10.0, 180.0, 10.0)
66+
result = utils.check_dateline(bbox)
67+
assert len(result) == 1
68+
assert result[0] == pytest.approx(bbox)
69+
70+
def test_bbox_west_of_dateline_not_crossing(self):
71+
bbox = (-180.0, -10.0, -170.0, 10.0)
72+
result = utils.check_dateline(bbox)
73+
assert len(result) == 1
74+
assert result[0] == pytest.approx(bbox)

0 commit comments

Comments
 (0)