Skip to content

Commit 2dbceb6

Browse files
committed
[ast][inference] Merge of AST return ModuleAST
It can be asserted that `ast.merge()` will always be of length 1. We don't need to loop over the modules. This removes one loop nest in the inference code. Signed-off-by: Manas <manas18244@iiitd.ac.in>
1 parent 41b6f35 commit 2dbceb6

File tree

2 files changed

+144
-146
lines changed

2 files changed

+144
-146
lines changed

src/ast.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl Qast {
100100

101101
/// Merge all modules in AST to one monolith module. Ensure mangling happens
102102
/// and function calls are referenced to their definitions.
103-
pub(crate) fn merge(&mut self) -> Qast {
103+
pub(crate) fn merge(&mut self) -> ModuleAST {
104104
let flattened_ast = self
105105
.modules
106106
.iter()
@@ -133,7 +133,7 @@ impl Qast {
133133
functions: flattened_ast,
134134
};
135135

136-
Qast::new(vec![std::rc::Rc::new(module.into())])
136+
module
137137
}
138138
}
139139

src/inference.rs

Lines changed: 142 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -181,172 +181,170 @@ pub fn infer(ast: &mut Qast) -> Result<()> {
181181

182182
// Merge all modules in one giant monolith module. Easier to do DCE and type
183183
// inference.
184-
let mut ast = ast.merge();
184+
let mut module = ast.merge();
185+
186+
let module_name = module.get_name();
187+
// functions but only collect their names and return types.
188+
for function in &module {
189+
function_table.push(VarAST::new_with_type(
190+
function.get_name().clone(),
191+
function.get_loc().clone(),
192+
function.get_output_type().clone(),
193+
));
194+
// A copy of function prepended with its module name is also added.
195+
// If the function is used inside the module, then we check against
196+
// the value pushed above, and it is called from other module, then
197+
// we check against the value pushed below.
198+
function_table.push(VarAST::new_with_type(
199+
module_name.clone() + "$" + function.get_name(),
200+
function.get_loc().clone(),
201+
function.get_output_type().clone(),
202+
));
203+
}
185204

186-
for mut module in &mut ast {
187-
let module_name = module.get_name();
188-
// functions but only collect their names and return types.
189-
for function in &*module {
190-
function_table.push(VarAST::new_with_type(
191-
function.get_name().clone(),
192-
function.get_loc().clone(),
193-
function.get_output_type().clone(),
194-
));
195-
// A copy of function prepended with its module name is also added.
196-
// If the function is used inside the module, then we check against
197-
// the value pushed above, and it is called from other module, then
198-
// we check against the value pushed below.
199-
function_table.push(VarAST::new_with_type(
200-
module_name.clone() + "$" + function.get_name(),
201-
function.get_loc().clone(),
202-
function.get_output_type().clone(),
203-
));
204-
}
205-
206-
for mut function in &mut *module {
207-
// parameter symbols
208-
let mut parameter_table: SymbolTable<VarAST> = SymbolTable::new();
209-
for param in function.iter_params() {
210-
parameter_table.push(param.clone());
211-
}
212-
213-
// local variables
214-
let mut local_var_table: SymbolTable<VarAST> = SymbolTable::new();
215-
for instruction in &*function {
216-
// only add let-lhs and only if they are type checked
217-
match *instruction.as_ref().borrow() {
218-
Expr::Let(ref def, _) => {
219-
// don't type check lhs-rhs, otherwise along with a
220-
// mismatch error, an unknown type error would also be
221-
// raised if local st doesn't find typed lhs.
222-
let checked: Result<Type> = Ok(def.get_type());
223-
if checked.is_ok_and(|ty| ty != Type::Bottom) {
224-
local_var_table.push(def.clone());
225-
}
205+
for mut function in &mut module {
206+
// parameter symbols
207+
let mut parameter_table: SymbolTable<VarAST> = SymbolTable::new();
208+
for param in function.iter_params() {
209+
parameter_table.push(param.clone());
210+
}
211+
212+
// local variables
213+
let mut local_var_table: SymbolTable<VarAST> = SymbolTable::new();
214+
for instruction in &*function {
215+
// only add let-lhs and only if they are type checked
216+
match *instruction.as_ref().borrow() {
217+
Expr::Let(ref def, _) => {
218+
// don't type check lhs-rhs, otherwise along with a
219+
// mismatch error, an unknown type error would also be
220+
// raised if local st doesn't find typed lhs.
221+
let checked: Result<Type> = Ok(def.get_type());
222+
if checked.is_ok_and(|ty| ty != Type::Bottom) {
223+
local_var_table.push(def.clone());
226224
}
227-
_ => {}
228225
}
226+
_ => {}
229227
}
228+
}
230229

231-
// infer local var types
232-
for instruction in &mut *function {
233-
let instruction_type = infer_expr(instruction);
230+
// infer local var types
231+
for instruction in &mut *function {
232+
let instruction_type = infer_expr(instruction);
234233

235-
if instruction_type.is_some_and(|ty| ty != Type::Bottom) {
236-
match *instruction.as_ref().borrow() {
237-
Expr::Let(ref var, _) => {
238-
if var.is_typed() {
239-
local_var_table.push(var.clone());
240-
}
234+
if instruction_type.is_some_and(|ty| ty != Type::Bottom) {
235+
match *instruction.as_ref().borrow() {
236+
Expr::Let(ref var, _) => {
237+
if var.is_typed() {
238+
local_var_table.push(var.clone());
241239
}
242-
_ => {}
243240
}
241+
_ => {}
244242
}
243+
}
245244

246-
if instruction_type.is_none() || instruction_type == Some(Type::Bottom) {
247-
// we couldn't infer all types for expression
248-
// see if either symbol table contains any information
249-
match infer_from_table(
250-
instruction,
251-
&parameter_table,
252-
&local_var_table,
253-
&function_table,
254-
) {
255-
None => {
256-
// This infers type for let expressions based on the
257-
// symbol table but doesn't update the table
258-
// entries. For e.g.,
259-
// ```quale
260-
// let a: f64 = 42;
261-
// let b = a; // this is inferred as f64 type,
262-
// // but symbol table
263-
// // doesn't contain it after
264-
// // inferring
265-
// let c = b; // hence, this would fail to be
266-
// // inferred
267-
// ```
268-
// So we have to update symbol tables accordingly.
269-
match *instruction.as_ref().borrow() {
270-
Expr::Let(ref var, _) => {
271-
if var.is_typed() {
272-
local_var_table.push(var.clone());
273-
}
245+
if instruction_type.is_none() || instruction_type == Some(Type::Bottom) {
246+
// we couldn't infer all types for expression
247+
// see if either symbol table contains any information
248+
match infer_from_table(
249+
instruction,
250+
&parameter_table,
251+
&local_var_table,
252+
&function_table,
253+
) {
254+
None => {
255+
// This infers type for let expressions based on the
256+
// symbol table but doesn't update the table
257+
// entries. For e.g.,
258+
// ```quale
259+
// let a: f64 = 42;
260+
// let b = a; // this is inferred as f64 type,
261+
// // but symbol table
262+
// // doesn't contain it after
263+
// // inferring
264+
// let c = b; // hence, this would fail to be
265+
// // inferred
266+
// ```
267+
// So we have to update symbol tables accordingly.
268+
match *instruction.as_ref().borrow() {
269+
Expr::Let(ref var, _) => {
270+
if var.is_typed() {
271+
local_var_table.push(var.clone());
274272
}
275-
_ => {}
276273
}
274+
_ => {}
277275
}
278-
Some(untyped) => {
279-
seen_errors = true;
280-
match untyped {
281-
Ok(expr) => {
282-
// unknown type of expression err
283-
let expr = expr.as_ref().borrow();
284-
qcceprintln!(
285-
"{} for `{}` {}",
286-
QccErrorKind::UnknownType,
287-
expr,
288-
expr.get_location()
289-
);
290-
}
291-
Err(err) => {
292-
// err is returned
293-
let instruction = instruction.as_ref().borrow();
294-
qcceprintln!(
295-
"{} on\n\t{}\t{}",
296-
err,
297-
instruction.get_location().row(),
298-
instruction
299-
);
300-
}
276+
}
277+
Some(untyped) => {
278+
seen_errors = true;
279+
match untyped {
280+
Ok(expr) => {
281+
// unknown type of expression err
282+
let expr = expr.as_ref().borrow();
283+
qcceprintln!(
284+
"{} for `{}` {}",
285+
QccErrorKind::UnknownType,
286+
expr,
287+
expr.get_location()
288+
);
289+
}
290+
Err(err) => {
291+
// err is returned
292+
let instruction = instruction.as_ref().borrow();
293+
qcceprintln!(
294+
"{} on\n\t{}\t{}",
295+
err,
296+
instruction.get_location().row(),
297+
instruction
298+
);
301299
}
302300
}
303301
}
304302
}
305303
}
304+
}
306305

307-
// type check between function return type and the last returned
308-
// expression
309-
let fn_return_type = *function.get_output_type();
310-
let fn_name = function.borrow().get_name().clone();
306+
// type check between function return type and the last returned
307+
// expression
308+
let fn_return_type = *function.get_output_type();
309+
let fn_name = function.borrow().get_name().clone();
311310

312-
let last_instruction = function.last_mut();
313-
if last_instruction.is_some() {
314-
let last = last_instruction.unwrap();
311+
let last_instruction = function.last_mut();
312+
if last_instruction.is_some() {
313+
let last = last_instruction.unwrap();
315314

316-
// get last expression's type
317-
let last_instruction_type = infer_expr(last);
315+
// get last expression's type
316+
let last_instruction_type = infer_expr(last);
318317

319-
if fn_return_type == Type::Bottom
320-
&& last_instruction_type.is_some()
321-
&& last_instruction_type != Some(Type::Bottom)
322-
{
323-
function.set_output_type(last_instruction_type.unwrap());
324-
} else {
325-
if last_instruction_type != Some(fn_return_type) {
326-
seen_errors = true;
327-
let err: QccError = QccErrorKind::TypeMismatch.into();
328-
let last_expr = last.as_ref().borrow();
329-
if last_instruction_type.is_none() {
330-
qcceprintln!(
331-
"{} between\n\t`{}` ({}) and `{}` ({}) {}",
332-
err,
333-
last_expr,
334-
Type::Bottom,
335-
fn_name,
336-
fn_return_type,
337-
last.as_ref().borrow().get_location()
338-
);
339-
} else {
340-
qcceprintln!(
341-
"{} between\n\t`{}` ({}) and `{}` ({}) {}",
342-
err,
343-
last_expr,
344-
last_instruction_type.unwrap(),
345-
fn_name,
346-
fn_return_type,
347-
last.as_ref().borrow().get_location()
348-
);
349-
}
318+
if fn_return_type == Type::Bottom
319+
&& last_instruction_type.is_some()
320+
&& last_instruction_type != Some(Type::Bottom)
321+
{
322+
function.set_output_type(last_instruction_type.unwrap());
323+
} else {
324+
if last_instruction_type != Some(fn_return_type) {
325+
seen_errors = true;
326+
let err: QccError = QccErrorKind::TypeMismatch.into();
327+
let last_expr = last.as_ref().borrow();
328+
if last_instruction_type.is_none() {
329+
qcceprintln!(
330+
"{} between\n\t`{}` ({}) and `{}` ({}) {}",
331+
err,
332+
last_expr,
333+
Type::Bottom,
334+
fn_name,
335+
fn_return_type,
336+
last.as_ref().borrow().get_location()
337+
);
338+
} else {
339+
qcceprintln!(
340+
"{} between\n\t`{}` ({}) and `{}` ({}) {}",
341+
err,
342+
last_expr,
343+
last_instruction_type.unwrap(),
344+
fn_name,
345+
fn_return_type,
346+
last.as_ref().borrow().get_location()
347+
);
350348
}
351349
}
352350
}

0 commit comments

Comments
 (0)