|
1 | 1 | import json |
2 | 2 | import os |
| 3 | +import tempfile |
3 | 4 | import traceback |
4 | 5 | import warnings |
| 6 | +from urllib.parse import urlparse |
5 | 7 |
|
6 | 8 | from neo4j import GraphDatabase as GD |
7 | 9 |
|
8 | 10 | from src import config |
9 | 11 | from src.models import select_embedding_model |
| 12 | +from src.storage.minio.client import get_minio_client, StorageError |
10 | 13 | from src.utils import logger |
11 | 14 | from src.utils.datetime_utils import utc_isoformat |
12 | 15 |
|
@@ -347,15 +350,59 @@ async def jsonl_file_add_entity(self, file_path, kgdb_name="neo4j"): |
347 | 350 | self.use_database(kgdb_name) # 切换到指定数据库 |
348 | 351 | logger.info(f"Start adding entity to {kgdb_name} with {file_path}") |
349 | 352 |
|
350 | | - def read_triples(file_path): |
351 | | - with open(file_path, encoding="utf-8") as file: |
352 | | - for line in file: |
353 | | - if line.strip(): |
354 | | - yield json.loads(line.strip()) |
| 353 | + # 检测 file_path 是否是 URL |
| 354 | + parsed_url = urlparse(file_path) |
| 355 | + temp_file_path = None |
355 | 356 |
|
356 | | - triples = list(read_triples(file_path)) |
| 357 | + try: |
| 358 | + if parsed_url.scheme in ('http', 'https'): # 如果是 URL |
| 359 | + logger.info(f"检测到 URL,正在从 MinIO 下载文件: {file_path}") |
| 360 | + |
| 361 | + # 从 URL 解析 bucket_name 和 object_name |
| 362 | + # URL 格式: http://host:port/bucket_name/object_name |
| 363 | + path_parts = parsed_url.path.lstrip('/').split('/', 1) |
| 364 | + if len(path_parts) < 2: |
| 365 | + raise ValueError(f"无法解析 MinIO URL: {file_path}") |
| 366 | + |
| 367 | + bucket_name = path_parts[0] |
| 368 | + object_name = path_parts[1] |
| 369 | + |
| 370 | + # 从 MinIO 下载文件 |
| 371 | + minio_client = get_minio_client() |
| 372 | + file_data = await minio_client.adownload_file(bucket_name, object_name) |
| 373 | + |
| 374 | + # 创建临时文件保存下载的内容 |
| 375 | + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False, encoding='utf-8') as temp_file: |
| 376 | + temp_file.write(file_data.decode('utf-8')) |
| 377 | + temp_file_path = temp_file.name |
| 378 | + |
| 379 | + logger.info(f"文件已下载到临时路径: {temp_file_path}") |
| 380 | + actual_file_path = temp_file_path |
| 381 | + else: |
| 382 | + # 本地文件路径 |
| 383 | + actual_file_path = file_path |
357 | 384 |
|
358 | | - await self.txt_add_vector_entity(triples, kgdb_name) |
| 385 | + def read_triples(file_path): |
| 386 | + with open(file_path, encoding="utf-8") as file: |
| 387 | + for line in file: |
| 388 | + if line.strip(): |
| 389 | + yield json.loads(line.strip()) |
| 390 | + |
| 391 | + triples = list(read_triples(actual_file_path)) |
| 392 | + |
| 393 | + await self.txt_add_vector_entity(triples, kgdb_name) |
| 394 | + |
| 395 | + except Exception as e: |
| 396 | + logger.error(f"处理文件失败: {e}") |
| 397 | + raise |
| 398 | + finally: |
| 399 | + # 清理临时文件 |
| 400 | + if temp_file_path and os.path.exists(temp_file_path): |
| 401 | + try: |
| 402 | + os.unlink(temp_file_path) |
| 403 | + logger.info(f"已删除临时文件: {temp_file_path}") |
| 404 | + except Exception as e: |
| 405 | + logger.warning(f"删除临时文件失败: {e}") |
359 | 406 |
|
360 | 407 | self.status = "open" |
361 | 408 | # 更新并保存图数据库信息 |
|
0 commit comments