@@ -400,7 +400,10 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} =
400400 ( [inline_function_statement propto adt fim body]
401401 @ map_no_loc s_upper )
402402 ; meta= Location_span. empty } ) } )
403- | Profile (_ , l ) | Block l ->
403+ | Profile (name , l ) ->
404+ Profile
405+ (name, List. map l ~f: (inline_function_statement propto adt fim))
406+ | Block l ->
404407 Block (List. map l ~f: (inline_function_statement propto adt fim))
405408 | SList l ->
406409 SList (List. map l ~f: (inline_function_statement propto adt fim))
@@ -535,7 +538,7 @@ let unroll_loop_one_step_statement _ =
535538 ( Expr.Fixed.
536539 { lower with
537540 pattern=
538- FunApp (StanLib (" Geq__" , FnPlain , SoA ), [upper; lower]) }
541+ FunApp (StanLib (" Geq__" , FnPlain , AoS ), [upper; lower]) }
539542 , { pattern=
540543 (let body_unrolled =
541544 subst_args_stmt [loopvar] [lower]
@@ -550,7 +553,7 @@ let unroll_loop_one_step_statement _ =
550553 { lower with
551554 pattern=
552555 FunApp
553- ( StanLib (" Plus__" , FnPlain , SoA )
556+ ( StanLib (" Plus__" , FnPlain , AoS )
554557 , [lower; Expr.Helpers. loop_bottom] ) } }
555558 ; meta= Location_span. empty } in
556559 match body_unrolled.pattern with
@@ -722,7 +725,10 @@ let dead_code_elimination (mir : Program.Typed.t) =
722725 && is_skip_break_continue body.pattern
723726 then Skip
724727 else For {loopvar; lower; upper; body}
725- | Profile (_ , l ) | Block l ->
728+ | Profile (name , l ) ->
729+ let l' = List. filter ~f: (fun x -> x.pattern <> Skip ) l in
730+ if List. length l' = 0 then Skip else Profile (name, l')
731+ | Block l ->
726732 let l' = List. filter ~f: (fun x -> x.pattern <> Skip ) l in
727733 if List. length l' = 0 then Skip else Block l'
728734 | SList l ->
@@ -1118,6 +1124,64 @@ let optimize_ad_levels (mir : Program.Typed.t) =
11181124 stmt in
11191125 transform_program_blockwise mir transform
11201126
1127+ (* *
1128+ * Deduces whether types can be Structures of Arrays (SoA/fast) or
1129+ * Arrays of Structs (AoS/slow). See the docs in
1130+ * Mem_pattern.query_demote_stmt/exprs* functions for
1131+ * details on the rules surrounding when demotion from
1132+ * SoA -> AoS needs to happen.
1133+ *
1134+ * This first does a simple iter over
1135+ * the log_prob portion of the MIR, finding the names of all matrices
1136+ * (and arrays of matrices) where either the Stan math function
1137+ * does not support SoA or the object is single cell accesed within a
1138+ * For or While loop. These are the initial variables
1139+ * given to the monotone framework. Then log_prob has all matrix like objects
1140+ * and the functions that use them to SoA. After that the
1141+ * Monotone framework is used to deduce assignment paths of AoS <-> SoA
1142+ * and vice versa which need to be demoted to AoS as well as updating
1143+ * functions and objects after these assignment passes that then
1144+ * also need to be AoS.
1145+ *
1146+ * @param mir: The program's whole MIR.
1147+ *)
1148+ let optimize_soa (mir : Program.Typed.t ) =
1149+ let gen_aos_variables
1150+ (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t )
1151+ (l : int ) (aos_variables : string Set.Poly.t ) =
1152+ let mir_node mir_idx = Map. find_exn flowgraph_to_mir mir_idx in
1153+ match (mir_node l).pattern with
1154+ | stmt -> Mem_pattern. query_demotable_stmt aos_variables stmt in
1155+ let initial_variables =
1156+ List. fold ~init: Set.Poly. empty
1157+ ~f: (Mem_pattern. query_initial_demotable_stmt false )
1158+ mir.log_prob in
1159+ (*
1160+ let print_set s =
1161+ Set.Poly.iter ~f:print_endline s in
1162+ let () = print_set initial_variables in
1163+ *)
1164+ let mod_exprs aos_exits mod_expr =
1165+ Mir_utils. map_rec_expr (Mem_pattern. modify_expr_pattern aos_exits) mod_expr
1166+ in
1167+ let modify_stmt_patt stmt_pattern variable_set =
1168+ Mem_pattern. modify_stmt_pattern stmt_pattern variable_set in
1169+ let transform stmt =
1170+ optimize_minimal_variables ~gen_variables: gen_aos_variables
1171+ ~update_expr: mod_exprs ~update_stmt: modify_stmt_patt ~initial_variables
1172+ stmt ~extra_variables: (fun _ -> initial_variables) in
1173+ let transform' s =
1174+ match transform {pattern= SList s; meta= Location_span. empty} with
1175+ | { pattern=
1176+ SList (l : (Expr.Typed.Meta.t, Stmt.Located.Meta.t ) Stmt.Fixed. t list )
1177+ ; _ } ->
1178+ l
1179+ | _ ->
1180+ raise
1181+ (Failure " Something went wrong with program transformation packing!" )
1182+ in
1183+ {mir with log_prob= transform' mir.log_prob}
1184+
11211185(* Apparently you need to completely copy/paste type definitions between
11221186 ml and mli files?*)
11231187type optimization_settings =
@@ -1134,7 +1198,8 @@ type optimization_settings =
11341198 ; partial_evaluation : bool
11351199 ; lazy_code_motion : bool
11361200 ; optimize_ad_levels : bool
1137- ; preserve_stability : bool }
1201+ ; preserve_stability : bool
1202+ ; optimize_soa : bool }
11381203
11391204let settings_const b =
11401205 { function_inlining= b
@@ -1150,7 +1215,8 @@ let settings_const b =
11501215 ; partial_evaluation= b
11511216 ; lazy_code_motion= b
11521217 ; optimize_ad_levels= b
1153- ; preserve_stability= not b }
1218+ ; preserve_stability= not b
1219+ ; optimize_soa= b }
11541220
11551221let all_optimizations : optimization_settings = settings_const true
11561222let no_optimizations : optimization_settings = settings_const false
@@ -1159,7 +1225,7 @@ type optimization_level = O0 | O1 | Oexperimental
11591225
11601226let level_optimizations (lvl : optimization_level ) : optimization_settings =
11611227 match lvl with
1162- | O0 -> { no_optimizations with allow_uninitialized_decls = false }
1228+ | O0 -> no_optimizations
11631229 | O1 ->
11641230 { function_inlining= false
11651231 ; static_loop_unrolling= false
@@ -1174,7 +1240,8 @@ let level_optimizations (lvl : optimization_level) : optimization_settings =
11741240 ; lazy_code_motion= false
11751241 ; allow_uninitialized_decls= false
11761242 ; optimize_ad_levels= true
1177- ; preserve_stability= false }
1243+ ; preserve_stability= false
1244+ ; optimize_soa= true }
11781245 | Oexperimental -> all_optimizations
11791246
11801247let optimization_suite ?(settings = all_optimizations) mir =
@@ -1220,7 +1287,8 @@ let optimization_suite ?(settings = all_optimizations) mir =
12201287 ; (optimize_ad_levels, settings.optimize_ad_levels)
12211288 (* Book: Machine idioms and instruction combining *)
12221289 (* Matthijs: Everything < block_fixing *)
1223- ; (block_fixing, settings.block_fixing) ] in
1290+ ; (block_fixing, settings.block_fixing)
1291+ ; (optimize_soa, settings.optimize_soa) ] in
12241292 let optimizations =
12251293 List. filter_map maybe_optimizations ~f: (fun (fn , flag ) ->
12261294 if flag then Some fn else None ) in
0 commit comments