@@ -103,7 +103,6 @@ def test_basics():
103103 blur_x .compute_at (blur_y , x ).vectorize (x , 8 )
104104 blur_y .compile_jit ()
105105
106-
107106def test_basics2 ():
108107 input = hl .ImageParam (hl .Float (32 ), 3 , "input" )
109108 hl .Param (hl .Float (32 ), "r_sigma" , 0.1 )
@@ -344,19 +343,30 @@ def test_typed_funcs():
344343 with assert_throws (hl .HalideError , "it is undefined" ):
345344 assert f .dimensions () == 0
346345
346+ with assert_throws (hl .HalideError , "it is undefined" ):
347+ assert f [x , y ].type () == hl .Int (32 )
348+
347349 f = hl .Func (hl .Int (32 ), 2 , "f" )
348350 assert not f .defined ()
349351 assert f .type () == hl .Int (32 )
350352 assert f .types () == [hl .Int (32 )]
351353 assert f .outputs () == 1
352354 assert f .dimensions () == 2
355+ assert f [x , y ].type () == hl .Int (32 )
356+
357+ with assert_throws (hl .HalideError , "has not yet been defined" ):
358+ # While we can ask for the type of f[x, y], because it's a
359+ # typed Func, we still can't use it as an Expr
360+ g = hl .Func ("g" )
361+ g [x , y ] = f [x , y ]
353362
354363 f = hl .Func ([hl .Int (32 ), hl .Float (64 )], 3 , "f" )
355364 assert not f .defined ()
356365 with assert_throws (hl .HalideError , "it returns a Tuple" ):
357366 assert f .type () == hl .Int (32 )
358367
359368 assert f .types () == [hl .Int (32 ), hl .Float (64 )]
369+ assert f [x , y ].types () == [hl .Int (32 ), hl .Float (64 )]
360370 assert f .outputs () == 2
361371 assert f .dimensions () == 3
362372
@@ -597,6 +607,31 @@ def test_print_ir():
597607 p = hl .Pipeline ()
598608 assert str (p ) == "<halide.Pipeline Pipeline()>"
599609
610+ def test_split_vars ():
611+ f = hl .Func ("f" )
612+ (x , xo , xi ) = hl .vars ("x xo xi" )
613+ f [x ] = x
614+ r = hl .RDom ([(0 , 10 ), (0 , 10 )], "r" )
615+ f [x ] += x + r .x + r .y
616+
617+ f .split (x , xo , xi , 8 )
618+
619+ vars = f .split_vars ()
620+ assert len (vars ) == 3
621+ assert vars [0 ].name () == xi .name ()
622+ assert vars [1 ].name () == xo .name ()
623+ assert vars [2 ].name () == hl .Var .outermost ().name ()
624+
625+ (rxo , rxi ) = (hl .RVar ("rxo" ), hl .RVar ("rxi" ))
626+ f .update ().split (r .x , rxo , rxi , 4 )
627+
628+ vars = f .update ().split_vars ()
629+ assert len (vars ) == 5
630+ assert isinstance (vars [0 ], hl .RVar ) and vars [0 ].name () == rxi .name ()
631+ assert isinstance (vars [1 ], hl .RVar ) and vars [1 ].name () == rxo .name ()
632+ assert isinstance (vars [2 ], hl .RVar ) and vars [2 ].name () == r .y .name ()
633+ assert isinstance (vars [3 ], hl .Var ) and vars [3 ].name () == x .name ()
634+ assert isinstance (vars [4 ], hl .Var ) and vars [4 ].name () == hl .Var .outermost ().name ()
600635
601636if __name__ == "__main__" :
602637 test_compiletime_error ()
@@ -622,3 +657,4 @@ def test_print_ir():
622657 test_implicit_update_by_int ()
623658 test_implicit_update_by_float ()
624659 test_print_ir ()
660+ test_split_vars ()
0 commit comments