-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathtest_generate_max_tokens.py
More file actions
279 lines (227 loc) · 11.8 KB
/
test_generate_max_tokens.py
File metadata and controls
279 lines (227 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""Regression tests for generate() max_tokens=None TypeError.
Verifies the fix for https://github.com/nari-labs/dia/issues/281:
Error during audio generation or saving:
'<' not supported between instances of 'int' and 'NoneType'
Root cause: when max_tokens=None is passed (e.g. from cli.py's default
--max-tokens=None), the generate() loop compares `dec_step < max_tokens`
which raises TypeError because int < None is not supported in Python 3.
The fix resolves None to config.decoder_config.max_position_embeddings
at the top of generate(), so the comparison is always int < int.
These tests parse source code directly to avoid requiring heavy deps
(torch, torchaudio, descript-audio-codec) in the test environment.
"""
import ast
import os
import sys
import textwrap
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_source(relpath: str) -> str:
with open(os.path.join(_ROOT, relpath)) as f:
return f.read()
def _parse_function(source: str, funcname: str) -> ast.FunctionDef:
"""Parse a module and return the AST node for *funcname* (top-level or method)."""
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == funcname:
return node
raise ValueError(f"{funcname} not found")
# ---------------------------------------------------------------------------
# 1. Signature: max_tokens accepts None
# ---------------------------------------------------------------------------
class TestGenerateSignature:
"""Ensure generate() signature allows None for max_tokens."""
def test_max_tokens_default_is_none(self):
"""generate() default for max_tokens should be None (resolved at runtime)."""
source = _read_source("dia/model.py")
func = _parse_function(source, "generate")
for arg, default in zip(func.args.args, func.args.defaults):
# args includes 'self', defaults align to the right
pass
# Match by name: generate has 'self' then keyword args
# Use kw_defaults for keyword-only, or defaults for positional
arg_names = [a.arg for a in func.args.args]
idx = arg_names.index("max_tokens")
# defaults align right-to-left with args (excluding self)
num_positional = len(func.args.args) - len(func.args.defaults)
default_idx = idx - num_positional
if default_idx >= 0:
default_node = func.args.defaults[default_idx]
assert isinstance(default_node, ast.Constant) and default_node.value is None, (
f"max_tokens default should be None, got: {ast.dump(default_node)}"
)
def test_max_tokens_annotation_allows_none(self):
"""max_tokens type hint must include None."""
source = _read_source("dia/model.py")
# Look for the annotation string in the source
assert "max_tokens: int | None" in source or "max_tokens: Optional[int]" in source, (
"max_tokens annotation should allow None"
)
# ---------------------------------------------------------------------------
# 2. None resolution: uses config fallback
# ---------------------------------------------------------------------------
class TestNoneResolution:
"""Verify max_tokens=None is resolved before the generation loop."""
def test_none_check_exists(self):
"""generate() must contain a check for max_tokens is None."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
assert "max_tokens is None" in func_source, (
"generate() must check for max_tokens is None"
)
def test_resolves_to_config_value(self):
"""generate() must resolve None to config.decoder_config.max_position_embeddings."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
assert "max_position_embeddings" in func_source, (
"generate() must resolve None to config.decoder_config.max_position_embeddings"
)
def test_none_resolution_before_loop(self):
"""The None → int resolution must happen before the while loop comparison."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
lines = func_source.split("\n")
none_check_line = None
while_loop_line = None
for i, line in enumerate(lines):
if "max_tokens is None" in line and none_check_line is None:
none_check_line = i
if "while" in line and "max_tokens" in line and while_loop_line is None:
while_loop_line = i
assert none_check_line is not None, "Could not find 'max_tokens is None' check"
assert while_loop_line is not None, "Could not find 'while ... max_tokens' loop"
assert none_check_line < while_loop_line, (
f"None check (line {none_check_line}) must come before while loop (line {while_loop_line})"
)
def test_explicit_int_not_overridden(self):
"""When max_tokens is an explicit int, it should be preserved."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
# The guard should be `if max_tokens is None:` — only triggers for None
assert "if max_tokens is None:" in func_source, (
"Resolution should use 'if max_tokens is None:' (identity check)"
)
@staticmethod
def _get_generate_source(full_source: str) -> str:
"""Extract the source of Dia.generate from the module source."""
tree = ast.parse(full_source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "generate":
return ast.get_source_segment(full_source, node) or ""
raise ValueError("generate() not found")
# ---------------------------------------------------------------------------
# 3. CLI integration
# ---------------------------------------------------------------------------
class TestCLI:
"""Verify cli.py correctly handles max_tokens=None."""
def test_default_max_tokens_is_none(self):
"""CLI --max-tokens should default to None."""
source = _read_source("cli.py")
# The argparse default should be None
assert "default=None" in source, (
"CLI --max-tokens should default to None"
)
def test_does_not_pass_none_to_generate(self):
"""CLI should only pass max_tokens when explicitly set by the user."""
source = _read_source("cli.py")
# The fix: check args.max_tokens is not None before including in kwargs
assert "args.max_tokens is not None" in source, (
"CLI should guard against passing None max_tokens to generate()"
)
def test_imports_default_sample_rate(self):
"""CLI should import DEFAULT_SAMPLE_RATE from dia.model."""
source = _read_source("cli.py")
assert "DEFAULT_SAMPLE_RATE" in source, (
"CLI should use DEFAULT_SAMPLE_RATE from dia.model"
)
def test_no_hardcoded_sample_rate(self):
"""CLI should not hardcode sample_rate = 44100."""
source = _read_source("cli.py")
# Strip comments
lines = [l.split("#")[0] for l in source.split("\n")]
code_only = "\n".join(lines)
assert "sample_rate = 44100" not in code_only, (
"CLI should not hardcode sample_rate = 44100; use DEFAULT_SAMPLE_RATE"
)
# ---------------------------------------------------------------------------
# 4. _prepare_generation None handling
# ---------------------------------------------------------------------------
class TestPrepareGeneration:
"""Verify _prepare_generation also handles max_tokens=None."""
def test_accepts_none_default(self):
"""_prepare_generation should accept None for max_tokens."""
source = _read_source("dia/model.py")
func = _parse_function(source, "_prepare_generation")
arg_names = [a.arg for a in func.args.args]
idx = arg_names.index("max_tokens")
num_positional = len(func.args.args) - len(func.args.defaults)
default_idx = idx - num_positional
if default_idx >= 0:
default_node = func.args.defaults[default_idx]
assert isinstance(default_node, ast.Constant) and default_node.value is None
# ---------------------------------------------------------------------------
# 5. Constant consistency
# ---------------------------------------------------------------------------
class TestConstants:
"""Verify DEFAULT_SAMPLE_RATE is defined and used consistently."""
def test_default_sample_rate_defined(self):
"""DEFAULT_SAMPLE_RATE should be defined in dia/model.py."""
source = _read_source("dia/model.py")
assert "DEFAULT_SAMPLE_RATE = 44100" in source
def test_save_audio_uses_constant(self):
"""save_audio() should use DEFAULT_SAMPLE_RATE, not a hardcoded value."""
source = _read_source("dia/model.py")
# Find the save_audio method
assert "DEFAULT_SAMPLE_RATE" in source
# Make sure save_audio references DEFAULT_SAMPLE_RATE
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "save_audio":
func_src = ast.get_source_segment(source, node) or ""
assert "DEFAULT_SAMPLE_RATE" in func_src, (
"save_audio() should use DEFAULT_SAMPLE_RATE"
)
break
# ---------------------------------------------------------------------------
# 6. Issue #281 reproduction: the exact error path
# ---------------------------------------------------------------------------
class TestIssue281Reproduction:
"""Verify the specific error from issue #281 cannot occur."""
def test_int_lt_none_raises_typeerror(self):
"""Confirm the original error: int < None is a TypeError in Python 3."""
with pytest.raises(TypeError, match="not supported"):
_ = 5 < None # noqa: B015
def test_no_unguarded_max_tokens_comparison(self):
"""All comparisons involving max_tokens in generate() must happen
after the None → int resolution."""
source = _read_source("dia/model.py")
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "generate":
func_source = ast.get_source_segment(source, node) or ""
lines = func_source.split("\n")
# Find the None guard line
guard_line = None
for i, line in enumerate(lines):
if "max_tokens is None" in line:
guard_line = i
break
assert guard_line is not None, "Missing None guard"
# Every comparison with max_tokens must be after the guard
for i, line in enumerate(lines):
stripped = line.strip()
if i <= guard_line:
continue
# Lines before the guard that compare max_tokens would be bugs
# (we already verified guard comes first, this is a sanity check)
# Verify: no `< max_tokens` or `>= max_tokens` before the guard
for i, line in enumerate(lines):
if i >= guard_line:
break
assert "< max_tokens" not in line and ">= max_tokens" not in line, (
f"Unguarded max_tokens comparison at line {i}: {line.strip()}"
)
break