@@ -21,42 +21,128 @@ let rec contains_eigen (ut : UnsizedType.t) : bool =
2121 | UMatrix | URowVector | UVector -> true
2222 | UInt | UReal | UMathLibraryFunction | UFun _ -> false
2323
24- let pp_set_size ppf (decl_id , st , adtype , (needs_filled : bool )) =
25- (* TODO: generate optimal adtypes for expressions and declarations *)
26- let real_nan =
27- match adtype with
28- | UnsizedType. AutoDiffable -> " DUMMY_VAR__"
29- | DataOnly -> " std::numeric_limits<double>::quiet_NaN()"
24+ (* Fill only needs to happen for containers
25+ * Note: This should probably be moved into its own function as data
26+ * does not need to be filled as we are promised user input data has the correct
27+ * dimensions. Transformed data must be filled as incorrect slices could lead
28+ * to elements of objects in transform data not being set by the user.
29+ *)
30+ let pp_filler ppf (decl_id , st , nan_type , needs_filled ) =
31+ match (needs_filled, contains_eigen (SizedType. to_unsized st)) with
32+ | true , true ->
33+ pf ppf " @[<hov 2>stan::math::fill(%s, %s);@]@," decl_id nan_type
34+ | _ -> ()
35+
36+ (* Pretty print a sized type*)
37+ let pp_st ppf (st , adtype ) =
38+ pf ppf " %a" pp_unsizedtype_local (adtype, SizedType. to_unsized st)
39+
40+ let pp_ut ppf (ut , adtype ) = pf ppf " %a" pp_unsizedtype_local (adtype, ut)
41+
42+ (* Get a string representing for the NaN type of the given type *)
43+ let nan_type (st , adtype ) =
44+ match (adtype, st) with
45+ | UnsizedType. AutoDiffable , _ -> " DUMMY_VAR__"
46+ | DataOnly , _ -> " std::numeric_limits<double>::quiet_NaN()"
47+
48+ (* Pretty printer for the right hand side of expressions to initialize objects.
49+ * For scalar types this sets the value to NaN and for containers initializes the memory.
50+ *)
51+ let rec pp_initialize ppf (st , adtype ) =
52+ let init_nan = nan_type (st, adtype) in
53+ match st with
54+ | SizedType. SInt -> pf ppf " std::numeric_limits<int>::min()"
55+ | SReal -> pf ppf " %s" init_nan
56+ | SVector d | SRowVector d -> pf ppf " %a(%a)" pp_st (st, adtype) pp_expr d
57+ | SMatrix (d1 , d2 ) ->
58+ pf ppf " %a(%a, %a)" pp_st (st, adtype) pp_expr d1 pp_expr d2
59+ | SArray (t , d ) ->
60+ pf ppf " %a(%a, %a)" pp_st (st, adtype) pp_expr d pp_initialize (t, adtype)
61+
62+ (* Initialize an object of a given size.*)
63+ let pp_assign_sized ppf (decl_id , st , adtype ) =
64+ let init_nan = nan_type (st, adtype) in
65+ let pp_assign ppf (decl_id , st , adtype ) =
66+ pf ppf " @[<hov 2>%s = %a;@]@," decl_id pp_initialize (st, adtype)
3067 in
31- let rec pp_size_ctor ppf st =
32- let pp_st ppf st =
33- pf ppf " %a" pp_unsizedtype_local (adtype, SizedType. to_unsized st)
34- in
68+ pf ppf " @[%a%a@]@," pp_assign (decl_id, st, adtype) pp_filler
69+ (decl_id, st, init_nan, true )
70+
71+ let % expect_test " set size mat array" =
72+ let int = Expr.Helpers. int in
73+ strf " @[<v>%a@]" pp_assign_sized
74+ (" d" , SArray (SArray (SMatrix (int 2 , int 3 ), int 4 ), int 5 ), DataOnly )
75+ |> print_endline ;
76+ [% expect
77+ {|
78+ d = std ::vector< std::vector< Eigen ::Matrix < double, - 1 , - 1 >>> (5 , std ::vector< Eigen ::Matrix < double, - 1 , - 1 >> (4 , Eigen ::Matrix < double, - 1 , - 1 > (2 , 3 )));
79+ stan ::math::fill(d , std ::numeric_limits< double> ::quiet_NaN()); | }]
80+
81+ (* Initialize Data and Transformed Data
82+ * This function is used in the model's constructor to
83+ * 1. Initialize memory for the data and transformed data
84+ * 2. If an Eigen type, place that memory into the class's Map
85+ * 3. Set the initial values of that data to NaN.
86+ * @param ppf A pretty printer
87+ * @param decl_id The name of the model class member
88+ * @param st The type of the class member
89+ *)
90+ let pp_assign_data ppf
91+ ((decl_id , st , needs_filled ) : string * Expr.Typed. t SizedType. t * bool ) =
92+ let init_nan = nan_type (st, DataOnly ) in
93+ let pp_assign ppf (decl_id , st ) =
3594 match st with
36- | SizedType. SInt -> pf ppf " std::numeric_limits<int>::min()"
37- | SReal -> pf ppf " %s" real_nan
38- | SVector d | SRowVector d -> pf ppf " %a(%a)" pp_st st pp_expr d
39- | SMatrix (d1 , d2 ) -> pf ppf " %a(%a, %a)" pp_st st pp_expr d1 pp_expr d2
40- | SArray (t , d ) -> pf ppf " %a(%a, %a)" pp_st st pp_expr d pp_size_ctor t
95+ | SizedType. SVector _ | SRowVector _ | SMatrix _ ->
96+ pf ppf " @[<hov 2>%s__ = %a;@]@," decl_id pp_initialize (st, DataOnly )
97+ | SInt | SReal | SArray _ ->
98+ pf ppf " @[<hov 2>%s = %a;@]@," decl_id pp_initialize (st, DataOnly )
4199 in
42- let print_fill ppf st =
43- match (contains_eigen (SizedType. to_unsized st), needs_filled) with
44- | true , true -> pf ppf " stan::math::fill(%s, %s);" decl_id real_nan
45- | _ , _ -> ()
100+ let pp_placement_new ppf (decl_id , st ) =
101+ match st with
102+ | SizedType. SVector d | SRowVector d ->
103+ pf ppf " @[<hov 2>new (&%s) Eigen::Map<%a>(%s__.data(), %a);@]@,"
104+ decl_id pp_st (st, DataOnly ) decl_id pp_expr d
105+ | SMatrix (d1 , d2 ) ->
106+ pf ppf " @[<hov 2>new (&%s) Eigen::Map<%a>(%s__.data(), %a, %a);@]@,"
107+ decl_id pp_st (st, DataOnly ) decl_id pp_expr d1 pp_expr d2
108+ | _ -> ()
46109 in
47- pf ppf " @[<hov 0>%s = %a;@,%a @]@," decl_id pp_size_ctor st print_fill st
110+ pf ppf " @[%a%a%a@]@," pp_assign (decl_id, st) pp_placement_new (decl_id, st)
111+ pp_filler
112+ (decl_id, st, init_nan, needs_filled)
48113
49- let % expect_test " set size mat array" =
114+ let % expect_test " set size map int array" =
50115 let int = Expr.Helpers. int in
51- strf " @[<v>%a@]" pp_set_size
52- ( " d"
53- , SArray (SArray (SMatrix (int 2 , int 3 ), int 4 ), int 5 )
54- , DataOnly
55- , false )
116+ strf " @[<v>%a@]" pp_assign_data
117+ (" darrmat" , SArray (SArray (SInt , int 4 ), int 5 ), false )
56118 |> print_endline ;
57119 [% expect
58120 {|
59- d = std ::vector< std::vector< Eigen ::Matrix < double, - 1 , - 1 >>> (5 , std ::vector< Eigen ::Matrix < double, - 1 , - 1 >> (4 , Eigen ::Matrix < double, - 1 , - 1 > (2 , 3 ))); | }]
121+ darrmat = std ::vector< std::vector< int >> (5 , std ::vector< int > (4 , std ::numeric_limits< int > ::min())); | }]
122+
123+ let % expect_test " set size map mat array" =
124+ let int = Expr.Helpers. int in
125+ strf " @[<v>%a@]" pp_assign_data
126+ (" darrmat" , SArray (SArray (SMatrix (int 2 , int 3 ), int 4 ), int 5 ), true )
127+ |> print_endline ;
128+ [% expect
129+ {|
130+ darrmat = std ::vector< std::vector< Eigen ::Matrix < double, - 1 , - 1 >>> (5 , std ::vector< Eigen ::Matrix < double, - 1 , - 1 >> (4 , Eigen ::Matrix < double, - 1 , - 1 > (2 , 3 )));
131+ stan ::math::fill(darrmat , std ::numeric_limits< double> ::quiet_NaN()); | }]
132+
133+ let % expect_test " set size map mat" =
134+ let int = Expr.Helpers. int in
135+ strf " @[<v>%a@]" pp_assign_data (" dmat" , SMatrix (int 2 , int 3 ), false )
136+ |> print_endline ;
137+ [% expect
138+ {|
139+ dmat__ = Eigen ::Matrix < double, - 1 , - 1 > (2 , 3 );
140+ new (& dmat) Eigen ::Map < Eigen ::Matrix < double, - 1 , - 1 >> (dmat__.data() , 2 , 3 ); | }]
141+
142+ let % expect_test " set size map int" =
143+ strf " @[<v>%a@]" pp_assign_data (" dint" , SInt , true ) |> print_endline ;
144+ [% expect {|
145+ dint = std ::numeric_limits< int > ::min(); | }]
60146
61147(* * [pp_for_loop ppf (loopvar, lower, upper, pp_body, body)] tries to
62148 pretty print a for-loop from lower to upper given some loopvar.*)
@@ -70,7 +156,51 @@ let rec integer_el_type = function
70156 | SInt -> true
71157 | SArray (st , _ ) -> integer_el_type st
72158
73- let pp_decl ppf (vident , ut , adtype ) =
159+ (* Print the private members of the model class
160+ * Accounting for types that can be moved to OpenCL.
161+ * @param ppf A formatter
162+ * @param vident name of the private member.
163+ * @param ut The unsized type to print.
164+ *)
165+ let pp_data_decl ppf (vident , ut ) =
166+ let opencl_check = (Transform_Mir. is_opencl_var vident, ut) in
167+ let pp_type =
168+ match opencl_check with
169+ | _ , UnsizedType. (UInt | UReal ) | false , _ -> pp_unsizedtype_local
170+ | true , UArray UInt -> fun ppf _ -> pf ppf " matrix_cl<int>"
171+ | true , _ -> fun ppf _ -> pf ppf " matrix_cl<double>"
172+ in
173+ match (opencl_check, ut) with
174+ | (false , _ ), ut -> (
175+ match ut with
176+ | UnsizedType. URowVector | UVector | UMatrix ->
177+ pf ppf " %a %s__;" pp_type (DataOnly , ut) vident
178+ | _ -> pf ppf " %a %s;" pp_type (DataOnly , ut) vident )
179+ | (true , _ ), _ -> pf ppf " %a %s;" pp_type (DataOnly , ut) vident
180+
181+ (* Create strings representing maps of Eigen types*)
182+ let pp_map_decl ppf (vident , ut ) =
183+ let scalar = local_scalar ut DataOnly in
184+ match ut with
185+ | UnsizedType. UInt | UReal -> ()
186+ | UMatrix ->
187+ pf ppf " Eigen::Map<Eigen::Matrix<%s, -1, -1>> %s{nullptr, 0, 0};" scalar
188+ vident
189+ | URowVector ->
190+ pf ppf " Eigen::Map<Eigen::Matrix<%s, 1, -1>> %s{nullptr, 0};" scalar
191+ vident
192+ | UVector ->
193+ pf ppf " Eigen::Map<Eigen::Matrix<%s, -1, 1>> %s{nullptr, 0};" scalar
194+ vident
195+ | x ->
196+ raise_s
197+ [% message
198+ " Error during Map data construction for " vident " of type "
199+ (x : UnsizedType.t )
200+ " . This should never happen, if you see this please file a bug \
201+ report." ]
202+
203+ let pp_unsized_decl ppf (vident , ut , adtype ) =
74204 let pp_type =
75205 match (Transform_Mir. is_opencl_var vident, ut) with
76206 | _ , UnsizedType. (UInt | UReal ) | false , _ -> pp_unsizedtype_local
@@ -80,14 +210,14 @@ let pp_decl ppf (vident, ut, adtype) =
80210 pf ppf " %a %s;" pp_type (adtype, ut) vident
81211
82212let pp_sized_decl ppf (vident , st , adtype ) =
83- pf ppf " %a@,%a" pp_decl
213+ pf ppf " %a@,%a" pp_unsized_decl
84214 (vident, SizedType. to_unsized st, adtype)
85- pp_set_size (vident, st, adtype, true )
215+ pp_assign_sized (vident, st, adtype)
86216
87- let pp_possibly_sized_decl ppf (vident , pst , adtype ) =
217+ let pp_decl ppf (vident , pst , adtype ) =
88218 match pst with
89219 | Type. Sized st -> pp_sized_decl ppf (vident, st, adtype)
90- | Unsized ut -> pp_decl ppf (vident, ut, adtype)
220+ | Unsized ut -> pp_unsized_decl ppf (vident, ut, adtype)
91221
92222let math_fn_translations = function
93223 | Internal_fun. FnLength -> Some (" length" , [] )
@@ -207,7 +337,7 @@ let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.({pattern; meta}) =
207337 | Block ls -> pp_block ppf (pp_stmt_list, ls)
208338 | SList ls -> pp_stmt_list ppf ls
209339 | Decl {decl_adtype; decl_id; decl_type} ->
210- pp_possibly_sized_decl ppf (decl_id, decl_type, decl_adtype)
340+ pp_decl ppf (decl_id, decl_type, decl_adtype)
211341
212342and pp_block_s ppf body =
213343 match body.pattern with
0 commit comments