Skip to content

Commit 8626779

Browse files
authored
Merge pull request #871 from SteveBronder/feature/cleanup-cpp-meta-info
Cleanup some of the C++
2 parents 282439e + f21d61f commit 8626779

File tree

11 files changed

+9092
-7452
lines changed

11 files changed

+9092
-7452
lines changed

src/stan_math_backend/Cpp_Json.ml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ module Str = Re.Str
44

55
let rec sizedtype_to_json (st : Expr.Typed.t SizedType.t) : Yojson.Basic.t =
66
let emit_cpp_expr e =
7-
Fmt.strf "<< %a >>" Expression_gen.pp_expr e
8-
|> Str.global_replace (Str.regexp "[\n\r\t ]+") " "
7+
Fmt.strf "+ std::to_string(%a) +" Expression_gen.pp_expr e
98
in
109
match st with
1110
| SInt -> `Assoc [("name", `String "int")]
@@ -44,17 +43,20 @@ let%expect_test "outvar to json pretty" =
4443
"name": "var_one",
4544
"type": {
4645
"name": "array",
47-
"length": "<< K >>",
48-
"element_type": { "name": "vector", "length": "<< N >>" }
46+
"length": "+ std::to_string(K) +",
47+
"element_type": { "name": "vector", "length": "+ std::to_string(N) +" }
4948
},
5049
"block": "parameters"
5150
} |}]
5251

