Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions src/analysis_and_optimization/Mem_pattern.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
if UnsizedType.contains_eigen_type type_ then union_recur exprs
else Set.Poly.empty
| TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3]
| Indexed (expr, _) -> matrix_set expr
| Indexed (expr, _) | Promotion (expr, _, _) -> matrix_set expr
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]
else Set.Poly.empty

Expand Down Expand Up @@ -45,7 +45,7 @@ let is_nonzero_subset ~set ~subset =
&& not (Set.Poly.is_empty subset)

(**
* Check an expression to count how many times we see a single index.
* Check an expression to count how many times we see a single index.
* @param acc An accumulator from previous folds of multiple expressions.
* @param pattern The expression patterns to match against
*)
Expand All @@ -62,6 +62,7 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
acc
+ count_single_idx_exprs 0 idx_expr
+ List.fold_left ~init:0 ~f:count_single_idx indexed
| Promotion (expr, _, _) -> count_single_idx_exprs acc expr
| EAnd (lhs, rhs) ->
acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs
| EOr (lhs, rhs) ->
Expand All @@ -72,9 +73,9 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
(**
* Check an Index to count how many times we see a single index.
* @param acc An accumulator from previous folds of multiple expressions.
* @param idx An Index to match. For Single types this adds 1 to the
* acc. For Upfrom and MultiIndex types we check the inner expression
* for a Single index. All and Between cannot be Single cell access
* @param idx An Index to match. For Single types this adds 1 to the
* acc. For Upfrom and MultiIndex types we check the inner expression
* for a Single index. All and Between cannot be Single cell access
* and so pass acc along.
*)
and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
Expand All @@ -84,13 +85,13 @@ and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
| Single _ -> acc + 1

(**
* Find indices on Matrix and Vector types that perform single
* Find indices on Matrix and Vector types that perform single
* cell access. Returns true if it finds
* a vector, row vector, matrix, or matrix with single cell access
* as well as an array of any of the above that is accessing the
* as well as an array of any of the above that is accessing the
* inner matrix types cell.
* @param ut An UnsizedType to match against.
* @param index This list is checked for Single cell access
* @param index This list is checked for Single cell access
* either at the top level or within the `Index` types of the list.
*)
let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t)
Expand Down Expand Up @@ -122,7 +123,7 @@ let is_fun_soa_supported name exprs =
* see the docs for `query_initial_demotable_funs`.
* @param in_loop a boolean to signify if the expression exists inside
* of a loop. If so, the names of matrix and vector like objects
* will be returned if the matrix or vector is accessed by single
* will be returned if the matrix or vector is accessed by single
* cell indexing.
*)
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
Expand All @@ -147,6 +148,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
Set.Poly.union acc index_demotes
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
acc
| Promotion (expr, _, _) -> query_expr acc expr
| TernaryIf (predicate, texpr, fexpr) ->
let predicate_demotes = query_expr acc predicate in
Set.Poly.union
Expand All @@ -159,18 +161,18 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
Set.Poly.union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs)

(**
* Query a function to detect if it or any of its used
* Query a function to detect if it or any of its used
* expression's objects or expressions should be demoted to AoS.
*
* The logic here demotes the expressions in a function to AoS if
* the function's inner expression returns has a meta type containing a matrix
* the function's inner expression returns has a meta type containing a matrix
* and either of :
* (1) The function is user defined and the UDFs inputs are matrices.
* (2) The Stan math function cannot support AoS
* @param in_loop A boolean to specify the logic of indexing expressions. See
* `query_initial_demotable_expr` for an explanation of the logic.
* @param kind The function type, for StanLib functions we check if the
* function supports SoA and for UserDefined functions we always fail
* @param kind The function type, for StanLib functions we check if the
* function supports SoA and for UserDefined functions we always fail
* and return back all of the names of the objects passed in expressions
* to the UDF.
* exprs The expression list passed to the functions.
Expand Down Expand Up @@ -212,7 +214,8 @@ let rec is_any_soa_supported_expr
match pattern with
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
is_any_soa_supported_fun_expr kind exprs
| Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list)) ->
| Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list))
|Promotion (expr, _, _) ->
is_any_soa_supported_expr expr
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
true
Expand Down Expand Up @@ -248,7 +251,8 @@ let rec is_any_ad_real_data_matrix_expr
match pattern with
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
is_any_ad_real_data_matrix_expr_fun kind exprs
| Indexed (expr, _) -> is_any_ad_real_data_matrix_expr expr
| Indexed (expr, _) | Promotion (expr, _, _) ->
is_any_ad_real_data_matrix_expr expr
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
false
| TernaryIf (_, texpr, fexpr) ->
Expand Down Expand Up @@ -303,15 +307,15 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)

