Skip to content

Commit 2a05527

Browse files
committed
Merge branch 'main' into saksham/blackforestlabs-ai
2 parents 07da906 + 57154a5 commit 2a05527

File tree

10 files changed

+7088
-7134
lines changed

10 files changed

+7088
-7134
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
HF_REPLICATE_KEY: dummy
4747
HF_SAMBANOVA_KEY: dummy
4848
HF_TOGETHER_KEY: dummy
49+
HF_NOVITA_KEY: dummy
4950
HF_FIREWORKS_KEY: dummy
5051
HF_BLACK_FOREST_LABS_KEY: dummy
5152

@@ -89,6 +90,7 @@ jobs:
8990
HF_REPLICATE_KEY: dummy
9091
HF_SAMBANOVA_KEY: dummy
9192
HF_TOGETHER_KEY: dummy
93+
HF_NOVITA_KEY: dummy
9294
HF_FIREWORKS_KEY: dummy
9395
HF_BLACK_FOREST_LABS_KEY: dummy
9496

@@ -159,5 +161,6 @@ jobs:
159161
HF_REPLICATE_KEY: dummy
160162
HF_SAMBANOVA_KEY: dummy
161163
HF_TOGETHER_KEY: dummy
164+
HF_NOVITA_KEY: dummy
162165
HF_FIREWORKS_KEY: dummy
163166
HF_BLACK_FOREST_LABS_KEY: dummy

