Skip to content

Commit b333d6e

Browse files
fix: post tool hook for auto triggering parallel agent
1 parent 3d0d1ec commit b333d6e

File tree

1 file changed

+167
-13
lines changed

1 file changed

+167
-13
lines changed

claude-code/hooks/post_tool_use.py

Lines changed: 167 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,51 @@
66
import json
77
import os
88
import sys
9+
import logging
10+
import hashlib
911
from pathlib import Path
12+
from datetime import datetime
1013

11-
def main():
14+
# --- Session-specific file paths ---
15+
def get_session_specific_paths():
16+
"""Generate session-specific paths based on working directory."""
17+
cwd = os.getcwd()
18+
# Create a short hash of the working directory for unique identification
19+
cwd_hash = hashlib.md5(cwd.encode()).hexdigest()[:8]
20+
21+
return {
22+
'supervisor_log': f"/tmp/claude_supervisor_{cwd_hash}.log",
23+
'state_file': f"/tmp/claude_todo_hook_{cwd_hash}.state",
24+
'project_name': os.path.basename(cwd)
25+
}
26+
27+
# Get session-specific paths
28+
paths = get_session_specific_paths()
29+
30+
# --- Logging Configuration ---
31+
logging.basicConfig(
32+
level=logging.INFO,
33+
format='%(asctime)s - %(levelname)s - [%(project)s] - %(message)s',
34+
filename=paths['supervisor_log'],
35+
filemode='a'
36+
)
37+
38+
# Add project name to all log records
39+
old_factory = logging.getLogRecordFactory()
40+
def record_factory(*args, **kwargs):
41+
record = old_factory(*args, **kwargs)
42+
record.project = paths['project_name']
43+
return record
44+
logging.setLogRecordFactory(record_factory)
45+
46+
def log_to_json_file(input_data):
47+
"""Original functionality: log all tool usage to JSON file."""
1248
try:
13-
# Read JSON input from stdin
14-
input_data = json.load(sys.stdin)
15-
1649
# Ensure log directory exists
1750
log_dir = Path.cwd() / 'logs'
1851
log_dir.mkdir(parents=True, exist_ok=True)
1952
log_path = log_dir / 'post_tool_use.json'
20-
53+
2154
# Read existing log data or initialize empty list
2255
if log_path.exists():
2356
with open(log_path, 'r') as f:
@@ -27,21 +60,142 @@ def main():
2760
log_data = []
2861
else:
2962
log_data = []
30-
63+
64+
# Add timestamp to the log entry
65+
input_data['logged_at'] = datetime.now().isoformat()
66+
3167
# Append new data
3268
log_data.append(input_data)
33-
69+
3470
# Write back to file with formatting
3571
with open(log_path, 'w') as f:
3672
json.dump(log_data, f, indent=2)
37-
73+
74+
return True
75+
except Exception as e:
76+
logging.error(f"Failed to log to JSON file: {e}")
77+
return False
78+
79+
def handle_todo_write_reflection(input_data):
80+
"""Handle TodoWrite-specific reflection prompting."""
81+
try:
82+
tool_input_data = input_data.get("tool_input", {})
83+
todo_objects = tool_input_data.get("todos", [])
84+
85+
if not todo_objects:
86+
logging.info("TodoWrite called, but 'todos' list is empty. Skipping reflection.")
87+
return None
88+
89+
# Extract todo content for hashing
90+
tasks_to_process_content = [task.get("content", "") for task in todo_objects]
91+
todo_content_full = "\n".join(tasks_to_process_content)
92+
93+
# Calculate hash of current todo list
94+
current_hash = hashlib.md5(todo_content_full.encode()).hexdigest()
95+
96+
# Check if state file exists and compare hashes
97+
last_hash = ""
98+
if os.path.exists(paths['state_file']):
99+
try:
100+
with open(paths['state_file'], 'r') as f:
101+
state_data = json.load(f)
102+
last_hash = state_data.get('hash', '')
103+
last_time = state_data.get('timestamp', '')
104+
logging.debug(f"Last state: hash={last_hash[:8]}..., time={last_time}")
105+
except (json.JSONDecodeError, IOError):
106+
# State file corrupted or old format, treat as new
107+
pass
108+
109+
if current_hash == last_hash:
110+
logging.info("Todo list has not changed. Skipping reflection prompt.")
111+
return None
112+
113+
logging.info(f"New todo list detected (hash: {current_hash[:8]}...). Preparing reflection prompt.")
114+
115+
# The reflection prompt to inject
116+
reflection_prompt = """
117+
**Supervisor's Prompt: Review and Parallelize the Plan**
118+
119+
The initial plan has been drafted. Now, **think** to optimize its execution.
120+
121+
1. **Analyze Dependencies**: Critically review the list of tasks.
122+
2. **Group for Parallelism**: Identify any tasks that are independent and can be executed concurrently. Group them into a parallel stage.
123+
3. **Format for Parallel Execution**: To run a group of tasks in parallel, you **must** place multiple `<invoke name="Task">` calls inside a **single** `<function_calls>` block in your response.
124+
125+
Reminder of example format for running two tasks in parallel:
126+
```xml
127+
<function_calls>
128+
<invoke name="Task">
129+
<parameter name="description">First parallel task...</parameter>
130+
<parameter name="prompt">Details for the first task...</parameter>
131+
<parameter name="subagent_type">appropriate-agent-type</parameter>
132+
</invoke>
133+
<invoke name="Task">
134+
<parameter name="description">Second parallel task...</parameter>
135+
<parameter name="prompt">Details for the second task...</parameter>
136+
<parameter name="subagent_type">appropriate-agent-type</parameter>
137+
</invoke>
138+
</function_calls>
139+
```
140+
141+
Please present your analysis of parallel stages and then proceed with the first stage using the correct format.
142+
"""
143+
144+
# Save new state with timestamp
145+
state_data = {
146+
'hash': current_hash,
147+
'timestamp': datetime.now().isoformat(),
148+
'todo_count': len(todo_objects)
149+
}
150+
with open(paths['state_file'], 'w') as f:
151+
json.dump(state_data, f, indent=2)
152+
logging.info(f"Updated state file with new hash: {current_hash[:8]}...")
153+
154+
return reflection_prompt
155+
156+
except Exception as e:
157+
logging.exception(f"Error in TodoWrite reflection handler: {e}")
158+
return None
159+
160+
def main():
161+
"""Main entry point for the hook."""
162+
logging.info("--- Post-Tool-Use Hook Triggered ---")
163+
164+
try:
165+
# Read JSON input from stdin
166+
input_data = json.load(sys.stdin)
167+
168+
# Log tool name for debugging
169+
tool_name = input_data.get("tool_name", "unknown")
170+
logging.info(f"Tool used: {tool_name}")
171+
172+
# Always log to JSON file (original functionality)
173+
log_to_json_file(input_data)
174+
175+
# Check if this is a TodoWrite tool call
176+
if tool_name == "TodoWrite":
177+
reflection_prompt = handle_todo_write_reflection(input_data)
178+
179+
if reflection_prompt:
180+
# Return the reflection prompt to Claude
181+
response = {
182+
"hookSpecificOutput": {
183+
"hookEventName": "PostToolUse",
184+
"additionalContext": reflection_prompt
185+
}
186+
}
187+
logging.info("Injecting reflection prompt for task parallelization.")
188+
print(json.dumps(response), flush=True)
189+
sys.exit(0)
190+
191+
# For all other tools or when no reflection needed, exit cleanly
38192
sys.exit(0)
39-
40-
except json.JSONDecodeError:
41-
# Handle JSON decode errors gracefully
193+
194+
except json.JSONDecodeError as e:
195+
logging.error(f"Failed to parse JSON input: {e}")
42196
sys.exit(0)
43-
except Exception:
44-
# Exit cleanly on any other error
197+
except Exception as e:
198+
logging.exception("An unexpected error occurred in the post-tool-use hook.")
45199
sys.exit(0)
46200

47201
if __name__ == '__main__':

0 commit comments

Comments
 (0)