Skip to content

Commit 3d98ea4

Browse files
committed
feat: implementing 'OpenAI compatible' provider
- Applying LLM provider interfaces to implement the 'openaicompatible' mode
1 parent f90fdd6 commit 3d98ea4

File tree

5 files changed

+248
-150
lines changed

5 files changed

+248
-150
lines changed

ext/ai/js/ai.js

Lines changed: 21 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import 'ext:ai/onnxruntime/onnx.js';
2-
import { parseJSON, parseJSONOverEventStream } from './llm/utils/json_parser.ts';
3-
import { LLMSession } from './llm/llm_session.ts';
2+
import { LLMSession, providers } from './llm/llm_session.ts';
43

54
const core = globalThis.Deno.core;
65

@@ -9,11 +8,15 @@ class Session {
98
init;
109
is_ext_inference_api;
1110
inferenceAPIHost;
11+
extraOpts;
1212

13-
constructor(model) {
13+
// TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run
14+
constructor(model, opts = {}) {
1415
this.model = model;
1516
this.is_ext_inference_api = false;
17+
this.extraOpts = opts;
1618

19+
// TODO:(kallebysantos) do we still need gte-small?
1720
if (model === 'gte-small') {
1821
this.init = core.ops.op_ai_init_model(model);
1922
} else {
@@ -28,131 +31,25 @@ class Session {
2831
const stream = opts.stream ?? false;
2932

3033
/** @type {'ollama' | 'openaicompatible'} */
34+
// TODO:(kallebysantos) get mode from 'new' and apply type checking based on that
3135
const mode = opts.mode ?? 'ollama';
3236

33-
if (mode === 'ollama') {
34-
// Using the new LLMSession API
35-
const llmSession = LLMSession.fromProvider('ollama', {
36-
inferenceAPIHost: this.inferenceAPIHost,
37-
model: this.model,
38-
});
39-
40-
return await llmSession.run({
41-
prompt,
42-
stream,
43-
signal: opts.signal,
44-
timeout: opts.timeout,
45-
});
46-
}
47-
48-
// default timeout 60s
49-
const timeout = typeof opts.timeout === 'number' ? opts.timeout : 60;
50-
const timeoutMs = timeout * 1000;
51-
52-
switch (mode) {
53-
case 'openaicompatible':
54-
break;
55-
56-
default:
57-
throw new TypeError(`invalid mode: ${mode}`);
58-
}
59-
60-
const timeoutSignal = AbortSignal.timeout(timeoutMs);
61-
const signals = [opts.signal, timeoutSignal]
62-
.filter((it) => it instanceof AbortSignal);
63-
64-
const signal = AbortSignal.any(signals);
65-
66-
const path = '/v1/chat/completions';
67-
const body = prompt;
68-
69-
const res = await fetch(
70-
new URL(path, this.inferenceAPIHost),
71-
{
72-
method: 'POST',
73-
headers: {
74-
'Content-Type': 'application/json',
75-
},
76-
body: JSON.stringify({
77-
model: this.model,
78-
stream,
79-
...body,
80-
}),
81-
},
82-
{ signal },
83-
);
84-
85-
if (!res.ok) {
86-
throw new Error(
87-
`Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`,
88-
);
37+
if (!Object.keys(providers).includes(mode)) {
38+
throw new TypeError(`invalid mode: ${mode}`);
8939
}
9040

91-
if (!res.body) {
92-
throw new Error('Missing body');
93-
}
94-
95-
const parseGenFn = stream === true ? parseJSONOverEventStream : parseJSON;
96-
const itr = parseGenFn(res.body, signal);
97-
98-
if (stream) {
99-
return (async function* () {
100-
for await (const message of itr) {
101-
if ('error' in message) {
102-
if (message.error instanceof Error) {
103-
throw message.error;
104-
} else {
105-
throw new Error(message.error);
106-
}
107-
}
108-
109-
yield message;
110-
111-
switch (mode) {
112-
case 'openaicompatible': {
113-
const finishReason = message.choices[0].finish_reason;
114-
115-
if (finishReason) {
116-
if (finishReason !== 'stop') {
117-
throw new Error('Expected a completed response.');
118-
}
119-
120-
return;
121-
}
122-
123-
break;
124-
}
125-
126-
default:
127-
throw new Error('unreachable');
128-
}
129-
}
130-
131-
throw new Error(
132-
'Did not receive done or success response in stream.',
133-
);
134-
})();
135-
} else {
136-
const message = await itr.next();
137-
138-
if (message.value && 'error' in message.value) {
139-
const error = message.value.error;
140-
141-
if (error instanceof Error) {
142-
throw error;
143-
} else {
144-
throw new Error(error);
145-
}
146-
}
147-
148-
const finish = message.value.choices[0].finish_reason === 'stop';
149-
150-
if (finish !== true) {
151-
throw new Error('Expected a completed response.');
152-
}
153-
154-
return message.value;
155-
}
41+
const llmSession = LLMSession.fromProvider(mode, {
42+
inferenceAPIHost: this.inferenceAPIHost,
43+
model: this.model,
44+
...this.extraOpts, // allows custom provider initialization like 'apiKey'
45+
});
46+
47+
return await llmSession.run({
48+
prompt,
49+
stream,
50+
signal: opts.signal,
51+
timeout: opts.timeout,
52+
});
15653
}
15754

15855
if (this.init) {

ext/ai/js/llm/llm_session.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { OllamaLLMSession } from './providers/ollama.ts';
2+
import { OpenAILLMSession } from './providers/openai.ts';
23

34
// @ts-ignore deno_core environment
45
const core = globalThis.Deno.core;
@@ -20,18 +21,25 @@ export type LLMRunInput = {
2021
};
2122

2223
export interface ILLMProviderOptions {
23-
inferenceAPIHost: string;
2424
model: string;
25+
inferenceAPIHost: string;
26+
}
27+
28+
export interface ILLMProviderInput {
29+
prompt: string | object;
30+
signal: AbortSignal;
2531
}
2632

2733
export interface ILLMProvider {
2834
// TODO:(kallebysantos) remove 'any'
29-
getStream(prompt: string, signal: AbortSignal): Promise<AsyncIterable<any>>;
30-
getText(prompt: string, signal: AbortSignal): Promise<any>;
35+
// TODO: (kallebysantos) standardised output format
36+
getStream(input: ILLMProviderInput): Promise<AsyncIterable<any>>;
37+
getText(input: ILLMProviderInput): Promise<any>;
3138
}
3239

3340
export const providers = {
3441
'ollama': OllamaLLMSession,
42+
'openaicompatible': OpenAILLMSession,
3543
} satisfies Record<string, new (opts: ILLMProviderOptions) => ILLMProvider>;
3644

3745
export type LLMProviderName = keyof typeof providers;
@@ -65,10 +73,11 @@ export class LLMSession {
6573
.filter((it) => it instanceof AbortSignal);
6674
const signal = AbortSignal.any(abortSignals);
6775

76+
const llmInput: ILLMProviderInput = { prompt: opts.prompt, signal };
6877
if (isStream) {
69-
return this.#inner.getStream(opts.prompt, signal);
78+
return this.#inner.getStream(llmInput);
7079
}
7180

72-
return this.#inner.getText(opts.prompt, signal);
81+
return this.#inner.getText(llmInput);
7382
}
7483
}

ext/ai/js/llm/providers/ollama.ts

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
import { ILLMProvider, ILLMProviderOptions } from '../llm_session.ts';
1+
import { ILLMProvider, ILLMProviderInput, ILLMProviderOptions } from '../llm_session.ts';
22
import { parseJSON } from '../utils/json_parser.ts';
33

44
export type OllamaProviderOptions = ILLMProviderOptions;
5+
export type OllamaProviderInput = ILLMProviderInput & {
6+
prompt: string;
7+
};
58

69
export type OllamaMessage = {
710
model: string;
@@ -26,10 +29,13 @@ export class OllamaLLMSession implements ILLMProvider {
2629

2730
// ref: https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L26
2831
async getStream(
29-
prompt: string,
30-
signal: AbortSignal,
32+
{ prompt, signal }: OllamaProviderInput,
3133
): Promise<AsyncIterable<OllamaMessage>> {
32-
const generator = await this.generate(prompt, signal, true);
34+
const generator = await this.generate(
35+
prompt,
36+
signal,
37+
true,
38+
) as AsyncGenerator<OllamaMessage>;
3339

3440
const stream = async function* () {
3541
for await (const message of generator) {
@@ -55,22 +61,10 @@ export class OllamaLLMSession implements ILLMProvider {
5561
return stream();
5662
}
5763

58-
async getText(prompt: string, signal: AbortSignal): Promise<OllamaMessage> {
59-
const generator = await this.generate(prompt, signal);
60-
61-
const message = await generator.next();
62-
63-
if (message.value && 'error' in message.value) {
64-
const error = message.value.error;
65-
66-
if (error instanceof Error) {
67-
throw error;
68-
} else {
69-
throw new Error(error);
70-
}
71-
}
72-
73-
const response = message.value;
64+
async getText(
65+
{ prompt, signal }: OllamaProviderInput,
66+
): Promise<OllamaMessage> {
67+
const response = await this.generate(prompt, signal) as OllamaMessage;
7468

7569
if (!response?.done) {
7670
throw new Error('Expected a completed response.');
@@ -110,6 +104,12 @@ export class OllamaLLMSession implements ILLMProvider {
110104
throw new Error('Missing body');
111105
}
112106

113-
return parseJSON<OllamaMessage>(res.body, signal);
107+
if (stream) {
108+
return parseJSON<OllamaMessage>(res.body, signal);
109+
}
110+
111+
const result: OllamaMessage = await res.json();
112+
113+
return result;
114114
}
115115
}

0 commit comments

Comments
 (0)