52+
(*Adds a backslash to all the inner quotes and then
53+
unslash the ones near a plus*)
5354
let replace_cpp_expr s =
5455
s
5556
|> Str.global_replace (Str.regexp {|"|}) {|\"|}
56-
|> Str.global_replace (Str.regexp {|\\"<<|}) {|" <<|}
57-
|> Str.global_replace (Str.regexp {|>>\\"|}) {|<< "|}
57+
|> Str.global_replace (Str.regexp {|\\"\+|}) {|" +|}
58+
|> Str.global_replace (Str.regexp {|\+\\"|}) {|+ "|}
59+
|> Str.global_replace (Str.regexp {|\\n|}) {||}
5860
5961
let wrap_in_quotes s = "\"" ^ s ^ "\""
6062

@@ -70,4 +72,4 @@ let%expect_test "outvar to json" =
7072
|> out_var_interpolated_json_str |> print_endline ;
7173
[%expect
7274
{|
73-
"[{\"name\":\"var_one\",\"type\":{\"name\":\"array\",\"length\":" << K << ",\"element_type\":{\"name\":\"vector\",\"length\":" << N << "}},\"block\":\"parameters\"}]" |}]
75+
"[{\"name\":\"var_one\",\"type\":{\"name\":\"array\",\"length\":" + std::to_string(K) + ",\"element_type\":{\"name\":\"vector\",\"length\":" + std::to_string(N) + "}},\"block\":\"parameters\"}]" |}]

src/stan_math_backend/Stan_math_code_gen.ml

Lines changed: 75 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,15 @@ let includes = "#include <stan/model/model_header.hpp>"
354354
let pp_validate_data ppf (name, st) =
355355
if String.is_suffix ~suffix:"__" name then ()
356356
else
357+
let pp_stdvector ppf args =
358+
let pp_cast ppf x = pf ppf "static_cast<size_t>(%a)" pp_expr x in
359+
pf ppf "@[<hov 2> std::vector<size_t>{@,%a}@]" (list ~sep:comma pp_cast)
360+
args
361+
in
357362
pf ppf "@[<hov 4>context__.validate_dims(@,%S,@,%S,@,%S,@,%a);@]@ "
358363
"data initialization" name
359364
(stantype_prim_str (SizedType.to_unsized st))
360-
pp_call
361-
("context__.to_vec", pp_expr, SizedType.get_dims st)
365+
pp_stdvector (SizedType.get_dims st)
362366

363367
(** Print the constructor of the model class.
364368
Read in data steps:
@@ -374,26 +378,19 @@ let pp_ctor ppf p =
374378
in
375379
pf ppf "%s(@[<hov 0>%a) : model_base_crtp(0) @]" p.Program.prog_name
376380
(list ~sep:comma string) params ;
377-
let pp_mul ppf () = pf ppf " * " in
378-
let pp_num_param ppf dims =
379-
pf ppf "num_params_r__ += %a;" (list ~sep:pp_mul pp_expr) dims
380-
in
381-
let get_param_st = function
382-
| _, {Program.out_block= Parameters; out_unconstrained_st= st; _} -> (
383-
match SizedType.get_dims st with
384-
| [] -> Some [Expr.Helpers.loop_bottom]
385-
| ls -> Some ls )
386-
| _ -> None
387-
in
388381
let data_idents = List.map ~f:fst p.input_vars |> String.Set.of_list in
389382
let pp_stmt_topdecl_size_only ppf (Stmt.Fixed.({pattern; meta}) as s) =
390383
match pattern with
391384
| Decl {decl_id; decl_type; _} when decl_id <> "pos__" -> (
392385
match decl_type with
393-
| Sized st ->
386+
| Sized st -> (
394387
Locations.pp_smeta ppf meta ;
395-
if Set.mem data_idents decl_id then pp_validate_data ppf (decl_id, st) ;
396-
pp_set_size ppf (decl_id, st, DataOnly)
388+
let is_input_data = Set.mem data_idents decl_id in
389+
match is_input_data with
390+
| true ->
391+
pp_validate_data ppf (decl_id, st) ;
392+
pp_set_size ppf (decl_id, st, DataOnly, false)
393+
| false -> pp_set_size ppf (decl_id, st, DataOnly, true) )
397394
| Unsized _ -> () )
398395
| _ -> pp_statement ppf s
399396
in
@@ -412,11 +409,29 @@ let pp_ctor ppf p =
412409
pp_located_error ppf
413410
(pp_block, (list ~sep:cut pp_stmt_topdecl_size_only, prepare_data)) ;
414411
cut ppf () ;
415-
pf ppf "num_params_r__ = 0U;@ " ;
416-
pp_located_error ppf
417-
( pp_block
418-
, ( list ~sep:cut pp_num_param
419-
, List.filter_map ~f:get_param_st output_vars ) ) )
412+
let get_param_st = function
413+
| _, {Program.out_block= Parameters; out_unconstrained_st= st; _} -> (
414+
match SizedType.get_dims st with
415+
| [] -> Some [Expr.Helpers.loop_bottom]
416+
| ls -> Some ls )
417+
| _ -> None
418+
in
419+
let output_params = List.filter_map ~f:get_param_st output_vars in
420+
let pp_mul ppf () = pf ppf " * " in
421+
let pp_num_param ppf (dims : Expr.Typed.t list) =
422+
match dims with
423+
| [a] -> pf ppf "@[%a@]@," (list ~sep:pp_mul pp_expr) [a]
424+
| _ -> pf ppf "@[(%a)@]@," (list ~sep:pp_mul pp_expr) dims
425+
in
426+
let pp_plus ppf () = pf ppf " + " in
427+
let pp_set_params ppf pars =
428+
(list ~sep:pp_plus pp_num_param) ppf pars
429+
in
430+
match output_params with
431+
| [] -> pf ppf "num_params_r__ = 0U;@,"
432+
| _ ->
433+
pf ppf "@[<hov 2>num_params_r__ = %a;@]@," pp_set_params
434+
output_params )
420435
, p )
421436

422437
let rec top_level_decls Stmt.Fixed.({pattern; _}) =
@@ -440,8 +455,8 @@ let pp_model_private ppf {Program.prepare_data; _} =
440455
@param cv_attr Optional parameter to add method attributes.
441456
@param ppbody (?A pretty printer of the method's body)
442457
*)
443-
let pp_method ppf rt name params intro ?(outro = nop) ?(cv_attr = ["const"])
444-
ppbody =
458+
let pp_method ppf rt name params intro ?(outro = nop)
459+
?(cv_attr : string list = ["const"]) ppbody =
445460
pf ppf "@[<v 2>inline %s %s(@[<hov>@,%a@]) %a " rt name
446461
(list ~sep:comma string) params (list ~sep:cut string) cv_attr ;
447462
pf ppf "{@,%a@ " intro () ;
@@ -453,39 +468,40 @@ let pp_method ppf rt name params intro ?(outro = nop) ?(cv_attr = ["const"])
453468
@param ppf A pretty printer.
454469
*)
455470
let pp_get_param_names ppf {Program.output_vars; _} =
456-
let add_param = fmt "names__.emplace_back(%S);" in
457-
pp_method ppf "void" "get_param_names" ["std::vector<std::string>& names__"]
458-
nop (fun ppf ->
459-
pf ppf "names__.clear();@ " ;
460-
(list ~sep:cut add_param) ppf (List.map ~f:fst output_vars) )
471+
let add_param = fmt "%S" in
472+
pp_method ppf "void" "get_param_names"
473+
["std::vector<std::string>& names__"]
474+
nop
475+
(fun ppf ->
476+
pf ppf "@[<hov 2>names__ = std::vector<std::string>{%a};@]@,"
477+
(list ~sep:comma add_param)
478+
(List.map ~f:fst output_vars) )
479+
~cv_attr:["const"]
461480

462481
(** Print the `get_dims` method of the model class. *)
463482
let pp_get_dims ppf {Program.output_vars; _} =
464483
let pp_cast ppf cast_dims =
465-
pf ppf "static_cast<size_t>(%a)@," pp_expr cast_dims
484+
pf ppf "@[<hov 2>static_cast<size_t>(%a)@]@," pp_expr cast_dims
466485
in
467486
let pp_pack ppf inner_dims =
468487
pf ppf "std::vector<size_t>{@[<hov>@,%a@]}" (list ~sep:comma pp_cast)
469488
inner_dims
470489
in
471-
let pp_add_pack ppf dims =
472-
pf ppf "dimss__.emplace_back(%a);@," pp_pack dims
490+
let pp_add_pack ppf dims = pf ppf "%a" pp_pack dims in
491+
let dim_list =
492+
List.(
493+
map ~f:(fun (_, {Program.out_constrained_st= st; _}) -> st) output_vars)
473494
in
474-
let pp_output_var ppf =
475-
(list ~sep:cut pp_add_pack)
476-
ppf
477-
List.(
478-
map ~f:SizedType.get_dims
479-
(map
480-
~f:(fun (_, {Program.out_constrained_st= st; _}) -> st)
481-
output_vars))
495+
let pp_output_var ppf dims =
496+
(list ~sep:comma pp_add_pack) ppf List.(map ~f:SizedType.get_dims dims)
482497
in
483-
let params = ["std::vector<std::vector<size_t>>& dimss__"] in
484-
let cv_attr = ["const"] in
485-
pp_method ppf "void" "get_dims" params
486-
(const string "dimss__.clear();")
487-
(fun ppf -> pp_output_var ppf)
488-
~cv_attr
498+
pp_method ppf "void" "get_dims"
499+
["std::vector<std::vector<size_t>>& dimss__"]
500+
nop
501+
(fun ppf ->
502+
pf ppf "@[<hov 2>dimss__ = std::vector<std::vector<size_t>>{%a};@]@,"
503+
pp_output_var dim_list )
504+
~cv_attr:["const"]
489505

490506
let pp_method_b ppf rt name params intro ?(outro = nop) ?(cv_attr = ["const"])
491507
body =
@@ -544,7 +560,7 @@ let rec pp_for_loop_iteratee ?(index_ids = []) ppf (iteratee, dims, pp_body) =
544560
| [] -> pp_body ppf (iteratee, index_ids)
545561
| dim :: dims ->
546562
iter dim (fun ppf (i, idcs) ->
547-
pf ppf "%a" pp_block
563+
pf ppf "@[%a @]" pp_block
548564
(pp_for_loop_iteratee ~index_ids:idcs, (i, dims, pp_body)) )
549565

550566
(** Print the `constrained_param_names` method of the model class. *)
@@ -575,15 +591,14 @@ let pp_constrained_param_names ppf {Program.output_vars; _} =
575591
let dims = List.rev (SizedType.get_dims st) in
576592
pp_for_loop_iteratee ppf (decl_id, dims, emit_name)
577593
in
578-
let cv_attr = ["const"; "final"] in
579594
pp_method ppf "void" "constrained_param_names" params nop
580595
(fun ppf ->
581596
(list ~sep:cut pp_param_names) ppf paramvars ;
582597
pf ppf "@,if (emit_transformed_parameters__) %a@," pp_block
583598
(list ~sep:cut pp_param_names, tparamvars) ;
584599
pf ppf "@,if (emit_generated_quantities__) %a@," pp_block
585600
(list ~sep:cut pp_param_names, gqvars) )
586-
~cv_attr
601+
~cv_attr:["const"; "final"]
587602

588603
(* Print the `unconstrained_param_names` method of the model class.
589604
This is just a copy of constrained, I need to figure out which one is wrong
@@ -695,11 +710,9 @@ let pp_log_prob ppf Program.({prog_name; log_prob; _}) =
695710
@param outvars The parameters to gather the sizes for.
696711
*)
697712
let pp_outvar_metadata ppf (method_name, outvars) =
698-
let intro = const string "std::stringstream s__;" in
699-
let outro ppf () = pf ppf "@ return s__.str();" in
700713
let json_str = Cpp_Json.out_var_interpolated_json_str outvars in
701-
let ppbody ppf = pf ppf "s__ << %s;" json_str in
702-
pp_method ppf "std::string" method_name [] intro ~outro ppbody
714+
let ppbody ppf = pf ppf "@[<hov 2>return std::string(%s);@]@," json_str in
715+
pp_method ppf "std::string" method_name [] nop ppbody ~cv_attr:["const"]
703716

704717
(** Print the `get_unconstrained_sizedtypes` method of the model class *)
705718
let pp_unconstrained_types ppf {Program.output_vars; _} =
@@ -729,14 +742,13 @@ let pp_overloads ppf () =
729742
const bool emit_transformed_parameters = true,
730743
const bool emit_generated_quantities = true,
731744
std::ostream* pstream = nullptr) const {
732-
std::vector<double> vars_vec(vars.size());
745+
std::vector<double> vars_vec;
746+
vars_vec.reserve(vars.size());
733747
std::vector<int> params_i;
734748
write_array_impl(base_rng, params_r, params_i, vars_vec,
735749
emit_transformed_parameters, emit_generated_quantities, pstream);
736-
vars.resize(vars_vec.size());
737-
for (int i = 0; i < vars.size(); ++i) {
738-
vars.coeffRef(i) = vars_vec[i];
739-
}
750+
vars = Eigen::Map<Eigen::Matrix<double,Eigen::Dynamic,1>>(
751+
vars_vec.data(), vars_vec.size());
740752
}
741753

