Skip to content

Commit 8d068b4

Browse files
committed
Merge branch 'feature/soa-optim' of github.com:SteveBronder/stanc3 into feature/soa-optim
2 parents ec1ea37 + 138d740 commit 8d068b4

28 files changed

+7146
-5563
lines changed

src/analysis_and_optimization/Mem_pattern.ml

Lines changed: 90 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
1313
match pattern with
1414
| Var s -> Set.Poly.singleton (Dataflow_types.VVar s, meta)
1515
| Lit _ -> Set.Poly.empty
16-
| FunApp (_, exprs) -> union_recur exprs
16+
| FunApp (_, exprs) ->
17+
if UnsizedType.contains_eigen_type type_ then union_recur exprs
18+
else Set.Poly.empty
1719
| TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3]
1820
| Indexed (expr, _) -> matrix_set expr
1921
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]
@@ -123,30 +125,38 @@ let is_fun_soa_supported name exprs =
123125
* will be returned if the matrix or vector is accessed by single
124126
* cell indexing.
125127
*)
126-
let rec query_initial_demotable_expr (in_loop : bool) Expr.Fixed.{pattern; _} =
127-
let query_expr = query_initial_demotable_expr in_loop in
128+
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
129+
Expr.Fixed.{pattern; _} : string Set.Poly.t =
130+
let query_expr (accum : string Set.Poly.t) =
131+
query_initial_demotable_expr in_loop ~acc:accum in
128132
match pattern with
129133
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
130-
query_initial_demotable_funs in_loop kind exprs
134+
query_initial_demotable_funs in_loop acc kind exprs
131135
| Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) ->
132136
let index_set =
133137
Set.Poly.union_list
134138
(List.map
135139
~f:
136140
(Index.apply ~default:Set.Poly.empty ~merge:Set.Poly.union
137-
query_expr )
141+
(query_expr acc) )
138142
indexed ) in
139-
if is_uni_eigen_loop_indexing in_loop type_ indexed then
140-
Set.Poly.union (query_var_eigen_names expr) index_set
141-
else Set.Poly.union (query_expr expr) index_set
143+
let index_demotes =
144+
if is_uni_eigen_loop_indexing in_loop type_ indexed then
145+
Set.Poly.union (query_var_eigen_names expr) index_set
146+
else Set.Poly.union (query_expr acc expr) index_set in
147+
Set.Poly.union acc index_demotes
142148
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
143-
Set.Poly.empty
149+
acc
144150
| TernaryIf (predicate, texpr, fexpr) ->
151+
let predicate_demotes = query_expr acc predicate in
145152
Set.Poly.union
146-
(Set.Poly.union (query_expr predicate) (query_var_eigen_names texpr))
153+
(Set.Poly.union predicate_demotes (query_var_eigen_names texpr))
147154
(query_var_eigen_names fexpr)
148155
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
149-
Set.Poly.union (query_expr lhs) (query_expr rhs)
156+
(*We need to get the demotes from both sides*)
157+
let full_lhs_rhs =
158+
Set.Poly.union (query_expr acc lhs) (query_expr acc rhs) in
159+
Set.Poly.union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs)
150160

