|
23 | 23 | from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict |
24 | 24 | from ..utils import ( |
25 | 25 | USE_PEFT_BACKEND, |
| 26 | + _get_detailed_type, |
26 | 27 | _get_model_file, |
| 28 | + _is_valid_type, |
27 | 29 | is_accelerate_available, |
28 | 30 | is_torch_version, |
29 | 31 | is_transformers_available, |
@@ -577,29 +579,36 @@ def LinearStrengthModel(start, finish, size): |
577 | 579 | pipeline.set_ip_adapter_scale(ip_strengths) |
578 | 580 | ``` |
579 | 581 | """ |
580 | | - transformer = self.transformer |
581 | | - if not isinstance(scale, list): |
582 | | - scale = [[scale] * transformer.config.num_layers] |
583 | | - elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float): |
584 | | - if len(scale) != transformer.config.num_layers: |
585 | | - raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.") |
| 582 | + |
| 583 | + scale_type = Union[int, float] |
| 584 | + num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters |
| 585 | + num_layers = self.transformer.config.num_layers |
| 586 | + |
| 587 | + # Single value for all layers of all IP-Adapters |
| 588 | + if isinstance(scale, scale_type): |
| 589 | + scale = [scale for _ in range(num_ip_adapters)] |
| 590 | + # List of per-layer scales for a single IP-Adapter |
| 591 | + elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1: |
586 | 592 | scale = [scale] |
| 593 | + # Invalid scale type |
| 594 | + elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]): |
| 595 | + raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.") |
587 | 596 |
|
588 | | - scale_configs = scale |
| 597 | + if len(scale) != num_ip_adapters: |
| 598 | + raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.") |
589 | 599 |
|
590 | | - key_id = 0 |
591 | | - for attn_name, attn_processor in transformer.attn_processors.items(): |
592 | | - if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)): |
593 | | - if len(scale_configs) != len(attn_processor.scale): |
594 | | - raise ValueError( |
595 | | - f"Cannot assign {len(scale_configs)} scale_configs to " |
596 | | - f"{len(attn_processor.scale)} IP-Adapter." |
597 | | - ) |
598 | | - elif len(scale_configs) == 1: |
599 | | - scale_configs = scale_configs * len(attn_processor.scale) |
600 | | - for i, scale_config in enumerate(scale_configs): |
601 | | - attn_processor.scale[i] = scale_config[key_id] |
602 | | - key_id += 1 |
| 600 | + if any(len(s) != num_layers for s in scale if isinstance(s, list)): |
| 601 | + invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers} |
| 602 | + raise ValueError( |
| 603 | + f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}." |
| 604 | + ) |
| 605 | + |
| 606 | + # Scalars are transformed to lists with length num_layers |
| 607 | + scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale] |
| 608 | + |
| 609 | + # Set scales. zip over scale_configs prevents going into single transformer layers |
| 610 | + for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs): |
| 611 | + attn_processor.scale = scale |
603 | 612 |
|
604 | 613 | def unload_ip_adapter(self): |
605 | 614 | """ |
|
0 commit comments