Skip to content

Commit 2029ee7

Browse files
committed
Support custom http client for RemoteVeAgent
1 parent 6016f10 commit 2029ee7

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

veadk/a2a/remote_ve_agent.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16-
from typing import Literal
16+
from typing import Literal, Optional
1717

1818
import httpx
1919
import requests
@@ -97,49 +97,50 @@ def __init__(
9797
url: str,
9898
auth_token: str | None = None,
9999
auth_method: Literal["header", "querystring"] | None = None,
100+
httpx_client: Optional[httpx.AsyncClient] = None,
100101
):
101-
if not auth_token:
102-
agent_card_dict = requests.get(url + AGENT_CARD_WELL_KNOWN_PATH).json()
103-
# replace agent_card_url with actual host
104-
agent_card_dict["url"] = url
102+
req_headers = {}
103+
req_params = {}
105104

106-
agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict)
107-
108-
logger.debug(f"Agent card of {name}: {agent_card_object}")
109-
super().__init__(name=name, agent_card=agent_card_object)
110-
else:
105+
if auth_token:
111106
if auth_method == "header":
112-
headers = {"Authorization": f"Bearer {auth_token}"}
113-
agent_card_dict = requests.get(
114-
url + AGENT_CARD_WELL_KNOWN_PATH, headers=headers
115-
).json()
116-
agent_card_dict["url"] = url
117-
118-
agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict)
119-
httpx_client = httpx.AsyncClient(
120-
base_url=url, headers=headers, timeout=600
121-
)
122-
123-
logger.debug(f"Agent card of {name}: {agent_card_object}")
124-
super().__init__(
125-
name=name, agent_card=agent_card_object, httpx_client=httpx_client
126-
)
107+
req_headers = {"Authorization": f"Bearer {auth_token}"}
127108
elif auth_method == "querystring":
128-
agent_card_dict = requests.get(
129-
url + AGENT_CARD_WELL_KNOWN_PATH + f"?token={auth_token}"
130-
).json()
131-
agent_card_dict["url"] = url
132-
133-
agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict)
134-
httpx_client = httpx.AsyncClient(
135-
base_url=url, params={"token": auth_token}, timeout=600
136-
)
137-
138-
logger.debug(f"Agent card of {name}: {agent_card_object}")
139-
super().__init__(
140-
name=name, agent_card=agent_card_object, httpx_client=httpx_client
141-
)
109+
req_params = {"token": auth_token}
142110
else:
143111
raise ValueError(
144112
f"Unsupported auth method {auth_method}, use `header` or `querystring` instead."
145113
)
114+
115+
agent_card_dict = requests.get(
116+
url + AGENT_CARD_WELL_KNOWN_PATH, headers=req_headers, params=req_params
117+
).json()
118+
# replace agent_card_url with actual host
119+
agent_card_dict["url"] = url
120+
121+
agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict)
122+
123+
logger.debug(f"Agent card of {name}: {agent_card_object}")
124+
125+
client_to_use = httpx_client
126+
if auth_token:
127+
if client_to_use:
128+
if auth_method == "header":
129+
client_to_use.headers.update(req_headers)
130+
elif auth_method == "querystring":
131+
new_params = dict(client_to_use.params)
132+
new_params.update(req_params)
133+
client_to_use.params = new_params
134+
else:
135+
if auth_method == "header":
136+
client_to_use = httpx.AsyncClient(
137+
base_url=url, headers=req_headers, timeout=600
138+
)
139+
elif auth_method == "querystring":
140+
client_to_use = httpx.AsyncClient(
141+
base_url=url, params=req_params, timeout=600
142+
)
143+
144+
super().__init__(
145+
name=name, agent_card=agent_card_object, httpx_client=client_to_use
146+
)

0 commit comments

Comments
 (0)