Skip to content

Commit 92092ea

Browse files
authored
Merge pull request #1059 from nhuurre/lvalue-indexing-bug
Add warning for lvalue indexing inconsistency.
2 parents 7e7382d + c4cb289 commit 92092ea

File tree

13 files changed

+227
-63
lines changed

13 files changed

+227
-63
lines changed

src/frontend/Canonicalize.ml

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,28 @@ let rec replace_deprecated_expr
9191
ident expr in
9292
{expr; emeta}
9393

94-
let replace_deprecated_lval deprecated_userdefined =
95-
map_lval_with (replace_deprecated_expr deprecated_userdefined) ident
94+
let replace_deprecated_lval deprecated_userdefined {lval; lmeta} =
95+
let is_multiindex = function
96+
| Single {emeta= {type_= Middle.UnsizedType.UInt; _}; _} -> false
97+
| _ -> true in
98+
let rec flatten_multi = function
99+
| LVariable id -> (LVariable id, None)
100+
| LIndexed ({lval; lmeta}, idcs) -> (
101+
let outer =
102+
List.map idcs
103+
~f:(map_index (replace_deprecated_expr deprecated_userdefined))
104+
in
105+
let unwrap = Option.value_map ~default:[] ~f:fst in
106+
match flatten_multi lval with
107+
| lval, inner when List.exists ~f:is_multiindex outer ->
108+
(lval, Some (unwrap inner @ outer, lmeta))
109+
| lval, None -> (LIndexed ({lval; lmeta}, outer), None)
110+
| lval, Some (inner, _) -> (lval, Some (inner @ outer, lmeta)) ) in
111+
let lval =
112+
match flatten_multi lval with
113+
| lval, None -> lval
114+
| lval, Some (idcs, lmeta) -> LIndexed ({lval; lmeta}, idcs) in
115+
{lval; lmeta}
96116

97117
let rec replace_deprecated_stmt
98118
(deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t)

src/frontend/Semantic_error.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ module StatementError = struct
320320
type t =
321321
| CannotAssignToReadOnly of string
322322
| CannotAssignToGlobal of string
323+
| LValueMultiIndexing
323324
| InvalidSamplingPDForPMF
324325
| InvalidSamplingCDForCCDF of string
325326
| InvalidSamplingNoSuchDistribution of string
@@ -352,6 +353,9 @@ module StatementError = struct
352353
Fmt.pf ppf
353354
"Cannot assign to global variable '%s' declared in previous blocks."
354355
name
356+
| LValueMultiIndexing ->
357+
Fmt.pf ppf
358+
"Left hand side of an assignment cannot have nested multi-indexing."
355359
| TargetPlusEqualsOutsideModelOrLogProb ->
356360
Fmt.pf ppf
357361
"Target can only be accessed in the model block or in definitions of \
@@ -604,6 +608,9 @@ let cannot_assign_to_read_only loc name =
604608
let cannot_assign_to_global loc name =
605609
StatementError (loc, StatementError.CannotAssignToGlobal name)
606610

611+
let cannot_assign_to_multiindex loc =
612+
StatementError (loc, StatementError.LValueMultiIndexing)
613+
607614
let invalid_sampling_pdf_or_pmf loc =
608615
StatementError (loc, StatementError.InvalidSamplingPDForPMF)
609616

src/frontend/Semantic_error.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ val empty_array : Location_span.t -> t
103103
val bad_int_literal : Location_span.t -> t
104104
val cannot_assign_to_read_only : Location_span.t -> string -> t
105105
val cannot_assign_to_global : Location_span.t -> string -> t
106+
val cannot_assign_to_multiindex : Location_span.t -> t
106107
val invalid_sampling_pdf_or_pmf : Location_span.t -> t
107108
val invalid_sampling_cdf_or_ccdf : Location_span.t -> string -> t
108109
val invalid_sampling_no_such_dist : Location_span.t -> string -> t

src/frontend/Typechecker.ml

Lines changed: 91 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ let check_postfixop loc op te =
207207
~type_ ~loc
208208
| _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error
209209

