Skip to content

Commit b54085c

Browse files
rohanmukeshartembilan
authored andcommitted
INT-4566: UPDATE for R2DBC In Channel Adapter
JIRA: https://jira.spring.io/browse/INT-4566 * Rework UPDATE logic according deprecations * Use `ColumnMapRowMapper` for default `Map` payload type * Clean up tests
1 parent 17fc4ea commit b54085c

File tree

2 files changed

+209
-14
lines changed

2 files changed

+209
-14
lines changed

spring-integration-r2dbc/src/main/java/org/springframework/integration/r2dbc/inbound/R2dbcMessageSource.java

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.integration.r2dbc.inbound;
1818

1919

20+
import java.util.HashMap;
2021
import java.util.Map;
22+
import java.util.function.BiFunction;
2123

2224
import org.reactivestreams.Publisher;
2325

@@ -30,9 +32,13 @@
3032
import org.springframework.expression.spel.support.StandardTypeLocator;
3133
import org.springframework.integration.endpoint.AbstractMessageSource;
3234
import org.springframework.integration.expression.ExpressionUtils;
35+
import org.springframework.r2dbc.core.ColumnMapRowMapper;
36+
import org.springframework.r2dbc.core.DatabaseClient;
3337
import org.springframework.r2dbc.core.RowsFetchSpec;
3438
import org.springframework.util.Assert;
3539

40+
import io.r2dbc.spi.Row;
41+
import io.r2dbc.spi.RowMetadata;
3642
import reactor.core.publisher.Mono;
3743

3844
/**
@@ -54,18 +60,24 @@
5460
*/
5561
public class R2dbcMessageSource extends AbstractMessageSource<Publisher<?>> {
5662

57-
private final R2dbcEntityOperations r2dbcEntityOperations;
63+
private final DatabaseClient databaseClient;
5864

5965
private final ReactiveDataAccessStrategy dataAccessStrategy;
6066

6167
private final Expression queryExpression;
6268

6369
private Class<?> payloadType = Map.class;
6470

71+
private BiFunction<Row, RowMetadata, ?> rowMapper = ColumnMapRowMapper.INSTANCE;
72+
6573
private boolean expectSingleResult = false;
6674

6775
private StandardEvaluationContext evaluationContext;
6876

77+
private String updateSql;
78+
79+
private BiFunction<DatabaseClient.GenericExecuteSpec, Object, DatabaseClient.GenericExecuteSpec> bindFunction;
80+
6981
private volatile boolean initialized = false;
7082

7183
/**
@@ -91,8 +103,8 @@ public R2dbcMessageSource(R2dbcEntityOperations r2dbcEntityOperations, String qu
91103
public R2dbcMessageSource(R2dbcEntityOperations r2dbcEntityOperations, Expression queryExpression) {
92104
Assert.notNull(r2dbcEntityOperations, "'r2dbcEntityOperations' must not be null");
93105
Assert.notNull(queryExpression, "'queryExpression' must not be null");
94-
this.r2dbcEntityOperations = r2dbcEntityOperations;
95-
this.dataAccessStrategy = this.r2dbcEntityOperations.getDataAccessStrategy();
106+
this.databaseClient = r2dbcEntityOperations.getDatabaseClient();
107+
this.dataAccessStrategy = r2dbcEntityOperations.getDataAccessStrategy();
96108
this.queryExpression = queryExpression;
97109
}
98110

@@ -108,10 +120,33 @@ public void setPayloadType(Class<?> payloadType) {
108120
}
109121

110122
/**
111-
* Provide a way to return all the records matching criteria or only and only a one otherwise.
123+
* Provide a way to set update query that will be passed to the
124+
* {@link org.springframework.data.r2dbc.core.DatabaseClient#execute(String)}
125+
* method.
126+
* @param updateSql Update query string.
127+
*/
128+
public void setUpdateSql(String updateSql) {
129+
this.updateSql = updateSql;
130+
}
131+
132+
/**
133+
* Provide a way to set BindFunction which will be used to bind parameters
134+
* in the update query.
135+
* @param bindFunction The bindFunction.
136+
*/
137+
@SuppressWarnings("unchecked")
138+
public void setBindFunction(
139+
BiFunction<DatabaseClient.GenericExecuteSpec, ?, DatabaseClient.GenericExecuteSpec> bindFunction) {
140+
141+
this.bindFunction =
142+
(BiFunction<DatabaseClient.GenericExecuteSpec, Object, DatabaseClient.GenericExecuteSpec>) bindFunction;
143+
}
144+
145+
/**
146+
* Provide a way to manage which find* method to invoke on {@link R2dbcEntityOperations}.
112147
* Default is 'false', which means the {@link #receive()} method will use
113-
* the {@link org.springframework.data.r2dbc.core.DatabaseClient#execute(String)} method and will fetch all. If set
114-
* to 'true'{@link #receive()} will use {@link org.springframework.data.r2dbc.core.DatabaseClient#execute(String)}
148+
* the {@link DatabaseClient#sql(String)} method and will fetch all. If set
149+
* to 'true'{@link #receive()} will use {@link DatabaseClient#sql(String)}
115150
* and will fetch one and the payload of the returned {@link org.springframework.messaging.Message}
116151
* will be the returned target Object of type
117152
* identified by {@link #payloadType} instead of a List.
@@ -136,6 +171,9 @@ protected void onInit() {
136171
*/
137172
((StandardTypeLocator) typeLocator).registerImport("org.springframework.data.relational.core.query");
138173
}
174+
if (!Map.class.isAssignableFrom(this.payloadType)) {
175+
this.rowMapper = this.dataAccessStrategy.getRowMapper(this.payloadType);
176+
}
139177
this.initialized = true;
140178
}
141179

