1+ import os
2+ import re
3+
4+ # Import the global categories list from raffle.py
5+ from .raffle import ALL_CATEGORIES
6+
7+ class TagCategoryStrength :
8+ @classmethod
9+ def INPUT_TYPES (s ):
10+ return {
11+ "required" : {
12+ "input_tags" : ("STRING" , {
13+ "multiline" : True ,
14+ "forceInput" : True ,
15+ "default" : "" ,
16+ "tooltip" : "Input tags to adjust (comma-separated)"
17+ }),
18+ "category_adjustments" : ("STRING" , {
19+ "multiline" : True ,
20+ "default" : "" ,
21+ "tooltip" : "Category adjustments in format: (category_name:strength), e.g., (artist:1.4), (meta:0.5), (poses:1.2)"
22+ }),
23+ },
24+ "optional" : {
25+ "preserve_existing_weights" : ("BOOLEAN" , {
26+ "default" : True ,
27+ "tooltip" : "If enabled, tags that already have weights like (tag:1.2) will keep their existing weights instead of being adjusted"
28+ }),
29+ }
30+ }
31+
32+ CATEGORY = "Raffle"
33+ RETURN_TYPES = ("STRING" , "STRING" )
34+ RETURN_NAMES = ("Adjusted tags" , "Debug info" )
35+ OUTPUT_TOOLTIPS = (
36+ "Tags with category-based strength adjustments applied" ,
37+ "Information about which tags were adjusted and their categories"
38+ )
39+ FUNCTION = "adjust_tag_categories"
40+
41+ def __init__ (self ):
42+ self ._tag_to_category_cache = None
43+
44+
45+
46+ def _load_tag_categories (self ):
47+ """Load the tag-to-category mapping from categorized_tags.txt"""
48+ if self ._tag_to_category_cache is not None :
49+ return self ._tag_to_category_cache
50+
51+ extension_path = os .path .normpath (os .path .dirname (__file__ ))
52+ categorized_tags_file_path = os .path .join (extension_path , "lists" , "categorized_tags.txt" )
53+
54+ if not os .path .exists (categorized_tags_file_path ):
55+ raise ValueError (f"Categorized tags file not found at { categorized_tags_file_path } " )
56+
57+ tag_to_category = {}
58+
59+ try :
60+ with open (categorized_tags_file_path , 'r' , encoding = 'utf-8' ) as f :
61+ for line in f :
62+ line = line .strip ()
63+ if not line :
64+ continue
65+
66+ # Parse format: [category] tag
67+ parts = line .split ('] ' , 1 )
68+ if len (parts ) != 2 :
69+ continue
70+
71+ category = parts [0 ][1 :] # Remove the leading [
72+ tag = parts [1 ]
73+
74+ tag_to_category [tag ] = category
75+
76+ except Exception as e :
77+ raise ValueError (f"Error reading categorized tags file: { str (e )} " )
78+
79+ self ._tag_to_category_cache = tag_to_category
80+ return tag_to_category
81+
82+ def _parse_category_adjustments (self , adjustments_string ):
83+ """Parse category adjustments from string format like (category:strength)"""
84+ adjustments = {}
85+
86+ if not adjustments_string .strip ():
87+ return adjustments
88+
89+ # Get valid categories for validation
90+ valid_categories_set = set (ALL_CATEGORIES )
91+
92+ # Split by commas and validate each part
93+ parts = [part .strip () for part in adjustments_string .split (',' ) if part .strip ()]
94+
95+ # Pattern for valid format: (category:strength)
96+ pattern = r'^\(([^:]+):([^)]+)\)$'
97+
98+ for part in parts :
99+ match = re .match (pattern , part )
100+
101+ if not match :
102+ # This part doesn't match the expected format
103+ raise ValueError (f"Invalid format '{ part } '. Expected format: (category:strength), e.g., (artist:1.4)" )
104+
105+ category , strength = match .groups ()
106+ category = category .strip ()
107+
108+ # Validate that the category exists in the valid categories
109+ if category not in valid_categories_set :
110+ valid_categories_str = ', ' .join (sorted (ALL_CATEGORIES ))
111+ raise ValueError (f"Invalid category '{ category } '. Valid categories are: { valid_categories_str } " )
112+
113+ try :
114+ strength_value = float (strength .strip ())
115+ adjustments [category ] = strength_value
116+ except ValueError :
117+ raise ValueError (f"Invalid strength value '{ strength } ' for category '{ category } '. Must be a number." )
118+
119+ return adjustments
120+
121+ def _normalize_tags (self , tag_string ):
122+ """Normalize a string of tags to a consistent format"""
123+ # Replace newlines with commas
124+ tag_string = tag_string .replace ('\r \n ' , '\n ' )
125+ tag_string = tag_string .replace ('\n ' , ',' )
126+
127+ # Remove multiple consecutive spaces
128+ while ' ' in tag_string :
129+ tag_string = tag_string .replace (' ' , ' ' )
130+
131+ # Remove multiple consecutive commas
132+ while ',,' in tag_string :
133+ tag_string = tag_string .replace (',,' , ',' )
134+
135+ # Split on commas and normalize each tag
136+ tags = tag_string .replace (', ' , ',' ).split (',' )
137+
138+ return [tag .strip () for tag in tags if tag .strip ()]
139+
140+ def _extract_tag_and_weight (self , tag ):
141+ """Extract tag name and existing weight from a tag like 'tag' or '(tag:1.2)'"""
142+ tag = tag .strip ()
143+
144+ # Check if tag already has weight in format (tag:weight)
145+ weight_pattern = r'^\(([^:]+):([^)]+)\)$'
146+ match = re .match (weight_pattern , tag )
147+
148+ if match :
149+ tag_name = match .group (1 ).strip ()
150+ try :
151+ weight = float (match .group (2 ).strip ())
152+ return tag_name , weight
153+ except ValueError :
154+ # If weight is invalid, treat as regular tag
155+ return tag , None
156+ else :
157+ return tag , None
158+
159+ def _apply_weight_to_tag (self , tag_name , weight ):
160+ """Apply weight to a tag, formatting it as (tag:weight)"""
161+ if weight == 1.0 :
162+ return tag_name
163+ else :
164+ return f"({ tag_name } :{ weight } )"
165+
166+ def adjust_tag_categories (self , input_tags , category_adjustments , preserve_existing_weights = True ):
167+ # Load tag-to-category mapping
168+ tag_to_category = self ._load_tag_categories ()
169+
170+ # Parse category adjustments
171+ adjustments = self ._parse_category_adjustments (category_adjustments )
172+
173+ # Normalize input tags
174+ tags = self ._normalize_tags (input_tags )
175+
176+ if not tags :
177+ return "" , "No input tags provided"
178+
179+ adjusted_tags = []
180+ debug_info = []
181+
182+ for tag in tags :
183+ # Extract tag name and existing weight
184+ tag_name , existing_weight = self ._extract_tag_and_weight (tag )
185+
186+ # Normalize tag name (replace spaces with underscores for lookup)
187+ normalized_tag_name = tag_name .replace (' ' , '_' )
188+
189+ # Find category for this tag
190+ category = tag_to_category .get (normalized_tag_name )
191+
192+ if category and category in adjustments :
193+ # This tag has a category adjustment
194+ adjustment_strength = adjustments [category ]
195+
196+ if preserve_existing_weights and existing_weight is not None :
197+ # Keep existing weight
198+ final_weight = existing_weight
199+ adjusted_tags .append (self ._apply_weight_to_tag (tag_name , final_weight ))
200+ debug_info .append (f"{ tag_name } [{ category } ] - kept existing weight: { final_weight } " )
201+ else :
202+ # Apply category adjustment
203+ if existing_weight is not None :
204+ # Multiply existing weight by adjustment
205+ final_weight = existing_weight * adjustment_strength
206+ else :
207+ # Apply adjustment to default weight of 1.0
208+ final_weight = adjustment_strength
209+
210+ adjusted_tags .append (self ._apply_weight_to_tag (tag_name , final_weight ))
211+ debug_info .append (f"{ tag_name } [{ category } ] - adjusted to: { final_weight } " )
212+ else :
213+ # No adjustment for this tag
214+ adjusted_tags .append (tag )
215+ if category :
216+ debug_info .append (f"{ tag_name } [{ category } ] - no adjustment" )
217+ else :
218+ debug_info .append (f"{ tag_name } [unknown category] - no adjustment" )
219+
220+ # Create debug output
221+ debug_output = f"Applied adjustments: { adjustments } \n \n Tag adjustments:\n " + "\n " .join (debug_info )
222+
223+ return ", " .join (adjusted_tags ), debug_output
0 commit comments