Skip to content

Commit 282439e

Browse files
authored
Merge pull request #856 from rybern/constraint-refactor-2
Call new deserializer backend for constrained reads
2 parents 036d930 + b8e666e commit 282439e

35 files changed

+7686
-16118
lines changed

Jenkinsfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import org.stan.Utils
33

44
def utils = new org.stan.Utils()
55
def skipExpressionTests = false
6-
76
/* Functions that runs a sh command and returns the stdout */
87
def runShell(String command){
98
def output = sh (returnStdout: true, script: "${command}").trim()

src/analysis_and_optimization/Factor_graph.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ let extract_factors_statement stmt =
2222
match stmt with
2323
| Stmt.Fixed.Pattern.TargetPE e ->
2424
List.map (summation_terms e) ~f:(fun x -> TargetTerm x)
25-
| NRFunApp (_, f, _) when Internal_fun.of_string_opt f = Some FnReject ->
26-
[Reject]
27-
| NRFunApp (_, s, args) when String.suffix s 3 = "_lp" ->
25+
| NRFunApp (CompilerInternal FnReject, _) -> [Reject]
26+
| NRFunApp ((UserDefined s | StanLib s), args) when String.suffix s 3 = "_lp"
27+
->
2828
[LPFunction (s, args)]
2929
| Assignment (_, _)
30-
|NRFunApp (_, _, _)
30+
|NRFunApp (_, _)
3131
|Break | Continue | Return _ | Skip
3232
|IfElse (_, _, _)
3333
|While (_, _)

src/analysis_and_optimization/Mir_utils.ml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ let rec num_expr_value (v : Expr.Typed.t) : (float * string) option =
3030
| {pattern= Fixed.Pattern.Lit (Real, str); _}
3131
|{pattern= Fixed.Pattern.Lit (Int, str); _} ->
3232
Some (float_of_string str, str)
33-
| {pattern= Fixed.Pattern.FunApp (StanLib, "PMinus__", [v]); _} -> (
33+
| {pattern= Fixed.Pattern.FunApp (StanLib "PMinus__", [v]); _} -> (
3434
match num_expr_value v with
3535
| Some (v, s) -> Some (-.v, "-" ^ s)
3636
| None -> None )
@@ -252,7 +252,7 @@ let rec expr_var_set Expr.Fixed.({pattern; meta}) =
252252
match pattern with
253253
| Var s -> Set.Poly.singleton (VVar s, meta)
254254
| Lit _ -> Set.Poly.empty
255-
| FunApp (_, _, exprs) -> union_recur exprs
255+
| FunApp (_, exprs) -> union_recur exprs
256256
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
257257
| Indexed (expr, ix) ->
258258
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
@@ -270,7 +270,7 @@ and index_var_set ix =
270270
let stmt_rhs stmt =
271271
match stmt with
272272
| Stmt.Fixed.Pattern.For vars -> Set.Poly.of_list [vars.lower; vars.upper]
273-
| NRFunApp (_, _, exprs) -> Set.Poly.of_list exprs
273+
| NRFunApp (_, exprs) -> Set.Poly.of_list exprs
274274
| IfElse (rhs, _, _)
275275
|While (rhs, _)
276276
|Assignment (_, rhs)
@@ -296,7 +296,7 @@ let expr_assigned_var Expr.Fixed.({pattern; _}) =
296296
(** See interface file *)
297297
let rec summation_terms (Expr.Fixed.({pattern; _}) as rhs) =
298298
match pattern with
299-
| FunApp (_, "Plus__", [e1; e2]) ->
299+
| FunApp (StanLib "Plus__", [e1; e2]) ->
300300
List.append (summation_terms e1) (summation_terms e2)
301301
| _ -> [rhs]
302302

@@ -356,7 +356,7 @@ let expr_subst_stmt m = map_rec_stmt_loc (expr_subst_stmt_base m)
356356
let rec expr_depth Expr.Fixed.({pattern; _}) =
357357
match pattern with
358358
| Var _ | Lit (_, _) -> 0
359-
| FunApp (_, _, l) ->
359+
| FunApp (_, l) ->
360360
1
361361
+ Option.value ~default:0
362362
(List.max_elt ~compare:compare_int (List.map ~f:expr_depth l))
@@ -394,9 +394,9 @@ let rec update_expr_ad_levels autodiffable_variables
394394
Expr.Typed.{e with meta= Meta.{e.meta with adlevel= AutoDiffable}}
395395
else {e with meta= {e.meta with adlevel= DataOnly}}
396396
| Lit (_, _) -> {e with meta= {e.meta with adlevel= DataOnly}}
397-
| FunApp (o, f, l) ->
397+
| FunApp (kind, l) ->
398398
let l = List.map ~f:(update_expr_ad_levels autodiffable_variables) l in
399-
{pattern= FunApp (o, f, l); meta= {e.meta with adlevel= ad_level_sup l}}
399+
{pattern= FunApp (kind, l); meta= {e.meta with adlevel= ad_level_sup l}}
400400
| TernaryIf (e1, e2, e3) ->
401401
let e1 = update_expr_ad_levels autodiffable_variables e1 in
402402
let e2 = update_expr_ad_levels autodiffable_variables e2 in

src/analysis_and_optimization/Monotone_framework.ml

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ let rec free_vars_expr (e : Expr.Typed.t) =
2929
match e.pattern with
3030
| Var x -> Set.Poly.singleton x
3131
| Lit (_, _) -> Set.Poly.empty
32-
| FunApp (_, f, l) ->
33-
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
32+
| FunApp (kind, l) -> free_vars_fnapp kind l
3433
| TernaryIf (e1, e2, e3) ->
3534
Set.Poly.union_list (List.map ~f:free_vars_expr [e1; e2; e3])
3635
| Indexed (e, l) ->
@@ -45,6 +44,13 @@ and free_vars_idx (i : Expr.Typed.t Index.t) =
4544
| Single e | Upfrom e | MultiIndex e -> free_vars_expr e
4645
| Between (e1, e2) -> Set.Poly.union (free_vars_expr e1) (free_vars_expr e2)
4746

47+
and free_vars_fnapp kind l =
48+
let arg_vars = List.map ~f:free_vars_expr l in
49+
match kind with
50+
| Fun_kind.UserDefined f ->
51+
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
52+
| _ -> Set.Poly.union_list arg_vars
53+
4854
(** Calculate the free (non-bound) variables in a statement *)
4955
let rec free_vars_stmt
5056
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
@@ -53,8 +59,7 @@ let rec free_vars_stmt
5359
free_vars_expr e
5460
| Assignment ((_, _, l), e) ->
5561
Set.Poly.union_list (free_vars_expr e :: List.map ~f:free_vars_idx l)
56-
| NRFunApp (_, f, l) ->
57-
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
62+
| NRFunApp (kind, l) -> free_vars_fnapp kind l
5863
| IfElse (e, b1, Some b2) ->
5964
Set.Poly.union_list
6065
[free_vars_expr e; free_vars_stmt b1.pattern; free_vars_stmt b2.pattern]
@@ -314,7 +319,7 @@ let constant_propagation_transfer
314319
| Decl {decl_id= s; _} | Assignment ((s, _, _ :: _), _) ->
315320
Map.remove m s
316321
| TargetPE _
317-
|NRFunApp (_, _, _)
322+
|NRFunApp (_, _)
318323
|Break | Continue | Return _ | Skip
319324
|IfElse (_, _, _)
320325
|While (_, _)
@@ -373,7 +378,7 @@ let expression_propagation_transfer
373378
in
374379
Set.Poly.fold kills ~init:m ~f:kill_var
375380
| TargetPE _
376-
|NRFunApp (_, _, _)
381+
|NRFunApp (_, _)
377382
|Break | Continue | Return _ | Skip
378383
|IfElse (_, _, _)
379384
|While (_, _)
@@ -414,7 +419,7 @@ let copy_propagation_transfer (globals : string Set.Poly.t)
414419
in
415420
Set.Poly.fold kills ~init:m ~f:kill_var
416421
| TargetPE _
417-
|NRFunApp (_, _, _)
422+
|NRFunApp (_, _)
418423
|Break | Continue | Return _ | Skip
419424
|IfElse (_, _, _)
420425
|While (_, _)
@@ -435,11 +440,11 @@ let assigned_vars_stmt (s : (Expr.Typed.t, 'a) Stmt.Fixed.Pattern.t) =
435440
match s with
436441
| Assignment ((x, _, _), _) -> Set.Poly.singleton x
437442
| TargetPE _ -> Set.Poly.singleton "target"
438-
| NRFunApp (_, s, _) when String.suffix s 3 = "_lp" ->
443+
| NRFunApp ((UserDefined s | StanLib s), _) when String.suffix s 3 = "_lp" ->
439444
Set.Poly.singleton "target"
440445
| For {loopvar= x; _} -> Set.Poly.singleton x
441446
| Decl {decl_id= _; _}
442-
|NRFunApp (_, _, _)
447+
|NRFunApp (_, _)
443448
|Break | Continue | Return _ | Skip
444449
|IfElse (_, _, _)
445450
|While (_, _)
@@ -478,9 +483,10 @@ let reaching_definitions_transfer
478483
|For {loopvar= x; _} ->
479484
Set.filter p ~f:(fun (y, _) -> y = x)
480485
| TargetPE _ -> Set.filter p ~f:(fun (y, _) -> y = "target")
481-
| NRFunApp (_, s, _) when String.suffix s 3 = "_lp" ->
486+
| NRFunApp ((UserDefined s | StanLib s), _)
487+
when String.suffix s 3 = "_lp" ->
482488
Set.filter p ~f:(fun (y, _) -> y = "target")
483-
| NRFunApp (_, _, _)
489+
| NRFunApp (_, _)
484490
|Break | Continue | Return _ | Skip
485491
|IfElse (_, _, _)
486492
|While (_, _)
@@ -523,7 +529,7 @@ let live_variables_transfer (never_kill : string Set.Poly.t)
523529
| Assignment ((x, _, []), _) | Decl {decl_id= x; _} ->
524530
Set.Poly.singleton x
525531
| TargetPE _
526-
|NRFunApp (_, _, _)
532+
|NRFunApp (_, _)
527533
|Break | Continue | Return _ | Skip
528534
|IfElse (_, _, _)
529535
|While (_, _)
@@ -542,7 +548,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
542548
(Expr.Typed.Set.singleton e)
543549
( match e.pattern with
544550
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
545-
| FunApp (_, _, l) ->
551+
| FunApp (_, l) ->
546552
Expr.Typed.Set.union_list (List.map ~f:used_subexpressions_expr l)
547553
| TernaryIf (e1, e2, e3) ->
548554
Expr.Typed.Set.union_list
@@ -580,7 +586,7 @@ let rec used_expressions_stmt_help f
580586
[ f e
581587
; used_expressions_stmt_help f b1.pattern
582588
; used_expressions_stmt_help f b2.pattern ]
583-
| NRFunApp (_, _, l) -> Expr.Typed.Set.union_list (List.map ~f l)
589+
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
584590
| Decl _ | Return None | Break | Continue | Skip -> Expr.Typed.Set.empty
585591
| IfElse (e, b, None) | While (e, b) ->
586592
Expr.Typed.Set.union (f e) (used_expressions_stmt_help f b.pattern)
@@ -614,7 +620,7 @@ let top_used_expressions_stmt_help f
614620
(Expr.Typed.Set.union_list
615621
(List.map ~f:(used_expressions_idx_help f) l))
616622
| While (e, _) | IfElse (e, _, _) -> f e
617-
| NRFunApp (_, _, l) -> Expr.Typed.Set.union_list (List.map ~f l)
623+
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
618624
| Profile _ | Block _ | SList _ | Decl _
619625
|Return None
620626
|Break | Continue | Skip ->
@@ -899,7 +905,7 @@ let rec declared_variables_stmt
899905
| Decl {decl_id= x; _} -> Set.Poly.singleton x
900906
| Assignment (_, _)
901907
|TargetPE _
902-
|NRFunApp (_, _, _)
908+
|NRFunApp (_, _)
903909
|Break | Continue | Return _ | Skip ->
904910
Set.Poly.empty
905911
| IfElse (_, b1, Some b2) ->

src/analysis_and_optimization/Optimize.ml

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ let rec inline_function_expression propto adt fim
219219
match pattern with
220220
| Var _ -> ([], [], e)
221221
| Lit (_, _) -> ([], [], e)
222-
| FunApp (t, s, es) -> (
222+
| FunApp (kind, es) -> (
223223
let dse_list =
224224
List.map ~f:(inline_function_expression propto adt fim) es
225225
in
@@ -231,30 +231,45 @@ let rec inline_function_expression propto adt fim
231231
List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list))
232232
in
233233
let es = List.map ~f:(function _, _, x -> x) dse_list in
234-
let s = if propto then s else Middle.Utils.stdlib_distribution_name s in
235-
match Map.find fim s with
236-
| None -> (d_list, s_list, {e with pattern= FunApp (t, s, es)})
237-
| Some (rt, args, b) ->
238-
let x = Gensym.generate ~prefix:"inline_" () in
239-
let handle = handle_early_returns (Some x) in
240-
let d_list2, s_list2, (e : Expr.Typed.t) =
241-
( [ Stmt.Fixed.Pattern.Decl
242-
{decl_adtype= adt; decl_id= x; decl_type= Option.value_exn rt}
243-
]
244-
(* We should minimize the code that's having its variables
234+
match kind with
235+
| CompilerInternal _ ->
236+
(d_list, s_list, {e with pattern= FunApp (kind, es)})
237+
| UserDefined fname | StanLib fname -> (
238+
let fname =
239+
if propto then fname
240+
else Middle.Utils.stdlib_distribution_name fname
241+
in
242+
match Map.find fim fname with
243+
| None ->
244+
let fun_kind =
245+
match kind with
246+
| Fun_kind.UserDefined _ -> Fun_kind.UserDefined fname
247+
| _ -> StanLib fname
248+
in
249+
(d_list, s_list, {e with pattern= FunApp (fun_kind, es)})
250+
| Some (rt, args, b) ->
251+
let x = Gensym.generate ~prefix:"inline_" () in
252+
let handle = handle_early_returns (Some x) in
253+
let d_list2, s_list2, (e : Expr.Typed.t) =
254+
( [ Stmt.Fixed.Pattern.Decl
255+
{ decl_adtype= adt
256+
; decl_id= x
257+
; decl_type= Option.value_exn rt } ]
258+
(* We should minimize the code that's having its variables
245259
replaced to avoid conflict with the (two) new dummy
246260
variables introduced by inlining *)
247-
, [handle (replace_fresh_local_vars (subst_args_stmt args es b))]
248-
, { pattern= Var x
249-
; meta=
250-
Expr.Typed.Meta.
251-
{ type_= Type.to_unsized (Option.value_exn rt)
252-
; adlevel= adt
253-
; loc= Location_span.empty } } )
254-
in
255-
let d_list = d_list @ d_list2 in
256-
let s_list = s_list @ s_list2 in
257-
(d_list, s_list, e) )
261+
, [ handle
262+
(replace_fresh_local_vars (subst_args_stmt args es b)) ]
263+
, { pattern= Var x
264+
; meta=
265+
Expr.Typed.Meta.
266+
{ type_= Type.to_unsized (Option.value_exn rt)
267+
; adlevel= adt
268+
; loc= Location_span.empty } } )
269+
in
270+
let d_list = d_list @ d_list2 in
271+
let s_list = s_list @ s_list2 in
272+
(d_list, s_list, e) ) )
258273
| TernaryIf (e1, e2, e3) ->
259274
let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in
260275
let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in
@@ -347,7 +362,7 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
347362
| TargetPE e ->
348363
let d, s, e = inline_function_expression propto adt fim e in
349364
slist_concat_no_loc (d @ s) (TargetPE e)
350-
| NRFunApp (t, s, es) ->
365+
| NRFunApp (kind, es) ->
351366
let dse_list =
352367
List.map ~f:(inline_function_expression propto adt fim) es
353368
in
@@ -362,14 +377,17 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
362377
in
363378
let es = List.map ~f:(function _, _, x -> x) dse_list in
364379
slist_concat_no_loc (d_list @ s_list)
365-
( match Map.find fim s with
366-
| None -> NRFunApp (t, s, es)
367-
| Some (_, args, b) ->
368-
let b = replace_fresh_local_vars b in
369-
let b = handle_early_returns None b in
370-
(subst_args_stmt args es
371-
{pattern= b; meta= Location_span.empty})
372-
.pattern )
380+
( match kind with
381+
| CompilerInternal _ -> NRFunApp (kind, es)
382+
| UserDefined s | StanLib s -> (
383+
match Map.find fim s with
384+
| None -> NRFunApp (kind, es)
385+
| Some (_, args, b) ->
386+
let b = replace_fresh_local_vars b in
387+
let b = handle_early_returns None b in
388+
(subst_args_stmt args es
389+
{pattern= b; meta= Location_span.empty})
390+
.pattern ) )
373391
| Return e -> (
374392
match e with
375393
| None -> Return None
@@ -499,7 +517,7 @@ let rec contains_top_break_or_continue Stmt.Fixed.({pattern; _}) =
499517
| Break | Continue -> true
500518
| Assignment (_, _)
501519
|TargetPE _
502-
|NRFunApp (_, _, _)
520+
|NRFunApp (_, _)
503521
|Return _ | Decl _
504522
|While (_, _)
505523
|For _ | Skip ->
@@ -565,7 +583,7 @@ let unroll_loop_one_step_statement _ =
565583
else
566584
IfElse
567585
( Expr.Fixed.
568-
{lower with pattern= FunApp (StanLib, "Geq__", [upper; lower])}
586+
{lower with pattern= FunApp (StanLib "Geq__", [upper; lower])}
569587
, { pattern=
570588
(let body_unrolled =
571589
subst_args_stmt [loopvar] [lower]
@@ -581,8 +599,7 @@ let unroll_loop_one_step_statement _ =
581599
{ lower with
582600
pattern=
583601
FunApp
584-
( StanLib
585-
, "Plus__"
602+
( StanLib "Plus__"
586603
, [lower; Expr.Helpers.loop_bottom] ) } }
587604
; meta= Location_span.empty }
588605
in
@@ -666,26 +683,21 @@ and accum_any pred b e = b || expr_any pred e
666683

667684
let can_side_effect_top_expr (e : Expr.Typed.t) =
668685
match e.pattern with
669-
| FunApp (t, f, _) ->
670-
String.suffix f 3 = "_lp"
671-
|| (t = CompilerInternal && f = Internal_fun.to_string FnReadParam)
672-
|| (t = CompilerInternal && f = Internal_fun.to_string FnReadData)
673-
|| (t = CompilerInternal && f = Internal_fun.to_string FnWriteParam)
674-
|| (t = CompilerInternal && f = Internal_fun.to_string FnConstrain)
675-
|| (t = CompilerInternal && f = Internal_fun.to_string FnValidateSize)
676-
|| (t = CompilerInternal && f = Internal_fun.to_string FnValidateSize)
677-
|| t = CompilerInternal
678-
&& f = Internal_fun.to_string FnValidateSizeSimplex
679-
|| t = CompilerInternal
680-
&& f = Internal_fun.to_string FnValidateSizeUnitVector
681-
|| (t = CompilerInternal && f = Internal_fun.to_string FnUnconstrain)
686+
| FunApp ((UserDefined f | StanLib f), _) -> String.suffix f 3 = "_lp"
687+
| FunApp
688+
( CompilerInternal
689+
( FnReadParam _ | FnReadData | FnWriteParam | FnConstrain _
690+
| FnValidateSize | FnValidateSizeSimplex | FnValidateSizeUnitVector
691+
| FnUnconstrain _ )
692+
, _ ) ->
693+
true
682694
| _ -> false
683695

684696
let cannot_duplicate_expr (e : Expr.Typed.t) =
685697
let pred e =
686698
can_side_effect_top_expr e
687699
|| ( match e.pattern with
688-
| FunApp (_, f, _) -> String.suffix f 4 = "_rng"
700+
| FunApp ((UserDefined f | StanLib f), _) -> String.suffix f 4 = "_rng"
689701
| _ -> false )
690702
|| (preserve_stability && UnsizedType.is_autodiffable e.meta.type_)
691703
in
@@ -746,7 +758,7 @@ let dead_code_elimination (mir : Program.Typed.t) =
746758
due to side effects. *)
747759
(* TODO: maybe we should revisit that. *)
748760
| Decl _ | TargetPE _
749-
|NRFunApp (_, _, _)
761+
|NRFunApp (_, _)
750762
|Break | Continue | Return _ | Skip ->
751763
stmt
752764
| IfElse (e, b1, b2) -> (

0 commit comments

Comments
 (0)