Skip to content

Commit e429d24

Browse files
committed
Get the adt id from the type directly
1 parent a35f3c6 commit e429d24

File tree

2 files changed

+28
-33
lines changed

2 files changed

+28
-33
lines changed

crates/formality-check/src/mini_rust_check.rs

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use formality_rust::grammar::minirust::{
1010
ValueExpression,
1111
};
1212
use formality_rust::grammar::FnBoundData;
13-
use formality_types::grammar::{CrateId, FnId, RigidName};
14-
use formality_types::grammar::{Relation, Ty, TyData, VariantId, Wcs};
13+
use formality_types::grammar::{CrateId, FnId};
14+
use formality_types::grammar::{Relation, Ty, VariantId, Wcs};
1515

1616
use crate::{Check, CrateItem};
1717
use anyhow::bail;
@@ -281,23 +281,11 @@ impl Check<'_> {
281281
bail!("Only Local is allowed as the root of FieldProjection")
282282
};
283283

284-
// Check if the index is valid for the tuple.
285-
// FIXME: use let chain here?
286-
287284
let Some(ty) = env.local_variables.get(&local_id) else {
288285
bail!("The local id used in PlaceExpression::Field is invalid.")
289286
};
290287

291-
// Get the ADT type information.
292-
let TyData::RigidTy(rigid_ty) = ty.data() else {
293-
bail!("The type for field projection must be rigid ty")
294-
};
295-
296-
// FIXME: directly get the adt_id information from ty
297-
298-
let RigidName::AdtId(ref adt_id) = rigid_ty.name else {
299-
bail!("The type for field projection must be adt")
300-
};
288+
let adt_id = ty.get_adt_id().unwrap();
301289

302290
let (
303291
_,
@@ -312,6 +300,7 @@ impl Check<'_> {
312300
bail!("The local used for field projection must be struct.")
313301
}
314302

303+
// Check if the index is valid for the tuple.
315304
if field_projection.index >= fields.len() {
316305
bail!("The field index used in PlaceExpression::Field is invalid.")
317306
}
@@ -353,13 +342,13 @@ impl Check<'_> {
353342
.unwrap();
354343

355344
// Find the callee from current crate.
345+
// FIXME: get the information from decl too
356346
let callee = curr_crate.items.iter().find(|item| {
357347
match item {
358348
CrateItem::Fn(fn_declared) => {
359349
if fn_declared.id == *fn_id {
360350
let fn_bound_data =
361351
typeck_env.env.instantiate_universally(&fn_declared.binder);
362-
// FIXME: maybe we should store the information somewhere else, like in the value expression?
363352
// Store the callee information in typeck_env, we will need this when type checking Terminator::Call.
364353
typeck_env
365354
.callee_input_tys
@@ -385,28 +374,25 @@ impl Check<'_> {
385374
Ok(constant.get_ty())
386375
}
387376
Struct(value_expressions, ty) => {
388-
let mut struct_field_ty = Vec::new();
377+
let adt_id = ty.get_adt_id().unwrap();
389378

390379
// Check the validity of the struct.
391-
if let TyData::RigidTy(rigid_ty) = ty.data() {
392-
if let RigidName::AdtId(adt_id) = &rigid_ty.name {
393-
let (
394-
_,
395-
AdtDeclBoundData {
396-
where_clause: _,
397-
variants,
398-
},
399-
) = self.decls.adt_decl(&adt_id).binder.open();
400-
let AdtDeclVariant { name, fields } = variants.last().unwrap();
401-
402-
if *name != VariantId::for_struct() {
403-
bail!("This type used in ValueExpression::Struct should be a struct")
404-
}
380+
let (
381+
_,
382+
AdtDeclBoundData {
383+
where_clause: _,
384+
variants,
385+
},
386+
) = self.decls.adt_decl(&adt_id).binder.open();
387+
let AdtDeclVariant { name, fields } = variants.last().unwrap();
405388

406-
struct_field_ty = fields.iter().map(|field| field.ty.clone()).collect();
407-
}
389+
if *name != VariantId::for_struct() {
390+
bail!("This type used in ValueExpression::Struct should be a struct")
408391
}
409392

393+
let struct_field_ty: Vec<Ty> =
394+
fields.iter().map(|field| field.ty.clone()).collect();
395+
410396
if value_expressions.len() != struct_field_ty.len() {
411397
bail!("The length of ValueExpression::Tuple does not match the type of the ADT declared")
412398
}

crates/formality-types/src/grammar/ty.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ impl Ty {
8383
}
8484
.upcast()
8585
}
86+
87+
pub fn get_adt_id(&self) -> Option<AdtId> {
88+
if let TyData::RigidTy(rigid_ty) = self.data() {
89+
if let RigidName::AdtId(ref adt_id) = rigid_ty.name {
90+
return Some(adt_id.clone());
91+
};
92+
};
93+
None
94+
}
8695
}
8796

8897
impl UpcastFrom<TyData> for Ty {

0 commit comments

Comments
 (0)