Skip to content

Commit 871b3da

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 5eb9ca2 + cc7a9e8 commit 871b3da

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ The list of configurations for various `fms_acceleration` plugins:
800800
- `--padding_free`: technique to process multiple examples in single batch without adding padding tokens that waste compute.
801801
- `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time.
802802
- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental):
803-
- `--fast_moe`: trains MoE models in parallel, increasing throughput and decreasing memory usage.
803+
- `--fast_moe`: trains MoE models in parallel with [Scatter MoE kernels](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/accelerated-moe#fms-acceleration-for-mixture-of-experts), increasing throughput and decreasing memory usage.
804804

805805
Notes:
806806
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
@@ -820,8 +820,13 @@ Notes:
820820
- works only for *multi-gpu*.
821821
- currently only includes the version of *multipack* optimized for linear attention implementations like *flash-attn*.
822822
* Notes on Fast MoE
823-
- `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree).
823+
- `--fast_moe` takes either an integer or boolean value.
824+
- When an integer `n` is passed, it enables expert parallel sharding with the expert parallel degree as `n` along with Scatter MoE kernels enabled.
825+
- When a boolean is passed, the expert parallel degree defaults to 1 and further the behaviour would be as follows:
826+
- if True, it is Scatter MoE Kernels with experts sharded based on the top level sharding protocol (e.g. FSDP).
827+
- if False, Scatter MoE Kernels with complete replication of experts across ranks.
824828
- `world_size` must be divisible by the `ep_degree`
829+
- `number of experts` in the MoE module must be divisible by the `ep_degree`
825830
- Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script.
826831
- The typical usecase for this script is to run:
827832
```

tests/test_sft_trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,21 +1352,23 @@ def test_run_e2e_with_hf_dataset_id(data_args):
13521352
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
13531353
)
13541354
@pytest.mark.parametrize(
1355-
"dataset_path",
1355+
"dataset_path, ep_degree",
13561356
[
1357-
TWITTER_COMPLAINTS_DATA_JSONL,
1357+
(TWITTER_COMPLAINTS_DATA_JSONL, 1),
1358+
(TWITTER_COMPLAINTS_DATA_JSONL, True),
1359+
(TWITTER_COMPLAINTS_DATA_JSONL, False),
13581360
],
13591361
)
1360-
def test_run_moe_ft_and_inference(dataset_path):
1361-
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
1362+
def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
1363+
"""Check if we can finetune a moe model with moe kernels and ep_degree=1"""
13621364
with tempfile.TemporaryDirectory() as tempdir:
13631365
data_args = copy.deepcopy(DATA_ARGS)
13641366
data_args.training_data_path = dataset_path
13651367
model_args = copy.deepcopy(MODEL_ARGS)
13661368
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
13671369
train_args = copy.deepcopy(TRAIN_ARGS)
13681370
train_args.output_dir = tempdir
1369-
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
1371+
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=ep_degree))
13701372
sft_trainer.train(
13711373
model_args, data_args, train_args, fast_moe_config=fast_moe_config
13721374
)

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from dataclasses import dataclass
16+
from dataclasses import dataclass, field
17+
from typing import Union
18+
import argparse
1719
import os
1820

1921
# Third Party
@@ -42,8 +44,15 @@
4244
@parsable_dataclass
4345
@dataclass
4446
class FastMoe:
47+
ep_degree: Union[int, bool] = 1
48+
disable_distributed: bool = field(
49+
default=False, metadata={"help": argparse.SUPPRESS}
50+
)
4551

46-
ep_degree: int = 1
52+
def __post_init__(self):
53+
if isinstance(self.ep_degree, bool):
54+
self.disable_distributed = self.ep_degree
55+
self.ep_degree = 1
4756

4857

4958
@dataclass

tuning/config/acceleration_configs/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
# Standard
1616
from dataclasses import fields, is_dataclass
17-
from typing import List, Type, get_type_hints
17+
from typing import List, Type, Union, get_type_hints
18+
import argparse
1819

1920
# Third Party
2021
from transformers.hf_argparser import DataClass, string_to_bool
@@ -39,6 +40,15 @@ def ensure_nested_dataclasses_initialized(dataclass: DataClass):
3940
setattr(dataclass, f.name, values)
4041

4142

43+
def bool_or_int(v):
44+
if isinstance(v, str):
45+
if v.isdigit():
46+
return int(v)
47+
elif isinstance(v, int):
48+
return v
49+
return string_to_bool(v)
50+
51+
4252
class EnsureTypes:
4353
"""EnsureTypes is a caster with an internal state to memorize the
4454
the casting order, so that we can apply the correct casting type.
@@ -47,7 +57,7 @@ class EnsureTypes:
4757
"""
4858

4959
def __init__(self, *types: Type):
50-
_map = {bool: string_to_bool}
60+
_map = {bool: string_to_bool, Union[bool, int]: bool_or_int}
5161
self.types = [_map.get(t, t) for t in types]
5262
self.reset()
5363

@@ -76,7 +86,13 @@ def parsable_dataclass(cls):
7686
if not is_dataclass(cls):
7787
raise ValueError("parsable only works with dataclass")
7888

79-
types = [fi.type for fi in fields(cls)]
89+
types = (
90+
fi.type
91+
for fi in fields(cls)
92+
if not any(
93+
v == argparse.SUPPRESS for k, v in fi.metadata.items() if k == "help"
94+
)
95+
)
8096

8197
class ParsableDataclass(cls, List):
8298

0 commit comments

Comments
 (0)