@@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects';
2525import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect' ;
2626import { InternalError , QueryError } from '../../client/errors' ;
2727import type { ClientOptions } from '../../client/options' ;
28- import { getModel , getRelationForeignKeyFieldPairs , requireField } from '../../client/query-utils' ;
28+ import { getIdFields , getModel , getRelationForeignKeyFieldPairs , requireField } from '../../client/query-utils' ;
2929import type {
3030 BinaryExpression ,
3131 BinaryOperator ,
@@ -111,7 +111,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
111111 }
112112
113113 @expr ( 'field' )
114- // @ts -expect-error
115114 private _field ( expr : FieldExpression , context : ExpressionTransformerContext < Schema > ) {
116115 const fieldDef = requireField ( this . schema , context . model , expr . field ) ;
117116 if ( ! fieldDef . relation ) {
@@ -162,8 +161,9 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
162161 return this . transformCollectionPredicate ( expr , context ) ;
163162 }
164163
165- const left = this . transform ( expr . left , context ) ;
166- const right = this . transform ( expr . right , context ) ;
164+ const { normalizedLeft, normalizedRight } = this . normalizeBinaryOperationOperands ( expr , context ) ;
165+ const left = this . transform ( normalizedLeft , context ) ;
166+ const right = this . transform ( normalizedRight , context ) ;
167167
168168 if ( op === 'in' ) {
169169 if ( this . isNullNode ( left ) ) {
@@ -195,6 +195,22 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
195195 return BinaryOperationNode . create ( left , this . transformOperator ( op ) , right ) ;
196196 }
197197
198+ private normalizeBinaryOperationOperands ( expr : BinaryExpression , context : ExpressionTransformerContext < Schema > ) {
199+ let normalizedLeft : Expression = expr . left ;
200+ if ( this . isRelationField ( expr . left , context . model ) ) {
201+ invariant ( ExpressionUtils . isNull ( expr . right ) , 'only null comparison is supported for relation field' ) ;
202+ const idFields = getIdFields ( this . schema , context . model ) ;
203+ normalizedLeft = this . makeOrAppendMember ( normalizedLeft , idFields [ 0 ] ! ) ;
204+ }
205+ let normalizedRight : Expression = expr . right ;
206+ if ( this . isRelationField ( expr . right , context . model ) ) {
207+ invariant ( ExpressionUtils . isNull ( expr . left ) , 'only null comparison is supported for relation field' ) ;
208+ const idFields = getIdFields ( this . schema , context . model ) ;
209+ normalizedRight = this . makeOrAppendMember ( normalizedRight , idFields [ 0 ] ! ) ;
210+ }
211+ return { normalizedLeft, normalizedRight } ;
212+ }
213+
198214 private transformCollectionPredicate ( expr : BinaryExpression , context : ExpressionTransformerContext < Schema > ) {
199215 invariant ( expr . op === '?' || expr . op === '!' || expr . op === '^' , 'expected "?" or "!" or "^" operator' ) ;
200216
@@ -211,11 +227,15 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
211227 ) ;
212228
213229 let newContextModel : string ;
214- if ( ExpressionUtils . isField ( expr . left ) ) {
215- const fieldDef = requireField ( this . schema , context . model , expr . left . field ) ;
230+ const fieldDef = this . getFieldDefFromFieldRef ( expr . left , context . model ) ;
231+ if ( fieldDef ) {
232+ invariant ( fieldDef . relation , `field is not a relation: ${ JSON . stringify ( expr . left ) } ` ) ;
216233 newContextModel = fieldDef . type ;
217234 } else {
218- invariant ( ExpressionUtils . isField ( expr . left . receiver ) ) ;
235+ invariant (
236+ ExpressionUtils . isMember ( expr . left ) && ExpressionUtils . isField ( expr . left . receiver ) ,
237+ 'left operand must be member access with field receiver' ,
238+ ) ;
219239 const fieldDef = requireField ( this . schema , context . model , expr . left . receiver . field ) ;
220240 newContextModel = fieldDef . type ;
221241 for ( const member of expr . left . members ) {
@@ -396,16 +416,14 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
396416
397417 if ( ExpressionUtils . isThis ( expr . receiver ) ) {
398418 if ( expr . members . length === 1 ) {
399- // optimize for the simple this.scalar case
400- const fieldDef = requireField ( this . schema , context . model , expr . members [ 0 ] ! ) ;
401- invariant ( ! fieldDef . relation , 'this.relation access should have been transformed into relation access' ) ;
402- return this . createColumnRef ( expr . members [ 0 ] ! , restContext ) ;
419+ // `this.relation` case, equivalent to field access
420+ return this . _field ( ExpressionUtils . field ( expr . members [ 0 ] ! ) , context ) ;
421+ } else {
422+ // transform the first segment into a relation access, then continue with the rest of the members
423+ const firstMemberFieldDef = requireField ( this . schema , context . model , expr . members [ 0 ] ! ) ;
424+ receiver = this . transformRelationAccess ( expr . members [ 0 ] ! , firstMemberFieldDef . type , restContext ) ;
425+ members = expr . members . slice ( 1 ) ;
403426 }
404-
405- // transform the first segment into a relation access, then continue with the rest of the members
406- const firstMemberFieldDef = requireField ( this . schema , context . model , expr . members [ 0 ] ! ) ;
407- receiver = this . transformRelationAccess ( expr . members [ 0 ] ! , firstMemberFieldDef . type , restContext ) ;
408- members = expr . members . slice ( 1 ) ;
409427 } else {
410428 receiver = this . transform ( expr . receiver , restContext ) ;
411429 }
@@ -559,4 +577,23 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
559577 return conditions . reduce ( ( acc , condition ) => ExpressionUtils . binary ( acc , '&&' , condition ) ) ;
560578 }
561579 }
580+
581+ private isRelationField ( expr : Expression , model : GetModels < Schema > ) {
582+ const fieldDef = this . getFieldDefFromFieldRef ( expr , model ) ;
583+ return ! ! fieldDef ?. relation ;
584+ }
585+
586+ private getFieldDefFromFieldRef ( expr : Expression , model : GetModels < Schema > ) : FieldDef | undefined {
587+ if ( ExpressionUtils . isField ( expr ) ) {
588+ return requireField ( this . schema , model , expr . field ) ;
589+ } else if (
590+ ExpressionUtils . isMember ( expr ) &&
591+ expr . members . length === 1 &&
592+ ExpressionUtils . isThis ( expr . receiver )
593+ ) {
594+ return requireField ( this . schema , model , expr . members [ 0 ] ! ) ;
595+ } else {
596+ return undefined ;
597+ }
598+ }
562599}
0 commit comments