Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ let rec expr_var_set Expr.Fixed.{pattern; meta} =
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
| Indexed (expr, ix) ->
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
| Promotion (expr, _) -> expr_var_set expr
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]

and index_var_set ix =
Expand Down Expand Up @@ -361,6 +362,7 @@ let rec expr_depth Expr.Fixed.{pattern; _} =
+ max (expr_depth e)
(Option.value ~default:0
(List.max_elt ~compare:compare_int (List.map ~f:idx_depth l)) )
| Promotion (expr, _) -> 1 + expr_depth expr
| EAnd (e1, e2) | EOr (e1, e2) ->
1
+ Option.value ~default:0
Expand Down Expand Up @@ -405,6 +407,10 @@ let rec update_expr_ad_levels autodiffable_variables
let e1 = update_expr_ad_levels autodiffable_variables e1 in
let e2 = update_expr_ad_levels autodiffable_variables e2 in
{pattern= EOr (e1, e2); meta= {e.meta with adlevel= ad_level_sup [e1; e2]}}
| Promotion (expr, ut) ->
let expr' = update_expr_ad_levels autodiffable_variables expr in
{ pattern= Promotion (expr', ut)
; meta= {e.meta with adlevel= ad_level_sup [expr']} }
| Indexed (ixed, i_list) ->
let ixed = update_expr_ad_levels autodiffable_variables ixed in
let i_list =
Expand Down
2 changes: 2 additions & 0 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ let print_mfp to_string (mfp : (int, 'a entry_exit) Map.Poly.t)
let rec free_vars_expr (e : Expr.Typed.t) =
match e.pattern with
| Var x -> Set.Poly.singleton x
| Promotion (expr, _) -> free_vars_expr expr
| Lit (_, _) -> Set.Poly.empty
| FunApp (kind, l) -> free_vars_fnapp kind l
| TernaryIf (e1, e2, e3) ->
Expand Down Expand Up @@ -544,6 +545,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
(Expr.Typed.Set.singleton e)
( match e.pattern with
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
| Promotion (expr, _) -> used_subexpressions_expr expr
| FunApp (k, l) ->
Expr.Typed.Set.union_list
(List.map ~f:used_subexpressions_expr (l @ Fun_kind.collect_exprs k))
Expand Down
7 changes: 5 additions & 2 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
match pattern with
| Var _ -> ([], [], e)
| Lit (_, _) -> ([], [], e)
| Promotion (expr, ut) ->
let d, sl, expr' = inline_function_expression propto adt fim expr in
(d, sl, {e with pattern= Promotion (expr', ut)})
| FunApp (kind, es) -> (
let d_list, s_list, es =
inline_list (inline_function_expression propto adt fim) es in
Expand Down Expand Up @@ -1024,8 +1027,8 @@ let block_fixing mir =
(* TODO: add tests *)
(* TODO: add pass to get rid of redundant declarations? *)

(**
* A generic optimization pass for finding a minimal set of variables that
(**
* A generic optimization pass for finding a minimal set of variables that
* are generated by some circumstance, and then updating the MIR with that set.
* @param gen_variables: the variables that must be added to the set at
* the given statement
Expand Down
1 change: 1 addition & 0 deletions src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
pattern=
( match e.pattern with
| Var _ | Lit (_, _) -> e.pattern
| Promotion (expr, ut) -> Promotion (eval_expr expr, ut)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] this could be a bit smarter, like promoting literals to literals and collapsing promotion of promotion.

| FunApp (kind, l) -> (
let l = List.map ~f:(eval_expr ~preserve_stability) l in
match kind with
Expand Down
18 changes: 14 additions & 4 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type ('e, 'f) expression =
| ImagNumeral of string
| FunApp of 'f * identifier * 'e list
| CondDistApp of 'f * identifier * 'e list
| Promotion of 'e * UnsizedType.t
(* GetLP is deprecated *)
| GetLP
| GetTarget
Expand Down Expand Up @@ -250,9 +251,14 @@ type typed_program = typed_statement program [@@deriving sexp, compare, map]
(** Forgetful function from typed to untyped expressions *)
let rec untyped_expression_of_typed_expression ({expr; emeta} : typed_expression)
: untyped_expression =
{ expr=
map_expression untyped_expression_of_typed_expression (fun _ -> ()) expr
; emeta= {loc= emeta.loc} }
match expr with
| Promotion (e, _) -> untyped_expression_of_typed_expression e
| _ ->
{ expr=
map_expression untyped_expression_of_typed_expression
(fun _ -> ())
expr
; emeta= {loc= emeta.loc} }

let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
untyped_lval =
Expand Down Expand Up @@ -304,7 +310,11 @@ let rec id_of_lvalue {lval; _} =

let rec get_loc_expr (e : untyped_expression) =
match e.expr with
| TernaryIf (e, _, _) | BinOp (e, _, _) | PostfixOp (e, _) | Indexed (e, _) ->
| TernaryIf (e, _, _)
|BinOp (e, _, _)
|PostfixOp (e, _)
|Indexed (e, _)
|Promotion (e, _) ->
get_loc_expr e
| PrefixOp (_, e) | ArrayExpr (e :: _) | RowVectorExpr (e :: _) | Paren e ->
e.emeta.loc.begin_loc
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ and trans_expr {Ast.expr; Ast.emeta} =
FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap
| Indexed (lhs, indices) ->
Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap
| Promotion (e, ty) -> Promotion (trans_expr e, ty) |> ewrap

and trans_idx = function
| Ast.All -> All
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ let rec no_parens {expr; emeta} =
| i -> map_index keep_parens i )
l )
; emeta }
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ ->
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ ->
{expr= map_expression no_parens ident expr; emeta}

and keep_parens {expr; emeta} =
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ and pp_expression ppf ({expr= e_content; emeta= {loc; _}} : untyped_expression)
| ArrayExpr es -> pf ppf "{@[%a}@]" pp_list_of_expression (es, loc)
| RowVectorExpr es -> pf ppf "[@[%a]@]" pp_list_of_expression (es, loc)
| Paren e -> pf ppf "(%a)" pp_expression e
| Promotion (e, _) -> pp_expression ppf e
| Indexed (e, l) -> (
match l with
| [] -> pf ppf "%a" pp_expression e
Expand Down
79 changes: 51 additions & 28 deletions src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -110,37 +110,50 @@ let rec compare_errors e1 e2 =
| SuffixMismatch _, _ | _, InputMismatch _ -> -1
| InputMismatch _, _ | _, SuffixMismatch _ -> 1 ) )

type promotions = None | RealPromotion | ComplexPromotion

let rec check_same_type depth t1 t2 =
let wrap_func = Option.map ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
match (t1, t2) with
| t1, t2 when t1 = t2 -> None
| UnsizedType.(UReal, UInt) when depth < 1 -> None
| t1, t2 when t1 = t2 -> Ok None
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok RealPromotion
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok ComplexPromotion
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok ComplexPromotion
(* Arrays: Try to recursively promote, but make sure the error is for these types,
not the recursive call *)
| UArray nt1, UArray nt2 ->
check_same_type depth nt1 nt2
|> Result.map_error ~f:(function
| TypeMismatch _ -> TypeMismatch (t1, t2, None)
| e -> e )
| UFun (_, _, s1, _), UFun (_, _, s2, _)
when Fun_kind.without_propto s1 <> Fun_kind.without_propto s2 ->
Some
Error
(SuffixMismatch (Fun_kind.without_propto s1, Fun_kind.without_propto s2))
|> wrap_func
| UFun (_, rt1, _, _), UFun (_, rt2, _, _) when rt1 <> rt2 ->
Some (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
| UFun (l1, _, _, _), UFun (l2, _, _, _) ->
check_compatible_arguments (depth + 1) l2 l1
|> Option.map ~f:(fun e -> InputMismatch e)
|> wrap_func
| t1, t2 -> Some (TypeMismatch (t1, t2, None))
Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
| UFun (l1, _, _, _), UFun (l2, _, _, _) -> (
match check_compatible_arguments (depth + 1) l2 l1 with
| Ok _ -> Ok None
| Error e -> Error (InputMismatch e) |> wrap_func )
| t1, t2 -> Error (TypeMismatch (t1, t2, None))

and check_compatible_arguments depth args1 args2 =
match List.zip args1 args2 with
and check_compatible_arguments depth typs args2 :
(promotions list, function_mismatch) result =
match List.zip typs args2 with
| List.Or_unequal_lengths.Unequal_lengths ->
Some (ArgNumMismatch (List.length args1, List.length args2))
Error (ArgNumMismatch (List.length typs, List.length args2))
| Ok l ->
List.find_mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
List.mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
match check_same_type depth ut1 ut2 with
| Some e -> Some (ArgError (i + 1, e))
| None ->
if ad1 = ad2 then None
| Error e -> Error (ArgError (i + 1, e))
| Ok p ->
if ad1 = ad2 then Ok p
else if depth < 2 && UnsizedType.autodifftype_can_convert ad1 ad2
then None
else Some (ArgError (i + 1, DataOnlyError)) )
then Ok p
else Error (ArgError (i + 1, DataOnlyError)) )
|> Result.all

let check_compatible_arguments_mod_conv = check_compatible_arguments 0
let max_n_errors = 5
Expand All @@ -153,6 +166,16 @@ let extract_function_types f =
Some (return, args, (fun x -> UserDefined x), mem)
| _ -> None

let promote es promotions =
List.map2_exn es promotions ~f:(fun (exp : Ast.typed_expression) prom ->
let emeta = exp.emeta in
match prom with
| RealPromotion when emeta.type_ <> UReal ->
Ast.{expr= Ast.Promotion (exp, UReal); emeta}
| ComplexPromotion when emeta.type_ <> UComplex ->
{expr= Promotion (exp, UComplex); emeta}
| _ -> exp )

let returntype env name args =
(* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
let name = Utils.stdlib_distribution_name name in
Expand All @@ -164,8 +187,8 @@ let returntype env name args =
|> List.fold_until ~init:[]
~f:(fun errors (rt, tys, funkind_constructor, _) ->
match check_compatible_arguments 0 tys args with
| None -> Stop (Ok (rt, funkind_constructor))
| Some e -> Continue (((rt, tys), e) :: errors) )
| Ok p -> Stop (Ok (rt, funkind_constructor, p))
| Error e -> Continue (((rt, tys), e) :: errors) )
~finish:(fun errors ->
let errors =
List.sort errors ~compare:(fun (_, e1) (_, e2) ->
Expand All @@ -180,7 +203,7 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
in
let minimal_args =
(UnsizedType.AutoDiffable, minimal_func_type) :: mandatory_arg_tys in
let wrap_err x = Some (minimal_args, ArgError (1, x)) in
let wrap_err x = Error (minimal_args, ArgError (1, x)) in
match args with
| ( _
, ( UnsizedType.UFun (fun_args, ReturnType return_type, suffix, _) as
Expand All @@ -193,22 +216,22 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
let suffix = Fun_kind.without_propto suffix in
if suffix = FnPlain || (allow_lpdf && suffix = FnLpdf ()) then
match check_compatible_arguments 1 mandatory mandatory_fun_arg_tys with
| Some x -> wrap_func_error (InputMismatch x)
| None -> (
| Error x -> wrap_func_error (InputMismatch x)
| Ok _ -> (
match check_same_type 1 return_type fun_return with
| Some _ ->
| Error _ ->
wrap_func_error
(ReturnTypeMismatch
(ReturnType fun_return, ReturnType return_type) )
| None ->
| Ok _ ->
let expected_args =
((UnsizedType.AutoDiffable, func_type) :: mandatory_arg_tys)
@ variadic_arg_tys in
check_compatible_arguments 0 expected_args args
|> Option.map ~f:(fun x -> (expected_args, x)) )
|> Result.map_error ~f:(fun x -> (expected_args, x)) )
else wrap_func_error (SuffixMismatch (FnPlain, suffix))
| (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err
| [] -> Some ([], ArgNumMismatch (List.length mandatory_arg_tys, 0))
| [] -> Error ([], ArgNumMismatch (List.length mandatory_arg_tys, 0))

let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) =
let open Fmt in
Expand Down
20 changes: 16 additions & 4 deletions src/frontend/SignatureMismatch.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,27 @@ type signature_error =
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
* function_mismatch

(** Indicate a promotion by the resulting type *)
type promotions = private None | RealPromotion | ComplexPromotion

val check_compatible_arguments_mod_conv :
(UnsizedType.autodifftype * UnsizedType.t) list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> function_mismatch option
-> (promotions list, function_mismatch) result

val promote :
Ast.typed_expression list -> promotions list -> Ast.typed_expression list
(** Given a list of expressions (arguments) and a list of [promotions],
return a list of expressions which include the
[Promotion] expression as appropiate *)

val returntype :
Environment.t
-> string
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> ( UnsizedType.returntype * (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
-> ( UnsizedType.returntype
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
* promotions list
, signature_error list * bool )
result

Expand All @@ -38,8 +49,9 @@ val check_variadic_args :
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> UnsizedType.t
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> ((UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch)
option
-> ( promotions list
, (UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch )
result

val pp_signature_mismatch :
Format.formatter
Expand Down
Loading