Skip to content

Commit 727c762

Browse files
authored
Merge pull request #1011 from stan-dev/shadowing
Fix issue with function shadowing
2 parents 3d0c89a + f9ceded commit 727c762

File tree

17 files changed

+24860
-17363
lines changed

17 files changed

+24860
-17363
lines changed

src/stan_math_backend/Expression_gen.ml

Lines changed: 56 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,9 @@ open Fmt
66
let ends_with suffix s = String.is_suffix ~suffix s
77
let starts_with prefix s = String.is_prefix ~prefix s
88

9-
let functions_requiring_namespace =
10-
String.Set.of_list
11-
[ "e"; "pi"; "log2"; "log10"; "sqrt2"; "not_a_number"; "positive_infinity"
12-
; "negative_infinity"; "machine_precision"; "abs"; "acos"; "acosh"; "asin"
13-
; "asinh"; "atan"; "atanh"; "cbrt"; "ceil"; "cos"; "cosh"; "erf"; "erfc"
14-
; "exp"; "exp2"; "expm1"; "fabs"; "floor"; "lgamma"; "log"; "log1p"; "log2"
15-
; "log10"; "round"; "sin"; "sinh"; "sqrt"; "tan"; "tanh"; "tgamma"; "trunc"
16-
; "fdim"; "fmax"; "fmin"; "hypot"; "fma"; "complex" ]
17-
189
let stan_namespace_qualify f =
19-
if Set.mem functions_requiring_namespace f then "stan::math::" ^ f else f
10+
if String.is_suffix ~suffix:"functor__" f || String.contains f ':' then f
11+
else "stan::math::" ^ f
2012

2113
(* return true if the types of the two expression are the same *)
2214
let types_match e1 e2 =
@@ -141,7 +133,7 @@ let fn_renames =
141133
~f:(fun (k, v) -> (Internal_fun.to_string k, v))
142134
[ (Internal_fun.FnLength, "stan::math::size")
143135
; (FnNegInf, "stan::math::negative_infinity")
144-
; (FnResizeToMatch, "resize_to_match")
136+
; (FnResizeToMatch, "stan::math::resize_to_match")
145137
; (FnNaN, "std::numeric_limits<double>::quiet_NaN") ]
146138
|> String.Map.of_alist_exn
147139

@@ -191,107 +183,84 @@ let transform_args = function
191183
| transform -> Transformation.fold (fun args arg -> args @ [arg]) [] transform
192184

193185
let rec pp_index ppf = function
194-
| Index.All -> pf ppf "index_omni()"
195-
| Single e -> pf ppf "index_uni(%a)" pp_expr e
196-
| Upfrom e -> pf ppf "index_min(%a)" pp_expr e
186+
| Index.All -> pf ppf "stan::model::index_omni()"
187+
| Single e -> pf ppf "stan::model::index_uni(%a)" pp_expr e
188+
| Upfrom e -> pf ppf "stan::model::index_min(%a)" pp_expr e
197189
| Between (e_low, e_high) ->
198-
pf ppf "index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
199-
| MultiIndex e -> pf ppf "index_multi(%a)" pp_expr e
190+
pf ppf "stan::model::index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
191+
| MultiIndex e -> pf ppf "stan::model::index_multi(%a)" pp_expr e
200192

201193
and pp_indexes ppf = function
202194
| [] -> pf ppf ""
203195
| idxs -> pf ppf "@[<hov 2>%a@]" (list ~sep:comma pp_index) idxs
204196

205197
and pp_logical_op ppf op lhs rhs =
206-
pf ppf "(primitive_value(@,%a)@ %s@ primitive_value(@,%a))" pp_expr lhs op
207-
pp_expr rhs
198+
pf ppf
199+
"(stan::math::primitive_value(@,%a)@ %s@ stan::math::primitive_value(@,%a))"
200+
pp_expr lhs op pp_expr rhs
208201

209202
and pp_unary ppf fm es = pf ppf fm pp_expr (List.hd_exn es)
210-
and pp_binary ppf fm es = pf ppf fm pp_expr (first es) pp_expr (second es)
203+
204+
and pp_binary_op ppf op es =
205+
pf ppf "(%a@ %s@ %a)" pp_expr (first es) op pp_expr (second es)
211206

212207
and pp_binary_f ppf f es =
213208
pf ppf "%s(@,%a,@ %a)" f pp_expr (first es) pp_expr (second es)
214209

215210
and first es = List.nth_exn es 0
216211
and second es = List.nth_exn es 1
217212

