Skip to content

Commit 8090401

Browse files
committed
CurvedRescaleCFG
1 parent 072c400 commit 8090401

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
__pycache__
2-
history_folder/
2+
history_folder/
3+
.claude

__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
from . import raffle
22
from . import preview_history # Import the renamed module
33
from . import tag_category_strength # Import the new module
4+
from . import curved_rescale_cfg # Import the curved rescale cfg module
45
from .raffle import Raffle
56
from .preview_history import PreviewHistory # Import the renamed class
67
from .tag_category_strength import TagCategoryStrength # Import the new class
8+
from .curved_rescale_cfg import CurvedRescaleCFG # Import the curved rescale cfg class
79

810
NODE_CLASS_MAPPINGS = {
911
"Raffle": Raffle,
1012
"PreviewHistory": PreviewHistory, # Add the renamed mapping
11-
"TagCategoryStrength": TagCategoryStrength # Add the new mapping
13+
"TagCategoryStrength": TagCategoryStrength, # Add the new mapping
14+
"CurvedRescaleCFG": CurvedRescaleCFG # Add the curved rescale cfg mapping
1215
}
1316
NODE_DISPLAY_NAME_MAPPINGS = {
1417
"Raffle": "Raffle",
1518
"PreviewHistory": "Preview History (Raffle)", # Add the renamed display name
16-
"TagCategoryStrength": "Tag Category Strength (Raffle)" # Add the new display name
19+
"TagCategoryStrength": "Tag Category Strength (Raffle)", # Add the new display name
20+
"CurvedRescaleCFG": "Curved Rescale CFG (Raffle)" # Add the curved rescale cfg display name
1721
}
1822

1923
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

curved_rescale_cfg.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
3+
class CurvedRescaleCFG:
4+
@classmethod
5+
def INPUT_TYPES(s):
6+
return {"required": { "model": ("MODEL",),
7+
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1000000.0, "step": 0.01}),
8+
"curve_peak_position": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
9+
"curve_sharpness": ("FLOAT", {"default": 2.0, "min": 0.01, "max": 1000000.0, "step": 0.01}),
10+
}}
11+
RETURN_TYPES = ("MODEL",)
12+
FUNCTION = "patch"
13+
14+
CATEGORY = "advanced/model"
15+
16+
def patch(self, model, multiplier, curve_peak_position, curve_sharpness):
17+
def rescale_cfg_advanced_wrapper(args):
18+
nonlocal multiplier, curve_peak_position, curve_sharpness
19+
20+
cond = args["cond"]
21+
uncond = args["uncond"]
22+
cond_scale = args["cond_scale"]
23+
current_sigma_tensor = args["sigma"]
24+
x_orig = args["input"]
25+
26+
if cond is None or uncond is None or cond_scale is None or current_sigma_tensor is None or current_sigma_tensor.numel() == 0 or x_orig is None:
27+
return args.get("uncond_denoised", args.get("cond_denoised", x_orig))
28+
29+
current_sigma = current_sigma_tensor[0].item()
30+
31+
# Calculate normalized progress (0 at high sigma, 1 at low sigma)
32+
# Using log scale for better distribution across sigma range
33+
sigma_normalized = max(0.0, min(1.0, 1.0 - (torch.log(current_sigma_tensor[0] + 1e-10) + 5) / 8))
34+
35+
# Apply bell curve that starts at 0, peaks at curve_peak_position, and returns to 0
36+
# Calculate distance from peak position
37+
distance_from_peak = abs(sigma_normalized - curve_peak_position)
38+
39+
# Calculate maximum possible distance (furthest edge from peak)
40+
max_distance = max(curve_peak_position, 1.0 - curve_peak_position)
41+
42+
# Normalize distance: 0 at peak, 1 at furthest edge
43+
normalized_distance = distance_from_peak / max_distance if max_distance > 0 else 0
44+
45+
# Create bell curve: 1 at peak, 0 at edges
46+
# Higher curve_sharpness makes the peak sharper/narrower
47+
curve_value = (1.0 - normalized_distance) ** curve_sharpness
48+
49+
dynamic_multiplier = multiplier * curve_value
50+
51+
sigma_view = current_sigma_tensor.view(current_sigma_tensor.shape[:1] + (1,) * (cond.ndim - 1))
52+
x = x_orig / (sigma_view * sigma_view + 1.0)
53+
v_pred_cond = ((x - x_orig + cond) * (sigma_view ** 2 + 1.0) ** 0.5) / sigma_view
54+
v_pred_uncond = ((x - x_orig + uncond) * (sigma_view ** 2 + 1.0) ** 0.5) / sigma_view
55+
v_pred_cfg = v_pred_uncond + cond_scale * (v_pred_cond - v_pred_uncond)
56+
ro_pos = torch.std(v_pred_cond, dim=tuple(range(1, v_pred_cond.ndim)), keepdim=True)
57+
ro_cfg = torch.std(v_pred_cfg, dim=tuple(range(1, v_pred_cfg.ndim)), keepdim=True)
58+
factor = torch.nan_to_num(ro_pos / (ro_cfg + 1e-5), nan=1.0, posinf=1.0, neginf=1.0)
59+
v_pred_final = dynamic_multiplier * (v_pred_cfg * factor) + (1.0 - dynamic_multiplier) * v_pred_cfg
60+
return x_orig - (x - v_pred_final * sigma_view / (sigma_view * sigma_view + 1.0) ** 0.5)
61+
62+
m = model.clone()
63+
m.set_model_sampler_cfg_function(rescale_cfg_advanced_wrapper)
64+
return (m, )
65+
66+
NODE_CLASS_MAPPINGS = {
67+
"CurvedRescaleCFG": CurvedRescaleCFG,
68+
}

0 commit comments

Comments
 (0)