|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import json |
16 | | -from typing import Literal |
| 16 | +from typing import Literal, Optional |
17 | 17 |
|
18 | 18 | import httpx |
19 | 19 | import requests |
@@ -97,49 +97,50 @@ def __init__( |
97 | 97 | url: str, |
98 | 98 | auth_token: str | None = None, |
99 | 99 | auth_method: Literal["header", "querystring"] | None = None, |
| 100 | + httpx_client: Optional[httpx.AsyncClient] = None, |
100 | 101 | ): |
101 | | - if not auth_token: |
102 | | - agent_card_dict = requests.get(url + AGENT_CARD_WELL_KNOWN_PATH).json() |
103 | | - # replace agent_card_url with actual host |
104 | | - agent_card_dict["url"] = url |
| 102 | + req_headers = {} |
| 103 | + req_params = {} |
105 | 104 |
|
106 | | - agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict) |
107 | | - |
108 | | - logger.debug(f"Agent card of {name}: {agent_card_object}") |
109 | | - super().__init__(name=name, agent_card=agent_card_object) |
110 | | - else: |
| 105 | + if auth_token: |
111 | 106 | if auth_method == "header": |
112 | | - headers = {"Authorization": f"Bearer {auth_token}"} |
113 | | - agent_card_dict = requests.get( |
114 | | - url + AGENT_CARD_WELL_KNOWN_PATH, headers=headers |
115 | | - ).json() |
116 | | - agent_card_dict["url"] = url |
117 | | - |
118 | | - agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict) |
119 | | - httpx_client = httpx.AsyncClient( |
120 | | - base_url=url, headers=headers, timeout=600 |
121 | | - ) |
122 | | - |
123 | | - logger.debug(f"Agent card of {name}: {agent_card_object}") |
124 | | - super().__init__( |
125 | | - name=name, agent_card=agent_card_object, httpx_client=httpx_client |
126 | | - ) |
| 107 | + req_headers = {"Authorization": f"Bearer {auth_token}"} |
127 | 108 | elif auth_method == "querystring": |
128 | | - agent_card_dict = requests.get( |
129 | | - url + AGENT_CARD_WELL_KNOWN_PATH + f"?token={auth_token}" |
130 | | - ).json() |
131 | | - agent_card_dict["url"] = url |
132 | | - |
133 | | - agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict) |
134 | | - httpx_client = httpx.AsyncClient( |
135 | | - base_url=url, params={"token": auth_token}, timeout=600 |
136 | | - ) |
137 | | - |
138 | | - logger.debug(f"Agent card of {name}: {agent_card_object}") |
139 | | - super().__init__( |
140 | | - name=name, agent_card=agent_card_object, httpx_client=httpx_client |
141 | | - ) |
| 109 | + req_params = {"token": auth_token} |
142 | 110 | else: |
143 | 111 | raise ValueError( |
144 | 112 | f"Unsupported auth method {auth_method}, use `header` or `querystring` instead." |
145 | 113 | ) |
| 114 | + |
| 115 | + agent_card_dict = requests.get( |
| 116 | + url + AGENT_CARD_WELL_KNOWN_PATH, headers=req_headers, params=req_params |
| 117 | + ).json() |
| 118 | + # replace agent_card_url with actual host |
| 119 | + agent_card_dict["url"] = url |
| 120 | + |
| 121 | + agent_card_object = _convert_agent_card_dict_to_obj(agent_card_dict) |
| 122 | + |
| 123 | + logger.debug(f"Agent card of {name}: {agent_card_object}") |
| 124 | + |
| 125 | + client_to_use = httpx_client |
| 126 | + if auth_token: |
| 127 | + if client_to_use: |
| 128 | + if auth_method == "header": |
| 129 | + client_to_use.headers.update(req_headers) |
| 130 | + elif auth_method == "querystring": |
| 131 | + new_params = dict(client_to_use.params) |
| 132 | + new_params.update(req_params) |
| 133 | + client_to_use.params = new_params |
| 134 | + else: |
| 135 | + if auth_method == "header": |
| 136 | + client_to_use = httpx.AsyncClient( |
| 137 | + base_url=url, headers=req_headers, timeout=600 |
| 138 | + ) |
| 139 | + elif auth_method == "querystring": |
| 140 | + client_to_use = httpx.AsyncClient( |
| 141 | + base_url=url, params=req_params, timeout=600 |
| 142 | + ) |
| 143 | + |
| 144 | + super().__init__( |
| 145 | + name=name, agent_card=agent_card_object, httpx_client=client_to_use |
| 146 | + ) |
0 commit comments