Skip to content

Commit 6826a26

Browse files
committed
Add support for struct and field projection
1 parent b39bbd8 commit 6826a26

File tree

3 files changed

+258
-8
lines changed

3 files changed

+258
-8
lines changed

crates/formality-check/src/mini_rust_check.rs

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ use std::iter::zip;
33
use formality_core::{Fallible, Map, Upcast};
44
use formality_prove::Env;
55
use formality_rust::grammar::minirust::ArgumentExpression::{ByValue, InPlace};
6-
use formality_rust::grammar::minirust::PlaceExpression::Local;
7-
use formality_rust::grammar::minirust::ValueExpression::{Constant, Fn, Load};
6+
use formality_rust::grammar::minirust::PlaceExpression::*;
7+
use formality_rust::grammar::minirust::ValueExpression::{Constant, Fn, Load, Struct};
88
use formality_rust::grammar::minirust::{
99
self, ty_is_int, ArgumentExpression, BasicBlock, BbId, LocalId, PlaceExpression,
1010
ValueExpression,
1111
};
12-
use formality_rust::grammar::FnBoundData;
13-
use formality_types::grammar::{CrateId, FnId};
14-
use formality_types::grammar::{Relation, Ty, Wcs};
12+
use formality_rust::grammar::{FnBoundData, StructBoundData};
13+
use formality_types::grammar::{AdtId, Relation, Ty, TyData, Wcs};
14+
use formality_types::grammar::{CrateId, FnId, RigidName};
1515

1616
use crate::{Check, CrateItem};
1717
use anyhow::bail;
@@ -95,6 +95,7 @@ impl Check<'_> {
9595
callee_input_tys: Map::new(),
9696
crate_id: crate_id.clone(),
9797
fn_args: body.args.clone(),
98+
adt_tys: Map::new(),
9899
};
99100

100101
// (4) Check statements in body are valid
@@ -276,6 +277,38 @@ impl Check<'_> {
276277
};
277278
place_ty = ty;
278279
}
280+
Field(field_projection) => {
281+
let Local(ref local_id) = *field_projection.root else {
282+
bail!("Only Local is allowed as the root of FieldProjection")
283+
};
284+
285+
// Check if the index is valid for the tuple.
286+
// FIXME: use let chain here?
287+
288+
let Some(ty) = env.local_variables.get(&local_id) else {
289+
bail!("The local id used in PlaceExpression::Field is invalid.")
290+
};
291+
292+
// Get the ADT type information.
293+
let TyData::RigidTy(rigid_ty) = ty.data() else {
294+
bail!("The type for field projection must be rigid ty")
295+
};
296+
297+
// FIXME: it'd be nice to have the field information in ty data
298+
let RigidName::AdtId(ref adt_id) = rigid_ty.name else {
299+
bail!("The type for field projection must be adt")
300+
};
301+
302+
let Some(tys) = env.adt_tys.get(&adt_id) else {
303+
bail!("The ADT used is invalid.")
304+
};
305+
306+
if field_projection.index >= tys.len() {
307+
bail!("The field index used in PlaceExpression::Field is invalid.")
308+
}
309+
310+
place_ty = tys[field_projection.index].clone();
311+
}
279312
}
280313
Ok(place_ty.clone())
281314
}
@@ -317,6 +350,7 @@ impl Check<'_> {
317350
if fn_declared.id == *fn_id {
318351
let fn_bound_data =
319352
typeck_env.env.instantiate_universally(&fn_declared.binder);
353+
// FIXME: maybe we should store the information somewhere else, like in the value expression?
320354
// Store the callee information in typeck_env, we will need this when type checking Terminator::Call.
321355
typeck_env
322356
.callee_input_tys
@@ -341,6 +375,77 @@ impl Check<'_> {
341375
// it will be rejected by the parser.
342376
Ok(constant.get_ty())
343377
}
378+
Struct(value_expressions, ty) => {
379+
let mut struct_field_ty = Vec::new();
380+
381+
// Check if the adt type is valid in current crate.
382+
if let TyData::RigidTy(rigid_ty) = ty.data() {
383+
if let RigidName::AdtId(adt_id) = &rigid_ty.name {
384+
// Find the crate that is currently being typeck.
385+
let curr_crate = self
386+
.program
387+
.crates
388+
.iter()
389+
.find(|c| c.id == typeck_env.crate_id)
390+
.unwrap();
391+
392+
// Find the struct from current crate.
393+
let target_struct = curr_crate.items.iter().find(|item| {
394+
match item {
395+
CrateItem::Struct(struct_item) => {
396+
if struct_item.id == *adt_id {
397+
// Get the ty data of the field.
398+
let (
399+
_,
400+
StructBoundData {
401+
where_clauses: _,
402+
fields,
403+
},
404+
) = struct_item.binder.open();
405+
for field in fields {
406+
struct_field_ty.push(field.ty);
407+
}
408+
return true;
409+
}
410+
false
411+
}
412+
_ => false,
413+
}
414+
});
415+
416+
if target_struct.is_none() {
417+
bail!("The type used in Tuple is not declared in current crate")
418+
}
419+
// We will need the adt type information when type checking field projection.
420+
typeck_env
421+
.adt_tys
422+
.insert(adt_id.clone(), struct_field_ty.clone());
423+
}
424+
}
425+
426+
// Make sure the length of value expression matches the length of field of adt.
427+
if value_expressions.len() != struct_field_ty.len() {
428+
bail!("The length of ValueExpression::Tuple does not match the type of the ADT declared")
429+
}
430+
431+
let expression_ty_pair = zip(value_expressions, struct_field_ty);
432+
433+
// FIXME: we only support const in value expression of tuple for now, we can add support
434+
// more in future.
435+
436+
for (value_expression, declared_ty) in expression_ty_pair {
437+
let Constant(_) = value_expression else {
438+
bail!("Only Constant is supported in ValueExpression::Tuple for now.")
439+
};
440+
let ty = self.check_value(typeck_env, value_expression)?;
441+
442+
// Make sure the type matches the declared adt.
443+
if ty != declared_ty {
444+
bail!("The type in ValueExpression::Tuple does not match the ADT declared")
445+
}
446+
}
447+
Ok(ty.clone())
448+
}
344449
}
345450
}
346451

