1717package org .springframework .integration .r2dbc .inbound ;
1818
1919
20+ import java .util .HashMap ;
2021import java .util .Map ;
22+ import java .util .function .BiFunction ;
2123
2224import org .reactivestreams .Publisher ;
2325
3032import org .springframework .expression .spel .support .StandardTypeLocator ;
3133import org .springframework .integration .endpoint .AbstractMessageSource ;
3234import org .springframework .integration .expression .ExpressionUtils ;
35+ import org .springframework .r2dbc .core .ColumnMapRowMapper ;
36+ import org .springframework .r2dbc .core .DatabaseClient ;
3337import org .springframework .r2dbc .core .RowsFetchSpec ;
3438import org .springframework .util .Assert ;
3539
40+ import io .r2dbc .spi .Row ;
41+ import io .r2dbc .spi .RowMetadata ;
3642import reactor .core .publisher .Mono ;
3743
3844/**
5460 */
5561public class R2dbcMessageSource extends AbstractMessageSource <Publisher <?>> {
5662
57- private final R2dbcEntityOperations r2dbcEntityOperations ;
63+ private final DatabaseClient databaseClient ;
5864
5965 private final ReactiveDataAccessStrategy dataAccessStrategy ;
6066
6167 private final Expression queryExpression ;
6268
6369 private Class <?> payloadType = Map .class ;
6470
71+ private BiFunction <Row , RowMetadata , ?> rowMapper = ColumnMapRowMapper .INSTANCE ;
72+
6573 private boolean expectSingleResult = false ;
6674
6775 private StandardEvaluationContext evaluationContext ;
6876
77+ private String updateSql ;
78+
79+ private BiFunction <DatabaseClient .GenericExecuteSpec , Object , DatabaseClient .GenericExecuteSpec > bindFunction ;
80+
6981 private volatile boolean initialized = false ;
7082
7183 /**
@@ -91,8 +103,8 @@ public R2dbcMessageSource(R2dbcEntityOperations r2dbcEntityOperations, String qu
91103 public R2dbcMessageSource (R2dbcEntityOperations r2dbcEntityOperations , Expression queryExpression ) {
92104 Assert .notNull (r2dbcEntityOperations , "'r2dbcEntityOperations' must not be null" );
93105 Assert .notNull (queryExpression , "'queryExpression' must not be null" );
94- this .r2dbcEntityOperations = r2dbcEntityOperations ;
95- this .dataAccessStrategy = this . r2dbcEntityOperations .getDataAccessStrategy ();
106+ this .databaseClient = r2dbcEntityOperations . getDatabaseClient () ;
107+ this .dataAccessStrategy = r2dbcEntityOperations .getDataAccessStrategy ();
96108 this .queryExpression = queryExpression ;
97109 }
98110
@@ -108,10 +120,33 @@ public void setPayloadType(Class<?> payloadType) {
108120 }
109121
110122 /**
111- * Provide a way to return all the records matching criteria or only and only a one otherwise.
123+ * Provide a way to set update query that will be passed to the
124+ * {@link org.springframework.data.r2dbc.core.DatabaseClient#execute(String)}
125+ * method.
126+ * @param updateSql Update query string.
127+ */
128+ public void setUpdateSql (String updateSql ) {
129+ this .updateSql = updateSql ;
130+ }
131+
132+ /**
133+ * Provide a way to set BindFunction which will be used to bind parameters
134+ * in the update query.
135+ * @param bindFunction The bindFunction.
136+ */
137+ @ SuppressWarnings ("unchecked" )
138+ public void setBindFunction (
139+ BiFunction <DatabaseClient .GenericExecuteSpec , ?, DatabaseClient .GenericExecuteSpec > bindFunction ) {
140+
141+ this .bindFunction =
142+ (BiFunction <DatabaseClient .GenericExecuteSpec , Object , DatabaseClient .GenericExecuteSpec >) bindFunction ;
143+ }
144+
145+ /**
146+ * Provide a way to manage which find* method to invoke on {@link R2dbcEntityOperations}.
112147 * Default is 'false', which means the {@link #receive()} method will use
113- * the {@link org.springframework.data.r2dbc.core. DatabaseClient#execute (String)} method and will fetch all. If set
114- * to 'true'{@link #receive()} will use {@link org.springframework.data.r2dbc.core. DatabaseClient#execute (String)}
148+ * the {@link DatabaseClient#sql (String)} method and will fetch all. If set
149+ * to 'true'{@link #receive()} will use {@link DatabaseClient#sql (String)}
115150 * and will fetch one and the payload of the returned {@link org.springframework.messaging.Message}
116151 * will be the returned target Object of type
117152 * identified by {@link #payloadType} instead of a List.
@@ -136,6 +171,9 @@ protected void onInit() {
136171 */
137172 ((StandardTypeLocator ) typeLocator ).registerImport ("org.springframework.data.relational.core.query" );
138173 }
174+ if (!Map .class .isAssignableFrom (this .payloadType )) {
175+ this .rowMapper = this .dataAccessStrategy .getRowMapper (this .payloadType );
176+ }
139177 this .initialized = true ;
140178 }
141179
@@ -155,17 +193,31 @@ protected Object doReceive() {
155193 Mono .fromSupplier (() -> this .queryExpression .getValue (this .evaluationContext ))
156194 .map (this ::prepareFetch );
157195 if (this .expectSingleResult ) {
158- return queryMono .flatMap (RowsFetchSpec ::one );
196+ return queryMono .flatMap (RowsFetchSpec ::one )
197+ .flatMap (this ::executeUpdate );
198+ }
199+
200+ return queryMono .flatMapMany (RowsFetchSpec ::all )
201+ .flatMap (this ::executeUpdate );
202+ }
203+
204+ private Mono <Object > executeUpdate (Object result ) {
205+ if (this .updateSql != null ) {
206+ DatabaseClient .GenericExecuteSpec genericExecuteSpec = this .databaseClient .sql (this .updateSql );
207+ if (this .bindFunction != null ) {
208+ genericExecuteSpec = this .bindFunction .apply (genericExecuteSpec , result );
209+ }
210+ return genericExecuteSpec .then ()
211+ .thenReturn (result );
159212 }
160- return queryMono . flatMapMany ( RowsFetchSpec :: all );
213+ return Mono . just ( result );
161214 }
162215
163216 private RowsFetchSpec <?> prepareFetch (Object queryObject ) {
164217 String queryString = evaluateQueryObject (queryObject );
165- return this .r2dbcEntityOperations
166- .getDatabaseClient ()
218+ return this .databaseClient
167219 .sql (queryString )
168- .map (this .dataAccessStrategy . getRowMapper ( this . payloadType ) );
220+ .map (this .rowMapper );
169221 }
170222
171223 private String evaluateQueryObject (Object queryObject ) {
0 commit comments