151161
(**
152162
* Query a function to detect if it or any of its used
@@ -165,25 +175,28 @@ let rec query_initial_demotable_expr (in_loop : bool) Expr.Fixed.{pattern; _} =
165175
* to the UDF.
166176
* exprs The expression list passed to the functions.
167177
*)
168-
and query_initial_demotable_funs (in_loop : bool) (kind : 'a Fun_kind.t)
169-
(exprs : Typed.Meta.t Expr.Fixed.t list) : string Set.Poly.t =
170-
let query_expr = query_initial_demotable_expr in_loop in
178+
and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
179+
(kind : 'a Fun_kind.t) (exprs : Typed.Meta.t Expr.Fixed.t list) :
180+
string Set.Poly.t =
181+
let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in
171182
let top_level_eigen_names =
172183
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
173-
let demoted_eigen_names = Set.Poly.union_list (List.map ~f:query_expr exprs) in
184+
let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in
185+
let demoted_and_top_level_names =
186+
Set.Poly.union demoted_eigen_names top_level_eigen_names in
174187
match kind with
175188
| Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> (
176189
match name with
177-
| "check_matching_dims" -> Set.Poly.empty
190+
| "check_matching_dims" -> acc
178191
| name -> (
179192
match is_fun_soa_supported name exprs with
180-
| true -> demoted_eigen_names
181-
| false -> Set.Poly.union demoted_eigen_names top_level_eigen_names ) )
193+
| true -> Set.Poly.union acc demoted_eigen_names
194+
| false -> Set.Poly.union acc demoted_and_top_level_names ) )
182195
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) ->
183-
Set.Poly.union demoted_eigen_names top_level_eigen_names
184-
| CompilerInternal (_ : 'a Internal_fun.t) -> Set.Poly.empty
196+
Set.Poly.union acc demoted_and_top_level_names
197+
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
185198
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
186-
Set.Poly.union demoted_eigen_names top_level_eigen_names
199+
Set.Poly.union acc demoted_and_top_level_names
187200

188201
(**
189202
* Check whether any functions in the right hand side expression of an assignment
@@ -295,7 +308,8 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
295308
*
296309
* For assignments:
297310
* We demote the LHS variable if any of the following are true:
298-
* 1. None of the RHS's functions are able to accept SoA matrices
311+
* 1. None of the RHS's functions are able to accept SoA matrices
312+
* and the rhs is not an internal compiler function.
299313
* 2. A single cell of the LHS is being assigned within a loop.
300314
* 3. The top level expression on the RHS is a combination of only
301315
* data matrices and scalar types. Operations on data matrix and
@@ -311,66 +325,85 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
311325
* @param in_loop A boolean to specify the logic of indexing expressions. See
312326
* `query_initial_demotable_expr` for an explanation of the logic.
313327
*)
314-
let rec query_initial_demotable_stmt (in_loop : bool)
328+
let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
315329
(Stmt.Fixed.{pattern; _} :
316330
(Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t ) :
317331
string Set.Poly.t =
318-
let query_expr = query_initial_demotable_expr in_loop in
332+
let query_expr (accum : string Set.Poly.t) =
333+
query_initial_demotable_expr in_loop ~acc:accum in
319334
match pattern with
320335
| Stmt.Fixed.Pattern.Assignment
321336
( ((name : string), (ut : UnsizedType.t), idx)
322337
, (Expr.Fixed.{meta= Expr.Typed.Meta.{type_; adlevel; _}; _} as rhs) ) ->
338+
let idx_list =
339+
List.fold ~init:acc
340+
~f:(fun accum x ->
341+
Index.folder accum
342+
(fun acc -> query_initial_demotable_expr in_loop ~acc)
343+
x )
344+
idx in
323345
let idx_demotable =
324-
let idx_list =
325-
Set.Poly.union_list
326-
(List.map
327-
~f:
328-
(Index.apply ~default:Set.Poly.empty ~merge:Set.Poly.union
329-
query_expr )
330-
idx ) in
331346
(* RHS (2)*)
332347
match is_uni_eigen_loop_indexing in_loop ut idx with
333348
| true -> Set.Poly.add idx_list name
334349
| false -> idx_list in
335-
let rhs_demotable = query_expr rhs in
350+
let rhs_demotable_names = query_expr acc rhs in
336351
(* RHS (3)*)
337-
let check_bad_assign =
352+
let check_if_rhs_ad_real_data_matrix_expr =
338353
match (UnsizedType.contains_eigen_type type_, adlevel) with
339354
| true, UnsizedType.AutoDiffable ->
340355
is_any_ad_real_data_matrix_expr rhs
341356
|| not (is_any_soa_supported_expr rhs)
342357
| _ -> false in
343358
(* RHS (1)*)
344359
let is_all_rhs_aos =
345-
let all_eigen_names = query_var_eigen_names rhs in
346-
is_nonzero_subset ~set:all_eigen_names ~subset:rhs_demotable in
347-
if is_all_rhs_aos || check_bad_assign then
348-
let base_set = Set.Poly.union idx_demotable rhs_demotable in
349-
Set.Poly.add (Set.Poly.union base_set (query_var_eigen_names rhs)) name
350-
else Set.Poly.union idx_demotable rhs_demotable
351-
| NRFunApp (kind, exprs) -> query_initial_demotable_funs in_loop kind exprs
360+
let all_rhs_eigen_names = query_var_eigen_names rhs in
361+
is_nonzero_subset ~subset:all_rhs_eigen_names ~set:rhs_demotable_names
362+
in
363+
let is_not_supported_func =
364+
match rhs.pattern with
365+
| FunApp (CompilerInternal _, _) -> false
366+
| FunApp (UserDefined _, _) -> true
367+
| _ -> false in
368+
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
369+
let assign_demotes =
370+
if
371+
is_eigen_stmt
372+
&& ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr
373+
|| is_not_supported_func )
374+
then
375+
let base_set = Set.Poly.union idx_demotable rhs_demotable_names in
376+
Set.Poly.add
377+
(Set.Poly.union base_set (query_var_eigen_names rhs))
378+
name
379+
else Set.Poly.union idx_demotable rhs_demotable_names in
380+
Set.Poly.union acc assign_demotes
381+
| NRFunApp (kind, exprs) ->
382+
query_initial_demotable_funs in_loop acc kind exprs
352383
| IfElse (predicate, true_stmt, op_false_stmt) ->
353-
let demotable_rhs =
354-
Option.value_map
355-
~f:(query_initial_demotable_stmt in_loop)
356-
~default:Set.Poly.empty op_false_stmt in
357-
Set.Poly.union_list
358-
[ query_expr predicate; query_initial_demotable_stmt in_loop true_stmt
359-
; demotable_rhs ]
384+
let predicate_acc = query_expr acc predicate in
385+
Set.Poly.union acc
386+
(Set.Poly.union_list
387+
[ predicate_acc
388+
; query_initial_demotable_stmt in_loop predicate_acc true_stmt
389+
; Option.value_map
390+
~f:(query_initial_demotable_stmt in_loop predicate_acc)
391+
~default:Set.Poly.empty op_false_stmt ] )
360392
| Return optional_expr ->
361-
Option.value_map ~f:query_expr ~default:Set.Poly.empty optional_expr
393+
Option.value_map ~f:(query_expr acc) ~default:Set.Poly.empty optional_expr
362394
| SList lst | Profile (_, lst) | Block lst ->
363395
Set.Poly.union_list
364-
(List.map ~f:(query_initial_demotable_stmt in_loop) lst)
365-
| TargetPE expr -> query_expr expr
396+
(List.map ~f:(query_initial_demotable_stmt in_loop acc) lst)
397+
| TargetPE expr -> query_expr acc expr
366398
| For {lower; upper; body; _} ->
367399
Set.Poly.union
368-
(Set.Poly.union (query_expr lower) (query_expr upper))
369-
(query_initial_demotable_stmt true body)
400+
(Set.Poly.union (query_expr acc lower) (query_expr acc upper))
401+
(query_initial_demotable_stmt true acc body)
370402
| While (predicate, body) ->
371-
Set.Poly.union (query_expr predicate)
372-
(query_initial_demotable_stmt true body)
373-
| Skip | Break | Continue | Decl _ -> Set.Poly.empty
403+
Set.Poly.union_list
404+
[ acc; query_expr acc predicate
405+
; query_initial_demotable_stmt true acc body ]
406+
| Skip | Break | Continue | Decl _ -> acc
374407

375408
(** Look through a statement to see whether the objects used in it need to be
376409
* modified from SoA to AoS. Returns the set of object names that need demoted
@@ -402,14 +435,6 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
402435
(* All other statements do not need logic here*)
403436
| _ -> Set.Poly.empty
404437

405-
(**
406-
* Search through an expression for the names of all types that hold matrices
407-
* and vectors.
408-
**)
409-
let query_eigen_names (expr : Typed.Meta.t Expr.Fixed.t) : string Set.Poly.t =
410-
let get_expr_names (Dataflow_types.VVar s, _) = Some s in
411-
Set.Poly.filter_map ~f:get_expr_names (matrix_set expr)
412-
413438
(**
414439
* Modify a function and it's subexpressions from SoA <-> AoS and vice versa.
415440
* This performs demotion for sub expressions recursively. The top level
@@ -428,7 +453,8 @@ let query_eigen_names (expr : Typed.Meta.t Expr.Fixed.t) : string Set.Poly.t =
428453
let rec modify_kind ?force_demotion:(force = false)
429454
(modifiable_set : string Set.Poly.t) (kind : 'a Fun_kind.t)
430455
(exprs : Expr.Typed.Meta.t Expr.Fixed.t list) =
431-
let expr_names = Set.Poly.union_list (List.map ~f:query_eigen_names exprs) in
456+
let expr_names =
457+
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
432458
let is_all_in_list =
433459
is_nonzero_subset ~set:modifiable_set ~subset:expr_names in
434460
match kind with

src/analysis_and_optimization/Optimize.ml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,10 +1153,9 @@ let optimize_soa (mir : Program.Typed.t) =
11531153
match (mir_node l).pattern with
11541154
| stmt -> Mem_pattern.query_demotable_stmt aos_variables stmt in
11551155
let initial_variables =
1156-
Set.Poly.union_list
1157-
(List.map
1158-
~f:(Mem_pattern.query_initial_demotable_stmt false)
1159-
mir.log_prob ) in
1156+
List.fold ~init:Set.Poly.empty
1157+
~f:(Mem_pattern.query_initial_demotable_stmt false)
1158+
mir.log_prob in
11601159
(*
11611160
let print_set s =
11621161
Set.Poly.iter ~f:print_endline s in

src/frontend/Canonicalize.ml

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,28 @@ let rec replace_deprecated_expr
9191
ident expr in
9292
{expr; emeta}
9393

94-
let replace_deprecated_lval deprecated_userdefined =
95-
map_lval_with (replace_deprecated_expr deprecated_userdefined) ident
94+
let replace_deprecated_lval deprecated_userdefined {lval; lmeta} =
95+
let is_multiindex = function
96+
| Single {emeta= {type_= Middle.UnsizedType.UInt; _}; _} -> false
97+
| _ -> true in
98+
let rec flatten_multi = function
99+
| LVariable id -> (LVariable id, None)
100+
| LIndexed ({lval; lmeta}, idcs) -> (
101+
let outer =
102+
List.map idcs
103+
~f:(map_index (replace_deprecated_expr deprecated_userdefined))
104+
in
105+
let unwrap = Option.value_map ~default:[] ~f:fst in
106+
match flatten_multi lval with
107+
| lval, inner when List.exists ~f:is_multiindex outer ->
108+
(lval, Some (unwrap inner @ outer, lmeta))
109+
| lval, None -> (LIndexed ({lval; lmeta}, outer), None)
110+
| lval, Some (inner, _) -> (lval, Some (inner @ outer, lmeta)) ) in
111+
let lval =
112+
match flatten_multi lval with
113+
| lval, None -> lval
114+
| lval, Some (idcs, lmeta) -> LIndexed ({lval; lmeta}, idcs) in
115+
{lval; lmeta}
96116

97117
let rec replace_deprecated_stmt
98118
(deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t)

src/frontend/Semantic_error.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ module StatementError = struct
320320
type t =
321321
| CannotAssignToReadOnly of string
322322
| CannotAssignToGlobal of string
323+
| LValueMultiIndexing
323324
| InvalidSamplingPDForPMF
324325
| InvalidSamplingCDForCCDF of string
325326
| InvalidSamplingNoSuchDistribution of string
@@ -352,6 +353,9 @@ module StatementError = struct
352353
Fmt.pf ppf
353354
"Cannot assign to global variable '%s' declared in previous blocks."
354355
name
356+
| LValueMultiIndexing ->
357+
Fmt.pf ppf
358+
"Left hand side of an assignment cannot have nested multi-indexing."
355359
| TargetPlusEqualsOutsideModelOrLogProb ->
356360
Fmt.pf ppf
357361
"Target can only be accessed in the model block or in definitions of \
@@ -604,6 +608,9 @@ let cannot_assign_to_read_only loc name =
604608
let cannot_assign_to_global loc name =
605609
StatementError (loc, StatementError.CannotAssignToGlobal name)
606610

611+
let cannot_assign_to_multiindex loc =
612+
StatementError (loc, StatementError.LValueMultiIndexing)
613+
607614
let invalid_sampling_pdf_or_pmf loc =
608615
StatementError (loc, StatementError.InvalidSamplingPDForPMF)
609616

src/frontend/Semantic_error.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ val empty_array : Location_span.t -> t
103103
val bad_int_literal : Location_span.t -> t
104104
val cannot_assign_to_read_only : Location_span.t -> string -> t
105105
val cannot_assign_to_global : Location_span.t -> string -> t
106+
val cannot_assign_to_multiindex : Location_span.t -> t
106107
val invalid_sampling_pdf_or_pmf : Location_span.t -> t
107108
val invalid_sampling_cdf_or_ccdf : Location_span.t -> string -> t
108109
val invalid_sampling_no_such_dist : Location_span.t -> string -> t

0 commit comments

Comments
 (0)