diff --git a/impl/pom.xml b/impl/pom.xml index fe91f06..6844a62 100644 --- a/impl/pom.xml +++ b/impl/pom.xml @@ -27,6 +27,11 @@ a2a-java-sdk-tests-server-common provided + + jakarta.servlet + jakarta.servlet-api + provided + io.github.a2asdk a2a-java-sdk-tests-server-common diff --git a/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java index 62252fa..93a0939 100644 --- a/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java +++ b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/A2AServerResource.java @@ -1,10 +1,14 @@ package org.wildfly.extras.a2a.server.apps.jakarta; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.GET; import jakarta.ws.rs.POST; @@ -13,6 +17,7 @@ import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.SecurityContext; import jakarta.ws.rs.ext.ExceptionMapper; import jakarta.ws.rs.ext.Provider; import jakarta.ws.rs.sse.Sse; @@ -21,6 +26,9 @@ import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.databind.JsonMappingException; import io.a2a.server.ExtendedAgentCard; +import io.a2a.server.ServerCallContext; +import io.a2a.server.auth.UnauthenticatedUser; +import io.a2a.server.auth.User; import io.a2a.server.requesthandlers.JSONRPCHandler; import io.a2a.server.util.async.Internal; import io.a2a.spec.AgentCard; @@ -70,6 +78,10 @@ public class A2AServerResource { @Internal Executor executor; + + @Inject + Instance callContextFactory; + /** * Handles incoming POST requests to the main A2A endpoint. Dispatches the * request to the appropriate JSON-RPC handler method and returns the response. @@ -80,10 +92,14 @@ public class A2AServerResource { @POST @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public JSONRPCResponse handleNonStreamingRequests(NonStreamingJSONRPCRequest request) { + public JSONRPCResponse handleNonStreamingRequests( + NonStreamingJSONRPCRequest request, @Context HttpServletRequest httpRequest, + @Context SecurityContext securityContext) { + + ServerCallContext context = createCallContext(httpRequest, securityContext); LOGGER.debug("Handling non-streaming request"); try { - return processNonStreamingRequest(request); + return processNonStreamingRequest(request, context); } finally { LOGGER.debug("Completed non-streaming request"); } @@ -96,9 +112,13 @@ public JSONRPCResponse handleNonStreamingRequests(NonStreamingJSONRPCRequest< @POST @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.SERVER_SENT_EVENTS) - public void handleStreamingRequests(StreamingJSONRPCRequest request, @Context SseEventSink sseEventSink, @Context Sse sse) { + public void handleStreamingRequests( + StreamingJSONRPCRequest request, @Context SseEventSink sseEventSink, + @Context Sse sse, @Context HttpServletRequest httpRequest, + @Context SecurityContext securityContext) { + ServerCallContext context = createCallContext(httpRequest, securityContext); LOGGER.debug("Handling streaming request"); - executor.execute(() -> processStreamingRequest(request, sseEventSink, sse)); + executor.execute(() -> processStreamingRequest(request, sseEventSink, sse, context)); LOGGER.debug("Submitted streaming request for async processing"); } @@ -142,33 +162,35 @@ public Response getAuthenticatedExtendedAgentCard() { .build(); } - private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request) { - if (request instanceof GetTaskRequest) { - return jsonRpcHandler.onGetTask((GetTaskRequest) request); - } else if (request instanceof CancelTaskRequest) { - return jsonRpcHandler.onCancelTask((CancelTaskRequest) request); - } else if (request instanceof SetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.setPushNotificationConfig((SetTaskPushNotificationConfigRequest) request); - } else if (request instanceof GetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.getPushNotificationConfig((GetTaskPushNotificationConfigRequest) request); - } else if (request instanceof SendMessageRequest) { - return jsonRpcHandler.onMessageSend((SendMessageRequest) request); - } else if (request instanceof ListTaskPushNotificationConfigRequest) { - return jsonRpcHandler.listPushNotificationConfig((ListTaskPushNotificationConfigRequest) request); - } else if (request instanceof DeleteTaskPushNotificationConfigRequest) { - return jsonRpcHandler.deletePushNotificationConfig((DeleteTaskPushNotificationConfigRequest) request); + private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request, + ServerCallContext context) { + if (request instanceof GetTaskRequest req) { + return jsonRpcHandler.onGetTask(req, context); + } else if (request instanceof CancelTaskRequest req) { + return jsonRpcHandler.onCancelTask(req, context); + } else if (request instanceof SetTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.setPushNotificationConfig(req, context); + } else if (request instanceof GetTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.getPushNotificationConfig(req, context); + } else if (request instanceof SendMessageRequest req) { + return jsonRpcHandler.onMessageSend(req, context); + } else if (request instanceof ListTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.listPushNotificationConfig(req, context); + } else if (request instanceof DeleteTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.deletePushNotificationConfig(req, context); } else { return generateErrorResponse(request, new UnsupportedOperationError()); } } - private void processStreamingRequest(StreamingJSONRPCRequest request, SseEventSink sseEventSink, Sse sse) { + private void processStreamingRequest(StreamingJSONRPCRequest request, SseEventSink sseEventSink, Sse sse, + ServerCallContext context) { Flow.Publisher> publisher; - if (request instanceof SendStreamingMessageRequest) { - publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request); + if (request instanceof SendStreamingMessageRequest req) { + publisher = jsonRpcHandler.onMessageSendStream(req, context); handleStreamingResponse(publisher, sseEventSink, sse); - } else if (request instanceof TaskResubscriptionRequest) { - publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request); + } else if (request instanceof TaskResubscriptionRequest req) { + publisher = jsonRpcHandler.onResubscribeToTask(req, context); handleStreamingResponse(publisher, sseEventSink, sse); } } @@ -218,6 +240,46 @@ public static void setStreamingIsSubscribedRunnable(Runnable streamingIsSubscrib A2AServerResource.streamingIsSubscribedRunnable = streamingIsSubscribedRunnable; } + private ServerCallContext createCallContext(HttpServletRequest request, SecurityContext securityContext) { + + if (callContextFactory.isUnsatisfied()) { + User user; + + if (securityContext.getUserPrincipal() == null) { + user = UnauthenticatedUser.INSTANCE; + } else { + user = new User() { + @Override + public boolean isAuthenticated() { + return true; + } + + @Override + public String getUsername() { + return securityContext.getUserPrincipal().getName(); + } + }; + } + Map state = new HashMap<>(); + // TODO Python's impl has + // state['auth'] = request.auth + // in jsonrpc_app.py. Figure out what this maps to in what we have here + + Map headers = new HashMap<>(); + for (Enumeration headerNames = request.getHeaderNames(); headerNames.hasMoreElements() ; ) { + String name = headerNames.nextElement(); + headers.put(name, headers.get(name)); + } + + state.put("headers", headers); + + return new ServerCallContext(user, state); + } else { + CallContextFactory builder = callContextFactory.get(); + return builder.build(request); + } + } + @Provider public static class JsonParseExceptionMapper implements ExceptionMapper { diff --git a/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/CallContextFactory.java b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/CallContextFactory.java new file mode 100644 index 0000000..f850625 --- /dev/null +++ b/impl/src/main/java/org/wildfly/extras/a2a/server/apps/jakarta/CallContextFactory.java @@ -0,0 +1,10 @@ +package org.wildfly.extras.a2a.server.apps.jakarta; + + +import jakarta.servlet.http.HttpServletRequest; + +import io.a2a.server.ServerCallContext; + +public interface CallContextFactory { + ServerCallContext build(HttpServletRequest request); +}