Skip to content

Commit 21540c6

Browse files
committed
update to master and move mem_patterns tests to the compiler optims file
2 parents aaea223 + 727c762 commit 21540c6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+28785
-16210
lines changed

Jenkinsfile

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ pipeline {
2929
agent none
3030
parameters {
3131
booleanParam(name:"skip_end_to_end", defaultValue: false, description:"Skip end-to-end tests ")
32-
string(defaultValue: '', name: 'cmdstan_pr',
32+
string(defaultValue: 'develop', name: 'cmdstan_pr',
3333
description: "CmdStan PR to test against. Will check out this PR in the downstream Stan repo.")
34-
string(defaultValue: '', name: 'stan_pr',
34+
string(defaultValue: 'develop', name: 'stan_pr',
3535
description: "Stan PR to test against. Will check out this PR in the downstream Stan repo.")
36-
string(defaultValue: '', name: 'math_pr',
36+
string(defaultValue: 'develop', name: 'math_pr',
3737
description: "Math PR to test against. Will check out this PR in the downstream Math repo.")
3838
}
3939
options {parallelsAlwaysFailFast()}
@@ -283,12 +283,15 @@ pipeline {
283283

284284
unstash 'ubuntu-exe'
285285

286-
sh """
287-
git clone --recursive https://github.com/stan-dev/math.git
288-
cp bin/stanc math/test/expressions/stanc
289-
"""
290-
291286
script {
287+
sh """
288+
git clone --recursive https://github.com/stan-dev/math.git
289+
"""
290+
utils.checkout_pr("math", "math", params.math_pr)
291+
sh """
292+
cp bin/stanc math/test/expressions/stanc
293+
"""
294+
292295
dir("math") {
293296
sh """
294297
echo O=0 >> make/local

src/middle/Stan_math_signatures.ml

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ let distributions =
271271
; (full_lpdf, "student_t", [DVReal; DVReal; DVReal; DVReal], SoA)
272272
; (full_lpdf, "std_normal", [DVReal], SoA)
273273
; (full_lpdf, "uniform", [DVReal; DVReal; DVReal], SoA)
274-
; ([Lpdf; Rng], "von_mises", [DVReal; DVReal; DVReal], SoA)
274+
; (full_lpdf, "von_mises", [DVReal; DVReal; DVReal], SoA)
275275
; (full_lpdf, "weibull", [DVReal; DVReal; DVReal], SoA)
276276
; ([Lpdf], "wiener", [DVReal; DVReal; DVReal; DVReal; DVReal], SoA)
277277
; ([Lpdf], "wishart", [DMatrix; DReal; DMatrix], SoA) ]
@@ -1688,6 +1688,16 @@ let () =
16881688
, ReturnType UReal
16891689
, [UVector; UMatrix; UVector; UVector; UReal]
16901690
, AoS ) ;
1691+
add_unqualified
1692+
( "normal_id_glm_lpdf"
1693+
, ReturnType UReal
1694+
, [UReal; UMatrix; UReal; UVector; UReal]
1695+
, AoS ) ;
1696+
add_unqualified
1697+
( "normal_id_glm_lpdf"
1698+
, ReturnType UReal
1699+
, [UReal; UMatrix; UVector; UVector; UReal]
1700+
, AoS ) ;
16911701
add_unqualified
16921702
( "normal_id_glm_lpdf"
16931703
, ReturnType UReal
@@ -1703,11 +1713,21 @@ let () =
17031713
, ReturnType UReal
17041714
, [UVector; URowVector; UReal; UVector; UVector]
17051715
, AoS ) ;
1716+
add_unqualified
1717+
( "normal_id_glm_lpdf"
1718+
, ReturnType UReal
1719+
, [UVector; URowVector; UVector; UVector; UReal]
1720+
, AoS ) ;
17061721
add_unqualified
17071722
( "normal_id_glm_lpdf"
17081723
, ReturnType UReal
17091724
, [UVector; URowVector; UVector; UVector; UVector]
17101725
, AoS ) ;
1726+
add_unqualified
1727+
( "normal_id_glm_lpdf"
1728+
, ReturnType UReal
1729+
, [UVector; URowVector; UReal; UVector; UReal]
1730+
, AoS ) ;
17111731
add_nullary "not_a_number" ;
17121732
add_unqualified ("num_elements", ReturnType UInt, [UMatrix], SoA) ;
17131733
add_unqualified ("num_elements", ReturnType UInt, [UVector], SoA) ;

src/stan_math_backend/Expression_gen.ml

Lines changed: 56 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,9 @@ open Fmt
66
let ends_with suffix s = String.is_suffix ~suffix s
77
let starts_with prefix s = String.is_prefix ~prefix s
88

9-
let functions_requiring_namespace =
10-
String.Set.of_list
11-
[ "e"; "pi"; "log2"; "log10"; "sqrt2"; "not_a_number"; "positive_infinity"
12-
; "negative_infinity"; "machine_precision"; "abs"; "acos"; "acosh"; "asin"
13-
; "asinh"; "atan"; "atanh"; "cbrt"; "ceil"; "cos"; "cosh"; "erf"; "erfc"
14-
; "exp"; "exp2"; "expm1"; "fabs"; "floor"; "lgamma"; "log"; "log1p"; "log2"
15-
; "log10"; "round"; "sin"; "sinh"; "sqrt"; "tan"; "tanh"; "tgamma"; "trunc"
16-
; "fdim"; "fmax"; "fmin"; "hypot"; "fma"; "complex" ]
17-
189
let stan_namespace_qualify f =
19-
if Set.mem functions_requiring_namespace f then "stan::math::" ^ f else f
10+
if String.is_suffix ~suffix:"functor__" f || String.contains f ':' then f
11+
else "stan::math::" ^ f
2012

2113
(* return true if the types of the two expression are the same *)
2214
let types_match e1 e2 =
@@ -157,7 +149,7 @@ let fn_renames =
157149
~f:(fun (k, v) -> (Internal_fun.to_string k, v))
158150
[ (Internal_fun.FnLength, "stan::math::size")
159151
; (FnNegInf, "stan::math::negative_infinity")
160-
; (FnResizeToMatch, "resize_to_match")
152+
; (FnResizeToMatch, "stan::math::resize_to_match")
161153
; (FnNaN, "std::numeric_limits<double>::quiet_NaN") ]
162154
|> String.Map.of_alist_exn
163155

@@ -207,110 +199,86 @@ let transform_args = function
207199
| transform -> Transformation.fold (fun args arg -> args @ [arg]) [] transform
208200

209201
let rec pp_index ppf = function
210-
| Index.All -> pf ppf "index_omni()"
211-
| Single e -> pf ppf "index_uni(%a)" pp_expr e
212-
| Upfrom e -> pf ppf "index_min(%a)" pp_expr e
202+
| Index.All -> pf ppf "stan::model::index_omni()"
203+
| Single e -> pf ppf "stan::model::index_uni(%a)" pp_expr e
204+
| Upfrom e -> pf ppf "stan::model::index_min(%a)" pp_expr e
213205
| Between (e_low, e_high) ->
214-
pf ppf "index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
215-
| MultiIndex e -> pf ppf "index_multi(%a)" pp_expr e
206+
pf ppf "stan::model::index_min_max(%a, %a)" pp_expr e_low pp_expr e_high
207+
| MultiIndex e -> pf ppf "stan::model::index_multi(%a)" pp_expr e
216208

217209
and pp_indexes ppf = function
218210
| [] -> pf ppf ""
219211
| idxs -> pf ppf "@[<hov 2>%a@]" (list ~sep:comma pp_index) idxs
220212

221213
and pp_logical_op ppf op lhs rhs =
222-
pf ppf "(primitive_value(@,%a)@ %s@ primitive_value(@,%a))" pp_expr lhs op
223-
pp_expr rhs
214+
pf ppf
215+
"(stan::math::primitive_value(@,%a)@ %s@ stan::math::primitive_value(@,%a))"
216+
pp_expr lhs op pp_expr rhs
224217

225218
and pp_unary ppf fm es = pf ppf fm pp_expr (List.hd_exn es)
226-
and pp_binary ppf fm es = pf ppf fm pp_expr (first es) pp_expr (second es)
219+
220+
and pp_binary_op ppf op es =
221+
pf ppf "(%a@ %s@ %a)" pp_expr (first es) op pp_expr (second es)
227222

228223
and pp_binary_f ppf f es =
229224
pf ppf "%s(@,%a,@ %a)" f pp_expr (first es) pp_expr (second es)
230225

231226
and first es = List.nth_exn es 0
232227
and second es = List.nth_exn es 1
233228

234-
and pp_scalar_binary ppf scalar_fmt generic_fmt es =
235-
pp_binary ppf
236-
( if is_scalar (first es) && is_scalar (second es) then scalar_fmt
237-
else generic_fmt )
238-
es
229+
and pp_scalar_binary ppf op fn es =
230+
if is_scalar (first es) && is_scalar (second es) then pp_binary_op ppf op es
231+
else pp_binary_f ppf fn es
239232

240-
and gen_operator_app op_expr =
241-
match op_expr with
242-
| Operator.Plus ->
243-
fun ppf es -> pp_scalar_binary ppf "(%a@ +@ %a)" "add(@,%a,@ %a)" es
233+
and gen_operator_app op ppf es =
234+
match op with
235+
| Operator.Plus -> pp_scalar_binary ppf "+" "stan::math::add" es
244236
| PMinus ->
245-
fun ppf es ->
246-
pp_unary ppf
247-
(if is_scalar (List.hd_exn es) then "-%a" else "minus(@,%a)")
248-
es
249-
| PPlus -> fun ppf es -> pp_unary ppf "%a" es
237+
pp_unary ppf
238+
(if is_scalar (List.hd_exn es) then "-%a" else "stan::math::minus(@,%a)")
239+
es
240+
| PPlus -> pp_unary ppf "%a" es
250241
| Transpose ->
251-
fun ppf es ->
252-
pp_unary ppf
253-
(if is_scalar (List.hd_exn es) then "%a" else "transpose(@,%a)")
254-
es
255-
| PNot -> fun ppf es -> pp_unary ppf "logical_negation(@,%a)" es
256-
| Minus ->
257-
fun ppf es -> pp_scalar_binary ppf "(%a@ -@ %a)" "subtract(@,%a,@ %a)" es
258-
| Times ->
259-
fun ppf es -> pp_scalar_binary ppf "(%a@ *@ %a)" "multiply(@,%a,@ %a)" es
242+
pp_unary ppf
243+
( if is_scalar (List.hd_exn es) then "%a"
244+
else "stan::math::transpose(@,%a)" )
245+
es
246+
| PNot -> pp_unary ppf "stan::math::logical_negation(@,%a)" es
247+
| Minus -> pp_scalar_binary ppf "-" "stan::math::subtract" es
248+
| Times -> pp_scalar_binary ppf "*" "stan::math::multiply" es
260249
| Divide | IntDivide ->
261-
fun ppf es ->
262-
if
263-
is_matrix (second es)
264-
&& (is_matrix (first es) || is_row_vector (first es))
265-
then pp_binary_f ppf "mdivide_right" es
266-
else pp_scalar_binary ppf "(%a@ /@ %a)" "divide(@,%a,@ %a)" es
267-
| Modulo -> fun ppf es -> pp_binary_f ppf "modulus" es
268-
| LDivide -> fun ppf es -> pp_binary_f ppf "mdivide_left" es
250+
if
251+
is_matrix (second es)
252+
&& (is_matrix (first es) || is_row_vector (first es))
253+
then pp_binary_f ppf "stan::math::mdivide_right" es
254+
else pp_scalar_binary ppf "/" "stan::math::divide" es
255+
| Modulo -> pp_binary_f ppf "stan::math::modulus" es
256+
| LDivide -> pp_binary_f ppf "stan::math::mdivide_left" es
269257
| And | Or ->
270258
Common.FatalError.fatal_error_msg
271259
[%message "And/Or should have been converted to an expression"]
272-
| EltTimes ->
273-
fun ppf es ->
274-
pp_scalar_binary ppf "(%a@ *@ %a)" "elt_multiply(@,%a,@ %a)" es
275-
| EltDivide ->
276-
fun ppf es ->
277-
pp_scalar_binary ppf "(%a@ /@ %a)" "elt_divide(@,%a,@ %a)" es
278-
| Pow -> fun ppf es -> pp_binary_f ppf "pow" es
279-
| EltPow -> fun ppf es -> pp_binary_f ppf "pow" es
280-
| Equals -> fun ppf es -> pp_binary_f ppf "logical_eq" es
281-
| NEquals -> fun ppf es -> pp_binary_f ppf "logical_neq" es
282-
| Less -> fun ppf es -> pp_binary_f ppf "logical_lt" es
283-
| Leq -> fun ppf es -> pp_binary_f ppf "logical_lte" es
284-
| Greater -> fun ppf es -> pp_binary_f ppf "logical_gt" es
285-
| Geq -> fun ppf es -> pp_binary_f ppf "logical_gte" es
260+
| EltTimes -> pp_scalar_binary ppf "*" "stan::math::elt_multiply" es
261+
| EltDivide -> pp_scalar_binary ppf "/" "stan::math::elt_divide" es
262+
| Pow -> pp_binary_f ppf "stan::math::pow" es
263+
| EltPow -> pp_binary_f ppf "stan::math::pow" es
264+
| Equals -> pp_binary_f ppf "stan::math::logical_eq" es
265+
| NEquals -> pp_binary_f ppf "stan::math::logical_neq" es
266+
| Less -> pp_binary_f ppf "stan::math::logical_lt" es
267+
| Leq -> pp_binary_f ppf "stan::math::logical_lte" es
268+
| Greater -> pp_binary_f ppf "stan::math::logical_gt" es
269+
| Geq -> pp_binary_f ppf "stan::math::logical_gte" es
286270

287271
and gen_misc_special_math_app (f : string)
288272
(mem_pattern : Common.Helpers.mem_pattern)
289273
(ret_type : UnsizedType.returntype option) =
290274
match f with
291275
| "lmultiply" ->
292-
Some (fun ppf es -> pp_binary ppf "multiply_log(@,%a,@ %a)" es)
276+
Some (fun ppf es -> pp_binary_f ppf "stan::math::multiply_log" es)
293277
| "lchoose" ->
294-
Some (fun ppf es -> pp_binary ppf "binomial_coefficient_log(@,%a,@ %a)" es)
295-
| "target" -> Some (fun ppf _ -> pf ppf "get_lp(lp__, lp_accum__)")
296-
| "get_lp" -> Some (fun ppf _ -> pf ppf "get_lp(lp__, lp_accum__)")
297-
| "max" | "min" ->
298-
Some
299-
(fun ppf es ->
300-
let f = match es with [_; _] -> "std::" ^ f | _ -> f in
301-
pp_call ppf (f, pp_expr, es) )
302-
| "ceil" ->
303-
let std_prefix_data_scalar f = function
304-
| [ Expr.
305-
{ Fixed.meta=
306-
Typed.Meta.{adlevel= DataOnly; type_= UInt | UReal; _}
307-
; _ } ] ->
308-
"std::" ^ f
309-
| _ -> f in
310278
Some
311-
(fun ppf es ->
312-
let f = std_prefix_data_scalar f es in
313-
pp_call ppf (f, pp_expr, es) )
279+
(fun ppf es -> pp_binary_f ppf "stan::math::binomial_coefficient_log" es)
280+
| "target" -> Some (fun ppf _ -> pf ppf "stan::math::get_lp(lp__, lp_accum__)")
281+
| "get_lp" -> Some (fun ppf _ -> pf ppf "stan::math::get_lp(lp__, lp_accum__)")
314282
| f when Map.mem fn_renames f ->
315283
Some (fun ppf es -> pp_call ppf (Map.find_exn fn_renames f, pp_expr, es))
316284
| "rep_matrix" | "rep_vector" | "rep_row_vector" | "append_row" | "append_col"
@@ -542,7 +510,8 @@ and pp_promoted ad ut ppf e =
542510
(local_scalar ut ad) pp_expr e )
543511

544512
and pp_indexed ppf (vident, indices, pretty) =
545-
pf ppf "@[<hov 2>rvalue(@,%s,@ %S,@ %a)@]" vident pretty pp_indexes indices
513+
pf ppf "@[<hov 2>stan::model::rvalue(@,%s,@ %S,@ %a)@]" vident pretty
514+
pp_indexes indices
546515

547516
and pp_indexed_simple ppf (obj, idcs) =
548517
let idx_minus_one = function
@@ -569,7 +538,7 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
569538
match pattern with
570539
| Var s -> pf ppf "%s" s
571540
| Lit (Str, s) -> pf ppf "\"%s\"" (Cpp_str.escaped s)
572-
| Lit (Imaginary, s) -> pf ppf "to_complex(0, %s)" s
541+
| Lit (Imaginary, s) -> pf ppf "stan::math::to_complex(0, %s)" s
573542
| Lit ((Real | Int), s) -> pf ppf "%s" s
574543
| FunApp
575544
( StanLib (op, _, _)
@@ -587,7 +556,6 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
587556
gen_fun_app suffix ppf f es mem_pattern ret_type
588557
| FunApp (CompilerInternal f, es) ->
589558
pp_compiler_internal_fn meta.adlevel meta.type_ f ppf es
590-
(* stan_namespace_qualify? *)
591559
| FunApp (UserDefined (f, suffix), es) ->
592560
pp_user_defined_fun ppf (f, suffix, es)
593561
| EAnd (e1, e2) -> pp_logical_op ppf "&&" e1 e2
@@ -675,7 +643,7 @@ let%expect_test "pp_expr9" =
675643

676644
let%expect_test "pp_expr10" =
677645
printf "%s" (pp_unlocated (Indexed (dummy_locate (Var "a"), [All]))) ;
678-
[%expect {| rvalue(a, "a", index_omni()) |}]
646+
[%expect {| stan::model::rvalue(a, "a", stan::model::index_omni()) |}]
679647

680648
let%expect_test "pp_expr11" =
681649
printf "%s"

src/stan_math_backend/Stan_math_code_gen.ml

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ let pp_returntype ppf arg_types rt =
114114

115115
let pp_eigen_arg_to_ref ppf arg_types =
116116
let pp_ref ppf name =
117-
pf ppf "@[<hv 8>const auto& %s = to_ref(%s);@]" name (name ^ "_arg__") in
117+
pf ppf "@[<hv 8>const auto& %s = stan::math::to_ref(%s);@]" name
118+
(name ^ "_arg__") in
118119
pf ppf "@[<v>%a@]@ " (list ~sep:cut pp_ref)
119120
(List.filter_map
120121
~f:(fun (_, name, ut) ->
@@ -188,7 +189,7 @@ let get_templates_and_args exprs fdargs =
188189
let pp_template_decorator ppf = function
189190
| [] -> ()
190191
| templates ->
191-
pf ppf "@[<hov>template <%a>@]@ " (list ~sep:comma string) templates
192+
pf ppf "template @[<hov><%a>@]@ " (list ~sep:comma string) templates
192193

193194
let mk_extra_args templates args =
194195
List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args)
@@ -903,19 +904,9 @@ let pp_model ppf ({Program.prog_name; _} as p) =
903904
"%%NAME%%3 %%VERSION%%" stanc_args_to_print ;
904905
pf ppf "@ %a@]@]@ };" pp_model_public p
905906

906-
(** The C++ aliases needed for the model class*)
907907
let usings =
908908
{|
909-
using stan::io::dump;
910-
using stan::model::assign;
911-
using stan::model::index_uni;
912-
using stan::model::index_max;
913-
using stan::model::index_min;
914-
using stan::model::index_min_max;
915-
using stan::model::index_multi;
916-
using stan::model::index_omni;
917909
using stan::model::model_base_crtp;
918-
using stan::model::rvalue;
919910
using namespace stan::math;
920911
|}
921912

0 commit comments

Comments
 (0)