@@ -402,4 +507,7 @@ struct TypeckEnv {
402507

403508
/// LocalId of function argument.
404509
fn_args: Vec<LocalId>,
510+
511+
/// Type information of adt
512+
adt_tys: Map<AdtId, Vec<Ty>>,
405513
}

crates/formality-rust/src/grammar/minirust.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ use formality_types::grammar::{Parameter, RigidName, ScalarId, Ty, TyData};
55
use crate::grammar::minirust::ConstTypePair::*;
66
use crate::grammar::FnId;
77

8+
use std::sync::Arc;
9+
810
// This definition is based on [MiniRust](https://github.com/minirust/minirust/blob/master/spec/lang/syntax.md).
911

1012
id!(BbId);
1113
id!(LocalId);
14+
id!(FieldId);
1215

1316
// Example:
1417
//
@@ -138,8 +141,11 @@ pub enum ValueExpression {
138141
Constant(ConstTypePair),
139142
#[grammar(fn_id $v0)]
140143
Fn(FnId),
141-
// #[grammar($(v0) as $v1)]
142-
// Tuple(Vec<ValueExpression>, Ty),
144+
// FIXME: minirust uses typle to represent arrays, structs, tuples (including unit).
145+
// But I think it will be quite annoying to do typecking when we have all these types
146+
// together, so I added a variant just for struct.
147+
#[grammar(struct ${v0} as $v1)] // FIXME: use comma separated
148+
Struct(Vec<ValueExpression>, Ty),
143149
// Union
144150
// Variant
145151
// GetDiscriminant
@@ -229,7 +235,17 @@ pub enum PlaceExpression {
229235
#[grammar(local($v0))]
230236
Local(LocalId),
231237
// Deref(Arc<ValueExpression>),
232-
// Field(Arc<PlaceExpression>, FieldId),
238+
// Project to a field.
239+
#[grammar($v0)]
240+
Field(FieldProjection),
233241
// Index
234242
// Downcast
235243
}
244+
245+
#[term($root.$index)]
246+
pub struct FieldProjection {
247+
/// The place to base the projection on.
248+
pub root: Arc<PlaceExpression>,
249+
/// The field to project to.
250+
pub index: usize,
251+
}

src/test/mir_fn_bodies.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,37 @@ fn test_storage_live_dead() {
254254
)
255255
}
256256

