Skip to content

Commit 53a641d

Browse files
committed
Added TagCategoryStrength node
1 parent b818955 commit 53a641d

File tree

3 files changed

+280
-50
lines changed

3 files changed

+280
-50
lines changed

__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from . import raffle
22
from . import preview_history # Import the renamed module
3+
from . import tag_category_strength # Import the new module
34
from .raffle import Raffle
45
from .preview_history import PreviewHistory # Import the renamed class
6+
from .tag_category_strength import TagCategoryStrength # Import the new class
57

68
NODE_CLASS_MAPPINGS = {
79
"Raffle": Raffle,
8-
"PreviewHistory": PreviewHistory # Add the renamed mapping
10+
"PreviewHistory": PreviewHistory, # Add the renamed mapping
11+
"TagCategoryStrength": TagCategoryStrength # Add the new mapping
912
}
1013
NODE_DISPLAY_NAME_MAPPINGS = {
1114
"Raffle": "Raffle",
12-
"PreviewHistory": "Preview History (Raffle)" # Add the renamed display name
15+
"PreviewHistory": "Preview History (Raffle)", # Add the renamed display name
16+
"TagCategoryStrength": "Tag Category Strength (Raffle)" # Add the new display name
1317
}
1418

1519
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

raffle.py

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,55 @@
2525
# Critical categories that should be excluded to maintain workflow
2626
WARNING_ABOUT_NEW_CATEGORIES = {'artist', 'character_name', 'copyright', 'meta'}
2727

