Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions spider2-lite/evaluation_suite/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,8 @@ def evaluate_spider2sql(args):
error_info = dbms_error_info
else:
pred_pd = pd.read_csv(os.path.join("temp", f"{id}.csv"))
if '_' in id:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
else:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')

pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')

all_files = os.listdir(gold_result_dir)
csv_files = [file for file in all_files if pattern.match(file)]
csv_files = sorted(csv_files)
Expand Down Expand Up @@ -293,10 +290,7 @@ def evaluate_spider2sql(args):
error_info = dbms_error_info
else:
pred_pd = pd.read_csv(os.path.join("temp", f"{id}.csv"))
if '_' in id:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
else:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')

all_files = os.listdir(gold_result_dir)
csv_files = [file for file in all_files if pattern.match(file)]
Expand Down Expand Up @@ -324,10 +318,7 @@ def evaluate_spider2sql(args):
error_info = dbms_error_info
else:
pred_pd = pd.read_csv(os.path.join("temp", f"{id}.csv"))
if '_' in id:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
else:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')

all_files = os.listdir(gold_result_dir)
csv_files = [file for file in all_files if pattern.match(file)]
Expand All @@ -350,10 +341,7 @@ def evaluate_spider2sql(args):
elif mode == "exec_result":
try:
pred_pd = pd.read_csv(os.path.join(args.result_dir, f"{id}.csv"))
if '_' in id:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
else:
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
pattern = re.compile(rf'^{re.escape(id)}(_[a-z])?\.csv$')
all_files = os.listdir(gold_result_dir)
csv_files = [file for file in all_files if pattern.match(file)]
csv_files = sorted(csv_files)
Expand Down