1010from itertools import product
1111from functools import lru_cache
1212from enum import Enum
13+ from .mag import Mag
1314
1415from .utils import (
1516 add_jobs_flag ,
@@ -113,53 +114,23 @@ def cube_addresses(source_wkw_info):
113114 return wkw_addresses
114115
115116
116- def mag_as_vector (mag ):
117- if isinstance (mag , int ):
118- return (mag , mag , mag )
119- else :
120- return mag
121-
122-
123- def is_mag_greater_than (mag1 , mag2 ):
124- for (m1 , m2 ) in zip (mag1 , mag2 ):
125- if m1 > m2 :
126- return True
127- return False
128-
129-
130- def assert_valid_mag (mag ):
131- if isinstance (mag , int ):
132- assert log2 (mag ) % 1 == 0 , "magnification needs to be power of 2."
133- else :
134- assert len (mag ) == 3 , "magnification must be int or a vector3 of ints"
135- for mag_dim in mag :
136- assert log2 (mag_dim ) % 1 == 0 , "magnification needs to be power of 2."
137-
138-
139- def mag_as_layer_name (mag ):
140- if isinstance (mag , int ):
141- return str (mag )
142- else :
143- x , y , z = mag
144- return "{}-{}-{}" .format (x , y , z )
145-
146-
147117def downsample (
148118 source_wkw_info ,
149119 target_wkw_info ,
150- _source_mag ,
151- _target_mag ,
120+ source_mag : Mag ,
121+ target_mag : Mag ,
152122 interpolation_mode ,
153123 cube_edge_len ,
154124 jobs ,
155125 compress ,
156126):
157- source_mag = mag_as_vector (_source_mag )
158- target_mag = mag_as_vector (_target_mag )
159- assert not is_mag_greater_than (source_mag , target_mag )
160- logging .info ("Downsampling mag {} from mag {}" .format (_target_mag , _source_mag ))
161127
162- mag_factors = [int (t / s ) for (t , s ) in zip (target_mag , source_mag )]
128+ assert source_mag < target_mag
129+ logging .info ("Downsampling mag {} from mag {}" .format (target_mag , source_mag ))
130+
131+ mag_factors = [
132+ t // s for (t , s ) in zip (target_mag .to_array (), source_mag .to_array ())
133+ ]
163134 # Detect the cubes that we want to downsample
164135 source_cube_addresses = cube_addresses (source_wkw_info )
165136
@@ -385,8 +356,8 @@ def downsample_cube(cube_buffer, factors, interpolation_mode):
385356def downsample_mag (
386357 path ,
387358 layer_name ,
388- source_mag ,
389- target_mag ,
359+ source_mag : Mag ,
360+ target_mag : Mag ,
390361 dtype = "uint8" ,
391362 interpolation_mode = "default" ,
392363 cube_edge_len = DEFAULT_EDGE_LEN ,
@@ -403,10 +374,10 @@ def downsample_mag(
403374 interpolation_mode = InterpolationModes [interpolation_mode .upper ()]
404375
405376 source_wkw_info = WkwDatasetInfo (
406- path , layer_name , dtype , mag_as_layer_name ( source_mag )
377+ path , layer_name , dtype , source_mag . to_layer_name ( )
407378 )
408379 target_wkw_info = WkwDatasetInfo (
409- path , layer_name , dtype , mag_as_layer_name ( target_mag )
380+ path , layer_name , dtype , target_mag . to_layer_name ( )
410381 )
411382 downsample (
412383 source_wkw_info ,
@@ -423,18 +394,17 @@ def downsample_mag(
423394def downsample_mags (
424395 path ,
425396 layer_name ,
426- from_mag ,
427- max_mag ,
397+ from_mag : Mag ,
398+ max_mag : Mag ,
428399 dtype ,
429400 interpolation_mode ,
430401 cube_edge_len ,
431402 jobs ,
432403 compress ,
433404):
434- assert log2 (from_mag ) % 1 == 0 , "'from_mag' needs to be power of 2."
435- target_mag = from_mag * 2
405+ target_mag = from_mag .scaled_by (2 )
436406 while target_mag <= max_mag :
437- source_mag = target_mag // 2
407+ source_mag = target_mag . divided_by ( 2 )
438408 downsample_mag (
439409 path ,
440410 layer_name ,
@@ -446,7 +416,7 @@ def downsample_mags(
446416 jobs ,
447417 compress ,
448418 )
449- target_mag = target_mag * 2
419+ target_mag . scale_by ( 2 )
450420
451421
452422if __name__ == "__main__" :
@@ -455,15 +425,15 @@ def downsample_mags(
455425 if args .verbose :
456426 logging .basicConfig (level = logging .DEBUG )
457427
428+ from_mag = Mag (args .from_mag )
429+ max_mag = Mag (args .max )
458430 if args .anisotropic_target_mag :
459- anisotropic_target_mag = tuple (map (int , args .anisotropic_target_mag .split ("-" )))
460- assert_valid_mag (args .from_mag )
461- assert_valid_mag (anisotropic_target_mag )
431+ anisotropic_target_mag = Mag (args .anisotropic_target_mag )
462432
463433 downsample_mag (
464434 args .path ,
465435 args .layer_name ,
466- args . from_mag ,
436+ from_mag ,
467437 anisotropic_target_mag ,
468438 args .dtype ,
469439 args .interpolation_mode ,
@@ -475,8 +445,8 @@ def downsample_mags(
475445 downsample_mags (
476446 args .path ,
477447 args .layer_name ,
478- args . from_mag ,
479- args . max ,
448+ from_mag ,
449+ max_mag ,
480450 args .dtype ,
481451 args .interpolation_mode ,
482452 args .buffer_cube_size ,
0 commit comments