Skip to content

Commit a5e48c1

Browse files
committed
Remove excessive asserts-as-type-validation
1 parent 6b1736c commit a5e48c1

File tree

3 files changed

+47
-65
lines changed

3 files changed

+47
-65
lines changed

cmdstanpy/stanfit/__init__.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -117,42 +117,42 @@ def from_csv(
117117
method, config_dict['method']
118118
)
119119
)
120+
model: str = config_dict['model'] # type: ignore
120121
try:
121122
if config_dict['method'] == 'sample':
122-
assert isinstance(config_dict['num_samples'], int)
123-
assert isinstance(config_dict['num_warmup'], int)
124-
assert isinstance(config_dict['thin'], int)
125-
assert isinstance(config_dict['model'], str)
126123
save_warmup = config_dict['save_warmup'] == 1
127124
chains = len(csvfiles)
125+
num_samples: int = config_dict['num_samples'] # type: ignore
126+
num_warmup: int = config_dict['num_warmup'] # type: ignore
127+
thin: int = config_dict['thin'] # type: ignore
128128
sampler_args = SamplerArgs(
129-
iter_sampling=config_dict['num_samples'],
130-
iter_warmup=config_dict['num_warmup'],
131-
thin=config_dict['thin'],
129+
iter_sampling=num_samples,
130+
iter_warmup=num_warmup,
131+
thin=thin,
132132
save_warmup=save_warmup,
133133
)
134134
# bugfix 425, check for fixed_params output
135135
try:
136136
check_sampler_csv(
137137
csvfiles[0],
138-
iter_sampling=config_dict['num_samples'],
139-
iter_warmup=config_dict['num_warmup'],
140-
thin=config_dict['thin'],
138+
iter_sampling=num_samples,
139+
iter_warmup=num_warmup,
140+
thin=thin,
141141
save_warmup=save_warmup,
142142
)
143143
except ValueError:
144144
try:
145145
check_sampler_csv(
146146
csvfiles[0],
147-
iter_sampling=config_dict['num_samples'],
148-
iter_warmup=config_dict['num_warmup'],
149-
thin=config_dict['thin'],
147+
iter_sampling=num_samples,
148+
iter_warmup=num_warmup,
149+
thin=thin,
150150
save_warmup=save_warmup,
151151
)
152152
sampler_args = SamplerArgs(
153-
iter_sampling=config_dict['num_samples'],
154-
iter_warmup=config_dict['num_warmup'],
155-
thin=config_dict['thin'],
153+
iter_sampling=num_samples,
154+
iter_warmup=num_warmup,
155+
thin=thin,
156156
save_warmup=save_warmup,
157157
fixed_param=True,
158158
)
@@ -162,8 +162,8 @@ def from_csv(
162162
) from e
163163

164164
cmdstan_args = CmdStanArgs(
165-
model_name=config_dict['model'],
166-
model_exe=config_dict['model'],
165+
model_name=model,
166+
model_exe=model,
167167
chain_ids=[x + 1 for x in range(chains)],
168168
method_args=sampler_args,
169169
)
@@ -180,19 +180,18 @@ def from_csv(
180180
"Cannot find optimization algorithm"
181181
" in file {}.".format(csvfiles[0])
182182
)
183-
assert isinstance(config_dict['algorithm'], str)
184-
assert isinstance(config_dict['model'], str)
183+
algorithm: str = config_dict['algorithm'] # type: ignore
185184
save_iterations = config_dict['save_iterations'] == 1
186185
jacobian = config_dict.get('jacobian', 0) == 1
187186