218-
and pp_scalar_binary ppf scalar_fmt generic_fmt es =
219-
pp_binary ppf
220-
( if is_scalar (first es) && is_scalar (second es) then scalar_fmt
221-
else generic_fmt )
222-
es
213+
and pp_scalar_binary ppf op fn es =
214+
if is_scalar (first es) && is_scalar (second es) then pp_binary_op ppf op es
215+
else pp_binary_f ppf fn es
223216

224-
and gen_operator_app = function
225-
| Operator.Plus ->
226-
fun ppf es -> pp_scalar_binary ppf "(%a@ +@ %a)" "add(@,%a,@ %a)" es
217+
and gen_operator_app op ppf es =
218+
match op with
219+
| Operator.Plus -> pp_scalar_binary ppf "+" "stan::math::add" es
227220
| PMinus ->
228-
fun ppf es ->
229-
pp_unary ppf
230-
(if is_scalar (List.hd_exn es) then "-%a" else "minus(@,%a)")
231-
es
232-
| PPlus -> fun ppf es -> pp_unary ppf "%a" es
221+
pp_unary ppf
222+
(if is_scalar (List.hd_exn es) then "-%a" else "stan::math::minus(@,%a)")
223+
es
224+
| PPlus -> pp_unary ppf "%a" es
233225
| Transpose ->
234-
fun ppf es ->
235-
pp_unary ppf
236-
(if is_scalar (List.hd_exn es) then "%a" else "transpose(@,%a)")
237-
es
238-
| PNot -> fun ppf es -> pp_unary ppf "logical_negation(@,%a)" es
239-
| Minus ->
240-
fun ppf es -> pp_scalar_binary ppf "(%a@ -@ %a)" "subtract(@,%a,@ %a)" es
241-
| Times ->
242-
fun ppf es -> pp_scalar_binary ppf "(%a@ *@ %a)" "multiply(@,%a,@ %a)" es
226+
pp_unary ppf
227+
( if is_scalar (List.hd_exn es) then "%a"
228+
else "stan::math::transpose(@,%a)" )
229+
es
230+
| PNot -> pp_unary ppf "stan::math::logical_negation(@,%a)" es
231+
| Minus -> pp_scalar_binary ppf "-" "stan::math::subtract" es
232+
| Times -> pp_scalar_binary ppf "*" "stan::math::multiply" es
243233
| Divide | IntDivide ->
244-
fun ppf es ->
245-
if
246-
is_matrix (second es)
247-
&& (is_matrix (first es) || is_row_vector (first es))
248-
then pp_binary_f ppf "mdivide_right" es
249-
else pp_scalar_binary ppf "(%a@ /@ %a)" "divide(@,%a,@ %a)" es
250-
| Modulo -> fun ppf es -> pp_binary_f ppf "modulus" es
251-
| LDivide -> fun ppf es -> pp_binary_f ppf "mdivide_left" es
234+
if
235+
is_matrix (second es)
236+
&& (is_matrix (first es) || is_row_vector (first es))
237+
then pp_binary_f ppf "stan::math::mdivide_right" es
238+
else pp_scalar_binary ppf "/" "stan::math::divide" es
239+
| Modulo -> pp_binary_f ppf "stan::math::modulus" es
240+
| LDivide -> pp_binary_f ppf "stan::math::mdivide_left" es
252241
| And | Or ->
253242
Common.FatalError.fatal_error_msg
254243
[%message "And/Or should have been converted to an expression"]
255-
| EltTimes ->
256-
fun ppf es ->
257-
pp_scalar_binary ppf "(%a@ *@ %a)" "elt_multiply(@,%a,@ %a)" es
258-
| EltDivide ->
259-
fun ppf es ->
260-
pp_scalar_binary ppf "(%a@ /@ %a)" "elt_divide(@,%a,@ %a)" es
261-
| Pow -> fun ppf es -> pp_binary_f ppf "pow" es
262-
| EltPow -> fun ppf es -> pp_binary_f ppf "pow" es
263-
| Equals -> fun ppf es -> pp_binary_f ppf "logical_eq" es
264-
| NEquals -> fun ppf es -> pp_binary_f ppf "logical_neq" es
265-
| Less -> fun ppf es -> pp_binary_f ppf "logical_lt" es
266-
| Leq -> fun ppf es -> pp_binary_f ppf "logical_lte" es
267-
| Greater -> fun ppf es -> pp_binary_f ppf "logical_gt" es
268-
| Geq -> fun ppf es -> pp_binary_f ppf "logical_gte" es
244+
| EltTimes -> pp_scalar_binary ppf "*" "stan::math::elt_multiply" es
245+
| EltDivide -> pp_scalar_binary ppf "/" "stan::math::elt_divide" es
246+
| Pow -> pp_binary_f ppf "stan::math::pow" es
247+
| EltPow -> pp_binary_f ppf "stan::math::pow" es
248+
| Equals -> pp_binary_f ppf "stan::math::logical_eq" es
249+
| NEquals -> pp_binary_f ppf "stan::math::logical_neq" es
250+
| Less -> pp_binary_f ppf "stan::math::logical_lt" es
251+
| Leq -> pp_binary_f ppf "stan::math::logical_lte" es
252+
| Greater -> pp_binary_f ppf "stan::math::logical_gt" es
253+
| Geq -> pp_binary_f ppf "stan::math::logical_gte" es
269254

