Skip to content

Commit 4b7142a

Browse files
committed
Start AST-level promotion
1 parent 7a06aad commit 4b7142a

File tree

9 files changed

+187
-173
lines changed

9 files changed

+187
-173
lines changed

src/frontend/Ast.ml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type ('e, 'f) expression =
4444
| ImagNumeral of string
4545
| FunApp of 'f * identifier * 'e list
4646
| CondDistApp of 'f * identifier * 'e list
47+
| Promotion of 'e * UnsizedType.t
4748
(* GetLP is deprecated *)
4849
| GetLP
4950
| GetTarget
@@ -304,7 +305,11 @@ let rec id_of_lvalue {lval; _} =
304305

305306
let rec get_loc_expr (e : untyped_expression) =
306307
match e.expr with
307-
| TernaryIf (e, _, _) | BinOp (e, _, _) | PostfixOp (e, _) | Indexed (e, _) ->
308+
| TernaryIf (e, _, _)
309+
|BinOp (e, _, _)
310+
|PostfixOp (e, _)
311+
|Indexed (e, _)
312+
|Promotion (e, _) ->
308313
get_loc_expr e
309314
| PrefixOp (_, e) | ArrayExpr (e :: _) | RowVectorExpr (e :: _) | Paren e ->
310315
e.emeta.loc.begin_loc

src/frontend/Ast_to_Mir.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ and trans_expr {Ast.expr; Ast.emeta} =
8181
FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap
8282
| Indexed (lhs, indices) ->
8383
Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap
84+
| Promotion (e, ty) -> Promotion (trans_expr e, ty) |> ewrap
8485

8586
and trans_idx = function
8687
| Ast.All -> All

src/frontend/Canonicalize.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ let rec no_parens {expr; emeta} =
142142
| i -> map_index keep_parens i )
143143
l )
144144
; emeta }
145-
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ ->
145+
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ ->
146146
{expr= map_expression no_parens ident expr; emeta}
147147

148148
and keep_parens {expr; emeta} =

src/frontend/Pretty_printing.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ and pp_expression ppf ({expr= e_content; emeta= {loc; _}} : untyped_expression)
272272
| ArrayExpr es -> pf ppf "{@[%a}@]" pp_list_of_expression (es, loc)
273273
| RowVectorExpr es -> pf ppf "[@[%a]@]" pp_list_of_expression (es, loc)
274274
| Paren e -> pf ppf "(%a)" pp_expression e
275+
| Promotion (e, _) -> pp_expression ppf e
275276
| Indexed (e, l) -> (
276277
match l with
277278
| [] -> pf ppf "%a" pp_expression e

src/frontend/SignatureMismatch.ml

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -111,38 +111,40 @@ let rec compare_errors e1 e2 =
111111
| InputMismatch _, _ | _, SuffixMismatch _ -> 1 ) )
112112

113113
let rec check_same_type depth t1 t2 =
114-
let wrap_func = Option.map ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
114+
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
115115
match (t1, t2) with
116-
| t1, t2 when t1 = t2 -> None
117-
| UnsizedType.(UReal, UInt) when depth < 1 -> None
118-
| UnsizedType.(UComplex, UInt) when depth < 1 -> None
119-
| UnsizedType.(UComplex, UReal) when depth < 1 -> None
116+
| t1, t2 when t1 = t2 -> Ok ()
117+
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok ()
118+
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok ()
119+
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok ()
120120
| UFun (_, _, s1, _), UFun (_, _, s2, _)
121121
when Fun_kind.without_propto s1 <> Fun_kind.without_propto s2 ->
122-
Some
122+
Error
123123
(SuffixMismatch (Fun_kind.without_propto s1, Fun_kind.without_propto s2))
124124
|> wrap_func
125125
| UFun (_, rt1, _, _), UFun (_, rt2, _, _) when rt1 <> rt2 ->
126-
Some (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
126+
Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
127127
| UFun (l1, _, _, _), UFun (l2, _, _, _) ->
128128
check_compatible_arguments (depth + 1) l2 l1
129-
|> Option.map ~f:(fun e -> InputMismatch e)
129+
|> Result.map_error ~f:(fun e -> InputMismatch e)
130130
|> wrap_func
131-
| t1, t2 -> Some (TypeMismatch (t1, t2, None))
131+
| t1, t2 -> Error (TypeMismatch (t1, t2, None))
132132

133133
and check_compatible_arguments depth args1 args2 =
134134
match List.zip args1 args2 with
135135
| List.Or_unequal_lengths.Unequal_lengths ->
136-
Some (ArgNumMismatch (List.length args1, List.length args2))
136+
Error (ArgNumMismatch (List.length args1, List.length args2))
137137
| Ok l ->
138138
List.find_mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
139139
match check_same_type depth ut1 ut2 with
140-
| Some e -> Some (ArgError (i + 1, e))
141-
| None ->
140+
| Error e -> Some (ArgError (i + 1, e))
141+
| Ok _ ->
142142
if ad1 = ad2 then None
143143
else if depth < 2 && UnsizedType.autodifftype_can_convert ad1 ad2
144144
then None
145145
else Some (ArgError (i + 1, DataOnlyError)) )
146+
|> Option.map ~f:Result.fail
147+
|> Option.value ~default:(Ok ())
146148

