Skip to content

Commit 4741cc8

Browse files
authored
explicitly set *_NUM_THREADS in Returnn run (#424)
1 parent c4b50a6 commit 4741cc8

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

returnn/extract_prior.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ def create_files(self):
163163

164164
def run(self):
165165
cmd = self._get_run_cmd()
166-
sp.check_call(cmd)
166+
env = os.environ.copy()
167+
env["OMP_NUM_THREADS"] = str(self.rqmt["cpu"])
168+
env["MKL_NUM_THREADS"] = str(self.rqmt["cpu"])
169+
sp.check_call(cmd, env=env)
167170

168171
merged_scores = np.loadtxt(self.out_prior_txt_file.get_path(), delimiter=" ")
169172

returnn/forward.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def run(self):
124124
]
125125

126126
try:
127-
sp.check_call(call, cwd=d)
127+
env = os.environ.copy()
128+
env["OMP_NUM_THREADS"] = str(self.rqmt["cpu"])
129+
env["MKL_NUM_THREADS"] = str(self.rqmt["cpu"])
130+
sp.check_call(call, cwd=d, env=env)
128131
except Exception as e:
129132
print("Run crashed - copy temporary work folder as 'crash_dir'")
130133
shutil.copytree(d, "crash_dir")

returnn/search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ def run(self):
127127
os.path.join(self.returnn_root.get_path(), "rnn.py"),
128128
self.out_returnn_config_file.get_path(),
129129
]
130-
sp.check_call(call)
130+
env = os.environ.copy()
131+
env["OMP_NUM_THREADS"] = str(self.rqmt["cpu"])
132+
env["MKL_NUM_THREADS"] = str(self.rqmt["cpu"])
133+
sp.check_call(call, env=env)
131134

132135
@classmethod
133136
def create_returnn_config(

returnn/training.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ def run(self):
353353
print("Cannot read:", exc)
354354
sys.stdout.flush()
355355

356-
sp.check_call(self._get_run_cmd())
356+
env = os.environ.copy()
357+
env["OMP_NUM_THREADS"] = str(self.rqmt["cpu"])
358+
env["MKL_NUM_THREADS"] = str(self.rqmt["cpu"])
359+
sp.check_call(self._get_run_cmd(), env=env)
357360

358361
lrf = self.returnn_config.get("learning_rate_file", "learning_rates")
359362
self._relink(lrf, self.out_learning_rates.get_path())

0 commit comments

Comments
 (0)