11import type { Node as SyntaxNode } from "web-tree-sitter" ;
22import { FuncFile } from "./FuncFile" ;
33import { Func } from "@server/languages/func/psi/Decls" ;
4+ import { Expression } from "./FuncNode" ;
5+ import { closestNamedSibling } from "@server/psi/utils" ;
46
57type Binding = {
68 identifier : SyntaxNode ,
79 producer_exp : SyntaxNode [ ]
810}
911
1012type BindingResult = {
13+ expression : Expression
1114 lhs : SyntaxNode [ ] ,
1215 rhs : SyntaxNode [ ] ,
1316 bindings : Map < string , Binding >
@@ -18,7 +21,8 @@ export class FunCBindingResolver {
1821
1922 protected funcMap : Map < string , Func > ;
2023 protected bindings : Map < string , Binding > ;
21- constructor ( file : FuncFile ) {
24+
25+ constructor ( readonly file : FuncFile ) {
2226 this . bindings = new Map ( ) ;
2327 this . funcMap = new Map ( ) ;
2428
@@ -45,6 +49,16 @@ export class FunCBindingResolver {
4549 }
4650 }
4751 if ( curChild . isNamed ) {
52+ // If modirying method call
53+ /*
54+ if (curChild.type == "method_call" && curChild.children[0]?.text == "~") {
55+ const firstArg = closestNamedSibling(curChild, 'prev', (sibling => sibling.type == "identifier"))
56+ if (firstArg) {
57+ // Not really lhs, but semantically it is
58+ lhs.push(firstArg)
59+ }
60+ }
61+ */
4862 if ( equalsFound ) {
4963 rhs . push ( curChild ) ;
5064 } else {
@@ -54,6 +68,7 @@ export class FunCBindingResolver {
5468 }
5569
5670 let bindRes : BindingResult = {
71+ expression : new Expression ( expression , this . file ) ,
5772 lhs,
5873 rhs,
5974 bindings : new Map ( )
@@ -67,7 +82,7 @@ export class FunCBindingResolver {
6782 }
6883
6984 const pattern = lhs [ 0 ]
70- this . walkPattern ( pattern , rhs [ 0 ] ) ;
85+ this . walkPattern ( pattern , rhs ) ;
7186
7287 // Copy the map for the output
7388 for ( let [ k , v ] of this . bindings . entries ( ) ) {
@@ -78,72 +93,147 @@ export class FunCBindingResolver {
7893 return bindRes ;
7994 }
8095
81- private walkPattern ( pattern : SyntaxNode , value : SyntaxNode ) {
96+ private walkPattern ( pattern : SyntaxNode , value : SyntaxNode [ ] ) {
8297 if ( ! pattern || pattern . type == "underscore" ) {
8398 return
8499 }
85100
86- switch ( pattern . type ) {
87- case "identifier" :
88- this . bindIdentifier ( pattern , value ) ;
89- break ;
90- case "local_vars_declaration" :
91- const curLhs = pattern . childForFieldName ( "lhs" ) ;
92- if ( ! curLhs ) {
93- throw new Error ( "No lhs in var declaration" )
94- }
95- this . walkPattern ( curLhs , value ) ;
96- break ;
97- case "var_declaration" :
98- this . bindIdentifier ( pattern . childForFieldName ( "name" ) ! , value ) ;
99- break ;
100- case "tensor_vars_declaration" :
101- case "tensor_expression" :
102- this . bindTensor ( pattern , value ) ;
103- break ;
101+ try {
102+ switch ( pattern . type ) {
103+ case "identifier" :
104+ this . bindIdentifier ( pattern , value ) ;
105+ break ;
106+ case "local_vars_declaration" :
107+ const curLhs = pattern . childForFieldName ( "lhs" ) ;
108+ if ( ! curLhs ) {
109+ throw new Error ( "No lhs in var declaration" )
110+ }
111+ this . walkPattern ( curLhs , value ) ;
112+ break ;
113+ case "var_declaration" :
114+ this . bindIdentifier ( pattern . childForFieldName ( "name" ) ! , value ) ;
115+ break ;
116+ case "tensor_vars_declaration" :
117+ case "tensor_expression" :
118+ case "tuple_expression" :
119+ case "parenthesized_expression" :
120+ case "nested_tensor_declaration" :
121+ case "tuple_vars_declaration" :
122+ this . bindCollection ( pattern , value ) ;
123+ break ;
124+ }
125+ } catch ( e ) {
126+ console . error ( `Failed to waks pattern ${ e } ${ pattern } , ${ value } ` )
104127 }
105128 }
106129
107- private bindIdentifier ( target : SyntaxNode , value : SyntaxNode ) {
130+ private bindIdentifier ( target : SyntaxNode , value : SyntaxNode [ ] , checkMethodRhs : boolean = true ) {
131+ if ( checkMethodRhs ) {
132+ value . forEach ( curNode => {
133+ if ( curNode . type == "method_call" ) {
134+ this . bindToMethodCall ( target , curNode ) ;
135+ } else {
136+ // In case calls are in tensor expressions
137+ curNode . descendantsOfType ( "method_call" ) . forEach ( methodCall => {
138+ if ( methodCall ) {
139+ this . bindToMethodCall ( target , methodCall ) ;
140+ }
141+ } )
142+ }
143+ } )
144+ }
108145 this . bindings . set ( target . text , {
109146 identifier : target ,
110- producer_exp : [ value ]
147+ producer_exp : value
111148 } ) ;
112149 }
113150
114- private bindTensor ( target : SyntaxNode , value : SyntaxNode ) {
115- const curValueType = value . type ;
116- if ( curValueType == "function_application" ) {
117- this . bindToFunctionCall ( target , value ) ;
118- } else if ( curValueType == "tensor_expression" ) {
119-
120- for ( let i = 0 ; i < target . namedChildCount ; i ++ ) {
121- const nextTarget = target . namedChildren [ i ] ;
122- if ( ! nextTarget ) {
123- continue ;
151+ private bindCollection ( target : SyntaxNode , value : SyntaxNode [ ] ) {
152+ if ( value . length >= 2 ) {
153+ value . forEach ( ( curNode ) => {
154+ if ( curNode . type == "method_call" ) {
155+ this . bindToMethodCall ( target , curNode ) ;
124156 }
125- const nextValue = value . namedChildren [ i ] ;
126- if ( ! nextValue ) {
127- throw new Error ( "Undefined value" ) ;
157+ } )
158+ } else if ( value . length == 1 ) {
159+ const curValue = value [ 0 ] ;
160+ const curValueType = curValue . type ;
161+ if ( curValueType == "function_application" ) {
162+ this . bindToFunctionCall ( target , curValue ) ;
163+ } else if ( curValueType == "tensor_expression" || curValueType == "tuple_expression" ) {
164+
165+ for ( let i = 0 ; i < target . namedChildCount ; i ++ ) {
166+ const nextTarget = target . namedChildren [ i ] ;
167+ if ( ! nextTarget ) {
168+ continue ;
169+ }
170+ const nextValue = curValue . namedChildren [ i ] ;
171+ if ( ! nextValue ) {
172+ throw new Error ( `Undefined next value ${ curValue } ` ) ;
173+ }
174+ this . walkPattern ( nextTarget , [ nextValue ] ) ;
128175 }
129- this . walkPattern ( nextTarget , nextValue ) ;
176+ } else {
177+ throw new TypeError ( `Type ${ curValueType } is not yet supported!` ) ;
130178 }
131- } else {
132- throw new TypeError ( `Type ${ curValueType } is not yet supported!` ) ;
133- }
134- /*
135- switch (value.type) {
136- case "function_application":
137- break;
138- case "tensor_expression":
139- this.walkPattern(target, value);
140- break;
141- default:
142- throw new Error(`Failed to bind tensor to ${value.type} ${target} ${value}`)
143- }
144- */
179+ }
145180 }
146181
182+ private bindToMethodCall ( target : SyntaxNode , value : SyntaxNode ) {
183+ const isModifying = value . children [ 0 ] ?. text == "~" ;
184+ const methodName = value . childForFieldName ( "method_name" ) ! . text ;
185+ let methodDecl = this . funcMap . get ( methodName ) ;
186+ if ( ! methodDecl ) {
187+ // Thre could be method with ~ prefix being part of the name
188+ if ( isModifying && methodName [ 0 ] !== "~" ) {
189+ methodDecl = this . funcMap . get ( "~" + methodName ) ;
190+ }
191+ }
192+ if ( ! methodDecl ) {
193+ throw new Error ( `Failed to get method declaration ${ methodName } ` )
194+ }
195+ const retType = methodDecl . returnType ( ) ;
196+
197+ if ( ! retType ) {
198+ throw new Error ( `Method ${ methodName } has no return type` )
199+ }
200+ if ( retType . node . type !== "tensor_type" ) {
201+ throw new TypeError ( `Expected tensor_type for modifying method return type got ${ retType . node . type } ` )
202+ }
203+
204+ // For non-modofiying method bind as normal function call;
205+ let bindScope = retType . node ;
206+
207+ if ( isModifying ) {
208+ const firstArg = closestNamedSibling ( value , 'prev' , ( sybl => sybl . type == "identifier" ) ) ;
209+ if ( ! firstArg ) {
210+ throw new Error ( `First arg not found for modifying method call ${ value } ` )
211+ }
212+ this . bindIdentifier ( firstArg , [ value ] , false ) ;
213+ // Next tensor type
214+ let retTensor : SyntaxNode | undefined ;
215+ const childrenCount = bindScope . namedChildCount ;
216+ // First is bound to the first method arg already
217+ for ( let i = 1 ; i < childrenCount ; i ++ ) {
218+ const curChild = bindScope . namedChild ( i ) ;
219+ if ( curChild ?. type == "tensor_type" ) {
220+ retTensor = curChild ;
221+ break ;
222+ }
223+ }
224+ if ( ! retTensor ) {
225+ throw new Error ( `Return tensor not defined for method ${ methodDecl } ` )
226+ }
227+ // If sub tensor is empty, we can return at this point
228+ if ( retTensor . namedChildCount == 0 ) {
229+ return ;
230+ }
231+ // Otherwise bind to the sub-tensor
232+ bindScope = retTensor ;
233+ }
234+
235+ this . bindToReturnType ( target , value , bindScope , false )
236+ }
147237 private bindToFunctionCall ( target : SyntaxNode , value : SyntaxNode ) {
148238 const funcIdentifier = value . childForFieldName ( "callee" ) ;
149239 if ( ! funcIdentifier ) {
@@ -161,8 +251,20 @@ export class FunCBindingResolver {
161251 this . bindToReturnType ( target , value , retType . node ) ;
162252
163253 }
164- private bindToReturnType ( target : SyntaxNode , callNode : SyntaxNode , retType : SyntaxNode ) {
165- const targetFiltered = target . type == "tensor_vars_declaration" ? target . childrenForFieldName ( "vars" ) . filter ( v => v ?. isNamed ) : target . namedChildren ;
254+ private bindToReturnType ( target : SyntaxNode , callNode : SyntaxNode , retType : SyntaxNode , checkMethodRhs : boolean = true ) {
255+ const targetType = target . type ;
256+ let targetFiltered : ( SyntaxNode | null ) [ ] ;
257+ // Hacky,but drop types
258+ if ( targetType == "tensor_vars_declaration" ) {
259+ targetFiltered = target . childrenForFieldName ( "vars" ) . filter ( v => v ?. isNamed ) ;
260+ } else if ( targetType == "var_declaration" || targetType == "identifier" ) {
261+ // Name is only part of var declaration
262+ const identifierNode = target . childForFieldName ( "name" ) ?? target ;
263+ this . bindIdentifier ( identifierNode , [ callNode ] , checkMethodRhs ) ;
264+ return ;
265+ } else {
266+ targetFiltered = target . namedChildren ;
267+ }
166268
167269 if ( targetFiltered . length != retType . namedChildCount ) {
168270 throw new Error ( `Return type arity error ${ target } ${ retType } ` ) ;
@@ -183,24 +285,29 @@ export class FunCBindingResolver {
183285 const patternType = pattern . type ;
184286
185287 switch ( patternType ) {
288+ case "tuple_vars_declaration" :
186289 case "tuple_expression" :
187- if ( bindType != "tuple_type_expression " ) {
290+ if ( bindType != "tuple_type " ) {
188291 throw new Error ( `Can't map ${ patternType } to ${ bindType } ` )
189292 }
190- this . bindToReturnType ( pattern , callNode , bindRhs ) ;
293+ this . bindToReturnType ( pattern , callNode , bindRhs , checkMethodRhs ) ;
294+ break ;
295+ case "local_vars_declaration" :
296+ this . bindToReturnType ( pattern . childForFieldName ( "lhs" ) ! , callNode , bindRhs , checkMethodRhs ) ;
191297 break ;
192298 case "tensor_var_declaration" :
299+ case "nested_tensor_declaration" :
193300 case "tensor_expression" :
194301 if ( bindType !== "tensor_type" ) {
195302 throw new Error ( `Cant map ${ patternType } to ${ bindType } ` )
196303 }
197- this . bindToReturnType ( pattern , callNode , bindRhs ) ;
304+ this . bindToReturnType ( pattern , callNode , bindRhs , checkMethodRhs ) ;
198305 break ;
199306 case "var_declaration" :
200- this . bindIdentifier ( pattern . childForFieldName ( "name" ) ! , callNode ) ;
307+ this . bindIdentifier ( pattern . childForFieldName ( "name" ) ! , [ callNode ] , checkMethodRhs ) ;
201308 break ;
202309 case "identifier" :
203- this . bindIdentifier ( pattern , callNode ) ;
310+ this . bindIdentifier ( pattern , [ callNode ] , checkMethodRhs ) ;
204311 break ;
205312 }
206313 }
0 commit comments