@@ -111,38 +111,40 @@ let rec compare_errors e1 e2 =
111111 | InputMismatch _ , _ | _ , SuffixMismatch _ -> 1 ) )
112112
113113let rec check_same_type depth t1 t2 =
114- let wrap_func = Option. map ~f: (fun e -> TypeMismatch (t1, t2, Some e)) in
114+ let wrap_func = Result. map_error ~f: (fun e -> TypeMismatch (t1, t2, Some e)) in
115115 match (t1, t2) with
116- | t1 , t2 when t1 = t2 -> None
117- | UnsizedType. (UReal, UInt) when depth < 1 -> None
118- | UnsizedType. (UComplex, UInt) when depth < 1 -> None
119- | UnsizedType. (UComplex, UReal) when depth < 1 -> None
116+ | t1 , t2 when t1 = t2 -> Ok ()
117+ | UnsizedType. (UReal, UInt) when depth < 1 -> Ok ()
118+ | UnsizedType. (UComplex, UInt) when depth < 1 -> Ok ()
119+ | UnsizedType. (UComplex, UReal) when depth < 1 -> Ok ()
120120 | UFun (_, _, s1, _), UFun (_, _, s2, _)
121121 when Fun_kind. without_propto s1 <> Fun_kind. without_propto s2 ->
122- Some
122+ Error
123123 (SuffixMismatch (Fun_kind. without_propto s1, Fun_kind. without_propto s2))
124124 |> wrap_func
125125 | UFun (_ , rt1 , _ , _ ), UFun (_ , rt2 , _ , _ ) when rt1 <> rt2 ->
126- Some (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
126+ Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
127127 | UFun (l1 , _ , _ , _ ), UFun (l2 , _ , _ , _ ) ->
128128 check_compatible_arguments (depth + 1 ) l2 l1
129- |> Option. map ~f: (fun e -> InputMismatch e)
129+ |> Result. map_error ~f: (fun e -> InputMismatch e)
130130 |> wrap_func
131- | t1 , t2 -> Some (TypeMismatch (t1, t2, None ))
131+ | t1 , t2 -> Error (TypeMismatch (t1, t2, None ))
132132
133133and check_compatible_arguments depth args1 args2 =
134134 match List. zip args1 args2 with
135135 | List.Or_unequal_lengths. Unequal_lengths ->
136- Some (ArgNumMismatch (List. length args1, List. length args2))
136+ Error (ArgNumMismatch (List. length args1, List. length args2))
137137 | Ok l ->
138138 List. find_mapi l ~f: (fun i ((ad1 , ut1 ), (ad2 , ut2 )) ->
139139 match check_same_type depth ut1 ut2 with
140- | Some e -> Some (ArgError (i + 1 , e))
141- | None ->
140+ | Error e -> Some (ArgError (i + 1 , e))
141+ | Ok _ ->
142142 if ad1 = ad2 then None
143143 else if depth < 2 && UnsizedType. autodifftype_can_convert ad1 ad2
144144 then None
145145 else Some (ArgError (i + 1 , DataOnlyError )) )
146+ |> Option. map ~f: Result. fail
147+ |> Option. value ~default: (Ok () )
146148
147149let check_compatible_arguments_mod_conv = check_compatible_arguments 0
148150let max_n_errors = 5
@@ -155,9 +157,13 @@ let extract_function_types f =
155157 Some (return, args, (fun x -> UserDefined x), mem)
156158 | _ -> None
157159
158- let returntype env name args =
160+ let arg_type x = Ast. (x.emeta.ad_level, x.emeta.type_)
161+ let get_arg_types = List. map ~f: arg_type
162+
163+ let returntype env name arg_exprs =
159164 (* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
160165 let name = Utils. stdlib_distribution_name name in
166+ let args = get_arg_types arg_exprs in
161167 Environment. find env name
162168 |> List. filter_map ~f: extract_function_types
163169 |> List. sort ~compare: (fun (x , _ , _ , _ ) (y , _ , _ , _ ) ->
@@ -166,8 +172,9 @@ let returntype env name args =
166172 |> List. fold_until ~init: []
167173 ~f: (fun errors (rt , tys , funkind_constructor , _ ) ->
168174 match check_compatible_arguments 0 tys args with
169- | None -> Stop (Ok (rt, funkind_constructor))
170- | Some e -> Continue (((rt, tys), e) :: errors) )
175+ (* TODO instead of unit, return Ast.typed_expr list which could contain promotions*)
176+ | Ok () -> Stop (Ok (rt, funkind_constructor))
177+ | Error e -> Continue (((rt, tys), e) :: errors) )
171178 ~finish: (fun errors ->
172179 let errors =
173180 List. sort errors ~compare: (fun (_ , e1 ) (_ , e2 ) ->
@@ -182,7 +189,7 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
182189 in
183190 let minimal_args =
184191 (UnsizedType. AutoDiffable , minimal_func_type) :: mandatory_arg_tys in
185- let wrap_err x = Some (minimal_args, ArgError (1 , x)) in
192+ let wrap_err x = Error (minimal_args, ArgError (1 , x)) in
186193 match args with
187194 | ( _
188195 , ( UnsizedType. UFun (fun_args, ReturnType return_type, suffix, _) as
@@ -195,22 +202,22 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
195202 let suffix = Fun_kind. without_propto suffix in
196203 if suffix = FnPlain || (allow_lpdf && suffix = FnLpdf () ) then
197204 match check_compatible_arguments 1 mandatory mandatory_fun_arg_tys with
198- | Some x -> wrap_func_error (InputMismatch x)
199- | None -> (
205+ | Error x -> wrap_func_error (InputMismatch x)
206+ | Ok () -> (
200207 match check_same_type 1 return_type fun_return with
201- | Some _ ->
208+ | Error _ ->
202209 wrap_func_error
203210 (ReturnTypeMismatch
204211 (ReturnType fun_return, ReturnType return_type) )
205- | None ->
212+ | Ok () ->
206213 let expected_args =
207214 ((UnsizedType. AutoDiffable , func_type) :: mandatory_arg_tys)
208215 @ variadic_arg_tys in
209216 check_compatible_arguments 0 expected_args args
210- |> Option. map ~f: (fun x -> (expected_args, x)) )
217+ |> Result. map_error ~f: (fun x -> (expected_args, x)) )
211218 else wrap_func_error (SuffixMismatch (FnPlain , suffix))
212219 | (_ , x ) :: _ -> TypeMismatch (minimal_func_type, x, None ) |> wrap_err
213- | [] -> Some ([] , ArgNumMismatch (List. length mandatory_arg_tys, 0 ))
220+ | [] -> Error ([] , ArgNumMismatch (List. length mandatory_arg_tys, 0 ))
214221
215222let pp_signature_mismatch ppf (name , arg_tys , (sigs , omitted )) =
216223 let open Fmt in
0 commit comments