4343default_max_ulp_difference = 1
4444
4545operations = [
46- # The following dictionaries may have additional keys like
47- #
48- # size - defines the number of samples: size ** 2
49- #
50- # max_ulp_difference - the maximal allowed ULP difference between
51- # function and reference values
52- #
53- # extra_prec_multiplier - the precison multiplier for mpmath.mp
54- # that defines the precision of computing reference values:
55- # mpmath.mp.prec * extra_prec_multiplier
56- #
57- # When unspecifed, these parameters are retrieved from
58- # functional_algorithms database of support functions.
59- #
60- dict (name = "asin" , mpmath_name = "arcsin" ),
61- dict (name = "acos" , mpmath_name = "arccos" ),
62- dict (name = "atan" , mpmath_name = "arctan" ),
63- dict (name = "asinh" , mpmath_name = "arcsinh" ),
64- dict (name = "acosh" , mpmath_name = "arccosh" ),
65- dict (name = "atanh" , mpmath_name = "arctanh" ),
66- dict (name = "square" , mpmath_name = "square" ),
46+ # The following dictionaries may have additional keys like
47+ #
48+ # size - defines the number of samples: size ** 2
49+ #
50+ # max_ulp_difference - the maximal allowed ULP difference between
51+ # function and reference values
52+ #
53+ # extra_prec_multiplier - the precison multiplier for mpmath.mp
54+ # that defines the precision of computing reference values:
55+ # mpmath.mp.prec * extra_prec_multiplier
56+ #
57+ # When unspecifed, these parameters are retrieved from
58+ # functional_algorithms database of support functions.
59+ #
60+ dict (name = "asin" , mpmath_name = "arcsin" ),
61+ dict (name = "acos" , mpmath_name = "arccos" ),
62+ dict (name = "atan" , mpmath_name = "arctan" ),
63+ dict (name = "asinh" , mpmath_name = "arcsinh" ),
64+ dict (name = "acosh" , mpmath_name = "arccosh" ),
65+ dict (name = "atanh" , mpmath_name = "arctanh" ),
66+ dict (name = "square" , mpmath_name = "square" ),
67+ dict (
68+ name = "log_plus_one" ,
69+ mpmath_name = "log1p" ,
70+ namespace = "stablehlo" ,
71+ passes = "--stablehlo-complex-math-expander" ,
72+ ),
6773]
6874
6975
@@ -127,19 +133,24 @@ def main():
127133 for op in operations :
128134 opname = op ["name" ]
129135 mpmath_opname = op .get ("mpmath_name" , opname )
136+ namespace = op .get ("namespace" , "chlo" )
130137 size_re = size_im = op .get ("size" , default_size )
131-
138+ passes = op . get ( "passes" , "--chlo-legalize-to-stablehlo" )
132139 for dtype in [np .complex64 , np .complex128 , np .float32 , np .float64 ]:
133140 params = fa .utils .function_validation_parameters (opname , dtype )
134141 max_ulp_difference = op .get (
135- "max_ulp_difference" ,
136- params .get ("max_valid_ulp_count" , default_max_ulp_difference ))
142+ "max_ulp_difference" ,
143+ params .get ("max_valid_ulp_count" , default_max_ulp_difference ),
144+ )
137145
138146 nmp = fa .utils .numpy_with_mpmath (
139- extra_prec_multiplier = op .get (
140- "extra_prec_multiplier" ,
141- params .get ("extra_prec_multiplier" , default_extra_prec_multiplier )),
142- flush_subnormals = flush_subnormals ,
147+ extra_prec_multiplier = op .get (
148+ "extra_prec_multiplier" ,
149+ params .get (
150+ "extra_prec_multiplier" , default_extra_prec_multiplier
151+ ),
152+ ),
153+ flush_subnormals = flush_subnormals ,
143154 )
144155
145156 fi = np .finfo (dtype )
@@ -180,7 +191,7 @@ def main():
180191 main_func = m .make_function ("main" , "" , "" , "public" )
181192
182193 ref_samples = main_func .call ("samples" )
183- actual = main_func .composite (f"chlo .{ opname } " , ref_samples )
194+ actual = main_func .composite (f"{ namespace } .{ opname } " , ref_samples )
184195 expected = main_func .call ("expected" )
185196
186197 main_func .void_call (
@@ -202,8 +213,10 @@ def main():
202213 continue
203214
204215 f = open (fname , "w" )
205- f .write ("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |"
206- " stablehlo-translate --interpret\n " )
216+ f .write (
217+ f"// RUN: stablehlo-opt { passes } %s |"
218+ " stablehlo-translate --interpret\n "
219+ )
207220 f .write (
208221 "// This file is generated, see build_tools/math/README.md for more"
209222 " information.\n " )
0 commit comments