Skip to content

Commit ea28dca

Browse files
committed
GH-2216 - Allow concrete class inheritance.
Closes #2216
1 parent 5fe0621 commit ea28dca

File tree

7 files changed

+50
-24
lines changed

7 files changed

+50
-24
lines changed

src/main/java/org/springframework/data/neo4j/core/mapping/Neo4jMappingContext.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.springframework.data.neo4j.core.convert.Neo4jPersistentPropertyConverterFactory;
4949
import org.springframework.data.neo4j.core.convert.Neo4jSimpleTypes;
5050
import org.springframework.data.neo4j.core.schema.IdGenerator;
51+
import org.springframework.data.neo4j.core.schema.Node;
5152
import org.springframework.data.util.ReflectionUtils;
5253
import org.springframework.data.util.TypeInformation;
5354
import org.springframework.lang.Nullable;
@@ -209,11 +210,13 @@ protected <T> Neo4jPersistentEntity<?> createPersistentEntity(TypeInformation<T>
209210
}
210211

211212
private static boolean isValidParentNode(@Nullable Class<?> parentClass) {
212-
if (parentClass == null) {
213+
if (parentClass == null || parentClass.equals(Object.class)) {
213214
return false;
214215
}
215216

216-
return Modifier.isAbstract(parentClass.getModifiers());
217+
// Either a concrete class explicitly annotated as Node or an abstract class
218+
return Modifier.isAbstract(parentClass.getModifiers()) ||
219+
parentClass.isAnnotationPresent(Node.class);
217220
}
218221

219222
/*

src/main/java/org/springframework/data/neo4j/core/mapping/NodeDescriptionStore.java

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
*/
1616
package org.springframework.data.neo4j.core.mapping;
1717

18+
import java.lang.reflect.Modifier;
1819
import java.util.Collection;
1920
import java.util.Collections;
21+
import java.util.Comparator;
2022
import java.util.HashMap;
2123
import java.util.HashSet;
2224
import java.util.List;
2325
import java.util.Map;
26+
import java.util.Optional;
2427
import java.util.Set;
25-
import java.util.function.BiFunction;
28+
import java.util.function.Function;
29+
import java.util.stream.Collectors;
2630

2731
import org.springframework.data.mapping.context.AbstractMappingContext;
2832
import org.springframework.lang.Nullable;
@@ -79,23 +83,30 @@ public NodeDescription<?> getNodeDescription(Class<?> targetType) {
7983

8084
public NodeDescriptionAndLabels deriveConcreteNodeDescription(Neo4jPersistentEntity<?> entityDescription, List<String> labels) {
8185

82-
if (labels == null || labels.isEmpty()) {
86+
boolean isConcreteClassThatFulfillsEverything = !Modifier.isAbstract(entityDescription.getUnderlyingClass().getModifiers()) && entityDescription.getStaticLabels().containsAll(labels);
87+
88+
if (labels == null || labels.isEmpty() || isConcreteClassThatFulfillsEverything) {
8389
return new NodeDescriptionAndLabels(entityDescription, Collections.emptyList());
8490
}
8591

8692
Collection<NodeDescription<?>> haystack;
87-
BiFunction<List<String>, NodeDescription<?>, Boolean> selector;
8893
if (entityDescription.describesInterface()) {
8994
haystack = this.values();
90-
selector = (staticLabels, other) -> staticLabels.containsAll(labels) && entityDescription.getType().isAssignableFrom(((Neo4jPersistentEntity<?>) other).getType());
9195
} else {
9296
haystack = entityDescription.getChildNodeDescriptionsInHierarchy();
93-
selector = (staticLabels, other) -> staticLabels.containsAll(labels) && other.getChildNodeDescriptionsInHierarchy().isEmpty();
9497
}
9598

96-
for (NodeDescription<?> childNodeDescription : haystack) {
97-
List<String> staticLabels = childNodeDescription.getStaticLabels();
98-
if (selector.apply(staticLabels, childNodeDescription)) {
99+
if (!haystack.isEmpty()) {
100+
Function<NodeDescription<?>, Integer> count = (nodeDescription) -> Math.toIntExact(nodeDescription.getStaticLabels().stream().filter(labels::contains).count());
101+
Optional<Map.Entry<NodeDescription<?>, Integer>> mostMatchingNodeDescription = haystack.stream()
102+
.filter(nd -> labels.containsAll(nd.getStaticLabels())) // remove candidates having more mandatory labels
103+
.collect(Collectors.toMap(Function.identity(), nodeDescription -> count.apply(nodeDescription)))
104+
.entrySet().stream()
105+
.max(Comparator.comparingInt(Map.Entry::getValue));
106+
107+
if (mostMatchingNodeDescription.isPresent()) {
108+
NodeDescription<?> childNodeDescription = mostMatchingNodeDescription.get().getKey();
109+
List<String> staticLabels = childNodeDescription.getStaticLabels();
99110
Set<String> surplusLabels = new HashSet<>(labels);
100111
surplusLabels.removeAll(staticLabels);
101112
return new NodeDescriptionAndLabels(childNodeDescription, surplusLabels);

src/test/java/org/springframework/data/neo4j/integration/imperative/DynamicLabelsIT.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,14 @@
4444
import org.springframework.context.annotation.Bean;
4545
import org.springframework.context.annotation.Configuration;
4646
import org.springframework.data.neo4j.config.AbstractNeo4jConfig;
47+
import org.springframework.data.neo4j.config.Neo4jEntityScanner;
4748
import org.springframework.data.neo4j.core.DatabaseSelectionProvider;
4849
import org.springframework.data.neo4j.core.Neo4jTemplate;
50+
import org.springframework.data.neo4j.core.convert.Neo4jConversions;
51+
import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext;
4952
import org.springframework.data.neo4j.core.transaction.Neo4jBookmarkManager;
5053
import org.springframework.data.neo4j.core.transaction.Neo4jTransactionManager;
54+
import org.springframework.data.neo4j.integration.shared.common.EntitiesWithDynamicLabels;
5155
import org.springframework.data.neo4j.integration.shared.common.EntitiesWithDynamicLabels.DynamicLabelsWithMultipleNodeLabels;
5256
import org.springframework.data.neo4j.integration.shared.common.EntitiesWithDynamicLabels.DynamicLabelsWithNodeLabel;
5357
import org.springframework.data.neo4j.integration.shared.common.EntitiesWithDynamicLabels.ExtendedBaseClass1;
@@ -83,7 +87,7 @@ class EntityWithSingleStaticLabelAndGeneratedId extends SpringTestBase {
8387
@Override
8488
Long createTestEntity(Transaction transaction) {
8589
Record r = transaction
86-
.run("CREATE (e:SimpleDynamicLabels:Foo:Bar:Baz:Foobar) RETURN id(e) as existingEntityId").single();
90+
.run("CREATE (e:InheritedSimpleDynamicLabels:SimpleDynamicLabels:Foo:Bar:Baz:Foobar) RETURN id(e) as existingEntityId").single();
8791
long newId = r.get("existingEntityId").asLong();
8892
transaction.commit();
8993
return newId;
@@ -108,7 +112,7 @@ void shouldUpdateDynamicLabels(@Autowired Neo4jTemplate template) {
108112
});
109113

110114
List<String> labels = getLabels(existingEntityId);
111-
assertThat(labels).containsExactlyInAnyOrder("SimpleDynamicLabels", "Fizz", "Bar", "Baz", "Foobar");
115+
assertThat(labels).containsExactlyInAnyOrder("SimpleDynamicLabels", "InheritedSimpleDynamicLabels", "Fizz", "Bar", "Baz", "Foobar");
112116
}
113117

114118
@Test
@@ -151,7 +155,7 @@ class EntityWithInheritedDynamicLabels extends SpringTestBase {
151155
@Override
152156
Long createTestEntity(Transaction transaction) {
153157
Record r = transaction
154-
.run("CREATE (e:InheritedSimpleDynamicLabels:Foo:Bar:Baz:Foobar) RETURN id(e) as existingEntityId")
158+
.run("CREATE (e:InheritedSimpleDynamicLabels:SimpleDynamicLabels:Foo:Bar:Baz:Foobar) RETURN id(e) as existingEntityId")
155159
.single();
156160
long newId = r.get("existingEntityId").asLong();
157161
transaction.commit();
@@ -179,7 +183,7 @@ void shouldUpdateDynamicLabels(@Autowired Neo4jTemplate template) {
179183
});
180184

181185
List<String> labels = getLabels(existingEntityId);
182-
assertThat(labels).containsExactlyInAnyOrder("InheritedSimpleDynamicLabels", "Fizz", "Bar", "Baz", "Foobar");
186+
assertThat(labels).containsExactlyInAnyOrder("SimpleDynamicLabels", "InheritedSimpleDynamicLabels", "Fizz", "Bar", "Baz", "Foobar");
183187
}
184188

185189
@Test
@@ -195,7 +199,7 @@ void shouldWriteDynamicLabels(@Autowired Neo4jTemplate template) {
195199
});
196200

197201
List<String> labels = getLabels(id);
198-
assertThat(labels).containsExactlyInAnyOrder("InheritedSimpleDynamicLabels", "A", "B", "C");
202+
assertThat(labels).containsExactlyInAnyOrder("SimpleDynamicLabels", "InheritedSimpleDynamicLabels", "A", "B", "C");
199203
}
200204
}
201205

@@ -507,6 +511,14 @@ public PlatformTransactionManager transactionManager(Driver driver, DatabaseSele
507511
public TransactionTemplate transactionTemplate(PlatformTransactionManager transactionManager) {
508512
return new TransactionTemplate(transactionManager);
509513
}
514+
515+
@Bean
516+
public Neo4jMappingContext neo4jMappingContext(Neo4jConversions neo4JConversions) throws ClassNotFoundException {
517+
518+
Neo4jMappingContext mappingContext = new Neo4jMappingContext(neo4JConversions);
519+
mappingContext.setInitialEntitySet(Neo4jEntityScanner.get().scan(EntitiesWithDynamicLabels.class.getPackage().getName()));
520+
return mappingContext;
521+
}
510522
}
511523
}
512524
}

src/test/java/org/springframework/data/neo4j/integration/properties/PropertyIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void assignedIdWithVersionShouldNotOverwriteUnknownProperties() {
8383

8484
try (Session session = driver.session(bookmarkCapture.createSessionConfig())) {
8585
session.run(
86-
"CREATE (m:SimplePropertyContainerWithVersion {id: 'id1', version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
86+
"CREATE (m:SimplePropertyContainerWithVersion:SimplePropertyContainer {id: 'id1', version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
8787
.consume();
8888
bookmarkCapture.seedWith(session.lastBookmark());
8989
}
@@ -111,7 +111,7 @@ void generatedIdWithVersionShouldNotOverwriteUnknownProperties() {
111111
Long id;
112112
try (Session session = driver.session(bookmarkCapture.createSessionConfig())) {
113113
id = session
114-
.run("CREATE (m:SimpleGeneratedIDPropertyContainerWithVersion {version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
114+
.run("CREATE (m:SimpleGeneratedIDPropertyContainerWithVersion:SimpleGeneratedIDPropertyContainer {version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
115115
.single().get(0).asLong();
116116
bookmarkCapture.seedWith(session.lastBookmark());
117117
}

src/test/java/org/springframework/data/neo4j/integration/properties/ReactivePropertyIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void assignedIdWithVersionShouldNotOverwriteUnknownProperties() {
8585

8686
try (Session session = driver.session(bookmarkCapture.createSessionConfig())) {
8787
session.run(
88-
"CREATE (m:SimplePropertyContainerWithVersion {id: 'id1', version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
88+
"CREATE (m:SimplePropertyContainer:SimplePropertyContainerWithVersion {id: 'id1', version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
8989
.consume();
9090
bookmarkCapture.seedWith(session.lastBookmark());
9191
}
@@ -113,7 +113,7 @@ void generatedIdWithVersionShouldNotOverwriteUnknownProperties() {
113113
Long id;
114114
try (Session session = driver.session(bookmarkCapture.createSessionConfig())) {
115115
id = session
116-
.run("CREATE (m:SimpleGeneratedIDPropertyContainerWithVersion {version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
116+
.run("CREATE (m:SimpleGeneratedIDPropertyContainer:SimpleGeneratedIDPropertyContainerWithVersion {version: 1, knownProperty: 'A', unknownProperty: 'Mr. X'}) RETURN id(m)")
117117
.single().get(0).asLong();
118118
bookmarkCapture.seedWith(session.lastBookmark());
119119
}

src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveDynamicLabelsIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class EntityWithInheritedDynamicLabels extends SpringTestBase {
151151
@Override
152152
Long createTestEntity(Transaction transaction) {
153153
Record r = transaction
154-
.run("" + "CREATE (e:InheritedSimpleDynamicLabels:Foo:Bar:Baz:Foobar) " + "RETURN id(e) as existingEntityId")
154+
.run("" + "CREATE (e:SimpleDynamicLabels:InheritedSimpleDynamicLabels:Foo:Bar:Baz:Foobar) " + "RETURN id(e) as existingEntityId")
155155
.single();
156156
long newId = r.get("existingEntityId").asLong();
157157
transaction.commit();
@@ -175,7 +175,7 @@ void shouldUpdateDynamicLabels(@Autowired ReactiveNeo4jTemplate template) {
175175
return template.save(entity);
176176
}).as(transactionalOperator::transactional)
177177
.thenMany(getLabels(existingEntityId)).sort().as(StepVerifier::create)
178-
.expectNext("Bar", "Baz", "Fizz", "Foobar", "InheritedSimpleDynamicLabels").verifyComplete();
178+
.expectNext("Bar", "Baz", "Fizz", "Foobar", "InheritedSimpleDynamicLabels", "SimpleDynamicLabels").verifyComplete();
179179
}
180180

181181
@Test
@@ -190,7 +190,7 @@ void shouldWriteDynamicLabels(@Autowired ReactiveNeo4jTemplate template) {
190190
template.save(entity).map(SimpleDynamicLabels::getId)
191191
.as(transactionalOperator::transactional)
192192
.flatMapMany(this::getLabels).sort().as(StepVerifier::create)
193-
.expectNext("A", "B", "C", "InheritedSimpleDynamicLabels").verifyComplete();
193+
.expectNext("A", "B", "C", "InheritedSimpleDynamicLabels", "SimpleDynamicLabels").verifyComplete();
194194
}
195195
}
196196

src/test/kotlin/org/springframework/data/neo4j/integration/imperative/KotlinInheritanceIT.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class KotlinInheritanceIT @Autowired constructor(
153153
// Note: The open base class used here is not abstract, there fore labels are not inherited
154154
val cnt = driver.session(bookmarkCapture.createSessionConfig()).use { session ->
155155
session.readTransaction { tx ->
156-
tx.run("MATCH (t:ConcreteNodeWithOpenKotlinBase) WHERE NOT t:OpenKotlinBase AND id(t) = \$id RETURN count(t)", mapOf("id" to thing.id)).single()[0].asLong()
156+
tx.run("MATCH (t:ConcreteNodeWithOpenKotlinBase:OpenKotlinBase) WHERE id(t) = \$id RETURN count(t)", mapOf("id" to thing.id)).single()[0].asLong()
157157
}
158158
}
159159
assertThat(cnt).isEqualTo(1L)
@@ -185,7 +185,7 @@ class KotlinInheritanceIT @Autowired constructor(
185185
// Note: The open base class used here is not abstract, there fore labels are not inherited
186186
val cnt = driver.session(bookmarkCapture.createSessionConfig()).use { session ->
187187
session.readTransaction { tx ->
188-
tx.run("MATCH (t:ConcreteDataNodeWithOpenKotlinBase) WHERE NOT t:OpenKotlinBase AND id(t) = \$id RETURN count(t)", mapOf("id" to thing.id)).single()[0].asLong()
188+
tx.run("MATCH (t:ConcreteDataNodeWithOpenKotlinBase) WHERE id(t) = \$id RETURN count(t)", mapOf("id" to thing.id)).single()[0].asLong()
189189
}
190190
}
191191
assertThat(cnt).isEqualTo(1L)

0 commit comments

Comments
 (0)