|
1 | 1 | import { HF_HUB_URL, HF_ROUTER_URL } from "../config"; |
2 | 2 | import { FAL_AI_API_BASE_URL } from "../providers/fal-ai"; |
| 3 | +import { NEBIUS_API_BASE_URL } from "../providers/nebius"; |
3 | 4 | import { REPLICATE_API_BASE_URL } from "../providers/replicate"; |
4 | 5 | import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova"; |
5 | 6 | import { TOGETHER_API_BASE_URL } from "../providers/together"; |
@@ -39,8 +40,7 @@ export async function makeRequestOptions( |
39 | 40 | let otherArgs = remainingArgs; |
40 | 41 | const provider = maybeProvider ?? "hf-inference"; |
41 | 42 |
|
42 | | - const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } = |
43 | | - options ?? {}; |
| 43 | + const { forceTask, includeCredentials, taskHint, chatCompletion } = options ?? {}; |
44 | 44 |
|
45 | 45 | if (endpointUrl && provider !== "hf-inference") { |
46 | 46 | throw new Error(`Cannot use endpointUrl with a third-party provider.`); |
@@ -107,18 +107,6 @@ export async function makeRequestOptions( |
107 | 107 | headers["Content-Type"] = "application/json"; |
108 | 108 | } |
109 | 109 |
|
110 | | - if (provider === "hf-inference") { |
111 | | - if (wait_for_model) { |
112 | | - headers["X-Wait-For-Model"] = "true"; |
113 | | - } |
114 | | - if (use_cache === false) { |
115 | | - headers["X-Use-Cache"] = "false"; |
116 | | - } |
117 | | - if (dont_load_model) { |
118 | | - headers["X-Load-Model"] = "0"; |
119 | | - } |
120 | | - } |
121 | | - |
122 | 110 | if (provider === "replicate") { |
123 | 111 | headers["Prefer"] = "wait"; |
124 | 112 | } |
@@ -149,7 +137,7 @@ export async function makeRequestOptions( |
149 | 137 | ? args.data |
150 | 138 | : JSON.stringify({ |
151 | 139 | ...otherArgs, |
152 | | - ...(chatCompletion || provider === "together" ? { model } : undefined), |
| 140 | + ...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined), |
153 | 141 | }), |
154 | 142 | ...(credentials ? { credentials } : undefined), |
155 | 143 | signal: options?.signal, |
@@ -184,6 +172,22 @@ function makeUrl(params: { |
184 | 172 | : FAL_AI_API_BASE_URL; |
185 | 173 | return `${baseUrl}/${params.model}`; |
186 | 174 | } |
| 175 | + case "nebius": { |
| 176 | + const baseUrl = shouldProxy |
| 177 | + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) |
| 178 | + : NEBIUS_API_BASE_URL; |
| 179 | + |
| 180 | + if (params.taskHint === "text-to-image") { |
| 181 | + return `${baseUrl}/v1/images/generations`; |
| 182 | + } |
| 183 | + if (params.taskHint === "text-generation") { |
| 184 | + if (params.chatCompletion) { |
| 185 | + return `${baseUrl}/v1/chat/completions`; |
| 186 | + } |
| 187 | + return `${baseUrl}/v1/completions`; |
| 188 | + } |
| 189 | + return baseUrl; |
| 190 | + } |
187 | 191 | case "replicate": { |
188 | 192 | const baseUrl = shouldProxy |
189 | 193 | ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) |
|
0 commit comments