28+
# Global list of all available categories - used by both Raffle and TagCategoryStrength
29+
ALL_CATEGORIES = [
30+
'abstract_symbols',
31+
'actions',
32+
'artstyle_technique',
33+
'artist',
34+
'background_objects',
35+
'bodily_fluids',
36+
'camera_angle_perspective',
37+
'camera_focus_subject',
38+
'camera_framing_composition',
39+
'character_count',
40+
'character_name',
41+
'clothes_and_accessories',
42+
'color_scheme',
43+
'content_censorship_methods',
44+
'copyright',
45+
'expressions_and_mental_state',
46+
'female_intimate_anatomy',
47+
'female_physical_descriptors',
48+
'format_and_presentation',
49+
'gaze_direction_and_eye_contact',
50+
'general_clothing_exposure',
51+
'generic_clothing_interactions',
52+
'holding_large_items',
53+
'holding_small_items',
54+
'intentional_design_exposure',
55+
'lighting_and_vfx',
56+
'male_intimate_anatomy',
57+
'male_physical_descriptors',
58+
'meta',
59+
'metadata_and_attribution',
60+
'named_garment_exposure',
61+
'nudity_and_absence_of_clothing',
62+
'one_handed_character_items',
63+
'physical_locations',
64+
'poses',
65+
'publicly_visible_anatomy',
66+
'relationships',
67+
'sex_acts',
68+
'sfw_clothed_anatomy',
69+
'special_backgrounds',
70+
'specific_garment_interactions',
71+
'speech_and_text',
72+
'standard_physical_descriptors',
73+
'thematic_settings',
74+
'two_handed_character_items'
75+
]
76+
2877
class Raffle:
2978
# Class variable to track if the critical categories warning has been shown
3079
_critical_warning_shown = False
@@ -187,54 +236,8 @@ def process_tags(self, exclude_taglists_containing, taglists_must_include, seed,
187236
if not os.path.exists(categorized_tags_file_path):
188237
raise ValueError(f"Categorized tags file not found at {categorized_tags_file_path}")
189238

190-
# Define all available categories and handle exclusions
191-
all_categories = [
192-
'abstract_symbols',
193-
'actions',
194-
'artstyle_technique',
195-
'artist',
196-
'background_objects',
197-
'bodily_fluids',
198-
'camera_angle_perspective',
199-
'camera_focus_subject',
200-
'camera_framing_composition',
201-
'character_count',
202-
'character_name',
203-
'clothes_and_accessories',
204-
'color_scheme',
205-
'content_censorship_methods',
206-
'copyright',
207-
'expressions_and_mental_state',
208-
'female_intimate_anatomy',
209-
'female_physical_descriptors',
210-
'format_and_presentation',
211-
'gaze_direction_and_eye_contact',
212-
'general_clothing_exposure',
213-
'generic_clothing_interactions',
214-
'holding_large_items',
215-
'holding_small_items',
216-
'intentional_design_exposure',
217-
'lighting_and_vfx',
218-
'male_intimate_anatomy',
219-
'male_physical_descriptors',
220-
'meta',
221-
'metadata_and_attribution',
222-
'named_garment_exposure',
223-
'nudity_and_absence_of_clothing',
224-
'one_handed_character_items',
225-
'physical_locations',
226-
'poses',
227-
'publicly_visible_anatomy',
228-
'relationships',
229-
'sex_acts',
230-
'sfw_clothed_anatomy',
231-
'special_backgrounds',
232-
'specific_garment_interactions',
233-
'speech_and_text',
234-
'standard_physical_descriptors',
235-
'thematic_settings',
236-
'two_handed_character_items'
237-
]
239+
# Use the global categories list
240+
all_categories = ALL_CATEGORIES
238241

239242
# Process excluded categories
240243
excluded_categories = []

tag_category_strength.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import os
2+
import re
3+
4+
# Import the global categories list from raffle.py
5+
from .raffle import ALL_CATEGORIES
6+
7+
class TagCategoryStrength:
8+
@classmethod
9+
def INPUT_TYPES(s):
10+
return {
11+
"required": {
12+
"input_tags": ("STRING", {
13+
"multiline": True,
14+
"forceInput": True,
15+
"default": "",
16+
"tooltip": "Input tags to adjust (comma-separated)"
17+
}),
18+
"category_adjustments": ("STRING", {
19+
"multiline": True,
20+
"default": "",
21+
"tooltip": "Category adjustments in format: (category_name:strength), e.g., (artist:1.4), (meta:0.5), (poses:1.2)"
22+
}),
23+
},
24+
"optional": {
25+
"preserve_existing_weights": ("BOOLEAN", {
26+
"default": True,
27+
"tooltip": "If enabled, tags that already have weights like (tag:1.2) will keep their existing weights instead of being adjusted"
28+
}),
29+
}
30+
}
31+
32+
CATEGORY = "Raffle"
33+
RETURN_TYPES = ("STRING", "STRING")
34+
RETURN_NAMES = ("Adjusted tags", "Debug info")
35+
OUTPUT_TOOLTIPS = (
36+
"Tags with category-based strength adjustments applied",
37+
"Information about which tags were adjusted and their categories"
38+
)
39+
FUNCTION = "adjust_tag_categories"
40+
41+
def __init__(self):
42+
self._tag_to_category_cache = None
43+
44+
45+
46+
def _load_tag_categories(self):
47+
"""Load the tag-to-category mapping from categorized_tags.txt"""
48+
if self._tag_to_category_cache is not None:
49+
return self._tag_to_category_cache
50+
51+
extension_path = os.path.normpath(os.path.dirname(__file__))
52+
categorized_tags_file_path = os.path.join(extension_path, "lists", "categorized_tags.txt")
53+
54+
if not os.path.exists(categorized_tags_file_path):
55+
raise ValueError(f"Categorized tags file not found at {categorized_tags_file_path}")
56+
57+
tag_to_category = {}
58+
59+
try:
60+
with open(categorized_tags_file_path, 'r', encoding='utf-8') as f:
61+
for line in f:
62+
line = line.strip()
63+
if not line:
64+
continue
65+
66+
# Parse format: [category] tag
67+
parts = line.split('] ', 1)
68+
if len(parts) != 2:
69+
continue
70+
71+
category = parts[0][1:] # Remove the leading [
72+
tag = parts[1]
73+
74+
tag_to_category[tag] = category
75+
76+
except Exception as e:
77+
raise ValueError(f"Error reading categorized tags file: {str(e)}")
78+
79+
self._tag_to_category_cache = tag_to_category
80+
return tag_to_category
81+
82+
def _parse_category_adjustments(self, adjustments_string):
83+
"""Parse category adjustments from string format like (category:strength)"""
84+
adjustments = {}
85+
86+
if not adjustments_string.strip():
87+
return adjustments
88+
89+
# Get valid categories for validation
90+
valid_categories_set = set(ALL_CATEGORIES)
91+
92+
# Split by commas and validate each part
93+
parts = [part.strip() for part in adjustments_string.split(',') if part.strip()]
94+
95+
# Pattern for valid format: (category:strength)
96+
pattern = r'^\(([^:]+):([^)]+)\)$'
97+
98+
for part in parts:
99+
match = re.match(pattern, part)
100+
101+
if not match:
102+
# This part doesn't match the expected format
103+
raise ValueError(f"Invalid format '{part}'. Expected format: (category:strength), e.g., (artist:1.4)")
104+
105+
category, strength = match.groups()
106+
category = category.strip()
107+
108+
# Validate that the category exists in the valid categories
109+
if category not in valid_categories_set:
110+
valid_categories_str = ', '.join(sorted(ALL_CATEGORIES))
111+
raise ValueError(f"Invalid category '{category}'. Valid categories are: {valid_categories_str}")
112+
113+
try:
114+
strength_value = float(strength.strip())
115+
adjustments[category] = strength_value
116+
except ValueError:
117+
raise ValueError(f"Invalid strength value '{strength}' for category '{category}'. Must be a number.")
118+
119+
return adjustments
120+
121+
def _normalize_tags(self, tag_string):
122+
"""Normalize a string of tags to a consistent format"""
123+
# Replace newlines with commas
124+
tag_string = tag_string.replace('\r\n', '\n')
125+
tag_string = tag_string.replace('\n', ',')
126+
127+
# Remove multiple consecutive spaces
128+
while ' ' in tag_string:
129+
tag_string = tag_string.replace(' ', ' ')
130+
131+
# Remove multiple consecutive commas
132+
while ',,' in tag_string:
133+
tag_string = tag_string.replace(',,', ',')
134+
135+
# Split on commas and normalize each tag
136+
tags = tag_string.replace(', ', ',').split(',')
137+
138+
return [tag.strip() for tag in tags if tag.strip()]
139+
140+
def _extract_tag_and_weight(self, tag):
141+
"""Extract tag name and existing weight from a tag like 'tag' or '(tag:1.2)'"""
142+
tag = tag.strip()
143+
144+
# Check if tag already has weight in format (tag:weight)
145+
weight_pattern = r'^\(([^:]+):([^)]+)\)$'
146+
match = re.match(weight_pattern, tag)
147+
148+
if match:
149+
tag_name = match.group(1).strip()
150+
try:
151+
weight = float(match.group(2).strip())
152+
return tag_name, weight
153+
except ValueError:
154+
# If weight is invalid, treat as regular tag
155+
return tag, None
156+
else:
157+
return tag, None
158+
159+
def _apply_weight_to_tag(self, tag_name, weight):
160+
"""Apply weight to a tag, formatting it as (tag:weight)"""
161+
if weight == 1.0:
162+
return tag_name
163+
else:
164+
return f"({tag_name}:{weight})"
165+
166+
def adjust_tag_categories(self, input_tags, category_adjustments, preserve_existing_weights=True):
167+
# Load tag-to-category mapping
168+
tag_to_category = self._load_tag_categories()
169+
170+
# Parse category adjustments
171+
adjustments = self._parse_category_adjustments(category_adjustments)
172+
173+
# Normalize input tags
174+
tags = self._normalize_tags(input_tags)
175+
176+
if not tags:
177+
return "", "No input tags provided"
178+
179+
adjusted_tags = []
180+
debug_info = []
181+
182+
for tag in tags:
183+
# Extract tag name and existing weight
184+
tag_name, existing_weight = self._extract_tag_and_weight(tag)
185+
186+
# Normalize tag name (replace spaces with underscores for lookup)
187+
normalized_tag_name = tag_name.replace(' ', '_')
188+
189+
# Find category for this tag
190+
category = tag_to_category.get(normalized_tag_name)
191+
192+
if category and category in adjustments:
193+
# This tag has a category adjustment
194+
adjustment_strength = adjustments[category]
195+
196+
if preserve_existing_weights and existing_weight is not None:
197+
# Keep existing weight
198+
final_weight = existing_weight
199+
adjusted_tags.append(self._apply_weight_to_tag(tag_name, final_weight))
200+
debug_info.append(f"{tag_name} [{category}] - kept existing weight: {final_weight}")
201+
else:
202+
# Apply category adjustment
203+
if existing_weight is not None:
204+
# Multiply existing weight by adjustment
205+
final_weight = existing_weight * adjustment_strength
206+
else:
207+
# Apply adjustment to default weight of 1.0
208+
final_weight = adjustment_strength
209+
210+
adjusted_tags.append(self._apply_weight_to_tag(tag_name, final_weight))
211+
debug_info.append(f"{tag_name} [{category}] - adjusted to: {final_weight}")
212+
else:
213+
# No adjustment for this tag
214+
adjusted_tags.append(tag)
215+
if category:
216+
debug_info.append(f"{tag_name} [{category}] - no adjustment")
217+
else:
218+
debug_info.append(f"{tag_name} [unknown category] - no adjustment")
219+
220+
# Create debug output
221+
debug_output = f"Applied adjustments: {adjustments}\n\nTag adjustments:\n" + "\n".join(debug_info)
222+
223+
return ", ".join(adjusted_tags), debug_output

0 commit comments

Comments
 (0)