@@ -155,17 +193,31 @@ protected Object doReceive() {
155193
Mono.fromSupplier(() -> this.queryExpression.getValue(this.evaluationContext))
156194
.map(this::prepareFetch);
157195
if (this.expectSingleResult) {
158-
return queryMono.flatMap(RowsFetchSpec::one);
196+
return queryMono.flatMap(RowsFetchSpec::one)
197+
.flatMap(this::executeUpdate);
198+
}
199+
200+
return queryMono.flatMapMany(RowsFetchSpec::all)
201+
.flatMap(this::executeUpdate);
202+
}
203+
204+
private Mono<Object> executeUpdate(Object result) {
205+
if (this.updateSql != null) {
206+
DatabaseClient.GenericExecuteSpec genericExecuteSpec = this.databaseClient.sql(this.updateSql);
207+
if (this.bindFunction != null) {
208+
genericExecuteSpec = this.bindFunction.apply(genericExecuteSpec, result);
209+
}
210+
return genericExecuteSpec.then()
211+
.thenReturn(result);
159212
}
160-
return queryMono.flatMapMany(RowsFetchSpec::all);
213+
return Mono.just(result);
161214
}
162215

163216
private RowsFetchSpec<?> prepareFetch(Object queryObject) {
164217
String queryString = evaluateQueryObject(queryObject);
165-
return this.r2dbcEntityOperations
166-
.getDatabaseClient()
218+
return this.databaseClient
167219
.sql(queryString)
168-
.map(this.dataAccessStrategy.getRowMapper(this.payloadType));
220+
.map(this.rowMapper);
169221
}
170222

