Skip to content

Commit 06b540d

Browse files
committed
Initial annotation parsing
1 parent bb9ce42 commit 06b540d

File tree

14 files changed

+444
-347
lines changed

14 files changed

+444
-347
lines changed

src/analysis_and_optimization/Optimize.ml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,18 @@ let gen_inline_var (name : string) (id_var : string) =
8484

8585
let replace_fresh_local_vars (fname : string) stmt =
8686
let f (m : (string, string) Core.Map.Poly.t) = function
87-
| Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} ->
87+
| Stmt.Fixed.Pattern.Decl
88+
{decl_adtype; decl_type; decl_id; decl_annotations; initialize} ->
8889
let new_name =
8990
match Map.Poly.find m decl_id with
9091
| Some existing -> existing
9192
| None -> gen_inline_var fname decl_id in
9293
( Stmt.Fixed.Pattern.Decl
93-
{decl_adtype; decl_id= new_name; decl_type; initialize}
94+
{ decl_adtype
95+
; decl_id= new_name
96+
; decl_type
97+
; decl_annotations
98+
; initialize }
9499
, Map.Poly.set m ~key:decl_id ~data:new_name )
95100
| Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} ->
96101
let new_name =
@@ -201,6 +206,7 @@ let handle_early_returns (fname : string) opt_var stmt =
201206
{ decl_adtype= DataOnly
202207
; decl_id= returned
203208
; decl_type= Sized SInt
209+
; decl_annotations= []
204210
; initialize= true }
205211
; meta= Location_span.empty }
206212
; Stmt.Fixed.
@@ -294,6 +300,7 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
294300
(Type.to_unsized decl_type)
295301
; decl_id= inline_return_name
296302
; decl_type
303+
; decl_annotations= []
297304
; initialize= false } ]
298305
(* We should minimize the code that's having its variables
299306
replaced to avoid conflict with the (two) new dummy
@@ -972,6 +979,7 @@ let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
972979
{ decl_adtype= Expr.Typed.adlevel_of key
973980
; decl_id= data
974981
; decl_type= Type.Unsized (Expr.Typed.type_of key)
982+
; decl_annotations= [] (* TODO annotations: correct? *)
975983
; initialize= true }
976984
; meta= Location_span.empty }
977985
:: accum) in

src/frontend/Ast.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,15 @@ type ('e, 's, 'l, 'f) statement =
171171
{ decl_type: 'e SizedType.t
172172
; transformation: 'e Transformation.t
173173
; is_global: bool
174+
; annotations: string list
174175
; variables: 'e variable list }
175176
| FunDef of
176177
{ returntype: UnsizedType.returntype
177178
; funname: identifier
178179
; arguments:
179180
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * identifier)
180181
list
182+
; annotations: string list
181183
; body: 's }
182184
[@@deriving sexp, hash, compare, map, fold]
183185

src/frontend/Ast_to_Mir.ml

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,8 @@ let var_constrain_check_stmts dconstrain loc adlevel decl_id decl_var trans
443443
@ check_decl decl_var st trans loc adlevel
444444
| _ -> []
445445

446-
let create_decl_with_assign decl_id declc decl_type initial_value transform
447-
smeta =
446+
let create_decl_with_assign decl_id declc decl_type initial_value
447+
decl_annotations transform smeta =
448448
let rhs = Option.map ~f:trans_expr initial_value in
449449
let decl_adtype =
450450
UnsizedType.fill_adtype_for_type declc.dadlevel (Type.to_unsized decl_type)
@@ -458,7 +458,9 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
458458
() } in
459459
let decl =
460460
Stmt.
461-
{ Fixed.pattern= Decl {decl_adtype; decl_id; decl_type; initialize= true}
461+
{ Fixed.pattern=
462+
Decl
463+
{decl_adtype; decl_id; decl_type; decl_annotations; initialize= true}
462464
; meta= smeta } in
463465
let rhs_assignment =
464466
Option.map
@@ -583,6 +585,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
583585
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
584586
; decl_id= loopvar.name
585587
; decl_type= Unsized decl_type
588+
; decl_annotations= []
586589
; initialize= true } } in
587590
let assignment var =
588591
Stmt.Fixed.
@@ -598,15 +601,16 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
598601
Common.ICE.internal_compiler_error
599602
[%message
600603
"Found function definition statement outside of function block"]
601-
| Ast.VarDecl {decl_type; transformation; variables; is_global= _} ->
604+
| Ast.VarDecl {decl_type; transformation; variables; annotations; is_global= _}
605+
->
602606
List.concat_map
603607
~f:(fun {identifier; initial_value} ->
604608
let transform = Transformation.map trans_expr transformation in
605609
let decl_id = identifier.Ast.name in
606610
let size_checks, dt = check_sizedtype decl_id decl_type in
607611
size_checks
608-
@ create_decl_with_assign decl_id declc dt initial_value transform
609-
smeta)
612+
@ create_decl_with_assign decl_id declc dt initial_value annotations
613+
transform smeta)
610614
variables
611615
| Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap
612616
| Ast.Profile (name, stmts) ->
@@ -629,6 +633,7 @@ and trans_packed_assign loc trans_stmt lvals rhs assign_op =
629633
{ decl_adtype= rhs.emeta.ad_level
630634
; decl_id= sym
631635
; decl_type= Unsized rhs_type
636+
; decl_annotations= []
632637
; initialize= false }
633638
; meta= rhs.emeta.loc } in
634639
let assign =
@@ -696,12 +701,13 @@ and trans_single_assignment smeta assign_lhs assign_rhs assign_op =
696701