257+
/// Test valid program that uses struct.
258+
#[test]
259+
fn test_struct() {
260+
crate::assert_ok!(
261+
[
262+
crate Foo {
263+
struct Dummy {
264+
value: u32,
265+
}
266+
267+
fn foo (u32) -> u32 = minirust(v1) -> v0 {
268+
let v0: u32;
269+
let v1: u32;
270+
let v2: Dummy;
271+
272+
bb0: {
273+
statements {
274+
local(v0) = load(local(v1));
275+
local(v2) = struct { constant(1: u32) } as Dummy;
276+
local(v2).0 = constant(2: u32);
277+
}
278+
return;
279+
}
280+
281+
};
282+
}
283+
]
284+
expect_test::expect![["()"]]
285+
)
286+
}
287+
257288
// Test what will happen if the next block does not exist for Terminator::Call.
258289
#[test]
259290
fn test_no_next_bb_for_call_terminator() {
@@ -737,3 +768,98 @@ fn test_fn_arg_storage_dead() {
737768
expect_test::expect!["Statement::StorageDead: trying to mark function arguments or return local as dead"]
738769
)
739770
}
771+
772+
/// Test the behaviour of using invalid index for the struct field.
773+
#[test]
774+
fn test_invalid_struct_field() {
775+
crate::assert_err!(
776+
[
777+
crate Foo {
778+
struct Dummy {
779+
value: u32,
780+
}
781+
782+
fn foo (u32) -> u32 = minirust(v1) -> v0 {
783+
let v0: u32;
784+
let v1: u32;
785+
let v2: Dummy;
786+
787+
bb0: {
788+
statements {
789+
local(v0) = load(local(v1));
790+
local(v2) = struct { constant(1: u32) } as Dummy;
791+
local(v2).1 = constant(2: u32);
792+
}
793+
return;
794+
}
795+
796+
};
797+
}
798+
]
799+
[]
800+
expect_test::expect!["The field index used in PlaceExpression::Field is invalid."]
801+
)
802+
}
803+
804+
/// Test the behaviour of using non-adt local for field projection.
805+
#[test]
806+
fn test_field_projection_root_non_adt() {
807+
crate::assert_err!(
808+
[
809+
crate Foo {
810+
struct Dummy {
811+
value: u32,
812+
}
813+
814+
fn foo (u32) -> u32 = minirust(v1) -> v0 {
815+
let v0: u32;
816+
let v1: u32;
817+
let v2: Dummy;
818+
819+
bb0: {
820+
statements {
821+
local(v0) = load(local(v1));
822+
local(v2) = struct { constant(1: u32) } as Dummy;
823+
local(v1).1 = constant(2: u32);
824+
}
825+
return;
826+
}
827+
828+
};
829+
}
830+
]
831+
[]
832+
expect_test::expect!["The type for field projection must be adt"]
833+
)
834+
}
835+
836+
/// Test the behaviour of initialising the struct with wrong type.
837+
#[test]
838+
fn test_struct_wrong_type_in_initialisation() {
839+
crate::assert_err!(
840+
[
841+
crate Foo {
842+
struct Dummy {
843+
value: u32,
844+
}
845+
846+
fn foo (u32) -> u32 = minirust(v1) -> v0 {
847+
let v0: u32;
848+
let v1: u32;
849+
let v2: Dummy;
850+
851+
bb0: {
852+
statements {
853+
local(v0) = load(local(v1));
854+
local(v2) = struct { constant(false) } as Dummy;
855+
}
856+
return;
857+
}
858+
859+
};
860+
}
861+
]
862+
[]
863+
expect_test::expect!["The type in ValueExpression::Tuple does not match the ADT declared"]
864+
)
865+
}

0 commit comments

Comments
 (0)