Skip to content

Commit d1c315e

Browse files
JackTemakialbertz
andauthored
Add PickOptimalParametersJob (#491)
* Add PickOptimalParametersJob Job that takes pickleable parameters and a tk.Variable value, and just returns the paramaters corresponding to the highest or lowest value. Co-authored-by: Albert Zeyer <[email protected]>
1 parent dc66f34 commit d1c315e

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tools/parameter_tuning.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from sisyphus import Job, Task, tk
2+
from typing import Any, Literal, Sequence, Union
3+
4+
import numpy as np
5+
6+
from i6_core.util import instanciate_delayed
7+
8+
9+
class GetOptimalParametersAsVariableJob(Job):
10+
"""
11+
Pick a set of optimal parameters based on their assigned (dynamic) score value.
12+
Each optimal parameter is outputted individually to be accessible in the Sisyphus manager.
13+
14+
Can be used to e.g. pick best lm-scale and prior scale to a corresponding ScliteJob.out_wer.
15+
"""
16+
17+
def __init__(
18+
self,
19+
*,
20+
parameters: Sequence[Sequence[Any]],
21+
values: Sequence[tk.Variable],
22+
mode: Union[Literal["maximize"], Literal["minimize"]],
23+
):
24+
"""
25+
:param parameters: parameters[best_idx] will be written to self.out_optimal_parameters
26+
as Sisyphus output variables.
27+
parameters[best_idx] (and thus self.out_optimal_parameters) is a sequence of fixed length,
28+
to allow to index into it.
29+
Thus, len(parameters[i]) must be the same for all i.
30+
:param values: best_idx = argmax(values) or argmin(values).
31+
Must have len(values) == len(parameters).
32+
Some calculations might be done using DelayedOps math beforehand.
33+
:param mode: "minimize" or "maximize"
34+
"""
35+
assert len(parameters) == len(values)
36+
for param in parameters[1:]:
37+
assert len(param) == len(parameters[0]), "all entries should have the same number of parameters"
38+
assert mode in ["minimize", "maximize"]
39+
self.parameters = parameters
40+
self.values = values
41+
self.mode = mode
42+
self.num_parameters = len(parameters[0])
43+
44+
self.out_optimal_parameters = [self.output_var("param_%i" % i, pickle=True) for i in range(self.num_parameters)]
45+
46+
def tasks(self):
47+
yield Task("run", mini_task=True)
48+
49+
def run(self):
50+
values = instanciate_delayed(self.values)
51+
52+
if self.mode == "minimize":
53+
index = np.argmin(values)
54+
else:
55+
index = np.argmax(values)
56+
57+
best_parameters = self.parameters[index]
58+
59+
for i, param in enumerate(best_parameters):
60+
self.out_optimal_parameters[i].set(param)

0 commit comments

Comments
 (0)