3333import org .apiguardian .api .API ;
3434
3535import org .springframework .data .domain .Sort ;
36+ import org .springframework .data .falkordb .core .query .FalkorDBQueryRewriter ;
37+ import org .springframework .data .falkordb .core .query .RewrittenQuery ;
38+ import org .springframework .lang .Nullable ;
3639import org .springframework .data .falkordb .core .mapping .DefaultFalkorDBEntityConverter ;
3740import org .springframework .data .falkordb .core .mapping .DefaultFalkorDBPersistentEntity ;
3841import org .springframework .data .falkordb .core .mapping .FalkorDBEntityConverter ;
@@ -56,14 +59,22 @@ public class FalkorDBTemplate implements FalkorDBOperations {
5659
5760 private final FalkorDBEntityConverter entityConverter ;
5861
62+ private final @ Nullable FalkorDBQueryRewriter queryRewriter ;
63+
5964 public FalkorDBTemplate (FalkorDBClient falkorDBClient , FalkorDBMappingContext mappingContext ,
6065 FalkorDBEntityConverter entityConverter ) {
66+ this (falkorDBClient , mappingContext , entityConverter , null );
67+ }
68+
69+ public FalkorDBTemplate (FalkorDBClient falkorDBClient , FalkorDBMappingContext mappingContext ,
70+ FalkorDBEntityConverter entityConverter , @ Nullable FalkorDBQueryRewriter queryRewriter ) {
6171 Assert .notNull (falkorDBClient , "FalkorDBClient must not be null" );
6272 Assert .notNull (mappingContext , "FalkorDBMappingContext must not be null" );
6373 Assert .notNull (entityConverter , "FalkorDBEntityConverter must not be null" );
6474
6575 this .falkorDBClient = falkorDBClient ;
6676 this .mappingContext = mappingContext ;
77+ this .queryRewriter = queryRewriter ;
6778
6879 // If the entity converter is DefaultFalkorDBEntityConverter and doesn't have a
6980 // client,
@@ -165,14 +176,9 @@ public <T> Optional<T> findById(Object id, Class<T> clazz) {
165176 String primaryLabel = getPrimaryLabel (persistentEntity );
166177
167178 String cypher = "MATCH (n:" + primaryLabel + ") WHERE id(n) = $id RETURN n" ;
168- Map <String , Object > parameters = Collections .singletonMap ("id" , id );
179+ Map <String , Object > parameters = Collections .singletonMap ("id" , normalizeInternalId ( id ) );
169180
170- return this .falkorDBClient .query (cypher , parameters , result -> {
171- for (FalkorDBClient .Record record : result .records ()) {
172- return Optional .of (this .entityConverter .read (clazz , record ));
173- }
174- return Optional .empty ();
175- });
181+ return queryForObject (cypher , parameters , clazz );
176182 }
177183
178184 @ Override
@@ -191,15 +197,10 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> clazz) {
191197 String primaryLabel = getPrimaryLabel (persistentEntity );
192198
193199 String cypher = "MATCH (n:" + primaryLabel + ") WHERE id(n) IN $ids RETURN n" ;
194- Map <String , Object > parameters = Collections .singletonMap ("ids" , idList );
200+ Map <String , Object > parameters = Collections .singletonMap ("ids" ,
201+ idList .stream ().map (this ::normalizeInternalId ).collect (Collectors .toList ()));
195202
196- return this .falkorDBClient .query (cypher , parameters , result -> {
197- List <T > entities = new ArrayList <>();
198- for (FalkorDBClient .Record record : result .records ()) {
199- entities .add (this .entityConverter .read (clazz , record ));
200- }
201- return entities ;
202- });
203+ return query (cypher , parameters , clazz );
203204 }
204205
205206 @ Override
@@ -211,13 +212,7 @@ public <T> List<T> findAll(Class<T> clazz) {
211212
212213 String cypher = "MATCH (n:" + primaryLabel + ") RETURN n" ;
213214
214- return this .falkorDBClient .query (cypher , Collections .emptyMap (), result -> {
215- List <T > entities = new ArrayList <>();
216- for (FalkorDBClient .Record record : result .records ()) {
217- entities .add (this .entityConverter .read (clazz , record ));
218- }
219- return entities ;
220- });
215+ return query (cypher , Collections .emptyMap (), clazz );
221216 }
222217
223218 @ Override
@@ -238,13 +233,7 @@ public <T> List<T> findAll(Class<T> clazz, Sort sort) {
238233 cypher .append (orderBy );
239234 }
240235
241- return this .falkorDBClient .query (cypher .toString (), Collections .emptyMap (), result -> {
242- List <T > entities = new ArrayList <>();
243- for (FalkorDBClient .Record record : result .records ()) {
244- entities .add (this .entityConverter .read (clazz , record ));
245- }
246- return entities ;
247- });
236+ return query (cypher .toString (), Collections .emptyMap (), clazz );
248237 }
249238
250239 @ Override
@@ -256,7 +245,9 @@ public <T> long count(Class<T> clazz) {
256245
257246 String cypher = "MATCH (n:" + primaryLabel + ") RETURN count(n) as count" ;
258247
259- return this .falkorDBClient .query (cypher , Collections .emptyMap (), result -> {
248+ RewrittenQuery rq = maybeRewrite (cypher , Collections .emptyMap (), clazz );
249+
250+ return this .falkorDBClient .query (rq .getCypher (), rq .getParameters (), result -> {
260251 for (FalkorDBClient .Record record : result .records ()) {
261252 Object count = record .get ("count" );
262253 return (count instanceof Number ) ? ((Number ) count ).longValue () : 0L ;
@@ -274,9 +265,11 @@ public <T> boolean existsById(Object id, Class<T> clazz) {
274265 String primaryLabel = getPrimaryLabel (persistentEntity );
275266
276267 String cypher = "MATCH (n:" + primaryLabel + ") WHERE id(n) = $id RETURN count(n) > 0 as exists" ;
277- Map <String , Object > parameters = Collections .singletonMap ("id" , id );
268+ Map <String , Object > parameters = Collections .singletonMap ("id" , normalizeInternalId ( id ) );
278269
279- return this .falkorDBClient .query (cypher , parameters , result -> {
270+ RewrittenQuery rq = maybeRewrite (cypher , parameters , clazz );
271+
272+ return this .falkorDBClient .query (rq .getCypher (), rq .getParameters (), result -> {
280273 for (FalkorDBClient .Record record : result .records ()) {
281274 Object exists = record .get ("exists" );
282275 return (exists instanceof Boolean ) ? (Boolean ) exists : false ;
@@ -294,7 +287,7 @@ public <T> void deleteById(Object id, Class<T> clazz) {
294287 String primaryLabel = getPrimaryLabel (persistentEntity );
295288
296289 String cypher = "MATCH (n:" + primaryLabel + ") WHERE id(n) = $id DELETE n" ;
297- Map <String , Object > parameters = Collections .singletonMap ("id" , id );
290+ Map <String , Object > parameters = Collections .singletonMap ("id" , normalizeInternalId ( id ) );
298291
299292 this .falkorDBClient .query (cypher , parameters );
300293 }
@@ -315,7 +308,8 @@ public <T> void deleteAllById(Iterable<?> ids, Class<T> clazz) {
315308 String primaryLabel = getPrimaryLabel (persistentEntity );
316309
317310 String cypher = "MATCH (n:" + primaryLabel + ") WHERE id(n) IN $ids DELETE n" ;
318- Map <String , Object > parameters = Collections .singletonMap ("ids" , idList );
311+ Map <String , Object > parameters = Collections .singletonMap ("ids" ,
312+ idList .stream ().map (this ::normalizeInternalId ).collect (Collectors .toList ()));
319313
320314 this .falkorDBClient .query (cypher , parameters );
321315 }
@@ -338,7 +332,9 @@ public <T> List<T> query(String cypher, Map<String, Object> parameters, Class<T>
338332 Assert .notNull (parameters , "Parameters must not be null" );
339333 Assert .notNull (clazz , "Class must not be null" );
340334
341- return this .falkorDBClient .query (cypher , parameters , result -> {
335+ RewrittenQuery rq = maybeRewrite (cypher , parameters , clazz );
336+
337+ return this .falkorDBClient .query (rq .getCypher (), rq .getParameters (), result -> {
342338 List <T > entities = new ArrayList <>();
343339 for (FalkorDBClient .Record record : result .records ()) {
344340 entities .add (this .entityConverter .read (clazz , record ));
@@ -353,7 +349,9 @@ public <T> Optional<T> queryForObject(String cypher, Map<String, Object> paramet
353349 Assert .notNull (parameters , "Parameters must not be null" );
354350 Assert .notNull (clazz , "Class must not be null" );
355351
356- return this .falkorDBClient .query (cypher , parameters , result -> {
352+ RewrittenQuery rq = maybeRewrite (cypher , parameters , clazz );
353+
354+ return this .falkorDBClient .query (rq .getCypher (), rq .getParameters (), result -> {
357355 for (FalkorDBClient .Record record : result .records ()) {
358356 return Optional .of (this .entityConverter .read (clazz , record ));
359357 }
@@ -368,7 +366,15 @@ public <T> T query(String cypher, Map<String, Object> parameters,
368366 Assert .notNull (parameters , "Parameters must not be null" );
369367 Assert .notNull (resultMapper , "Result mapper must not be null" );
370368
371- return this .falkorDBClient .query (cypher , parameters , resultMapper );
369+ RewrittenQuery rq = maybeRewrite (cypher , parameters , null );
370+ return this .falkorDBClient .query (rq .getCypher (), rq .getParameters (), resultMapper );
371+ }
372+
373+ private RewrittenQuery maybeRewrite (String cypher , Map <String , Object > parameters , @ Nullable Class <?> domainType ) {
374+ if (this .queryRewriter == null ) {
375+ return RewrittenQuery .of (cypher , parameters );
376+ }
377+ return this .queryRewriter .rewrite (cypher , parameters , domainType );
372378 }
373379
374380 /**
@@ -389,14 +395,29 @@ public FalkorDBMappingContext getMappingContext() {
389395 return this .mappingContext ;
390396 }
391397
398+ private Object normalizeInternalId (Object id ) {
399+ if (id instanceof Long ) {
400+ Long l = (Long ) id ;
401+ if (l >= Integer .MIN_VALUE && l <= Integer .MAX_VALUE ) {
402+ return l .intValue ();
403+ }
404+ }
405+ return id ;
406+ }
407+
392408 private String getPrimaryLabel (DefaultFalkorDBPersistentEntity <?> persistentEntity ) {
393409 // Get the primary label from the @Node annotation
394410 Node nodeAnnotation = persistentEntity .getType ().getAnnotation (Node .class );
395411 if (nodeAnnotation != null ) {
396412 if (!nodeAnnotation .primaryLabel ().isEmpty ()) {
397413 return nodeAnnotation .primaryLabel ();
398414 }
399- else if (nodeAnnotation .labels ().length > 0 ) {
415+ // Note: @AliasFor is not applied when reading the annotation via plain reflection.
416+ // We therefore need to check both value() and labels().
417+ if (nodeAnnotation .value ().length > 0 ) {
418+ return nodeAnnotation .value ()[0 ];
419+ }
420+ if (nodeAnnotation .labels ().length > 0 ) {
400421 return nodeAnnotation .labels ()[0 ];
401422 }
402423 }
0 commit comments