Skip to content

Commit 01c7830

Browse files
authored
V1 of working pipeline (#1)
1 parent a2e1d5d commit 01c7830

15 files changed

+1520
-1780
lines changed

.gitignore

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,10 @@ cython_debug/
200200
model/pretrained/pretrained_models/*
201201

202202
# Prevent any .pth files being added as a saved model.
203-
model/pretrained/saved_models/*.pth*
203+
model/saved_models/*.pth*
204204

205205
# No data in the data/train folder
206-
data/train/*
206+
data/train/*
207+
208+
# No data in data_split folder
209+
data_split/*

README.md

Lines changed: 29 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,6 @@ This project uses deep learning to analyze audio files and detect AI-generated c
44

55
## Setup
66

7-
You can set up this project using either `uv` (recommended) or `pip`.
8-
9-
### Option 1: Using uv (Recommended)
10-
11-
1. Install `uv` if you haven't already:
12-
```bash
13-
curl -LsSf https://astral.sh/uv/install.sh | sh
14-
```
15-
16-
2. Create and activate a virtual environment:
17-
```bash
18-
uv venv
19-
source .venv/bin/activate # On Unix/macOS
20-
# or
21-
.venv\Scripts\activate # On Windows
22-
```
23-
24-
3. Install the package in development mode:
25-
```bash
26-
uv pip install -e .
27-
```
28-
29-
### Option 2: Using pip
30-
317
1. Create and activate a virtual environment:
328
```bash
339
python -m venv .venv
@@ -50,39 +26,53 @@ For the training step, I used this file from here [Link text][https://github.com
5026
For the example here, I set up a data folder at the top level with /data/train/ai and /data/train/real
5127
and would .mp3 and .wav files that I want to fintune against. I got the real data from
5228
FMA [Link Text][https://github.com/mdeff/fma] for testing, and the AI generated data from
53-
Facebook's Music Gen.
29+
Facebook's Music Gen. There needs to be the word "ai" in the path of the ai folders and "real" in the
30+
path to the real songs.
5431

55-
**NOTE: In /model/pretrained/cnn14.py, I'm hardcoding the path to be /mode/pretrained/pretrained_models/Cnn14_16k_mAP=0.438.pth.gz. This would have to be changed in the future. Cnn14 only takes in gzip files
32+
**NOTE: In /model/pretrained/cnn14.py, I'm hardcoding the path to be /model/pretrained/pretrained_models/Cnn14_16k_mAP=0.438.pth.gz. This would have to be changed in the future. Cnn14 only takes in gzip files
5633
so gzip your file beforehand**
5734

5835
Steps:
59-
1. First place files in audio-processing-ai/data/train (if you are going to finetune data against your model)
36+
1. First place files in audio-processing-ai/data/train (if you are going to finetune data against your model)
37+
**All AI Files should go in the /data/train/ai and all of the real files goes in /data/train/real. This is because we need to do supervised learning befor training the classfier which file is AI music and which is Real**
6038
2. Figure out the model you are going to finetune against
6139
3. Update this line (PRETRAINED_MODEL_PATH = 'model/pretrained/pretrained_models/Cnn14_16k_mAP=0.438.pth.gz') at cnn14.py to the .pth.gz file location of your choice
6240

6341
To train the model:
6442
```bash
65-
cd audio-processing-ai
66-
python train.py --epoch 5 --dataFolder data/train/ --savePath model/saved_models/your_model.pth
43+
python train.py \
44+
--num-epochs 5 \
45+
--dataFolder data/train/ \
46+
--savedPath model/saved_models/your_model.pth \
47+
[--resume-from path/to/checkpoint.pth] # Optional: resume from a checkpoint
6748
```
6849

69-
### Inference
70-
71-
If you have an already trained/finetuned model and you just want to run the prediction,
72-
run it as such.
73-
74-
Folder is the path to the audio files you want to test against.
50+
Required arguments:
51+
- `--savedPath`: Path where the model will be saved (must end in .pth)
52+
- `--dataFolder`: Directory containing training data (default: "data/train/")
53+
- `--num-epochs`: Number of training epochs (default: 5)
7554

76-
Example lists the model path as model/saved_models/your_model.pth but that is changeable
77-
depending on where you saved it.
55+
Optional arguments:
56+
- `--resume-from`: Path to a checkpoint to resume training from
7857

79-
The outputted file is predictions_timestamp.csv
58+
### Inference
8059

8160
To run predictions on audio files:
8261
```bash
83-
python predict.py --folder path/to/audio/files --model model/saved_models/your_model.pth
62+
python predict.py \
63+
--folder path/to/audio/files \
64+
--model model/saved_models/your_model.pth
8465
```
8566

67+
Required arguments:
68+
- `--folder`: Directory containing .mp3/.wav files to analyze
69+
- `--model`: Path to your trained model (.pth file)
70+
71+
The script will:
72+
1. Process each audio file in the specified folder
73+
2. Generate predictions for AI-generated content and audio scene tags
74+
3. Save results to a CSV file named `predictions_YYYYMMDD_HHMM.csv`
75+
8676
## Project Structure
8777

8878
- `inference/`: Inference scripts for prediction
@@ -101,45 +91,3 @@ python predict.py --folder path/to/audio/files --model model/saved_models/your_m
10191
- Training data should be organized in the `data/train/` directory
10292
- Model checkpoints are saved in `model/saved_models/`
10393
- The project is installed as a Python package for proper import handling
104-
105-
## Code Quality
106-
107-
This project uses Ruff for both linting and formatting Python code. Ruff is a fast Python linter and formatter written in Rust.
108-
109-
### Using Ruff
110-
111-
1. Install Ruff (it's already included in the dev dependencies):
112-
```bash
113-
# Using pip (recommended if you want to use your existing virtual environment)
114-
pip install -e ".[dev]"
115-
116-
# OR using uv pip (if you want to use uv but keep your current virtual environment)
117-
uv pip install -e ".[dev]"
118-
119-
# Note: Do NOT use 'uv venv' unless you want to create a new virtual environment
120-
# with pyenv. If you want to use uv while keeping your current environment,
121-
# use 'uv pip' instead.
122-
```
123-
124-
2. Format your code:
125-
```bash
126-
ruff format .
127-
```
128-
129-
3. Lint your code:
130-
```bash
131-
ruff check .
132-
```
133-
134-
4. Fix linting issues automatically:
135-
```bash
136-
ruff check --fix .
137-
```
138-
139-
The Ruff configuration is in `pyproject.toml`. Currently, it:
140-
- Uses a line length of 88 characters (same as Black)
141-
- Targets Python 3.9
142-
- Enables pycodestyle (`E`) and Pyflakes (`F`) rules by default
143-
- Ignores line length violations (`E501`)
144-
145-
You can customize the Ruff configuration by modifying the `[tool.ruff]` section in `pyproject.toml`.

check_bias.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
import glob
3+
import wave
4+
import torch
5+
import torchaudio
6+
import numpy as np
7+
from tqdm import tqdm
8+
from model.pretrained.dual_head_cnn14 import DualHeadCnn14
9+
10+
def load_audio(path, sample_rate=16000, target_length=64000):
11+
waveform, sr = torchaudio.load(path)
12+
if sr != sample_rate:
13+
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
14+
15+
waveform = waveform.mean(dim=0, keepdim=True) # mono: [1, T]
16+
17+
# Ensure the input has at least target_length (e.g., 4s of audio)
18+
if waveform.shape[1] < target_length:
19+
pad_len = target_length - waveform.shape[1]
20+
waveform = torch.nn.functional.pad(waveform, (0, pad_len))
21+
else:
22+
waveform = waveform[:, :target_length] # truncate if too long
23+
24+
return waveform # [1, target_length]
25+
26+
def get_logits(model, files, device):
27+
logits = []
28+
for file in tqdm(files, desc="Evaluating"):
29+
try:
30+
waveform = load_audio(file, target_length=64000).to(device)
31+
waveform = waveform.to(device) # waveform: [1, T]
32+
waveform = waveform.unsqueeze(0) # [1, 1, T]
33+
waveform = waveform.unsqueeze(2) # [1, 1, 1, T]
34+
35+
if waveform.ndim != 4:
36+
37+
print(f"❌ Invalid input shape {waveform.shape} → expected [1, 1, 1, T]")
38+
continue
39+
40+
with torch.no_grad():
41+
print(f"{file} → waveform shape: {waveform.shape}")
42+
logit = model(waveform).squeeze().item()
43+
logits.append(logit)
44+
45+
except Exception as e:
46+
print(f"❌ Error processing {file}: {e}")
47+
return logits
48+
49+
def safe_mean(x):
50+
return np.mean(x) if len(x) > 0 else float("nan")
51+
52+
# Gather files (wav + mp3) from data/train
53+
real_files = glob.glob("data/train/real/**/*.wav", recursive=True) + \
54+
glob.glob("data/train/real/**/*.mp3", recursive=True)
55+
ai_files = glob.glob("data/train/ai/**/*.wav", recursive=True) + \
56+
glob.glob("data/train/ai/**/*.mp3", recursive=True)
57+
58+
print(f"🟩 Found {len(real_files)} real audio files.")
59+
print(f"🟥 Found {len(ai_files)} AI audio files.")
60+
61+
if not real_files:
62+
print("⚠️ Warning: No real files found.")
63+
if not ai_files:
64+
print("⚠️ Warning: No AI files found.")
65+
66+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67+
68+
model = DualHeadCnn14(pretrained=True)
69+
model.load_state_dict(torch.load("/Users/sumerjoshi/upwork/audio-processing-ai/model/saved_models/Cnn14_16k_mAP_around2000_20250614_1223.pth", map_location=device))
70+
model.eval()
71+
model.to(device)
72+
73+
real_logits = get_logits(model, real_files, device)
74+
ai_logits = get_logits(model, ai_files, device)
75+
76+
avg_real_logit = safe_mean(real_logits)
77+
avg_ai_logit = safe_mean(ai_logits)
78+
79+
print("\n=== Bias Check Results ===")
80+
if real_logits:
81+
print(f"Real Logit Avg: {avg_real_logit:.4f} | Sigmoid: {torch.sigmoid(torch.tensor(avg_real_logit)).item():.4f}")
82+
else:
83+
print("⚠️ No real logits computed.")
84+
85+
if ai_logits:
86+
print(f"AI Logit Avg: {avg_ai_logit:.4f} | Sigmoid: {torch.sigmoid(torch.tensor(avg_ai_logit)).item():.4f}")
87+
else:
88+
print("⚠️ No AI logits computed.")

check_bias_fixed.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from logging import BufferingFormatter
2+
import os
3+
import glob
4+
import wave
5+
import torch
6+
import torchaudio
7+
import numpy as np
8+
from tqdm import tqdm
9+
from torch import Tensor
10+
from model.pretrained.dual_head_cnn14 import DualHeadCnn14Simple
11+
from predict import preprocess_audio as load_audio # reuse the same logic
12+
13+
14+
def get_logits(model, files, device):
15+
logits = []
16+
for file_path in tqdm(files, desc="Evaluating"):
17+
try:
18+
waveform = load_audio(file_path=file_path).to(device) # [1, T]
19+
print(f"Initial waveform shape: {waveform.shape}")
20+
if waveform.ndim == 2:
21+
waveform = waveform.unsqueeze(0) # [1, 1, T]
22+
elif waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1:
23+
pass # already correct
24+
else:
25+
print(f"❌ Unexpected input shape: {waveform.shape}")
26+
continue
27+
28+
with torch.no_grad():
29+
input_tensor = waveform # expected by model
30+
print(f"{file_path} → waveform shape: {input_tensor.shape}")
31+
binary_logit, _ = model(input_tensor)
32+
logit = binary_logit.squeeze().item()
33+
prob = torch.sigmoid(binary_logit).squeeze().item()
34+
35+
print(f" → Logit: {logit:.4f}, Sigmoid Prob: {prob:.4f}")
36+
logits.append(logit)
37+
38+
except Exception as e:
39+
print(f"❌ Error processing {file_path}: {e}")
40+
return logits
41+
42+
def safe_mean(x):
43+
return np.mean(x) if len(x) > 0 else float("nan")
44+
45+
# Gather files (wav + mp3) from data/train
46+
real_files = glob.glob("data/train/real/**/*.wav", recursive=True) + \
47+
glob.glob("data/train/real/**/*.mp3", recursive=True)
48+
ai_files = glob.glob("data/train/ai/**/*.wav", recursive=True) + \
49+
glob.glob("data/train/ai/**/*.mp3", recursive=True)
50+
51+
print(f"🟩 Found {len(real_files)} real audio files.")
52+
print(f"🟥 Found {len(ai_files)} AI audio files.")
53+
54+
55+
if not real_files:
56+
print("⚠️ Warning: No real files found.")
57+
if not ai_files:
58+
print("⚠️ Warning: No AI files found.")
59+
60+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61+
62+
state_dict = torch.load("/Users/sumerjoshi/upwork/audio-processing-ai/model/saved_models/Cnn14_16k_mAP_around2000_samplingAndRealTransformChanges_20250615_0746.pth", map_location=device)
63+
model = DualHeadCnn14Simple(pretrained=False) # Use False since weights are from your training
64+
model.load_state_dict(state_dict)
65+
model.eval()
66+
model.to(device)
67+
68+
real_logits = get_logits(model, real_files, device)
69+
ai_logits = get_logits(model, ai_files, device)
70+
71+
avg_real_logit = safe_mean(real_logits)
72+
avg_ai_logit = safe_mean(ai_logits)
73+
74+
print("\n=== Bias Check Results ===")
75+
if real_logits:
76+
print(f"Real Logit Avg: {avg_real_logit:.4f} | Sigmoid: {torch.sigmoid(torch.tensor(avg_real_logit)).item():.4f}")
77+
else:
78+
print("⚠️ No real logits computed.")
79+
80+
if ai_logits:
81+
print(f"AI Logit Avg: {avg_ai_logit:.4f} | Sigmoid: {torch.sigmoid(torch.tensor(avg_ai_logit)).item():.4f}")
82+
else:
83+
print("⚠️ No AI logits computed.")

0 commit comments

Comments
 (0)