Skip to content

Commit b964175

Browse files
authored
Merge pull request #1021 from adamhaber/feature/incomplete-probability-calls
Better error messages for incomplete probability calls
2 parents 3699f24 + 5bebf3b commit b964175

20 files changed

+297
-1
lines changed

src/frontend/Semantic_error.ml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ module TypeError = struct
3434
| ReturningFnExpectedNonReturningFound of string
3535
| ReturningFnExpectedNonFnFound of string
3636
| ReturningFnExpectedUndeclaredIdentFound of string
37+
| ReturningFnExpectedUndeclaredDistSuffixFound of string * string
38+
| ReturningFnExpectedWrongDistSuffixFound of string * string
3739
| NonReturningFnExpectedReturningFound of string
3840
| NonReturningFnExpectedNonFnFound of string
3941
| NonReturningFnExpectedUndeclaredIdentFound of string
@@ -163,6 +165,22 @@ module TypeError = struct
163165
"A returning function was expected but an undeclared identifier \
164166
'%s' was supplied."
165167
fn_name
168+
| ReturningFnExpectedUndeclaredDistSuffixFound (prefix, suffix) ->
169+
Fmt.pf ppf "Function '%s_%s' is not implemented for distribution '%s'."
170+
prefix suffix prefix
171+
| ReturningFnExpectedWrongDistSuffixFound (prefix, suffix) ->
172+
let newsuffix =
173+
match suffix with
174+
| "lpdf" -> "lpmf"
175+
| "lupdf" -> "lupmf"
176+
| "lpmf" -> "lpdf"
177+
| "lupmf" -> "lupdf"
178+
| _ -> raise_s [%message "This should never happen."]
179+
in
180+
Fmt.pf ppf
181+
"Function '%s_%s' is not implemented for distribution '%s', use \
182+
'%s_%s' instead."
183+
prefix suffix prefix prefix newsuffix
166184
| NonReturningFnExpectedUndeclaredIdentFound fn_name ->
167185
Fmt.pf ppf
168186
"A non-returning function was expected but an undeclared identifier \
@@ -487,6 +505,16 @@ let returning_fn_expected_nonfn_found loc name =
487505
let returning_fn_expected_undeclaredident_found loc name =
488506
TypeError (loc, TypeError.ReturningFnExpectedUndeclaredIdentFound name)
489507

508+
let returning_fn_expected_undeclared_dist_suffix_found loc (prefix, suffix) =
509+
TypeError
510+
( loc
511+
, TypeError.ReturningFnExpectedUndeclaredDistSuffixFound (prefix, suffix)
512+
)
513+
514+
let returning_fn_expected_wrong_dist_suffix_found loc (prefix, suffix) =
515+
TypeError
516+
(loc, TypeError.ReturningFnExpectedWrongDistSuffixFound (prefix, suffix))
517+
490518
let nonreturning_fn_expected_returning_found loc name =
491519
TypeError (loc, TypeError.NonReturningFnExpectedReturningFound name)
492520

src/frontend/Semantic_error.mli

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ val returning_fn_expected_nonfn_found : Location_span.t -> string -> t
3636
val returning_fn_expected_undeclaredident_found :
3737
Location_span.t -> string -> t
3838

39+
val returning_fn_expected_undeclared_dist_suffix_found :
40+
Location_span.t -> string * string -> t
41+
42+
val returning_fn_expected_wrong_dist_suffix_found :
43+
Location_span.t -> string * string -> t
44+
3945
val illtyped_reduce_sum :
4046
Location_span.t
4147
-> string

src/frontend/Typechecker.ml

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,38 @@ let check_fn ~is_cond_dist loc tenv id es =
394394
(Utils.normalized_name id.name)) ->
395395
Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error
396396
| [] ->
397-
Semantic_error.returning_fn_expected_undeclaredident_found loc id.name
397+
( match Utils.split_distribution_suffix id.name with
398+
| Some (prefix, suffix) -> (
399+
let known_families =
400+
List.map
401+
~f:(fun (_, y, _, _) -> y)
402+
Stan_math_signatures.distributions
403+
in
404+
let is_known_family s =
405+
List.mem known_families s ~equal:String.equal
406+
in
407+
match suffix with
408+
| ("lpmf" | "lumpf") when Env.mem tenv (prefix ^ "_lpdf") ->
409+
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
410+
(prefix, suffix)
411+
| ("lpdf" | "lumdf") when Env.mem tenv (prefix ^ "_lpmf") ->
412+
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
413+
(prefix, suffix)
414+
| _ ->
415+
if
416+
is_known_family prefix
417+
&& List.mem ~equal:String.equal
418+
Utils.cumulative_distribution_suffices_w_rng suffix
419+
then
420+
Semantic_error
421+
.returning_fn_expected_undeclared_dist_suffix_found loc
422+
(prefix, suffix)
423+
else
424+
Semantic_error.returning_fn_expected_undeclaredident_found loc
425+
id.name )
426+
| None ->
427+
Semantic_error.returning_fn_expected_undeclaredident_found loc
428+
id.name )
398429
|> error
399430
| _ (* a function *) -> (
400431
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with

src/middle/Utils.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ let conditioning_suffices =
1414
["_lpdf"; "_lupdf"; "_lupmf"; "_lpmf"; "_cdf"; "_lcdf"; "_lccdf"]
1515

1616
let conditioning_suffices_w_log = conditioning_suffices @ ["_log"]
17+
18+
let cumulative_distribution_suffices =
19+
["cdf"; "lcdf"; "lccdf"; "cdf_log"; "ccdf_log"]
20+
21+
let cumulative_distribution_suffices_w_rng =
22+
cumulative_distribution_suffices @ ["rng"]
23+
1724
let is_user_ident = Fn.non (String.is_suffix ~suffix:"__")
1825

1926
let unnormalized_suffix = function
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
model {
2+
target += foo_lpdf(1);
3+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
data {
2+
int foo_lpmf;
3+
}
4+
model {
5+
target += foo_lpdf(1);
6+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
data {
2+
}
3+
model {
4+
target += von_mises_ccdf_log(1, 0,1);
5+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
(include ../dune)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
data {
2+
}
3+
model {
4+
// known family, known suffix, not implemented
5+
target += binomial_lpdf(1|0,1);
6+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
data {
2+
}
3+
model {
4+
// known family, known suffix, not implemented
5+
target += normal_lpmf(1|0,1);
6+
}

0 commit comments

Comments
 (0)