Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions impl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
<artifactId>a2a-java-sdk-tests-server-common</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>jakarta.servlet</groupId>
<artifactId>jakarta.servlet-api</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>io.github.a2asdk</groupId>
<artifactId>a2a-java-sdk-tests-server-common</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package org.wildfly.extras.a2a.server.apps.jakarta;

import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;

import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
Expand All @@ -13,6 +17,7 @@
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.SecurityContext;
import jakarta.ws.rs.ext.ExceptionMapper;
import jakarta.ws.rs.ext.Provider;
import jakarta.ws.rs.sse.Sse;
Expand All @@ -21,6 +26,9 @@
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.databind.JsonMappingException;
import io.a2a.server.ExtendedAgentCard;
import io.a2a.server.ServerCallContext;
import io.a2a.server.auth.UnauthenticatedUser;
import io.a2a.server.auth.User;
import io.a2a.server.requesthandlers.JSONRPCHandler;
import io.a2a.server.util.async.Internal;
import io.a2a.spec.AgentCard;
Expand Down Expand Up @@ -70,6 +78,10 @@ public class A2AServerResource {
@Internal
Executor executor;


@Inject
Instance<CallContextFactory> callContextFactory;

/**
* Handles incoming POST requests to the main A2A endpoint. Dispatches the
* request to the appropriate JSON-RPC handler method and returns the response.
Expand All @@ -80,10 +92,14 @@ public class A2AServerResource {
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public JSONRPCResponse<?> handleNonStreamingRequests(NonStreamingJSONRPCRequest<?> request) {
public JSONRPCResponse<?> handleNonStreamingRequests(
NonStreamingJSONRPCRequest<?> request, @Context HttpServletRequest httpRequest,
@Context SecurityContext securityContext) {

ServerCallContext context = createCallContext(httpRequest, securityContext);
LOGGER.debug("Handling non-streaming request");
try {
return processNonStreamingRequest(request);
return processNonStreamingRequest(request, context);
} finally {
LOGGER.debug("Completed non-streaming request");
}
Expand All @@ -96,9 +112,13 @@ public JSONRPCResponse<?> handleNonStreamingRequests(NonStreamingJSONRPCRequest<
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.SERVER_SENT_EVENTS)
public void handleStreamingRequests(StreamingJSONRPCRequest<?> request, @Context SseEventSink sseEventSink, @Context Sse sse) {
public void handleStreamingRequests(
StreamingJSONRPCRequest<?> request, @Context SseEventSink sseEventSink,
@Context Sse sse, @Context HttpServletRequest httpRequest,
@Context SecurityContext securityContext) {
ServerCallContext context = createCallContext(httpRequest, securityContext);
LOGGER.debug("Handling streaming request");
executor.execute(() -> processStreamingRequest(request, sseEventSink, sse));
executor.execute(() -> processStreamingRequest(request, sseEventSink, sse, context));
LOGGER.debug("Submitted streaming request for async processing");
}

Expand Down Expand Up @@ -142,33 +162,35 @@ public Response getAuthenticatedExtendedAgentCard() {
.build();
}

private JSONRPCResponse<?> processNonStreamingRequest(NonStreamingJSONRPCRequest<?> request) {
if (request instanceof GetTaskRequest) {
return jsonRpcHandler.onGetTask((GetTaskRequest) request);
} else if (request instanceof CancelTaskRequest) {
return jsonRpcHandler.onCancelTask((CancelTaskRequest) request);
} else if (request instanceof SetTaskPushNotificationConfigRequest) {
return jsonRpcHandler.setPushNotificationConfig((SetTaskPushNotificationConfigRequest) request);
} else if (request instanceof GetTaskPushNotificationConfigRequest) {
return jsonRpcHandler.getPushNotificationConfig((GetTaskPushNotificationConfigRequest) request);
} else if (request instanceof SendMessageRequest) {
return jsonRpcHandler.onMessageSend((SendMessageRequest) request);
} else if (request instanceof ListTaskPushNotificationConfigRequest) {
return jsonRpcHandler.listPushNotificationConfig((ListTaskPushNotificationConfigRequest) request);
} else if (request instanceof DeleteTaskPushNotificationConfigRequest) {
return jsonRpcHandler.deletePushNotificationConfig((DeleteTaskPushNotificationConfigRequest) request);
private JSONRPCResponse<?> processNonStreamingRequest(NonStreamingJSONRPCRequest<?> request,
ServerCallContext context) {
if (request instanceof GetTaskRequest req) {
return jsonRpcHandler.onGetTask(req, context);
} else if (request instanceof CancelTaskRequest req) {
return jsonRpcHandler.onCancelTask(req, context);
} else if (request instanceof SetTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.setPushNotificationConfig(req, context);
} else if (request instanceof GetTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.getPushNotificationConfig(req, context);
} else if (request instanceof SendMessageRequest req) {
return jsonRpcHandler.onMessageSend(req, context);
} else if (request instanceof ListTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.listPushNotificationConfig(req, context);
} else if (request instanceof DeleteTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.deletePushNotificationConfig(req, context);
} else {
return generateErrorResponse(request, new UnsupportedOperationError());
}
}

private void processStreamingRequest(StreamingJSONRPCRequest<?> request, SseEventSink sseEventSink, Sse sse) {
private void processStreamingRequest(StreamingJSONRPCRequest<?> request, SseEventSink sseEventSink, Sse sse,
ServerCallContext context) {
Flow.Publisher<? extends JSONRPCResponse<?>> publisher;
if (request instanceof SendStreamingMessageRequest) {
publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request);
if (request instanceof SendStreamingMessageRequest req) {
publisher = jsonRpcHandler.onMessageSendStream(req, context);
handleStreamingResponse(publisher, sseEventSink, sse);
} else if (request instanceof TaskResubscriptionRequest) {
publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request);
} else if (request instanceof TaskResubscriptionRequest req) {
publisher = jsonRpcHandler.onResubscribeToTask(req, context);
handleStreamingResponse(publisher, sseEventSink, sse);
}
}
Expand Down Expand Up @@ -218,6 +240,46 @@ public static void setStreamingIsSubscribedRunnable(Runnable streamingIsSubscrib
A2AServerResource.streamingIsSubscribedRunnable = streamingIsSubscribedRunnable;
}

private ServerCallContext createCallContext(HttpServletRequest request, SecurityContext securityContext) {

if (callContextFactory.isUnsatisfied()) {
User user;

if (securityContext.getUserPrincipal() == null) {
user = UnauthenticatedUser.INSTANCE;
} else {
user = new User() {
@Override
public boolean isAuthenticated() {
return true;
}

@Override
public String getUsername() {
return securityContext.getUserPrincipal().getName();
}
};
}
Map<String, Object> state = new HashMap<>();
// TODO Python's impl has
// state['auth'] = request.auth
// in jsonrpc_app.py. Figure out what this maps to in what we have here

Map<String, String> headers = new HashMap<>();
for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements() ; ) {
String name = headerNames.nextElement();
headers.put(name, headers.get(name));
}

state.put("headers", headers);

return new ServerCallContext(user, state);
} else {
CallContextFactory builder = callContextFactory.get();
return builder.build(request);
}
}

@Provider
public static class JsonParseExceptionMapper implements ExceptionMapper<JsonParseException> {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.wildfly.extras.a2a.server.apps.jakarta;


import jakarta.servlet.http.HttpServletRequest;

import io.a2a.server.ServerCallContext;

public interface CallContextFactory {
ServerCallContext build(HttpServletRequest request);
}