Skip to content

Commit 3d6a38a

Browse files
committed
fix(graph): 修复图谱文件上传的 bug,支持从MinIO URL加载知识图谱数据
添加从MinIO存储下载文件的功能,当检测到输入路径为URL时自动处理 下载完成后会创建临时文件并确保最终清理
1 parent eb79ad4 commit 3d6a38a

File tree

1 file changed

+54
-7
lines changed

1 file changed

+54
-7
lines changed

src/knowledge/graph.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import json
22
import os
3+
import tempfile
34
import traceback
45
import warnings
6+
from urllib.parse import urlparse
57

68
from neo4j import GraphDatabase as GD
79

810
from src import config
911
from src.models import select_embedding_model
12+
from src.storage.minio.client import get_minio_client, StorageError
1013
from src.utils import logger
1114
from src.utils.datetime_utils import utc_isoformat
1215

@@ -347,15 +350,59 @@ async def jsonl_file_add_entity(self, file_path, kgdb_name="neo4j"):
347350
self.use_database(kgdb_name) # 切换到指定数据库
348351
logger.info(f"Start adding entity to {kgdb_name} with {file_path}")
349352

350-
def read_triples(file_path):
351-
with open(file_path, encoding="utf-8") as file:
352-
for line in file:
353-
if line.strip():
354-
yield json.loads(line.strip())
353+
# 检测 file_path 是否是 URL
354+
parsed_url = urlparse(file_path)
355+
temp_file_path = None
355356

356-
triples = list(read_triples(file_path))
357+
try:
358+
if parsed_url.scheme in ('http', 'https'): # 如果是 URL
359+
logger.info(f"检测到 URL,正在从 MinIO 下载文件: {file_path}")
360+
361+
# 从 URL 解析 bucket_name 和 object_name
362+
# URL 格式: http://host:port/bucket_name/object_name
363+
path_parts = parsed_url.path.lstrip('/').split('/', 1)
364+
if len(path_parts) < 2:
365+
raise ValueError(f"无法解析 MinIO URL: {file_path}")
366+
367+
bucket_name = path_parts[0]
368+
object_name = path_parts[1]
369+
370+
# 从 MinIO 下载文件
371+
minio_client = get_minio_client()
372+
file_data = await minio_client.adownload_file(bucket_name, object_name)
373+
374+
# 创建临时文件保存下载的内容
375+
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False, encoding='utf-8') as temp_file:
376+
temp_file.write(file_data.decode('utf-8'))
377+
temp_file_path = temp_file.name
378+
379+
logger.info(f"文件已下载到临时路径: {temp_file_path}")
380+
actual_file_path = temp_file_path
381+
else:
382+
# 本地文件路径
383+
actual_file_path = file_path
357384

358-
await self.txt_add_vector_entity(triples, kgdb_name)
385+
def read_triples(file_path):
386+
with open(file_path, encoding="utf-8") as file:
387+
for line in file:
388+
if line.strip():
389+
yield json.loads(line.strip())
390+
391+
triples = list(read_triples(actual_file_path))
392+
393+
await self.txt_add_vector_entity(triples, kgdb_name)
394+
395+
except Exception as e:
396+
logger.error(f"处理文件失败: {e}")
397+
raise
398+
finally:
399+
# 清理临时文件
400+
if temp_file_path and os.path.exists(temp_file_path):
401+
try:
402+
os.unlink(temp_file_path)
403+
logger.info(f"已删除临时文件: {temp_file_path}")
404+
except Exception as e:
405+
logger.warning(f"删除临时文件失败: {e}")
359406

360407
self.status = "open"
361408
# 更新并保存图数据库信息

0 commit comments

Comments
 (0)