Skip to content

Commit d32e0cf

Browse files
authored
[Feature] Offload all text encoders by default (hao-ai-lab#594)
1 parent 3d6a0df commit d32e0cf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+357
-256
lines changed

.github/workflows/pr-test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,17 @@ jobs:
122122
# Actual tests
123123
encoder-test:
124124
- 'fastvideo/v1/models/encoders/**'
125-
- 'fastvideo/v1/models/loaders/**'
125+
- 'fastvideo/v1/models/loader/**'
126126
- 'fastvideo/v1/tests/encoders/**'
127127
- *common-paths
128128
vae-test:
129129
- 'fastvideo/v1/models/vaes/**'
130-
- 'fastvideo/v1/models/loaders/**'
130+
- 'fastvideo/v1/models/loader/**'
131131
- 'fastvideo/v1/tests/vaes/**'
132132
- *common-paths
133133
transformer-test:
134134
- 'fastvideo/v1/models/dits/**'
135-
- 'fastvideo/v1/models/loaders/**'
135+
- 'fastvideo/v1/models/loader/**'
136136
- 'fastvideo/v1/tests/transformers/**'
137137
- 'fastvideo/v1/layers/**'
138138
- 'fastvideo/v1/attention/**'

examples/inference/basic/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main():
1010
# attempt to identify the optimal arguments.
1111
generator = VideoGenerator.from_pretrained(
1212
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
13-
# if num_gpus > 1, FastVideo will automatically handle distributed setup
13+
# FastVideo will automatically handle distributed setup
1414
num_gpus=2,
1515
use_fsdp_inference=True,
1616
use_cpu_offload=False

fastvideo/v1/configs/models/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field, fields
3-
from typing import Any, Dict
3+
from typing import Any, Dict, List, Tuple
44

55
from fastvideo.v1.logger import init_logger
66

@@ -12,7 +12,9 @@
1212
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
1313
@dataclass
1414
class ArchConfig:
15-
pass
15+
stacked_params_mapping: List[Tuple[str, str, str]] = field(
16+
default_factory=list
17+
) # mapping from huggingface weight names to custom names
1618

1719

1820
@dataclass

fastvideo/v1/configs/models/dits/stepvideo.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig
66

77

8-
def is_blocks(n: str, m) -> bool:
9-
return "blocks" in n and str.isdigit(n.split(".")[-1])
10-
11-
128
@dataclass
139
class StepVideoArchConfig(DiTArchConfig):
14-
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
10+
_fsdp_shard_conditions: list = field(
11+
default_factory=lambda:
12+
[lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()])
1513

1614
_param_names_mapping: dict = field(
1715
default_factory=lambda: {

fastvideo/v1/configs/models/encoders/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ class TextEncoderArchConfig(EncoderArchConfig):
3232
output_past: bool = True
3333
scalable_attention: bool = True
3434
tie_word_embeddings: bool = False
35-
35+
stacked_params_mapping: List[Tuple[str, str, str]] = field(
36+
default_factory=list
37+
) # mapping from huggingface weight names to custom names
3638
tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
39+
_fsdp_shard_conditions: list = field(default_factory=lambda: [])
3740

3841
def __post_init__(self) -> None:
3942
self.tokenizer_kwargs = {

fastvideo/v1/configs/models/encoders/clip.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field
3-
from typing import Optional
3+
from typing import List, Optional, Tuple
44

55
from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig,
66
ImageEncoderConfig,
77
TextEncoderArchConfig,
88
TextEncoderConfig)
99

1010

11+
def _is_transformer_layer(n: str, m) -> bool:
12+
return "layers" in n and str.isdigit(n.split(".")[-1])
13+
14+
15+
def _is_embeddings(n: str, m) -> bool:
16+
return n.endswith("embeddings")
17+
18+
1119
@dataclass
1220
class CLIPTextArchConfig(TextEncoderArchConfig):
1321
vocab_size: int = 49408
@@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
2735
bos_token_id: int = 49406
2836
eos_token_id: int = 49407
2937
text_len: int = 77
38+
stacked_params_mapping: List[Tuple[str, str,
39+
str]] = field(default_factory=lambda: [
40+
# (param_name, shard_name, shard_id)
41+
("qkv_proj", "q_proj", "q"),
42+
("qkv_proj", "k_proj", "k"),
43+
("qkv_proj", "v_proj", "v"),
44+
])
45+
_fsdp_shard_conditions: list = field(
46+
default_factory=lambda: [_is_transformer_layer, _is_embeddings])
3047

3148

3249
@dataclass
@@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
4562
attention_dropout: float = 0.0
4663
initializer_range: float = 0.02
4764
initializer_factor: float = 1.0
65+
stacked_params_mapping: List[Tuple[str, str,
66+
str]] = field(default_factory=lambda: [
67+
# (param_name, shard_name, shard_id)
68+
("qkv_proj", "q_proj", "q"),
69+
("qkv_proj", "k_proj", "k"),
70+
("qkv_proj", "v_proj", "v"),
71+
])
4872

4973

5074
@dataclass

fastvideo/v1/configs/models/encoders/llama.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field
3-
from typing import Optional
3+
from typing import List, Optional, Tuple
44

55
from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
66
TextEncoderConfig)
77

88

9+
def _is_transformer_layer(n: str, m) -> bool:
10+
return "layers" in n and str.isdigit(n.split(".")[-1])
11+
12+
13+
def _is_embeddings(n: str, m) -> bool:
14+
return n.endswith("embed_tokens")
15+
16+
17+
def _is_final_norm(n: str, m) -> bool:
18+
return n.endswith("norm")
19+
20+
921
@dataclass
1022
class LlamaArchConfig(TextEncoderArchConfig):
1123
vocab_size: int = 32000
@@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig):
3244
head_dim: Optional[int] = None
3345
hidden_state_skip_layer: int = 2
3446
text_len: int = 256
47+
stacked_params_mapping: List[Tuple[str, str, str]] = field(
48+
default_factory=lambda: [
49+
# (param_name, shard_name, shard_id)
50+
(".qkv_proj", ".q_proj", "q"),
51+
(".qkv_proj", ".k_proj", "k"),
52+
(".qkv_proj", ".v_proj", "v"),
53+
(".gate_up_proj", ".gate_proj", 0), # type: ignore
54+
(".gate_up_proj", ".up_proj", 1), # type: ignore
55+
])
56+
_fsdp_shard_conditions: list = field(
57+
default_factory=lambda:
58+
[_is_transformer_layer, _is_embeddings, _is_final_norm])
3559

3660

3761
@dataclass

fastvideo/v1/configs/models/encoders/t5.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field
3-
from typing import Optional
3+
from typing import List, Optional, Tuple
44

55
from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
66
TextEncoderConfig)
77

88

9+
def _is_transformer_layer(n: str, m) -> bool:
10+
return "block" in n and str.isdigit(n.split(".")[-1])
11+
12+
13+
def _is_embeddings(n: str, m) -> bool:
14+
return n.endswith("shared")
15+
16+
17+
def _is_final_layernorm(n: str, m) -> bool:
18+
return n.endswith("final_layer_norm")
19+
20+
921
@dataclass
1022
class T5ArchConfig(TextEncoderArchConfig):
1123
vocab_size: int = 32128
@@ -29,6 +41,16 @@ class T5ArchConfig(TextEncoderArchConfig):
2941
eos_token_id: int = 1
3042
classifier_dropout: float = 0.0
3143
text_len: int = 512
44+
stacked_params_mapping: List[Tuple[str, str,
45+
str]] = field(default_factory=lambda: [
46+
# (param_name, shard_name, shard_id)
47+
(".qkv_proj", ".q", "q"),
48+
(".qkv_proj", ".k", "k"),
49+
(".qkv_proj", ".v", "v"),
50+
])
51+
_fsdp_shard_conditions: list = field(
52+
default_factory=lambda:
53+
[_is_transformer_layer, _is_embeddings, _is_final_layernorm])
3254

3355
# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
3456
def __post_init__(self):

fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
build_parquet_iterable_style_dataloader)
1212
from fastvideo.v1.distributed import get_world_rank
1313
from fastvideo.v1.distributed.parallel_state import (
14-
cleanup_dist_env_and_memory, get_torch_device,
14+
cleanup_dist_env_and_memory, get_local_torch_device,
1515
maybe_init_distributed_environment_and_model_parallel)
1616
from fastvideo.v1.logger import init_logger
1717

