Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions align_pmids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""
Helper script to align PMIDs between predictions and feedback files.

This script helps ensure both CSV files have overlapping PMIDs for confusion matrix analysis.
"""

import pandas as pd
import sys
from pathlib import Path


def align_pmids(predictions_path, feedback_path, output_dir=None):
"""
Align PMIDs between predictions and feedback files.

Options:
1. Filter predictions to match feedback PMIDs
2. Filter feedback to match prediction PMIDs
3. Find intersection and create aligned files
"""
print("Loading files...")
preds = pd.read_csv(predictions_path)
fb = pd.read_csv(feedback_path)

pred_pmids = set(preds["PMID"].unique())
fb_pmids = set(fb["PMID"].unique())
overlap = pred_pmids.intersection(fb_pmids)

print(f"\nData Summary:")
print(f" Predictions file: {len(pred_pmids)} unique PMIDs")
print(f" Feedback file: {len(fb_pmids)} unique PMIDs")
print(f" Overlapping PMIDs: {len(overlap)}")

if len(overlap) == 0:
print("\n⚠️ WARNING: No overlapping PMIDs found!")
print("You need to align the data before running confusion matrix analysis.")
return

if output_dir is None:
output_dir = Path(".")
else:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)

# Create aligned files
preds_aligned = preds[preds["PMID"].isin(overlap)].copy()
fb_aligned = fb[fb["PMID"].isin(overlap)].copy()

preds_output = output_dir / "analysis_results_aligned.csv"
fb_output = output_dir / "feedback_aligned.csv"

preds_aligned.to_csv(preds_output, index=False)
fb_aligned.to_csv(fb_output, index=False)

print(f"\n✅ Created aligned files:")
print(f" {preds_output} ({len(preds_aligned)} rows)")
print(f" {fb_output} ({len(fb_aligned)} rows)")
print(f"\nYou can now run confusion_matrix_analysis.py with these aligned files:")
print(f" python confusion_matrix_analysis.py {preds_output} {fb_output}")

# Show sample overlapping PMIDs
if len(overlap) <= 10:
print(f"\nOverlapping PMIDs: {sorted(overlap)}")
else:
print(f"\nSample overlapping PMIDs (first 10): {sorted(list(overlap))[:10]}")


if __name__ == "__main__":
if len(sys.argv) < 3:
print("Usage: python align_pmids.py <predictions.csv> <feedback.csv> [output_dir]")
print("\nExample:")
print(" python align_pmids.py analysis_results.csv feedback.csv")
sys.exit(1)

predictions_path = Path(sys.argv[1])
feedback_path = Path(sys.argv[2])
output_dir = sys.argv[3] if len(sys.argv) > 3 else None

if not predictions_path.exists():
print(f"Error: Predictions file not found: {predictions_path}")
sys.exit(1)

if not feedback_path.exists():
print(f"Error: Feedback file not found: {feedback_path}")
sys.exit(1)

align_pmids(predictions_path, feedback_path, output_dir)
128 changes: 113 additions & 15 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,106 @@ def check_image(self):
except subprocess.CalledProcessError:
return False

def load_pmids_from_file(self, file_path: str) -> List[str]:
"""Load PMIDs from a file, supporting txt, csv, xls, and xlsx formats.

For Excel files, uses Docker to read them since pandas is only available in Docker.
"""
file_path_obj = Path(file_path)
file_ext = file_path_obj.suffix.lower()

pmids = []

try:
if file_ext in ['.xls', '.xlsx']:
# Use Docker to read Excel files since pandas is only available in Docker
pmids = self._read_excel_via_docker(file_path)
elif file_ext == '.csv':
# Read CSV using standard library (no pandas needed)
with open(file_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
if row and row[0].strip():
pmids.append(row[0].strip())
else:
# Default to text file (one PMID per line)
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
# Handle comma-separated PMIDs in text files
if "," in line:
pmids.extend([p.strip() for p in line.split(",") if p.strip()])
else:
pmids.append(line)

return pmids
except Exception as e:
raise Exception(f"Error reading file '{file_path}': {e}")

def _read_excel_via_docker(self, file_path: str) -> List[str]:
"""Read Excel file using Docker container where pandas is available."""
file_path_obj = Path(file_path).resolve()
file_dir = file_path_obj.parent
file_name = file_path_obj.name

# Check if Docker is available
if not self.check_docker():
raise Exception("Docker is required to read Excel files but Docker is not available")

# Check if Docker image exists
if not self.check_image():
raise Exception(
f"Docker image '{self.image_name}' not found. "
"Please run 'BioAnalyzer build' first."
)

# Create a Python script to read the Excel file
script = f"""
import pandas as pd
import sys
import json

try:
# Read Excel file - PMIDs should be in the first column
df = pd.read_excel('/workspace/{file_name}')
# Get the first column
first_col = df.iloc[:, 0]
# Extract PMIDs (convert to string and strip whitespace)
pmids = [str(val).strip() for val in first_col if pd.notna(val) and str(val).strip()]
# Output as JSON for easy parsing
print(json.dumps(pmids))
except Exception as e:
print('ERROR: ' + str(e), file=sys.stderr)
sys.exit(1)
"""

try:
# Run the script in Docker
result = subprocess.run(
[
"docker", "run", "--rm",
"-v", f"{file_dir}:/workspace",
"-w", "/workspace",
self.image_name,
"python", "-c", script
],
capture_output=True,
text=True,
check=True,
cwd=file_dir
)

# Parse the JSON output
pmids = json.loads(result.stdout.strip())
return pmids

except subprocess.CalledProcessError as e:
error_msg = e.stderr.strip() if e.stderr else "Unknown error"
raise Exception(f"Failed to read Excel file via Docker: {error_msg}")
except json.JSONDecodeError:
raise Exception(f"Failed to parse Docker output: {result.stdout}")

def build_containers(self):
"""Build Docker containers."""
print("🔨 Building BioAnalyzer containers...")
Expand Down Expand Up @@ -2264,9 +2364,8 @@ def main():

if args.file:
try:
with open(args.file, "r") as f:
file_pmids = [line.strip() for line in f if line.strip()]
pmids.extend(file_pmids)
file_pmids = cli.load_pmids_from_file(args.file)
pmids.extend(file_pmids)
except Exception as e:
print(f"❌ Error reading file: {e}")
return
Expand Down Expand Up @@ -2318,18 +2417,17 @@ def main():

if args.file:
try:
with open(args.file, "r") as f:
file_pmids = [line.strip() for line in f if line.strip()]
if not file_pmids:
print(
f"❌ Error: File '{args.file}' is empty or contains no valid PMIDs."
)
print(
" Please add PMIDs to the file (one per line or comma-separated)."
)
return
pmids.extend(file_pmids)
print(f"📁 Loaded {len(file_pmids)} PMID(s) from {args.file}")
file_pmids = cli.load_pmids_from_file(args.file)
if not file_pmids:
print(
f"❌ Error: File '{args.file}' is empty or contains no valid PMIDs."
)
print(
" Please add PMIDs to the file (one per line or comma-separated)."
)
return
pmids.extend(file_pmids)
print(f"📁 Loaded {len(file_pmids)} PMID(s) from {args.file}")
except FileNotFoundError:
print(f"❌ Error: File '{args.file}' not found.")
return
Expand Down
Loading
Loading