@@ -207,7 +207,7 @@ let check_postfixop loc op te =
207207 ~type_ ~loc
208208 | _ -> Semantic_error. illtyped_postfix_op loc op te.emeta.type_ |> error
209209
210- let check_variable cf loc tenv id =
210+ let check_id cf loc tenv id =
211211 match Env. find tenv (Utils. stdlib_distribution_name id.name) with
212212 | [] ->
213213 (* OCaml in these situations suggests similar names
@@ -217,9 +217,8 @@ let check_variable cf loc tenv id =
217217 (Env. nearest_ident tenv id.name)
218218 |> error
219219 | {kind = `StanMath ; _} :: _ ->
220- mk_typed_expression ~expr: (Variable id)
221- ~ad_level: (calculate_autodifftype cf MathLibrary UMathLibraryFunction )
222- ~type_: UMathLibraryFunction ~loc
220+ ( calculate_autodifftype cf MathLibrary UMathLibraryFunction
221+ , UnsizedType. UMathLibraryFunction )
223222 | {kind= `Variable {origin= Param | TParam | GQuant ; _}; _} :: _
224223 when cf.in_toplevel_decl ->
225224 Semantic_error. non_data_variable_size_decl loc |> error
@@ -230,9 +229,7 @@ let check_variable cf loc tenv id =
230229 || cf.current_block = Model ) ->
231230 Semantic_error. invalid_unnormalized_fn loc |> error
232231 | {kind = `Variable {origin; _} ; type_} :: _ ->
233- mk_typed_expression ~expr: (Variable id)
234- ~ad_level: (calculate_autodifftype cf origin type_)
235- ~type_ ~loc
232+ (calculate_autodifftype cf origin type_, type_)
236233 (* TODO - When it's time for overloading, will this need
237234 some kind of filter/match on arg types? *)
238235 | { kind= `UserDefined | `UserDeclared _
@@ -241,13 +238,13 @@ let check_variable cf loc tenv id =
241238 let type_ =
242239 UnsizedType. UFun
243240 (args, rt, Fun_kind. suffix_from_name id.name, mem_pattern) in
244- mk_typed_expression ~expr: (Variable id)
245- ~ad_level: (calculate_autodifftype cf Functions type_)
246- ~type_ ~loc
241+ (calculate_autodifftype cf Functions type_, type_)
247242 | {kind = `UserDefined | `UserDeclared _ ; type_} :: _ ->
248- mk_typed_expression ~expr: (Variable id)
249- ~ad_level: (calculate_autodifftype cf Functions type_)
250- ~type_ ~loc
243+ (calculate_autodifftype cf Functions type_, type_)
244+
245+ let check_variable cf loc tenv id =
246+ let ad_level, type_ = check_id cf loc tenv id in
247+ mk_typed_expression ~expr: (Variable id) ~ad_level ~type_ ~loc
251248
252249let get_consistent_types ad_level type_ es =
253250 let f state e =
@@ -298,6 +295,9 @@ let indexing_type idx =
298295 | Single {emeta = {type_ = UnsizedType. UInt ; _} ; _} -> `Single
299296 | _ -> `Multi
300297
298+ let is_multiindex i =
299+ match indexing_type i with `Single -> false | `Multi -> true
300+
301301let inferred_unsizedtype_of_indexed ~loc ut indices =
302302 let rec aux type_ idcs =
303303 match (type_, idcs) with
@@ -730,34 +730,75 @@ let verify_assignment_global loc cf block is_global id =
730730 if (not is_global) || block = cf.current_block then ()
731731 else Semantic_error. cannot_assign_to_global loc id.name |> error
732732
733- let mk_assignment_from_indexed_expr assop lhs rhs =
734- Assignment
735- {assign_lhs= Ast. lvalue_of_expr lhs; assign_op= assop; assign_rhs= rhs}
736-
737- let check_assignment_operator loc assop lhs rhs =
733+ let verify_assignment_operator loc assop lhs rhs =
738734 let err op =
739- Semantic_error. illtyped_assignment loc op lhs.emeta .type_ rhs.emeta.type_
735+ Semantic_error. illtyped_assignment loc op lhs.lmeta .type_ rhs.emeta.type_
740736 in
741- let () =
742- match assop with
743- | Assign | ArrowAssign ->
744- if
745- UnsizedType. check_of_same_type_mod_array_conv " " lhs.emeta.type_
746- rhs.emeta.type_
747- then ()
748- else err Operator. Equals |> error
749- | OperatorAssign op -> (
750- let args = List. map ~f: arg_type [lhs; rhs] in
751- let return_type =
752- Stan_math_signatures. assignmentoperator_stan_math_return_type op args
753- in
754- match return_type with Some Void -> () | _ -> err op |> error ) in
755- mk_typed_statement ~return_type: NoReturnType ~loc
756- ~stmt: (mk_assignment_from_indexed_expr assop lhs rhs)
737+ match assop with
738+ | Assign | ArrowAssign ->
739+ if
740+ UnsizedType. check_of_same_type_mod_array_conv " " lhs.lmeta.type_
741+ rhs.emeta.type_
742+ then ()
743+ else err Operator. Equals |> error
744+ | OperatorAssign op -> (
745+ let args = List. map ~f: arg_type [Ast. expr_of_lvalue lhs; rhs] in
746+ let return_type =
747+ Stan_math_signatures. assignmentoperator_stan_math_return_type op args
748+ in
749+ match return_type with Some Void -> () | _ -> err op |> error )
750+
751+ let check_lvalue cf tenv = function
752+ | {lval = LVariable id ; lmeta = ({loc} : located_meta )} ->
753+ verify_identifier id ;
754+ let ad_level, type_ = check_id cf loc tenv id in
755+ {lval= LVariable id; lmeta= {ad_level; type_; loc}}
756+ | {lval = LIndexed (lval , idcs ); lmeta = {loc} } ->
757+ let rec check_inner = function
758+ | {lval = LVariable id ; lmeta = ({loc} : located_meta )} ->
759+ verify_identifier id ;
760+ let ad_level, type_ = check_id cf loc tenv id in
761+ let var = {lval= LVariable id; lmeta= {ad_level; type_; loc}} in
762+ (var, var, [] )
763+ | {lval = LIndexed (lval , idcs ); lmeta = {loc} } ->
764+ let lval, var, flat = check_inner lval in
765+ let idcs = List. map ~f: (check_index cf tenv) idcs in
766+ let ad_level =
767+ inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in
768+ let type_ =
769+ inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in
770+ ( {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}}
771+ , var
772+ , flat @ idcs ) in
773+ let lval, var, flat = check_inner lval in
774+ let idcs = List. map ~f: (check_index cf tenv) idcs in
775+ let ad_level = inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in
776+ let type_ = inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in
777+ if List. exists ~f: is_multiindex flat then (
778+ add_warning loc
779+ " Nested multi-indexing on the left hand side of assignment does not \
780+ behave the same as nested indexing in expressions. This is \
781+ considered a bug and will be disallowed in Stan 2.32.0. The \
782+ indexing can be automatically fixed using the canonicalize flag for \
783+ stanc." ;
784+ let lvalue_rvalue_types_differ =
785+ try
786+ let flat_type =
787+ inferred_unsizedtype_of_indexed ~loc var.lmeta.type_ (flat @ idcs)
788+ in
789+ let rec can_assign = function
790+ | UnsizedType. (UArray t1 , UArray t2 ) -> can_assign (t1, t2)
791+ | UVector , URowVector | URowVector , UVector -> false
792+ | t1 , t2 -> UnsizedType. compare t1 t2 <> 0 in
793+ can_assign (flat_type, type_)
794+ with Errors. SemanticError _ -> true in
795+ if lvalue_rvalue_types_differ then
796+ Semantic_error. cannot_assign_to_multiindex loc |> error ) ;
797+ {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}}
757798
758799let check_assignment loc cf tenv assign_lhs assign_op assign_rhs =
759800 let assign_id = Ast. id_of_lvalue assign_lhs in
760- let lhs = assign_lhs |> expr_of_lvalue |> check_expression cf tenv in
801+ let lhs = check_lvalue cf tenv assign_lhs in
761802 let rhs = check_expression cf tenv assign_rhs in
762803 let block, global, readonly =
763804 let var = Env. find tenv assign_id.name in
@@ -772,7 +813,9 @@ let check_assignment loc cf tenv assign_lhs assign_op assign_rhs =
772813 |> error in
773814 verify_assignment_global loc cf block global assign_id ;
774815 verify_assignment_read_only loc readonly assign_id ;
775- check_assignment_operator loc assign_op lhs rhs
816+ verify_assignment_operator loc assign_op lhs rhs ;
817+ mk_typed_statement ~return_type: NoReturnType ~loc
818+ ~stmt: (Assignment {assign_lhs= lhs; assign_op; assign_rhs= rhs})
776819
777820(* target plus-equals / increment log-prob *)
778821
@@ -1167,20 +1210,17 @@ and check_sizedtype cf tenv sizedty =
11671210
11681211and check_var_decl_initial_value loc cf tenv id init_val_opt =
11691212 match init_val_opt with
1170- | Some e -> (
1171- let stmt =
1172- Assignment
1173- { assign_lhs= {lval= LVariable id; lmeta= {loc}}
1174- ; assign_op= Assign
1175- ; assign_rhs= e } in
1176- mk_untyped_statement ~loc ~stmt
1177- |> check_statement cf tenv |> snd
1178- |> fun ts ->
1179- match (ts.stmt, ts.smeta.return_type) with
1180- | Assignment {assign_rhs = ue ; _} , NoReturnType -> Some ue
1181- | _ ->
1182- Common.FatalError. fatal_error_msg
1183- [% message " check_var_decl: `Assignment` expected." ] )
1213+ | Some e ->
1214+ let lhs = check_lvalue cf tenv {lval= LVariable id; lmeta= {loc}} in
1215+ let rhs = check_expression cf tenv e in
1216+ if
1217+ UnsizedType. check_of_same_type_mod_array_conv " " lhs.lmeta.type_
1218+ rhs.emeta.type_
1219+ then Some rhs
1220+ else
1221+ Semantic_error. illtyped_assignment loc Equals lhs.lmeta.type_
1222+ rhs.emeta.type_
1223+ |> error
11841224 | None -> None
11851225
11861226and check_transformation cf tenv ut trans =
0 commit comments