Skip to content

Commit 801a5ad

Browse files
committed
重构: ai chat 系统提示词支持传入更多的数据库信息,减少交互次数
1 parent f5e28e0 commit 801a5ad

File tree

12 files changed

+122
-32
lines changed

12 files changed

+122
-32
lines changed

client/lib/models/ai.dart

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,15 @@ abstract class AIChatUserMessageModel with _$AIChatUserMessageModel {
124124
const factory AIChatUserMessageModel({
125125
required AIChatMessageId id,
126126
required String content,
127+
String? ref,
127128
}) = _AIChatUserMessageModel;
128129

129-
String toMessage() => content;
130+
String toMessage() {
131+
final refText = ref?.trim() ?? '';
132+
if (refText.isEmpty) return content;
133+
// ref 作为“额外上下文”,拼到 user message 里供 LLM 使用,但 UI 仍可只展示 content。
134+
return '$content\n\nref:\n$refText';
135+
}
130136
}
131137

132138
@freezed

client/lib/models/instances.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,6 @@ abstract class PaginationInstanceListModel with _$PaginationInstanceListModel {
8787
abstract class InstanceMetadataModel with _$InstanceMetadataModel {
8888
const factory InstanceMetadataModel({
8989
required List<MetaDataNode> metadata,
90+
required String? version,
9091
}) = _InstanceMetadataModel;
9192
}

client/lib/models/sessions.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ abstract class SessionAIChatModel with _$SessionAIChatModel {
288288
required SessionId sessionId,
289289
required String? currentSchema,
290290
required DatabaseType? dbType,
291-
required List<MetaDataNode>? metadata,
291+
required InstanceMetadataModel? metadata,
292292
required ConnId? connId,
293293
required SQLConnectState? state,
294294
required AIChatModel chatModel,

client/lib/repositories/instances/instances.dart

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ class InstanceRepoImpl extends InstanceRepo {
257257
conn = SessionConn(model: instance);
258258
await conn.connect();
259259
final metadataNode = await conn.metadata();
260-
return InstanceMetadataModel(metadata: metadataNode);
260+
final version = await conn.version();
261+
return InstanceMetadataModel(metadata: metadataNode, version: version);
261262
} catch (e) {
262263
rethrow;
263264
} finally {

client/lib/repositories/instances/session_conn.dart

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class SessionConn {
233233
Future<List<MetaDataNode>> metadata() async {
234234
return await conn2!.metadata();
235235
}
236+
237+
Future<String?> version() async {
238+
return await conn2!.version();
239+
}
236240
}
237241

238242
@Riverpod(keepAlive: true)

client/lib/screens/sessions/ai_chat/input_user.dart

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,79 @@ class SessionChatInputCard extends ConsumerStatefulWidget {
3030
}
3131

3232
class _SessionChatInputCardState extends ConsumerState<SessionChatInputCard> {
33+
Set<String> _extractMentionedTables(String encoded) {
34+
final segments = MentionSegmentSerializer.decode(encoded);
35+
final tableNames = <String>{};
36+
37+
// 只取“真正的 mention token”
38+
for (final s in segments) {
39+
if (s is MentionSegment) {
40+
tableNames.add(s.label);
41+
}
42+
}
43+
44+
return tableNames;
45+
}
46+
47+
MetaDataNode? _findSchemaNode(SessionAIChatModel chatModel) {
48+
if (chatModel.metadata == null || chatModel.currentSchema == null) return null;
49+
final root = MetaDataNode(MetaType.instance, "", items: chatModel.metadata!.metadata);
50+
MetaDataNode? schemaNode;
51+
root.visitor((node, _) {
52+
if (schemaNode != null) return false;
53+
if (node.type == MetaType.schema && node.value == chatModel.currentSchema) {
54+
schemaNode = node;
55+
return false;
56+
}
57+
return true;
58+
});
59+
return schemaNode;
60+
}
61+
62+
String _buildTableRef(SessionAIChatModel chatModel, Iterable<String> mentionedTables) {
63+
final schemaNode = _findSchemaNode(chatModel);
64+
if (schemaNode == null) return '';
65+
66+
final tables = mentionedTables.toSet().toList()..sort();
67+
if (tables.isEmpty) return '';
68+
69+
final b = StringBuffer();
70+
for (final tableName in tables) {
71+
MetaDataNode? tableNode;
72+
for (final n in (schemaNode.items ?? const <MetaDataNode>[])) {
73+
if (n.type == MetaType.table && n.value == tableName) {
74+
tableNode = n;
75+
break;
76+
}
77+
}
78+
79+
// 直接复用 MetaDataNode.toString() 的 JSON 序列化(见 db_driver_metadata.dart)
80+
if (tableNode != null) {
81+
b.writeln(tableNode.toString());
82+
}
83+
}
84+
85+
return b.toString().trimRight();
86+
}
87+
3388
Future<void> _sendMessage(AIChatId chatId, SessionAIChatModel chatModel) async {
3489
final chatInputController = SessionController.sessionController(chatModel.sessionId).chatInputController;
3590
final encoded = chatInputController.text.trim();
36-
final text = MentionSegmentSerializer.decode(encoded).map((s) => s.toDisplayText()).join().trim();
91+
final segments = MentionSegmentSerializer.decode(encoded);
92+
final text = segments.map((s) => s.toDisplayText()).join().trim();
3793
chatInputController.clear();
3894

95+
// 如果用户通过 @ 提及了表,则把表结构信息放到 ref 里
96+
final mentionedTables = _extractMentionedTables(encoded);
97+
final refText = _buildTableRef(chatModel, mentionedTables);
98+
3999
// 调用AIChatService的chat方法
40100
await ref.read(aIChatServiceProvider.notifier).chat(
41101
chatId,
42102
chatModel.llmAgents.lastUsedLLMAgent!.id,
43103
genChatSystemPrompt(chatModel),
44104
message: text,
105+
refText: refText.isEmpty ? null : refText,
45106
);
46107

47108
final scrollController = SessionController.sessionController(chatModel.sessionId).aiChatScrollController;
@@ -80,9 +141,8 @@ class _SessionChatInputCardState extends ConsumerState<SessionChatInputCard> {
80141
// 输入框
81142
ChatInputFieldWidget(
82143
model: widget.model,
83-
onSubmitted: widget.model.canSendMessage()
84-
? () => _sendMessage(widget.model.chatModel.id, widget.model)
85-
: null,
144+
onSubmitted:
145+
widget.model.canSendMessage() ? () => _sendMessage(widget.model.chatModel.id, widget.model) : null,
86146
),
87147

88148
const SizedBox(height: kSpacingSmall),
@@ -375,12 +435,9 @@ class _ChatInputFieldWidgetState extends ConsumerState<ChatInputFieldWidget> {
375435
if (widget.model.metadata == null || widget.model.currentSchema == null) {
376436
return [];
377437
}
378-
final schema = MetaDataNode(MetaType.instance, "", items: widget.model.metadata!);
438+
final schema = MetaDataNode(MetaType.instance, "", items: widget.model.metadata!.metadata);
379439
final schemaNodes = schema.getChildren(MetaType.schema, widget.model.currentSchema!);
380-
return schemaNodes
381-
.where((e) => e.type == MetaType.table)
382-
.map((e) => e.value)
383-
.toList();
440+
return schemaNodes.where((e) => e.type == MetaType.table).map((e) => e.value).toList();
384441
}
385442

386443
List<String> _filterAndSortTables(List<String> allTables, String query) {

client/lib/services/ai/chat.dart

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,13 @@ class AIChatService extends _$AIChatService {
131131
}
132132

133133
/// 进行AI对话,请求接口,存储消息并刷新使用 provider 来动态刷新页面
134-
Future<void> chat(AIChatId id, LLMAgentId agentId, String systemPrompt, {String? message}) async {
134+
Future<void> chat(
135+
AIChatId id,
136+
LLMAgentId agentId,
137+
String systemPrompt, {
138+
String? message,
139+
String? refText,
140+
}) async {
135141
final repo = ref.read(aiChatRepoProvider);
136142
final model = repo.getAIChatById(id);
137143
if (model == null) {
@@ -142,7 +148,13 @@ class AIChatService extends _$AIChatService {
142148
if (message != null) {
143149
repo.addMessage(
144150
id,
145-
AIChatMessageItem.userMessage(AIChatUserMessageModel(id: AIChatMessageId.generate(), content: message)),
151+
AIChatMessageItem.userMessage(
152+
AIChatUserMessageModel(
153+
id: AIChatMessageId.generate(),
154+
content: message,
155+
ref: refText,
156+
),
157+
),
146158
);
147159
}
148160

client/lib/services/ai/prompt.dart

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,32 @@ To confirm that you are available, please only return a number 1 to me.
88

99
const chatTemplate = """
1010
你是一个智能SQL客户端助手. 你正在与一个使用数据库工具的用户对话. 你正在帮助用户回答关于数据库的问题.
11-
## 数据库信息:
11+
## 当前数据库connection的一些基本信息:
1212
db type: {dbType}
13+
db version: {dbVersion}
14+
current schema: {currentSchema}
1315
1416
## 用户输入的格式:
1517
用户会通过@符号指定表名并在当前对话里将表信息传递给你, 给你辅助回答问题.
16-
@table_name 表示表名, 例如: @users.
17-
ref:
18-
table_name信息:
19-
- 表名: users
20-
- 表描述: 用户表
21-
- 表字段:
22-
- id: 用户ID
23-
- name: 用户名
24-
- email: 用户邮箱
25-
- created_at: 创建时间
26-
- updated_at: 更新时间
18+
@table_name 表示表名, 例如: @users. 在`ref:`后面会传递表信息给你, 你需要根据表信息来辅助回答问题.
2719
2820
## 注意点:
2921
- 你只能回答或解决与数据库相关的问题;
3022
- 如果回复包含SQL, 每个SQL应该被包裹在一个 ```sql``` 块中;
23+
- 信任用户传递的表信息,除非用户显式的表达你需要重新查询它;
3124
- 数据库的query查询是非常重要的工具, 你除了使用它进行数据库信息获取外,还可以用它来进行任务逻辑计算, 例如:`SELECT 100 * 30 as result`;
3225
- 在使用query工具时尽可能一次获取更多想要的信息, 避免多次调用query工具;
3326
- 在使用query工具时要保持返回必要信息, 不要返回无关信息,例如:只返回需要的列和行;
3427
- 在使用query工具时要注意性能问题,例如: 可使用limit等限制返回数据量, 避免返回过多数据导致性能问题;
3528
""";
3629

3730
String genChatSystemPrompt(SessionAIChatModel model) {
38-
String prompt = chatTemplate;
39-
if (model.dbType != null) {
40-
prompt = prompt.replaceAll("{dbType}", model.dbType!.name);
41-
}
42-
return prompt;
31+
final dbVersion = (model.metadata?.version ?? "").trim();
32+
final currentSchema = (model.currentSchema ?? "").trim();
33+
return chatTemplate
34+
.replaceAll("{dbType}", model.dbType?.name ?? "-")
35+
.replaceAll("{dbVersion}", dbVersion.isEmpty ? "-" : dbVersion)
36+
.replaceAll("{currentSchema}", currentSchema.isEmpty ? "-" : currentSchema);
4337
}
4438

4539
// 导入任务的文件命名

client/lib/services/sessions/session_chat.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class SessionAIChatNotifier extends _$SessionAIChatNotifier {
4545
sessionId: session.sessionId,
4646
currentSchema: session.currentSchema,
4747
dbType: session.dbType,
48-
metadata: metadata?.value?.metadata,
48+
metadata: metadata?.value,
4949
connId: session.connId,
5050
state: session.connState,
5151
llmAgents: llmAgents,

pkg/db_driver/lib/src/db_driver_interface.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ abstract class BaseConnection {
106106
Future<List<String>> schemas();
107107
Future<String?> getCurrentSchema();
108108
Future<void> setCurrentSchema(String schema);
109+
Future<String> version();
109110

110111
void listen(
111112
{Function()? onCloseCallback,

0 commit comments

Comments
 (0)