1818import java .lang .reflect .Modifier ;
1919import java .util .Collection ;
2020import java .util .Collections ;
21- import java .util .Comparator ;
2221import java .util .HashMap ;
2322import java .util .HashSet ;
2423import java .util .List ;
2524import java .util .Map ;
26- import java .util .Optional ;
2725import java .util .Set ;
28- import java .util .function .Function ;
29- import java .util .stream .Collectors ;
26+ import java .util .function .BiFunction ;
3027
3128import org .springframework .data .mapping .context .AbstractMappingContext ;
3229import org .springframework .lang .Nullable ;
@@ -46,6 +43,24 @@ final class NodeDescriptionStore {
4643 */
4744 private final Map <String , NodeDescription <?>> nodeDescriptionsByPrimaryLabel = new HashMap <>();
4845
46+ private final Map <NodeDescription <?>, Map <List <String >, NodeDescriptionAndLabels >> nodeDescriptionAndLabelsCache = new HashMap <>();
47+
48+ private final BiFunction <NodeDescription <?>, List <String >, NodeDescriptionAndLabels > nodeDescriptionAndLabels =
49+ (nodeDescription , labels ) -> {
50+ Map <List <String >, NodeDescriptionAndLabels > listNodeDescriptionAndLabelsMap = nodeDescriptionAndLabelsCache .get (nodeDescription );
51+ if (listNodeDescriptionAndLabelsMap == null ) {
52+ nodeDescriptionAndLabelsCache .put (nodeDescription , new HashMap <>());
53+ listNodeDescriptionAndLabelsMap = nodeDescriptionAndLabelsCache .get (nodeDescription );
54+ }
55+
56+ NodeDescriptionAndLabels cachedNodeDescriptionAndLabels = listNodeDescriptionAndLabelsMap .get (labels );
57+ if (cachedNodeDescriptionAndLabels == null ) {
58+ cachedNodeDescriptionAndLabels = computeConcreteNodeDescription (nodeDescription , labels );
59+ listNodeDescriptionAndLabelsMap .put (labels , cachedNodeDescriptionAndLabels );
60+ }
61+ return cachedNodeDescriptionAndLabels ;
62+ };
63+
4964 public boolean containsKey (String primaryLabel ) {
5065 return nodeDescriptionsByPrimaryLabel .containsKey (primaryLabel );
5166 }
@@ -81,7 +96,11 @@ public NodeDescription<?> getNodeDescription(Class<?> targetType) {
8196 return null ;
8297 }
8398
84- public NodeDescriptionAndLabels deriveConcreteNodeDescription (Neo4jPersistentEntity <?> entityDescription , List <String > labels ) {
99+ public NodeDescriptionAndLabels deriveConcreteNodeDescription (NodeDescription <?> entityDescription , List <String > labels ) {
100+ return nodeDescriptionAndLabels .apply (entityDescription , labels );
101+ }
102+
103+ private NodeDescriptionAndLabels computeConcreteNodeDescription (NodeDescription <?> entityDescription , List <String > labels ) {
85104
86105 boolean isConcreteClassThatFulfillsEverything = !Modifier .isAbstract (entityDescription .getUnderlyingClass ().getModifiers ()) && entityDescription .getStaticLabels ().containsAll (labels );
87106
@@ -97,25 +116,48 @@ public NodeDescriptionAndLabels deriveConcreteNodeDescription(Neo4jPersistentEnt
97116 }
98117
99118 if (!haystack .isEmpty ()) {
100- Function <NodeDescription <?>, Integer > count = (nodeDescription ) -> Math .toIntExact (nodeDescription .getStaticLabels ().stream ().filter (labels ::contains ).count ());
101- Optional <Map .Entry <NodeDescription <?>, Integer >> mostMatchingNodeDescription = haystack .stream ()
102- .filter (nd -> labels .containsAll (nd .getStaticLabels ())) // remove candidates having more mandatory labels
103- .collect (Collectors .toMap (Function .identity (), nodeDescription -> count .apply (nodeDescription )))
104- .entrySet ().stream ()
105- .max (Comparator .comparingInt (Map .Entry ::getValue ));
106-
107- if (mostMatchingNodeDescription .isPresent ()) {
108- NodeDescription <?> childNodeDescription = mostMatchingNodeDescription .get ().getKey ();
109- List <String > staticLabels = childNodeDescription .getStaticLabels ();
110- Set <String > surplusLabels = new HashSet <>(labels );
111- surplusLabels .removeAll (staticLabels );
112- return new NodeDescriptionAndLabels (childNodeDescription , surplusLabels );
119+
120+ NodeDescription <?> mostMatchingNodeDescription = null ;
121+ Map <NodeDescription <?>, Integer > unmatchedLabelsCache = new HashMap <>();
122+ List <String > mostMatchingStaticLabels = null ;
123+
124+ // Remove is faster than "stream, filter, count".
125+ BiFunction <NodeDescription <?>, List <String >, Integer > unmatchedLabelsCount =
126+ (nodeDescription , staticLabels ) -> {
127+ Set <String > staticLabelsClone = new HashSet <>(staticLabels );
128+ labels .forEach (staticLabelsClone ::remove );
129+ return staticLabelsClone .size ();
130+ };
131+
132+ for (NodeDescription <?> nd : haystack ) {
133+ List <String > staticLabels = nd .getStaticLabels ();
134+
135+ if (staticLabels .containsAll (labels )) {
136+ Set <String > surplusLabels = new HashSet <>(labels );
137+ staticLabels .forEach (surplusLabels ::remove );
138+ return new NodeDescriptionAndLabels (nd , surplusLabels );
139+ }
140+
141+ unmatchedLabelsCache .put (nd , unmatchedLabelsCount .apply (nd , staticLabels ));
142+ if (mostMatchingNodeDescription == null ) {
143+ mostMatchingNodeDescription = nd ;
144+ mostMatchingStaticLabels = staticLabels ;
145+ continue ;
146+ }
147+
148+ if (unmatchedLabelsCache .get (nd ) < unmatchedLabelsCache .get (mostMatchingNodeDescription )) {
149+ mostMatchingNodeDescription = nd ;
150+ }
113151 }
152+
153+ Set <String > surplusLabels = new HashSet <>(labels );
154+ mostMatchingStaticLabels .forEach (surplusLabels ::remove );
155+ return new NodeDescriptionAndLabels (mostMatchingNodeDescription , surplusLabels );
114156 }
115157
116158 Set <String > surplusLabels = new HashSet <>(labels );
117159 surplusLabels .remove (entityDescription .getPrimaryLabel ());
118- surplusLabels . removeAll ( entityDescription .getAdditionalLabels ());
160+ entityDescription .getAdditionalLabels (). forEach ( surplusLabels :: remove );
119161 return new NodeDescriptionAndLabels (entityDescription , surplusLabels );
120162 }
121163}
0 commit comments