@@ -13,7 +13,9 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
1313 match pattern with
1414 | Var s -> Set.Poly. singleton (Dataflow_types. VVar s, meta)
1515 | Lit _ -> Set.Poly. empty
16- | FunApp (_ , exprs ) -> union_recur exprs
16+ | FunApp (_ , exprs ) ->
17+ if UnsizedType. contains_eigen_type type_ then union_recur exprs
18+ else Set.Poly. empty
1719 | TernaryIf (_ , expr2 , expr3 ) -> union_recur [expr2; expr3]
1820 | Indexed (expr , _ ) -> matrix_set expr
1921 | EAnd (expr1 , expr2 ) | EOr (expr1 , expr2 ) -> union_recur [expr1; expr2]
@@ -123,30 +125,38 @@ let is_fun_soa_supported name exprs =
123125 * will be returned if the matrix or vector is accessed by single
124126 * cell indexing.
125127 *)
126- let rec query_initial_demotable_expr (in_loop : bool ) Expr.Fixed. {pattern; _} =
127- let query_expr = query_initial_demotable_expr in_loop in
128+ let rec query_initial_demotable_expr (in_loop : bool ) ~(acc : string Set.Poly.t )
129+ Expr.Fixed. {pattern; _} : string Set.Poly.t =
130+ let query_expr (accum : string Set.Poly.t ) =
131+ query_initial_demotable_expr in_loop ~acc: accum in
128132 match pattern with
129133 | FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed.t list )) ->
130- query_initial_demotable_funs in_loop kind exprs
134+ query_initial_demotable_funs in_loop acc kind exprs
131135 | Indexed ((Expr.Fixed. {meta = {type_; _} ; _} as expr ), indexed ) ->
132136 let index_set =
133137 Set.Poly. union_list
134138 (List. map
135139 ~f:
136140 (Index. apply ~default: Set.Poly. empty ~merge: Set.Poly. union
137- query_expr )
141+ ( query_expr acc) )
138142 indexed ) in
139- if is_uni_eigen_loop_indexing in_loop type_ indexed then
140- Set.Poly. union (query_var_eigen_names expr) index_set
141- else Set.Poly. union (query_expr expr) index_set
143+ let index_demotes =
144+ if is_uni_eigen_loop_indexing in_loop type_ indexed then
145+ Set.Poly. union (query_var_eigen_names expr) index_set
146+ else Set.Poly. union (query_expr acc expr) index_set in
147+ Set.Poly. union acc index_demotes
142148 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
143- Set.Poly. empty
149+ acc
144150 | TernaryIf (predicate , texpr , fexpr ) ->
151+ let predicate_demotes = query_expr acc predicate in
145152 Set.Poly. union
146- (Set.Poly. union (query_expr predicate) (query_var_eigen_names texpr))
153+ (Set.Poly. union predicate_demotes (query_var_eigen_names texpr))
147154 (query_var_eigen_names fexpr)
148155 | EAnd (lhs , rhs ) | EOr (lhs , rhs ) ->
149- Set.Poly. union (query_expr lhs) (query_expr rhs)
156+ (* We need to get the demotes from both sides*)
157+ let full_lhs_rhs =
158+ Set.Poly. union (query_expr acc lhs) (query_expr acc rhs) in
159+ Set.Poly. union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs)
150160
151161(* *
152162 * Query a function to detect if it or any of its used
@@ -165,25 +175,28 @@ let rec query_initial_demotable_expr (in_loop : bool) Expr.Fixed.{pattern; _} =
165175 * to the UDF.
166176 * exprs The expression list passed to the functions.
167177 *)
168- and query_initial_demotable_funs (in_loop : bool ) (kind : 'a Fun_kind.t )
169- (exprs : Typed.Meta.t Expr.Fixed.t list ) : string Set.Poly.t =
170- let query_expr = query_initial_demotable_expr in_loop in
178+ and query_initial_demotable_funs (in_loop : bool ) (acc : string Set.Poly.t )
179+ (kind : 'a Fun_kind.t ) (exprs : Typed.Meta.t Expr.Fixed.t list ) :
180+ string Set.Poly. t =
181+ let query_expr accum = query_initial_demotable_expr in_loop ~acc: accum in
171182 let top_level_eigen_names =
172183 Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
173- let demoted_eigen_names = Set.Poly. union_list (List. map ~f: query_expr exprs) in
184+ let demoted_eigen_names = List. fold ~init: acc ~f: query_expr exprs in
185+ let demoted_and_top_level_names =
186+ Set.Poly. union demoted_eigen_names top_level_eigen_names in
174187 match kind with
175188 | Fun_kind. StanLib (name , (_ : bool Fun_kind.suffix ), _ ) -> (
176189 match name with
177- | "check_matching_dims" -> Set.Poly. empty
190+ | "check_matching_dims" -> acc
178191 | name -> (
179192 match is_fun_soa_supported name exprs with
180- | true -> demoted_eigen_names
181- | false -> Set.Poly. union demoted_eigen_names top_level_eigen_names ) )
193+ | true -> Set.Poly. union acc demoted_eigen_names
194+ | false -> Set.Poly. union acc demoted_and_top_level_names ) )
182195 | CompilerInternal (Internal_fun. FnMakeArray | FnMakeRowVec ) ->
183- Set.Poly. union demoted_eigen_names top_level_eigen_names
184- | CompilerInternal (_ : 'a Internal_fun.t ) -> Set.Poly. empty
196+ Set.Poly. union acc demoted_and_top_level_names
197+ | CompilerInternal (_ : 'a Internal_fun.t ) -> acc
185198 | UserDefined ((_ : string ), (_ : bool Fun_kind.suffix )) ->
186- Set.Poly. union demoted_eigen_names top_level_eigen_names
199+ Set.Poly. union acc demoted_and_top_level_names
187200
188201(* *
189202 * Check whether any functions in the right hand side expression of an assignment
@@ -295,7 +308,8 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
295308 *
296309 * For assignments:
297310 * We demote the LHS variable if any of the following are true:
298- * 1. None of the RHS's functions are able to accept SoA matrices
311+ * 1. None of the RHS's functions are able to accept SoA matrices
312+ * and the rhs is not an internal compiler function.
299313 * 2. A single cell of the LHS is being assigned within a loop.
300314 * 3. The top level expression on the RHS is a combination of only
301315 * data matrices and scalar types. Operations on data matrix and
@@ -311,66 +325,85 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
311325 * @param in_loop A boolean to specify the logic of indexing expressions. See
312326 * `query_initial_demotable_expr` for an explanation of the logic.
313327 *)
314- let rec query_initial_demotable_stmt (in_loop : bool )
328+ let rec query_initial_demotable_stmt (in_loop : bool ) ( acc : string Set.Poly.t )
315329 (Stmt.Fixed. {pattern; _} :
316330 (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t ) :
317331 string Set.Poly. t =
318- let query_expr = query_initial_demotable_expr in_loop in
332+ let query_expr (accum : string Set.Poly.t ) =
333+ query_initial_demotable_expr in_loop ~acc: accum in
319334 match pattern with
320335 | Stmt.Fixed.Pattern. Assignment
321336 ( ((name : string ), (ut : UnsizedType.t ), idx)
322337 , (Expr.Fixed. {meta= Expr.Typed.Meta. {type_; adlevel; _}; _} as rhs) ) ->
338+ let idx_list =
339+ List. fold ~init: acc
340+ ~f: (fun accum x ->
341+ Index. folder accum
342+ (fun acc -> query_initial_demotable_expr in_loop ~acc )
343+ x )
344+ idx in
323345 let idx_demotable =
324- let idx_list =
325- Set.Poly. union_list
326- (List. map
327- ~f:
328- (Index. apply ~default: Set.Poly. empty ~merge: Set.Poly. union
329- query_expr )
330- idx ) in
331346 (* RHS (2)*)
332347 match is_uni_eigen_loop_indexing in_loop ut idx with
333348 | true -> Set.Poly. add idx_list name
334349 | false -> idx_list in
335- let rhs_demotable = query_expr rhs in
350+ let rhs_demotable_names = query_expr acc rhs in
336351 (* RHS (3)*)
337- let check_bad_assign =
352+ let check_if_rhs_ad_real_data_matrix_expr =
338353 match (UnsizedType. contains_eigen_type type_, adlevel) with
339354 | true , UnsizedType. AutoDiffable ->
340355 is_any_ad_real_data_matrix_expr rhs
341356 || not (is_any_soa_supported_expr rhs)
342357 | _ -> false in
343358 (* RHS (1)*)
344359 let is_all_rhs_aos =
345- let all_eigen_names = query_var_eigen_names rhs in
346- is_nonzero_subset ~set: all_eigen_names ~subset: rhs_demotable in
347- if is_all_rhs_aos || check_bad_assign then
348- let base_set = Set.Poly. union idx_demotable rhs_demotable in
349- Set.Poly. add (Set.Poly. union base_set (query_var_eigen_names rhs)) name
350- else Set.Poly. union idx_demotable rhs_demotable
351- | NRFunApp (kind , exprs ) -> query_initial_demotable_funs in_loop kind exprs
360+ let all_rhs_eigen_names = query_var_eigen_names rhs in
361+ is_nonzero_subset ~subset: all_rhs_eigen_names ~set: rhs_demotable_names
362+ in
363+ let is_not_supported_func =
364+ match rhs.pattern with
365+ | FunApp (CompilerInternal _ , _ ) -> false
366+ | FunApp (UserDefined _ , _ ) -> true
367+ | _ -> false in
368+ let is_eigen_stmt = UnsizedType. contains_eigen_type rhs.meta.type_ in
369+ let assign_demotes =
370+ if
371+ is_eigen_stmt
372+ && ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr
373+ || is_not_supported_func )
374+ then
375+ let base_set = Set.Poly. union idx_demotable rhs_demotable_names in
376+ Set.Poly. add
377+ (Set.Poly. union base_set (query_var_eigen_names rhs))
378+ name
379+ else Set.Poly. union idx_demotable rhs_demotable_names in
380+ Set.Poly. union acc assign_demotes
381+ | NRFunApp (kind , exprs ) ->
382+ query_initial_demotable_funs in_loop acc kind exprs
352383 | IfElse (predicate , true_stmt , op_false_stmt ) ->
353- let demotable_rhs =
354- Option. value_map
355- ~f: (query_initial_demotable_stmt in_loop)
356- ~default: Set.Poly. empty op_false_stmt in
357- Set.Poly. union_list
358- [ query_expr predicate; query_initial_demotable_stmt in_loop true_stmt
359- ; demotable_rhs ]
384+ let predicate_acc = query_expr acc predicate in
385+ Set.Poly. union acc
386+ (Set.Poly. union_list
387+ [ predicate_acc
388+ ; query_initial_demotable_stmt in_loop predicate_acc true_stmt
389+ ; Option. value_map
390+ ~f: (query_initial_demotable_stmt in_loop predicate_acc)
391+ ~default: Set.Poly. empty op_false_stmt ] )
360392 | Return optional_expr ->
361- Option. value_map ~f: query_expr ~default: Set.Poly. empty optional_expr
393+ Option. value_map ~f: ( query_expr acc) ~default: Set.Poly. empty optional_expr
362394 | SList lst | Profile (_ , lst ) | Block lst ->
363395 Set.Poly. union_list
364- (List. map ~f: (query_initial_demotable_stmt in_loop) lst)
365- | TargetPE expr -> query_expr expr
396+ (List. map ~f: (query_initial_demotable_stmt in_loop acc ) lst)
397+ | TargetPE expr -> query_expr acc expr
366398 | For {lower; upper; body; _} ->
367399 Set.Poly. union
368- (Set.Poly. union (query_expr lower) (query_expr upper))
369- (query_initial_demotable_stmt true body)
400+ (Set.Poly. union (query_expr acc lower) (query_expr acc upper))
401+ (query_initial_demotable_stmt true acc body)
370402 | While (predicate , body ) ->
371- Set.Poly. union (query_expr predicate)
372- (query_initial_demotable_stmt true body)
373- | Skip | Break | Continue | Decl _ -> Set.Poly. empty
403+ Set.Poly. union_list
404+ [ acc; query_expr acc predicate
405+ ; query_initial_demotable_stmt true acc body ]
406+ | Skip | Break | Continue | Decl _ -> acc
374407
375408(* * Look through a statement to see whether the objects used in it need to be
376409 * modified from SoA to AoS. Returns the set of object names that need demoted
@@ -402,14 +435,6 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
402435 (* All other statements do not need logic here*)
403436 | _ -> Set.Poly. empty
404437
405- (* *
406- * Search through an expression for the names of all types that hold matrices
407- * and vectors.
408- **)
409- let query_eigen_names (expr : Typed.Meta.t Expr.Fixed.t ) : string Set.Poly.t =
410- let get_expr_names (Dataflow_types. VVar s , _ ) = Some s in
411- Set.Poly. filter_map ~f: get_expr_names (matrix_set expr)
412-
413438(* *
414439 * Modify a function and it's subexpressions from SoA <-> AoS and vice versa.
415440 * This performs demotion for sub expressions recursively. The top level
@@ -428,7 +453,8 @@ let query_eigen_names (expr : Typed.Meta.t Expr.Fixed.t) : string Set.Poly.t =
428453let rec modify_kind ?force_demotion :(force = false )
429454 (modifiable_set : string Set.Poly.t ) (kind : 'a Fun_kind.t )
430455 (exprs : Expr.Typed.Meta.t Expr.Fixed.t list ) =
431- let expr_names = Set.Poly. union_list (List. map ~f: query_eigen_names exprs) in
456+ let expr_names =
457+ Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
432458 let is_all_in_list =
433459 is_nonzero_subset ~set: modifiable_set ~subset: expr_names in
434460 match kind with
0 commit comments