|  | 
| 11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
|  | 14 | +from typing import Optional, Tuple, Union | 
|  | 15 | + | 
| 14 | 16 | from ..utils import deprecate | 
| 15 | 17 | from .controlnets.controlnet import (  # noqa | 
| 16 |  | -    BaseOutput, | 
| 17 | 18 |     ControlNetConditioningEmbedding, | 
| 18 | 19 |     ControlNetModel, | 
| 19 | 20 |     ControlNetOutput, | 
|  | 
| 24 | 25 | class ControlNetOutput(ControlNetOutput): | 
| 25 | 26 |     def __init__(self, *args, **kwargs): | 
| 26 | 27 |         deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead." | 
| 27 |  | -        deprecate("ControlNetOutput", "0.34", deprecation_message) | 
|  | 28 | +        deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message) | 
| 28 | 29 |         super().__init__(*args, **kwargs) | 
| 29 | 30 | 
 | 
| 30 | 31 | 
 | 
| 31 | 32 | class ControlNetModel(ControlNetModel): | 
| 32 |  | -    def __init__(self, *args, **kwargs): | 
|  | 33 | +    def __init__( | 
|  | 34 | +        self, | 
|  | 35 | +        in_channels: int = 4, | 
|  | 36 | +        conditioning_channels: int = 3, | 
|  | 37 | +        flip_sin_to_cos: bool = True, | 
|  | 38 | +        freq_shift: int = 0, | 
|  | 39 | +        down_block_types: Tuple[str, ...] = ( | 
|  | 40 | +            "CrossAttnDownBlock2D", | 
|  | 41 | +            "CrossAttnDownBlock2D", | 
|  | 42 | +            "CrossAttnDownBlock2D", | 
|  | 43 | +            "DownBlock2D", | 
|  | 44 | +        ), | 
|  | 45 | +        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", | 
|  | 46 | +        only_cross_attention: Union[bool, Tuple[bool]] = False, | 
|  | 47 | +        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), | 
|  | 48 | +        layers_per_block: int = 2, | 
|  | 49 | +        downsample_padding: int = 1, | 
|  | 50 | +        mid_block_scale_factor: float = 1, | 
|  | 51 | +        act_fn: str = "silu", | 
|  | 52 | +        norm_num_groups: Optional[int] = 32, | 
|  | 53 | +        norm_eps: float = 1e-5, | 
|  | 54 | +        cross_attention_dim: int = 1280, | 
|  | 55 | +        transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, | 
|  | 56 | +        encoder_hid_dim: Optional[int] = None, | 
|  | 57 | +        encoder_hid_dim_type: Optional[str] = None, | 
|  | 58 | +        attention_head_dim: Union[int, Tuple[int, ...]] = 8, | 
|  | 59 | +        num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, | 
|  | 60 | +        use_linear_projection: bool = False, | 
|  | 61 | +        class_embed_type: Optional[str] = None, | 
|  | 62 | +        addition_embed_type: Optional[str] = None, | 
|  | 63 | +        addition_time_embed_dim: Optional[int] = None, | 
|  | 64 | +        num_class_embeds: Optional[int] = None, | 
|  | 65 | +        upcast_attention: bool = False, | 
|  | 66 | +        resnet_time_scale_shift: str = "default", | 
|  | 67 | +        projection_class_embeddings_input_dim: Optional[int] = None, | 
|  | 68 | +        controlnet_conditioning_channel_order: str = "rgb", | 
|  | 69 | +        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), | 
|  | 70 | +        global_pool_conditions: bool = False, | 
|  | 71 | +        addition_embed_type_num_heads: int = 64, | 
|  | 72 | +    ): | 
| 33 | 73 |         deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead." | 
| 34 |  | -        deprecate("ControlNetModel", "0.34", deprecation_message) | 
| 35 |  | -        super().__init__(*args, **kwargs) | 
|  | 74 | +        deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message) | 
|  | 75 | +        super().__init__( | 
|  | 76 | +            in_channels=in_channels, | 
|  | 77 | +            conditioning_channels=conditioning_channels, | 
|  | 78 | +            flip_sin_to_cos=flip_sin_to_cos, | 
|  | 79 | +            freq_shift=freq_shift, | 
|  | 80 | +            down_block_types=down_block_types, | 
|  | 81 | +            mid_block_type=mid_block_type, | 
|  | 82 | +            only_cross_attention=only_cross_attention, | 
|  | 83 | +            block_out_channels=block_out_channels, | 
|  | 84 | +            layers_per_block=layers_per_block, | 
|  | 85 | +            downsample_padding=downsample_padding, | 
|  | 86 | +            mid_block_scale_factor=mid_block_scale_factor, | 
|  | 87 | +            act_fn=act_fn, | 
|  | 88 | +            norm_num_groups=norm_num_groups, | 
|  | 89 | +            norm_eps=norm_eps, | 
|  | 90 | +            cross_attention_dim=cross_attention_dim, | 
|  | 91 | +            transformer_layers_per_block=transformer_layers_per_block, | 
|  | 92 | +            encoder_hid_dim=encoder_hid_dim, | 
|  | 93 | +            encoder_hid_dim_type=encoder_hid_dim_type, | 
|  | 94 | +            attention_head_dim=attention_head_dim, | 
|  | 95 | +            num_attention_heads=num_attention_heads, | 
|  | 96 | +            use_linear_projection=use_linear_projection, | 
|  | 97 | +            class_embed_type=class_embed_type, | 
|  | 98 | +            addition_embed_type=addition_embed_type, | 
|  | 99 | +            addition_time_embed_dim=addition_time_embed_dim, | 
|  | 100 | +            num_class_embeds=num_class_embeds, | 
|  | 101 | +            upcast_attention=upcast_attention, | 
|  | 102 | +            resnet_time_scale_shift=resnet_time_scale_shift, | 
|  | 103 | +            projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | 
|  | 104 | +            controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, | 
|  | 105 | +            conditioning_embedding_out_channels=conditioning_embedding_out_channels, | 
|  | 106 | +            global_pool_conditions=global_pool_conditions, | 
|  | 107 | +            addition_embed_type_num_heads=addition_embed_type_num_heads, | 
|  | 108 | +        ) | 
| 36 | 109 | 
 | 
| 37 | 110 | 
 | 
| 38 | 111 | class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding): | 
| 39 | 112 |     def __init__(self, *args, **kwargs): | 
| 40 | 113 |         deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead." | 
| 41 |  | -        deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message) | 
|  | 114 | +        deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message) | 
| 42 | 115 |         super().__init__(*args, **kwargs) | 
0 commit comments