11/*
2- * Copyright 2014-2018 the original author or authors.
2+ * Copyright 2014-2019 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1717package org .springframework .integration .scattergather ;
1818
1919import org .springframework .aop .support .AopUtils ;
20+ import org .springframework .beans .factory .BeanFactory ;
21+ import org .springframework .beans .factory .BeanInitializationException ;
2022import org .springframework .context .Lifecycle ;
2123import org .springframework .integration .channel .FixedSubscriberChannel ;
2224import org .springframework .integration .channel .QueueChannel ;
25+ import org .springframework .integration .channel .ReactiveStreamsSubscribableChannel ;
2326import org .springframework .integration .context .IntegrationContextUtils ;
2427import org .springframework .integration .core .MessageProducer ;
2528import org .springframework .integration .endpoint .AbstractEndpoint ;
2629import org .springframework .integration .endpoint .EventDrivenConsumer ;
2730import org .springframework .integration .endpoint .PollingConsumer ;
31+ import org .springframework .integration .endpoint .ReactiveStreamsConsumer ;
2832import org .springframework .integration .handler .AbstractReplyProducingMessageHandler ;
2933import org .springframework .integration .support .channel .HeaderChannelRegistry ;
3034import org .springframework .messaging .Message ;
3135import org .springframework .messaging .MessageChannel ;
3236import org .springframework .messaging .MessageDeliveryException ;
3337import org .springframework .messaging .MessageHandler ;
3438import org .springframework .messaging .MessageHeaders ;
35- import org .springframework .messaging .MessagingException ;
3639import org .springframework .messaging .PollableChannel ;
3740import org .springframework .messaging .SubscribableChannel ;
3841import org .springframework .util .Assert ;
@@ -57,13 +60,22 @@ public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler i
5760
5861 private MessageChannel gatherChannel ;
5962
63+ private String errorChannelName = IntegrationContextUtils .ERROR_CHANNEL_BEAN_NAME ;
64+
6065 private long gatherTimeout = -1 ;
6166
6267 private AbstractEndpoint gatherEndpoint ;
6368
6469 private HeaderChannelRegistry replyChannelRegistry ;
6570
6671
72+ public ScatterGatherHandler (MessageHandler scatterer , MessageHandler gatherer ) {
73+ this (new FixedSubscriberChannel (scatterer ), gatherer );
74+ Assert .notNull (scatterer , "'scatterer' must not be null" );
75+ Class <?> scattererClass = AopUtils .getTargetClass (scatterer );
76+ checkClass (scattererClass , "org.springframework.integration.router.RecipientListRouter" , "scatterer" );
77+ }
78+
6779 public ScatterGatherHandler (MessageChannel scatterChannel , MessageHandler gatherer ) {
6880 Assert .notNull (scatterChannel , "'scatterChannel' must not be null" );
6981 Assert .notNull (gatherer , "'gatherer' must not be null" );
@@ -73,13 +85,6 @@ public ScatterGatherHandler(MessageChannel scatterChannel, MessageHandler gather
7385 this .gatherer = gatherer ;
7486 }
7587
76- public ScatterGatherHandler (MessageHandler scatterer , MessageHandler gatherer ) {
77- this (new FixedSubscriberChannel (scatterer ), gatherer );
78- Assert .notNull (scatterer , "'scatterer' must not be null" );
79- Class <?> scattererClass = AopUtils .getTargetClass (scatterer );
80- checkClass (scattererClass , "org.springframework.integration.router.RecipientListRouter" , "scatterer" );
81- }
82-
8388 public void setGatherChannel (MessageChannel gatherChannel ) {
8489 this .gatherChannel = gatherChannel ;
8590 }
@@ -88,8 +93,20 @@ public void setGatherTimeout(long gatherTimeout) {
8893 this .gatherTimeout = gatherTimeout ;
8994 }
9095
96+ /**
97+ * Specify a {@link MessageChannel} bean name for async error processing.
98+ * Defaults to {@link IntegrationContextUtils#ERROR_CHANNEL_BEAN_NAME}.
99+ * @param errorChannelName the {@link MessageChannel} bean name for async error processing.
100+ * @since 5.1.3
101+ */
102+ public void setErrorChannelName (String errorChannelName ) {
103+ Assert .hasText (errorChannelName , "'errorChannelName' must not be empty." );
104+ this .errorChannelName = errorChannelName ;
105+ }
106+
91107 @ Override
92108 protected void doInit () {
109+ BeanFactory beanFactory = getBeanFactory ();
93110 if (this .gatherChannel == null ) {
94111 this .gatherChannel = new FixedSubscriberChannel (this .gatherer );
95112 }
@@ -101,33 +118,39 @@ else if (this.gatherChannel instanceof PollableChannel) {
101118 this .gatherEndpoint = new PollingConsumer ((PollableChannel ) this .gatherChannel , this .gatherer );
102119 ((PollingConsumer ) this .gatherEndpoint ).setReceiveTimeout (this .gatherTimeout );
103120 }
121+ else if (this .gatherChannel instanceof ReactiveStreamsSubscribableChannel ) {
122+ this .gatherEndpoint = new ReactiveStreamsConsumer (this .gatherChannel , this .gatherer );
123+ }
104124 else {
105- throw new MessagingException ("Unsupported 'replyChannel' type [" + this .gatherChannel .getClass () + "]."
106- + "SubscribableChannel or PollableChannel type are supported." );
125+ throw new BeanInitializationException ("Unsupported 'replyChannel' type '" +
126+ this .gatherChannel .getClass () + "'. " +
127+ "'SubscribableChannel', 'PollableChannel' or 'ReactiveStreamsSubscribableChannel' " +
128+ "types are supported." );
107129 }
108- this .gatherEndpoint .setBeanFactory (this . getBeanFactory () );
130+ this .gatherEndpoint .setBeanFactory (beanFactory );
109131 this .gatherEndpoint .afterPropertiesSet ();
110132 }
111133
112- ((MessageProducer ) this .gatherer ).setOutputChannel (new FixedSubscriberChannel (message -> {
113- MessageHeaders headers = message .getHeaders ();
114- if (headers .containsKey (GATHER_RESULT_CHANNEL )) {
115- Object gatherResultChannel = headers .get (GATHER_RESULT_CHANNEL );
116- if (gatherResultChannel instanceof MessageChannel ) {
117- messagingTemplate .send ((MessageChannel ) gatherResultChannel , message );
118- }
119- else if (gatherResultChannel instanceof String ) {
120- messagingTemplate .send ((String ) gatherResultChannel , message );
121- }
122- }
123- else {
124- throw new MessageDeliveryException (message ,
125- "The 'gatherResultChannel' header is required to delivery gather result." );
126- }
127- }));
128-
129- this .replyChannelRegistry = getBeanFactory ()
130- .getBean (IntegrationContextUtils .INTEGRATION_HEADER_CHANNEL_REGISTRY_BEAN_NAME ,
134+ ((MessageProducer ) this .gatherer )
135+ .setOutputChannel (new FixedSubscriberChannel (message -> {
136+ MessageHeaders headers = message .getHeaders ();
137+ if (headers .containsKey (GATHER_RESULT_CHANNEL )) {
138+ Object gatherResultChannel = headers .get (GATHER_RESULT_CHANNEL );
139+ if (gatherResultChannel instanceof MessageChannel ) {
140+ messagingTemplate .send ((MessageChannel ) gatherResultChannel , message );
141+ }
142+ else if (gatherResultChannel instanceof String ) {
143+ messagingTemplate .send ((String ) gatherResultChannel , message );
144+ }
145+ }
146+ else {
147+ throw new MessageDeliveryException (message ,
148+ "The 'gatherResultChannel' header is required to delivery gather result." );
149+ }
150+ }));
151+
152+ this .replyChannelRegistry =
153+ beanFactory .getBean (IntegrationContextUtils .INTEGRATION_HEADER_CHANNEL_REGISTRY_BEAN_NAME ,
131154 HeaderChannelRegistry .class );
132155 }
133156
@@ -137,11 +160,13 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
137160
138161 Object gatherResultChannelName = this .replyChannelRegistry .channelToChannelName (gatherResultChannel );
139162
140- Message <?> scatterMessage = getMessageBuilderFactory ()
141- .fromMessage (requestMessage )
142- .setHeader (GATHER_RESULT_CHANNEL , gatherResultChannelName )
143- .setReplyChannel (this .gatherChannel )
144- .build ();
163+ Message <?> scatterMessage =
164+ getMessageBuilderFactory ()
165+ .fromMessage (requestMessage )
166+ .setHeader (GATHER_RESULT_CHANNEL , gatherResultChannelName )
167+ .setReplyChannel (this .gatherChannel )
168+ .setErrorChannelName (this .errorChannelName )
169+ .build ();
145170
146171 this .messagingTemplate .send (this .scatterChannel , scatterMessage );
147172
@@ -151,7 +176,7 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
151176 .fromMessage (gatherResult )
152177 .removeHeader (GATHER_RESULT_CHANNEL )
153178 .setHeader (MessageHeaders .REPLY_CHANNEL , requestMessage .getHeaders ().getReplyChannel ())
154- .build ( );
179+ .setHeader ( MessageHeaders . ERROR_CHANNEL , requestMessage . getHeaders (). getErrorChannel () );
155180 }
156181
157182 return null ;
@@ -179,7 +204,8 @@ public boolean isRunning() {
179204 private void checkClass (Class <?> gathererClass , String className , String type ) throws LinkageError {
180205 try {
181206 Class <?> clazz = ClassUtils .forName (className , ClassUtils .getDefaultClassLoader ());
182- Assert .isAssignable (clazz , gathererClass , "the '" + type + "' must be an " + className + " instance" );
207+ Assert .isAssignable (clazz , gathererClass , () -> "the '" + type + "' must be an " + className + " " +
208+ "instance" );
183209 }
184210 catch (ClassNotFoundException e ) {
185211 throw new IllegalStateException ("The class for '" + className + "' cannot be loaded" , e );
0 commit comments