Skip to content

Commit 8f64021

Browse files
committed
break: improving typescript support and refactoring the API
- Improving Typescript support for dynamic suggestion based on the selected Session type. - Break: Now LLM models must be defined inside `options` argument, it allows a better typescript checking as well makes easier to extend the API. - There's no need to check if `inferenceHost` env var is defined, since we can now switch between different LLM providers. Instead, we can enable LLM support if the given type is an allowed provider.
1 parent 3d98ea4 commit 8f64021

File tree

9 files changed

+363
-217
lines changed

9 files changed

+363
-217
lines changed

ext/ai/js/ai.d.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { Session } from "./ai.ts";
2+
import { LLMSessionRunInputOptions } from "./llm/llm_session.ts";
3+
import {
4+
OllamaProviderInput,
5+
OllamaProviderOptions,
6+
} from "./llm/providers/ollama.ts";
7+
import {
8+
OpenAIProviderInput,
9+
OpenAIProviderOptions,
10+
} from "./llm/providers/openai.ts";
11+
12+
export namespace ai {
13+
export { Session };
14+
export {
15+
LLMSessionRunInputOptions as LLMRunOptions,
16+
OllamaProviderInput as OllamaInput,
17+
OllamaProviderOptions as OllamaOptions,
18+
OpenAIProviderInput as OpenAICompatibleInput,
19+
OpenAIProviderOptions as OpenAICompatibleOptions,
20+
};
21+
}

ext/ai/js/ai.js

Lines changed: 0 additions & 80 deletions
This file was deleted.

ext/ai/js/ai.ts

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import "./onnxruntime/onnx.js";
2+
import {
3+
LLMProviderInstance,
4+
LLMProviderName,
5+
LLMSession,
6+
LLMSessionRunInputOptions as LLMInputOptions,
7+
providers,
8+
} from "./llm/llm_session.ts";
9+
10+
// @ts-ignore deno_core environment
11+
const core = globalThis.Deno.core;
12+
13+
// NOTE:(kallebysantos) do we still need gte-small? Or maybe add another type 'embeddings' with custom model opt.
14+
export type SessionType = LLMProviderName | "gte-small";
15+
16+
export type SessionOptions<T extends SessionType> = T extends LLMProviderName
17+
? LLMProviderInstance<T>["options"]
18+
: never;
19+
20+
export type SessionInput<T extends SessionType> = T extends LLMProviderName
21+
? LLMProviderInstance<T>["input"]
22+
: T extends "gte-small" ? string
23+
: never;
24+
25+
export type EmbeddingInputOptions = {
26+
/**
27+
* Pool embeddings by taking their mean
28+
*/
29+
mean_pool?: boolean;
30+
31+
/**
32+
* Normalize the embeddings result
33+
*/
34+
normalize?: boolean;
35+
};
36+
37+
export type SessionInputOptions<T extends SessionType> = T extends
38+
LLMProviderName ? LLMInputOptions
39+
: EmbeddingInputOptions;
40+
41+
export class Session<T extends SessionType> {
42+
#model?: string;
43+
#init?: Promise<void>;
44+
45+
// TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run
46+
constructor(
47+
public readonly type: T,
48+
public readonly options?: SessionOptions<T>,
49+
) {
50+
if (this.isEmbeddingType()) {
51+
this.#model = "gte-small"; // Default model
52+
this.#init = core.ops.op_ai_init_model(this.#model);
53+
return;
54+
}
55+
56+
if (this.isLLMType()) {
57+
if (!Object.keys(providers).includes(type)) {
58+
throw new TypeError(`invalid type: '${type}'`);
59+
}
60+
61+
if (!this.options || !this.options.model) {
62+
throw new Error(
63+
`missing required parameter 'model' for type: '${type}'`,
64+
);
65+
}
66+
67+
this.options.baseURL ??= core.ops.op_get_env(
68+
"AI_INFERENCE_API_HOST",
69+
) as string;
70+
71+
if (!this.options.baseURL) {
72+
throw new Error(
73+
`missing required parameter 'baseURL' for type: '${type}'`,
74+
);
75+
}
76+
}
77+
}
78+
79+
// /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */
80+
async run(input: SessionInput<T>, options: SessionInputOptions<T>) {
81+
if (this.isLLMType()) {
82+
const opts = options as LLMInputOptions;
83+
const stream = opts.stream ?? false;
84+
85+
const llmSession = LLMSession.fromProvider(this.type, {
86+
// safety: We did check `options` during construction
87+
baseURL: this.options!.baseURL,
88+
model: this.options!.model,
89+
...this.options, // allows custom provider initialization like 'apiKey'
90+
});
91+
92+
return await llmSession.run(input, {
93+
stream,
94+
signal: opts.signal,
95+
timeout: opts.timeout,
96+
});
97+
}
98+
99+
if (this.#init) {
100+
await this.#init;
101+
}
102+
103+
const opts = options as EmbeddingInputOptions;
104+
105+
const mean_pool = opts.mean_pool ?? true;
106+
const normalize = opts.normalize ?? true;
107+
108+
const result = await core.ops.op_ai_run_model(
109+
// @ts-ignore
110+
this.#model,
111+
prompt,
112+
mean_pool,
113+
normalize,
114+
);
115+
116+
return result;
117+
}
118+
119+
private isEmbeddingType(
120+
this: Session<SessionType>,
121+
): this is Session<"gte-small"> {
122+
return this.type === "gte-small";
123+
}
124+
125+
private isLLMType(
126+
this: Session<SessionType>,
127+
): this is Session<LLMProviderName> {
128+
return this.type !== "gte-small";
129+
}
130+
}
131+
132+
const MAIN_WORKER_API = {
133+
tryCleanupUnusedSession: () =>
134+
/* async */ core.ops.op_ai_try_cleanup_unused_session(),
135+
};
136+
137+
const USER_WORKER_API = {
138+
Session,
139+
};
140+
141+
export { MAIN_WORKER_API, USER_WORKER_API };

