@@ -6,17 +6,9 @@ open Fmt
66let ends_with suffix s = String. is_suffix ~suffix s
77let starts_with prefix s = String. is_prefix ~prefix s
88
9- let functions_requiring_namespace =
10- String.Set. of_list
11- [ " e" ; " pi" ; " log2" ; " log10" ; " sqrt2" ; " not_a_number" ; " positive_infinity"
12- ; " negative_infinity" ; " machine_precision" ; " abs" ; " acos" ; " acosh" ; " asin"
13- ; " asinh" ; " atan" ; " atanh" ; " cbrt" ; " ceil" ; " cos" ; " cosh" ; " erf" ; " erfc"
14- ; " exp" ; " exp2" ; " expm1" ; " fabs" ; " floor" ; " lgamma" ; " log" ; " log1p" ; " log2"
15- ; " log10" ; " round" ; " sin" ; " sinh" ; " sqrt" ; " tan" ; " tanh" ; " tgamma" ; " trunc"
16- ; " fdim" ; " fmax" ; " fmin" ; " hypot" ; " fma" ; " complex" ]
17-
189let stan_namespace_qualify f =
19- if Set. mem functions_requiring_namespace f then " stan::math::" ^ f else f
10+ if String. is_suffix ~suffix: " functor__" f || String. contains f ':' then f
11+ else " stan::math::" ^ f
2012
2113(* return true if the types of the two expression are the same *)
2214let types_match e1 e2 =
@@ -141,7 +133,7 @@ let fn_renames =
141133 ~f: (fun (k , v ) -> (Internal_fun. to_string k, v))
142134 [ (Internal_fun. FnLength , " stan::math::size" )
143135 ; (FnNegInf , " stan::math::negative_infinity" )
144- ; (FnResizeToMatch , " resize_to_match" )
136+ ; (FnResizeToMatch , " stan::math:: resize_to_match" )
145137 ; (FnNaN , " std::numeric_limits<double>::quiet_NaN" ) ]
146138 |> String.Map. of_alist_exn
147139
@@ -191,107 +183,84 @@ let transform_args = function
191183 | transform -> Transformation. fold (fun args arg -> args @ [arg]) [] transform
192184
193185let rec pp_index ppf = function
194- | Index. All -> pf ppf " index_omni()"
195- | Single e -> pf ppf " index_uni(%a)" pp_expr e
196- | Upfrom e -> pf ppf " index_min(%a)" pp_expr e
186+ | Index. All -> pf ppf " stan::model:: index_omni()"
187+ | Single e -> pf ppf " stan::model:: index_uni(%a)" pp_expr e
188+ | Upfrom e -> pf ppf " stan::model:: index_min(%a)" pp_expr e
197189 | Between (e_low , e_high ) ->
198- pf ppf " index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
199- | MultiIndex e -> pf ppf " index_multi(%a)" pp_expr e
190+ pf ppf " stan::model:: index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
191+ | MultiIndex e -> pf ppf " stan::model:: index_multi(%a)" pp_expr e
200192
201193and pp_indexes ppf = function
202194 | [] -> pf ppf " "
203195 | idxs -> pf ppf " @[<hov 2>%a@]" (list ~sep: comma pp_index) idxs
204196
205197and pp_logical_op ppf op lhs rhs =
206- pf ppf " (primitive_value(@,%a)@ %s@ primitive_value(@,%a))" pp_expr lhs op
207- pp_expr rhs
198+ pf ppf
199+ " (stan::math::primitive_value(@,%a)@ %s@ stan::math::primitive_value(@,%a))"
200+ pp_expr lhs op pp_expr rhs
208201
209202and pp_unary ppf fm es = pf ppf fm pp_expr (List. hd_exn es)
210- and pp_binary ppf fm es = pf ppf fm pp_expr (first es) pp_expr (second es)
203+
204+ and pp_binary_op ppf op es =
205+ pf ppf " (%a@ %s@ %a)" pp_expr (first es) op pp_expr (second es)
211206
212207and pp_binary_f ppf f es =
213208 pf ppf " %s(@,%a,@ %a)" f pp_expr (first es) pp_expr (second es)
214209
215210and first es = List. nth_exn es 0
216211and second es = List. nth_exn es 1
217212
218- and pp_scalar_binary ppf scalar_fmt generic_fmt es =
219- pp_binary ppf
220- ( if is_scalar (first es) && is_scalar (second es) then scalar_fmt
221- else generic_fmt )
222- es
213+ and pp_scalar_binary ppf op fn es =
214+ if is_scalar (first es) && is_scalar (second es) then pp_binary_op ppf op es
215+ else pp_binary_f ppf fn es
223216
224- and gen_operator_app = function
225- | Operator. Plus ->
226- fun ppf es -> pp_scalar_binary ppf " (%a@ +@ %a) " " add(@,%a,@ %a) " es
217+ and gen_operator_app op ppf es =
218+ match op with
219+ | Operator. Plus -> pp_scalar_binary ppf " + " " stan::math:: add" es
227220 | PMinus ->
228- fun ppf es ->
229- pp_unary ppf
230- (if is_scalar (List. hd_exn es) then " -%a" else " minus(@,%a)" )
231- es
232- | PPlus -> fun ppf es -> pp_unary ppf " %a" es
221+ pp_unary ppf
222+ (if is_scalar (List. hd_exn es) then " -%a" else " stan::math::minus(@,%a)" )
223+ es
224+ | PPlus -> pp_unary ppf " %a" es
233225 | Transpose ->
234- fun ppf es ->
235- pp_unary ppf
236- (if is_scalar (List. hd_exn es) then " %a" else " transpose(@,%a)" )
237- es
238- | PNot -> fun ppf es -> pp_unary ppf " logical_negation(@,%a)" es
239- | Minus ->
240- fun ppf es -> pp_scalar_binary ppf " (%a@ -@ %a)" " subtract(@,%a,@ %a)" es
241- | Times ->
242- fun ppf es -> pp_scalar_binary ppf " (%a@ *@ %a)" " multiply(@,%a,@ %a)" es
226+ pp_unary ppf
227+ ( if is_scalar (List. hd_exn es) then " %a"
228+ else " stan::math::transpose(@,%a)" )
229+ es
230+ | PNot -> pp_unary ppf " stan::math::logical_negation(@,%a)" es
231+ | Minus -> pp_scalar_binary ppf " -" " stan::math::subtract" es
232+ | Times -> pp_scalar_binary ppf " *" " stan::math::multiply" es
243233 | Divide | IntDivide ->
244- fun ppf es ->
245- if
246- is_matrix (second es)
247- && (is_matrix (first es) || is_row_vector (first es))
248- then pp_binary_f ppf " mdivide_right" es
249- else pp_scalar_binary ppf " (%a@ /@ %a)" " divide(@,%a,@ %a)" es
250- | Modulo -> fun ppf es -> pp_binary_f ppf " modulus" es
251- | LDivide -> fun ppf es -> pp_binary_f ppf " mdivide_left" es
234+ if
235+ is_matrix (second es)
236+ && (is_matrix (first es) || is_row_vector (first es))
237+ then pp_binary_f ppf " stan::math::mdivide_right" es
238+ else pp_scalar_binary ppf " /" " stan::math::divide" es
239+ | Modulo -> pp_binary_f ppf " stan::math::modulus" es
240+ | LDivide -> pp_binary_f ppf " stan::math::mdivide_left" es
252241 | And | Or ->
253242 Common.FatalError. fatal_error_msg
254243 [% message " And/Or should have been converted to an expression" ]
255- | EltTimes ->
256- fun ppf es ->
257- pp_scalar_binary ppf " (%a@ *@ %a)" " elt_multiply(@,%a,@ %a)" es
258- | EltDivide ->
259- fun ppf es ->
260- pp_scalar_binary ppf " (%a@ /@ %a)" " elt_divide(@,%a,@ %a)" es
261- | Pow -> fun ppf es -> pp_binary_f ppf " pow" es
262- | EltPow -> fun ppf es -> pp_binary_f ppf " pow" es
263- | Equals -> fun ppf es -> pp_binary_f ppf " logical_eq" es
264- | NEquals -> fun ppf es -> pp_binary_f ppf " logical_neq" es
265- | Less -> fun ppf es -> pp_binary_f ppf " logical_lt" es
266- | Leq -> fun ppf es -> pp_binary_f ppf " logical_lte" es
267- | Greater -> fun ppf es -> pp_binary_f ppf " logical_gt" es
268- | Geq -> fun ppf es -> pp_binary_f ppf " logical_gte" es
244+ | EltTimes -> pp_scalar_binary ppf " *" " stan::math::elt_multiply" es
245+ | EltDivide -> pp_scalar_binary ppf " /" " stan::math::elt_divide" es
246+ | Pow -> pp_binary_f ppf " stan::math::pow" es
247+ | EltPow -> pp_binary_f ppf " stan::math::pow" es
248+ | Equals -> pp_binary_f ppf " stan::math::logical_eq" es
249+ | NEquals -> pp_binary_f ppf " stan::math::logical_neq" es
250+ | Less -> pp_binary_f ppf " stan::math::logical_lt" es
251+ | Leq -> pp_binary_f ppf " stan::math::logical_lte" es
252+ | Greater -> pp_binary_f ppf " stan::math::logical_gt" es
253+ | Geq -> pp_binary_f ppf " stan::math::logical_gte" es
269254
270255and gen_misc_special_math_app f =
271256 match f with
272257 | "lmultiply" ->
273- Some (fun ppf es -> pp_binary ppf " multiply_log(@,%a,@ %a) " es)
258+ Some (fun ppf es -> pp_binary_f ppf " stan::math:: multiply_log" es)
274259 | "lchoose" ->
275- Some (fun ppf es -> pp_binary ppf " binomial_coefficient_log(@,%a,@ %a)" es)
276- | "target" -> Some (fun ppf _ -> pf ppf " get_lp(lp__, lp_accum__)" )
277- | "get_lp" -> Some (fun ppf _ -> pf ppf " get_lp(lp__, lp_accum__)" )
278- | "max" | "min" ->
279- Some
280- (fun ppf es ->
281- let f = match es with [_; _] -> " std::" ^ f | _ -> f in
282- pp_call ppf (f, pp_expr, es) )
283- | "ceil" ->
284- let std_prefix_data_scalar f = function
285- | [ Expr.
286- { Fixed. meta=
287- Typed.Meta. {adlevel= DataOnly ; type_= UInt | UReal ; _}
288- ; _ } ] ->
289- " std::" ^ f
290- | _ -> f in
291260 Some
292- (fun ppf es ->
293- let f = std_prefix_data_scalar f es in
294- pp_call ppf (f, pp_expr, es) )
261+ (fun ppf es -> pp_binary_f ppf " stan::math::binomial_coefficient_log " es)
262+ | "target" -> Some ( fun ppf _ -> pf ppf " stan::math::get_lp(lp__, lp_accum__) " )
263+ | "get_lp" -> Some ( fun ppf _ -> pf ppf " stan::math::get_lp(lp__, lp_accum__) " )
295264 | f when Map. mem fn_renames f ->
296265 Some (fun ppf es -> pp_call ppf (Map. find_exn fn_renames f, pp_expr, es))
297266 | _ -> None
@@ -498,7 +467,8 @@ and pp_promoted ad ut ppf e =
498467 (local_scalar ut ad) pp_expr e )
499468
500469and pp_indexed ppf (vident , indices , pretty ) =
501- pf ppf " @[<hov 2>rvalue(@,%s,@ %S,@ %a)@]" vident pretty pp_indexes indices
470+ pf ppf " @[<hov 2>stan::model::rvalue(@,%s,@ %S,@ %a)@]" vident pretty
471+ pp_indexes indices
502472
503473and pp_indexed_simple ppf (obj , idcs ) =
504474 let idx_minus_one = function
@@ -525,7 +495,7 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
525495 match pattern with
526496 | Var s -> pf ppf " %s" s
527497 | Lit (Str, s ) -> pf ppf " \" %s\" " (Cpp_str. escaped s)
528- | Lit (Imaginary, s ) -> pf ppf " to_complex(0, %s)" s
498+ | Lit (Imaginary, s ) -> pf ppf " stan::math:: to_complex(0, %s)" s
529499 | Lit ((Real | Int ), s ) -> pf ppf " %s" s
530500 | FunApp
531501 ( StanLib (op, _, _)
@@ -541,7 +511,6 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
541511 gen_fun_app suffix ppf f es mem_pattern
542512 | FunApp (CompilerInternal f , es ) ->
543513 pp_compiler_internal_fn meta.adlevel meta.type_ f ppf es
544- (* stan_namespace_qualify? *)
545514 | FunApp (UserDefined (f , suffix ), es ) ->
546515 pp_user_defined_fun ppf (f, suffix, es)
547516 | EAnd (e1 , e2 ) -> pp_logical_op ppf " &&" e1 e2
@@ -629,7 +598,7 @@ let%expect_test "pp_expr9" =
629598
630599let % expect_test " pp_expr10" =
631600 printf " %s" (pp_unlocated (Indexed (dummy_locate (Var " a" ), [All ]))) ;
632- [% expect {| rvalue(a, " a" , index_omni() ) | }]
601+ [% expect {| stan ::model:: rvalue(a , "a" , stan ::model:: index_omni()) |}]
633602
634603let % expect_test " pp_expr11" =
635604 printf " %s"
0 commit comments