From ec895a6609528521b884514c0dbce68105b7c301 Mon Sep 17 00:00:00 2001
From: BlackTechX <155609122+BlackTechX011@users.noreply.github.com>
Date: Fri, 15 Aug 2025 09:14:54 +0000
Subject: [PATCH] modified: README.md deleted: infer.py
deleted: infer.sh modified: requirements.txt new file:
songbloom.py
---
README.md | 137 ++++++++++++++++++++++++++++-------------
infer.py | 88 ---------------------------
infer.sh | 3 -
requirements.txt | 24 ++++++--
songbloom.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++
5 files changed, 267 insertions(+), 140 deletions(-)
delete mode 100644 infer.py
delete mode 100644 infer.sh
create mode 100644 songbloom.py
diff --git a/README.md b/README.md
index 96722d5..0b55cd4 100644
--- a/README.md
+++ b/README.md
@@ -1,83 +1,134 @@
-# [SongBloom]: *Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement*
-We propose **SongBloom**, a novel framework for full-length song generation that leverages an interleaved paradigm of autoregressive sketching and diffusion-based refinement. SongBloom employs an autoregressive diffusion model that combines the high fidelity of diffusion models with the scalability of language models.
-Specifically, it gradually extends a musical sketch from short to long and refines the details from coarse to fine-grained. The interleaved generation paradigm effectively integrates prior semantic and acoustic context to guide the generation process.
-Experimental results demonstrate that SongBloom outperforms existing methods across both subjective and objective metrics and achieves performance comparable to the state-of-the-art commercial music generation platforms.
+# SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement
-
+
-Demo page: [https://cypress-yang.github.io/SongBloom_demo](https://cypress-yang.github.io/SongBloom_demo)
+[](https://arxiv.org/abs/2506.07634)
+[](https://huggingface.co/CypressYang/SongBloom)
+[](https://www.apache.org/licenses/LICENSE-2.0)
-ArXiv: [https://arxiv.org/abs/2506.07634](https://arxiv.org/abs/2506.07634)
+
-## Prepare Environments
+We propose **SongBloom**, a novel framework for full-length song generation that leverages an interleaved paradigm of autoregressive sketching and diffusion-based refinement. By combining a high-fidelity diffusion model with a scalable language model, SongBloom gradually extends a musical sketch from short to long and refines details from coarse to fine-grained.
+
+This interleaved paradigm effectively integrates prior semantic and acoustic context to guide the generation process, achieving state-of-the-art results in coherent, full-length song creation.
+
+### ▶️ [**Check out the Demos**](https://cypress-yang.github.io/SongBloom_demo/)
+
+
+
+## 🚀 Getting Started
+
+Follow these three simple steps to generate your first song with SongBloom.
+
+### Step 1: Set Up Your Environment
+
+First, clone the repository and set up the Conda environment.
```bash
+# Clone the repository
+git clone https://github.com/Cypress-Yang/SongBloom.git
+cd SongBloom
+
+# Create and activate the Conda environment
conda create -n SongBloom python==3.8.12
conda activate SongBloom
-# yum install libsndfile
-# pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 # For different CUDA version
+# For Linux, you may need to install libsndfile first
+# sudo apt-get install libsndfile1 or sudo yum install libsndfile
+
+# Install all required Python packages
pip install -r requirements.txt
```
+> **Note:** The `requirements.txt` file includes a specific version of PyTorch for CUDA 11.8. If you have a different CUDA version, please install the appropriate PyTorch and Torchaudio binaries from the [official site](https://pytorch.org/get-started/previous-versions/).
-## Data Preparation
-A .jsonl file, where each line is a json object:
+### Step 2: Prepare Your Songs (`.songbloom` file)
-```json
-{
- "idx": "The index of each sample",
- "lyrics": "The lyrics to be generated",
- "prompt_wav": "The path of the style prompt audio",
-}
-```
+Instead of complex command-line arguments, you define your songs in a simple `.songbloom` file using the human-readable TOML format. Create a file like `my_songs.songbloom`:
-One example can be refered to as: [example/test.jsonl](example/test.jsonl)
+```toml
+# File: my_songs.songbloom
+# Define one or more songs to generate.
-The prompt wav should be a 10-second, 48kHz audio clip.
+[sunset_lullaby]
+lyrics = "the sun is setting low and the stars are starting to glow"
+prompt_wav = "prompts/my_awesome_prompt.wav" # 10s, 48kHz audio clip
+n_samples = 2 # Optional: Number of variations to generate (default is 1)
+output_name = "sunset_song_final" # Optional: Filename for the output
-The details about lyric format can be found in [docs/lyric_format.md](docs/lyric_format.md).
+[city_rhythm]
+lyrics = "walking through the city streets with a rhythm in my feet"
+prompt_wav = "inputs/city_beat.mp3"
+```
+* The prompt audio should ideally be a **10-second, 48kHz** audio clip.
+* For details on lyric formatting, see [`docs/lyric_format.md`](docs/lyric_format.md).
-## Inference
+### Step 3: Generate Music!
-```bash
-source set_env.sh
+Now, run the main script, pointing it to your configuration file. The model and necessary assets will be downloaded automatically on the first run.
-python3 infer.py --input-jsonl example/test.jsonl
+```bash
+# Basic usage
+python3 songbloom.py my_songs.songbloom
-# For GPUs with low VRAM like RTX4090, you should set the dtype as bfloat16
-python3 infer.py --input-jsonl example/test.jsonl --dtype bfloat16
+# Specify a different output directory
+python3 songbloom.py my_songs.songbloom --output-dir "path/to/my/music"
-# SongBloom also supports flash-attn (optional). To enable it, please install flash-attn (v2.6.3 is used during training) manually and set os.environ['DISABLE_FLASH_ATTN'] = "0" in infer.py:8
+# For GPUs with lower VRAM (e.g., RTX 4090), use bfloat16 for better performance
+python3 songbloom.py my_songs.songbloom --dtype bfloat16
```
+> **Flash Attention**: To enable flash-attn for a potential speed-up, install the library manually and change `DISABLE_FLASH_ATTN` from `"1"` to `"0"` at the top of `songbloom.py`.
+
-## Models
+## 📦 Models
-| Name | Size | Max Length | Prompt type | 🤗 |
+All models are available on the [Hugging Face Hub](https://huggingface.co/CypressYang/SongBloom).
+
+| Name | Size | Max Length | Prompt type | Link |
| -------------------- | ---- | ---------- | ----------- | -------------------------------------------- |
-| songbloom_full_150s | 2B | 2m30s | 10s wav | [link](https://huggingface.co/CypressYang/SongBloom) |
-| songbloom_mulan_150s | 2B | 2m30s | 10s wav / text description | coming soon |
-| ... | | | | |
+| `songbloom_full_150s` | 2B | 2m 30s | 10s wav | [🤗 HF Repo](https://huggingface.co/CypressYang/SongBloom) |
+| `songbloom_mulan_150s` | 2B | 2m 30s | 10s wav / text | *Coming Soon* |
+
+## 📝 TODO List
+
+- [ ] Support Text Description Prompts
+- [ ] Release full-length model version
+
+
+## 📈 Star History
+
+[](https://star-history.com/#Cypress-Yang/SongBloom&Date)
-## TODO List
+
-- [ ] Support Text Description
-- [ ] Full version
+## ✨ Contributors
+
+A huge thank you to all the amazing people who have contributed to this project!
+
+
## Citation
-```
+If you find SongBloom useful in your research, please cite our paper:
+
+```bibtex
@article{yang2025songbloom,
-title={SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement},
-author={Yang, Chenyu and Wang, Shuai and Chen, Hangting and Tan, Wei and Yu, Jianwei and Li, Haizhou},
-journal={arXiv preprint arXiv:2506.07634},
-year={2025}
+ title={SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement},
+ author={Yang, Chenyu and Wang, Shuai and Chen, Hangting and Tan, Wei and Yu, Jianwei and Li, Haizhou},
+ journal={arXiv preprint arXiv:2506.07634},
+ year={2025}
}
```
## License
-SongBloom (codes and weights) is released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
+The code and model weights for SongBloom are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
\ No newline at end of file
diff --git a/infer.py b/infer.py
deleted file mode 100644
index 5af9fb5..0000000
--- a/infer.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import os, sys
-import torch, torchaudio
-import argparse
-import json
-from omegaconf import MISSING, OmegaConf,DictConfig
-from huggingface_hub import hf_hub_download
-
-os.environ['DISABLE_FLASH_ATTN'] = "1"
-from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler
-
-
-def hf_download(repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache", **kwargs):
- cfg_path = hf_hub_download(
- repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir, **kwargs)
- ckpt_path = hf_hub_download(
- repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir, **kwargs)
-
- vae_cfg_path = hf_hub_download(
- repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir, **kwargs)
- vae_ckpt_path = hf_hub_download(
- repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir, **kwargs)
-
- g2p_path = hf_hub_download(
- repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir, **kwargs)
-
-
-
-
-
-
-def load_config(cfg_file, parent_dir="./") -> DictConfig:
- OmegaConf.register_new_resolver("eval", lambda x: eval(x))
- OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
- OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
- OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
- OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
- # cmd_cfg = OmegaConf.from_cli()
-
- file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None \
- else OmegaConf.create()
-
-
- return file_cfg
-
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--repo-id", type=str, default="CypressYang/SongBloom")
- parser.add_argument("--model-name", type=str, default="songbloom_full_150s")
- parser.add_argument("--local-dir", type=str, default="./cache")
- parser.add_argument("--input-jsonl", type=str, required=True)
- parser.add_argument("--output-dir", type=str, default="./output")
- parser.add_argument("--n-samples", type=int, default=2)
- parser.add_argument("--dtype", type=str, default='float32', choices=['float32', 'bfloat16'])
-
- args = parser.parse_args()
-
- hf_download(args.repo_id, args.model_name, args.local_dir)
- cfg = load_config(f"{args.local_dir}/{args.model_name}.yaml", parent_dir=args.local_dir)
-
- dtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
- model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype)
- model.set_generation_params(**cfg.inference)
-
- os.makedirs(args.output_dir, exist_ok=True)
-
- input_lines = open(args.input_jsonl, 'r').readlines()
- input_lines = [json.loads(l.strip()) for l in input_lines]
-
- for test_sample in input_lines:
- # print(test_sample)
- idx, lyrics, prompt_wav = test_sample["idx"], test_sample["lyrics"], test_sample["prompt_wav"]
-
- prompt_wav, sr = torchaudio.load(prompt_wav)
- if sr != model.sample_rate:
- prompt_wav = torchaudio.functional.resample(prompt_wav, sr, model.sample_rate)
- prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype)
- prompt_wav = prompt_wav[..., :10*model.sample_rate]
- # breakpoint()
- for i in range(args.n_samples):
- wav = model.generate(lyrics, prompt_wav)
- torchaudio.save(f'{args.output_dir}/{idx}_s{i}.flac', wav[0].cpu().float(), model.sample_rate)
-
-
-if __name__ == "__main__":
-
- main()
diff --git a/infer.sh b/infer.sh
deleted file mode 100644
index 70d6d98..0000000
--- a/infer.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-source set_env.sh
-
-python3 infer.py --input-jsonl example/test.jsonl
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 68327ff..f8915f1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,17 +1,29 @@
+# Core dependencies for PyTorch and audio processing
+torch==2.2.0
+torchaudio==2.2.0
+
+# Model and experiment management
+lightning==2.2.1
huggingface-hub==0.24.6
+transformers==4.44.1
+omegaconf==2.2.0
+
+# Configuration file parsing for the user-friendly interface
+toml
+
+# NLP and text processing libraries
+# Chinese language support
jieba-fast==0.53
pypinyin==0.51.0
cn2an==0.5.22
+# English language support
wordsegment==1.3.1
g2p-en==2.1.0
-lightning==2.2.1
nltk==3.8.1
-omegaconf==2.2.0
-torch==2.2.0
-torchaudio==2.2.0
-transformers==4.44.1
-einops==0.8.0
spacy==3.7.4
num2words==0.5.13
+
+# Tensor manipulation and audio codecs
+einops==0.8.0
descript-audio-codec==1.0.0
vector_quantize_pytorch==1.14.8
\ No newline at end of file
diff --git a/songbloom.py b/songbloom.py
new file mode 100644
index 0000000..9142ff7
--- /dev/null
+++ b/songbloom.py
@@ -0,0 +1,155 @@
+import os
+import argparse
+import json
+import torch
+import torchaudio
+from omegaconf import OmegaConf, DictConfig
+from huggingface_hub import hf_hub_download
+import toml # Using the toml library for the new config format
+
+# It's good practice to keep environment variable settings at the top
+os.environ['DISABLE_FLASH_ATTN'] = "1"
+
+from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler
+
+def hf_download(repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache", **kwargs):
+ """
+ Downloads model and configuration files from Hugging Face Hub.
+ Prints status messages for a better user experience.
+ """
+ print("Downloading necessary files from Hugging Face Hub...")
+
+ cfg_path = hf_hub_download(
+ repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir, **kwargs
+ )
+ ckpt_path = hf_hub_download(
+ repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir, **kwargs
+ )
+ vae_cfg_path = hf_hub_download(
+ repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir, **kwargs
+ )
+ vae_ckpt_path = hf_hub_download(
+ repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir, **kwargs
+ )
+ g2p_path = hf_hub_download(
+ repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir, **kwargs
+ )
+
+ print("All files downloaded successfully.")
+ return cfg_path
+
+def load_model_config(cfg_file, parent_dir="./") -> DictConfig:
+ """
+ Loads and resolves the OmegaConf configuration.
+ """
+ OmegaConf.register_new_resolver("eval", eval)
+ OmegaConf.register_new_resolver("concat", lambda *x: [item for sublist in x for item in sublist])
+ OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
+ OmegaConf.register_new_resolver("load_yaml", OmegaConf.load)
+ OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
+
+ return OmegaConf.load(cfg_file)
+
+def main():
+ """
+ Main function to drive the song generation process.
+ """
+ parser = argparse.ArgumentParser(
+ description="Generate songs with SongBloom using a user-friendly .songbloom configuration file."
+ )
+ parser.add_argument(
+ "input_file",
+ type=str,
+ help="Path to your .songbloom configuration file."
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="./output",
+ help="Directory to save the generated audio files."
+ )
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ default="CypressYang/SongBloom",
+ help="Hugging Face repository ID for the model."
+ )
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="songbloom_full_150s",
+ help="The name of the model to use."
+ )
+ parser.add_argument(
+ "--local-dir",
+ type=str,
+ default="./cache",
+ help="Local directory to cache downloaded models."
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default='float32',
+ choices=['float32', 'bfloat16'],
+ help="Data type for model inference."
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.input_file):
+ print(f"Error: Input file not found at '{args.input_file}'")
+ return
+
+ # --- Model Loading ---
+ print("--- Initializing SongBloom ---")
+ cfg_path = hf_download(args.repo_id, args.model_name, args.local_dir)
+ cfg = load_model_config(cfg_path, parent_dir=args.local_dir)
+
+ dtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
+ print(f"Loading model with {args.dtype} precision...")
+ model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype)
+ model.set_generation_params(**cfg.inference)
+ print("Model loaded successfully.")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # --- Song Generation ---
+ print(f"\n--- Loading songs from {args.input_file} ---")
+ song_requests = toml.load(args.input_file)
+
+ for song_name, details in song_requests.items():
+ print(f"\nProcessing song: '{song_name}'")
+
+ lyrics = details.get("lyrics")
+ prompt_wav_path = details.get("prompt_wav")
+ output_name = details.get("output_name", song_name) # Default to song name if not provided
+ n_samples = details.get("n_samples", 1)
+
+ if not lyrics or not prompt_wav_path:
+ print(f" -> Skipping '{song_name}' due to missing 'lyrics' or 'prompt_wav'.")
+ continue
+
+ try:
+ prompt_wav, sr = torchaudio.load(prompt_wav_path)
+ except FileNotFoundError:
+ print(f" -> Skipping '{song_name}': Prompt audio not found at '{prompt_wav_path}'")
+ continue
+
+ if sr != model.sample_rate:
+ prompt_wav = torchaudio.functional.resample(prompt_wav, sr, model.sample_rate)
+
+ prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype)
+ prompt_wav = prompt_wav[..., :10 * model.sample_rate]
+
+ for i in range(n_samples):
+ print(f" -> Generating sample {i + 1} of {n_samples}...")
+ wav = model.generate(lyrics, prompt_wav)
+
+ output_filename = f'{args.output_dir}/{output_name}_sample{i + 1}.flac'
+ torchaudio.save(output_filename, wav[0].cpu().float(), model.sample_rate)
+ print(f" -> Saved to {output_filename}")
+
+ print("\n--- All songs processed. ---")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file