742754
template <typename RNG>
@@ -746,7 +758,8 @@ let pp_overloads ppf () =
746758
bool emit_transformed_parameters = true,
747759
bool emit_generated_quantities = true,
748760
std::ostream* pstream = nullptr) const {
749-
write_array_impl(base_rng, params_r, params_i, vars, emit_transformed_parameters, emit_generated_quantities, pstream);
761+
write_array_impl(base_rng, params_r, params_i, vars,
762+
emit_transformed_parameters, emit_generated_quantities, pstream);
750763
}
751764

752765
template <bool propto__, bool jacobian__, typename T_>
@@ -767,13 +780,12 @@ let pp_overloads ppf () =
767780
inline void transform_inits(const stan::io::var_context& context,
768781
Eigen::Matrix<double, Eigen::Dynamic, 1>& params_r,
769782
std::ostream* pstream = nullptr) const final {
770-
std::vector<double> params_r_vec(params_r.size());
783+
std::vector<double> params_r_vec;
784+
params_r_vec.reserve(params_r.size());
771785
std::vector<int> params_i;
772786
transform_inits_impl(context, params_i, params_r_vec, pstream);
773-
params_r.resize(params_r_vec.size());
774-
for (int i = 0; i < params_r.size(); ++i) {
775-
params_r.coeffRef(i) = params_r_vec[i];
776-
}
787+
params_r = Eigen::Map<Eigen::Matrix<double,Eigen::Dynamic,1>>(
788+
params_r_vec.data(), params_r_vec.size());
777789
}
778790
inline void transform_inits(const stan::io::var_context& context,
779791
std::vector<int>& params_i,

src/stan_math_backend/Statement_gen.ml

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ let pp_profile ppf (pp_body, name, body) =
1515
in
1616
pf ppf "{@;<1 2>@[<v>%s@;@;%a@]@,}" profile pp_body body
1717

18-
let rec contains_eigen = function
18+
let rec contains_eigen (ut : UnsizedType.t) : bool =
19+
match ut with
1920
| UnsizedType.UArray t -> contains_eigen t
2021
| UMatrix | URowVector | UVector -> true
2122
| UInt | UReal | UMathLibraryFunction | UFun _ -> false
2223

23-
let pp_set_size ppf (decl_id, st, adtype) =
24+
let pp_set_size ppf (decl_id, st, adtype, (needs_filled : bool)) =
2425
(* TODO: generate optimal adtypes for expressions and declarations *)
2526
let real_nan =
2627
match adtype with
@@ -38,25 +39,30 @@ let pp_set_size ppf (decl_id, st, adtype) =
3839
| SMatrix (d1, d2) -> pf ppf "%a(%a, %a)" pp_st st pp_expr d1 pp_expr d2
3940
| SArray (t, d) -> pf ppf "%a(%a, %a)" pp_st st pp_expr d pp_size_ctor t
4041
in
41-
pf ppf "@[<hov 2>%s = %a;@]@," decl_id pp_size_ctor st ;
42-
if contains_eigen (SizedType.to_unsized st) then
43-
pf ppf "@[<hov 2>stan::math::fill(%s, %s);@]@," decl_id real_nan
42+
let print_fill ppf st =
43+
match (contains_eigen (SizedType.to_unsized st), needs_filled) with
44+
| true, true -> pf ppf "stan::math::fill(%s, %s);" decl_id real_nan
45+
| _, _ -> ()
46+
in
47+
pf ppf "@[<hov 0>%s = %a;@,%a @]@," decl_id pp_size_ctor st print_fill st
4448

4549
let%expect_test "set size mat array" =
4650
let int = Expr.Helpers.int in
4751
strf "@[<v>%a@]" pp_set_size
48-
("d", SArray (SArray (SMatrix (int 2, int 3), int 4), int 5), DataOnly)
52+
( "d"
53+
, SArray (SArray (SMatrix (int 2, int 3), int 4), int 5)
54+
, DataOnly
55+
, false )
4956
|> print_endline ;
5057
[%expect
5158
{|
52-
d = std::vector<std::vector<Eigen::Matrix<double, -1, -1>>>(5, std::vector<Eigen::Matrix<double, -1, -1>>(4, Eigen::Matrix<double, -1, -1>(2, 3)));
53-
stan::math::fill(d, std::numeric_limits<double>::quiet_NaN()); |}]
59+
d = std::vector<std::vector<Eigen::Matrix<double, -1, -1>>>(5, std::vector<Eigen::Matrix<double, -1, -1>>(4, Eigen::Matrix<double, -1, -1>(2, 3))); |}]
5460

5561
(** [pp_for_loop ppf (loopvar, lower, upper, pp_body, body)] tries to
5662
pretty print a for-loop from lower to upper given some loopvar.*)
5763
let pp_for_loop ppf (loopvar, lower, upper, pp_body, body) =
58-
pf ppf "@[<hov>for (@[<hov>int %s = %a;@ %s <= %a;@ ++%s@])" loopvar pp_expr
59-
lower loopvar pp_expr upper loopvar ;
64+
pf ppf "@[for (@[int %s = %a;@ %s <= %a;@ ++%s@])" loopvar pp_expr lower
65+
loopvar pp_expr upper loopvar ;
6066
pf ppf " %a@]" pp_body body
6167

6268
let rec integer_el_type = function
@@ -76,7 +82,7 @@ let pp_decl ppf (vident, ut, adtype) =
7682
let pp_sized_decl ppf (vident, st, adtype) =
7783
pf ppf "%a@,%a" pp_decl
7884
(vident, SizedType.to_unsized st, adtype)
79-
pp_set_size (vident, st, adtype)
85+
pp_set_size (vident, st, adtype, true)
8086

8187
let pp_possibly_sized_decl ppf (vident, pst, adtype) =
8288
match pst with

0 commit comments

Comments
 (0)