@@ -17,7 +17,7 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
1717 if UnsizedType. contains_eigen_type type_ then union_recur exprs
1818 else Set.Poly. empty
1919 | TernaryIf (_ , expr2 , expr3 ) -> union_recur [expr2; expr3]
20- | Indexed (expr , _ ) -> matrix_set expr
20+ | Indexed (expr , _ ) | Promotion ( expr , _ , _ ) -> matrix_set expr
2121 | EAnd (expr1 , expr2 ) | EOr (expr1 , expr2 ) -> union_recur [expr1; expr2]
2222 else Set.Poly. empty
2323
@@ -45,7 +45,7 @@ let is_nonzero_subset ~set ~subset =
4545 && not (Set.Poly. is_empty subset)
4646
4747(* *
48- * Check an expression to count how many times we see a single index.
48+ * Check an expression to count how many times we see a single index.
4949 * @param acc An accumulator from previous folds of multiple expressions.
5050 * @param pattern The expression patterns to match against
5151 *)
@@ -62,6 +62,7 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
6262 acc
6363 + count_single_idx_exprs 0 idx_expr
6464 + List. fold_left ~init: 0 ~f: count_single_idx indexed
65+ | Promotion (expr , _ , _ ) -> count_single_idx_exprs acc expr
6566 | EAnd (lhs , rhs ) ->
6667 acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs
6768 | EOr (lhs , rhs ) ->
@@ -72,9 +73,9 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
7273(* *
7374 * Check an Index to count how many times we see a single index.
7475 * @param acc An accumulator from previous folds of multiple expressions.
75- * @param idx An Index to match. For Single types this adds 1 to the
76- * acc. For Upfrom and MultiIndex types we check the inner expression
77- * for a Single index. All and Between cannot be Single cell access
76+ * @param idx An Index to match. For Single types this adds 1 to the
77+ * acc. For Upfrom and MultiIndex types we check the inner expression
78+ * for a Single index. All and Between cannot be Single cell access
7879 * and so pass acc along.
7980 *)
8081and count_single_idx (acc : int ) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t )
@@ -84,13 +85,13 @@ and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
8485 | Single _ -> acc + 1
8586
8687(* *
87- * Find indices on Matrix and Vector types that perform single
88+ * Find indices on Matrix and Vector types that perform single
8889 * cell access. Returns true if it finds
8990 * a vector, row vector, matrix, or matrix with single cell access
90- * as well as an array of any of the above that is accessing the
91+ * as well as an array of any of the above that is accessing the
9192 * inner matrix types cell.
9293 * @param ut An UnsizedType to match against.
93- * @param index This list is checked for Single cell access
94+ * @param index This list is checked for Single cell access
9495 * either at the top level or within the `Index` types of the list.
9596 *)
9697let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t )
@@ -122,7 +123,7 @@ let is_fun_soa_supported name exprs =
122123 * see the docs for `query_initial_demotable_funs`.
123124 * @param in_loop a boolean to signify if the expression exists inside
124125 * of a loop. If so, the names of matrix and vector like objects
125- * will be returned if the matrix or vector is accessed by single
126+ * will be returned if the matrix or vector is accessed by single
126127 * cell indexing.
127128 *)
128129let rec query_initial_demotable_expr (in_loop : bool ) ~(acc : string Set.Poly.t )
@@ -147,6 +148,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
147148 Set.Poly. union acc index_demotes
148149 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
149150 acc
151+ | Promotion (expr , _ , _ ) -> query_expr acc expr
150152 | TernaryIf (predicate , texpr , fexpr ) ->
151153 let predicate_demotes = query_expr acc predicate in
152154 Set.Poly. union
@@ -159,18 +161,18 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
159161 Set.Poly. union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs)
160162
161163(* *
162- * Query a function to detect if it or any of its used
164+ * Query a function to detect if it or any of its used
163165 * expression's objects or expressions should be demoted to AoS.
164166 *
165167 * The logic here demotes the expressions in a function to AoS if
166- * the function's inner expression returns has a meta type containing a matrix
168+ * the function's inner expression returns has a meta type containing a matrix
167169 * and either of :
168170 * (1) The function is user defined and the UDFs inputs are matrices.
169171 * (2) The Stan math function cannot support AoS
170172 * @param in_loop A boolean to specify the logic of indexing expressions. See
171173 * `query_initial_demotable_expr` for an explanation of the logic.
172- * @param kind The function type, for StanLib functions we check if the
173- * function supports SoA and for UserDefined functions we always fail
174+ * @param kind The function type, for StanLib functions we check if the
175+ * function supports SoA and for UserDefined functions we always fail
174176 * and return back all of the names of the objects passed in expressions
175177 * to the UDF.
176178 * exprs The expression list passed to the functions.
@@ -212,7 +214,8 @@ let rec is_any_soa_supported_expr
212214 match pattern with
213215 | FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed.t list )) ->
214216 is_any_soa_supported_fun_expr kind exprs
215- | Indexed (expr , (_ : Typed.Meta.t Fixed.t Index.t list )) ->
217+ | Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list ))
218+ | Promotion (expr , _ , _ ) ->
216219 is_any_soa_supported_expr expr
217220 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
218221 true
@@ -248,7 +251,8 @@ let rec is_any_ad_real_data_matrix_expr
248251 match pattern with
249252 | FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed.t list )) ->
250253 is_any_ad_real_data_matrix_expr_fun kind exprs
251- | Indexed (expr , _ ) -> is_any_ad_real_data_matrix_expr expr
254+ | Indexed (expr , _ ) | Promotion (expr , _ , _ ) ->
255+ is_any_ad_real_data_matrix_expr expr
252256 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
253257 false
254258 | TernaryIf (_ , texpr , fexpr ) ->
@@ -303,15 +307,15 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
303307
304308(* *
305309 * Query to find the initial set of objects in statements that cannot be SoA.
306- * This is mostly recursive over expressions and statements, with the exception of
310+ * This is mostly recursive over expressions and statements, with the exception of
307311 * functions and Assignments.
308312 *
309313 * For assignments:
310314 * We demote the LHS variable if any of the following are true:
311- * 1. None of the RHS's functions are able to accept SoA matrices
315+ * 1. None of the RHS's functions are able to accept SoA matrices
312316 * and the rhs is not an internal compiler function.
313317 * 2. A single cell of the LHS is being assigned within a loop.
314- * 3. The top level expression on the RHS is a combination of only
318+ * 3. The top level expression on the RHS is a combination of only
315319 * data matrices and scalar types. Operations on data matrix and
316320 * scalar values in Stan math will return a AoS matrix. We currently
317321 * have no way to tell Stan math to return a SoA matrix.
@@ -408,11 +412,11 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408412(* * Look through a statement to see whether the objects used in it need to be
409413 * modified from SoA to AoS. Returns the set of object names that need demoted
410414 * in a statement, if any.
411- * This function looks at Assignment statements, and returns back the
415+ * This function looks at Assignment statements, and returns back the
412416 * set of top level object names given:
413- * 1. If the name of the lhs assignee is in the `aos_exits`, all the names
417+ * 1. If the name of the lhs assignee is in the `aos_exits`, all the names
414418 * of the expressions with a type containing a matrix are returned.
415- * 2. If the names of the rhs objects containing matrix types are in the subset of
419+ * 2. If the names of the rhs objects containing matrix types are in the subset of
416420 * aos_exits.
417421 * @param aos_exits A set of variables that can be demoted.
418422 * @param pattern The Stmt pattern to query.
@@ -437,15 +441,15 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
437441
438442(* *
439443 * Modify a function and it's subexpressions from SoA <-> AoS and vice versa.
440- * This performs demotion for sub expressions recursively. The top level
441- * expression and it's sub expressions are demoted to SoA if
442- * 1. The names of the variables in the subexpressions returning
444+ * This performs demotion for sub expressions recursively. The top level
445+ * expression and it's sub expressions are demoted to SoA if
446+ * 1. The names of the variables in the subexpressions returning
443447 * objects holding matrices are all in the modifiable set.
444448 * 2. The function does not support SoA
445449 * 3. The `force` argument is `true`
446- * @param force_demotion If true, forces an expression and it's sub-expressions
450+ * @param force_demotion If true, forces an expression and it's sub-expressions
447451 * to be AoS.
448- * @param modifiable_set The set of names that are either demotable
452+ * @param modifiable_set The set of names that are either demotable
449453 * to AoS or promotable to SoA.
450454 * @param kind A `Fun_kind.t`
451455 * @param exprs A list of expressions going into the function.
@@ -474,15 +478,15 @@ let rec modify_kind ?force_demotion:(force = false)
474478 ( kind
475479 , List. map ~f: (modify_expr ~force_demotion: force modifiable_set) exprs )
476480
477- (* *
481+ (* *
478482 * Modify an expression and it's subexpressions from SoA <-> AoS
479483 * and vice versa. The only real paths in the below is on the
480484 * functions and ternary expressions.
481485 *
482486 * The logic for functions is defined in `modify_kind`.
483487 * `TernaryIf` is forcefully demoted to AoS if the type of the expression
484488 * contains a matrix.
485- * @param force_demotion If true, forces an expression and it's sub-expressions
489+ * @param force_demotion If true, forces an expression and it's sub-expressions
486490 * to be AoS.
487491 * @param modifiable_set The name of the variables whose
488492 * associated expressions we want to modify.
@@ -519,11 +523,13 @@ and modify_expr_pattern ?force_demotion:(force = false)
519523 , List. map ~f: (Index. map (mod_expr ~force_demotion: force)) indexed )
520524 | EAnd (lhs , rhs ) -> EAnd (mod_expr lhs, mod_expr rhs)
521525 | EOr (lhs , rhs ) -> EOr (mod_expr lhs, mod_expr rhs)
526+ | Promotion (expr , type_ , ad_level ) ->
527+ Promotion (mod_expr expr, type_, ad_level)
522528 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
523529 pattern
524530
525- (* *
526- * Given a Set of strings containing the names of objects that can be
531+ (* *
532+ * Given a Set of strings containing the names of objects that can be
527533* modified from AoS <-> SoA and vice versa, modify them within the expression.
528534* @param mem_pattern The memory pattern to change expressions to.
529535* @param modifiable_set The name of the variables whose
@@ -536,15 +542,15 @@ and modify_expr ?force_demotion:(force = false)
536542 pattern= modify_expr_pattern ~force_demotion: force modifiable_set pattern }
537543
538544(* *
539- * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
545+ * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
540546* For `Decl` and `Assignment`'s reading in parameters, we demote to AoS
541547* if the `decl_id` (or assign name) is in the modifiable set and
542548* otherwise promote the statement to `SoA`.
543549* For general `Assignment` statements, we check if the assignee is in
544550* the demotable set. If so, we force demotion of all of the rhs expressions.
545551* All other statements recurse over their statements and expressions.
546- *
547- * @param pattern The statement pattern to modify
552+ *
553+ * @param pattern The statement pattern to modify
548554* @param modifiable_set The name of the variable we are searching for.
549555*)
550556let rec modify_stmt_pattern
@@ -624,11 +630,11 @@ let rec modify_stmt_pattern
624630 | Skip | Break | Continue | Decl _ -> pattern
625631
626632(* *
627- * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
628- * @param mem_pattern A mem_pattern to modify expressions to. For the
629- * given memory pattern, this modifies
633+ * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
634+ * @param mem_pattern A mem_pattern to modify expressions to. For the
635+ * given memory pattern, this modifies
630636* statement patterns and expressions to it.
631- * @param stmt The statement to modify.
637+ * @param stmt The statement to modify.
632638* @param modifiable_set The name of the variable we are searching for.
633639*)
634640and modify_stmt (Stmt.Fixed. {pattern; _} as stmt )
0 commit comments