2121import reactor .core .publisher .Flux ;
2222import reactor .core .publisher .Mono ;
2323
24- import java .beans .FeatureDescriptor ;
2524import java .util .Collections ;
25+ import java .util .LinkedHashSet ;
2626import java .util .List ;
2727import java .util .Map ;
2828import java .util .Optional ;
29+ import java .util .Set ;
2930import java .util .function .BiFunction ;
3031import java .util .function .Function ;
3132import java .util .stream .Collectors ;
3233
3334import org .reactivestreams .Publisher ;
35+
3436import org .springframework .beans .BeansException ;
3537import org .springframework .beans .factory .BeanFactory ;
3638import org .springframework .beans .factory .BeanFactoryAware ;
4648import org .springframework .data .mapping .callback .ReactiveEntityCallbacks ;
4749import org .springframework .data .mapping .context .MappingContext ;
4850import org .springframework .data .projection .EntityProjection ;
49- import org .springframework .data .projection .ProjectionInformation ;
5051import org .springframework .data .projection .SpelAwareProxyProjectionFactory ;
5152import org .springframework .data .r2dbc .convert .R2dbcConverter ;
5253import org .springframework .data .r2dbc .dialect .DialectResolver ;
5657import org .springframework .data .r2dbc .mapping .event .AfterSaveCallback ;
5758import org .springframework .data .r2dbc .mapping .event .BeforeConvertCallback ;
5859import org .springframework .data .r2dbc .mapping .event .BeforeSaveCallback ;
60+ import org .springframework .data .relational .core .mapping .PersistentPropertyTranslator ;
5961import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
6062import org .springframework .data .relational .core .mapping .RelationalPersistentProperty ;
6163import org .springframework .data .relational .core .query .Criteria ;
6870import org .springframework .data .relational .core .sql .SqlIdentifier ;
6971import org .springframework .data .relational .core .sql .Table ;
7072import org .springframework .data .relational .domain .RowDocument ;
73+ import org .springframework .data .util .Predicates ;
7174import org .springframework .data .util .ProxyUtils ;
7275import org .springframework .lang .Nullable ;
7376import org .springframework .r2dbc .core .DatabaseClient ;
@@ -332,7 +335,7 @@ private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdent
332335
333336 StatementMapper .SelectSpec selectSpec = statementMapper //
334337 .createSelect (tableName ) //
335- .doWithTable ((table , spec ) -> spec .withProjection (getSelectProjection (table , query , returnType )));
338+ .doWithTable ((table , spec ) -> spec .withProjection (getSelectProjection (table , query , entityType , returnType )));
336339
337340 if (query .getLimit () > 0 ) {
338341 selectSpec = selectSpec .limit (query .getLimit ());
@@ -423,7 +426,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, Class<T> entit
423426 }
424427
425428 @ Override
426- public <T > RowsFetchSpec <T > query (PreparedOperation <?> operation , Class <?> entityClass , Class <T > resultType ) throws DataAccessException {
429+ public <T > RowsFetchSpec <T > query (PreparedOperation <?> operation , Class <?> entityClass , Class <T > resultType )
430+ throws DataAccessException {
427431
428432 Assert .notNull (operation , "PreparedOperation must not be null" );
429433 Assert .notNull (entityClass , "Entity class must not be null" );
@@ -759,18 +763,16 @@ private <T> RelationalPersistentEntity<T> getRequiredEntity(T entity) {
759763 return (RelationalPersistentEntity ) getRequiredEntity (entityType );
760764 }
761765
762- private <T > List <Expression > getSelectProjection (Table table , Query query , Class <T > returnType ) {
766+ private <T > List <Expression > getSelectProjection (Table table , Query query , Class <?> entityType , Class < T > returnType ) {
763767
764768 if (query .getColumns ().isEmpty ()) {
765769
766- if (returnType .isInterface ()) {
770+ EntityProjection <T , ?> projection = converter .introspectProjection (returnType , entityType );
771+
772+ if (projection .isProjection () && projection .isClosedProjection ()) {
767773
768- ProjectionInformation projectionInformation = projectionFactory . getProjectionInformation ( returnType );
774+ return computeProjectedFields ( table , returnType , projection );
769775
770- if (projectionInformation .isClosed ()) {
771- return projectionInformation .getInputProperties ().stream ().map (FeatureDescriptor ::getName ).map (table ::column )
772- .collect (Collectors .toList ());
773- }
774776 }
775777
776778 return Collections .singletonList (table .asterisk ());
@@ -779,6 +781,36 @@ private <T> List<Expression> getSelectProjection(Table table, Query query, Class
779781 return query .getColumns ().stream ().map (table ::column ).collect (Collectors .toList ());
780782 }
781783
784+ @ SuppressWarnings ("unchecked" )
785+ private <T > List <Expression > computeProjectedFields (Table table , Class <T > returnType ,
786+ EntityProjection <T , ?> projection ) {
787+
788+ if (returnType .isInterface ()) {
789+
790+ Set <String > properties = new LinkedHashSet <>();
791+ projection .forEach (it -> {
792+ properties .add (it .getPropertyPath ().getSegment ());
793+ });
794+
795+ return properties .stream ().map (table ::column ).collect (Collectors .toList ());
796+ }
797+
798+ Set <SqlIdentifier > properties = new LinkedHashSet <>();
799+ // DTO projections use merged metadata between domain type and result type
800+ PersistentPropertyTranslator translator = PersistentPropertyTranslator .create (
801+ mappingContext .getRequiredPersistentEntity (projection .getDomainType ()),
802+ Predicates .negate (RelationalPersistentProperty ::hasExplicitColumnName ));
803+
804+ RelationalPersistentEntity <?> persistentEntity = mappingContext
805+ .getRequiredPersistentEntity (projection .getMappedType ());
806+ for (RelationalPersistentProperty property : persistentEntity ) {
807+ properties .add (translator .translate (property ).getColumnName ());
808+ }
809+
810+ return properties .stream ().map (table ::column ).collect (Collectors .toList ());
811+ }
812+
813+ @ SuppressWarnings ("unchecked" )
782814 public <T > RowsFetchSpec <T > getRowsFetchSpec (DatabaseClient .GenericExecuteSpec executeSpec , Class <?> entityType ,
783815 Class <T > resultType ) {
784816
@@ -791,13 +823,13 @@ public <T> RowsFetchSpec<T> getRowsFetchSpec(DatabaseClient.GenericExecuteSpec e
791823 } else {
792824
793825 EntityProjection <T , ?> projection = converter .introspectProjection (resultType , entityType );
826+ Class <T > typeToRead = projection .isProjection () ? resultType
827+ : resultType .isInterface () ? (Class <T >) entityType : resultType ;
794828
795829 rowMapper = (row , rowMetadata ) -> {
796830
797- RowDocument document = dataAccessStrategy .toRowDocument (resultType , row , rowMetadata .getColumnMetadatas ());
798-
799- return projection .isProjection () ? converter .project (projection , document )
800- : converter .read (resultType , document );
831+ RowDocument document = dataAccessStrategy .toRowDocument (typeToRead , row , rowMetadata .getColumnMetadatas ());
832+ return converter .project (projection , document );
801833 };
802834 }
803835
0 commit comments