Skip to content

Commit 45a6a33

Browse files
committed
test_83%
1 parent 7ec55e9 commit 45a6a33

15 files changed

+4545
-928
lines changed

aigraphx/repositories/neo4j_repo.py

Lines changed: 360 additions & 418 deletions
Large diffs are not rendered by default.

aigraphx/repositories/postgres_repo.py

Lines changed: 79 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ async def get_hf_models_by_ids(self, model_ids: List[str]) -> List[Dict[str, Any
199199
"""Fetches details for multiple HF models by their model_ids."""
200200
if not model_ids:
201201
return []
202-
# Query needed columns from hf_models table
203-
query = "SELECT hf_model_id, author, hf_pipeline_tag FROM hf_models WHERE hf_model_id = ANY(%s);"
202+
# Query ALL columns from hf_models table, not just selected ones
203+
query = "SELECT * FROM hf_models WHERE hf_model_id = ANY(%s);"
204204
try:
205205
async with self.pool.connection() as conn:
206206
async with conn.cursor(row_factory=dict_row) as cur:
@@ -217,17 +217,17 @@ async def search_models_by_keyword(
217217
) -> Tuple[List[Dict[str, Any]], int]:
218218
"""Searches hf_models table by keyword across relevant fields."""
219219
search_term = f"%{query}%"
220-
# Use COALESCE for potentially null fields like pipeline_tag
220+
# Use COALESCE for potentially null fields like hf_pipeline_tag
221221
# Ensure correct column names are used
222222
where_clause = """
223-
model_id ILIKE %s OR
224-
author ILIKE %s OR
225-
COALESCE(pipeline_tag, '') ILIKE %s
223+
hf_model_id ILIKE %s OR
224+
hf_author ILIKE %s OR
225+
COALESCE(hf_pipeline_tag, '') ILIKE %s
226226
"""
227227
params = [search_term] * 3 # Repeat search term for each ILIKE
228228

229229
# Fields to select (adjust as needed)
230-
select_fields = "model_id, author, pipeline_tag, last_modified, tags, likes, downloads, library_name, sha"
230+
select_fields = "hf_model_id, hf_author, hf_pipeline_tag, hf_last_modified, hf_tags, hf_likes, hf_downloads, hf_library_name, hf_sha"
231231

232232
count_sql = f"""
233233
SELECT COUNT(*)
@@ -239,7 +239,7 @@ async def search_models_by_keyword(
239239
SELECT {select_fields}
240240
FROM hf_models
241241
WHERE ({where_clause})
242-
ORDER BY last_modified DESC
242+
ORDER BY hf_last_modified DESC
243243
LIMIT %s OFFSET %s
244244
"""
245245

@@ -291,11 +291,13 @@ async def get_all_hf_models_for_sync(
291291
self, batch_size: int = 1000
292292
) -> AsyncGenerator[List[Dict[str, Any]], None]:
293293
"""Fetches all HF models in batches, yielding each batch as a list of dicts."""
294-
query = "SELECT * FROM hf_models;" # Select all fields for sync
294+
# 修复SQL语法错误,不要在查询中使用分号
295+
query = "SELECT * FROM hf_models" # 移除末尾的分号
295296
offset = 0
296297
try:
297298
while True:
298-
batch_query = f"{query} ORDER BY hf_model_id LIMIT %s OFFSET %s;" # Order for deterministic batches
299+
# 正确的SQL格式
300+
batch_query = f"{query} ORDER BY hf_model_id LIMIT %s OFFSET %s" # 移除末尾的分号
299301
async with self.pool.connection() as conn:
300302
async with conn.cursor(row_factory=dict_row) as cur:
301303
await cur.execute(batch_query, (batch_size, offset))
@@ -311,82 +313,89 @@ async def get_all_hf_models_for_sync(
311313

312314
async def save_paper(self, paper_data: Dict[str, Any]) -> bool:
313315
"""Saves a single paper's data, handling potential JSON fields and conflicts."""
314-
# Prepare columns and values, handle JSON serialization
315-
# Note: psycopg handles most Python types including lists/dicts for JSONB
316-
cols = []
317-
vals = []
318-
excluded_updates = []
319-
for key, value in paper_data.items():
320-
# Basic validation/transformation (adjust as needed)
321-
if key == "authors" and not isinstance(value, list):
322-
value = []
323-
if key == "categories" and not isinstance(value, list):
324-
value = []
325-
# if isinstance(value, (dict, list)): # Psycopg handles this
326-
# value = json.dumps(value)
327-
328-
cols.append(key)
329-
vals.append(value)
330-
# For ON CONFLICT, exclude the primary key (pwc_id) from update set
331-
if key != "pwc_id":
332-
# Use f-string safely as key comes from dict keys
333-
excluded_updates.append(f"{key} = EXCLUDED.{key}")
334-
335-
if not cols or "pwc_id" not in cols:
316+
# 基本参数验证
317+
if not paper_data or "pwc_id" not in paper_data:
336318
self.logger.error("Cannot save paper: missing data or pwc_id.")
337319
return False
338-
339-
# Construct query (Ensure table and column names are correct)
340-
# Using pwc_id for conflict resolution
341-
query = f"""
342-
INSERT INTO papers ({", ".join(cols)})
343-
VALUES ({", ".join(["%s"] * len(vals))})
344-
ON CONFLICT (pwc_id) DO UPDATE SET
345-
{", ".join(excluded_updates)},
346-
updated_at = CURRENT_TIMESTAMP;
347-
"""
320+
348321
try:
322+
# 准备列和值
323+
cols = []
324+
vals = []
325+
excluded_updates = []
326+
327+
for key, value in paper_data.items():
328+
# 处理特殊字段
329+
if key in ["authors", "categories"] and value is not None:
330+
# 确保列表类型被序列化为JSON
331+
if isinstance(value, list):
332+
value = json.dumps(value)
333+
334+
cols.append(key)
335+
vals.append(value)
336+
337+
# 对于ON CONFLICT,排除主键
338+
if key != "pwc_id":
339+
excluded_updates.append(f"{key} = EXCLUDED.{key}")
340+
341+
# 构建查询
342+
if not excluded_updates:
343+
# 如果只有pwc_id,至少添加一个更新字段
344+
excluded_updates = ["updated_at = CURRENT_TIMESTAMP"]
345+
346+
query = f"""
347+
INSERT INTO papers ({", ".join(cols)})
348+
VALUES ({", ".join(["%s"] * len(vals))})
349+
ON CONFLICT (pwc_id) DO UPDATE SET
350+
{", ".join(excluded_updates)};
351+
"""
352+
353+
self.logger.debug(f"Saving paper with pwc_id: {paper_data.get('pwc_id')}")
354+
349355
async with self.pool.connection() as conn:
350356
async with conn.cursor() as cur:
351357
await cur.execute(query, vals)
352-
return cur.rowcount > 0 # Check if a row was inserted/updated
358+
# 返回True表示成功执行,无论是插入还是更新
359+
return True
353360
except Exception as e:
354361
self.logger.error(
355362
f"Error saving paper with pwc_id {paper_data.get('pwc_id')} to PG: {e}"
356363
)
357-
# Consider logging paper_data partially for debugging
358364
self.logger.debug(traceback.format_exc())
359365
return False
360366

361367
async def save_hf_models_batch(self, models_data: List[Dict[str, Any]]) -> None:
362-
"""Saves a batch of HF models, updating existing ones based on model_id."""
368+
"""Saves a batch of HF models, updating existing ones based on hf_model_id."""
363369
if not models_data:
364370
return
365371

366372
# Prepare columns based on the first model (assume consistency)
367-
# Handle potential JSON fields if needed (e.g., tags)
373+
# Handle potential JSON fields if needed (e.g., hf_tags)
368374
if not models_data[0]:
369375
return # Handle empty dict case
370376

371377
cols = list(models_data[0].keys())
372-
# Ensure model_id is present for ON CONFLICT
373-
if "model_id" not in cols:
378+
# Ensure hf_model_id is present for ON CONFLICT
379+
if "hf_model_id" not in cols:
374380
self.logger.error(
375-
"Cannot save HF models batch: 'model_id' missing in data."
381+
"Cannot save HF models batch: 'hf_model_id' missing in data."
376382
)
377383
return
378384

379385
excluded_updates = []
380386
for key in cols:
381-
if key != "model_id":
387+
if key != "hf_model_id":
382388
# Use f-string safely as key comes from dict keys
383389
excluded_updates.append(f"{key} = EXCLUDED.{key}")
384390

391+
# 构建每行的VALUES部分
392+
placeholders = ", ".join(["%s"] * len(cols))
393+
385394
# Construct query (Ensure table and column names are correct)
386395
query = f"""
387396
INSERT INTO hf_models ({", ".join(cols)})
388-
VALUES %s
389-
ON CONFLICT (model_id) DO UPDATE SET
397+
VALUES ({placeholders})
398+
ON CONFLICT (hf_model_id) DO UPDATE SET
390399
{", ".join(excluded_updates)};
391400
"""
392401

@@ -396,8 +405,8 @@ async def save_hf_models_batch(self, models_data: List[Dict[str, Any]]) -> None:
396405
row = []
397406
for col in cols:
398407
value = model.get(col)
399-
# FIX: Serialize list fields (like 'tags') to JSON strings
400-
if col == "tags" and isinstance(value, list):
408+
# FIX: Serialize list fields (like 'hf_tags') to JSON strings
409+
if col == "hf_tags" and isinstance(value, list):
401410
row.append(json.dumps(value))
402411
else:
403412
row.append(value) # type: ignore[arg-type] # Allow None, db driver handles it
@@ -406,9 +415,12 @@ async def save_hf_models_batch(self, models_data: List[Dict[str, Any]]) -> None:
406415
try:
407416
async with self.pool.connection() as conn:
408417
async with conn.cursor() as cur:
409-
await cur.executemany(query, data_tuples)
418+
# 逐个执行插入,而不是使用executemany
419+
for data_tuple in data_tuples:
420+
await cur.execute(query, data_tuple)
421+
410422
self.logger.info(
411-
f"Successfully saved/updated batch of {cur.rowcount} HF models."
423+
f"Successfully saved/updated batch of {len(models_data)} HF models."
412424
)
413425
except Exception as e:
414426
self.logger.error(f"Error saving hf_models batch to PG: {e}")
@@ -506,6 +518,13 @@ async def get_all_paper_ids_and_text(
506518

507519
async def close(self) -> None:
508520
"""Closes the connection pool (if applicable and managed by this instance)."""
521+
if hasattr(self, 'pool') and self.pool is not None:
522+
try:
523+
await self.pool.close()
524+
self.logger.info("PostgreSQL connection pool closed successfully.")
525+
except Exception as e:
526+
self.logger.error(f"Error closing PostgreSQL connection pool: {e}")
527+
self.logger.debug(traceback.format_exc())
509528

510529
# --- NEW Method: fetch_one --- #
511530
async def fetch_one(
@@ -710,7 +729,7 @@ async def get_tasks_for_papers(self, paper_ids: List[int]) -> Dict[int, List[str
710729
query = sql.SQL(
711730
"""
712731
SELECT paper_id, task_name
713-
FROM pwc_tasks -- Corrected table name
732+
FROM pwc_tasks
714733
WHERE paper_id = ANY(%s);
715734
"""
716735
)
@@ -745,7 +764,7 @@ async def get_datasets_for_papers(
745764
query = sql.SQL(
746765
"""
747766
SELECT paper_id, dataset_name
748-
FROM pwc_datasets -- Corrected table name
767+
FROM pwc_datasets
749768
WHERE paper_id = ANY(%s);
750769
"""
751770
)
@@ -779,8 +798,8 @@ async def get_repositories_for_papers(
779798
return {}
780799
query = sql.SQL(
781800
"""
782-
SELECT paper_id, url -- Corrected column name
783-
FROM pwc_repositories -- Corrected table name
801+
SELECT paper_id, url
802+
FROM pwc_repositories
784803
WHERE paper_id = ANY(%s);
785804
"""
786805
)

0 commit comments

Comments
 (0)