Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import soundfile as sf
import torch

from dia.model import Dia
from dia.model import DEFAULT_SAMPLE_RATE, Dia


def set_seed(seed: int):
Expand Down Expand Up @@ -117,22 +117,23 @@ def main():
# Generate audio
print("Generating audio...")
try:
sample_rate = 44100 # Default assumption

output_audio = model.generate(
gen_kwargs = dict(
text=args.text,
audio_prompt=args.audio_prompt,
max_tokens=args.max_tokens,
cfg_scale=args.cfg_scale,
temperature=args.temperature,
top_p=args.top_p,
)
if args.max_tokens is not None:
gen_kwargs["max_tokens"] = args.max_tokens

output_audio = model.generate(**gen_kwargs)
print("Audio generation complete.")

print(f"Saving audio to {args.output}...")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)

sf.write(args.output, output_audio, sample_rate)
sf.write(args.output, output_audio, DEFAULT_SAMPLE_RATE)
print(f"Audio successfully saved to {args.output}")

except Exception as e:
Expand Down
7 changes: 5 additions & 2 deletions dia/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def save_audio(self, path: str, audio: np.ndarray):
def generate(
self,
text: str | list[str],
max_tokens: int = 3072,
max_tokens: int | None = None,
cfg_scale: float = 3.0,
temperature: float = 1.2,
top_p: float = 0.95,
Expand All @@ -610,7 +610,8 @@ def generate(
Args:
text: The input text prompt, or a list of text prompts for batch generation.
max_tokens: The maximum number of audio tokens to generate per prompt.
Defaults to the model's configured audio length if None.
Defaults to the model's configured decoder max_position_embeddings
if None.
cfg_scale: The scale factor for classifier-free guidance (CFG). Higher values
lead to stronger guidance towards the text prompt.
temperature: The temperature for sampling. Higher values increase randomness.
Expand All @@ -637,6 +638,8 @@ def generate(
sequence if no audio was generated for it.
"""
batch_size = len(text) if isinstance(text, list) else 1
if max_tokens is None:
max_tokens = self.config.decoder_config.max_position_embeddings
audio_eos_value = self.config.eos_token_id
audio_pad_value = self.config.pad_token_id
delay_pattern = self.config.delay_pattern
Expand Down
279 changes: 279 additions & 0 deletions tests/test_generate_max_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""Regression tests for generate() max_tokens=None TypeError.

Verifies the fix for https://github.com/nari-labs/dia/issues/281:
Error during audio generation or saving:
'<' not supported between instances of 'int' and 'NoneType'

Root cause: when max_tokens=None is passed (e.g. from cli.py's default
--max-tokens=None), the generate() loop compares `dec_step < max_tokens`
which raises TypeError because int < None is not supported in Python 3.

The fix resolves None to config.decoder_config.max_position_embeddings
at the top of generate(), so the comparison is always int < int.

These tests parse source code directly to avoid requiring heavy deps
(torch, torchaudio, descript-audio-codec) in the test environment.
"""

import ast
import os
import sys
import textwrap

import pytest

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def _read_source(relpath: str) -> str:
with open(os.path.join(_ROOT, relpath)) as f:
return f.read()


def _parse_function(source: str, funcname: str) -> ast.FunctionDef:
"""Parse a module and return the AST node for *funcname* (top-level or method)."""
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == funcname:
return node
raise ValueError(f"{funcname} not found")


# ---------------------------------------------------------------------------
# 1. Signature: max_tokens accepts None
# ---------------------------------------------------------------------------

class TestGenerateSignature:
"""Ensure generate() signature allows None for max_tokens."""

def test_max_tokens_default_is_none(self):
"""generate() default for max_tokens should be None (resolved at runtime)."""
source = _read_source("dia/model.py")
func = _parse_function(source, "generate")
for arg, default in zip(func.args.args, func.args.defaults):
# args includes 'self', defaults align to the right
pass
# Match by name: generate has 'self' then keyword args
# Use kw_defaults for keyword-only, or defaults for positional
arg_names = [a.arg for a in func.args.args]
idx = arg_names.index("max_tokens")
# defaults align right-to-left with args (excluding self)
num_positional = len(func.args.args) - len(func.args.defaults)
default_idx = idx - num_positional
if default_idx >= 0:
default_node = func.args.defaults[default_idx]
assert isinstance(default_node, ast.Constant) and default_node.value is None, (
f"max_tokens default should be None, got: {ast.dump(default_node)}"
)

def test_max_tokens_annotation_allows_none(self):
"""max_tokens type hint must include None."""
source = _read_source("dia/model.py")
# Look for the annotation string in the source
assert "max_tokens: int | None" in source or "max_tokens: Optional[int]" in source, (
"max_tokens annotation should allow None"
)


# ---------------------------------------------------------------------------
# 2. None resolution: uses config fallback
# ---------------------------------------------------------------------------

class TestNoneResolution:
"""Verify max_tokens=None is resolved before the generation loop."""

def test_none_check_exists(self):
"""generate() must contain a check for max_tokens is None."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
assert "max_tokens is None" in func_source, (
"generate() must check for max_tokens is None"
)

def test_resolves_to_config_value(self):
"""generate() must resolve None to config.decoder_config.max_position_embeddings."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
assert "max_position_embeddings" in func_source, (
"generate() must resolve None to config.decoder_config.max_position_embeddings"
)

def test_none_resolution_before_loop(self):
"""The None → int resolution must happen before the while loop comparison."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
lines = func_source.split("\n")

none_check_line = None
while_loop_line = None
for i, line in enumerate(lines):
if "max_tokens is None" in line and none_check_line is None:
none_check_line = i
if "while" in line and "max_tokens" in line and while_loop_line is None:
while_loop_line = i

assert none_check_line is not None, "Could not find 'max_tokens is None' check"
assert while_loop_line is not None, "Could not find 'while ... max_tokens' loop"
assert none_check_line < while_loop_line, (
f"None check (line {none_check_line}) must come before while loop (line {while_loop_line})"
)

def test_explicit_int_not_overridden(self):
"""When max_tokens is an explicit int, it should be preserved."""
source = _read_source("dia/model.py")
func_source = self._get_generate_source(source)
# The guard should be `if max_tokens is None:` — only triggers for None
assert "if max_tokens is None:" in func_source, (
"Resolution should use 'if max_tokens is None:' (identity check)"
)

@staticmethod
def _get_generate_source(full_source: str) -> str:
"""Extract the source of Dia.generate from the module source."""
tree = ast.parse(full_source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "generate":
return ast.get_source_segment(full_source, node) or ""
raise ValueError("generate() not found")


# ---------------------------------------------------------------------------
# 3. CLI integration
# ---------------------------------------------------------------------------

class TestCLI:
"""Verify cli.py correctly handles max_tokens=None."""

def test_default_max_tokens_is_none(self):
"""CLI --max-tokens should default to None."""
source = _read_source("cli.py")
# The argparse default should be None
assert "default=None" in source, (
"CLI --max-tokens should default to None"
)

def test_does_not_pass_none_to_generate(self):
"""CLI should only pass max_tokens when explicitly set by the user."""
source = _read_source("cli.py")
# The fix: check args.max_tokens is not None before including in kwargs
assert "args.max_tokens is not None" in source, (
"CLI should guard against passing None max_tokens to generate()"
)

def test_imports_default_sample_rate(self):
"""CLI should import DEFAULT_SAMPLE_RATE from dia.model."""
source = _read_source("cli.py")
assert "DEFAULT_SAMPLE_RATE" in source, (
"CLI should use DEFAULT_SAMPLE_RATE from dia.model"
)

def test_no_hardcoded_sample_rate(self):
"""CLI should not hardcode sample_rate = 44100."""
source = _read_source("cli.py")
# Strip comments
lines = [l.split("#")[0] for l in source.split("\n")]
code_only = "\n".join(lines)
assert "sample_rate = 44100" not in code_only, (
"CLI should not hardcode sample_rate = 44100; use DEFAULT_SAMPLE_RATE"
)


# ---------------------------------------------------------------------------
# 4. _prepare_generation None handling
# ---------------------------------------------------------------------------

class TestPrepareGeneration:
"""Verify _prepare_generation also handles max_tokens=None."""

def test_accepts_none_default(self):
"""_prepare_generation should accept None for max_tokens."""
source = _read_source("dia/model.py")
func = _parse_function(source, "_prepare_generation")
arg_names = [a.arg for a in func.args.args]
idx = arg_names.index("max_tokens")
num_positional = len(func.args.args) - len(func.args.defaults)
default_idx = idx - num_positional
if default_idx >= 0:
default_node = func.args.defaults[default_idx]
assert isinstance(default_node, ast.Constant) and default_node.value is None


# ---------------------------------------------------------------------------
# 5. Constant consistency
# ---------------------------------------------------------------------------

class TestConstants:
"""Verify DEFAULT_SAMPLE_RATE is defined and used consistently."""

def test_default_sample_rate_defined(self):
"""DEFAULT_SAMPLE_RATE should be defined in dia/model.py."""
source = _read_source("dia/model.py")
assert "DEFAULT_SAMPLE_RATE = 44100" in source

def test_save_audio_uses_constant(self):
"""save_audio() should use DEFAULT_SAMPLE_RATE, not a hardcoded value."""
source = _read_source("dia/model.py")
# Find the save_audio method
assert "DEFAULT_SAMPLE_RATE" in source
# Make sure save_audio references DEFAULT_SAMPLE_RATE
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "save_audio":
func_src = ast.get_source_segment(source, node) or ""
assert "DEFAULT_SAMPLE_RATE" in func_src, (
"save_audio() should use DEFAULT_SAMPLE_RATE"
)
break


# ---------------------------------------------------------------------------
# 6. Issue #281 reproduction: the exact error path
# ---------------------------------------------------------------------------

class TestIssue281Reproduction:
"""Verify the specific error from issue #281 cannot occur."""

def test_int_lt_none_raises_typeerror(self):
"""Confirm the original error: int < None is a TypeError in Python 3."""
with pytest.raises(TypeError, match="not supported"):
_ = 5 < None # noqa: B015

def test_no_unguarded_max_tokens_comparison(self):
"""All comparisons involving max_tokens in generate() must happen
after the None → int resolution."""
source = _read_source("dia/model.py")
tree = ast.parse(source)

for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "generate":
func_source = ast.get_source_segment(source, node) or ""
lines = func_source.split("\n")

# Find the None guard line
guard_line = None
for i, line in enumerate(lines):
if "max_tokens is None" in line:
guard_line = i
break
assert guard_line is not None, "Missing None guard"

# Every comparison with max_tokens must be after the guard
for i, line in enumerate(lines):
stripped = line.strip()
if i <= guard_line:
continue
# Lines before the guard that compare max_tokens would be bugs
# (we already verified guard comes first, this is a sanity check)

# Verify: no `< max_tokens` or `>= max_tokens` before the guard
for i, line in enumerate(lines):
if i >= guard_line:
break
assert "< max_tokens" not in line and ">= max_tokens" not in line, (
f"Unguarded max_tokens comparison at line {i}: {line.strip()}"
)
break