|
| 1 | +import torch |
| 2 | + |
| 3 | +class CurvedRescaleCFG: |
| 4 | + @classmethod |
| 5 | + def INPUT_TYPES(s): |
| 6 | + return {"required": { "model": ("MODEL",), |
| 7 | + "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1000000.0, "step": 0.01}), |
| 8 | + "curve_peak_position": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), |
| 9 | + "curve_sharpness": ("FLOAT", {"default": 2.0, "min": 0.01, "max": 1000000.0, "step": 0.01}), |
| 10 | + }} |
| 11 | + RETURN_TYPES = ("MODEL",) |
| 12 | + FUNCTION = "patch" |
| 13 | + |
| 14 | + CATEGORY = "advanced/model" |
| 15 | + |
| 16 | + def patch(self, model, multiplier, curve_peak_position, curve_sharpness): |
| 17 | + def rescale_cfg_advanced_wrapper(args): |
| 18 | + nonlocal multiplier, curve_peak_position, curve_sharpness |
| 19 | + |
| 20 | + cond = args["cond"] |
| 21 | + uncond = args["uncond"] |
| 22 | + cond_scale = args["cond_scale"] |
| 23 | + current_sigma_tensor = args["sigma"] |
| 24 | + x_orig = args["input"] |
| 25 | + |
| 26 | + if cond is None or uncond is None or cond_scale is None or current_sigma_tensor is None or current_sigma_tensor.numel() == 0 or x_orig is None: |
| 27 | + return args.get("uncond_denoised", args.get("cond_denoised", x_orig)) |
| 28 | + |
| 29 | + current_sigma = current_sigma_tensor[0].item() |
| 30 | + |
| 31 | + # Calculate normalized progress (0 at high sigma, 1 at low sigma) |
| 32 | + # Using log scale for better distribution across sigma range |
| 33 | + sigma_normalized = max(0.0, min(1.0, 1.0 - (torch.log(current_sigma_tensor[0] + 1e-10) + 5) / 8)) |
| 34 | + |
| 35 | + # Apply bell curve that starts at 0, peaks at curve_peak_position, and returns to 0 |
| 36 | + # Calculate distance from peak position |
| 37 | + distance_from_peak = abs(sigma_normalized - curve_peak_position) |
| 38 | + |
| 39 | + # Calculate maximum possible distance (furthest edge from peak) |
| 40 | + max_distance = max(curve_peak_position, 1.0 - curve_peak_position) |
| 41 | + |
| 42 | + # Normalize distance: 0 at peak, 1 at furthest edge |
| 43 | + normalized_distance = distance_from_peak / max_distance if max_distance > 0 else 0 |
| 44 | + |
| 45 | + # Create bell curve: 1 at peak, 0 at edges |
| 46 | + # Higher curve_sharpness makes the peak sharper/narrower |
| 47 | + curve_value = (1.0 - normalized_distance) ** curve_sharpness |
| 48 | + |
| 49 | + dynamic_multiplier = multiplier * curve_value |
| 50 | + |
| 51 | + sigma_view = current_sigma_tensor.view(current_sigma_tensor.shape[:1] + (1,) * (cond.ndim - 1)) |
| 52 | + x = x_orig / (sigma_view * sigma_view + 1.0) |
| 53 | + v_pred_cond = ((x - x_orig + cond) * (sigma_view ** 2 + 1.0) ** 0.5) / sigma_view |
| 54 | + v_pred_uncond = ((x - x_orig + uncond) * (sigma_view ** 2 + 1.0) ** 0.5) / sigma_view |
| 55 | + v_pred_cfg = v_pred_uncond + cond_scale * (v_pred_cond - v_pred_uncond) |
| 56 | + ro_pos = torch.std(v_pred_cond, dim=tuple(range(1, v_pred_cond.ndim)), keepdim=True) |
| 57 | + ro_cfg = torch.std(v_pred_cfg, dim=tuple(range(1, v_pred_cfg.ndim)), keepdim=True) |
| 58 | + factor = torch.nan_to_num(ro_pos / (ro_cfg + 1e-5), nan=1.0, posinf=1.0, neginf=1.0) |
| 59 | + v_pred_final = dynamic_multiplier * (v_pred_cfg * factor) + (1.0 - dynamic_multiplier) * v_pred_cfg |
| 60 | + return x_orig - (x - v_pred_final * sigma_view / (sigma_view * sigma_view + 1.0) ** 0.5) |
| 61 | + |
| 62 | + m = model.clone() |
| 63 | + m.set_model_sampler_cfg_function(rescale_cfg_advanced_wrapper) |
| 64 | + return (m, ) |
| 65 | + |
| 66 | +NODE_CLASS_MAPPINGS = { |
| 67 | + "CurvedRescaleCFG": CurvedRescaleCFG, |
| 68 | +} |
0 commit comments