Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@

<build>
<plugins>
<plugin>
<groupId>io.spring.javaformat</groupId>
<artifactId>spring-javaformat-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.util.CollectionUtils;


/**
*
* @author Oleg Zhurakousky
* @since 3.2
*
Expand All @@ -38,8 +36,8 @@ public class AWSCompanionAutoConfiguration {
@Bean
public AWSTypesMessageConverter awsTypesMessageConverter(GenericApplicationContext applicationContext) {
JsonMapper jsonMapper = CollectionUtils.isEmpty(applicationContext.getBeansOfType(JsonMapper.class).values())
? new JacksonMapper(new ObjectMapper())
: applicationContext.getBean(JsonMapper.class);
? new JacksonMapper(new ObjectMapper()) : applicationContext.getBean(JsonMapper.class);
return new AWSTypesMessageConverter(jsonMapper);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.springframework.util.StreamUtils;

/**
*
* @author Oleg Zhurakousky
* @author Anton Barkan
*
Expand Down Expand Up @@ -79,13 +78,13 @@ static boolean isSupportedAWSType(Type type) {
type = FunctionTypeUtils.getImmediateGenericType(type, 0);
}
Class<?> rawType = FunctionTypeUtils.getRawType(type);
return rawType != null && rawType.getPackage() != null &&
rawType.getPackage().getName().startsWith(
"com.amazonaws.services.lambda.runtime.events");
return rawType != null && rawType.getPackage() != null
&& rawType.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events");
}

@SuppressWarnings("rawtypes")
public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper, Context context) throws IOException {
public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier,
JsonMapper jsonMapper, Context context) throws IOException {
if (inputType != null && FunctionTypeUtils.isMessage(inputType)) {
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
}
Expand All @@ -101,7 +100,8 @@ public static Message generateMessage(InputStream payload, Type inputType, boole
}
}

public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier,
JsonMapper jsonMapper) {
return generateMessage(payload, inputType, isSupplier, jsonMapper, null);
}

Expand All @@ -124,10 +124,8 @@ public static Message<byte[]> generateMessage(byte[] payload, Type inputType, bo

Message<byte[]> requestMessage;

MessageBuilder builder = MessageBuilder
.withPayload(structMessage instanceof Map msg && msg.containsKey("payload")
? (msg.get("payload"))
: payload);
MessageBuilder builder = MessageBuilder.withPayload(
structMessage instanceof Map msg && msg.containsKey("payload") ? (msg.get("payload")) : payload);
if (isApiGateway) {
builder.setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true);
if (JsonMapper.isJsonStringRepresentsCollection(((Map) structMessage).get("body"))) {
Expand Down Expand Up @@ -165,7 +163,8 @@ private static Object convertFromJsonIfNecessary(Object value, JsonMapper object
}

@SuppressWarnings("unchecked")
public static byte[] generateOutputFromObject(Message<?> requestMessage, Object output, JsonMapper objectMapper, Type functionOutputType) {
public static byte[] generateOutputFromObject(Message<?> requestMessage, Object output, JsonMapper objectMapper,
Type functionOutputType) {
Message<byte[]> responseMessage = null;
if (output instanceof Publisher<?>) {
List<Object> result = new ArrayList<>();
Expand Down Expand Up @@ -209,25 +208,27 @@ else if (result.size() > 1) {
}

@SuppressWarnings({ "rawtypes", "unchecked" })
public static byte[] generateOutput(Message requestMessage, Message<?> responseMessage,
JsonMapper objectMapper, Type functionOutputType) {
public static byte[] generateOutput(Message requestMessage, Message<?> responseMessage, JsonMapper objectMapper,
Type functionOutputType) {

if (isSupportedAWSType(functionOutputType)) {
return extractPayload((Message<Object>) responseMessage, objectMapper);
}

byte[] responseBytes = responseMessage == null ? "\"OK\"".getBytes() : extractPayload((Message<Object>) responseMessage, objectMapper);
if (requestMessage.getHeaders().containsKey(AWS_API_GATEWAY) && ((boolean) requestMessage.getHeaders().get(AWS_API_GATEWAY))) {
byte[] responseBytes = responseMessage == null ? "\"OK\"".getBytes()
: extractPayload((Message<Object>) responseMessage, objectMapper);
if (requestMessage.getHeaders().containsKey(AWS_API_GATEWAY)
&& ((boolean) requestMessage.getHeaders().get(AWS_API_GATEWAY))) {
Map<String, Object> response = new HashMap<String, Object>();
response.put(IS_BASE64_ENCODED, responseMessage != null && responseMessage.getHeaders().containsKey(IS_BASE64_ENCODED)
? responseMessage.getHeaders().get(IS_BASE64_ENCODED) : false);
response.put(IS_BASE64_ENCODED,
responseMessage != null && responseMessage.getHeaders().containsKey(IS_BASE64_ENCODED)
? responseMessage.getHeaders().get(IS_BASE64_ENCODED) : false);

AtomicReference<MessageHeaders> headers = new AtomicReference<>();
int statusCode = HttpStatus.OK.value();
if (responseMessage != null) {
headers.set(responseMessage.getHeaders());
statusCode = headers.get().containsKey(STATUS_CODE)
? (int) headers.get().get(STATUS_CODE)
statusCode = headers.get().containsKey(STATUS_CODE) ? (int) headers.get().get(STATUS_CODE)
: HttpStatus.OK.value();
}

Expand All @@ -237,8 +238,8 @@ public static byte[] generateOutput(Message requestMessage, Message<?> responseM
response.put("statusDescription", httpStatus.toString());
}

String body = responseMessage == null
? "\"OK\"" : new String(extractPayload((Message<Object>) responseMessage, objectMapper), StandardCharsets.UTF_8);
String body = responseMessage == null ? "\"OK\"" : new String(
extractPayload((Message<Object>) responseMessage, objectMapper), StandardCharsets.UTF_8);
response.put(BODY, body);
if (responseMessage != null) {
Map<String, String> responseHeaders = new HashMap<>();
Expand All @@ -259,4 +260,5 @@ public static byte[] generateOutput(Message requestMessage, Message<?> responseM
private static boolean isRequestKinesis(Message<Object> requestMessage) {
return requestMessage.getHeaders().containsKey("Records");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@
import org.springframework.util.MimeType;

/**
* Implementation of {@link MessageConverter} which uses Jackson or Gson libraries to do the
* actual conversion via {@link JsonMapper} instance.
* Implementation of {@link MessageConverter} which uses Jackson or Gson libraries to do
* the actual conversion via {@link JsonMapper} instance.
*
* @author Oleg Zhurakousky
*
* @since 3.2
*/
class AWSTypesMessageConverter extends JsonMessageConverter {
Expand All @@ -52,8 +51,9 @@ class AWSTypesMessageConverter extends JsonMessageConverter {
private final AtomicReference<S3EventSerializer> s3EventSerializer = new AtomicReference<>();

AWSTypesMessageConverter(JsonMapper jsonMapper) {
this(jsonMapper, new MimeType("application", "json"), new MimeType(CloudEventMessageUtils.APPLICATION_CLOUDEVENTS.getType(),
CloudEventMessageUtils.APPLICATION_CLOUDEVENTS.getSubtype() + "+json"));
this(jsonMapper, new MimeType("application", "json"),
new MimeType(CloudEventMessageUtils.APPLICATION_CLOUDEVENTS.getType(),
CloudEventMessageUtils.APPLICATION_CLOUDEVENTS.getSubtype() + "+json"));
}

AWSTypesMessageConverter(JsonMapper jsonMapper, MimeType... supportedMimeTypes) {
Expand All @@ -69,9 +69,10 @@ protected boolean canConvertFrom(Message<?> message, @Nullable Class<?> targetCl
if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_EVENT)) {
return ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_EVENT));
}
//TODO Do we really need the ^^ above? It seems like the line below dows the trick
else if (targetClass.getPackage() != null &&
targetClass.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events")) {
// TODO Do we really need the ^^ above? It seems like the line below dows the
// trick
else if (targetClass.getPackage() != null
&& targetClass.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events")) {
return true;
}
return false;
Expand All @@ -82,9 +83,10 @@ protected Object convertFromInternal(Message<?> message, Class<?> targetClass, @
if (message.getPayload().getClass().isAssignableFrom(targetClass)) {
return message.getPayload();
}
if (targetClass.getPackage() != null &&
targetClass.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events")) {
PojoSerializer<?> serializer = LambdaEventSerializers.serializerFor(targetClass, Thread.currentThread().getContextClassLoader());
if (targetClass.getPackage() != null
&& targetClass.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events")) {
PojoSerializer<?> serializer = LambdaEventSerializers.serializerFor(targetClass,
Thread.currentThread().getContextClassLoader());
Object event = serializer.fromJson(new ByteArrayInputStream((byte[]) message.getPayload()));
return event;
}
Expand Down Expand Up @@ -115,24 +117,24 @@ protected boolean canConvertTo(Object payload, @Nullable MessageHeaders headers)
return true;
}


@SuppressWarnings("unchecked")
@Override
protected Object convertToInternal(Object payload, @Nullable MessageHeaders headers,
@Nullable Object conversionHint) {
if (payload instanceof String && headers.containsKey(AWSLambdaUtils.IS_BASE64_ENCODED) && (boolean) headers.get(AWSLambdaUtils.IS_BASE64_ENCODED)) {
if (payload instanceof String && headers.containsKey(AWSLambdaUtils.IS_BASE64_ENCODED)
&& (boolean) headers.get(AWSLambdaUtils.IS_BASE64_ENCODED)) {
return ((String) payload).getBytes(StandardCharsets.UTF_8);
}
if (payload.getClass().getName().equals("com.amazonaws.services.lambda.runtime.events.S3Event")) {
if (this.s3EventSerializer.get() == null) {
this.s3EventSerializer.set(new S3EventSerializer<>().withClassLoader(ClassUtils.getDefaultClassLoader()));
this.s3EventSerializer
.set(new S3EventSerializer<>().withClassLoader(ClassUtils.getDefaultClassLoader()));
}
ByteArrayOutputStream stream = new ByteArrayOutputStream();
this.s3EventSerializer.get().toJson(payload, stream);
return stream.toByteArray();
}


return jsonMapper.toJson(payload);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
import static org.apache.http.HttpHeaders.USER_AGENT;

/**
* Event loop and necessary configurations to support AWS Lambda
* Custom Runtime - https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html.
* Event loop and necessary configurations to support AWS Lambda Custom Runtime -
* https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html.
*
* @author Oleg Zhurakousky
* @author Mark Sailes
Expand All @@ -69,13 +69,15 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
private static Log logger = LogFactory.getLog(CustomRuntimeEventLoop.class);

static final String LAMBDA_VERSION_DATE = "2018-06-01";

private static final String LAMBDA_ERROR_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/{2}/error";

private static final String LAMBDA_RUNTIME_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/next";

private static final String LAMBDA_INVOCATION_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/{2}/response";
private static final String USER_AGENT_VALUE = String.format(
"spring-cloud-function/%s-%s",
System.getProperty("java.runtime.version"),
extractVersion());

private static final String USER_AGENT_VALUE = String.format("spring-cloud-function/%s-%s",
System.getProperty("java.runtime.version"), extractVersion());

private final ConfigurableApplicationContext applicationContext;

Expand Down Expand Up @@ -125,7 +127,9 @@ private void eventLoop(ConfigurableApplicationContext context) {
logger.debug("Event URI: " + eventUri);
}

RequestEntity<Void> requestEntity = RequestEntity.get(URI.create(eventUri)).header(USER_AGENT, USER_AGENT_VALUE).build();
RequestEntity<Void> requestEntity = RequestEntity.get(URI.create(eventUri))
.header(USER_AGENT, USER_AGENT_VALUE)
.build();
FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class);
RestTemplate rest = new RestTemplate();
JsonMapper mapper = context.getBean(JsonMapper.class);
Expand All @@ -144,18 +148,22 @@ private void eventLoop(ConfigurableApplicationContext context) {
if (response != null && response.hasBody()) {
String requestId = response.getHeaders().getFirst("Lambda-Runtime-Aws-Request-Id");
try {
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders());
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog,
response.getHeaders());

ByteArrayInputStream is = new ByteArrayInputStream(response.getBody().getBytes(StandardCharsets.UTF_8));
Message<?> requestMessage = AWSLambdaUtils.generateMessage(is, function.getInputType(), function.isSupplier(), mapper, clientContext);
ByteArrayInputStream is = new ByteArrayInputStream(
response.getBody().getBytes(StandardCharsets.UTF_8));
Message<?> requestMessage = AWSLambdaUtils.generateMessage(is, function.getInputType(),
function.isSupplier(), mapper, clientContext);
requestMessage = enrichTraceHeaders(response.getHeaders(), requestMessage);

Object functionResponse = function.apply(requestMessage);

byte[] responseBytes = AWSLambdaUtils.generateOutputFromObject(requestMessage, functionResponse, mapper, function.getOutputType());
byte[] responseBytes = AWSLambdaUtils.generateOutputFromObject(requestMessage, functionResponse,
mapper, function.getOutputType());

String invocationUrl = MessageFormat
.format(LAMBDA_INVOCATION_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE, requestId);
String invocationUrl = MessageFormat.format(LAMBDA_INVOCATION_URL_TEMPLATE, runtimeApi,
LAMBDA_VERSION_DATE, requestId);

ResponseEntity<Object> result = rest.exchange(RequestEntity.post(URI.create(invocationUrl))
.header(USER_AGENT, USER_AGENT_VALUE)
Expand All @@ -179,9 +187,7 @@ private Message<?> enrichTraceHeaders(HttpHeaders headers, Message<?> message) {
String headerTrace = trim(headers.getFirst("X-Amzn-Trace-Id"));

// prefer Lambda runtime header, then environment, then inbound header
String resolved = runtimeTrace != null ? runtimeTrace
: envTrace != null ? envTrace
: headerTrace;
String resolved = runtimeTrace != null ? runtimeTrace : envTrace != null ? envTrace : headerTrace;

if (resolved != null) {
System.setProperty("com.amazonaws.xray.traceHeader", resolved);
Expand Down Expand Up @@ -280,7 +286,8 @@ public String toString() {
return context;
}

private void propagateAwsError(String requestId, Exception e, JsonMapper mapper, String runtimeApi, RestTemplate rest) {
private void propagateAwsError(String requestId, Exception e, JsonMapper mapper, String runtimeApi,
RestTemplate rest) {
String errorMessage = e.getMessage();
String errorType = e.getClass().getSimpleName();
StringWriter sw = new StringWriter();
Expand All @@ -293,10 +300,11 @@ private void propagateAwsError(String requestId, Exception e, JsonMapper mapper,
em.put("stackTrace", stackTrace);
byte[] outputBody = mapper.toJson(em);
try {
String errorUrl = MessageFormat.format(LAMBDA_ERROR_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE, requestId);
ResponseEntity<Object> result = rest.exchange(RequestEntity.post(URI.create(errorUrl))
.header(USER_AGENT, USER_AGENT_VALUE)
.body(outputBody), Object.class);
String errorUrl = MessageFormat.format(LAMBDA_ERROR_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE,
requestId);
ResponseEntity<Object> result = rest.exchange(
RequestEntity.post(URI.create(errorUrl)).header(USER_AGENT, USER_AGENT_VALUE).body(outputBody),
Object.class);
if (logger.isInfoEnabled()) {
logger.info("Result ERROR status: " + result.getStatusCode());
}
Expand Down Expand Up @@ -366,8 +374,8 @@ private FunctionInvocationWrapper locateFunction(Environment environment, Functi
this.routingFunction = functionCatalog.lookup(RoutingFunction.FUNCTION_NAME, "application/json");
if (this.routingFunction != null && logger.isInfoEnabled()) {
logger.info("Will default to RoutingFunction, since multiple functions available in FunctionCatalog."
+ "Expecting 'spring.cloud.function.definition' or 'spring.cloud.function.routing-expression' as Message headers. "
+ "If invocation is over API Gateway, Message headers can be provided as HTTP headers.");
+ "Expecting 'spring.cloud.function.definition' or 'spring.cloud.function.routing-expression' as Message headers. "
+ "If invocation is over API Gateway, Message headers can be provided as HTTP headers.");
}
function = this.routingFunction;
}
Expand Down Expand Up @@ -399,4 +407,5 @@ private static String extractVersion() {
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ private boolean isCustomRuntime(Environment environment) {
return false;
}


private boolean isWebExportEnabled(GenericApplicationContext context) {
Boolean enabled = context.getEnvironment()
.getProperty("spring.cloud.function.web.export.enabled", Boolean.class);
.getProperty("spring.cloud.function.web.export.enabled", Boolean.class);
return enabled != null && enabled;
}

Expand Down
Loading