|  | 
|  | 1 | +# Copyright 2025 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import inspect | 
|  | 16 | +from dataclasses import dataclass | 
|  | 17 | +from typing import Dict, List, Type, Union | 
|  | 18 | + | 
|  | 19 | +import torch | 
|  | 20 | +import torch.distributed._functional_collectives as funcol | 
|  | 21 | + | 
|  | 22 | +from ..models._modeling_parallel import ( | 
|  | 23 | +    ContextParallelConfig, | 
|  | 24 | +    ContextParallelInput, | 
|  | 25 | +    ContextParallelModelPlan, | 
|  | 26 | +    ContextParallelOutput, | 
|  | 27 | +) | 
|  | 28 | +from ..utils import get_logger | 
|  | 29 | +from ..utils.torch_utils import unwrap_module | 
|  | 30 | +from .hooks import HookRegistry, ModelHook | 
|  | 31 | + | 
|  | 32 | + | 
|  | 33 | +logger = get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 34 | + | 
|  | 35 | +_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" | 
|  | 36 | +_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata | 
|  | 40 | +@dataclass | 
|  | 41 | +class ModuleForwardMetadata: | 
|  | 42 | +    cached_parameter_indices: Dict[str, int] = None | 
|  | 43 | +    _cls: Type = None | 
|  | 44 | + | 
|  | 45 | +    def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): | 
|  | 46 | +        kwargs = kwargs or {} | 
|  | 47 | + | 
|  | 48 | +        if identifier in kwargs: | 
|  | 49 | +            return kwargs[identifier], True, None | 
|  | 50 | + | 
|  | 51 | +        if self.cached_parameter_indices is not None: | 
|  | 52 | +            index = self.cached_parameter_indices.get(identifier, None) | 
|  | 53 | +            if index is None: | 
|  | 54 | +                raise ValueError(f"Parameter '{identifier}' not found in cached indices.") | 
|  | 55 | +            return args[index], False, index | 
|  | 56 | + | 
|  | 57 | +        if self._cls is None: | 
|  | 58 | +            raise ValueError("Model class is not set for metadata.") | 
|  | 59 | + | 
|  | 60 | +        parameters = list(inspect.signature(self._cls.forward).parameters.keys()) | 
|  | 61 | +        parameters = parameters[1:]  # skip `self` | 
|  | 62 | +        self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} | 
|  | 63 | + | 
|  | 64 | +        if identifier not in self.cached_parameter_indices: | 
|  | 65 | +            raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") | 
|  | 66 | + | 
|  | 67 | +        index = self.cached_parameter_indices[identifier] | 
|  | 68 | + | 
|  | 69 | +        if index >= len(args): | 
|  | 70 | +            raise ValueError(f"Expected {index} arguments but got {len(args)}.") | 
|  | 71 | + | 
|  | 72 | +        return args[index], False, index | 
|  | 73 | + | 
|  | 74 | + | 
|  | 75 | +def apply_context_parallel( | 
|  | 76 | +    module: torch.nn.Module, | 
|  | 77 | +    parallel_config: ContextParallelConfig, | 
|  | 78 | +    plan: Dict[str, ContextParallelModelPlan], | 
|  | 79 | +) -> None: | 
|  | 80 | +    """Apply context parallel on a model.""" | 
|  | 81 | +    logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}") | 
|  | 82 | + | 
|  | 83 | +    for module_id, cp_model_plan in plan.items(): | 
|  | 84 | +        submodule = _get_submodule_by_name(module, module_id) | 
|  | 85 | +        if not isinstance(submodule, list): | 
|  | 86 | +            submodule = [submodule] | 
|  | 87 | + | 
|  | 88 | +        logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") | 
|  | 89 | + | 
|  | 90 | +        for m in submodule: | 
|  | 91 | +            if isinstance(cp_model_plan, dict): | 
|  | 92 | +                hook = ContextParallelSplitHook(cp_model_plan, parallel_config) | 
|  | 93 | +                hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) | 
|  | 94 | +            elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): | 
|  | 95 | +                if isinstance(cp_model_plan, ContextParallelOutput): | 
|  | 96 | +                    cp_model_plan = [cp_model_plan] | 
|  | 97 | +                if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): | 
|  | 98 | +                    raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") | 
|  | 99 | +                hook = ContextParallelGatherHook(cp_model_plan, parallel_config) | 
|  | 100 | +                hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) | 
|  | 101 | +            else: | 
|  | 102 | +                raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") | 
|  | 103 | +            registry = HookRegistry.check_if_exists_or_initialize(m) | 
|  | 104 | +            registry.register_hook(hook, hook_name) | 
|  | 105 | + | 
|  | 106 | + | 
|  | 107 | +def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None: | 
|  | 108 | +    for module_id, cp_model_plan in plan.items(): | 
|  | 109 | +        submodule = _get_submodule_by_name(module, module_id) | 
|  | 110 | +        if not isinstance(submodule, list): | 
|  | 111 | +            submodule = [submodule] | 
|  | 112 | + | 
|  | 113 | +        for m in submodule: | 
|  | 114 | +            registry = HookRegistry.check_if_exists_or_initialize(m) | 
|  | 115 | +            if isinstance(cp_model_plan, dict): | 
|  | 116 | +                hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) | 
|  | 117 | +            elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): | 
|  | 118 | +                hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) | 
|  | 119 | +            else: | 
|  | 120 | +                raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") | 
|  | 121 | +            registry.remove_hook(hook_name) | 
|  | 122 | + | 
|  | 123 | + | 
|  | 124 | +class ContextParallelSplitHook(ModelHook): | 
|  | 125 | +    def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: | 
|  | 126 | +        super().__init__() | 
|  | 127 | +        self.metadata = metadata | 
|  | 128 | +        self.parallel_config = parallel_config | 
|  | 129 | +        self.module_forward_metadata = None | 
|  | 130 | + | 
|  | 131 | +    def initialize_hook(self, module): | 
|  | 132 | +        cls = unwrap_module(module).__class__ | 
|  | 133 | +        self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) | 
|  | 134 | +        return module | 
|  | 135 | + | 
|  | 136 | +    def pre_forward(self, module, *args, **kwargs): | 
|  | 137 | +        args_list = list(args) | 
|  | 138 | + | 
|  | 139 | +        for name, cpm in self.metadata.items(): | 
|  | 140 | +            if isinstance(cpm, ContextParallelInput) and cpm.split_output: | 
|  | 141 | +                continue | 
|  | 142 | + | 
|  | 143 | +            # Maybe the parameter was passed as a keyword argument | 
|  | 144 | +            input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( | 
|  | 145 | +                name, args_list, kwargs | 
|  | 146 | +            ) | 
|  | 147 | + | 
|  | 148 | +            if input_val is None: | 
|  | 149 | +                continue | 
|  | 150 | + | 
|  | 151 | +            # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard | 
|  | 152 | +            # the output instead of input for a particular layer by setting split_output=True | 
|  | 153 | +            if isinstance(input_val, torch.Tensor): | 
|  | 154 | +                input_val = self._prepare_cp_input(input_val, cpm) | 
|  | 155 | +            elif isinstance(input_val, (list, tuple)): | 
|  | 156 | +                if len(input_val) != len(cpm): | 
|  | 157 | +                    raise ValueError( | 
|  | 158 | +                        f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." | 
|  | 159 | +                    ) | 
|  | 160 | +                sharded_input_val = [] | 
|  | 161 | +                for i, x in enumerate(input_val): | 
|  | 162 | +                    if torch.is_tensor(x) and not cpm[i].split_output: | 
|  | 163 | +                        x = self._prepare_cp_input(x, cpm[i]) | 
|  | 164 | +                    sharded_input_val.append(x) | 
|  | 165 | +                input_val = sharded_input_val | 
|  | 166 | +            else: | 
|  | 167 | +                raise ValueError(f"Unsupported input type: {type(input_val)}") | 
|  | 168 | + | 
|  | 169 | +            if is_kwarg: | 
|  | 170 | +                kwargs[name] = input_val | 
|  | 171 | +            elif index is not None and index < len(args_list): | 
|  | 172 | +                args_list[index] = input_val | 
|  | 173 | +            else: | 
|  | 174 | +                raise ValueError( | 
|  | 175 | +                    f"An unexpected error occurred while processing the input '{name}'. Please open an " | 
|  | 176 | +                    f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " | 
|  | 177 | +                    f"example along with the full stack trace." | 
|  | 178 | +                ) | 
|  | 179 | + | 
|  | 180 | +        return tuple(args_list), kwargs | 
|  | 181 | + | 
|  | 182 | +    def post_forward(self, module, output): | 
|  | 183 | +        is_tensor = isinstance(output, torch.Tensor) | 
|  | 184 | +        is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output) | 
|  | 185 | + | 
|  | 186 | +        if not is_tensor and not is_tensor_list: | 
|  | 187 | +            raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") | 
|  | 188 | + | 
|  | 189 | +        output = [output] if is_tensor else list(output) | 
|  | 190 | +        for index, cpm in self.metadata.items(): | 
|  | 191 | +            if not isinstance(cpm, ContextParallelInput) or not cpm.split_output: | 
|  | 192 | +                continue | 
|  | 193 | +            if index >= len(output): | 
|  | 194 | +                raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") | 
|  | 195 | +            current_output = output[index] | 
|  | 196 | +            current_output = self._prepare_cp_input(current_output, cpm) | 
|  | 197 | +            output[index] = current_output | 
|  | 198 | + | 
|  | 199 | +        return output[0] if is_tensor else tuple(output) | 
|  | 200 | + | 
|  | 201 | +    def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: | 
|  | 202 | +        if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: | 
|  | 203 | +            raise ValueError( | 
|  | 204 | +                f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." | 
|  | 205 | +            ) | 
|  | 206 | +        return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) | 
|  | 207 | + | 
|  | 208 | + | 
|  | 209 | +class ContextParallelGatherHook(ModelHook): | 
|  | 210 | +    def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: | 
|  | 211 | +        super().__init__() | 
|  | 212 | +        self.metadata = metadata | 
|  | 213 | +        self.parallel_config = parallel_config | 
|  | 214 | + | 
|  | 215 | +    def post_forward(self, module, output): | 
|  | 216 | +        is_tensor = isinstance(output, torch.Tensor) | 
|  | 217 | + | 
|  | 218 | +        if is_tensor: | 
|  | 219 | +            output = [output] | 
|  | 220 | +        elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): | 
|  | 221 | +            raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") | 
|  | 222 | + | 
|  | 223 | +        output = list(output) | 
|  | 224 | + | 
|  | 225 | +        if len(output) != len(self.metadata): | 
|  | 226 | +            raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") | 
|  | 227 | + | 
|  | 228 | +        for i, cpm in enumerate(self.metadata): | 
|  | 229 | +            if cpm is None: | 
|  | 230 | +                continue | 
|  | 231 | +            output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) | 
|  | 232 | + | 
|  | 233 | +        return output[0] if is_tensor else tuple(output) | 
|  | 234 | + | 
|  | 235 | + | 
|  | 236 | +class AllGatherFunction(torch.autograd.Function): | 
|  | 237 | +    @staticmethod | 
|  | 238 | +    def forward(ctx, tensor, dim, group): | 
|  | 239 | +        ctx.dim = dim | 
|  | 240 | +        ctx.group = group | 
|  | 241 | +        ctx.world_size = torch.distributed.get_world_size(group) | 
|  | 242 | +        ctx.rank = torch.distributed.get_rank(group) | 
|  | 243 | +        return funcol.all_gather_tensor(tensor, dim, group=group) | 
|  | 244 | + | 
|  | 245 | +    @staticmethod | 
|  | 246 | +    def backward(ctx, grad_output): | 
|  | 247 | +        grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim) | 
|  | 248 | +        return grad_chunks[ctx.rank], None, None | 
|  | 249 | + | 
|  | 250 | + | 
|  | 251 | +class EquipartitionSharder: | 
|  | 252 | +    @classmethod | 
|  | 253 | +    def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | 
|  | 254 | +        # NOTE: the following assertion does not have to be true in general. We simply enforce it for now | 
|  | 255 | +        # because the alternate case has not yet been tested/required for any model. | 
|  | 256 | +        assert tensor.size()[dim] % mesh.size() == 0, ( | 
|  | 257 | +            "Tensor size along dimension to be sharded must be divisible by mesh size" | 
|  | 258 | +        ) | 
|  | 259 | + | 
|  | 260 | +        # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) | 
|  | 261 | +        # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] | 
|  | 262 | + | 
|  | 263 | +        return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] | 
|  | 264 | + | 
|  | 265 | +    @classmethod | 
|  | 266 | +    def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | 
|  | 267 | +        tensor = tensor.contiguous() | 
|  | 268 | +        tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) | 
|  | 269 | +        return tensor | 
|  | 270 | + | 
|  | 271 | + | 
|  | 272 | +def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: | 
|  | 273 | +    if name.count("*") > 1: | 
|  | 274 | +        raise ValueError("Wildcard '*' can only be used once in the name") | 
|  | 275 | +    return _find_submodule_by_name(model, name) | 
|  | 276 | + | 
|  | 277 | + | 
|  | 278 | +def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: | 
|  | 279 | +    if name == "": | 
|  | 280 | +        return model | 
|  | 281 | +    first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") | 
|  | 282 | +    if first_atom == "*": | 
|  | 283 | +        if not isinstance(model, torch.nn.ModuleList): | 
|  | 284 | +            raise ValueError("Wildcard '*' can only be used with ModuleList") | 
|  | 285 | +        submodules = [] | 
|  | 286 | +        for submodule in model: | 
|  | 287 | +            subsubmodules = _find_submodule_by_name(submodule, remaining_name) | 
|  | 288 | +            if not isinstance(subsubmodules, list): | 
|  | 289 | +                subsubmodules = [subsubmodules] | 
|  | 290 | +            submodules.extend(subsubmodules) | 
|  | 291 | +        return submodules | 
|  | 292 | +    else: | 
|  | 293 | +        if hasattr(model, first_atom): | 
|  | 294 | +            submodule = getattr(model, first_atom) | 
|  | 295 | +            return _find_submodule_by_name(submodule, remaining_name) | 
|  | 296 | +        else: | 
|  | 297 | +            raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") | 
0 commit comments