@@ -199,8 +199,8 @@ async def get_hf_models_by_ids(self, model_ids: List[str]) -> List[Dict[str, Any
199199 """Fetches details for multiple HF models by their model_ids."""
200200 if not model_ids :
201201 return []
202- # Query needed columns from hf_models table
203- query = "SELECT hf_model_id, author, hf_pipeline_tag FROM hf_models WHERE hf_model_id = ANY(%s);"
202+ # Query ALL columns from hf_models table, not just selected ones
203+ query = "SELECT * FROM hf_models WHERE hf_model_id = ANY(%s);"
204204 try :
205205 async with self .pool .connection () as conn :
206206 async with conn .cursor (row_factory = dict_row ) as cur :
@@ -217,17 +217,17 @@ async def search_models_by_keyword(
217217 ) -> Tuple [List [Dict [str , Any ]], int ]:
218218 """Searches hf_models table by keyword across relevant fields."""
219219 search_term = f"%{ query } %"
220- # Use COALESCE for potentially null fields like pipeline_tag
220+ # Use COALESCE for potentially null fields like hf_pipeline_tag
221221 # Ensure correct column names are used
222222 where_clause = """
223- model_id ILIKE %s OR
224- author ILIKE %s OR
225- COALESCE(pipeline_tag , '') ILIKE %s
223+ hf_model_id ILIKE %s OR
224+ hf_author ILIKE %s OR
225+ COALESCE(hf_pipeline_tag , '') ILIKE %s
226226 """
227227 params = [search_term ] * 3 # Repeat search term for each ILIKE
228228
229229 # Fields to select (adjust as needed)
230- select_fields = "model_id, author, pipeline_tag, last_modified, tags, likes, downloads, library_name, sha "
230+ select_fields = "hf_model_id, hf_author, hf_pipeline_tag, hf_last_modified, hf_tags, hf_likes, hf_downloads, hf_library_name, hf_sha "
231231
232232 count_sql = f"""
233233 SELECT COUNT(*)
@@ -239,7 +239,7 @@ async def search_models_by_keyword(
239239 SELECT { select_fields }
240240 FROM hf_models
241241 WHERE ({ where_clause } )
242- ORDER BY last_modified DESC
242+ ORDER BY hf_last_modified DESC
243243 LIMIT %s OFFSET %s
244244 """
245245
@@ -291,11 +291,13 @@ async def get_all_hf_models_for_sync(
291291 self , batch_size : int = 1000
292292 ) -> AsyncGenerator [List [Dict [str , Any ]], None ]:
293293 """Fetches all HF models in batches, yielding each batch as a list of dicts."""
294- query = "SELECT * FROM hf_models;" # Select all fields for sync
294+ # 修复SQL语法错误,不要在查询中使用分号
295+ query = "SELECT * FROM hf_models" # 移除末尾的分号
295296 offset = 0
296297 try :
297298 while True :
298- batch_query = f"{ query } ORDER BY hf_model_id LIMIT %s OFFSET %s;" # Order for deterministic batches
299+ # 正确的SQL格式
300+ batch_query = f"{ query } ORDER BY hf_model_id LIMIT %s OFFSET %s" # 移除末尾的分号
299301 async with self .pool .connection () as conn :
300302 async with conn .cursor (row_factory = dict_row ) as cur :
301303 await cur .execute (batch_query , (batch_size , offset ))
@@ -311,82 +313,89 @@ async def get_all_hf_models_for_sync(
311313
312314 async def save_paper (self , paper_data : Dict [str , Any ]) -> bool :
313315 """Saves a single paper's data, handling potential JSON fields and conflicts."""
314- # Prepare columns and values, handle JSON serialization
315- # Note: psycopg handles most Python types including lists/dicts for JSONB
316- cols = []
317- vals = []
318- excluded_updates = []
319- for key , value in paper_data .items ():
320- # Basic validation/transformation (adjust as needed)
321- if key == "authors" and not isinstance (value , list ):
322- value = []
323- if key == "categories" and not isinstance (value , list ):
324- value = []
325- # if isinstance(value, (dict, list)): # Psycopg handles this
326- # value = json.dumps(value)
327-
328- cols .append (key )
329- vals .append (value )
330- # For ON CONFLICT, exclude the primary key (pwc_id) from update set
331- if key != "pwc_id" :
332- # Use f-string safely as key comes from dict keys
333- excluded_updates .append (f"{ key } = EXCLUDED.{ key } " )
334-
335- if not cols or "pwc_id" not in cols :
316+ # 基本参数验证
317+ if not paper_data or "pwc_id" not in paper_data :
336318 self .logger .error ("Cannot save paper: missing data or pwc_id." )
337319 return False
338-
339- # Construct query (Ensure table and column names are correct)
340- # Using pwc_id for conflict resolution
341- query = f"""
342- INSERT INTO papers ({ ", " .join (cols )} )
343- VALUES ({ ", " .join (["%s" ] * len (vals ))} )
344- ON CONFLICT (pwc_id) DO UPDATE SET
345- { ", " .join (excluded_updates )} ,
346- updated_at = CURRENT_TIMESTAMP;
347- """
320+
348321 try :
322+ # 准备列和值
323+ cols = []
324+ vals = []
325+ excluded_updates = []
326+
327+ for key , value in paper_data .items ():
328+ # 处理特殊字段
329+ if key in ["authors" , "categories" ] and value is not None :
330+ # 确保列表类型被序列化为JSON
331+ if isinstance (value , list ):
332+ value = json .dumps (value )
333+
334+ cols .append (key )
335+ vals .append (value )
336+
337+ # 对于ON CONFLICT,排除主键
338+ if key != "pwc_id" :
339+ excluded_updates .append (f"{ key } = EXCLUDED.{ key } " )
340+
341+ # 构建查询
342+ if not excluded_updates :
343+ # 如果只有pwc_id,至少添加一个更新字段
344+ excluded_updates = ["updated_at = CURRENT_TIMESTAMP" ]
345+
346+ query = f"""
347+ INSERT INTO papers ({ ", " .join (cols )} )
348+ VALUES ({ ", " .join (["%s" ] * len (vals ))} )
349+ ON CONFLICT (pwc_id) DO UPDATE SET
350+ { ", " .join (excluded_updates )} ;
351+ """
352+
353+ self .logger .debug (f"Saving paper with pwc_id: { paper_data .get ('pwc_id' )} " )
354+
349355 async with self .pool .connection () as conn :
350356 async with conn .cursor () as cur :
351357 await cur .execute (query , vals )
352- return cur .rowcount > 0 # Check if a row was inserted/updated
358+ # 返回True表示成功执行,无论是插入还是更新
359+ return True
353360 except Exception as e :
354361 self .logger .error (
355362 f"Error saving paper with pwc_id { paper_data .get ('pwc_id' )} to PG: { e } "
356363 )
357- # Consider logging paper_data partially for debugging
358364 self .logger .debug (traceback .format_exc ())
359365 return False
360366
361367 async def save_hf_models_batch (self , models_data : List [Dict [str , Any ]]) -> None :
362- """Saves a batch of HF models, updating existing ones based on model_id ."""
368+ """Saves a batch of HF models, updating existing ones based on hf_model_id ."""
363369 if not models_data :
364370 return
365371
366372 # Prepare columns based on the first model (assume consistency)
367- # Handle potential JSON fields if needed (e.g., tags )
373+ # Handle potential JSON fields if needed (e.g., hf_tags )
368374 if not models_data [0 ]:
369375 return # Handle empty dict case
370376
371377 cols = list (models_data [0 ].keys ())
372- # Ensure model_id is present for ON CONFLICT
373- if "model_id " not in cols :
378+ # Ensure hf_model_id is present for ON CONFLICT
379+ if "hf_model_id " not in cols :
374380 self .logger .error (
375- "Cannot save HF models batch: 'model_id ' missing in data."
381+ "Cannot save HF models batch: 'hf_model_id ' missing in data."
376382 )
377383 return
378384
379385 excluded_updates = []
380386 for key in cols :
381- if key != "model_id " :
387+ if key != "hf_model_id " :
382388 # Use f-string safely as key comes from dict keys
383389 excluded_updates .append (f"{ key } = EXCLUDED.{ key } " )
384390
391+ # 构建每行的VALUES部分
392+ placeholders = ", " .join (["%s" ] * len (cols ))
393+
385394 # Construct query (Ensure table and column names are correct)
386395 query = f"""
387396 INSERT INTO hf_models ({ ", " .join (cols )} )
388- VALUES %s
389- ON CONFLICT (model_id ) DO UPDATE SET
397+ VALUES ( { placeholders } )
398+ ON CONFLICT (hf_model_id ) DO UPDATE SET
390399 { ", " .join (excluded_updates )} ;
391400 """
392401
@@ -396,8 +405,8 @@ async def save_hf_models_batch(self, models_data: List[Dict[str, Any]]) -> None:
396405 row = []
397406 for col in cols :
398407 value = model .get (col )
399- # FIX: Serialize list fields (like 'tags ') to JSON strings
400- if col == "tags " and isinstance (value , list ):
408+ # FIX: Serialize list fields (like 'hf_tags ') to JSON strings
409+ if col == "hf_tags " and isinstance (value , list ):
401410 row .append (json .dumps (value ))
402411 else :
403412 row .append (value ) # type: ignore[arg-type] # Allow None, db driver handles it
@@ -406,9 +415,12 @@ async def save_hf_models_batch(self, models_data: List[Dict[str, Any]]) -> None:
406415 try :
407416 async with self .pool .connection () as conn :
408417 async with conn .cursor () as cur :
409- await cur .executemany (query , data_tuples )
418+ # 逐个执行插入,而不是使用executemany
419+ for data_tuple in data_tuples :
420+ await cur .execute (query , data_tuple )
421+
410422 self .logger .info (
411- f"Successfully saved/updated batch of { cur . rowcount } HF models."
423+ f"Successfully saved/updated batch of { len ( models_data ) } HF models."
412424 )
413425 except Exception as e :
414426 self .logger .error (f"Error saving hf_models batch to PG: { e } " )
@@ -506,6 +518,13 @@ async def get_all_paper_ids_and_text(
506518
507519 async def close (self ) -> None :
508520 """Closes the connection pool (if applicable and managed by this instance)."""
521+ if hasattr (self , 'pool' ) and self .pool is not None :
522+ try :
523+ await self .pool .close ()
524+ self .logger .info ("PostgreSQL connection pool closed successfully." )
525+ except Exception as e :
526+ self .logger .error (f"Error closing PostgreSQL connection pool: { e } " )
527+ self .logger .debug (traceback .format_exc ())
509528
510529 # --- NEW Method: fetch_one --- #
511530 async def fetch_one (
@@ -710,7 +729,7 @@ async def get_tasks_for_papers(self, paper_ids: List[int]) -> Dict[int, List[str
710729 query = sql .SQL (
711730 """
712731 SELECT paper_id, task_name
713- FROM pwc_tasks -- Corrected table name
732+ FROM pwc_tasks
714733 WHERE paper_id = ANY(%s);
715734 """
716735 )
@@ -745,7 +764,7 @@ async def get_datasets_for_papers(
745764 query = sql .SQL (
746765 """
747766 SELECT paper_id, dataset_name
748- FROM pwc_datasets -- Corrected table name
767+ FROM pwc_datasets
749768 WHERE paper_id = ANY(%s);
750769 """
751770 )
@@ -779,8 +798,8 @@ async def get_repositories_for_papers(
779798 return {}
780799 query = sql .SQL (
781800 """
782- SELECT paper_id, url -- Corrected column name
783- FROM pwc_repositories -- Corrected table name
801+ SELECT paper_id, url
802+ FROM pwc_repositories
784803 WHERE paper_id = ANY(%s);
785804 """
786805 )
0 commit comments