270255
and gen_misc_special_math_app f =
271256
match f with
272257
| "lmultiply" ->
273-
Some (fun ppf es -> pp_binary ppf "multiply_log(@,%a,@ %a)" es)
258+
Some (fun ppf es -> pp_binary_f ppf "stan::math::multiply_log" es)
274259
| "lchoose" ->
275-
Some (fun ppf es -> pp_binary ppf "binomial_coefficient_log(@,%a,@ %a)" es)
276-
| "target" -> Some (fun ppf _ -> pf ppf "get_lp(lp__, lp_accum__)")
277-
| "get_lp" -> Some (fun ppf _ -> pf ppf "get_lp(lp__, lp_accum__)")
278-
| "max" | "min" ->
279-
Some
280-
(fun ppf es ->
281-
let f = match es with [_; _] -> "std::" ^ f | _ -> f in
282-
pp_call ppf (f, pp_expr, es) )
283-
| "ceil" ->
284-
let std_prefix_data_scalar f = function
285-
| [ Expr.
286-
{ Fixed.meta=
287-
Typed.Meta.{adlevel= DataOnly; type_= UInt | UReal; _}
288-
; _ } ] ->
289-
"std::" ^ f
290-
| _ -> f in
291260
Some
292-
(fun ppf es ->
293-
let f = std_prefix_data_scalar f es in
294-
pp_call ppf (f, pp_expr, es) )
261+
(fun ppf es -> pp_binary_f ppf "stan::math::binomial_coefficient_log" es)
262+
| "target" -> Some (fun ppf _ -> pf ppf "stan::math::get_lp(lp__, lp_accum__)")
263+
| "get_lp" -> Some (fun ppf _ -> pf ppf "stan::math::get_lp(lp__, lp_accum__)")
295264
| f when Map.mem fn_renames f ->
296265
Some (fun ppf es -> pp_call ppf (Map.find_exn fn_renames f, pp_expr, es))
297266
| _ -> None
@@ -498,7 +467,8 @@ and pp_promoted ad ut ppf e =
498467
(local_scalar ut ad) pp_expr e )
499468

500469
and pp_indexed ppf (vident, indices, pretty) =
501-
pf ppf "@[<hov 2>rvalue(@,%s,@ %S,@ %a)@]" vident pretty pp_indexes indices
470+
pf ppf "@[<hov 2>stan::model::rvalue(@,%s,@ %S,@ %a)@]" vident pretty
471+
pp_indexes indices
502472

