From ec895a6609528521b884514c0dbce68105b7c301 Mon Sep 17 00:00:00 2001 From: BlackTechX <155609122+BlackTechX011@users.noreply.github.com> Date: Fri, 15 Aug 2025 09:14:54 +0000 Subject: [PATCH] modified: README.md deleted: infer.py deleted: infer.sh modified: requirements.txt new file: songbloom.py --- README.md | 137 ++++++++++++++++++++++++++++------------- infer.py | 88 --------------------------- infer.sh | 3 - requirements.txt | 24 ++++++-- songbloom.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 267 insertions(+), 140 deletions(-) delete mode 100644 infer.py delete mode 100644 infer.sh create mode 100644 songbloom.py diff --git a/README.md b/README.md index 96722d5..0b55cd4 100644 --- a/README.md +++ b/README.md @@ -1,83 +1,134 @@ -# [SongBloom]: *Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement* -We propose **SongBloom**, a novel framework for full-length song generation that leverages an interleaved paradigm of autoregressive sketching and diffusion-based refinement. SongBloom employs an autoregressive diffusion model that combines the high fidelity of diffusion models with the scalability of language models. -Specifically, it gradually extends a musical sketch from short to long and refines the details from coarse to fine-grained. The interleaved generation paradigm effectively integrates prior semantic and acoustic context to guide the generation process. -Experimental results demonstrate that SongBloom outperforms existing methods across both subjective and objective metrics and achieves performance comparable to the state-of-the-art commercial music generation platforms. +# SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement -![img](docs/architecture.png) +
-Demo page: [https://cypress-yang.github.io/SongBloom_demo](https://cypress-yang.github.io/SongBloom_demo) +[![Paper](https://img.shields.io/badge/arXiv-2506.07634-b31b1b.svg)](https://arxiv.org/abs/2506.07634) +[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/CypressYang/SongBloom) +[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0) -ArXiv: [https://arxiv.org/abs/2506.07634](https://arxiv.org/abs/2506.07634) +
-## Prepare Environments +We propose **SongBloom**, a novel framework for full-length song generation that leverages an interleaved paradigm of autoregressive sketching and diffusion-based refinement. By combining a high-fidelity diffusion model with a scalable language model, SongBloom gradually extends a musical sketch from short to long and refines details from coarse to fine-grained. + +This interleaved paradigm effectively integrates prior semantic and acoustic context to guide the generation process, achieving state-of-the-art results in coherent, full-length song creation. + +### ▶️ [**Check out the Demos**](https://cypress-yang.github.io/SongBloom_demo/) + +![SongBloom Architecture](docs/architecture.png) + +## 🚀 Getting Started + +Follow these three simple steps to generate your first song with SongBloom. + +### Step 1: Set Up Your Environment + +First, clone the repository and set up the Conda environment. ```bash +# Clone the repository +git clone https://github.com/Cypress-Yang/SongBloom.git +cd SongBloom + +# Create and activate the Conda environment conda create -n SongBloom python==3.8.12 conda activate SongBloom -# yum install libsndfile -# pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 # For different CUDA version +# For Linux, you may need to install libsndfile first +# sudo apt-get install libsndfile1 or sudo yum install libsndfile + +# Install all required Python packages pip install -r requirements.txt ``` +> **Note:** The `requirements.txt` file includes a specific version of PyTorch for CUDA 11.8. If you have a different CUDA version, please install the appropriate PyTorch and Torchaudio binaries from the [official site](https://pytorch.org/get-started/previous-versions/). -## Data Preparation -A .jsonl file, where each line is a json object: +### Step 2: Prepare Your Songs (`.songbloom` file) -```json -{ - "idx": "The index of each sample", - "lyrics": "The lyrics to be generated", - "prompt_wav": "The path of the style prompt audio", -} -``` +Instead of complex command-line arguments, you define your songs in a simple `.songbloom` file using the human-readable TOML format. Create a file like `my_songs.songbloom`: -One example can be refered to as: [example/test.jsonl](example/test.jsonl) +```toml +# File: my_songs.songbloom +# Define one or more songs to generate. -The prompt wav should be a 10-second, 48kHz audio clip. +[sunset_lullaby] +lyrics = "the sun is setting low and the stars are starting to glow" +prompt_wav = "prompts/my_awesome_prompt.wav" # 10s, 48kHz audio clip +n_samples = 2 # Optional: Number of variations to generate (default is 1) +output_name = "sunset_song_final" # Optional: Filename for the output -The details about lyric format can be found in [docs/lyric_format.md](docs/lyric_format.md). +[city_rhythm] +lyrics = "walking through the city streets with a rhythm in my feet" +prompt_wav = "inputs/city_beat.mp3" +``` +* The prompt audio should ideally be a **10-second, 48kHz** audio clip. +* For details on lyric formatting, see [`docs/lyric_format.md`](docs/lyric_format.md). -## Inference +### Step 3: Generate Music! -```bash -source set_env.sh +Now, run the main script, pointing it to your configuration file. The model and necessary assets will be downloaded automatically on the first run. -python3 infer.py --input-jsonl example/test.jsonl +```bash +# Basic usage +python3 songbloom.py my_songs.songbloom -# For GPUs with low VRAM like RTX4090, you should set the dtype as bfloat16 -python3 infer.py --input-jsonl example/test.jsonl --dtype bfloat16 +# Specify a different output directory +python3 songbloom.py my_songs.songbloom --output-dir "path/to/my/music" -# SongBloom also supports flash-attn (optional). To enable it, please install flash-attn (v2.6.3 is used during training) manually and set os.environ['DISABLE_FLASH_ATTN'] = "0" in infer.py:8 +# For GPUs with lower VRAM (e.g., RTX 4090), use bfloat16 for better performance +python3 songbloom.py my_songs.songbloom --dtype bfloat16 ``` +> **Flash Attention**: To enable flash-attn for a potential speed-up, install the library manually and change `DISABLE_FLASH_ATTN` from `"1"` to `"0"` at the top of `songbloom.py`. + -## Models +## 📦 Models -| Name | Size | Max Length | Prompt type | 🤗 | +All models are available on the [Hugging Face Hub](https://huggingface.co/CypressYang/SongBloom). + +| Name | Size | Max Length | Prompt type | Link | | -------------------- | ---- | ---------- | ----------- | -------------------------------------------- | -| songbloom_full_150s | 2B | 2m30s | 10s wav | [link](https://huggingface.co/CypressYang/SongBloom) | -| songbloom_mulan_150s | 2B | 2m30s | 10s wav / text description | coming soon | -| ... | | | | | +| `songbloom_full_150s` | 2B | 2m 30s | 10s wav | [🤗 HF Repo](https://huggingface.co/CypressYang/SongBloom) | +| `songbloom_mulan_150s` | 2B | 2m 30s | 10s wav / text | *Coming Soon* | + +## 📝 TODO List + +- [ ] Support Text Description Prompts +- [ ] Release full-length model version + + +## 📈 Star History +
+[![Star History Chart](https://api.star-history.com/svg?repos=Cypress-Yang/SongBloom&type=Date)](https://star-history.com/#Cypress-Yang/SongBloom&Date) -## TODO List +
-- [ ] Support Text Description -- [ ] Full version +## ✨ Contributors + +A huge thank you to all the amazing people who have contributed to this project! + +
+ + + + + +
## Citation -``` +If you find SongBloom useful in your research, please cite our paper: + +```bibtex @article{yang2025songbloom, -title={SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement}, -author={Yang, Chenyu and Wang, Shuai and Chen, Hangting and Tan, Wei and Yu, Jianwei and Li, Haizhou}, -journal={arXiv preprint arXiv:2506.07634}, -year={2025} + title={SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement}, + author={Yang, Chenyu and Wang, Shuai and Chen, Hangting and Tan, Wei and Yu, Jianwei and Li, Haizhou}, + journal={arXiv preprint arXiv:2506.07634}, + year={2025} } ``` ## License -SongBloom (codes and weights) is released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). +The code and model weights for SongBloom are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). \ No newline at end of file diff --git a/infer.py b/infer.py deleted file mode 100644 index 5af9fb5..0000000 --- a/infer.py +++ /dev/null @@ -1,88 +0,0 @@ -import os, sys -import torch, torchaudio -import argparse -import json -from omegaconf import MISSING, OmegaConf,DictConfig -from huggingface_hub import hf_hub_download - -os.environ['DISABLE_FLASH_ATTN'] = "1" -from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler - - -def hf_download(repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache", **kwargs): - cfg_path = hf_hub_download( - repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir, **kwargs) - ckpt_path = hf_hub_download( - repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir, **kwargs) - - vae_cfg_path = hf_hub_download( - repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir, **kwargs) - vae_ckpt_path = hf_hub_download( - repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir, **kwargs) - - g2p_path = hf_hub_download( - repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir, **kwargs) - - - - - - -def load_config(cfg_file, parent_dir="./") -> DictConfig: - OmegaConf.register_new_resolver("eval", lambda x: eval(x)) - OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) - OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0]) - OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x)) - OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir)) - # cmd_cfg = OmegaConf.from_cli() - - file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None \ - else OmegaConf.create() - - - return file_cfg - - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--repo-id", type=str, default="CypressYang/SongBloom") - parser.add_argument("--model-name", type=str, default="songbloom_full_150s") - parser.add_argument("--local-dir", type=str, default="./cache") - parser.add_argument("--input-jsonl", type=str, required=True) - parser.add_argument("--output-dir", type=str, default="./output") - parser.add_argument("--n-samples", type=int, default=2) - parser.add_argument("--dtype", type=str, default='float32', choices=['float32', 'bfloat16']) - - args = parser.parse_args() - - hf_download(args.repo_id, args.model_name, args.local_dir) - cfg = load_config(f"{args.local_dir}/{args.model_name}.yaml", parent_dir=args.local_dir) - - dtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype) - model.set_generation_params(**cfg.inference) - - os.makedirs(args.output_dir, exist_ok=True) - - input_lines = open(args.input_jsonl, 'r').readlines() - input_lines = [json.loads(l.strip()) for l in input_lines] - - for test_sample in input_lines: - # print(test_sample) - idx, lyrics, prompt_wav = test_sample["idx"], test_sample["lyrics"], test_sample["prompt_wav"] - - prompt_wav, sr = torchaudio.load(prompt_wav) - if sr != model.sample_rate: - prompt_wav = torchaudio.functional.resample(prompt_wav, sr, model.sample_rate) - prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype) - prompt_wav = prompt_wav[..., :10*model.sample_rate] - # breakpoint() - for i in range(args.n_samples): - wav = model.generate(lyrics, prompt_wav) - torchaudio.save(f'{args.output_dir}/{idx}_s{i}.flac', wav[0].cpu().float(), model.sample_rate) - - -if __name__ == "__main__": - - main() diff --git a/infer.sh b/infer.sh deleted file mode 100644 index 70d6d98..0000000 --- a/infer.sh +++ /dev/null @@ -1,3 +0,0 @@ -source set_env.sh - -python3 infer.py --input-jsonl example/test.jsonl \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 68327ff..f8915f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,29 @@ +# Core dependencies for PyTorch and audio processing +torch==2.2.0 +torchaudio==2.2.0 + +# Model and experiment management +lightning==2.2.1 huggingface-hub==0.24.6 +transformers==4.44.1 +omegaconf==2.2.0 + +# Configuration file parsing for the user-friendly interface +toml + +# NLP and text processing libraries +# Chinese language support jieba-fast==0.53 pypinyin==0.51.0 cn2an==0.5.22 +# English language support wordsegment==1.3.1 g2p-en==2.1.0 -lightning==2.2.1 nltk==3.8.1 -omegaconf==2.2.0 -torch==2.2.0 -torchaudio==2.2.0 -transformers==4.44.1 -einops==0.8.0 spacy==3.7.4 num2words==0.5.13 + +# Tensor manipulation and audio codecs +einops==0.8.0 descript-audio-codec==1.0.0 vector_quantize_pytorch==1.14.8 \ No newline at end of file diff --git a/songbloom.py b/songbloom.py new file mode 100644 index 0000000..9142ff7 --- /dev/null +++ b/songbloom.py @@ -0,0 +1,155 @@ +import os +import argparse +import json +import torch +import torchaudio +from omegaconf import OmegaConf, DictConfig +from huggingface_hub import hf_hub_download +import toml # Using the toml library for the new config format + +# It's good practice to keep environment variable settings at the top +os.environ['DISABLE_FLASH_ATTN'] = "1" + +from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler + +def hf_download(repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache", **kwargs): + """ + Downloads model and configuration files from Hugging Face Hub. + Prints status messages for a better user experience. + """ + print("Downloading necessary files from Hugging Face Hub...") + + cfg_path = hf_hub_download( + repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir, **kwargs + ) + ckpt_path = hf_hub_download( + repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir, **kwargs + ) + vae_cfg_path = hf_hub_download( + repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir, **kwargs + ) + vae_ckpt_path = hf_hub_download( + repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir, **kwargs + ) + g2p_path = hf_hub_download( + repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir, **kwargs + ) + + print("All files downloaded successfully.") + return cfg_path + +def load_model_config(cfg_file, parent_dir="./") -> DictConfig: + """ + Loads and resolves the OmegaConf configuration. + """ + OmegaConf.register_new_resolver("eval", eval) + OmegaConf.register_new_resolver("concat", lambda *x: [item for sublist in x for item in sublist]) + OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0]) + OmegaConf.register_new_resolver("load_yaml", OmegaConf.load) + OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir)) + + return OmegaConf.load(cfg_file) + +def main(): + """ + Main function to drive the song generation process. + """ + parser = argparse.ArgumentParser( + description="Generate songs with SongBloom using a user-friendly .songbloom configuration file." + ) + parser.add_argument( + "input_file", + type=str, + help="Path to your .songbloom configuration file." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./output", + help="Directory to save the generated audio files." + ) + parser.add_argument( + "--repo-id", + type=str, + default="CypressYang/SongBloom", + help="Hugging Face repository ID for the model." + ) + parser.add_argument( + "--model-name", + type=str, + default="songbloom_full_150s", + help="The name of the model to use." + ) + parser.add_argument( + "--local-dir", + type=str, + default="./cache", + help="Local directory to cache downloaded models." + ) + parser.add_argument( + "--dtype", + type=str, + default='float32', + choices=['float32', 'bfloat16'], + help="Data type for model inference." + ) + + args = parser.parse_args() + + if not os.path.exists(args.input_file): + print(f"Error: Input file not found at '{args.input_file}'") + return + + # --- Model Loading --- + print("--- Initializing SongBloom ---") + cfg_path = hf_download(args.repo_id, args.model_name, args.local_dir) + cfg = load_model_config(cfg_path, parent_dir=args.local_dir) + + dtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 + print(f"Loading model with {args.dtype} precision...") + model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype) + model.set_generation_params(**cfg.inference) + print("Model loaded successfully.") + + os.makedirs(args.output_dir, exist_ok=True) + + # --- Song Generation --- + print(f"\n--- Loading songs from {args.input_file} ---") + song_requests = toml.load(args.input_file) + + for song_name, details in song_requests.items(): + print(f"\nProcessing song: '{song_name}'") + + lyrics = details.get("lyrics") + prompt_wav_path = details.get("prompt_wav") + output_name = details.get("output_name", song_name) # Default to song name if not provided + n_samples = details.get("n_samples", 1) + + if not lyrics or not prompt_wav_path: + print(f" -> Skipping '{song_name}' due to missing 'lyrics' or 'prompt_wav'.") + continue + + try: + prompt_wav, sr = torchaudio.load(prompt_wav_path) + except FileNotFoundError: + print(f" -> Skipping '{song_name}': Prompt audio not found at '{prompt_wav_path}'") + continue + + if sr != model.sample_rate: + prompt_wav = torchaudio.functional.resample(prompt_wav, sr, model.sample_rate) + + prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype) + prompt_wav = prompt_wav[..., :10 * model.sample_rate] + + for i in range(n_samples): + print(f" -> Generating sample {i + 1} of {n_samples}...") + wav = model.generate(lyrics, prompt_wav) + + output_filename = f'{args.output_dir}/{output_name}_sample{i + 1}.flac' + torchaudio.save(output_filename, wav[0].cpu().float(), model.sample_rate) + print(f" -> Saved to {output_filename}") + + print("\n--- All songs processed. ---") + +if __name__ == "__main__": + main() \ No newline at end of file