171223
private String evaluateQueryObject(Object queryObject) {

spring-integration-r2dbc/src/test/java/org/springframework/integration/r2dbc/inbound/R2dbcMessageSourceTests.java

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ public class R2dbcMessageSourceTests {
5858

5959
R2dbcEntityTemplate entityTemplate;
6060

61+
@Autowired
62+
R2dbcMessageSource defaultR2dbcMessageSource;
63+
6164
@Autowired
6265
R2dbcMessageSource r2dbcMessageSourceSelectOne;
6366

@@ -70,9 +73,13 @@ public class R2dbcMessageSourceTests {
7073
@BeforeEach
7174
public void setup() {
7275
this.entityTemplate = new R2dbcEntityTemplate(this.client, H2Dialect.INSTANCE);
73-
List<String> statements = Arrays.asList(
74-
"DROP TABLE IF EXISTS person;",
75-
"CREATE table person (id INT AUTO_INCREMENT NOT NULL, name VARCHAR2, age INT NOT NULL);");
76+
r2dbcMessageSourceSelectMany.setExpectSingleResult(false);
77+
defaultR2dbcMessageSource.setBindFunction(null);
78+
79+
List<String> statements =
80+
Arrays.asList(
81+
"DROP TABLE IF EXISTS person;",
82+
"CREATE table person (id INT AUTO_INCREMENT NOT NULL, name VARCHAR2, age INT NOT NULL);");
7683

7784
statements.forEach(it -> this.client.sql(it)
7885
.fetch()
@@ -129,6 +136,136 @@ public void validateSuccessfulQueryWithMultipleElementOfFluxDBObject() {
129136

130137
}
131138

139+
@Test
140+
public void validateSuccessfulUpdateWithSingleElementOfMonoDBObject() {
141+
this.entityTemplate.insert(new Person("Bob", 35))
142+
.then()
143+
.as(StepVerifier::create)
144+
.verifyComplete();
145+
146+
r2dbcMessageSourceSelectMany.setUpdateSql("UPDATE Person SET name='Foo' where age = :age");
147+
r2dbcMessageSourceSelectMany.setBindFunction(
148+
(DatabaseClient.GenericExecuteSpec bindSpec, Person o) -> bindSpec.bind("age", o.getAge()));
149+
r2dbcMessageSourceSelectMany.setExpectSingleResult(true);
150+
151+
StepVerifier.create(r2dbcMessageSourceSelectMany.receive().getPayload())
152+
.assertNext(person -> assertThat(((Person) person).getName()).isEqualTo("Bob"))
153+
.verifyComplete();
154+
155+
this.entityTemplate.select(Person.class)
156+
.all()
157+
.as(StepVerifier::create)
158+
.assertNext(person -> assertThat(person.getName()).isEqualTo("Foo"))
159+
.verifyComplete();
160+
161+
}
162+
163+
@Test
164+
public void validateSuccessfulUpdateWithMultiplesElementsOfFluxDBObject() {
165+
this.entityTemplate.insert(new Person("Bob", 35))
166+
.then()
167+
.as(StepVerifier::create)
168+
.verifyComplete();
169+
170+
this.entityTemplate.insert(new Person("Tom", 40))
171+
.then()
172+
.as(StepVerifier::create)
173+
.verifyComplete();
174+
175+
r2dbcMessageSourceSelectMany.setUpdateSql("UPDATE person SET name='Foo' where id = :id");
176+
r2dbcMessageSourceSelectMany.setBindFunction(
177+
(DatabaseClient.GenericExecuteSpec bindSpec, Person o) -> bindSpec.bind("id", o.getId()));
178+
StepVerifier.create(r2dbcMessageSourceSelectMany.receive().getPayload())
179+
.expectNextCount(2)
180+
.verifyComplete();
181+
182+
this.entityTemplate.select(Person.class)
183+
.all()
184+
.as(StepVerifier::create)
185+
.assertNext(person -> assertThat(person.getName()).isEqualTo("Foo"))
186+
.assertNext(person -> assertThat(person.getName()).isEqualTo("Foo"))
187+
.verifyComplete();
188+
189+
}
190+
191+
@Test
192+
public void validateSuccessfulUpdateWithoutBindFunction() {
193+
this.entityTemplate.insert(new Person("Bob", 35))
194+
.then()
195+
.as(StepVerifier::create)
196+
.verifyComplete();
197+
198+
this.entityTemplate.insert(new Person("Tom", 40))
199+
.then()
200+
.as(StepVerifier::create)
201+
.verifyComplete();
202+
203+
r2dbcMessageSourceSelectMany.setUpdateSql("UPDATE person SET name='Foo' where id = 1");
204+
205+
StepVerifier.create(r2dbcMessageSourceSelectMany.receive().getPayload())
206+
.expectNextCount(2)
207+
.verifyComplete();
208+
209+
this.entityTemplate.select(Person.class)
210+
.all()
211+
.as(StepVerifier::create)
212+
.assertNext(person -> assertThat(person.getName()).isEqualTo("Foo"))
213+
.assertNext(person -> assertThat(person.getName()).isEqualTo("Tom"))
214+
.verifyComplete();
215+
216+
}
217+
218+
@Test
219+
public void validateSuccessfulUpdateWithoutPayloadType() {
220+
this.entityTemplate.insert(new Person("Bob", 35))
221+
.then()
222+
.as(StepVerifier::create)
223+
.verifyComplete();
224+
225+
this.entityTemplate.insert(new Person("Tom", 40))
226+
.then()
227+
.as(StepVerifier::create)
228+
.verifyComplete();
229+
230+
defaultR2dbcMessageSource.setUpdateSql("UPDATE person SET name='Foo' where id = 1");
231+
232+
StepVerifier.create(defaultR2dbcMessageSource.receive().getPayload())
233+
.expectNextCount(2)
234+
.verifyComplete();
235+
236+
this.client.sql("select * from person")
237+
.fetch()
238+
.all()
239+
.as(StepVerifier::create)
240+
.assertNext(person -> assertThat(person.get("name")).isEqualTo("Foo"))
241+
.assertNext(person -> assertThat(person.get("name")).isEqualTo("Tom"))
242+
.verifyComplete();
243+
244+
}
245+
246+
@Test
247+
public void testWrongPayloadTypeInBindFunction() {
248+
this.entityTemplate.insert(new Person("Bob", 35))
249+
.then()
250+
.as(StepVerifier::create)
251+
.verifyComplete();
252+
253+
this.entityTemplate.insert(new Person("Tom", 40))
254+
.then()
255+
.as(StepVerifier::create)
256+
.verifyComplete();
257+
258+
defaultR2dbcMessageSource.setUpdateSql("UPDATE person SET name='Foo' where id = 1");
259+
defaultR2dbcMessageSource.setBindFunction(
260+
(DatabaseClient.GenericExecuteSpec bindSpec, Person o) -> bindSpec.bind("id", o.getId()));
261+
262+
StepVerifier.create(defaultR2dbcMessageSource.receive().getPayload())
263+
.expectErrorMatches(throwable -> throwable instanceof ClassCastException)
264+
.verify();
265+
266+
}
267+
268+
132269
@Test
133270
public void testAnyOtherObjectQueryExpression() {
134271

@@ -145,6 +282,12 @@ static class R2dbcMessageSourceConfiguration {
145282
@Autowired
146283
R2dbcEntityTemplate r2dbcEntityTemplate;
147284

285+
@Bean
286+
public R2dbcMessageSource defaultR2dbcMessageSource() {
287+
return new R2dbcMessageSource(r2dbcEntityTemplate, "select * from " +
288+
"person");
289+
}
290+
148291
@Bean
149292
public R2dbcMessageSource r2dbcMessageSourceSelectOne() {
150293
R2dbcMessageSource r2dbcMessageSource = new R2dbcMessageSource(this.r2dbcEntityTemplate,

0 commit comments

Comments
 (0)