@@ -6,17 +6,9 @@ open Fmt
66let ends_with suffix s = String. is_suffix ~suffix s
77let starts_with prefix s = String. is_prefix ~prefix s
88
9- let functions_requiring_namespace =
10- String.Set. of_list
11- [ " e" ; " pi" ; " log2" ; " log10" ; " sqrt2" ; " not_a_number" ; " positive_infinity"
12- ; " negative_infinity" ; " machine_precision" ; " abs" ; " acos" ; " acosh" ; " asin"
13- ; " asinh" ; " atan" ; " atanh" ; " cbrt" ; " ceil" ; " cos" ; " cosh" ; " erf" ; " erfc"
14- ; " exp" ; " exp2" ; " expm1" ; " fabs" ; " floor" ; " lgamma" ; " log" ; " log1p" ; " log2"
15- ; " log10" ; " round" ; " sin" ; " sinh" ; " sqrt" ; " tan" ; " tanh" ; " tgamma" ; " trunc"
16- ; " fdim" ; " fmax" ; " fmin" ; " hypot" ; " fma" ; " complex" ]
17-
189let stan_namespace_qualify f =
19- if Set. mem functions_requiring_namespace f then " stan::math::" ^ f else f
10+ if String. is_suffix ~suffix: " functor__" f || String. contains f ':' then f
11+ else " stan::math::" ^ f
2012
2113(* return true if the types of the two expression are the same *)
2214let types_match e1 e2 =
@@ -157,7 +149,7 @@ let fn_renames =
157149 ~f: (fun (k , v ) -> (Internal_fun. to_string k, v))
158150 [ (Internal_fun. FnLength , " stan::math::size" )
159151 ; (FnNegInf , " stan::math::negative_infinity" )
160- ; (FnResizeToMatch , " resize_to_match" )
152+ ; (FnResizeToMatch , " stan::math:: resize_to_match" )
161153 ; (FnNaN , " std::numeric_limits<double>::quiet_NaN" ) ]
162154 |> String.Map. of_alist_exn
163155
@@ -207,110 +199,86 @@ let transform_args = function
207199 | transform -> Transformation. fold (fun args arg -> args @ [arg]) [] transform
208200
209201let rec pp_index ppf = function
210- | Index. All -> pf ppf " index_omni()"
211- | Single e -> pf ppf " index_uni(%a)" pp_expr e
212- | Upfrom e -> pf ppf " index_min(%a)" pp_expr e
202+ | Index. All -> pf ppf " stan::model:: index_omni()"
203+ | Single e -> pf ppf " stan::model:: index_uni(%a)" pp_expr e
204+ | Upfrom e -> pf ppf " stan::model:: index_min(%a)" pp_expr e
213205 | Between (e_low , e_high ) ->
214- pf ppf " index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
215- | MultiIndex e -> pf ppf " index_multi(%a)" pp_expr e
206+ pf ppf " stan::model:: index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
207+ | MultiIndex e -> pf ppf " stan::model:: index_multi(%a)" pp_expr e
216208
217209and pp_indexes ppf = function
218210 | [] -> pf ppf " "
219211 | idxs -> pf ppf " @[<hov 2>%a@]" (list ~sep: comma pp_index) idxs
220212
221213and pp_logical_op ppf op lhs rhs =
222- pf ppf " (primitive_value(@,%a)@ %s@ primitive_value(@,%a))" pp_expr lhs op
223- pp_expr rhs
214+ pf ppf
215+ " (stan::math::primitive_value(@,%a)@ %s@ stan::math::primitive_value(@,%a))"
216+ pp_expr lhs op pp_expr rhs
224217
225218and pp_unary ppf fm es = pf ppf fm pp_expr (List. hd_exn es)
226- and pp_binary ppf fm es = pf ppf fm pp_expr (first es) pp_expr (second es)
219+
220+ and pp_binary_op ppf op es =
221+ pf ppf " (%a@ %s@ %a)" pp_expr (first es) op pp_expr (second es)
227222
228223and pp_binary_f ppf f es =
229224 pf ppf " %s(@,%a,@ %a)" f pp_expr (first es) pp_expr (second es)
230225
231226and first es = List. nth_exn es 0
232227and second es = List. nth_exn es 1
233228
234- and pp_scalar_binary ppf scalar_fmt generic_fmt es =
235- pp_binary ppf
236- ( if is_scalar (first es) && is_scalar (second es) then scalar_fmt
237- else generic_fmt )
238- es
229+ and pp_scalar_binary ppf op fn es =
230+ if is_scalar (first es) && is_scalar (second es) then pp_binary_op ppf op es
231+ else pp_binary_f ppf fn es
239232
240- and gen_operator_app op_expr =
241- match op_expr with
242- | Operator. Plus ->
243- fun ppf es -> pp_scalar_binary ppf " (%a@ +@ %a)" " add(@,%a,@ %a)" es
233+ and gen_operator_app op ppf es =
234+ match op with
235+ | Operator. Plus -> pp_scalar_binary ppf " +" " stan::math::add" es
244236 | PMinus ->
245- fun ppf es ->
246- pp_unary ppf
247- (if is_scalar (List. hd_exn es) then " -%a" else " minus(@,%a)" )
248- es
249- | PPlus -> fun ppf es -> pp_unary ppf " %a" es
237+ pp_unary ppf
238+ (if is_scalar (List. hd_exn es) then " -%a" else " stan::math::minus(@,%a)" )
239+ es
240+ | PPlus -> pp_unary ppf " %a" es
250241 | Transpose ->
251- fun ppf es ->
252- pp_unary ppf
253- (if is_scalar (List. hd_exn es) then " %a" else " transpose(@,%a)" )
254- es
255- | PNot -> fun ppf es -> pp_unary ppf " logical_negation(@,%a)" es
256- | Minus ->
257- fun ppf es -> pp_scalar_binary ppf " (%a@ -@ %a)" " subtract(@,%a,@ %a)" es
258- | Times ->
259- fun ppf es -> pp_scalar_binary ppf " (%a@ *@ %a)" " multiply(@,%a,@ %a)" es
242+ pp_unary ppf
243+ ( if is_scalar (List. hd_exn es) then " %a"
244+ else " stan::math::transpose(@,%a)" )
245+ es
246+ | PNot -> pp_unary ppf " stan::math::logical_negation(@,%a)" es
247+ | Minus -> pp_scalar_binary ppf " -" " stan::math::subtract" es
248+ | Times -> pp_scalar_binary ppf " *" " stan::math::multiply" es
260249 | Divide | IntDivide ->
261- fun ppf es ->
262- if
263- is_matrix (second es)
264- && (is_matrix (first es) || is_row_vector (first es))
265- then pp_binary_f ppf " mdivide_right" es
266- else pp_scalar_binary ppf " (%a@ /@ %a)" " divide(@,%a,@ %a)" es
267- | Modulo -> fun ppf es -> pp_binary_f ppf " modulus" es
268- | LDivide -> fun ppf es -> pp_binary_f ppf " mdivide_left" es
250+ if
251+ is_matrix (second es)
252+ && (is_matrix (first es) || is_row_vector (first es))
253+ then pp_binary_f ppf " stan::math::mdivide_right" es
254+ else pp_scalar_binary ppf " /" " stan::math::divide" es
255+ | Modulo -> pp_binary_f ppf " stan::math::modulus" es
256+ | LDivide -> pp_binary_f ppf " stan::math::mdivide_left" es
269257 | And | Or ->
270258 Common.FatalError. fatal_error_msg
271259 [% message " And/Or should have been converted to an expression" ]
272- | EltTimes ->
273- fun ppf es ->
274- pp_scalar_binary ppf " (%a@ *@ %a)" " elt_multiply(@,%a,@ %a)" es
275- | EltDivide ->
276- fun ppf es ->
277- pp_scalar_binary ppf " (%a@ /@ %a)" " elt_divide(@,%a,@ %a)" es
278- | Pow -> fun ppf es -> pp_binary_f ppf " pow" es
279- | EltPow -> fun ppf es -> pp_binary_f ppf " pow" es
280- | Equals -> fun ppf es -> pp_binary_f ppf " logical_eq" es
281- | NEquals -> fun ppf es -> pp_binary_f ppf " logical_neq" es
282- | Less -> fun ppf es -> pp_binary_f ppf " logical_lt" es
283- | Leq -> fun ppf es -> pp_binary_f ppf " logical_lte" es
284- | Greater -> fun ppf es -> pp_binary_f ppf " logical_gt" es
285- | Geq -> fun ppf es -> pp_binary_f ppf " logical_gte" es
260+ | EltTimes -> pp_scalar_binary ppf " *" " stan::math::elt_multiply" es
261+ | EltDivide -> pp_scalar_binary ppf " /" " stan::math::elt_divide" es
262+ | Pow -> pp_binary_f ppf " stan::math::pow" es
263+ | EltPow -> pp_binary_f ppf " stan::math::pow" es
264+ | Equals -> pp_binary_f ppf " stan::math::logical_eq" es
265+ | NEquals -> pp_binary_f ppf " stan::math::logical_neq" es
266+ | Less -> pp_binary_f ppf " stan::math::logical_lt" es
267+ | Leq -> pp_binary_f ppf " stan::math::logical_lte" es
268+ | Greater -> pp_binary_f ppf " stan::math::logical_gt" es
269+ | Geq -> pp_binary_f ppf " stan::math::logical_gte" es
286270
287271and gen_misc_special_math_app (f : string )
288272 (mem_pattern : Common.Helpers.mem_pattern )
289273 (ret_type : UnsizedType.returntype option ) =
290274 match f with
291275 | "lmultiply" ->
292- Some (fun ppf es -> pp_binary ppf " multiply_log(@,%a,@ %a) " es)
276+ Some (fun ppf es -> pp_binary_f ppf " stan::math:: multiply_log" es)
293277 | "lchoose" ->
294- Some (fun ppf es -> pp_binary ppf " binomial_coefficient_log(@,%a,@ %a)" es)
295- | "target" -> Some (fun ppf _ -> pf ppf " get_lp(lp__, lp_accum__)" )
296- | "get_lp" -> Some (fun ppf _ -> pf ppf " get_lp(lp__, lp_accum__)" )
297- | "max" | "min" ->
298- Some
299- (fun ppf es ->
300- let f = match es with [_; _] -> " std::" ^ f | _ -> f in
301- pp_call ppf (f, pp_expr, es) )
302- | "ceil" ->
303- let std_prefix_data_scalar f = function
304- | [ Expr.
305- { Fixed. meta=
306- Typed.Meta. {adlevel= DataOnly ; type_= UInt | UReal ; _}
307- ; _ } ] ->
308- " std::" ^ f
309- | _ -> f in
310278 Some
311- (fun ppf es ->
312- let f = std_prefix_data_scalar f es in
313- pp_call ppf (f, pp_expr, es) )
279+ (fun ppf es -> pp_binary_f ppf " stan::math::binomial_coefficient_log " es)
280+ | "target" -> Some ( fun ppf _ -> pf ppf " stan::math::get_lp(lp__, lp_accum__) " )
281+ | "get_lp" -> Some ( fun ppf _ -> pf ppf " stan::math::get_lp(lp__, lp_accum__) " )
314282 | f when Map. mem fn_renames f ->
315283 Some (fun ppf es -> pp_call ppf (Map. find_exn fn_renames f, pp_expr, es))
316284 | " rep_matrix" | " rep_vector" | " rep_row_vector" | " append_row" | " append_col"
@@ -542,7 +510,8 @@ and pp_promoted ad ut ppf e =
542510 (local_scalar ut ad) pp_expr e )
543511
544512and pp_indexed ppf (vident , indices , pretty ) =
545- pf ppf " @[<hov 2>rvalue(@,%s,@ %S,@ %a)@]" vident pretty pp_indexes indices
513+ pf ppf " @[<hov 2>stan::model::rvalue(@,%s,@ %S,@ %a)@]" vident pretty
514+ pp_indexes indices
546515
547516and pp_indexed_simple ppf (obj , idcs ) =
548517 let idx_minus_one = function
@@ -569,7 +538,7 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
569538 match pattern with
570539 | Var s -> pf ppf " %s" s
571540 | Lit (Str, s ) -> pf ppf " \" %s\" " (Cpp_str. escaped s)
572- | Lit (Imaginary, s ) -> pf ppf " to_complex(0, %s)" s
541+ | Lit (Imaginary, s ) -> pf ppf " stan::math:: to_complex(0, %s)" s
573542 | Lit ((Real | Int ), s ) -> pf ppf " %s" s
574543 | FunApp
575544 ( StanLib (op, _, _)
@@ -587,7 +556,6 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
587556 gen_fun_app suffix ppf f es mem_pattern ret_type
588557 | FunApp (CompilerInternal f , es ) ->
589558 pp_compiler_internal_fn meta.adlevel meta.type_ f ppf es
590- (* stan_namespace_qualify? *)
591559 | FunApp (UserDefined (f , suffix ), es ) ->
592560 pp_user_defined_fun ppf (f, suffix, es)
593561 | EAnd (e1 , e2 ) -> pp_logical_op ppf " &&" e1 e2
@@ -675,7 +643,7 @@ let%expect_test "pp_expr9" =
675643
676644let % expect_test " pp_expr10" =
677645 printf " %s" (pp_unlocated (Indexed (dummy_locate (Var " a" ), [All ]))) ;
678- [% expect {| rvalue(a, " a" , index_omni() ) | }]
646+ [% expect {| stan ::model:: rvalue(a , "a" , stan ::model:: index_omni()) |}]
679647
680648let % expect_test " pp_expr11" =
681649 printf " %s"
0 commit comments