@@ -148,8 +148,8 @@ def main() -> None:
148148
break
149149

150150
# Move data to device
151-
latents = latents.to(get_torch_device())
152-
embeddings = embeddings.to(get_torch_device())
151+
latents = latents.to(get_local_torch_device())
152+
embeddings = embeddings.to(get_local_torch_device())
153153

154154
# Calculate actual batch size
155155
batch_size = latents.size(0)

fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
build_parquet_map_style_dataloader)
1414
from fastvideo.v1.distributed import get_world_rank
1515
from fastvideo.v1.distributed.parallel_state import (
16-
cleanup_dist_env_and_memory, get_torch_device,
16+
cleanup_dist_env_and_memory, get_local_torch_device,
1717
maybe_init_distributed_environment_and_model_parallel)
1818
from fastvideo.v1.logger import init_logger
1919

@@ -165,8 +165,8 @@ def main() -> None:
165165
break
166166

167167
# Move data to device
168-
latents = latents.to(get_torch_device())
169-
embeddings = embeddings.to(get_torch_device())
168+
latents = latents.to(get_local_torch_device())
169+
embeddings = embeddings.to(get_local_torch_device())
170170

171171
# Calculate actual batch size
172172
batch_size = latents.size(0)

0 commit comments

Comments
 (0)