503473
and pp_indexed_simple ppf (obj, idcs) =
504474
let idx_minus_one = function
@@ -525,7 +495,7 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
525495
match pattern with
526496
| Var s -> pf ppf "%s" s
527497
| Lit (Str, s) -> pf ppf "\"%s\"" (Cpp_str.escaped s)
528-
| Lit (Imaginary, s) -> pf ppf "to_complex(0, %s)" s
498+
| Lit (Imaginary, s) -> pf ppf "stan::math::to_complex(0, %s)" s
529499
| Lit ((Real | Int), s) -> pf ppf "%s" s
530500
| FunApp
531501
( StanLib (op, _, _)
@@ -541,7 +511,6 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
541511
gen_fun_app suffix ppf f es mem_pattern
542512
| FunApp (CompilerInternal f, es) ->
543513
pp_compiler_internal_fn meta.adlevel meta.type_ f ppf es
544-
(* stan_namespace_qualify? *)
545514
| FunApp (UserDefined (f, suffix), es) ->
546515
pp_user_defined_fun ppf (f, suffix, es)
547516
| EAnd (e1, e2) -> pp_logical_op ppf "&&" e1 e2
@@ -629,7 +598,7 @@ let%expect_test "pp_expr9" =
629598

630599
let%expect_test "pp_expr10" =
631600
printf "%s" (pp_unlocated (Indexed (dummy_locate (Var "a"), [All]))) ;
632-
[%expect {| rvalue(a, "a", index_omni()) |}]
601+
[%expect {| stan::model::rvalue(a, "a", stan::model::index_omni()) |}]
633602

634603
let%expect_test "pp_expr11" =
635604
printf "%s"

src/stan_math_backend/Stan_math_code_gen.ml

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ let pp_returntype ppf arg_types rt =
114114

115115
let pp_eigen_arg_to_ref ppf arg_types =
116116
let pp_ref ppf name =
117-
pf ppf "@[<hv 8>const auto& %s = to_ref(%s);@]" name (name ^ "_arg__") in
117+
pf ppf "@[<hv 8>const auto& %s = stan::math::to_ref(%s);@]" name
118+
(name ^ "_arg__") in
118119
pf ppf "@[<v>%a@]@ " (list ~sep:cut pp_ref)
119120
(List.filter_map
120121
~f:(fun (_, name, ut) ->
@@ -188,7 +189,7 @@ let get_templates_and_args exprs fdargs =
188189
let pp_template_decorator ppf = function
189190
| [] -> ()
190191
| templates ->
191-
pf ppf "@[<hov>template <%a>@]@ " (list ~sep:comma string) templates
192+
pf ppf "template @[<hov><%a>@]@ " (list ~sep:comma string) templates
192193

193194
let mk_extra_args templates args =
194195
List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args)
@@ -903,19 +904,9 @@ let pp_model ppf ({Program.prog_name; _} as p) =
903904
"%%NAME%%3 %%VERSION%%" stanc_args_to_print ;
904905
pf ppf "@ %a@]@]@ };" pp_model_public p
905906

906-
(** The C++ aliases needed for the model class*)
907907
let usings =
908908
{|
909-
using stan::io::dump;
910-
using stan::model::assign;
911-
using stan::model::index_uni;
912-
using stan::model::index_max;
913-
using stan::model::index_min;
914-
using stan::model::index_min_max;
915-
using stan::model::index_multi;
916-
using stan::model::index_omni;
917909
using stan::model::model_base_crtp;
918-
using stan::model::rvalue;
919910
using namespace stan::math;
920911
|}
921912

src/stan_math_backend/Statement_gen.ml

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ let pp_call_str ppf (name, args) = pp_call ppf (name, string, args)
88
let pp_block ppf (pp_body, body) = pf ppf "{@;<1 2>@[<v>%a@]@,}" pp_body body
99

1010
let pp_profile ppf (pp_body, name, body) =
11-
let profile =
12-
Fmt.str
13-
"profile<local_scalar_t__> profile__(%s, \
14-
const_cast<profile_map&>(profiles__));"
11+
let profile ppf name =
12+
pf ppf
13+
"@[<hov 2>stan::math::profile<local_scalar_t__> profile__(%s,@ \
14+
const_cast<stan::math::profile_map&>(profiles__));@]"
1515
name in
16-
pf ppf "{@;<1 2>@[<v>%s@;@;%a@]@,}" profile pp_body body
16+
pf ppf "{@;<1 2>@[<v>%a@;@;%a@]@,}" profile name pp_body body
1717

1818
let rec contains_eigen (ut : UnsizedType.t) : bool =
1919
match ut with
@@ -251,9 +251,10 @@ let pp_decl ppf (vident, pst, adtype, initialize) =
251251

252252
let math_fn_translations = function
253253
| Internal_fun.FnLength -> Some ("length", [])
254-
| FnValidateSize -> Some ("validate_non_negative_index", [])
255-
| FnValidateSizeSimplex -> Some ("validate_positive_index", [])
256-
| FnValidateSizeUnitVector -> Some ("validate_unit_vector_index", [])
254+
| FnValidateSize -> Some ("stan::math::validate_non_negative_index", [])
255+
| FnValidateSizeSimplex -> Some ("stan::math::validate_positive_index", [])
256+
| FnValidateSizeUnitVector ->
257+
Some ("stan::math::validate_unit_vector_index", [])
257258
| FnReadWriteEventsOpenCL x -> Some (x ^ ".wait_for_read_write_events", [])
258259
| _ -> None
259260

@@ -263,7 +264,7 @@ let trans_math_fn f =
263264

