1616package org .springframework .data .jdbc .core .convert ;
1717
1818import java .util .*;
19+ import java .util .function .BiFunction ;
1920import java .util .function .Function ;
2021import java .util .function .Predicate ;
2122import java .util .stream .Collectors ;
@@ -118,7 +119,7 @@ public class SqlGenerator {
118119
119120 /**
120121 * Create a basic select structure with all the necessary joins
121- *
122+ *
122123 * @param table the table to base the select on
123124 * @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
124125 * {@literal true} will be skipped when determining columns to select.
@@ -188,6 +189,8 @@ private Condition getSubselectCondition(AggregatePath path,
188189 Table subSelectTable = Table .create (parentPathTableInfo .qualifiedTableName ());
189190
190191 Map <AggregatePath , Column > selectFilterColumns = new TreeMap <>();
192+
193+ // TODO: cannot we simply pass on the columnInfos?
191194 parentPathTableInfo .effectiveIdColumnInfos ().forEach ( //
192195 (ap , ci ) -> //
193196 selectFilterColumns .put (ap , subSelectTable .column (ci .name ())) //
@@ -471,6 +474,8 @@ String createDeleteAllSql(@Nullable PersistentPropertyPath<RelationalPersistentP
471474 * @return the statement as a {@link String}. Guaranteed to be not {@literal null}.
472475 */
473476 String createDeleteByPath (PersistentPropertyPath <RelationalPersistentProperty > path ) {
477+ // TODO: When deleting by path, why do we expect the where-value to be id and not named after the path?
478+ // See SqlGeneratorEmbeddedUnitTests.deleteByPath
474479 return createDeleteByPathAndCriteria (mappingContext .getAggregatePath (path ), this ::equalityCondition );
475480 }
476481
@@ -490,12 +495,10 @@ String createDeleteInByPath(PersistentPropertyPath<RelationalPersistentProperty>
490495 */
491496 private Condition inCondition (Map <AggregatePath , Column > columnMap ) {
492497
493- List <Column > columns = List . copyOf ( columnMap .values () );
498+ Collection <Column > columns = columnMap .values ();
494499
495- if (columns .size () == 1 ) {
496- return Conditions .in (columns .get (0 ), getBindMarker (IDS_SQL_PARAMETER ));
497- }
498- return Conditions .in (TupleExpression .create (columns ), getBindMarker (IDS_SQL_PARAMETER ));
500+ return Conditions .in (columns .size () == 1 ? columns .iterator ().next () : TupleExpression .create (columns ),
501+ getBindMarker (IDS_SQL_PARAMETER ));
499502 }
500503
501504 /**
@@ -504,44 +507,54 @@ private Condition inCondition(Map<AggregatePath, Column> columnMap) {
504507 */
505508 private Condition equalityCondition (Map <AggregatePath , Column > columnMap ) {
506509
507- AggregatePath . ColumnInfos idColumnInfos = mappingContext . getAggregatePath ( entity ). getTableInfo (). idColumnInfos ( );
510+ Assert . isTrue (! columnMap . isEmpty (), "Column map must not be empty" );
508511
509- Condition result = null ;
510- for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
511- BindMarker bindMarker = getBindMarker (idColumnInfos .get (entry .getKey ()).name ());
512- Comparison singleCondition = entry .getValue ().isEqualTo (bindMarker );
512+ AggregatePath .ColumnInfos idColumnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
513513
514- result = result == null ? singleCondition : result .and (singleCondition );
515- }
516- Assert .state (result != null , "We need at least one condition" );
517- return result ;
514+ return createPredicate (columnMap , (aggregatePath , column ) -> {
515+ return column .isEqualTo (getBindMarker (idColumnInfos .get (aggregatePath ).name ()));
516+ });
518517 }
519518
520519 /**
521520 * Constructs a function for constructing where a condition. The where condition will be of the form
522521 * {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
523522 */
524523 private Condition isNotNullCondition (Map <AggregatePath , Column > columnMap ) {
524+ return createPredicate (columnMap , (aggregatePath , column ) -> column .isNotNull ());
525+ }
526+
527+ /**
528+ * Constructs a function for constructing where a condition. The where condition will be of the form
529+ * {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
530+ */
531+ private static Condition createPredicate (Map <AggregatePath , Column > columnMap ,
532+ BiFunction <AggregatePath , Column , Condition > conditionFunction ) {
525533
526534 Condition result = null ;
527- for (Column column : columnMap .values ()) {
528- Condition singleCondition = column .isNotNull ();
535+ for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
529536
537+ Condition singleCondition = conditionFunction .apply (entry .getKey (), entry .getValue ());
530538 result = result == null ? singleCondition : result .and (singleCondition );
531539 }
532540 Assert .state (result != null , "We need at least one condition" );
533541 return result ;
534542 }
535543
536544 private String createFindOneSql () {
537-
538545 return render (selectBuilder ().where (equalityIdWhereCondition ()).build ());
539546 }
540547
541548 private Condition equalityIdWhereCondition () {
549+ return equalityIdWhereCondition (getIdColumns ());
550+ }
551+
552+ private Condition equalityIdWhereCondition (Iterable <Column > columns ) {
553+
554+ Assert .isTrue (columns .iterator ().hasNext (), "Identifier columns must not be empty" );
542555
543556 Condition aggregate = null ;
544- for (Column column : getIdColumns () ) {
557+ for (Column column : columns ) {
545558
546559 Comparison condition = column .isEqualTo (getBindMarker (column .getName ()));
547560 aggregate = aggregate == null ? condition : aggregate .and (condition );
@@ -766,19 +779,13 @@ Join getJoin(AggregatePath path) {
766779 Table parentTable = sqlContext .getTable (idDefiningParentPath );
767780 AggregatePath .ColumnInfos idColumnInfos = idDefiningParentPath .getTableInfo ().idColumnInfos ();
768781
769- final Condition [] joinCondition = { null };
770- backRefColumnInfos .forEach ((ap , ci ) -> {
771-
772- Condition elementalCondition = currentTable .column (ci .name ())
773- .isEqualTo (parentTable .column (idColumnInfos .get (ap ).name ()));
774- joinCondition [0 ] = joinCondition [0 ] == null ? elementalCondition : joinCondition [0 ].and (elementalCondition );
775- });
782+ Condition joinCondition = backRefColumnInfos .reduce (Conditions .unrestricted (), (aggregatePath , columnInfo ) -> {
776783
777- return new Join ( //
778- currentTable , //
779- joinCondition [0 ] //
780- );
784+ return currentTable .column (columnInfo .name ())
785+ .isEqualTo (parentTable .column (idColumnInfos .get (aggregatePath ).name ()));
786+ }, Condition ::and );
781787
788+ return new Join (currentTable , joinCondition );
782789 }
783790
784791 private String createFindAllInListSql () {
@@ -917,6 +924,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
917924
918925 Map <AggregatePath , Column > columns = new TreeMap <>();
919926 AggregatePath .ColumnInfos columnInfos = path .getTableInfo ().backReferenceColumnInfos ();
927+
928+ // TODO: cannot we simply pass on the columnInfos?
920929 columnInfos .forEach ((ag , ci ) -> columns .put (ag , table .column (ci .name ())));
921930
922931 if (isFirstNonRoot (path )) {
@@ -970,17 +979,20 @@ private Table getTable() {
970979 */
971980 private Column getSingleNonNullColumn () {
972981
982+ // getColumn() is slightly different from the code in any(…). Why?
983+ // AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
984+ // return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
985+
973986 AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
974987 return columnInfos .any ((ap , ci ) -> sqlContext .getTable (columnInfos .fullPath (ap )).column (ci .name ()).as (ci .alias ()));
975988 }
976989
977990 private List <Column > getIdColumns () {
978991
979992 AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
980- List <Column > result = new ArrayList <>(columnInfos .size ());
981- columnInfos .forEach ((ap , ci ) -> result .add (sqlContext .getColumn (columnInfos .fullPath (ap ))));
982993
983- return result ;
994+ return columnInfos
995+ .toColumnList ((aggregatePath , columnInfo ) -> sqlContext .getColumn (columnInfos .fullPath (aggregatePath )));
984996 }
985997
986998 private Column getVersionColumn () {
0 commit comments