Skip to content

Commit 6fc77ea

Browse files
authored
Merge pull request #7596 from sagemathinc/llm-1shot
improve AI generator prompts
2 parents 3ae60b4 + b5a59f2 commit 6fc77ea

File tree

8 files changed

+342
-168
lines changed

8 files changed

+342
-168
lines changed

src/packages/frontend/client/llm.ts

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,18 @@ import * as message from "@cocalc/util/message";
2525
import type { WebappClient } from "./client";
2626
import type { History } from "./types";
2727

28+
interface QueryLLMProps {
29+
input: string;
30+
model: LanguageModel;
31+
system?: string;
32+
history?: History;
33+
project_id?: string;
34+
path?: string;
35+
chatStream?: ChatStream; // if given, uses chat stream
36+
tag?: string;
37+
startStreamExplicitly?: boolean;
38+
}
39+
2840
interface EmbeddingsQuery {
2941
scope: string | string[];
3042
limit: number; // client automatically deals with large limit by making multiple requests (i.e., there is no limit on the limit)
@@ -41,7 +53,7 @@ export class LLMClient {
4153
this.client = client;
4254
}
4355

44-
public async query(opts): Promise<string> {
56+
public async query(opts: QueryLLMProps): Promise<string> {
4557
return await this.queryLanguageModel(opts);
4658
}
4759

@@ -70,17 +82,7 @@ export class LLMClient {
7082
path,
7183
chatStream,
7284
tag = "",
73-
}: {
74-
input: string;
75-
model: LanguageModel;
76-
system?: string;
77-
history?: History;
78-
project_id?: string;
79-
path?: string;
80-
chatStream?: ChatStream; // if given, uses chat stream
81-
tag?: string;
82-
startStreamExplicitly?: boolean;
83-
}): Promise<string> {
85+
}: QueryLLMProps): Promise<string> {
8486
system ??= getSystemPrompt(model, path);
8587

8688
// remove all date entries from all history objects

src/packages/frontend/codemirror/extensions/ai-formula.tsx

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { Button, Descriptions, Divider, Input, Modal, Space } from "antd";
2+
import { debounce } from "lodash";
23

34
import { useLanguageModelSetting } from "@cocalc/frontend/account/useLanguageModelSetting";
45
import {
@@ -8,6 +9,7 @@ import {
89
useState,
910
useTypedRedux,
1011
} from "@cocalc/frontend/app-framework";
12+
import type { Message } from "@cocalc/frontend/client/types";
1113
import {
1214
HelpIcon,
1315
Icon,
@@ -20,12 +22,11 @@ import AIAvatar from "@cocalc/frontend/components/ai-avatar";
2022
import { LLMModelName } from "@cocalc/frontend/components/llm-name";
2123
import LLMSelector from "@cocalc/frontend/frame-editors/llm/llm-selector";
2224
import { show_react_modal } from "@cocalc/frontend/misc";
25+
import { LLMCostEstimation } from "@cocalc/frontend/misc/llm-cost-estimation";
2326
import track from "@cocalc/frontend/user-tracking";
2427
import { webapp_client } from "@cocalc/frontend/webapp-client";
2528
import { isFreeModel } from "@cocalc/util/db-schema/llm-utils";
2629
import { unreachable } from "@cocalc/util/misc";
27-
import { LLMCostEstimation } from "../../misc/llm-cost-estimation";
28-
import { debounce } from "lodash";
2930

3031
type Mode = "tex" | "md";
3132

@@ -64,13 +65,18 @@ function AiGenFormula({ mode, text = "", project_id, cb }: Props) {
6465
useAsyncEffect(
6566
debounce(
6667
async () => {
67-
const prompt = getPrompt() ?? "";
68+
const { input, history, system } = getPrompt() ?? "";
6869
// compute the number of tokens (this MUST be a lazy import):
6970
const { getMaxTokens, numTokensUpperBound } = await import(
7071
"@cocalc/frontend/misc/llm"
7172
);
7273

73-
setTokens(numTokensUpperBound(prompt, getMaxTokens(model)));
74+
const all = [
75+
input,
76+
history.map(({ content }) => content).join(" "),
77+
system,
78+
].join(" ");
79+
setTokens(numTokensUpperBound(all, getMaxTokens(model)));
7480
},
7581
1000,
7682
{ leading: true, trailing: true },
@@ -83,20 +89,47 @@ function AiGenFormula({ mode, text = "", project_id, cb }: Props) {
8389
.getStore("projects")
8490
.hasLanguageModelEnabled(project_id, LLM_USAGE_TAG);
8591

86-
function getPrompt() {
87-
const description = input || text;
88-
const p1 = `Convert the following plain-text description of a formula to a LaTeX formula`;
89-
const p2 = `Return the LaTeX formula, and only the formula. Enclose the formula in a single snippet delimited by $. Do not add any explanations.`;
92+
function getSystemPrompt(): string {
93+
const p1 = `Typset the plain-text description of a mathematical formula as a LaTeX formula. The formula will be`;
94+
const p2 = `Return only the LaTeX formula, ready to be inserted into the document. Do not add any explanations.`;
9095
switch (mode) {
9196
case "tex":
92-
return `${p1} in a *.tex file. Assume the package "amsmath" is available. ${p2}:\n\n${description}`;
97+
return `${p1} in a *.tex file. Assume the package "amsmath" is available. ${p2}`;
9398
case "md":
94-
return `${p1} in a markdown file. ${p2}\n\n${description}`;
99+
return `${p1} in a markdown file. Formulas are inside of $ or $$. ${p2}`;
95100
default:
96101
unreachable(mode);
102+
return p1;
97103
}
98104
}
99105

106+
function getPrompt(): { input: string; history: Message[]; system: string } {
107+
const system = getSystemPrompt();
108+
// 3-shot examples
109+
const history: Message[] = [
110+
{ role: "user", content: "equation e^(i pi) = -1" },
111+
{ role: "assistant", content: "$$e^{i \\pi} = -1$$" },
112+
{
113+
role: "user",
114+
content: "integral 0 to 2 pi sin(x)^2",
115+
},
116+
{
117+
role: "assistant",
118+
content: "$\\int_{0}^{2\\pi} \\sin(x)^2 \\, \\mathrm{d}x$",
119+
},
120+
{
121+
role: "user",
122+
content: "equation system: [ 1 + x^2 = a, 1 - y^2 = ln(a) ]",
123+
},
124+
{
125+
role: "assistant",
126+
content:
127+
"\\begin{cases}\n1 + x^2 = a \\\n1 - y^2 = \\ln(a)\n\\end{cases}",
128+
},
129+
];
130+
return { input: input || text, system, history };
131+
}
132+
100133
function wrapFormula(tex: string = "") {
101134
// wrap single-line formulas in $...$
102135
// if it is multiline, wrap in \begin{equation}...\end{equation}
@@ -170,12 +203,14 @@ function AiGenFormula({ mode, text = "", project_id, cb }: Props) {
170203
type: "generate",
171204
model,
172205
});
206+
const { system, input, history } = getPrompt();
173207
const reply = await webapp_client.openai_client.query({
174-
input: getPrompt(),
208+
input,
209+
history,
210+
system,
211+
model,
175212
project_id,
176213
tag: LLM_USAGE_TAG,
177-
model,
178-
system: "",
179214
});
180215
const tex = processFormula(reply);
181216
// significant differece? Also show the full reply

src/packages/frontend/frame-editors/llm/llm-query-dropdown.tsx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import { LLM_PROVIDER } from "@cocalc/util/db-schema/llm-utils";
1111
import { LLMTools } from "@cocalc/jupyter/types";
1212

1313
interface Props {
14-
llmTools?: LLMTools;
14+
llmTools?: Pick<LLMTools, "model" | "setModel">;
1515
task?: string;
1616
onClick: () => void;
1717
loading?: boolean;
@@ -87,7 +87,10 @@ export function LLMQueryDropdownButton({
8787
trigger={["click"]}
8888
icon={<Icon name="caret-down" />}
8989
onClick={onClick}
90-
menu={{ items: getItems() }}
90+
menu={{
91+
items: getItems(),
92+
style: { maxHeight: "50vh", overflow: "auto" },
93+
}}
9194
loading={loading}
9295
disabled={disabled}
9396
>

0 commit comments

Comments
 (0)