Skip to content

Commit 1b88aa9

Browse files
authored
Merge pull request #65 from waldronlab/feature/add-type-hints
Add comprehensive type hints to core functions
2 parents ea46e42 + ca744d4 commit 1b88aa9

File tree

10 files changed

+87
-44
lines changed

10 files changed

+87
-44
lines changed

app/api/app.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import logging
44
import os
55
import traceback
6+
from typing import Dict, Any
67
from fastapi import FastAPI, Request
78
from fastapi.middleware.cors import CORSMiddleware
89
from fastapi.responses import JSONResponse
910
from fastapi.exceptions import RequestValidationError
11+
from app.api.models.api_models import HealthResponse
1012

1113
from app.api.routers import (
1214
bugsigdb_analysis,
@@ -78,7 +80,7 @@
7880

7981

8082
@app.get("/")
81-
async def root():
83+
async def root() -> Dict[str, str]:
8284
"""API root endpoint."""
8385
return {
8486
"message": "BioAnalyzer Backend API",
@@ -91,15 +93,17 @@ async def root():
9193

9294

9395
@app.get("/health")
94-
async def health_check():
96+
async def health_check() -> HealthResponse:
9597
"""Health check endpoint."""
9698
from app.api.routers.system import health_check as system_health_check
9799

98100
return await system_health_check()
99101

100102

101103
@app.exception_handler(RequestValidationError)
102-
async def validation_exception_handler(request: Request, exc: RequestValidationError):
104+
async def validation_exception_handler(
105+
request: Request, exc: RequestValidationError
106+
) -> JSONResponse:
103107
"""Handle request validation errors."""
104108
logger.warning(f"Validation error: {exc.errors()}")
105109
return JSONResponse(
@@ -113,7 +117,9 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
113117

114118

115119
@app.exception_handler(Exception)
116-
async def global_exception_handler(request: Request, exc: Exception):
120+
async def global_exception_handler(
121+
request: Request, exc: Exception
122+
) -> JSONResponse:
117123
"""Handle unexpected exceptions with credential masking."""
118124
from app.utils.credential_masking import mask_exception_message, mask_string
119125

app/api/routers/bugsigdb_analysis.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from fastapi import APIRouter, HTTPException
44
import logging
5+
from typing import Dict, Any
56
from app.utils.credential_masking import mask_exception_message
67
from app.api.utils.constants import ESSENTIAL_FIELDS_INFO, STATUS_VALUES
78
from app.services.bugsigdb_analyzer import analyze_paper_simple
@@ -10,7 +11,7 @@
1011
router = APIRouter(prefix="/api/v1", tags=["BugSigDB Analysis"])
1112

1213

13-
async def _run_analysis(pmid: str):
14+
async def _run_analysis(pmid: str) -> Dict[str, Any]:
1415
"""Run analysis with validated PMID."""
1516
from app.api.utils.api_utils import validate_pmid
1617

@@ -22,15 +23,15 @@ async def _run_analysis(pmid: str):
2223
return result
2324

2425

25-
def _handle_analysis_error(pmid: str, e: Exception):
26+
def _handle_analysis_error(pmid: str, e: Exception) -> None:
2627
"""Handle analysis errors consistently."""
2728
safe_error = mask_exception_message(e)
2829
logger.error(f"Error in analysis for PMID {pmid}: {safe_error}")
2930
raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")
3031

3132

3233
@router.get("/analyze/{pmid}")
33-
async def analyze_paper(pmid: str):
34+
async def analyze_paper(pmid: str) -> Dict[str, Any]:
3435
"""Analyze paper for BugSigDB fields."""
3536
try:
3637
return await _run_analysis(pmid)
@@ -41,7 +42,7 @@ async def analyze_paper(pmid: str):
4142

4243

4344
@router.post("/analyze/{pmid}")
44-
async def analyze_paper_post(pmid: str):
45+
async def analyze_paper_post(pmid: str) -> Dict[str, Any]:
4546
"""Analyze paper for BugSigDB fields (POST method)."""
4647
try:
4748
return await _run_analysis(pmid)

app/api/routers/system.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
logger = logging.getLogger(__name__)
3535
router = APIRouter(prefix="/api/v1", tags=["System"])
3636

37-
_unified_qa = None
38-
_pubmed_retriever = None
37+
_unified_qa: Optional[UnifiedQA] = None
38+
_pubmed_retriever: Optional[PubMedRetriever] = None
3939

40-
def get_unified_qa():
40+
def get_unified_qa() -> Optional[UnifiedQA]:
4141
"""Get or initialize UnifiedQA instance."""
4242
global _unified_qa
4343
if _unified_qa is None:
@@ -55,7 +55,7 @@ def get_unified_qa():
5555
unified_qa = None # Will be initialized on first use via get_unified_qa()
5656

5757

58-
def get_pubmed_retriever():
58+
def get_pubmed_retriever() -> Optional[PubMedRetriever]:
5959
"""Get or initialize PubMedRetriever instance."""
6060
global _pubmed_retriever
6161
if _pubmed_retriever is None:
@@ -69,15 +69,15 @@ def get_pubmed_retriever():
6969

7070

7171
@router.get("/")
72-
async def root():
72+
async def root() -> Any:
7373
"""Redirect to the frontend application."""
7474
from fastapi.responses import RedirectResponse
7575

7676
return RedirectResponse(url="/static/index.html")
7777

7878

7979
@router.get("/health")
80-
async def health_check():
80+
async def health_check() -> HealthResponse:
8181
"""Health check endpoint to verify service is running."""
8282
try:
8383
current_time = get_current_timestamp()
@@ -92,7 +92,7 @@ async def health_check():
9292

9393

9494
@router.get("/config")
95-
async def get_config():
95+
async def get_config() -> ConfigResponse:
9696
"""
9797
**Get configuration settings for the frontend.**
9898

app/api/utils/api_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
import logging
5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Optional, Any
66
from datetime import datetime
77
import pytz
88

@@ -24,7 +24,7 @@ def extract_taxa(text: str) -> List[str]:
2424
return list(set(taxa))
2525

2626

27-
def create_default_field_structure(field_name: str) -> Dict:
27+
def create_default_field_structure(field_name: str) -> Dict[str, Any]:
2828
"""Create a default structure for a missing field."""
2929
field_structures = {
3030
"host_species": {
@@ -83,7 +83,7 @@ def create_default_field_structure(field_name: str) -> Dict:
8383
)
8484

8585

86-
def validate_field_structure(field_data: Dict, field_name: str) -> bool:
86+
def validate_field_structure(field_data: Dict[str, Any], field_name: str) -> bool:
8787
"""Validate that a field has the correct structure."""
8888
required_keys = {
8989
"status",
@@ -99,7 +99,7 @@ def validate_field_structure(field_data: Dict, field_name: str) -> bool:
9999
return required_keys.issubset(field_data.keys())
100100

101101

102-
def create_comprehensive_fallback_analysis() -> Dict:
102+
def create_comprehensive_fallback_analysis() -> Dict[str, Any]:
103103
"""Create a comprehensive fallback analysis when parsing completely fails."""
104104
return {
105105
"host_species": create_default_field_structure("host_species"),
@@ -131,7 +131,7 @@ def generate_curation_summary(parsed_analysis: Dict, missing_fields: List[str])
131131

132132
def get_paper_metadata_from_csv(
133133
pmid: str, csv_path: str = "data/full_dump.csv"
134-
) -> Optional[Dict]:
134+
) -> Optional[Dict[str, str]]:
135135
"""Get paper metadata from CSV file."""
136136
try:
137137
import csv

app/services/advanced_rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def get_rerank_metrics(self) -> Dict[str, Any]:
211211
"""Get performance metrics from the re-ranker."""
212212
return self.reranker.get_metrics()
213213

214-
def get_summary_stats(self, summaries: List[ChunkSummary]) -> Dict:
214+
def get_summary_stats(self, summaries: List[ChunkSummary]) -> Dict[str, Any]:
215215
"""Get statistics about summaries."""
216216
if not summaries:
217217
return {

app/services/bugsigdb_analyzer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import logging
8-
from typing import Dict, Optional, List
8+
from typing import Dict, Optional, List, Any
99
import asyncio
1010
import json
1111

@@ -23,12 +23,12 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26-
_unified_qa = None
27-
_pubmed_retriever = None
28-
_cache_manager = None
26+
_unified_qa: Optional[UnifiedQA] = None
27+
_pubmed_retriever: Optional[PubMedRetriever] = None
28+
_cache_manager: Optional[CacheManager] = None
2929

3030

31-
def get_unified_qa():
31+
def get_unified_qa() -> Optional[UnifiedQA]:
3232
"""Get or initialize UnifiedQA instance."""
3333
global _unified_qa
3434
if _unified_qa is None:
@@ -49,7 +49,7 @@ def get_unified_qa():
4949
return _unified_qa
5050

5151

52-
def get_pubmed_retriever():
52+
def get_pubmed_retriever() -> Optional[PubMedRetriever]:
5353
"""Get or initialize PubMedRetriever instance."""
5454
global _pubmed_retriever
5555
if _pubmed_retriever is None:
@@ -61,15 +61,15 @@ def get_pubmed_retriever():
6161
return _pubmed_retriever
6262

6363

64-
def get_cache_manager():
64+
def get_cache_manager() -> CacheManager:
6565
"""Get or initialize CacheManager instance."""
6666
global _cache_manager
6767
if _cache_manager is None:
6868
_cache_manager = CacheManager()
6969
return _cache_manager
7070

7171

72-
ESSENTIAL_FIELDS = {
72+
ESSENTIAL_FIELDS: Dict[str, str] = {
7373
"host_species": "What host species is being studied in this research?",
7474
"body_site": "What body site or anatomical location was sampled for microbiome analysis?",
7575
"condition": "What disease, treatment, or condition is being studied?",
@@ -79,7 +79,7 @@ def get_cache_manager():
7979
}
8080

8181

82-
async def analyze_paper_simple(pmid: str) -> Optional[Dict]:
82+
async def analyze_paper_simple(pmid: str) -> Optional[Dict[str, Any]]:
8383
"""Extract BugSigDB fields from a paper.
8484
8585
Uses v1 API flow: direct LLM queries per field. Fast but less accurate than RAG.

app/utils/config.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
from pathlib import Path
33
import logging
4+
from typing import List, Optional
45
from .credential_masking import mask_exception_message, mask_string
56

67
try:
78
from dotenv import load_dotenv # type: ignore
89
except ImportError:
910

10-
def load_dotenv(*args, **kwargs): # type: ignore[no-redef]
11+
def load_dotenv(*args: object, **kwargs: object) -> None: # type: ignore[no-redef]
1112
"""Fallback when python-dotenv is not installed."""
1213
logger = logging.getLogger(__name__)
1314
logger.warning(
@@ -45,7 +46,7 @@ def load_dotenv(*args, **kwargs): # type: ignore[no-redef]
4546
LLM_MODEL = os.getenv("LLM_MODEL", "") or None
4647

4748
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gemini")
48-
AVAILABLE_MODELS = []
49+
AVAILABLE_MODELS: List[str] = []
4950

5051
if GEMINI_API_KEY:
5152
AVAILABLE_MODELS.append("gemini")
@@ -57,7 +58,7 @@ def load_dotenv(*args, **kwargs): # type: ignore[no-redef]
5758
AVAILABLE_MODELS.append("ollama")
5859

5960

60-
def validate_gemini_key():
61+
def validate_gemini_key() -> bool:
6162
"""Validate Gemini API key by configuring the client."""
6263
if not GEMINI_API_KEY:
6364
return False
@@ -75,9 +76,9 @@ def validate_gemini_key():
7576
return False
7677

7778

78-
def validate_env_vars():
79+
def validate_env_vars() -> bool:
7980
"""Validate that required environment variables are set."""
80-
missing_vars = []
81+
missing_vars: List[str] = []
8182

8283
if not NCBI_API_KEY:
8384
missing_vars.append("NCBI_API_KEY")
@@ -103,9 +104,9 @@ def validate_env_vars():
103104
validate_env_vars()
104105

105106

106-
def check_required_vars():
107+
def check_required_vars() -> bool:
107108
"""Check if all required environment variables are set."""
108-
missing_vars = []
109+
missing_vars: List[str] = []
109110

110111
if not NCBI_API_KEY:
111112
missing_vars.append("NCBI_API_KEY")
@@ -212,7 +213,7 @@ def check_required_vars():
212213
MAX_LOG_FILES = 5 # Keep 5 rotated log files
213214

214215

215-
def setup_logging():
216+
def setup_logging() -> logging.Logger:
216217
"""Configure logging with file rotation.
217218
218219
Falls back to console-only logging if file handlers can't be created.
@@ -237,7 +238,7 @@ def setup_logging():
237238
root_logger.removeHandler(handler)
238239

239240
# Try to create file handlers, but handle permission errors gracefully
240-
file_handlers_created = []
241+
file_handlers_created: List[str] = []
241242

242243
try:
243244
# Main application log handler with rotation

app/utils/credential_masking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import re
9-
from typing import Optional, Iterable
9+
from typing import Optional, Iterable, Any, Dict
1010

1111
# Environment / config keys that should always be masked
1212
API_KEY_ENV_VARS: set[str] = {
@@ -99,7 +99,7 @@ def mask_string(text: str, show_last: int = 4) -> str:
9999
return masked
100100

101101

102-
def mask_dict(data: dict, keys_to_mask: Optional[Iterable[str]] = None) -> dict:
102+
def mask_dict(data: Dict[str, Any], keys_to_mask: Optional[Iterable[str]] = None) -> Dict[str, Any]:
103103
"""
104104
Recursively mask sensitive values in dictionaries.
105105
"""
@@ -122,7 +122,7 @@ def mask_dict(data: dict, keys_to_mask: Optional[Iterable[str]] = None) -> dict:
122122
return masked
123123

124124

125-
def safe_log_message(message: str, *args, **kwargs) -> str:
125+
def safe_log_message(message: str, *args: Any, **kwargs: Any) -> str:
126126
"""
127127
Format a log message and safely mask any credentials.
128128
"""

cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _get_env_file_values(self) -> Dict[str, str]:
9595
except Exception:
9696
return {}
9797

98-
def _collect_env_flags(self) -> list:
98+
def _collect_env_flags(self) -> List[str]:
9999
"""Collect docker -e flags for known env vars if present in the host environment."""
100100
flags = []
101101
env_file = self._get_env_file_path()
@@ -113,7 +113,7 @@ def _build_api_url(self, path: str) -> str:
113113
suffix = path.lstrip("/")
114114
return f"{base}/{suffix}" if suffix else base
115115

116-
def _validate_environment(self):
116+
def _validate_environment(self) -> None:
117117
"""Warn about missing critical env vars before starting containers."""
118118
env_file = self._get_env_file_path()
119119
env_file_values = self._get_env_file_values()

0 commit comments

Comments
 (0)