Skip to content

Commit 17b1a7f

Browse files
committed
Merge branch 'main' into zeineldeen_att_decoder
2 parents 72ae8ac + 83ff39e commit 17b1a7f

29 files changed

+2388
-44
lines changed

.github/workflows/black.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ jobs:
1010
check-black-formatting:
1111
runs-on: ubuntu-latest
1212
steps:
13-
- uses: actions/checkout@v2
14-
- uses: actions/setup-python@v2
13+
- uses: actions/checkout@v4
14+
- uses: actions/setup-python@v4
1515
with:
1616
python-version: 3.8
1717
cache: 'pip'

.github/workflows/model_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ jobs:
1010
test-jobs:
1111
runs-on: ubuntu-latest
1212
steps:
13-
- uses: actions/checkout@v2
13+
- uses: actions/checkout@v4
1414
with:
1515
repository: "rwth-i6/i6_models"
1616
path: ""
17-
- uses: actions/setup-python@v2
17+
- uses: actions/setup-python@v4
1818
with:
1919
python-version: 3.8
2020
cache: 'pip'

.github/workflows/publish.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: Publish
2+
3+
on:
4+
workflow_run:
5+
workflows: ["model_tests"]
6+
branches: [main]
7+
types:
8+
- completed
9+
10+
jobs:
11+
publish:
12+
if: >-
13+
github.event.workflow_run.conclusion == 'success' &&
14+
github.event.workflow_run.head_branch == 'main' &&
15+
github.event.workflow_run.event == 'push' &&
16+
github.repository == 'rwth-i6/i6_models'
17+
runs-on: ubuntu-latest
18+
19+
steps:
20+
- uses: actions/checkout@v4
21+
22+
- uses: actions/setup-python@v4
23+
with:
24+
python-version: 3.8
25+
26+
- name: Install Python deps
27+
run: |
28+
echo "PATH=$PATH:$HOME/.local/bin" >> $GITHUB_ENV
29+
pip3 install --user --upgrade pip setuptools wheel twine
30+
31+
- run: python3 setup.py sdist
32+
33+
# https://github.com/marketplace/actions/pypi-publish
34+
- name: Publish to PyPI
35+
# https://github.com/pypa/gh-action-pypi-publish/issues/112
36+
uses: pypa/gh-action-pypi-publish@release/v1.4
37+
with:
38+
user: __token__
39+
password: ${{ secrets.pypi_password }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ share/python-wheels/
2525
.installed.cfg
2626
*.egg
2727
MANIFEST
28+
/_setup_info_generated.py
2829

2930
# PyInstaller
3031
# Usually these files are written by a python script from a template

MANIFEST.in

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# https://packaging.python.org/guides/using-manifest-in/
2+
3+
include MANIFEST.in
4+
include _setup_info_generated.py
5+
6+
include LICENSE
7+
include CODEOWNERS
8+
9+
include .editorconfig
10+
include .kateconfig
11+
include .gitmodules
12+
include .gitignore
13+
14+
include *.py
15+
include *.rst
16+
include *.md
17+
include *.txt
18+
include *.toml
19+
graft i6_models
20+
21+
graft tests
22+
23+
global-exclude *.py[cod]
24+
global-exclude __pycache__
25+
global-exclude .history*

i6_models/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1+
"""
2+
i6_models root import
3+
"""
14

5+
from .__setup__ import get_version_str as _get_version_str
6+
7+
__long_version__ = _get_version_str(
8+
fallback="1.0.0+unknown", long=True, verbose_error=True
9+
) # `SemVer <https://semver.org/>`__ compatible
10+
__version__ = __long_version__[: __long_version__.index("+")] # distutils.version.StrictVersion compatible
11+
__git_version__ = __long_version__ # just an alias, to keep similar to other projects

i6_models/__setup__.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""
2+
Used by setup.py.
3+
"""
4+
5+
from __future__ import annotations
6+
from pprint import pprint
7+
import os
8+
import sys
9+
10+
11+
_my_dir = os.path.dirname(os.path.abspath(__file__))
12+
# Use realpath to resolve any symlinks. We want the real root-dir, to be able to check the Git revision.
13+
_root_dir = os.path.dirname(os.path.realpath(_my_dir))
14+
15+
16+
def debug_print_file(fn):
17+
"""
18+
:param str fn:
19+
"""
20+
print("%s:" % fn)
21+
if not os.path.exists(fn):
22+
print("<does not exist>")
23+
return
24+
if os.path.isdir(fn):
25+
print("<dir:>")
26+
pprint(sorted(os.listdir(fn)))
27+
return
28+
print(open(fn).read())
29+
30+
31+
def parse_pkg_info(fn):
32+
"""
33+
:param str fn:
34+
:return: dict with info written by distutils. e.g. ``res["Version"]`` is the version.
35+
:rtype: dict[str,str]
36+
"""
37+
res = {}
38+
for ln in open(fn).read().splitlines():
39+
if not ln or not ln[:1].strip():
40+
continue
41+
key, value = ln.split(": ", 1)
42+
res[key] = value
43+
return res
44+
45+
46+
def get_version_str(verbose=False, verbose_error=False, fallback=None, long=False):
47+
"""
48+
:param bool verbose: print exactly how we end up with some version
49+
:param bool verbose_error: print only any potential errors
50+
:param str|None fallback:
51+
:param bool long:
52+
False: Always distutils.version.StrictVersion compatible. just like "1.20190202.154527".
53+
True: Will also add the revision string, like "1.20180724.141845+git.7865d01".
54+
The format might change in the future.
55+
We will keep it `SemVer <https://semver.org/>`__ compatible.
56+
I.e. the string before the `"+"` will be the short version.
57+
We always make sure that there is a `"+"` in the string.
58+
:rtype: str
59+
"""
60+
# Earlier we checked PKG-INFO, via parse_pkg_info. Both in the root-dir and in my-dir.
61+
# Now we should always have _setup_info_generated.py, copied by our own setup.
62+
# Do not use PKG-INFO at all anymore (for now), as it would only have the short version.
63+
# Only check _setup_info_generated in the current dir, not in the root-dir,
64+
# because we want to only use it if this was installed via a package.
65+
# Otherwise, we want the current Git version.
66+
if os.path.exists("%s/_setup_info_generated.py" % _my_dir):
67+
# noinspection PyUnresolvedReferences
68+
from . import _setup_info_generated as info
69+
70+
if verbose:
71+
print("Found _setup_info_generated.py, long version %r, version %r." % (info.long_version, info.version))
72+
if long:
73+
assert "+" in info.long_version
74+
return info.long_version
75+
return info.version
76+
77+
info_in_root_filename = "%s/_setup_info_generated.py" % _root_dir
78+
if os.path.exists(info_in_root_filename):
79+
# The root dir might not be in sys.path, so just load directly.
80+
code = compile(open(info_in_root_filename).read(), info_in_root_filename, "exec")
81+
info = {}
82+
eval(code, info)
83+
version = info["version"]
84+
long_version = info["long_version"]
85+
if verbose:
86+
print("Found %r in root, long version %r, version %r." % (info_in_root_filename, long_version, version))
87+
if long:
88+
assert "+" in long_version
89+
return long_version
90+
return version
91+
92+
if os.path.exists("%s/.git" % _root_dir):
93+
try:
94+
version = git_head_version(git_dir=_root_dir, long=long)
95+
if verbose:
96+
print("Version via Git:", version)
97+
if long:
98+
assert "+" in version
99+
return version
100+
except Exception as exc:
101+
if verbose or verbose_error:
102+
print("Exception while getting Git version:", exc)
103+
sys.excepthook(*sys.exc_info())
104+
if not fallback:
105+
raise # no fallback
106+
107+
if fallback:
108+
if verbose:
109+
print("Version via fallback:", fallback)
110+
if long:
111+
assert "+" in fallback
112+
return fallback
113+
raise Exception("Cannot get RETURNN version.")
114+
115+
116+
def git_head_version(git_dir=_root_dir, long=False):
117+
"""
118+
:param str git_dir:
119+
:param bool long: see :func:`get_version_str`
120+
:rtype: str
121+
"""
122+
commit_date = git_commit_date(git_dir=git_dir) # like "20190202.154527"
123+
version = "1.%s" % commit_date # distutils.version.StrictVersion compatible
124+
if long:
125+
# Keep SemVer compatible.
126+
rev = git_commit_rev(git_dir=git_dir)
127+
version += "+git.%s" % rev
128+
if git_is_dirty(git_dir=git_dir):
129+
version += ".dirty"
130+
return version
131+
132+
133+
def git_commit_date(commit="HEAD", git_dir="."):
134+
"""
135+
:param str commit:
136+
:param str git_dir:
137+
:rtype: str
138+
"""
139+
return (
140+
sys_exec_out("git", "show", "-s", "--format=%ci", commit, cwd=git_dir)
141+
.strip()[:-6]
142+
.replace(":", "")
143+
.replace("-", "")
144+
.replace(" ", ".")
145+
)
146+
147+
148+
def git_commit_rev(commit="HEAD", git_dir=".", length=None):
149+
"""
150+
:param str commit:
151+
:param str git_dir:
152+
:param int|None length:
153+
:rtype: str
154+
"""
155+
if commit is None:
156+
commit = "HEAD"
157+
return sys_exec_out("git", "rev-parse", "--short=%i" % length if length else "--short", commit, cwd=git_dir).strip()
158+
159+
160+
def git_is_dirty(git_dir: str = ".") -> bool:
161+
"""
162+
:param git_dir:
163+
:return: whether it is dirty
164+
"""
165+
r = sys_exec_ret_code("git", "diff", "--no-ext-diff", "--quiet", "--exit-code", cwd=git_dir)
166+
if r == 0:
167+
return False
168+
if r == 1:
169+
return True
170+
assert False, "bad return %i" % r
171+
172+
173+
def sys_exec_out(*args, **kwargs) -> str:
174+
"""
175+
:param str args: for subprocess.Popen
176+
:param kwargs: for subprocess.Popen
177+
:return: stdout as str (assumes utf8)
178+
"""
179+
from subprocess import Popen, PIPE, CalledProcessError
180+
181+
kwargs.setdefault("shell", False)
182+
p = Popen(args, stdin=PIPE, stdout=PIPE, **kwargs)
183+
out, _ = p.communicate()
184+
if p.returncode != 0:
185+
raise CalledProcessError(p.returncode, args)
186+
if isinstance(out, bytes):
187+
out = out.decode("utf8")
188+
assert isinstance(out, str)
189+
return out
190+
191+
192+
def sys_exec_ret_code(*args, **kwargs) -> int:
193+
"""
194+
:param str args: for subprocess.call
195+
:param kwargs: for subprocess.call
196+
:return: return code
197+
"""
198+
import subprocess
199+
200+
res = subprocess.call(args, shell=False, **kwargs)
201+
valid = kwargs.get("valid", (0, 1))
202+
if valid is not None:
203+
if res not in valid:
204+
raise subprocess.CalledProcessError(res, args)
205+
return res

i6_models/assemblies/conformer/conformer_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, cfg: ConformerBlockV1Config):
5252
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
5353
"""
5454
:param x: input tensor of shape [B, T, F]
55-
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
55+
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
5656
:return: torch.Tensor of shape [B, T, F]
5757
"""
5858
x = 0.5 * self.ff1(x) + x # [B, T, F]
@@ -98,7 +98,7 @@ def __init__(self, cfg: ConformerEncoderV1Config):
9898
def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
9999
"""
100100
:param data_tensor: input tensor of shape [B, T', F]
101-
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T']
101+
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T']
102102
:return: (output, out_seq_mask)
103103
where output is torch.Tensor of shape [B, T, F'],
104104
out_seq_mask is a torch.Tensor of shape [B, T]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .e_branchformer_v1 import *

0 commit comments

Comments
 (0)