@@ -219,7 +219,7 @@ let rec inline_function_expression propto adt fim
219219 match pattern with
220220 | Var _ -> ([] , [] , e)
221221 | Lit (_ , _ ) -> ([] , [] , e)
222- | FunApp (t , s , es ) -> (
222+ | FunApp (kind , es ) -> (
223223 let dse_list =
224224 List. map ~f: (inline_function_expression propto adt fim) es
225225 in
@@ -231,30 +231,45 @@ let rec inline_function_expression propto adt fim
231231 List. concat (List. rev (List. map ~f: (function _ , x , _ -> x) dse_list))
232232 in
233233 let es = List. map ~f: (function _ , _ , x -> x) dse_list in
234- let s = if propto then s else Middle.Utils. stdlib_distribution_name s in
235- match Map. find fim s with
236- | None -> (d_list, s_list, {e with pattern= FunApp (t, s, es)})
237- | Some (rt , args , b ) ->
238- let x = Gensym. generate ~prefix: " inline_" () in
239- let handle = handle_early_returns (Some x) in
240- let d_list2, s_list2, (e : Expr.Typed.t ) =
241- ( [ Stmt.Fixed.Pattern. Decl
242- {decl_adtype= adt; decl_id= x; decl_type= Option. value_exn rt}
243- ]
244- (* We should minimize the code that's having its variables
234+ match kind with
235+ | CompilerInternal _ ->
236+ (d_list, s_list, {e with pattern= FunApp (kind, es)})
237+ | UserDefined fname | StanLib fname -> (
238+ let fname =
239+ if propto then fname
240+ else Middle.Utils. stdlib_distribution_name fname
241+ in
242+ match Map. find fim fname with
243+ | None ->
244+ let fun_kind =
245+ match kind with
246+ | Fun_kind. UserDefined _ -> Fun_kind. UserDefined fname
247+ | _ -> StanLib fname
248+ in
249+ (d_list, s_list, {e with pattern= FunApp (fun_kind, es)})
250+ | Some (rt , args , b ) ->
251+ let x = Gensym. generate ~prefix: " inline_" () in
252+ let handle = handle_early_returns (Some x) in
253+ let d_list2, s_list2, (e : Expr.Typed.t ) =
254+ ( [ Stmt.Fixed.Pattern. Decl
255+ { decl_adtype= adt
256+ ; decl_id= x
257+ ; decl_type= Option. value_exn rt } ]
258+ (* We should minimize the code that's having its variables
245259 replaced to avoid conflict with the (two) new dummy
246260 variables introduced by inlining *)
247- , [handle (replace_fresh_local_vars (subst_args_stmt args es b))]
248- , { pattern= Var x
249- ; meta=
250- Expr.Typed.Meta.
251- { type_= Type. to_unsized (Option. value_exn rt)
252- ; adlevel= adt
253- ; loc= Location_span. empty } } )
254- in
255- let d_list = d_list @ d_list2 in
256- let s_list = s_list @ s_list2 in
257- (d_list, s_list, e) )
261+ , [ handle
262+ (replace_fresh_local_vars (subst_args_stmt args es b)) ]
263+ , { pattern= Var x
264+ ; meta=
265+ Expr.Typed.Meta.
266+ { type_= Type. to_unsized (Option. value_exn rt)
267+ ; adlevel= adt
268+ ; loc= Location_span. empty } } )
269+ in
270+ let d_list = d_list @ d_list2 in
271+ let s_list = s_list @ s_list2 in
272+ (d_list, s_list, e) ) )
258273 | TernaryIf (e1 , e2 , e3 ) ->
259274 let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in
260275 let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in
@@ -347,7 +362,7 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
347362 | TargetPE e ->
348363 let d, s, e = inline_function_expression propto adt fim e in
349364 slist_concat_no_loc (d @ s) (TargetPE e)
350- | NRFunApp (t , s , es ) ->
365+ | NRFunApp (kind , es ) ->
351366 let dse_list =
352367 List. map ~f: (inline_function_expression propto adt fim) es
353368 in
@@ -362,14 +377,17 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
362377 in
363378 let es = List. map ~f: (function _ , _ , x -> x) dse_list in
364379 slist_concat_no_loc (d_list @ s_list)
365- ( match Map. find fim s with
366- | None -> NRFunApp (t, s, es)
367- | Some (_ , args , b ) ->
368- let b = replace_fresh_local_vars b in
369- let b = handle_early_returns None b in
370- (subst_args_stmt args es
371- {pattern= b; meta= Location_span. empty})
372- .pattern )
380+ ( match kind with
381+ | CompilerInternal _ -> NRFunApp (kind, es)
382+ | UserDefined s | StanLib s -> (
383+ match Map. find fim s with
384+ | None -> NRFunApp (kind, es)
385+ | Some (_ , args , b ) ->
386+ let b = replace_fresh_local_vars b in
387+ let b = handle_early_returns None b in
388+ (subst_args_stmt args es
389+ {pattern= b; meta= Location_span. empty})
390+ .pattern ) )
373391 | Return e -> (
374392 match e with
375393 | None -> Return None
@@ -499,7 +517,7 @@ let rec contains_top_break_or_continue Stmt.Fixed.({pattern; _}) =
499517 | Break | Continue -> true
500518 | Assignment (_, _)
501519 | TargetPE _
502- | NRFunApp (_, _, _ )
520+ | NRFunApp (_, _)
503521 | Return _ | Decl _
504522 | While (_, _)
505523 | For _ | Skip ->
@@ -565,7 +583,7 @@ let unroll_loop_one_step_statement _ =
565583 else
566584 IfElse
567585 ( Expr.Fixed.
568- {lower with pattern= FunApp (StanLib , " Geq__" , [upper; lower])}
586+ {lower with pattern= FunApp (StanLib " Geq__" , [upper; lower])}
569587 , { pattern=
570588 (let body_unrolled =
571589 subst_args_stmt [loopvar] [lower]
@@ -581,8 +599,7 @@ let unroll_loop_one_step_statement _ =
581599 { lower with
582600 pattern=
583601 FunApp
584- ( StanLib
585- , " Plus__"
602+ ( StanLib " Plus__"
586603 , [lower; Expr.Helpers. loop_bottom] ) } }
587604 ; meta= Location_span. empty }
588605 in
@@ -666,26 +683,21 @@ and accum_any pred b e = b || expr_any pred e
666683
667684let can_side_effect_top_expr (e : Expr.Typed.t ) =
668685 match e.pattern with
669- | FunApp (t , f , _ ) ->
670- String. suffix f 3 = " _lp"
671- || (t = CompilerInternal && f = Internal_fun. to_string FnReadParam )
672- || (t = CompilerInternal && f = Internal_fun. to_string FnReadData )
673- || (t = CompilerInternal && f = Internal_fun. to_string FnWriteParam )
674- || (t = CompilerInternal && f = Internal_fun. to_string FnConstrain )
675- || (t = CompilerInternal && f = Internal_fun. to_string FnValidateSize )
676- || (t = CompilerInternal && f = Internal_fun. to_string FnValidateSize )
677- || t = CompilerInternal
678- && f = Internal_fun. to_string FnValidateSizeSimplex
679- || t = CompilerInternal
680- && f = Internal_fun. to_string FnValidateSizeUnitVector
681- || (t = CompilerInternal && f = Internal_fun. to_string FnUnconstrain )
686+ | FunApp ((UserDefined f | StanLib f ), _ ) -> String. suffix f 3 = " _lp"
687+ | FunApp
688+ ( CompilerInternal
689+ ( FnReadParam _ | FnReadData | FnWriteParam | FnConstrain _
690+ | FnValidateSize | FnValidateSizeSimplex | FnValidateSizeUnitVector
691+ | FnUnconstrain _ )
692+ , _ ) ->
693+ true
682694 | _ -> false
683695
684696let cannot_duplicate_expr (e : Expr.Typed.t ) =
685697 let pred e =
686698 can_side_effect_top_expr e
687699 || ( match e.pattern with
688- | FunApp (_ , f , _ ) -> String. suffix f 4 = " _rng"
700+ | FunApp (( UserDefined f | StanLib f ) , _ ) -> String. suffix f 4 = " _rng"
689701 | _ -> false )
690702 || (preserve_stability && UnsizedType. is_autodiffable e.meta.type_)
691703 in
@@ -746,7 +758,7 @@ let dead_code_elimination (mir : Program.Typed.t) =
746758 due to side effects. *)
747759 (* TODO: maybe we should revisit that. *)
748760 | Decl _ | TargetPE _
749- | NRFunApp (_, _, _ )
761+ | NRFunApp (_, _)
750762 | Break | Continue | Return _ | Skip ->
751763 stmt
752764 | IfElse (e , b1 , b2 ) -> (
0 commit comments