188187
optimize_args = OptimizeArgs(
189-
algorithm=config_dict['algorithm'],
188+
algorithm=algorithm,
190189
save_iterations=save_iterations,
191190
jacobian=jacobian,
192191
)
193192
cmdstan_args = CmdStanArgs(
194-
model_name=config_dict['model'],
195-
model_exe=config_dict['model'],
193+
model_name=model,
194+
model_exe=model,
196195
chain_ids=None,
197196
method_args=optimize_args,
198197
)
@@ -207,28 +206,19 @@ def from_csv(
207206
"Cannot find variational algorithm"
208207
" in file {}.".format(csvfiles[0])
209208
)
210-
assert isinstance(config_dict['model'], str)
211-
assert isinstance(config_dict['algorithm'], str)
212-
assert isinstance(config_dict['iter'], int)
213-
assert isinstance(config_dict['grad_samples'], int)
214-
assert isinstance(config_dict['elbo_samples'], int)
215-
assert isinstance(config_dict['eta'], (int, float))
216-
assert isinstance(config_dict['tol_rel_obj'], float)
217-
assert isinstance(config_dict['eval_elbo'], int)
218-
assert isinstance(config_dict['output_samples'], int)
219209
variational_args = VariationalArgs(
220-
algorithm=config_dict['algorithm'],
221-
iter=config_dict['iter'],
222-
grad_samples=config_dict['grad_samples'],
223-
elbo_samples=config_dict['elbo_samples'],
224-
eta=config_dict['eta'],
225-
tol_rel_obj=config_dict['tol_rel_obj'],
226-
eval_elbo=config_dict['eval_elbo'],
227-
output_samples=config_dict['output_samples'],
210+
algorithm=config_dict['algorithm'], # type: ignore
211+
iter=config_dict['iter'], # type: ignore
212+
grad_samples=config_dict['grad_samples'], # type: ignore
213+
elbo_samples=config_dict['elbo_samples'], # type: ignore
214+
eta=config_dict['eta'], # type: ignore
215+
tol_rel_obj=config_dict['tol_rel_obj'], # type: ignore
216+
eval_elbo=config_dict['eval_elbo'], # type: ignore
217+
output_samples=config_dict['output_samples'], # type: ignore
228218
)
229219
cmdstan_args = CmdStanArgs(
230-
model_name=config_dict['model'],
231-
model_exe=config_dict['model'],
220+
model_name=model,
221+
model_exe=model,
232222
chain_ids=None,
233223
method_args=variational_args,
234224
)
@@ -238,18 +228,15 @@ def from_csv(
238228
runset._set_retcode(i, 0)
239229
return CmdStanVB(runset)
240230
elif config_dict['method'] == 'laplace':
241-
assert isinstance(config_dict['mode'], str)
242-
assert isinstance(config_dict['draws'], int)
243-
assert isinstance(config_dict['model'], str)
244231
jacobian = config_dict['jacobian'] == 1
245232
laplace_args = LaplaceArgs(
246-
mode=config_dict['mode'],
247-
draws=config_dict['draws'],
233+
mode=config_dict['mode'], # type: ignore
234+
draws=config_dict['draws'], # type: ignore
248235
jacobian=jacobian,
249236
)
250237
cmdstan_args = CmdStanArgs(
251-
model_name=config_dict['model'],
252-
model_exe=config_dict['model'],
238+
model_name=model,
239+
model_exe=model,
253240
chain_ids=None,
254241
method_args=laplace_args,
255242
)
@@ -258,21 +245,18 @@ def from_csv(
258245
for i in range(len(runset._retcodes)):
259246
runset._set_retcode(i, 0)
260247
mode: CmdStanMLE = from_csv(
261-
config_dict['mode'],
248+
config_dict['mode'], # type: ignore
262249
method='optimize',
263250
) # type: ignore
264251
return CmdStanLaplace(runset, mode=mode)
265252
elif config_dict['method'] == 'pathfinder':
266-
assert isinstance(config_dict['num_draws'], int)
267-
assert isinstance(config_dict['num_paths'], int)
268-
assert isinstance(config_dict['model'], str)
269253
pathfinder_args = PathfinderArgs(
270-
num_draws=config_dict['num_draws'],
271-
num_paths=config_dict['num_paths'],
254+
num_draws=config_dict['num_draws'], # type: ignore
255+
num_paths=config_dict['num_paths'], # type: ignore
272256
)
273257
cmdstan_args = CmdStanArgs(
274-
model_name=config_dict['model'],
275-
model_exe=config_dict['model'],
258+
model_name=model,
259+
model_exe=model,
276260
chain_ids=None,
277261
method_args=pathfinder_args,
278262
)

cmdstanpy/stanfit/metadata.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def __init__(
2222
"""Initialize object from CSV headers"""
2323
self._cmdstan_config = config
2424

25-
assert isinstance(config['raw_header'], str)
26-
vars = stanio.parse_header(config['raw_header'])
25+
vars = stanio.parse_header(config['raw_header']) # type: ignore
2726

2827
self._method_vars = {
2928
k: v for (k, v) in vars.items() if k.endswith('__')
@@ -58,8 +57,7 @@ def cmdstan_config(self) -> Dict[str, Any]:
5857
@property
5958
def column_names(self) -> Tuple[str, ...]:
6059
col_names = self['column_names']
61-
assert isinstance(col_names, tuple)
62-
return col_names
60+
return col_names # type: ignore
6361

6462
@property
6563
def method_vars(self) -> Dict[str, stanio.Variable]:

cmdstanpy/utils/stancsv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,9 @@ def determine_draw_counts(lines: Iterator[bytes]) -> Tuple[int, int]:
323323
if is_fixed_param:
324324
num_warmup = 0
325325
else:
326-
assert adaptation_block_idx is not None
327-
num_warmup = adaptation_block_idx - header_line_idx - 1
326+
num_warmup = (
327+
adaptation_block_idx - header_line_idx - 1 # type: ignore
328+
)
328329
num_sampling = timing_block_idx - sampling_block_idx
329330
return num_warmup, num_sampling
330331

@@ -493,8 +494,7 @@ def parse_sampler_metadata_from_csv(
493494
draws[0]
494495
):
495496
raise_on_invalid_adaptation_block(comments)
496-
max_depth = config["max_depth"]
497-
assert isinstance(max_depth, int)
497+
max_depth: int = config["max_depth"] # type: ignore
498498
max_tree_hits, divs = extract_max_treedepth_and_divergence_counts(
499499
draws, max_depth, num_warmup
500500
)

0 commit comments

Comments
 (0)