147149
let check_compatible_arguments_mod_conv = check_compatible_arguments 0
148150
let max_n_errors = 5
@@ -155,9 +157,13 @@ let extract_function_types f =
155157
Some (return, args, (fun x -> UserDefined x), mem)
156158
| _ -> None
157159

158-
let returntype env name args =
160+
let arg_type x = Ast.(x.emeta.ad_level, x.emeta.type_)
161+
let get_arg_types = List.map ~f:arg_type
162+
163+
let returntype env name arg_exprs =
159164
(* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
160165
let name = Utils.stdlib_distribution_name name in
166+
let args = get_arg_types arg_exprs in
161167
Environment.find env name
162168
|> List.filter_map ~f:extract_function_types
163169
|> List.sort ~compare:(fun (x, _, _, _) (y, _, _, _) ->
@@ -166,8 +172,9 @@ let returntype env name args =
166172
|> List.fold_until ~init:[]
167173
~f:(fun errors (rt, tys, funkind_constructor, _) ->
168174
match check_compatible_arguments 0 tys args with
169-
| None -> Stop (Ok (rt, funkind_constructor))
170-
| Some e -> Continue (((rt, tys), e) :: errors) )
175+
(* TODO instead of unit, return Ast.typed_expr list which could contain promotions*)
176+
| Ok () -> Stop (Ok (rt, funkind_constructor))
177+
| Error e -> Continue (((rt, tys), e) :: errors) )
171178
~finish:(fun errors ->
172179
let errors =
173180
List.sort errors ~compare:(fun (_, e1) (_, e2) ->
@@ -182,7 +189,7 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
182189
in
183190
let minimal_args =
184191
(UnsizedType.AutoDiffable, minimal_func_type) :: mandatory_arg_tys in
185-
let wrap_err x = Some (minimal_args, ArgError (1, x)) in
192+
let wrap_err x = Error (minimal_args, ArgError (1, x)) in
186193
match args with
187194
| ( _
188195
, ( UnsizedType.UFun (fun_args, ReturnType return_type, suffix, _) as
@@ -195,22 +202,22 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
195202
let suffix = Fun_kind.without_propto suffix in
196203
if suffix = FnPlain || (allow_lpdf && suffix = FnLpdf ()) then
197204
match check_compatible_arguments 1 mandatory mandatory_fun_arg_tys with
198-
| Some x -> wrap_func_error (InputMismatch x)
199-
| None -> (
205+
| Error x -> wrap_func_error (InputMismatch x)
206+
| Ok () -> (
200207
match check_same_type 1 return_type fun_return with
201-
| Some _ ->
208+
| Error _ ->
202209
wrap_func_error
203210
(ReturnTypeMismatch
204211
(ReturnType fun_return, ReturnType return_type) )
205-
| None ->
212+
| Ok () ->
206213
let expected_args =
207214
((UnsizedType.AutoDiffable, func_type) :: mandatory_arg_tys)
208215
@ variadic_arg_tys in
209216
check_compatible_arguments 0 expected_args args
210-
|> Option.map ~f:(fun x -> (expected_args, x)) )
217+
|> Result.map_error ~f:(fun x -> (expected_args, x)) )
211218
else wrap_func_error (SuffixMismatch (FnPlain, suffix))
212219
| (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err
213-
| [] -> Some ([], ArgNumMismatch (List.length mandatory_arg_tys, 0))
220+
| [] -> Error ([], ArgNumMismatch (List.length mandatory_arg_tys, 0))
214221

215222
let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) =
216223
let open Fmt in

src/frontend/SignatureMismatch.mli

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ type signature_error =
2222
val check_compatible_arguments_mod_conv :
2323
(UnsizedType.autodifftype * UnsizedType.t) list
2424
-> (UnsizedType.autodifftype * UnsizedType.t) list
25-
-> function_mismatch option
25+
-> (unit (* Ast.typed_expression list *), function_mismatch) result
2626

2727
val returntype :
2828
Environment.t
2929
-> string
30-
-> (UnsizedType.autodifftype * UnsizedType.t) list
30+
-> Ast.typed_expression list
3131
-> ( UnsizedType.returntype * (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
3232
, signature_error list * bool )
3333
result
@@ -38,8 +38,9 @@ val check_variadic_args :
3838
-> (UnsizedType.autodifftype * UnsizedType.t) list
3939
-> UnsizedType.t
4040
-> (UnsizedType.autodifftype * UnsizedType.t) list
41-
-> ((UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch)
42-
option
41+
-> ( unit
42+
, (UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch )
43+
result
4344

4445
val pp_signature_mismatch :
4546
Format.formatter

src/frontend/Typechecker.ml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ let check_fn ~is_cond_dist loc tenv id es =
424424
(Env.nearest_ident tenv id.name) )
425425
|> error
426426
| _ (* a function *) -> (
427-
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with
427+
match SignatureMismatch.returntype tenv id.name es with
428428
| Ok (Void, _) ->
429429
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
430430
|> error
@@ -458,11 +458,11 @@ let check_reduce_sum ~is_cond_dist loc id es =
458458
SignatureMismatch.check_variadic_args true mandatory_args
459459
mandatory_fun_args UReal (get_arg_types es)
460460
with
461-
| None ->
461+
| Ok () ->
462462
mk_typed_expression
463463
~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es))
464464
~ad_level:(expr_ad_lub es) ~type_:UnsizedType.UReal ~loc
465-
| Some (expected_args, err) ->
465+
| Error (expected_args, err) ->
466466
Semantic_error.illtyped_reduce_sum loc id.name
467467
(List.map ~f:type_of_expr_typed es)
468468
expected_args err
@@ -477,7 +477,7 @@ let check_reduce_sum ~is_cond_dist loc id es =
477477
let expected_args, err =
478478
SignatureMismatch.check_variadic_args true mandatory_args
479479
mandatory_fun_args UReal (get_arg_types es)
480-
|> Option.value_exn in
480+
|> Result.error |> Option.value_exn in
481481
Semantic_error.illtyped_reduce_sum_generic loc id.name
482482
(List.map ~f:type_of_expr_typed es)
483483
expected_args err
@@ -498,12 +498,12 @@ let check_variadic_ode ~is_cond_dist loc id es =
498498
Stan_math_signatures.variadic_ode_mandatory_fun_args
499499
Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types es)
500500
with
501-
| None ->
501+
| Ok () ->
502502
mk_typed_expression
503503
~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es))
504504
~ad_level:(expr_ad_lub es)
505505
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
506-
| Some (expected_args, err) ->
506+
| Error (expected_args, err) ->
507507
Semantic_error.illtyped_variadic_ode loc id.name
508508
(List.map ~f:type_of_expr_typed es)
509509
expected_args err
@@ -657,6 +657,10 @@ and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) :
657657
es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id
658658
| CondDistApp ((), id, es) ->
659659
es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id
660+
| Promotion (e, _) ->
661+
(* Should never happen: promotions are produced during typechecking *)
662+
Common.FatalError.fatal_error_msg
663+
[%message "Promotion in untyped AST" (e : Ast.untyped_expression)]
660664

661665
and check_expression_of_int_type cf tenv e name =
662666
let te = check_expression cf tenv e in
@@ -698,7 +702,7 @@ let check_nrfn loc tenv id es =
698702
(Env.nearest_ident tenv id.name)
699703
|> error
700704
| _ (* a function *) -> (
701-
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with
705+
match SignatureMismatch.returntype tenv id.name es with
702706
| Ok (Void, fnk) ->
703707
mk_typed_statement
704708
~stmt:(NRFunApp (fnk (Fun_kind.suffix_from_name id.name), id, es))

0 commit comments

Comments
 (0)