ext/ai/js/llm/llm_session.ts

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import { OllamaLLMSession } from './providers/ollama.ts';
2-
import { OpenAILLMSession } from './providers/openai.ts';
1+
import { OllamaLLMSession } from "./providers/ollama.ts";
2+
import { OpenAILLMSession } from "./providers/openai.ts";
33

44
// @ts-ignore deno_core environment
55
const core = globalThis.Deno.core;
@@ -20,30 +20,59 @@ export type LLMRunInput = {
2020
signal?: AbortSignal;
2121
};
2222

23+
export interface ILLMProviderMeta {
24+
input: ILLMProviderInput;
25+
output: unknown;
26+
options: ILLMProviderOptions;
27+
}
28+
2329
export interface ILLMProviderOptions {
2430
model: string;
25-
inferenceAPIHost: string;
31+
baseURL?: string;
2632
}
2733

28-
export interface ILLMProviderInput {
29-
prompt: string | object;
30-
signal: AbortSignal;
31-
}
34+
export type ILLMProviderInput<T = string | object> = T extends string ? string
35+
: T;
3236

3337
export interface ILLMProvider {
3438
// TODO:(kallebysantos) remove 'any'
3539
// TODO: (kallebysantos) standardised output format
36-
getStream(input: ILLMProviderInput): Promise<AsyncIterable<any>>;
37-
getText(input: ILLMProviderInput): Promise<any>;
40+
getStream(
41+
input: ILLMProviderInput,
42+
signal: AbortSignal,
43+
): Promise<AsyncIterable<any>>;
44+
getText(input: ILLMProviderInput, signal: AbortSignal): Promise<any>;
3845
}
3946

4047
export const providers = {
41-
'ollama': OllamaLLMSession,
42-
'openaicompatible': OpenAILLMSession,
43-
} satisfies Record<string, new (opts: ILLMProviderOptions) => ILLMProvider>;
48+
"ollama": OllamaLLMSession,
49+
"openaicompatible": OpenAILLMSession,
50+
} satisfies Record<
51+
string,
52+
new (opts: ILLMProviderOptions) => ILLMProvider & ILLMProviderMeta
53+
>;
4454

4555
export type LLMProviderName = keyof typeof providers;
4656

57+
export type LLMProviderClass<T extends LLMProviderName> = (typeof providers)[T];
58+
export type LLMProviderInstance<T extends LLMProviderName> = InstanceType<
59+
LLMProviderClass<T>
60+
>;
61+
62+
export type LLMSessionRunInputOptions = {
63+
/**
64+
* Stream response from model. Applies only for LLMs like `mistral` (default: false)
65+
*/
66+
stream?: boolean;
67+
68+
/**
69+
* Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60)
70+
*/
71+
timeout?: number;
72+
73+
signal?: AbortSignal;
74+
};
75+
4776
export class LLMSession {
4877
#inner: ILLMProvider;
4978

@@ -53,31 +82,31 @@ export class LLMSession {
5382

5483
static fromProvider(name: LLMProviderName, opts: ILLMProviderOptions) {
5584
const ProviderType = providers[name];
56-
if (!ProviderType) throw new Error('invalid provider');
85+
if (!ProviderType) throw new Error("invalid provider");
5786

5887
const provider = new ProviderType(opts);
5988

6089
return new LLMSession(provider);
6190
}
6291

6392
run(
64-
opts: LLMRunInput,
93+
input: ILLMProviderInput,
94+
opts: LLMSessionRunInputOptions,
6595
): Promise<AsyncIterable<any>> | Promise<any> {
6696
const isStream = opts.stream ?? false;
6797

68-
const timeoutSeconds = typeof opts.timeout === 'number' ? opts.timeout : 60;
98+
const timeoutSeconds = typeof opts.timeout === "number" ? opts.timeout : 60;
6999
const timeoutMs = timeoutSeconds * 1000;
70100

71101
const timeoutSignal = AbortSignal.timeout(timeoutMs);
72102
const abortSignals = [opts.signal, timeoutSignal]
73103
.filter((it) => it instanceof AbortSignal);
74104
const signal = AbortSignal.any(abortSignals);
75105

76-
const llmInput: ILLMProviderInput = { prompt: opts.prompt, signal };
77106
if (isStream) {
78-
return this.#inner.getStream(llmInput);
107+
return this.#inner.getStream(input, signal);
79108
}
80109

81-
return this.#inner.getText(llmInput);
110+
return this.#inner.getText(input, signal);
82111
}
83112
}

0 commit comments

Comments
 (0)