264265
let pp_bool_expr ppf expr =
265266
match Expr.Typed.type_of expr with
266-
| UReal -> pp_call ppf ("as_bool", pp_expr, [expr])
267+
| UReal -> pp_call ppf ("stan::math::as_bool", pp_expr, [expr])
267268
| _ -> pp_expr ppf expr
268269

269270
let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.{pattern; meta} =
@@ -305,13 +306,15 @@ let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.{pattern; meta} =
305306
Expr.Fixed.pattern= FunApp (CompilerInternal FnDeepCopy, [e]) }
306307
| _ -> recurse e in
307308
let rhs = maybe_deep_copy rhs in
308-
pf ppf "@[<hov 2>assign(@,%s,@ %a,@ %S%s%a@]);" assignee pp_expr rhs
309+
pf ppf "@[<hov 2>stan::model::assign(@,%s,@ %a,@ %S%s%a@]);" assignee
310+
pp_expr rhs
309311
(str "assigning variable %s" assignee)
310312
(if List.length idcs = 0 then "" else ", ")
311313
pp_indexes idcs
312314
| TargetPE e -> pf ppf "@[<hov 2>lp_accum__.add(@,%a@]);" pp_expr e
313315
| NRFunApp (CompilerInternal FnPrint, args) ->
314-
let pp_arg ppf a = pf ppf "stan_print(pstream__, %a);" pp_expr a in
316+
let pp_arg ppf a =
317+
pf ppf "stan::math::stan_print(pstream__, %a);" pp_expr a in
315318
let args = args @ [Expr.Helpers.str "\n"] in
316319
pf ppf "if (pstream__) %a" pp_block (list ~sep:cut pp_arg, args)
317320
| NRFunApp (CompilerInternal FnReject, args) ->
@@ -323,7 +326,8 @@ let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.{pattern; meta} =
323326
| NRFunApp (CompilerInternal (FnCheck {trans; var_name; var}), args) ->
324327
Option.iter (check_to_string trans) ~f:(fun check_name ->
325328
let function_arg = Expr.Helpers.variable "function__" in
326-
pf ppf "%s(@[<hov>%a@]);" ("check_" ^ check_name)
329+
pf ppf "%s(@[<hov>%a@]);"
330+
("stan::math::check_" ^ check_name)
327331
(list ~sep:comma pp_expr)
328332
(function_arg :: Expr.Helpers.str var_name :: var :: args) )
329333
| NRFunApp (CompilerInternal (FnWriteParam {unconstrain_opt; var}), _) -> (
@@ -344,7 +348,9 @@ let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.{pattern; meta} =
344348
pf ppf "%s(@[<hov>%a@]);" fname (list ~sep:comma pp_expr)
345349
(extra_args @ args)
346350
| NRFunApp (StanLib (fname, _, _), args) ->
347-
pf ppf "%s(@[<hov>%a@]);" fname (list ~sep:comma pp_expr) args
351+
pf ppf "%s(@[<hov>%a@]);"
352+
(stan_namespace_qualify fname)
353+
(list ~sep:comma pp_expr) args
348354
| NRFunApp (UserDefined (fname, suffix), args) ->
349355
pf ppf "%a;" pp_user_defined_fun (fname, suffix, args)
350356
| Break -> string ppf "break;"

test/integration/cli-args/filename_good.expected

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44
#include <stan/model/model_header.hpp>
55
namespace filename_good_model_namespace {
66

7-
using stan::io::dump;
8-
using stan::model::assign;
9-
using stan::model::index_uni;
10-
using stan::model::index_max;
11-
using stan::model::index_min;
12-
using stan::model::index_min_max;
13-
using stan::model::index_multi;
14-
using stan::model::index_omni;
157
using stan::model::model_base_crtp;
16-
using stan::model::rvalue;
178
using namespace stan::math;
189

1910

@@ -125,11 +116,12 @@ class filename_good_model final : public model_base_crtp<filename_good_model> {
125116
(void) function__; // suppress unused var warning
126117

127118
try {
128-
if (logical_negation((primitive_value(emit_transformed_parameters__) ||
129-
primitive_value(emit_generated_quantities__)))) {
119+
if (stan::math::logical_negation((stan::math::primitive_value(
120+
emit_transformed_parameters__) || stan::math::primitive_value(
121+
emit_generated_quantities__)))) {
130122
return ;
131123
}
132-
if (logical_negation(emit_generated_quantities__)) {
124+
if (stan::math::logical_negation(emit_generated_quantities__)) {
133125
return ;
134126
}
135127
} catch (const std::exception& e) {

0 commit comments

Comments
 (0)