Skip to content

Commit 0bf30ca

Browse files
committed
feat: implementing Ollama LLM provider
- Applying LLM provider interfaces to implement the Ollama provider
1 parent f311d25 commit 0bf30ca

File tree

4 files changed

+164
-41
lines changed

4 files changed

+164
-41
lines changed

ext/ai/js/ai.js

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import 'ext:ai/onnxruntime/onnx.js';
22
import { parseJSON, parseJSONOverEventStream } from './llm/utils/json_parser.ts';
3+
import { LLMSession } from './llm/llm_session.ts';
34

45
const core = globalThis.Deno.core;
56

@@ -13,10 +14,10 @@ class Session {
1314
this.model = model;
1415
this.is_ext_inference_api = false;
1516

16-
if (model === "gte-small") {
17+
if (model === 'gte-small') {
1718
this.init = core.ops.op_ai_init_model(model);
1819
} else {
19-
this.inferenceAPIHost = core.ops.op_get_env("AI_INFERENCE_API_HOST");
20+
this.inferenceAPIHost = core.ops.op_get_env('AI_INFERENCE_API_HOST');
2021
this.is_ext_inference_api = !!this.inferenceAPIHost; // only enable external inference API if env variable is set
2122
}
2223
}
@@ -26,16 +27,30 @@ class Session {
2627
if (this.is_ext_inference_api) {
2728
const stream = opts.stream ?? false;
2829

30+
/** @type {'ollama' | 'openaicompatible'} */
31+
const mode = opts.mode ?? 'ollama';
32+
33+
if (mode === 'ollama') {
34+
// Using the new LLMSession API
35+
const llmSession = LLMSession.fromProvider('ollama', {
36+
inferenceAPIHost: this.inferenceAPIHost,
37+
model: this.model,
38+
});
39+
40+
return await llmSession.run({
41+
prompt,
42+
stream,
43+
signal: opts.signal,
44+
timeout: opts.timeout,
45+
});
46+
}
47+
2948
// default timeout 60s
30-
const timeout = typeof opts.timeout === "number" ? opts.timeout : 60;
49+
const timeout = typeof opts.timeout === 'number' ? opts.timeout : 60;
3150
const timeoutMs = timeout * 1000;
3251

33-
/** @type {'ollama' | 'openaicompatible'} */
34-
const mode = opts.mode ?? "ollama";
35-
3652
switch (mode) {
37-
case "ollama":
38-
case "openaicompatible":
53+
case 'openaicompatible':
3954
break;
4055

4156
default:
@@ -48,15 +63,15 @@ class Session {
4863

4964
const signal = AbortSignal.any(signals);
5065

51-
const path = mode === "ollama" ? "/api/generate" : "/v1/chat/completions";
52-
const body = mode === "ollama" ? { prompt } : prompt;
66+
const path = '/v1/chat/completions';
67+
const body = prompt;
5368

5469
const res = await fetch(
5570
new URL(path, this.inferenceAPIHost),
5671
{
57-
method: "POST",
72+
method: 'POST',
5873
headers: {
59-
"Content-Type": "application/json",
74+
'Content-Type': 'application/json',
6075
},
6176
body: JSON.stringify({
6277
model: this.model,
@@ -74,20 +89,16 @@ class Session {
7489
}
7590

7691
if (!res.body) {
77-
throw new Error("Missing body");
92+
throw new Error('Missing body');
7893
}
7994

80-
const parseGenFn = mode === "ollama"
81-
? parseJSON
82-
: stream === true
83-
? parseJSONOverEventStream
84-
: parseJSON;
95+
const parseGenFn = stream === true ? parseJSONOverEventStream : parseJSON;
8596
const itr = parseGenFn(res.body, signal);
8697

8798
if (stream) {
8899
return (async function* () {
89100
for await (const message of itr) {
90-
if ("error" in message) {
101+
if ('error' in message) {
91102
if (message.error instanceof Error) {
92103
throw message.error;
93104
} else {
@@ -98,20 +109,12 @@ class Session {
98109
yield message;
99110

100111
switch (mode) {
101-
case "ollama": {
102-
if (message.done) {
103-
return;
104-
}
105-
106-
break;
107-
}
108-
109-
case "openaicompatible": {
112+
case 'openaicompatible': {
110113
const finishReason = message.choices[0].finish_reason;
111114

112115
if (finishReason) {
113-
if (finishReason !== "stop") {
114-
throw new Error("Expected a completed response.");
116+
if (finishReason !== 'stop') {
117+
throw new Error('Expected a completed response.');
115118
}
116119

117120
return;
@@ -121,18 +124,18 @@ class Session {
121124
}
122125

123126
default:
124-
throw new Error("unreachable");
127+
throw new Error('unreachable');
125128
}
126129
}
127130

128131
throw new Error(
129-
"Did not receive done or success response in stream.",
132+
'Did not receive done or success response in stream.',
130133
);
131134
})();
132135
} else {
133136
const message = await itr.next();
134137

135-
if (message.value && "error" in message.value) {
138+
if (message.value && 'error' in message.value) {
136139
const error = message.value.error;
137140

138141
if (error instanceof Error) {
@@ -142,12 +145,10 @@ class Session {
142145
}
143146
}
144147

145-
const finish = mode === "ollama"
146-
? message.value.done
147-
: message.value.choices[0].finish_reason === "stop";
148+
const finish = message.value.choices[0].finish_reason === 'stop';
148149

149150
if (finish !== true) {
150-
throw new Error("Expected a completed response.");
151+
throw new Error('Expected a completed response.');
151152
}
152153

153154
return message.value;
@@ -172,8 +173,7 @@ class Session {
172173
}
173174

174175
const MAIN_WORKER_API = {
175-
tryCleanupUnusedSession: () =>
176-
/* async */ core.ops.op_ai_try_cleanup_unused_session(),
176+
tryCleanupUnusedSession: () => /* async */ core.ops.op_ai_try_cleanup_unused_session(),
177177
};
178178

179179
const USER_WORKER_API = {

ext/ai/js/llm/llm_session.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import { OllamaLLMSession } from './providers/ollama.ts';
2+
13
// @ts-ignore deno_core environment
24
const core = globalThis.Deno.core;
35

@@ -28,14 +30,20 @@ export interface ILLMProvider {
2830
getText(prompt: string, signal: AbortSignal): Promise<any>;
2931
}
3032

33+
export const providers = {
34+
'ollama': OllamaLLMSession,
35+
} satisfies Record<string, new (opts: ILLMProviderOptions) => ILLMProvider>;
36+
37+
export type LLMProviderName = keyof typeof providers;
38+
3139
export class LLMSession {
3240
#inner: ILLMProvider;
3341

3442
constructor(provider: ILLMProvider) {
3543
this.#inner = provider;
3644
}
3745

38-
static fromProvider(name: string, opts: ILLMProviderOptions) {
46+
static fromProvider(name: LLMProviderName, opts: ILLMProviderOptions) {
3947
const ProviderType = providers[name];
4048
if (!ProviderType) throw new Error('invalid provider');
4149

@@ -64,4 +72,3 @@ export class LLMSession {
6472
return this.#inner.getText(opts.prompt, signal);
6573
}
6674
}
67-

ext/ai/js/llm/providers/ollama.ts

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import { ILLMProvider, ILLMProviderOptions } from '../llm_session.ts';
2+
import { parseJSON } from '../utils/json_parser.ts';
3+
4+
export type OllamaProviderOptions = ILLMProviderOptions;
5+
6+
export type OllamaMessage = {
7+
model: string;
8+
created_at: Date;
9+
response: string;
10+
done: boolean;
11+
context: number[];
12+
total_duration: number;
13+
load_duration: number;
14+
prompt_eval_count: number;
15+
prompt_eval_duration: number;
16+
eval_count: number;
17+
eval_duration: number;
18+
};
19+
20+
export class OllamaLLMSession implements ILLMProvider {
21+
opts: OllamaProviderOptions;
22+
23+
constructor(opts: OllamaProviderOptions) {
24+
this.opts = opts;
25+
}
26+
27+
// ref: https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L26
28+
async getStream(
29+
prompt: string,
30+
signal: AbortSignal,
31+
): Promise<AsyncIterable<OllamaMessage>> {
32+
const generator = await this.generate(prompt, signal, true);
33+
34+
const stream = async function* () {
35+
for await (const message of generator) {
36+
if ('error' in message) {
37+
if (message.error instanceof Error) {
38+
throw message.error;
39+
} else {
40+
throw new Error(message.error as string);
41+
}
42+
}
43+
44+
yield message;
45+
if (message.done) {
46+
return;
47+
}
48+
}
49+
50+
throw new Error(
51+
'Did not receive done or success response in stream.',
52+
);
53+
};
54+
55+
return stream();
56+
}
57+
58+
async getText(prompt: string, signal: AbortSignal): Promise<OllamaMessage> {
59+
const generator = await this.generate(prompt, signal);
60+
61+
const message = await generator.next();
62+
63+
if (message.value && 'error' in message.value) {
64+
const error = message.value.error;
65+
66+
if (error instanceof Error) {
67+
throw error;
68+
} else {
69+
throw new Error(error);
70+
}
71+
}
72+
73+
const response = message.value;
74+
75+
if (!response?.done) {
76+
throw new Error('Expected a completed response.');
77+
}
78+
79+
return response;
80+
}
81+
82+
private async generate(
83+
prompt: string,
84+
signal: AbortSignal,
85+
stream: boolean = false,
86+
) {
87+
const res = await fetch(
88+
new URL('/api/generate', this.opts.inferenceAPIHost),
89+
{
90+
method: 'POST',
91+
headers: {
92+
'Content-Type': 'application/json',
93+
},
94+
body: JSON.stringify({
95+
model: this.opts.model,
96+
stream,
97+
prompt,
98+
}),
99+
signal,
100+
},
101+
);
102+
103+
if (!res.ok) {
104+
throw new Error(
105+
`Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`,
106+
);
107+
}
108+
109+
if (!res.body) {
110+
throw new Error('Missing body');
111+
}
112+
113+
return parseJSON<OllamaMessage>(res.body, signal);
114+
}
115+
}

ext/ai/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ deno_core::extension!(
5555
"onnxruntime/onnx.js",
5656
"onnxruntime/cache_adapter.js",
5757
"llm/llm_session.ts",
58+
"llm/providers/ollama.ts",
5859
"llm/utils/json_parser.ts",
5960
"llm/utils/event_stream_parser.mjs",
6061
"llm/utils/event_source_stream.mjs",

0 commit comments

Comments
 (0)