1616package org .springframework .data .jdbc .core .convert ;
1717
1818import java .util .*;
19+ import java .util .function .BiFunction ;
1920import java .util .function .Function ;
2021import java .util .function .Predicate ;
2122import java .util .stream .Collectors ;
@@ -115,7 +116,7 @@ public class SqlGenerator {
115116
116117 /**
117118 * Create a basic select structure with all the necessary joins
118- *
119+ *
119120 * @param table the table to base the select on
120121 * @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
121122 * {@literal true} will be skipped when determining columns to select.
@@ -185,6 +186,8 @@ private Condition getSubselectCondition(AggregatePath path,
185186 Table subSelectTable = Table .create (parentPathTableInfo .qualifiedTableName ());
186187
187188 Map <AggregatePath , Column > selectFilterColumns = new TreeMap <>();
189+
190+ // TODO: cannot we simply pass on the columnInfos?
188191 parentPathTableInfo .effectiveIdColumnInfos ().forEach ( //
189192 (ap , ci ) -> //
190193 selectFilterColumns .put (ap , subSelectTable .column (ci .name ())) //
@@ -468,6 +471,8 @@ String createDeleteAllSql(@Nullable PersistentPropertyPath<RelationalPersistentP
468471 * @return the statement as a {@link String}. Guaranteed to be not {@literal null}.
469472 */
470473 String createDeleteByPath (PersistentPropertyPath <RelationalPersistentProperty > path ) {
474+ // TODO: When deleting by path, why do we expect the where-value to be id and not named after the path?
475+ // See SqlGeneratorEmbeddedUnitTests.deleteByPath
471476 return createDeleteByPathAndCriteria (mappingContext .getAggregatePath (path ), this ::equalityCondition );
472477 }
473478
@@ -487,12 +492,10 @@ String createDeleteInByPath(PersistentPropertyPath<RelationalPersistentProperty>
487492 */
488493 private Condition inCondition (Map <AggregatePath , Column > columnMap ) {
489494
490- List <Column > columns = List . copyOf ( columnMap .values () );
495+ Collection <Column > columns = columnMap .values ();
491496
492- if (columns .size () == 1 ) {
493- return Conditions .in (columns .get (0 ), getBindMarker (IDS_SQL_PARAMETER ));
494- }
495- return Conditions .in (TupleExpression .create (columns ), getBindMarker (IDS_SQL_PARAMETER ));
497+ return Conditions .in (columns .size () == 1 ? columns .iterator ().next () : TupleExpression .create (columns ),
498+ getBindMarker (IDS_SQL_PARAMETER ));
496499 }
497500
498501 /**
@@ -501,44 +504,54 @@ private Condition inCondition(Map<AggregatePath, Column> columnMap) {
501504 */
502505 private Condition equalityCondition (Map <AggregatePath , Column > columnMap ) {
503506
504- AggregatePath . ColumnInfos idColumnInfos = mappingContext . getAggregatePath ( entity ). getTableInfo (). idColumnInfos ( );
507+ Assert . isTrue (! columnMap . isEmpty (), "Column map must not be empty" );
505508
506- Condition result = null ;
507- for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
508- BindMarker bindMarker = getBindMarker (idColumnInfos .get (entry .getKey ()).name ());
509- Comparison singleCondition = entry .getValue ().isEqualTo (bindMarker );
509+ AggregatePath .ColumnInfos idColumnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
510510
511- result = result == null ? singleCondition : result .and (singleCondition );
512- }
513- Assert .state (result != null , "We need at least one condition" );
514- return result ;
511+ return createPredicate (columnMap , (aggregatePath , column ) -> {
512+ return column .isEqualTo (getBindMarker (idColumnInfos .get (aggregatePath ).name ()));
513+ });
515514 }
516515
517516 /**
518517 * Constructs a function for constructing where a condition. The where condition will be of the form
519518 * {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
520519 */
521520 private Condition isNotNullCondition (Map <AggregatePath , Column > columnMap ) {
521+ return createPredicate (columnMap , (aggregatePath , column ) -> column .isNotNull ());
522+ }
523+
524+ /**
525+ * Constructs a function for constructing where a condition. The where condition will be of the form
526+ * {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
527+ */
528+ private static Condition createPredicate (Map <AggregatePath , Column > columnMap ,
529+ BiFunction <AggregatePath , Column , Condition > conditionFunction ) {
522530
523531 Condition result = null ;
524- for (Column column : columnMap .values ()) {
525- Condition singleCondition = column .isNotNull ();
532+ for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
526533
534+ Condition singleCondition = conditionFunction .apply (entry .getKey (), entry .getValue ());
527535 result = result == null ? singleCondition : result .and (singleCondition );
528536 }
529537 Assert .state (result != null , "We need at least one condition" );
530538 return result ;
531539 }
532540
533541 private String createFindOneSql () {
534-
535542 return render (selectBuilder ().where (equalityIdWhereCondition ()).build ());
536543 }
537544
538545 private Condition equalityIdWhereCondition () {
546+ return equalityIdWhereCondition (getIdColumns ());
547+ }
548+
549+ private Condition equalityIdWhereCondition (Iterable <Column > columns ) {
550+
551+ Assert .isTrue (columns .iterator ().hasNext (), "Identifier columns must not be empty" );
539552
540553 Condition aggregate = null ;
541- for (Column column : getIdColumns () ) {
554+ for (Column column : columns ) {
542555
543556 Comparison condition = column .isEqualTo (getBindMarker (column .getName ()));
544557 aggregate = aggregate == null ? condition : aggregate .and (condition );
@@ -711,19 +724,13 @@ Join getJoin(AggregatePath path) {
711724 Table parentTable = sqlContext .getTable (idDefiningParentPath );
712725 AggregatePath .ColumnInfos idColumnInfos = idDefiningParentPath .getTableInfo ().idColumnInfos ();
713726
714- final Condition [] joinCondition = { null };
715- backRefColumnInfos .forEach ((ap , ci ) -> {
716-
717- Condition elementalCondition = currentTable .column (ci .name ())
718- .isEqualTo (parentTable .column (idColumnInfos .get (ap ).name ()));
719- joinCondition [0 ] = joinCondition [0 ] == null ? elementalCondition : joinCondition [0 ].and (elementalCondition );
720- });
727+ Condition joinCondition = backRefColumnInfos .reduce (Conditions .unrestricted (), (aggregatePath , columnInfo ) -> {
721728
722- return new Join ( //
723- currentTable , //
724- joinCondition [0 ] //
725- );
729+ return currentTable .column (columnInfo .name ())
730+ .isEqualTo (parentTable .column (idColumnInfos .get (aggregatePath ).name ()));
731+ }, Condition ::and );
726732
733+ return new Join (currentTable , joinCondition );
727734 }
728735
729736 private String createFindAllInListSql () {
@@ -862,6 +869,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
862869
863870 Map <AggregatePath , Column > columns = new TreeMap <>();
864871 AggregatePath .ColumnInfos columnInfos = path .getTableInfo ().backReferenceColumnInfos ();
872+
873+ // TODO: cannot we simply pass on the columnInfos?
865874 columnInfos .forEach ((ag , ci ) -> columns .put (ag , table .column (ci .name ())));
866875
867876 if (isFirstNonRoot (path )) {
@@ -915,17 +924,20 @@ private Table getTable() {
915924 */
916925 private Column getSingleNonNullColumn () {
917926
927+ // getColumn() is slightly different from the code in any(…). Why?
928+ // AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
929+ // return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
930+
918931 AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
919932 return columnInfos .any ((ap , ci ) -> sqlContext .getTable (columnInfos .fullPath (ap )).column (ci .name ()).as (ci .alias ()));
920933 }
921934
922935 private List <Column > getIdColumns () {
923936
924937 AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
925- List <Column > result = new ArrayList <>(columnInfos .size ());
926- columnInfos .forEach ((ap , ci ) -> result .add (sqlContext .getColumn (columnInfos .fullPath (ap ))));
927938
928- return result ;
939+ return columnInfos
940+ .toColumnList ((aggregatePath , columnInfo ) -> sqlContext .getColumn (columnInfos .fullPath (aggregatePath )));
929941 }
930942
931943 private Column getVersionColumn () {
0 commit comments