Skip to content

Commit a2e1d5d

Browse files
committed
Tests fixed and added ruff and updated README.md
1 parent c385adc commit a2e1d5d

File tree

16 files changed

+485
-216
lines changed

16 files changed

+485
-216
lines changed

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,45 @@ python predict.py --folder path/to/audio/files --model model/saved_models/your_m
101101
- Training data should be organized in the `data/train/` directory
102102
- Model checkpoints are saved in `model/saved_models/`
103103
- The project is installed as a Python package for proper import handling
104+
105+
## Code Quality
106+
107+
This project uses Ruff for both linting and formatting Python code. Ruff is a fast Python linter and formatter written in Rust.
108+
109+
### Using Ruff
110+
111+
1. Install Ruff (it's already included in the dev dependencies):
112+
```bash
113+
# Using pip (recommended if you want to use your existing virtual environment)
114+
pip install -e ".[dev]"
115+
116+
# OR using uv pip (if you want to use uv but keep your current virtual environment)
117+
uv pip install -e ".[dev]"
118+
119+
# Note: Do NOT use 'uv venv' unless you want to create a new virtual environment
120+
# with pyenv. If you want to use uv while keeping your current environment,
121+
# use 'uv pip' instead.
122+
```
123+
124+
2. Format your code:
125+
```bash
126+
ruff format .
127+
```
128+
129+
3. Lint your code:
130+
```bash
131+
ruff check .
132+
```
133+
134+
4. Fix linting issues automatically:
135+
```bash
136+
ruff check --fix .
137+
```
138+
139+
The Ruff configuration is in `pyproject.toml`. Currently, it:
140+
- Uses a line length of 88 characters (same as Black)
141+
- Targets Python 3.9
142+
- Enables pycodestyle (`E`) and Pyflakes (`F`) rules by default
143+
- Ignores line length violations (`E501`)
144+
145+
You can customize the Ruff configuration by modifying the `[tool.ruff]` section in `pyproject.toml`.

dataset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Dataset package for audio processing AI."""
1+
"""Dataset package for audio processing AI."""

dataset/ai_audio_dataset.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,24 @@
55
from typing import Tuple
66
from zipfile import ZipFile
77

8+
89
class AIAudioDataset(torch.utils.data.Dataset):
910
def __init__(self, root_dir, sample_rate=16000, duration=10.0) -> None:
1011
"""
1112
Args:
12-
root_dir: str Path to .mp3 and .wav files to create the training.
13+
root_dir: str Path to .mp3 and .wav files to create the training.
1314
sample_rate: int set to 16khz for training
1415
duration: float set to 10 seconds for sampling
1516
"""
16-
self.paths = list(Path(root_dir).rglob("*.wav")) + list(Path(root_dir).rglob("*.mp3"))
17+
self.paths = list(Path(root_dir).rglob("*.wav")) + list(
18+
Path(root_dir).rglob("*.mp3")
19+
)
1720
self.sample_rate = sample_rate
1821
self.duration = duration
1922
self.audio_len = int(sample_rate * duration)
2023

21-
2224
def __len__(self) -> int:
2325
return len(self.paths)
24-
2526

2627
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
2728
path = self.paths[idx]
@@ -31,7 +32,7 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
3132
waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
3233

3334
if waveform.shape[0] > 1:
34-
waveform = waveform.mean(dim=0,keepdim=True)
35+
waveform = waveform.mean(dim=0, keepdim=True)
3536

3637
total_len = waveform.shape[1]
3738

@@ -40,18 +41,18 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
4041
waveform = torch.nn.functional.pad(waveform, (0, pad_len))
4142
else:
4243
start = random.randint(0, total_len - self.audio_len)
43-
waveform = waveform[:, start:start + self.audio_len]
44+
waveform = waveform[:, start : start + self.audio_len]
4445

4546
mel_spec = torchaudio.transforms.MelSpectrogram(
4647
sample_rate=self.sample_rate,
4748
n_fft=1024,
4849
hop_length=320,
4950
n_mels=64,
5051
f_min=50,
51-
f_max=8000
52+
f_max=8000,
5253
)(waveform)
5354

5455
logmel = torch.log(mel_spec + 1e-6)
5556

5657
label = 1 if "ai" in str(path).lower() else 0
57-
return logmel.squeeze(0), torch.tensor(label)
58+
return logmel.squeeze(0), torch.tensor(label)

inference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Inference package for audio processing AI."""
1+
"""Inference package for audio processing AI."""

model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Model package for audio processing AI."""
1+
"""Model package for audio processing AI."""

0 commit comments

Comments
 (0)