210-
let check_variable cf loc tenv id =
210+
let check_id cf loc tenv id =
211211
match Env.find tenv (Utils.stdlib_distribution_name id.name) with
212212
| [] ->
213213
(* OCaml in these situations suggests similar names
@@ -217,9 +217,8 @@ let check_variable cf loc tenv id =
217217
(Env.nearest_ident tenv id.name)
218218
|> error
219219
| {kind= `StanMath; _} :: _ ->
220-
mk_typed_expression ~expr:(Variable id)
221-
~ad_level:(calculate_autodifftype cf MathLibrary UMathLibraryFunction)
222-
~type_:UMathLibraryFunction ~loc
220+
( calculate_autodifftype cf MathLibrary UMathLibraryFunction
221+
, UnsizedType.UMathLibraryFunction )
223222
| {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _
224223
when cf.in_toplevel_decl ->
225224
Semantic_error.non_data_variable_size_decl loc |> error
@@ -230,9 +229,7 @@ let check_variable cf loc tenv id =
230229
|| cf.current_block = Model ) ->
231230
Semantic_error.invalid_unnormalized_fn loc |> error
232231
| {kind= `Variable {origin; _}; type_} :: _ ->
233-
mk_typed_expression ~expr:(Variable id)
234-
~ad_level:(calculate_autodifftype cf origin type_)
235-
~type_ ~loc
232+
(calculate_autodifftype cf origin type_, type_)
236233
(* TODO - When it's time for overloading, will this need
237234
some kind of filter/match on arg types? *)
238235
| { kind= `UserDefined | `UserDeclared _
@@ -241,13 +238,13 @@ let check_variable cf loc tenv id =
241238
let type_ =
242239
UnsizedType.UFun
243240
(args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in
244-
mk_typed_expression ~expr:(Variable id)
245-
~ad_level:(calculate_autodifftype cf Functions type_)
246-
~type_ ~loc
241+
(calculate_autodifftype cf Functions type_, type_)
247242
| {kind= `UserDefined | `UserDeclared _; type_} :: _ ->
248-
mk_typed_expression ~expr:(Variable id)
249-
~ad_level:(calculate_autodifftype cf Functions type_)
250-
~type_ ~loc
243+
(calculate_autodifftype cf Functions type_, type_)
244+
245+
let check_variable cf loc tenv id =
246+
let ad_level, type_ = check_id cf loc tenv id in
247+
mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc
251248

252249
let get_consistent_types ad_level type_ es =
253250
let f state e =
@@ -298,6 +295,9 @@ let indexing_type idx =
298295
| Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single
299296
| _ -> `Multi
300297

298+
let is_multiindex i =
299+
match indexing_type i with `Single -> false | `Multi -> true
300+
301301
let inferred_unsizedtype_of_indexed ~loc ut indices =
302302
let rec aux type_ idcs =
303303
match (type_, idcs) with
@@ -730,34 +730,75 @@ let verify_assignment_global loc cf block is_global id =
730730
if (not is_global) || block = cf.current_block then ()
731731
else Semantic_error.cannot_assign_to_global loc id.name |> error
732732

733-
let mk_assignment_from_indexed_expr assop lhs rhs =
734-
Assignment
735-
{assign_lhs= Ast.lvalue_of_expr lhs; assign_op= assop; assign_rhs= rhs}
736-
737-
let check_assignment_operator loc assop lhs rhs =
733+
let verify_assignment_operator loc assop lhs rhs =
738734
let err op =
739-
Semantic_error.illtyped_assignment loc op lhs.emeta.type_ rhs.emeta.type_
735+
Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_
740736
in
741-
let () =
742-
match assop with
743-
| Assign | ArrowAssign ->
744-
if
745-
UnsizedType.check_of_same_type_mod_array_conv "" lhs.emeta.type_
746-
rhs.emeta.type_
747-
then ()
748-
else err Operator.Equals |> error
749-
| OperatorAssign op -> (
750-
let args = List.map ~f:arg_type [lhs; rhs] in
751-
let return_type =
752-
Stan_math_signatures.assignmentoperator_stan_math_return_type op args
753-
in
754-
match return_type with Some Void -> () | _ -> err op |> error ) in
755-
mk_typed_statement ~return_type:NoReturnType ~loc
756-
~stmt:(mk_assignment_from_indexed_expr assop lhs rhs)
737+
match assop with
738+
| Assign | ArrowAssign ->
739+
if
740+
UnsizedType.check_of_same_type_mod_array_conv "" lhs.lmeta.type_
741+
rhs.emeta.type_
742+
then ()
743+
else err Operator.Equals |> error
744+
| OperatorAssign op -> (
745+
let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in
746+
let return_type =
747+
Stan_math_signatures.assignmentoperator_stan_math_return_type op args
748+
in
749+
match return_type with Some Void -> () | _ -> err op |> error )
750+
751+
let check_lvalue cf tenv = function
752+
| {lval= LVariable id; lmeta= ({loc} : located_meta)} ->
753+
verify_identifier id ;
754+
let ad_level, type_ = check_id cf loc tenv id in
755+
{lval= LVariable id; lmeta= {ad_level; type_; loc}}
756+
| {lval= LIndexed (lval, idcs); lmeta= {loc}} ->
757+
let rec check_inner = function
758+
| {lval= LVariable id; lmeta= ({loc} : located_meta)} ->
759+
verify_identifier id ;
760+
let ad_level, type_ = check_id cf loc tenv id in
761+
let var = {lval= LVariable id; lmeta= {ad_level; type_; loc}} in
762+
(var, var, [])
763+
| {lval= LIndexed (lval, idcs); lmeta= {loc}} ->
764+
let lval, var, flat = check_inner lval in
765+
let idcs = List.map ~f:(check_index cf tenv) idcs in
766+
let ad_level =
767+
inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in
768+
let type_ =
769+
inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in
770+
( {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}}
771+
, var
772+
, flat @ idcs ) in
773+
let lval, var, flat = check_inner lval in
774+
let idcs = List.map ~f:(check_index cf tenv) idcs in
775+
let ad_level = inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in
776+
let type_ = inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in
777+
if List.exists ~f:is_multiindex flat then (
778+
add_warning loc
779+
"Nested multi-indexing on the left hand side of assignment does not \
780+
behave the same as nested indexing in expressions. This is \
781+
considered a bug and will be disallowed in Stan 2.32.0. The \
782+
indexing can be automatically fixed using the canonicalize flag for \
783+
stanc." ;
784+
let lvalue_rvalue_types_differ =
785+
try
786+
let flat_type =
787+
inferred_unsizedtype_of_indexed ~loc var.lmeta.type_ (flat @ idcs)
788+
in
789+
let rec can_assign = function
790+
| UnsizedType.(UArray t1, UArray t2) -> can_assign (t1, t2)
791+
| UVector, URowVector | URowVector, UVector -> false
792+
| t1, t2 -> UnsizedType.compare t1 t2 <> 0 in
793+
can_assign (flat_type, type_)
794+
with Errors.SemanticError _ -> true in
795+
if lvalue_rvalue_types_differ then
796+
Semantic_error.cannot_assign_to_multiindex loc |> error ) ;
797+
{lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}}
757798

758799
let check_assignment loc cf tenv assign_lhs assign_op assign_rhs =
759800
let assign_id = Ast.id_of_lvalue assign_lhs in
760-
let lhs = assign_lhs |> expr_of_lvalue |> check_expression cf tenv in
801+
let lhs = check_lvalue cf tenv assign_lhs in
761802
let rhs = check_expression cf tenv assign_rhs in
762803
let block, global, readonly =
763804
let var = Env.find tenv assign_id.name in
@@ -772,7 +813,9 @@ let check_assignment loc cf tenv assign_lhs assign_op assign_rhs =
772813
|> error in
773814
verify_assignment_global loc cf block global assign_id ;
774815
verify_assignment_read_only loc readonly assign_id ;
775-
check_assignment_operator loc assign_op lhs rhs
816+
verify_assignment_operator loc assign_op lhs rhs ;
817+
mk_typed_statement ~return_type:NoReturnType ~loc
818+
~stmt:(Assignment {assign_lhs= lhs; assign_op; assign_rhs= rhs})
776819

777820
(* target plus-equals / increment log-prob *)
778821

@@ -1167,20 +1210,17 @@ and check_sizedtype cf tenv sizedty =
11671210

11681211
and check_var_decl_initial_value loc cf tenv id init_val_opt =
11691212
match init_val_opt with
1170-
| Some e -> (
1171-
let stmt =
1172-
Assignment
1173-
{ assign_lhs= {lval= LVariable id; lmeta= {loc}}
1174-
; assign_op= Assign
1175-
; assign_rhs= e } in
1176-
mk_untyped_statement ~loc ~stmt
1177-
|> check_statement cf tenv |> snd
1178-
|> fun ts ->
1179-
match (ts.stmt, ts.smeta.return_type) with
1180-
| Assignment {assign_rhs= ue; _}, NoReturnType -> Some ue
1181-
| _ ->
1182-
Common.FatalError.fatal_error_msg
1183-
[%message " check_var_decl: `Assignment` expected."] )
1213+
| Some e ->
1214+
let lhs = check_lvalue cf tenv {lval= LVariable id; lmeta= {loc}} in
1215+
let rhs = check_expression cf tenv e in
1216+
if
1217+
UnsizedType.check_of_same_type_mod_array_conv "" lhs.lmeta.type_
1218+
rhs.emeta.type_
1219+
then Some rhs
1220+
else
1221+
Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_
1222+
rhs.emeta.type_
1223+
|> error
11841224
| None -> None
11851225

11861226
and check_transformation cf tenv ut trans =
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
data {
2+
array[5] vector[5] x;
3+
}
4+
transformed data {
5+
array[5] vector[5] y;
6+
y[:][1] = x[:][1];
7+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
data {
2+
vector[5] x;
3+
}
4+
transformed data {
5+
vector[5] y;
6+
y[:][:] = x[:][:];
7+
}

test/integration/bad/stanc.expected

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,28 @@ Semantic error in 'lp-error.stan', line 5, column 2 to column 6:
13581358
-------------------------------------------------
13591359

13601360
Identifier 'lp__' clashes with reserved keyword.
1361+
$ ../../../../install/default/bin/stanc lvalue_indexes1.stan
1362+
Semantic error in 'lvalue_indexes1.stan', line 6, column 4 to column 11:
1363+
-------------------------------------------------
1364+
4: transformed data {
1365+
5: array[5] vector[5] y;
1366+
6: y[:][1] = x[:][1];
1367+
^
1368+
7: }
1369+
-------------------------------------------------
1370+
1371+
Left hand side of an assignment cannot have nested multi-indexing.
1372+
$ ../../../../install/default/bin/stanc lvalue_indexes2.stan
1373+
Semantic error in 'lvalue_indexes2.stan', line 6, column 4 to column 11:
1374+
-------------------------------------------------
1375+
4: transformed data {
1376+
5: vector[5] y;
1377+
6: y[:][:] = x[:][:];
1378+
^
1379+
7: }
1380+
-------------------------------------------------
1381+
1382+
Left hand side of an assignment cannot have nested multi-indexing.
13611383
$ ../../../../install/default/bin/stanc matrix_expr_bad1.stan
13621384
Semantic error in 'matrix_expr_bad1.stan', line 2, column 2 to column 34:
13631385
-------------------------------------------------

test/integration/cli-args/canonicalize/canonical.expected

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ transformed data {
183183
array[0] int x_i;
184184
array[0] real x_r;
185185
matrix[N, N] K = gp_exp_quad_cov(x_quad, 1.0, 1.0);
186+
array[5, 5] real idxs;
187+
idxs[1][ : ] = idxs[1][ : ];
188+
idxs[ : , 1] = idxs[ : ][2];
186189
}
187190
parameters {
188191
real x;
@@ -232,6 +235,11 @@ generated quantities {
232235
x_r, x_i);
233236
}
234237

238+
Warning in 'deprecated.stan', line 36, column 2: Nested multi-indexing on the
239+
left hand side of assignment does not behave the same as nested indexing
240+
in expressions. This is considered a bug and will be disallowed in Stan
241+
2.32.0. The indexing can be automatically fixed using the canonicalize
242+
flag for stanc.
235243
$ ../../../../../install/default/bin/stanc --print-canonical funs.stanfunctions
236244
// comment test comment
237245
void test(int x, int y) {

test/integration/cli-args/canonicalize/deprecated.stan

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ transformed data {
3131
array[0] int x_i;
3232
array[0] real x_r;
3333
matrix[N, N] K = cov_exp_quad(x_quad, 1.0, 1.0);
34-
}
34+
real idxs[5,5];
35+
idxs[1][:] = idxs[1][:];
36+
idxs[:][1] = idxs[:][2];
37+
}
3538
parameters {
3639
real x;
3740
array[3] real theta;

test/integration/cli-args/canonicalize/deprecations-only.expected

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ transformed data {
163163
array[0] int x_i;
164164
array[0] real x_r;
165165
matrix[N, N] K = gp_exp_quad_cov(x_quad, 1.0, 1.0);
166+
array[5, 5] real idxs;
167+
idxs[1][ : ] = idxs[1][ : ];
168+
idxs[ : , 1] = idxs[ : ][2];
166169
}
167170
parameters {
168171
real x;
@@ -212,6 +215,11 @@ generated quantities {
212215
x_r, x_i);
213216
}
214217

218+
Warning in 'deprecated.stan', line 36, column 2: Nested multi-indexing on the
219+
left hand side of assignment does not behave the same as nested indexing
220+
in expressions. This is considered a bug and will be disallowed in Stan
221+
2.32.0. The indexing can be automatically fixed using the canonicalize
222+
flag for stanc.
215223
$ ../../../../../install/default/bin/stanc --auto-format --canonicalize deprecations funs.stanfunctions
216224
// comment test comment
217225
void test(int x, int y) {

0 commit comments

Comments
 (0)