-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
127 lines (98 loc) · 3.39 KB
/
config.py
File metadata and controls
127 lines (98 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Global configuration for the phase-2 jitter-robust unfolded SBL project."""
from __future__ import annotations
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import numpy as np
@dataclass
class ExperimentConfig:
"""Experiment settings for UAV-assisted IoT sensing with coprime array."""
# Strict paper parameters
array_positions_half_lambda: tuple[float, ...] = (0, 3, 5, 6, 9, 10, 12, 15)
num_targets: int = 10
num_snapshots: int = 10
snr_min_db: float = 0.0
snr_max_db: float = 20.0
jitter_std_deg: float = 10.0
angle_min_deg: int = -60
angle_max_deg: int = 60
angle_step_deg: int = 1
# Phase-2 refinement requirements
num_layers: int = 20
lambda_l1: float = 1e-3
# Dataset and optimization
train_samples: int = 20000
val_samples: int = 3000
test_samples: int = 4000
batch_size: int = 64
num_workers: int = 0
num_epochs: int = 100
learning_rate: float = 5e-4
weight_decay: float = 1e-5
lr_step_size: int = 25
lr_gamma: float = 0.7
grad_clip_norm: float = 1.0
kl_warmup_epochs: int = 15
ema_decay: float = 0.999
constraint_start_epoch: int = 8
constraint_ramp_epochs: int = 24
early_stop_patience: int = 30
# Model hyperparameters
ista_init_step: float = 0.18
ista_init_thresh: float = 0.02
interpolator_channels: int = 64
interpolator_dropout: float = 0.05
completion_steps: int = 4
completion_tau: float = 0.05
jitter_prior_std_deg: float = 12.0
offgrid_max_deg: float = 0.5
offgrid_sigma_deg: float = 0.35
sparsify_bias: float = 0.12
local_norm_kernel: int = 5
# Composite loss weights
lambda_rec: float = 0.3
lambda_phys: float = 0.2
lambda_kl: float = 5e-3
lambda_rank: float = 0.15
ranking_margin: float = 0.2
# Reproducibility and paths
seed: int = 2026
checkpoint_dir: str = "checkpoints"
checkpoint_name: str = "jitter_unfolded_sbl_phase2.pt"
figure_dir: str = "figures"
@property
def num_sensors(self) -> int:
"""Number of physical sensors."""
return len(self.array_positions_half_lambda)
@property
def angle_grid_deg(self) -> np.ndarray:
"""Discrete angle grid in degrees."""
return np.arange(
self.angle_min_deg,
self.angle_max_deg + self.angle_step_deg,
self.angle_step_deg,
dtype=np.float32,
)
@property
def num_grid(self) -> int:
"""Grid size."""
return int(self.angle_grid_deg.size)
@property
def checkpoint_path(self) -> Path:
"""Resolved checkpoint path."""
return Path(self.checkpoint_dir) / self.checkpoint_name
def to_dict(self) -> dict[str, Any]:
"""Convert dataclass to dictionary."""
return asdict(self)
def update_config_from_args(cfg: ExperimentConfig, args: Any) -> ExperimentConfig:
"""Update config using matching non-None argparse fields."""
for key in cfg.to_dict():
if hasattr(args, key):
value = getattr(args, key)
if value is not None:
setattr(cfg, key, value)
return cfg
def ensure_output_dirs(cfg: ExperimentConfig) -> None:
"""Create checkpoint and figure directories if needed."""
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.figure_dir).mkdir(parents=True, exist_ok=True)