(**
* Query to find the initial set of objects in statements that cannot be SoA.
* This is mostly recursive over expressions and statements, with the exception of
* This is mostly recursive over expressions and statements, with the exception of
* functions and Assignments.
*
* For assignments:
* We demote the LHS variable if any of the following are true:
* 1. None of the RHS's functions are able to accept SoA matrices
* 1. None of the RHS's functions are able to accept SoA matrices
* and the rhs is not an internal compiler function.
* 2. A single cell of the LHS is being assigned within a loop.
* 3. The top level expression on the RHS is a combination of only
* 3. The top level expression on the RHS is a combination of only
* data matrices and scalar types. Operations on data matrix and
* scalar values in Stan math will return a AoS matrix. We currently
* have no way to tell Stan math to return a SoA matrix.
Expand Down Expand Up @@ -408,11 +412,11 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
(** Look through a statement to see whether the objects used in it need to be
* modified from SoA to AoS. Returns the set of object names that need demoted
* in a statement, if any.
* This function looks at Assignment statements, and returns back the
* This function looks at Assignment statements, and returns back the
* set of top level object names given:
* 1. If the name of the lhs assignee is in the `aos_exits`, all the names
* 1. If the name of the lhs assignee is in the `aos_exits`, all the names
* of the expressions with a type containing a matrix are returned.
* 2. If the names of the rhs objects containing matrix types are in the subset of
* 2. If the names of the rhs objects containing matrix types are in the subset of
* aos_exits.
* @param aos_exits A set of variables that can be demoted.
* @param pattern The Stmt pattern to query.
Expand All @@ -437,15 +441,15 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)

(**
* Modify a function and it's subexpressions from SoA <-> AoS and vice versa.
* This performs demotion for sub expressions recursively. The top level
* expression and it's sub expressions are demoted to SoA if
* 1. The names of the variables in the subexpressions returning
* This performs demotion for sub expressions recursively. The top level
* expression and it's sub expressions are demoted to SoA if
* 1. The names of the variables in the subexpressions returning
* objects holding matrices are all in the modifiable set.
* 2. The function does not support SoA
* 3. The `force` argument is `true`
* @param force_demotion If true, forces an expression and it's sub-expressions
* @param force_demotion If true, forces an expression and it's sub-expressions
* to be AoS.
* @param modifiable_set The set of names that are either demotable
* @param modifiable_set The set of names that are either demotable
* to AoS or promotable to SoA.
* @param kind A `Fun_kind.t`
* @param exprs A list of expressions going into the function.
Expand Down Expand Up @@ -474,15 +478,15 @@ let rec modify_kind ?force_demotion:(force = false)
( kind
, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs )

(**
(**
* Modify an expression and it's subexpressions from SoA <-> AoS
* and vice versa. The only real paths in the below is on the
* functions and ternary expressions.
*
* The logic for functions is defined in `modify_kind`.
* `TernaryIf` is forcefully demoted to AoS if the type of the expression
* contains a matrix.
* @param force_demotion If true, forces an expression and it's sub-expressions
* @param force_demotion If true, forces an expression and it's sub-expressions
* to be AoS.
* @param modifiable_set The name of the variables whose
* associated expressions we want to modify.
Expand Down Expand Up @@ -519,11 +523,13 @@ and modify_expr_pattern ?force_demotion:(force = false)
, List.map ~f:(Index.map (mod_expr ~force_demotion:force)) indexed )
| EAnd (lhs, rhs) -> EAnd (mod_expr lhs, mod_expr rhs)
| EOr (lhs, rhs) -> EOr (mod_expr lhs, mod_expr rhs)
| Promotion (expr, type_, ad_level) ->
Promotion (mod_expr expr, type_, ad_level)
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
pattern

(**
* Given a Set of strings containing the names of objects that can be
(**
* Given a Set of strings containing the names of objects that can be
* modified from AoS <-> SoA and vice versa, modify them within the expression.
* @param mem_pattern The memory pattern to change expressions to.
* @param modifiable_set The name of the variables whose
Expand All @@ -536,15 +542,15 @@ and modify_expr ?force_demotion:(force = false)
pattern= modify_expr_pattern ~force_demotion:force modifiable_set pattern }

(**
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
* For `Decl` and `Assignment`'s reading in parameters, we demote to AoS
* if the `decl_id` (or assign name) is in the modifiable set and
* otherwise promote the statement to `SoA`.
* For general `Assignment` statements, we check if the assignee is in
* the demotable set. If so, we force demotion of all of the rhs expressions.
* All other statements recurse over their statements and expressions.
*
* @param pattern The statement pattern to modify
*
* @param pattern The statement pattern to modify
* @param modifiable_set The name of the variable we are searching for.
*)
let rec modify_stmt_pattern
Expand Down Expand Up @@ -624,11 +630,11 @@ let rec modify_stmt_pattern
| Skip | Break | Continue | Decl _ -> pattern

(**
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
* @param mem_pattern A mem_pattern to modify expressions to. For the
* given memory pattern, this modifies
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
* @param mem_pattern A mem_pattern to modify expressions to. For the
* given memory pattern, this modifies
* statement patterns and expressions to it.
* @param stmt The statement to modify.
* @param stmt The statement to modify.
* @param modifiable_set The name of the variable we are searching for.
*)
and modify_stmt (Stmt.Fixed.{pattern; _} as stmt)
Expand Down
6 changes: 6 additions & 0 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ let rec expr_var_set Expr.Fixed.{pattern; meta} =
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
| Indexed (expr, ix) ->
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
| Promotion (expr, _, _) -> expr_var_set expr
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]

and index_var_set ix =
Expand Down Expand Up @@ -369,6 +370,7 @@ let rec expr_depth Expr.Fixed.{pattern; _} =
+ max (expr_depth e)
(Option.value ~default:0
(List.max_elt ~compare:compare_int (List.map ~f:idx_depth l)) )
| Promotion (expr, _, _) -> 1 + expr_depth expr
| EAnd (e1, e2) | EOr (e1, e2) ->
1
+ Option.value ~default:0
Expand Down Expand Up @@ -413,6 +415,10 @@ let rec update_expr_ad_levels autodiffable_variables
let e1 = update_expr_ad_levels autodiffable_variables e1 in
let e2 = update_expr_ad_levels autodiffable_variables e2 in
{pattern= EOr (e1, e2); meta= {e.meta with adlevel= ad_level_sup [e1; e2]}}
| Promotion (expr, ut, ad) ->
let expr' = update_expr_ad_levels autodiffable_variables expr in
{ pattern= Promotion (expr', ut, ad)
; meta= {e.meta with adlevel= ad_level_sup [expr']} }
| Indexed (ixed, i_list) ->
let ixed = update_expr_ad_levels autodiffable_variables ixed in
let i_list =
Expand Down
24 changes: 13 additions & 11 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ let print_mfp to_string (mfp : (int, 'a entry_exit) Map.Poly.t)
let rec free_vars_expr (e : Expr.Typed.t) =
match e.pattern with
| Var x -> Set.Poly.singleton x
| Promotion (expr, _, _) -> free_vars_expr expr
| Lit (_, _) -> Set.Poly.empty
| FunApp (kind, l) -> free_vars_fnapp kind l
| TernaryIf (e1, e2, e3) ->
Expand Down Expand Up @@ -133,15 +134,15 @@ let reverse (type l) (module F : FLOWGRAPH with type labels = l) =
with type labels = l )

(** Modify the end nodes of a flowgraph to depend on its inits
* To force the monotone framework to run until the program never changes
* this function modifies the input `Flowgraph` so that it's end nodes
* To force the monotone framework to run until the program never changes
* this function modifies the input `Flowgraph` so that it's end nodes
* depend on it's initial nodes. The inits of the reverse flowgraph are used
* for this since we normally have both the forward and reverse flowgraphs
* for this since we normally have both the forward and reverse flowgraphs
* available.
* @tparam l Type of the label for each flowgraph, most commonly an int
* @tparam l Type of the label for each flowgraph, most commonly an int
* @param Flowgraph The flowgraph to modify
* @param RevFlowgraph The same flowgraph as `Flowgraph` but reversed.
*
* @param RevFlowgraph The same flowgraph as `Flowgraph` but reversed.
*
*)
let make_circular_flowgraph (type l)
(module Flowgraph : FLOWGRAPH with type labels = l)
Expand Down Expand Up @@ -544,6 +545,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
(Expr.Typed.Set.singleton e)
( match e.pattern with
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
| Promotion (expr, _, _) -> used_subexpressions_expr expr
| FunApp (k, l) ->
Expr.Typed.Set.union_list
(List.map ~f:used_subexpressions_expr (l @ Fun_kind.collect_exprs k))
Expand Down Expand Up @@ -1011,14 +1013,14 @@ let lazy_expressions_mfp
let used_not_latest_expressions_mfp = Mf4.mfp () in
(latest_expr, used_not_latest_expressions_mfp)

(** Run the minimal fixed point algorithm to deduce the smallest set of
(** Run the minimal fixed point algorithm to deduce the smallest set of
* variables that satisfy a set of conditions.
* @param Flowgraph The set of nodes to analyze
* @param flowgraph_to_mir Map of nodes to their actual values in the MIR
* @param Flowgraph The set of nodes to analyze
* @param flowgraph_to_mir Map of nodes to their actual values in the MIR
* @param initial_variables The set of variables to start in the set
* @param gen_variable Used in the transfer function to deduce variables
* @param gen_variable Used in the transfer function to deduce variables
* that should be in the set
*
*
*)
let minimal_variables_mfp
(module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH
Expand Down
7 changes: 5 additions & 2 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
match pattern with
| Var _ -> ([], [], e)
| Lit (_, _) -> ([], [], e)
| Promotion (expr, ut, ad) ->
let d, sl, expr' = inline_function_expression propto adt fim expr in
(d, sl, {e with pattern= Promotion (expr', ut, ad)})
| FunApp (kind, es) -> (
let d_list, s_list, es =
inline_list (inline_function_expression propto adt fim) es in
Expand Down Expand Up @@ -1030,8 +1033,8 @@ let block_fixing mir =
(* TODO: add tests *)
(* TODO: add pass to get rid of redundant declarations? *)

(**
* A generic optimization pass for finding a minimal set of variables that
(**
* A generic optimization pass for finding a minimal set of variables that
* are generated by some circumstance, and then updating the MIR with that set.
* @param gen_variables: the variables that must be added to the set at
* the given statement
Expand Down
1 change: 1 addition & 0 deletions src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
pattern=
( match e.pattern with
| Var _ | Lit (_, _) -> e.pattern
| Promotion (expr, ut, ad) -> Promotion (eval_expr expr, ut, ad)
| FunApp (kind, l) -> (
let l = List.map ~f:(eval_expr ~preserve_stability) l in
match kind with
Expand Down
18 changes: 14 additions & 4 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type ('e, 'f) expression =
| ImagNumeral of string
| FunApp of 'f * identifier * 'e list
| CondDistApp of 'f * identifier * 'e list
| Promotion of 'e * UnsizedType.t * UnsizedType.autodifftype
(* GetLP is deprecated *)
| GetLP
| GetTarget
Expand Down Expand Up @@ -250,9 +251,14 @@ type typed_program = typed_statement program [@@deriving sexp, compare, map]
(** Forgetful function from typed to untyped expressions *)
let rec untyped_expression_of_typed_expression ({expr; emeta} : typed_expression)
: untyped_expression =
{ expr=
map_expression untyped_expression_of_typed_expression (fun _ -> ()) expr
; emeta= {loc= emeta.loc} }
match expr with
| Promotion (e, _, _) -> untyped_expression_of_typed_expression e
| _ ->
{ expr=
map_expression untyped_expression_of_typed_expression
(fun _ -> ())
expr
; emeta= {loc= emeta.loc} }

let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
untyped_lval =
Expand Down Expand Up @@ -304,7 +310,11 @@ let rec id_of_lvalue {lval; _} =

let rec get_loc_expr (e : untyped_expression) =
match e.expr with
| TernaryIf (e, _, _) | BinOp (e, _, _) | PostfixOp (e, _) | Indexed (e, _) ->
| TernaryIf (e, _, _)
|BinOp (e, _, _)
|PostfixOp (e, _)
|Indexed (e, _)
|Promotion (e, _, _) ->
get_loc_expr e
| PrefixOp (_, e) | ArrayExpr (e :: _) | RowVectorExpr (e :: _) | Paren e ->
e.emeta.loc.begin_loc
Expand Down
Loading