diff --git a/impl/pom.xml b/impl/pom.xml
index fe91f06..6844a62 100644
--- a/impl/pom.xml
+++ b/impl/pom.xml
@@ -27,6 +27,11 @@
a2a-java-sdk-tests-server-common
provided
+
+ jakarta.servlet
+ jakarta.servlet-api
+ provided
+
io.github.a2asdk
a2a-java-sdk-tests-server-common
diff --git a/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java
index 62252fa..93a0939 100644
--- a/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java
+++ b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java
@@ -1,10 +1,14 @@
package org.wildfly.extras.a2a.server.apps.jakarta;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;
+import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
@@ -13,6 +17,7 @@
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
+import jakarta.ws.rs.core.SecurityContext;
import jakarta.ws.rs.ext.ExceptionMapper;
import jakarta.ws.rs.ext.Provider;
import jakarta.ws.rs.sse.Sse;
@@ -21,6 +26,9 @@
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.databind.JsonMappingException;
import io.a2a.server.ExtendedAgentCard;
+import io.a2a.server.ServerCallContext;
+import io.a2a.server.auth.UnauthenticatedUser;
+import io.a2a.server.auth.User;
import io.a2a.server.requesthandlers.JSONRPCHandler;
import io.a2a.server.util.async.Internal;
import io.a2a.spec.AgentCard;
@@ -70,6 +78,10 @@ public class A2AServerResource {
@Internal
Executor executor;
+
+ @Inject
+ Instance callContextFactory;
+
/**
* Handles incoming POST requests to the main A2A endpoint. Dispatches the
* request to the appropriate JSON-RPC handler method and returns the response.
@@ -80,10 +92,14 @@ public class A2AServerResource {
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
- public JSONRPCResponse> handleNonStreamingRequests(NonStreamingJSONRPCRequest> request) {
+ public JSONRPCResponse> handleNonStreamingRequests(
+ NonStreamingJSONRPCRequest> request, @Context HttpServletRequest httpRequest,
+ @Context SecurityContext securityContext) {
+
+ ServerCallContext context = createCallContext(httpRequest, securityContext);
LOGGER.debug("Handling non-streaming request");
try {
- return processNonStreamingRequest(request);
+ return processNonStreamingRequest(request, context);
} finally {
LOGGER.debug("Completed non-streaming request");
}
@@ -96,9 +112,13 @@ public JSONRPCResponse> handleNonStreamingRequests(NonStreamingJSONRPCRequest<
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.SERVER_SENT_EVENTS)
- public void handleStreamingRequests(StreamingJSONRPCRequest> request, @Context SseEventSink sseEventSink, @Context Sse sse) {
+ public void handleStreamingRequests(
+ StreamingJSONRPCRequest> request, @Context SseEventSink sseEventSink,
+ @Context Sse sse, @Context HttpServletRequest httpRequest,
+ @Context SecurityContext securityContext) {
+ ServerCallContext context = createCallContext(httpRequest, securityContext);
LOGGER.debug("Handling streaming request");
- executor.execute(() -> processStreamingRequest(request, sseEventSink, sse));
+ executor.execute(() -> processStreamingRequest(request, sseEventSink, sse, context));
LOGGER.debug("Submitted streaming request for async processing");
}
@@ -142,33 +162,35 @@ public Response getAuthenticatedExtendedAgentCard() {
.build();
}
- private JSONRPCResponse> processNonStreamingRequest(NonStreamingJSONRPCRequest> request) {
- if (request instanceof GetTaskRequest) {
- return jsonRpcHandler.onGetTask((GetTaskRequest) request);
- } else if (request instanceof CancelTaskRequest) {
- return jsonRpcHandler.onCancelTask((CancelTaskRequest) request);
- } else if (request instanceof SetTaskPushNotificationConfigRequest) {
- return jsonRpcHandler.setPushNotificationConfig((SetTaskPushNotificationConfigRequest) request);
- } else if (request instanceof GetTaskPushNotificationConfigRequest) {
- return jsonRpcHandler.getPushNotificationConfig((GetTaskPushNotificationConfigRequest) request);
- } else if (request instanceof SendMessageRequest) {
- return jsonRpcHandler.onMessageSend((SendMessageRequest) request);
- } else if (request instanceof ListTaskPushNotificationConfigRequest) {
- return jsonRpcHandler.listPushNotificationConfig((ListTaskPushNotificationConfigRequest) request);
- } else if (request instanceof DeleteTaskPushNotificationConfigRequest) {
- return jsonRpcHandler.deletePushNotificationConfig((DeleteTaskPushNotificationConfigRequest) request);
+ private JSONRPCResponse> processNonStreamingRequest(NonStreamingJSONRPCRequest> request,
+ ServerCallContext context) {
+ if (request instanceof GetTaskRequest req) {
+ return jsonRpcHandler.onGetTask(req, context);
+ } else if (request instanceof CancelTaskRequest req) {
+ return jsonRpcHandler.onCancelTask(req, context);
+ } else if (request instanceof SetTaskPushNotificationConfigRequest req) {
+ return jsonRpcHandler.setPushNotificationConfig(req, context);
+ } else if (request instanceof GetTaskPushNotificationConfigRequest req) {
+ return jsonRpcHandler.getPushNotificationConfig(req, context);
+ } else if (request instanceof SendMessageRequest req) {
+ return jsonRpcHandler.onMessageSend(req, context);
+ } else if (request instanceof ListTaskPushNotificationConfigRequest req) {
+ return jsonRpcHandler.listPushNotificationConfig(req, context);
+ } else if (request instanceof DeleteTaskPushNotificationConfigRequest req) {
+ return jsonRpcHandler.deletePushNotificationConfig(req, context);
} else {
return generateErrorResponse(request, new UnsupportedOperationError());
}
}
- private void processStreamingRequest(StreamingJSONRPCRequest> request, SseEventSink sseEventSink, Sse sse) {
+ private void processStreamingRequest(StreamingJSONRPCRequest> request, SseEventSink sseEventSink, Sse sse,
+ ServerCallContext context) {
Flow.Publisher extends JSONRPCResponse>> publisher;
- if (request instanceof SendStreamingMessageRequest) {
- publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request);
+ if (request instanceof SendStreamingMessageRequest req) {
+ publisher = jsonRpcHandler.onMessageSendStream(req, context);
handleStreamingResponse(publisher, sseEventSink, sse);
- } else if (request instanceof TaskResubscriptionRequest) {
- publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request);
+ } else if (request instanceof TaskResubscriptionRequest req) {
+ publisher = jsonRpcHandler.onResubscribeToTask(req, context);
handleStreamingResponse(publisher, sseEventSink, sse);
}
}
@@ -218,6 +240,46 @@ public static void setStreamingIsSubscribedRunnable(Runnable streamingIsSubscrib
A2AServerResource.streamingIsSubscribedRunnable = streamingIsSubscribedRunnable;
}
+ private ServerCallContext createCallContext(HttpServletRequest request, SecurityContext securityContext) {
+
+ if (callContextFactory.isUnsatisfied()) {
+ User user;
+
+ if (securityContext.getUserPrincipal() == null) {
+ user = UnauthenticatedUser.INSTANCE;
+ } else {
+ user = new User() {
+ @Override
+ public boolean isAuthenticated() {
+ return true;
+ }
+
+ @Override
+ public String getUsername() {
+ return securityContext.getUserPrincipal().getName();
+ }
+ };
+ }
+ Map state = new HashMap<>();
+ // TODO Python's impl has
+ // state['auth'] = request.auth
+ // in jsonrpc_app.py. Figure out what this maps to in what we have here
+
+ Map headers = new HashMap<>();
+ for (Enumeration headerNames = request.getHeaderNames(); headerNames.hasMoreElements() ; ) {
+ String name = headerNames.nextElement();
+ headers.put(name, headers.get(name));
+ }
+
+ state.put("headers", headers);
+
+ return new ServerCallContext(user, state);
+ } else {
+ CallContextFactory builder = callContextFactory.get();
+ return builder.build(request);
+ }
+ }
+
@Provider
public static class JsonParseExceptionMapper implements ExceptionMapper {
diff --git a/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/CallContextFactory.java b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/CallContextFactory.java
new file mode 100644
index 0000000..f850625
--- /dev/null
+++ b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/CallContextFactory.java
@@ -0,0 +1,10 @@
+package org.wildfly.extras.a2a.server.apps.jakarta;
+
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import io.a2a.server.ServerCallContext;
+
+public interface CallContextFactory {
+ ServerCallContext build(HttpServletRequest request);
+}