packages/inference/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Currently, we support the following providers:
5050
- [Fal.ai](https://fal.ai)
5151
- [Fireworks AI](https://fireworks.ai)
5252
- [Nebius](https://studio.nebius.ai)
53+
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
5354
- [Replicate](https://replicate.com)
5455
- [Sambanova](https://sambanova.ai)
5556
- [Together](https://together.xyz)

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { NEBIUS_API_BASE_URL } from "../providers/nebius";
44
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
55
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
66
import { TOGETHER_API_BASE_URL } from "../providers/together";
7+
import { NOVITA_API_BASE_URL } from "../providers/novita";
78
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
89
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
910
import type { InferenceProvider } from "../types";
@@ -29,8 +30,6 @@ export async function makeRequestOptions(
2930
stream?: boolean;
3031
},
3132
options?: Options & {
32-
/** When a model can be used for multiple tasks, and we want to run a non-default task */
33-
forceTask?: string | InferenceTask;
3433
/** To load default model if needed */
3534
taskHint?: InferenceTask;
3635
chatCompletion?: boolean;
@@ -40,14 +39,11 @@ export async function makeRequestOptions(
4039
let otherArgs = remainingArgs;
4140
const provider = maybeProvider ?? "hf-inference";
4241

43-
const { forceTask, includeCredentials, taskHint, chatCompletion } = options ?? {};
42+
const { includeCredentials, taskHint, chatCompletion } = options ?? {};
4443

4544
if (endpointUrl && provider !== "hf-inference") {
4645
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
4746
}
48-
if (forceTask && provider !== "hf-inference") {
49-
throw new Error(`Cannot use forceTask with a third-party provider.`);
50-
}
5147
if (maybeModel && isUrl(maybeModel)) {
5248
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
5349
}
@@ -78,7 +74,6 @@ export async function makeRequestOptions(
7874
: makeUrl({
7975
authMethod,
8076
chatCompletion: chatCompletion ?? false,
81-
forceTask,
8277
model,
8378
provider: provider ?? "hf-inference",
8479
taskHint,
@@ -152,7 +147,6 @@ function makeUrl(params: {
152147
model: string;
153148
provider: InferenceProvider;
154149
taskHint: InferenceTask | undefined;
155-
forceTask?: string | InferenceTask;
156150
}): string {
157151
if (params.authMethod === "none" && params.provider !== "hf-inference") {
158152
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
@@ -225,6 +219,7 @@ function makeUrl(params: {
225219
}
226220
return baseUrl;
227221
}
222+
228223
case "fireworks-ai": {
229224
const baseUrl = shouldProxy
230225
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
@@ -234,15 +229,24 @@ function makeUrl(params: {
234229
}
235230
return baseUrl;
236231
}
232+
case "novita": {
233+
const baseUrl = shouldProxy
234+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
235+
: NOVITA_API_BASE_URL;
236+
if (params.taskHint === "text-generation") {
237+
if (params.chatCompletion) {
238+
return `${baseUrl}/chat/completions`;
239+
}
240+
return `${baseUrl}/completions`;
241+
}
242+
return baseUrl;
243+
}
237244
default: {
238245
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
239-
const url = params.forceTask
240-
? `${baseUrl}/pipeline/${params.forceTask}/${params.model}`
241-
: `${baseUrl}/models/${params.model}`;
242246
if (params.taskHint === "text-generation" && params.chatCompletion) {
243-
return url + `/v1/chat/completions`;
247+
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
244248
}
245-
return url;
249+
return `${baseUrl}/models/${params.model}`;
246250
}
247251
}
248252
}

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2424
replicate: {},
2525
sambanova: {},
2626
together: {},
27+
novita: {},
2728
};
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Novita model ID here:
5+
*
6+
* https://huggingface.co/api/partners/novita/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Novita and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Novita, please open an issue on the present repo
15+
* and we will tag Novita team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/tasks/nlp/featureExtraction.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { InferenceOutputError } from "../../lib/InferenceOutputError";
2-
import { getDefaultTask } from "../../lib/getDefaultTask";
32
import type { BaseArgs, Options } from "../../types";
43
import { request } from "../custom/request";
54

@@ -25,12 +24,9 @@ export async function featureExtraction(
2524
args: FeatureExtractionArgs,
2625
options?: Options
2726
): Promise<FeatureExtractionOutput> {
28-
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
29-
3027
const res = await request<FeatureExtractionOutput>(args, {
3128
...options,
3229
taskHint: "feature-extraction",
33-
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
3430
});
3531
let isValidOutput = true;
3632

packages/inference/src/tasks/nlp/sentenceSimilarity.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
22
import { InferenceOutputError } from "../../lib/InferenceOutputError";
3-
import { getDefaultTask } from "../../lib/getDefaultTask";
43
import type { BaseArgs, Options } from "../../types";
54
import { request } from "../custom/request";
65
import { omit } from "../../utils/omit";
@@ -14,11 +13,9 @@ export async function sentenceSimilarity(
1413
args: SentenceSimilarityArgs,
1514
options?: Options
1615
): Promise<SentenceSimilarityOutput> {
17-
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
1816
const res = await request<SentenceSimilarityOutput>(prepareInput(args), {
1917
...options,
2018
taskHint: "sentence-similarity",
21-
...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
2219
});
2320

2421
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

packages/inference/src/types.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ export interface Options {
2929
export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
32+
"black-forest-labs",
3233
"fal-ai",
3334
"fireworks-ai",
34-
"nebius",
3535
"hf-inference",
36+
"nebius",
37+
"novita",
3638
"replicate",
3739
"sambanova",
3840
"together",
39-
"black-forest-labs",
4041
] as const;
42+
4143
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
4244

4345
export interface BaseArgs {

packages/inference/test/HfInference.spec.ts

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,6 @@ describe.concurrent("HfInference", () => {
351351
});
352352
expect(response).toEqual(expect.arrayContaining([expect.any(Number)]));
353353
});
354-
it("FeatureExtraction - same model as sentence similarity", async () => {
355-
const response = await hf.featureExtraction({
356-
model: "sentence-transformers/paraphrase-xlm-r-multilingual-v1",
357-
inputs: "That is a happy person",
358-
});
359-
360-
expect(response.length).toBeGreaterThan(10);
361-
expect(response).toEqual(expect.arrayContaining([expect.any(Number)]));
362-
});
363354
it("FeatureExtraction - facebook/bart-base", async () => {
364355
const response = await hf.featureExtraction({
365356
model: "facebook/bart-base",
@@ -1176,6 +1167,53 @@ describe.concurrent("HfInference", () => {
11761167
TIMEOUT
11771168
);
11781169

1170+
describe.concurrent(
1171+
"Novita",
1172+
() => {
1173+
const client = new HfInference(env.HF_NOVITA_KEY);
1174+
1175+
HARDCODED_MODEL_ID_MAPPING["novita"] = {
1176+
"meta-llama/llama-3.1-8b-instruct": "meta-llama/llama-3.1-8b-instruct",
1177+
"deepseek/deepseek-r1-distill-qwen-14b": "deepseek/deepseek-r1-distill-qwen-14b",
1178+
};
1179+
1180+
it("chatCompletion", async () => {
1181+
const res = await client.chatCompletion({
1182+
model: "meta-llama/llama-3.1-8b-instruct",
1183+
provider: "novita",
1184+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1185+
});
1186+
if (res.choices && res.choices.length > 0) {
1187+
const completion = res.choices[0].message?.content;
1188+
expect(completion).toContain("two");
1189+
}
1190+
});
1191+
1192+
it("chatCompletion stream", async () => {
1193+
const stream = client.chatCompletionStream({
1194+
model: "deepseek/deepseek-r1-distill-qwen-14b",
1195+
provider: "novita",
1196+
messages: [{ role: "user", content: "Say this is a test" }],
1197+
stream: true,
1198+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1199+
1200+
let fullResponse = "";
1201+
for await (const chunk of stream) {
1202+
if (chunk.choices && chunk.choices.length > 0) {
1203+
const content = chunk.choices[0].delta?.content;
1204+
if (content) {
1205+
fullResponse += content;
1206+
}
1207+
}
1208+
}
1209+
1210+
// Verify we got a meaningful response
1211+
expect(fullResponse).toBeTruthy();
1212+
expect(fullResponse.length).toBeGreaterThan(0);
1213+
});
1214+
},
1215+
TIMEOUT
1216+
);
11791217
describe.concurrent(
11801218
"Black Forest Labs",
11811219
() => {

0 commit comments

Comments
 (0)