697702
let trans_fun_def ud_dists (ts : Ast.typed_statement) =
698703
match ts.stmt with
699-
| Ast.FunDef {returntype; funname; arguments; body} ->
704+
| Ast.FunDef {returntype; funname; arguments; annotations; body} ->
700705
[ Program.
701706
{ fdrt= returntype
702707
; fdname= funname.name
703708
; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto)
704709
; fdargs= List.map ~f:trans_arg arguments
710+
; fdannotations= annotations
705711
; fdbody=
706712
trans_stmt ud_dists
707713
{transform_action= IgnoreTransform; dadlevel= AutoDiffable}
@@ -743,6 +749,7 @@ let rec trans_sizedtype_decl declc tr name st =
743749
{ decl_type= Sized SInt
744750
; decl_id
745751
; decl_adtype= DataOnly
752+
; decl_annotations= []
746753
; initialize= true }
747754
; meta= e.meta.loc } in
748755
let assign =
@@ -821,7 +828,12 @@ let trans_block ud_dists declc block prog =
821828
let f stmt (accum1, accum2, accum3) =
822829
match stmt with
823830
| { Ast.stmt=
824-
VarDecl {decl_type= type_; variables; transformation; is_global= true}
831+
VarDecl
832+
{ decl_type= type_
833+
; variables
834+
; transformation
835+
; annotations
836+
; is_global= true }
825837
; smeta } ->
826838
let outvars, sizes, stmts =
827839
List.unzip3
@@ -839,10 +851,11 @@ let trans_block ud_dists declc block prog =
839851
{ out_constrained_st= type_
840852
; out_unconstrained_st= param_size transform type_
841853
; out_block= block
854+
; out_annotations= annotations
842855
; out_trans= transform } ) in
843856
let stmts =
844857
create_decl_with_assign decl_id declc (Sized type_)
845-
initial_value transform smeta.loc in
858+
initial_value annotations transform smeta.loc in
846859
(outvar, size, stmts))
847860
variables in
848861
( outvars @ accum1

src/frontend/Canonicalize.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ let parens_lval = map_lval_with no_parens Fn.id
6565
let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
6666
let stmt =
6767
match stmt with
68-
| VarDecl {decl_type= d; transformation= t; variables; is_global} ->
68+
| VarDecl
69+
{decl_type= d; transformation= t; variables; annotations; is_global} ->
6970
VarDecl
7071
{ decl_type= Middle.SizedType.map no_parens d
7172
; transformation= Middle.Transformation.map keep_parens t
7273
; variables= List.map ~f:(map_variable no_parens) variables
74+
; annotations
7375
; is_global }
7476
| For {loop_variable; lower_bound; upper_bound; loop_body} ->
7577
For

src/frontend/Pretty_printing.ml

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,13 @@ let rec pp_transformed_type ppf (st, trans) =
382382
pp_possibly_transformed_type (ty, trans)
383383
| _ -> pf ppf "%a" pp_possibly_transformed_type (st, trans)
384384

385+
let pp_annotations ppf ann =
386+
(* TODO better comment handling? *)
387+
if List.is_empty ann then ()
388+
else
389+
let pp ppf s = pf ppf "%@%s" s in
390+
pf ppf "%a@ " (list ~sep:sp pp) ann
391+
385392
let rec pp_indent_unless_block ppf ((s : untyped_statement), loc) =
386393
match s.stmt with
387394
| Block _ -> pp_statement ppf s
@@ -448,16 +455,23 @@ and pp_statement ppf ({stmt= s_content; smeta= {loc}} as ss : untyped_statement)
448455
pf ppf "profile(%s) {@,%a@,}" name
449456
(indented_box pp_list_of_statements)
450457
(vdsl, loc)
451-
| VarDecl {decl_type= pst; transformation= trans; variables; is_global= _} ->
458+
| VarDecl
459+
{ decl_type= pst
460+
; transformation= trans
461+
; variables
462+
; annotations
463+
; is_global= _ } ->
452464
let pp_var ppf {identifier; initial_value} =
453465
pf ppf "%a%a" pp_identifier identifier
454466
(option (fun ppf e -> pf ppf " = %a" pp_expression e))
455467
initial_value in
456-
pf ppf "@[<h>%a %a;@]" pp_transformed_type (pst, trans)
457-
(list ~sep:comma pp_var) variables
458-
| FunDef {returntype= rt; funname= id; arguments= args; body= b} -> (
468+
pf ppf "@[<hv>%a@[<h>%a %a;@]@]" pp_annotations annotations
469+
pp_transformed_type (pst, trans) (list ~sep:comma pp_var) variables
470+
| FunDef {returntype= rt; funname= id; arguments= args; annotations; body= b}
471+
-> (
459472
let loc_of (_, _, id) = id.id_loc in
460-
pf ppf "%a %a(%a" pp_returntype rt pp_identifier id
473+
pf ppf "@[<hv>%a@[<h>%a %a(@]%a@]" pp_annotations annotations
474+
pp_returntype rt pp_identifier id
461475
(box (pp_list_of pp_args loc_of))
462476
(args, {loc with end_loc= b.smeta.loc.begin_loc});
463477
match b with

src/frontend/Typechecker.ml

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,7 @@ and check_transformation cf tenv ut trans =
16591659
TupleTransformation tes
16601660

16611661
and check_var_decl loc cf tenv sized_ty trans
1662-
(variables : untyped_expression Ast.variable list) is_global =
1662+
(variables : untyped_expression Ast.variable list) annotations is_global =
16631663
let checked_type =
16641664
check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty in
16651665
let unsized_type = SizedType.to_unsized checked_type in
@@ -1684,6 +1684,7 @@ and check_var_decl loc cf tenv sized_ty trans
16841684
{ decl_type= checked_type
16851685
; transformation= checked_trans
16861686
; variables= tvariables
1687+
; annotations
16871688
; is_global } in
16881689
(tenv, mk_typed_statement ~stmt ~loc ~return_type:Incomplete)
16891690

@@ -1786,7 +1787,7 @@ and add_function tenv name type_ defined =
17861787
Env.set_raw tenv name (new_fn :: defns)
17871788
else Env.add tenv name type_ defined
17881789

1789-
and check_fundef loc cf tenv return_ty id args body =
1790+
and check_fundef loc cf tenv return_ty id args annotations body =
17901791
List.iter args ~f:(fun (_, _, id) -> verify_identifier id);
17911792
verify_identifier id;
17921793
let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in
@@ -1838,8 +1839,11 @@ and check_fundef loc cf tenv return_ty id args body =
18381839
verify_fundef_return_tys loc return_ty checked_body;
18391840
let stmt =
18401841
FunDef
1841-
{returntype= return_ty; funname= id; arguments= args; body= checked_body}
1842-
in
1842+
{ returntype= return_ty
1843+
; funname= id
1844+
; arguments= args
1845+
; annotations
1846+
; body= checked_body } in
18431847
(* NB: **not** tenv_body, so args don't leak out *)
18441848
(tenv, mk_typed_statement ~return_type:Incomplete ~loc ~stmt)
18451849

@@ -1872,10 +1876,11 @@ and check_statement (cf : context_flags_record) (tenv : Env.t)
18721876
| Block stmts -> (tenv, check_block loc cf tenv stmts)
18731877
| Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl)
18741878
(* these two are special in that they're allowed to change the type environment *)
1875-
| VarDecl {decl_type; transformation; variables; is_global} ->
1876-
check_var_decl loc cf tenv decl_type transformation variables is_global
1877-
| FunDef {returntype; funname; arguments; body} ->
1878-
check_fundef loc cf tenv returntype funname arguments body
1879+
| VarDecl {decl_type; transformation; variables; annotations; is_global} ->
1880+
check_var_decl loc cf tenv decl_type transformation variables annotations
1881+
is_global
1882+
| FunDef {returntype; funname; arguments; annotations; body} ->
1883+
check_fundef loc cf tenv returntype funname arguments annotations body
18791884

18801885
let verify_fun_def_body_in_block = function
18811886
| {stmt= FunDef {body= {stmt= Block _; _}; _}; _}
@@ -1904,7 +1909,8 @@ let add_userdefined_functions tenv stmts_opt =
19041909
| Some {stmts; _} ->
19051910
let f tenv (s : Ast.untyped_statement) =
19061911
match s with
1907-
| {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} ->
1912+
| { stmt= FunDef {returntype; funname; arguments; body; annotations= _}
1913+
; smeta= {loc} } ->
19081914
let arg_types = Ast.type_of_arguments arguments in
19091915
verify_fundef_overloaded loc tenv funname arg_types returntype;
19101916
let defined =

src/frontend/lexer.mll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ rule token = parse
197197
| identifier as id { lexer_logger ("identifier " ^ id) ;
198198
lexer_pos_logger (lexeme_start_p lexbuf);
199199
Parser.IDENTIFIER (lexeme lexbuf) }
200+
(* TODO annotation proper paren lexxing *)
201+
| "@" (non_space_or_newline+ as ann)
202+
{ lexer_logger ("annotation " ^ ann) ;
203+
lexer_pos_logger (lexeme_start_p lexbuf);
204+
add_separator lexbuf ;
205+
Parser.ANNOTATION (lexeme lexbuf |> Core.String.chop_prefix_exn ~prefix:"@") }
200206
(* End of file *)
201207
| eof { lexer_logger "eof" ;
202208
if Preprocessor.size () = 1

0 commit comments

Comments
 (0)