-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchat.py
More file actions
1329 lines (1080 loc) · 59 KB
/
chat.py
File metadata and controls
1329 lines (1080 loc) · 59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Seeed Wiki 优化问答系统
使用预保存的 FAISS 索引和 Ollama nomic-embed-text 模型
"""
import json
import os
import pickle
import numpy as np
import faiss
import ollama
import time
import re
import sys
import readline # 添加 readline 支持,提供更好的输入体验
import hashlib
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
import threading
import torch
import gc
import tempfile
class OptimizedQASystem:
def __init__(self):
self.faiss_index = None
self.faiss_metadata = None
self.wiki_pages = []
self.embedding_model = "nomic-embed-text"
# 性能优化相关
self.embedding_cache = {} # embedding 缓存
self.answer_cache = {} # 回答缓存
self.cache_lock = threading.Lock() # 缓存锁
self.executor = ThreadPoolExecutor(max_workers=2) # 线程池
# 流式显示相关
self.streaming_enabled = True # 是否启用流式显示
self.typing_speed = 0.03 # 打字速度(秒/字符)
# TTS相关
self.tts_enabled = False # 是否启用TTS
self.tts_model = None # TTS模型
self.tts_speaker_id = None # 说话人ID
self.tts_device = None # TTS设备
self.audio_playing = False # 音频播放状态
self.tts_available = False # TTS是否可用
self.tts_thread = None # TTS线程
self.tts_queue = [] # TTS任务队列
self.tts_lock = threading.Lock() # TTS锁
self.tts_processing = False # TTS是否正在处理
# 设置 readline 配置
self.setup_readline()
# 检查数据文件
self.check_data_files()
# 检查 Ollama 服务
self.check_ollama_service()
# 检查TTS模块可用性(不初始化)
self.check_tts_availability()
# 初始化系统
self.initialize_system()
def setup_readline(self):
"""设置 readline 配置,提供更好的输入体验"""
try:
# 设置历史文件
histfile = os.path.join(os.path.expanduser("~"), ".seeed_qa_history")
readline.read_history_file(histfile)
readline.set_history_length(1000)
# 设置自动补全
readline.parse_and_bind('tab: complete')
# 设置输入提示符样式
readline.parse_and_bind('set editing-mode emacs')
except Exception as e:
print(f"⚠️ readline 设置失败: {str(e)}")
print("💡 输入体验可能受限,但基本功能正常")
def safe_input(self, prompt):
"""安全的输入函数,提供更好的错误处理"""
try:
# 尝试使用 readline 输入
user_input = input(prompt)
return user_input.strip()
except (EOFError, KeyboardInterrupt):
print("\n👋 用户中断,退出程序")
sys.exit(0)
except Exception as e:
print(f"\n❌ 输入错误: {str(e)}")
return ""
def save_history(self):
"""保存输入历史"""
try:
histfile = os.path.join(os.path.expanduser("~"), ".seeed_qa_history")
readline.write_history_file(histfile)
except Exception:
pass # 忽略历史保存错误
def check_data_files(self):
"""检查必要的数据文件"""
required_files = [
"./data_base/faiss_index.bin",
"./data_base/faiss_metadata.pkl",
"./data_base/seeed_wiki_embeddings_db.json"
]
missing_files = []
for file in required_files:
if not os.path.exists(file):
missing_files.append(file)
if missing_files:
print("❌ 缺少必要的数据文件:")
for file in missing_files:
print(f" - {file}")
print("\n💡 请先运行爬虫脚本获取数据:")
print(" python scrape_with_embeddings.py")
raise FileNotFoundError(f"缺少数据文件: {', '.join(missing_files)}")
print("✅ 所有必要的数据文件已找到")
def check_ollama_service(self):
"""检查 Ollama 服务状态"""
try:
models = ollama.list()
print(f"✅ Ollama 服务正常,可用模型: {len(models['models'])} 个")
model_names = [model['name'] for model in models['models']]
print(model_names)
if 'nomic-embed-text:latest' not in model_names:
print("⚠️ 未找到 nomic-embed-text 模型,正在安装...")
ollama.pull('nomic-embed-text')
print("✅ nomic-embed-text 模型安装完成")
else:
print("✅ nomic-embed-text 模型已安装")
except Exception as e:
print(f"❌ Ollama 服务检查失败: {str(e)}")
raise
def check_tts_availability(self):
"""检查TTS模块是否可用(不实际导入)"""
print("🎤 检查TTS模块可用性...")
# 检查是否安装了必要的包
try:
import importlib.util
# 检查melo-tts
melo_spec = importlib.util.find_spec("melo")
if melo_spec is None:
print("⚠️ Melo TTS模块未安装")
self.tts_available = False
return
# pygame不是必需的,我们只保存文件不播放
print("✅ TTS相关模块已安装")
self.tts_available = True
except Exception as e:
print(f"⚠️ TTS模块检查失败: {str(e)}")
self.tts_available = False
def enable_tts(self):
"""启用TTS功能"""
if not self.tts_available:
raise Exception("TTS模块不可用")
if self.tts_enabled:
return # 已经启用
try:
from melo.api import TTS
print("🔧 配置PyTorch优化...")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# 设备配置
self.tts_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🎤 检测到设备: {self.tts_device}")
if self.tts_device == 'cuda':
print("🔧 配置CUDA内存...")
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.6)
print(f"✅ CUDA内存配置完成")
# 初始化TTS模型
print("📥 正在加载TTS模型...")
start_time = time.time()
self.tts_model = TTS(language='ZH', device=self.tts_device)
load_time = time.time() - start_time
print(f"✅ TTS模型加载完成,耗时: {load_time:.2f}秒")
# 获取说话人ID
speaker_ids = self.tts_model.hps.data.spk2id
self.tts_speaker_id = speaker_ids['ZH']
print(f"✅ 说话人ID: {self.tts_speaker_id}")
# 创建音频输出目录
self.audio_output_dir = "./audio_outputs"
os.makedirs(self.audio_output_dir, exist_ok=True)
print(f"✅ 音频输出目录: {self.audio_output_dir}")
# 模型预热
print("🔥 正在进行TTS模型预热...")
warmup_text = "你好"
start_time = time.time()
warmup_path = os.path.join(self.audio_output_dir, "warmup.wav")
self.tts_model.tts_to_file(warmup_text, self.tts_speaker_id, warmup_path, speed=1.0)
warmup_time = time.time() - start_time
print(f"✅ TTS预热完成,耗时: {warmup_time:.2f}秒")
# 清理预热文件
try:
os.unlink(warmup_path)
except:
pass
self.tts_enabled = True
# 启动TTS工作线程
self.tts_thread = threading.Thread(target=self.tts_worker_thread, daemon=True)
self.tts_thread.start()
print("🎉 TTS功能启用成功!")
print("🧵 TTS后台线程已启动")
except Exception as e:
print(f"❌ TTS启用失败: {str(e)}")
self.tts_enabled = False
self.tts_model = None
raise
def disable_tts(self):
"""禁用TTS功能"""
if not self.tts_enabled:
return
print("🔄 正在禁用TTS功能...")
# 停止TTS线程
self.tts_enabled = False
# 等待线程结束
if self.tts_thread and self.tts_thread.is_alive():
print("⏳ 等待TTS线程结束...")
self.tts_thread.join(timeout=5) # 最多等待5秒
# 清空队列
with self.tts_lock:
self.tts_queue.clear()
self.tts_processing = False
# 清理资源
if self.tts_device == 'cuda':
torch.cuda.empty_cache()
gc.collect()
print("✅ TTS功能已禁用")
def tts_worker_thread(self):
"""TTS工作线程"""
while self.tts_enabled:
try:
# 检查队列中是否有任务
with self.tts_lock:
if not self.tts_queue:
time.sleep(0.1) # 短暂休眠
continue
# 获取任务
task = self.tts_queue.pop(0)
self.tts_processing = True
# 处理TTS任务
text = task.get('text', '')
speed = task.get('speed', 1.0)
callback = task.get('callback', None)
if text and self.tts_model:
print(f"🎤 [后台] 开始生成语音: '{text[:30]}...'")
# 生成语音文件
audio_file = self._generate_audio_file(text, speed)
if audio_file and callback:
callback(audio_file)
print(f"🎤 [后台] 语音生成完成: {audio_file}")
with self.tts_lock:
self.tts_processing = False
except Exception as e:
print(f"❌ [后台] TTS处理错误: {str(e)}")
with self.tts_lock:
self.tts_processing = False
time.sleep(1) # 错误后短暂休眠
def _generate_audio_file(self, text, speed=1.0):
"""生成音频文件(内部方法)"""
try:
# 清理文本
clean_text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?;:""''()【】]', '', text)
if not clean_text.strip():
return None
# 限制文本长度
if len(clean_text) > 500:
clean_text = clean_text[:500] + "..."
# 生成文件名
timestamp = int(time.time())
text_hash = hashlib.md5(clean_text.encode('utf-8')).hexdigest()[:8]
filename = f"answer_{timestamp}_{text_hash}.wav"
output_path = os.path.join(self.audio_output_dir, filename)
# 生成语音
start_time = time.time()
self.tts_model.tts_to_file(clean_text, self.tts_speaker_id, output_path, speed=speed)
tts_time = time.time() - start_time
# 获取文件大小
file_size = os.path.getsize(output_path)
file_size_mb = file_size / (1024 * 1024)
print(f"📁 [后台] 文件已保存: {output_path}")
print(f"📊 [后台] 大小: {file_size_mb:.2f} MB, 耗时: {tts_time:.2f}秒")
return output_path
except Exception as e:
print(f"❌ [后台] 音频生成失败: {str(e)}")
return None
def initialize_system(self):
"""初始化系统"""
print("🚀 正在初始化优化问答系统...")
try:
# 加载 FAISS 索引
print("🔍 加载 FAISS 索引...")
if not os.path.exists("./data_base/faiss_index.bin"):
raise FileNotFoundError("FAISS 索引文件不存在")
self.faiss_index = faiss.read_index("./data_base/faiss_index.bin")
if self.faiss_index is None:
raise Exception("FAISS 索引加载失败")
print(f"✅ FAISS 索引加载完成: {self.faiss_index.ntotal} 个向量")
print(f" 索引维度: {self.faiss_index.d}")
print(f" 索引类型: {type(self.faiss_index).__name__}")
# 加载向量元数据
print("📊 加载向量元数据...")
if not os.path.exists("./data_base/faiss_metadata.pkl"):
raise FileNotFoundError("元数据文件不存在")
with open("./data_base/faiss_metadata.pkl", 'rb') as f:
self.faiss_metadata = pickle.load(f)
if not self.faiss_metadata or len(self.faiss_metadata) == 0:
raise Exception("元数据为空")
print(f"✅ 元数据加载完成: {len(self.faiss_metadata)} 条记录")
# 检查索引和元数据的一致性
if self.faiss_index.ntotal != len(self.faiss_metadata):
print(f"⚠️ 警告: 索引向量数({self.faiss_index.ntotal})与元数据记录数({len(self.faiss_metadata)})不匹配")
# 加载 Wiki 页面数据
print("📚 加载 Wiki 页面数据...")
if not os.path.exists("./data_base/seeed_wiki_embeddings_db.json"):
raise FileNotFoundError("Wiki 页面数据文件不存在")
with open("./data_base/seeed_wiki_embeddings_db.json", 'r', encoding='utf-8') as f:
data = json.load(f)
self.wiki_pages = data['pages']
self.metadata = data['metadata']
if not self.wiki_pages or len(self.wiki_pages) == 0:
raise Exception("Wiki 页面数据为空")
print(f"✅ 页面数据加载完成: {len(self.wiki_pages)} 个页面")
# 检查页面数据和元数据的一致性
if len(self.wiki_pages) != len(self.faiss_metadata):
print(f"⚠️ 警告: 页面数据数({len(self.wiki_pages)})与元数据记录数({len(self.faiss_metadata)})不匹配")
# 测试 Embedding 模型
print("🤖 测试 Embedding 模型...")
test_embedding = self.generate_embedding("test")
if test_embedding is None:
raise Exception("Embedding 生成失败")
# 检查 embedding 维度是否与索引匹配
if test_embedding.shape[0] != self.faiss_index.d:
raise Exception(f"Embedding 维度({test_embedding.shape[0]})与索引维度({self.faiss_index.d})不匹配")
print(f"✅ Embedding 模型测试成功: {len(test_embedding)} 维")
print("🎉 系统初始化完成!")
self.show_system_info()
# 加载缓存
self.load_cache()
except Exception as e:
print(f"❌ 系统初始化失败: {str(e)}")
import traceback
traceback.print_exc()
raise
def show_system_info(self):
"""显示系统信息"""
print(f"\n📊 系统信息:")
print(f" 总页面数: {len(self.wiki_pages)}")
print(f" 总向量数: {self.faiss_index.ntotal}")
print(f" 向量维度: {self.metadata['vector_dimension']}")
print(f" 内容类型: {self.metadata['content_type']}")
print(f" Embedding 模型: {self.metadata['embedding_model']}")
print(f" 索引类型: {self.metadata['index_type']}")
print(f" 爬取时间: {self.metadata['crawl_time']}")
print(f" 缓存状态: Embedding缓存 {len(self.embedding_cache)} 项,回答缓存已禁用")
print(f" 流式显示: {'启用' if self.streaming_enabled else '禁用'}")
print(f" 打字速度: {self.typing_speed:.3f} 秒/字符")
print(f" TTS功能: {'启用' if self.tts_enabled else '禁用'}")
print(f" TTS模块: {'可用' if self.tts_available else '不可用'}")
if self.tts_enabled:
print(f" TTS设备: {self.tts_device}")
print(f" 音频输出: 文件保存模式")
print(f" 输出目录: {getattr(self, 'audio_output_dir', '未设置')}")
print(f" TTS线程: {'运行中' if self.tts_thread and self.tts_thread.is_alive() else '未启动'}")
print(f" 队列状态: {len(self.tts_queue)} 个任务")
print(f" 处理状态: {'处理中' if self.tts_processing else '空闲'}")
def load_cache(self):
"""加载缓存数据"""
try:
cache_file = "./data_base/cache_data.pkl"
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cache_data = pickle.load(f)
self.embedding_cache = cache_data.get('embedding_cache', {})
self.answer_cache = cache_data.get('answer_cache', {})
print(f"✅ 缓存加载完成: Embedding {len(self.embedding_cache)} 项,回答 {len(self.answer_cache)} 项")
except Exception as e:
print(f"⚠️ 缓存加载失败: {str(e)}")
def save_cache(self):
"""保存缓存数据"""
try:
cache_file = "./data_base/cache_data.pkl"
cache_data = {
'embedding_cache': self.embedding_cache,
'answer_cache': self.answer_cache
}
with open(cache_file, 'wb') as f:
pickle.dump(cache_data, f)
print(f"✅ 缓存保存完成")
except Exception as e:
print(f"⚠️ 缓存保存失败: {str(e)}")
def clear_cache(self):
"""清空缓存"""
with self.cache_lock:
self.embedding_cache.clear()
self.answer_cache.clear()
print("✅ 缓存已清空")
def typewriter_effect(self, text, speed=None):
"""打字机效果显示文本"""
if speed is None:
speed = self.typing_speed
for char in text:
print(char, end='', flush=True)
time.sleep(speed)
print() # 换行
def stream_response(self, response_generator):
"""流式显示回答"""
full_answer = ""
print("💬 回答: ", end='', flush=True)
try:
for chunk in response_generator:
if 'message' in chunk and 'content' in chunk['message']:
content = chunk['message']['content']
if content:
full_answer += content
print(content, end='', flush=True)
time.sleep(self.typing_speed)
except Exception as e:
print(f"\n⚠️ 流式显示错误: {str(e)}")
print() # 换行
return full_answer
def text_to_speech(self, text, speed=1.0, callback=None):
"""将文本转换为语音并保存到文件(异步处理)"""
if not self.tts_enabled or not self.tts_model or not text.strip():
return False
try:
# 清理文本,移除特殊字符
clean_text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?;:""''()【】]', '', text)
if not clean_text.strip():
return False
# 限制文本长度,避免生成过长的音频
if len(clean_text) > 500:
clean_text = clean_text[:500] + "..."
# 创建TTS任务
task = {
'text': clean_text,
'speed': speed,
'callback': callback or self._default_tts_callback
}
# 添加到队列
with self.tts_lock:
self.tts_queue.append(task)
print(f"🎤 语音任务已加入队列: '{clean_text[:30]}...'")
print(f"📊 队列长度: {len(self.tts_queue)}")
return True
except Exception as e:
print(f"❌ TTS任务添加失败: {str(e)}")
return False
def _default_tts_callback(self, audio_file):
"""默认TTS回调函数"""
if audio_file:
print(f"✅ 语音文件生成完成: {audio_file}")
print(f"💡 您可以使用音频播放器播放此文件")
else:
print("❌ 语音文件生成失败")
def stop_audio(self):
"""停止音频播放(文件保存模式不需要)"""
print("💡 当前使用文件保存模式,无需停止播放")
def show_audio_files(self):
"""显示音频文件信息"""
if not hasattr(self, 'audio_output_dir') or not os.path.exists(self.audio_output_dir):
print("📁 音频输出目录不存在")
return
audio_files = [f for f in os.listdir(self.audio_output_dir) if f.endswith('.wav')]
if not audio_files:
print("📁 音频输出目录为空")
return
print(f"📁 音频文件列表 (共 {len(audio_files)} 个):")
print(f"📂 目录: {self.audio_output_dir}")
print("-" * 60)
for i, filename in enumerate(sorted(audio_files, reverse=True)[:10], 1): # 显示最近10个
file_path = os.path.join(self.audio_output_dir, filename)
file_size = os.path.getsize(file_path)
file_size_mb = file_size / (1024 * 1024)
mod_time = time.ctime(os.path.getmtime(file_path))
print(f"{i:2d}. {filename}")
print(f" 大小: {file_size_mb:.2f} MB")
print(f" 时间: {mod_time}")
print()
if len(audio_files) > 10:
print(f"... 还有 {len(audio_files) - 10} 个文件")
print("💡 您可以使用音频播放器播放这些文件")
def generate_embedding(self, text):
"""使用 Ollama 生成文本的 embedding 向量(带缓存优化)"""
if not text or not text.strip():
print("❌ 输入文本为空")
return None
# 生成文本的哈希值作为缓存键
text_hash = hashlib.md5(text.encode('utf-8')).hexdigest()
# 检查缓存
with self.cache_lock:
if text_hash in self.embedding_cache:
cached_embedding = self.embedding_cache[text_hash]
if cached_embedding is not None and isinstance(cached_embedding, np.ndarray):
return cached_embedding
else:
# 清理无效的缓存项
del self.embedding_cache[text_hash]
try:
print(f"🔍 正在生成文本的 embedding: '{text[:50]}...'")
response = ollama.embeddings(model=self.embedding_model, prompt=text)
if "embedding" not in response:
print(f"❌ Ollama 响应格式错误: {response}")
return None
embedding = response["embedding"]
if not embedding or len(embedding) == 0:
print("❌ 生成的 embedding 为空")
return None
# 转换为 numpy 数组
embedding = np.array(embedding, dtype=np.float32)
# 检查数组是否有效
if np.isnan(embedding).any() or np.isinf(embedding).any():
print("❌ embedding 包含 NaN 或 Inf 值")
return None
# 归一化
norm = np.linalg.norm(embedding)
if norm == 0:
print("❌ embedding 向量的范数为 0")
return None
embedding = embedding / norm
print(f"✅ embedding 生成成功: 维度 {len(embedding)}, 范数 {np.linalg.norm(embedding):.6f}")
# 缓存结果
with self.cache_lock:
self.embedding_cache[text_hash] = embedding
# 限制缓存大小,避免内存溢出
if len(self.embedding_cache) > 1000:
# 删除最旧的缓存项
oldest_key = next(iter(self.embedding_cache))
del self.embedding_cache[oldest_key]
return embedding
except Exception as e:
print(f"❌ Embedding 生成失败: {str(e)}")
import traceback
traceback.print_exc()
return None
def search_knowledge_base(self, query, top_k=20):
"""在知识库中搜索相关内容(优化版本)"""
try:
# 检查 FAISS 索引是否正确加载
if self.faiss_index is None:
print("❌ FAISS 索引未加载")
return []
# 检查元数据是否正确加载
if self.faiss_metadata is None or len(self.faiss_metadata) == 0:
print("❌ 元数据未加载")
return []
# 生成查询的 embedding
query_embedding = self.generate_embedding(query)
if query_embedding is None:
print("❌ 无法生成查询的 embedding")
return []
# 确保 embedding 是正确的 numpy 数组
if not isinstance(query_embedding, np.ndarray):
print(f"❌ embedding 类型错误: {type(query_embedding)}")
return []
# 检查向量维度
expected_dim = self.faiss_index.d
if query_embedding.shape[0] != expected_dim:
print(f"❌ 向量维度不匹配: 期望 {expected_dim}, 实际 {query_embedding.shape[0]}")
return []
# 重塑为正确的形状
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
# 执行 FAISS 搜索
scores, indices = self.faiss_index.search(query_embedding, top_k)
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < len(self.faiss_metadata):
metadata = self.faiss_metadata[idx]
page_data = self.wiki_pages[idx]
results.append({
'rank': i + 1,
'score': float(score),
'title': metadata['title'],
'url': metadata['url'],
'content': page_data['content'],
'content_length': metadata['content_length'],
'timestamp': metadata['timestamp']
})
return results
except Exception as e:
print(f"❌ 搜索失败: {str(e)}")
import traceback
traceback.print_exc()
return []
def ask_question(self, question):
"""提问并获取回答(优化版本)"""
print(f"\n🤔 用户问题: {question}")
# 禁用回答缓存,每次都实时生成
# question_hash = hashlib.md5(question.encode('utf-8')).hexdigest()
# with self.cache_lock:
# if question_hash in self.answer_cache:
# print("⚡ 使用缓存回答")
# print(f"\n💬 回答:")
# print(f"{self.answer_cache[question_hash]}")
# return
# 搜索知识库
print("🔍 正在搜索知识库...")
start_time = time.time()
search_results = self.search_knowledge_base(question, top_k=20) # 减少搜索数量
search_time = time.time() - start_time
if not search_results:
print("❌ 未找到相关信息")
return
print(f"✅ 搜索完成,耗时: {search_time:.3f} 秒")
print(f"📊 找到 {len(search_results)} 个相关文档")
# 智能选择最相关的结果(减少到5个)
best_results = self.select_best_results(question, search_results, max_results=5)
# 显示搜索结果
print(f"🔍 搜索结果预览:")
for i, result in enumerate(best_results[:3]): # 只显示前3个
print(f" {i+1}. {result['title']}")
print(f" 相关度: {result['score']:.3f}")
print(f" URL: {result['url']}")
print()
# 生成回答
print("🤖 正在生成回答...")
answer_start_time = time.time()
answer = self.generate_answer(question, best_results)
answer_time = time.time() - answer_start_time
# 显示回答
print(f"\n💬 回答:")
print(f"{answer}")
print(f"\n⏱️ 回答生成耗时: {answer_time:.3f} 秒")
# 语音生成回答(异步处理)
if self.tts_enabled and answer:
print("🎤 正在将语音任务加入队列...")
success = self.text_to_speech(answer, speed=1.0)
if success:
print("✅ 语音任务已加入后台处理队列")
print("💡 语音文件将在后台生成,完成后会显示文件路径")
else:
print("❌ 语音任务添加失败")
# 禁用回答缓存,不保存生成的回答
# with self.cache_lock:
# self.answer_cache[question_hash] = answer
# # 限制缓存大小
# if len(self.answer_cache) > 100:
# oldest_key = next(iter(self.answer_cache))
# del self.answer_cache[oldest_key]
def select_best_results(self, question, search_results, max_results=10):
"""智能选择最相关的结果"""
if not search_results:
return []
# 提取问题中的关键词
question_lower = question.lower()
keywords = []
# 中文关键词
chinese_keywords = ['矽递', '科技', '公司', '介绍', '简介', '关于', '什么是', '如何', '怎么']
for keyword in chinese_keywords:
if keyword in question_lower:
keywords.append(keyword)
# 英文关键词
english_keywords = ['seeed', 'studio', 'company', 'introduction', 'about', 'what', 'how']
for keyword in english_keywords:
if keyword in question_lower:
keywords.append(keyword)
# 计算每个结果的相关性分数
scored_results = []
for result in search_results:
score = result['score']
title = result['title'].lower()
content = result['content'].lower()
# 关键词匹配加分
keyword_bonus = 0
for keyword in keywords:
if keyword in title:
keyword_bonus += 0.1
if keyword in content:
keyword_bonus += 0.05
# 标题匹配加分
title_bonus = 0
if any(keyword in title for keyword in keywords):
title_bonus += 0.05
# 计算最终分数
final_score = score + keyword_bonus + title_bonus
scored_results.append((result, final_score))
# 按最终分数排序
scored_results.sort(key=lambda x: x[1], reverse=True)
# 返回前N个最佳结果
best_results = [result for result, score in scored_results[:max_results]]
print(f"🔍 智能选择结果:")
print(f" 关键词: {keywords}")
print(f" 选择结果数: {len(best_results)}")
return best_results
def detect_language(self, text):
"""检测文本语言 - 改进版本"""
# 检测中文字符
chinese_chars = re.findall(r'[\u4e00-\u9fff]', text)
english_chars = re.findall(r'[a-zA-Z]', text)
# 计算中英文比例
total_chars = len(text.replace(' ', '').replace('\n', ''))
if total_chars == 0:
return 'en' # 默认为英文
chinese_ratio = len(chinese_chars) / total_chars
english_ratio = len(english_chars) / total_chars
# 如果中文字符超过10%,或者中文比例大于英文比例,则认为是中文
if chinese_ratio > 0.1 or (chinese_ratio > 0 and chinese_ratio > english_ratio):
return 'zh'
elif english_ratio > 0.5:
return 'en'
else:
# 如果都不明显,检查是否有中文标点符号
chinese_punctuation = re.findall(r'[,。!?;:""''()【】]', text)
if chinese_punctuation:
return 'zh'
return 'en'
def generate_answer(self, question, search_results):
"""基于搜索结果生成回答 - 优化版本"""
if not search_results:
return "抱歉,我在知识库中没有找到相关信息。"
# 检测用户问题的语言
user_language = self.detect_language(question)
print(f"🔍 检测到问题语言: {user_language}")
# 构建上下文信息(优化:限制长度)
context_parts = []
total_length = 0
max_context_length = 3000 # 限制上下文长度
for result in search_results:
title = result['title']
content = result['content']
# 移除 [Introduction] 前缀,清理内容
if content.startswith('[Introduction] '):
content = content[16:]
# 截断过长的内容
if len(content) > 500:
content = content[:500] + "..."
context_part = f"文档标题: {title}\n内容: {content}"
# 检查是否会超出长度限制
if total_length + len(context_part) > max_context_length:
break
context_parts.append(context_part)
total_length += len(context_part)
context = "\n\n".join(context_parts)
# 根据用户语言选择 prompt,强制指定输出语言
if user_language == 'zh':
prompt = f"""请基于以下资料,用自然、连贯的中文回答用户问题。
重要要求:
1. 必须用中文回答,不能使用英文
2. 介绍产品时说"我们的xxx产品..."
3. 严格基于提供的资料回答,绝对不能编造或虚构任何信息
4. 如果资料中没有某个具体信息(如成立时间、具体数据等),绝对不要编造,应该说"资料中未提及"
5. 语言要流畅自然,体现专业且亲切的企业形象
6. 不要分点分段,用一段话概括所有相关信息
7. 不要重复说"我们是Seeed Studio的AI助手"这样的身份介绍
相关资料:
{context}
用户问题: {question}
请用一段连贯的中文回答,使用"我们"的表达方式,严格基于资料内容,不编造任何信息:"""
else:
prompt = f"""Please answer the user's question in natural, coherent English based on the following materials.
Important requirements:
1. Must answer in English, not in Chinese
2. Answer as Seeed Studio's representative, using "we" expressions
3. When introducing products, say "our xxx product..."
4. Strictly base your answer on the provided materials, absolutely do not fabricate or invent any information
5. If specific information (like founding date, specific data, etc.) is not mentioned in the materials, absolutely do not make it up, say "not mentioned in the materials"
6. Make the language fluent and natural, reflecting a professional yet friendly corporate image
7. Don't use bullet points or separate paragraphs, summarize all relevant information in one coherent paragraph
8. Don't repeat identity introductions like "We are Seeed Studio's AI assistant"
Materials:
{context}
User Question: {question}
Please answer using "we" expressions in one coherent English paragraph, strictly based on the materials without fabricating any information:"""
# 使用 Ollama 生成自然语言回答(流式版本)
try:
# 优化:使用更快的模型和更简洁的prompt
system_prompt = f'用{user_language}回答,基于资料,不编造信息,不要重复身份介绍。'
if self.streaming_enabled:
# 流式生成回答
response_generator = ollama.chat(
model='qwen2.5:3b',
messages=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': prompt}
],
options={
'temperature': 0.7, # 降低随机性,提高一致性
'top_p': 0.9, # 限制词汇选择范围
'num_predict': 300, # 限制生成长度
},
stream=True # 启用流式输出
)
answer = self.stream_response(response_generator)
else:
# 非流式生成回答
response = ollama.chat(
model='qwen2.5:3b',
messages=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': prompt}
],
options={
'temperature': 0.7,
'top_p': 0.9,
'num_predict': 300,
}
)
answer = response['message']['content'].strip()
# 使用打字机效果显示
print("💬 回答: ", end='', flush=True)
self.typewriter_effect(answer)
# 验证回答语言