Skip to content

Commit 7fe27dc

Browse files
authored
feat: implement agentic power-steering analysis (#2365)
Replaces regex-based power-steering with Claude SDK analysis. Fixes Copilot review issues: NOT INVOKED priority, bounded conversation summary (512K), negation logic fix, UTF-8 encoding, docstring accuracy.
1 parent 1967972 commit 7fe27dc

14 files changed

+3632
-1185
lines changed

.claude/tools/amplihack/hooks/claude_power_steering.py

Lines changed: 240 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
# Try to import Claude SDK
4141
try:
42-
from claude_agent_sdk import ClaudeAgentOptions, query
42+
from claude_agent_sdk import ClaudeAgentOptions, query # type: ignore[import-not-found]
4343

4444
CLAUDE_SDK_AVAILABLE = True
4545
except ImportError:
@@ -51,6 +51,9 @@
5151

5252
# Security constants
5353
MAX_SDK_RESPONSE_LENGTH = 5000
54+
MAX_CONVERSATION_SUMMARY_LENGTH = (
55+
512_000 # Max chars for SDK conversation context (1M token window)
56+
)
5457
SUSPICIOUS_PATTERNS = [
5558
r"<script",
5659
r"javascript:",
@@ -78,6 +81,8 @@
7881
"analyze_claims_sync",
7982
"analyze_if_addressed_sync",
8083
"analyze_consideration_sync",
84+
"analyze_workflow_invocation",
85+
"analyze_workflow_invocation_sync",
8186
"CLAUDE_SDK_AVAILABLE",
8287
]
8388

@@ -239,10 +244,18 @@ async def analyze_consideration(
239244
response_parts = []
240245
async with asyncio.timeout(CHECKER_TIMEOUT):
241246
async for message in query(prompt=prompt, options=options):
242-
if hasattr(message, "text"):
243-
response_parts.append(message.text)
244-
elif hasattr(message, "content"):
245-
response_parts.append(str(message.content))
247+
# Extract text from AssistantMessage content blocks
248+
content = getattr(message, "content", None)
249+
if content is not None:
250+
if isinstance(content, list):
251+
# AssistantMessage: content is list[ContentBlock]
252+
for block in content:
253+
text = getattr(block, "text", None)
254+
if isinstance(text, str):
255+
response_parts.append(text)
256+
elif isinstance(content, str):
257+
# UserMessage: content can be str
258+
response_parts.append(content)
246259

247260
# Join all parts
248261
response = "".join(response_parts)
@@ -346,7 +359,7 @@ def _extract_reason_from_response(response: str) -> str | None:
346359
response: Full SDK response text
347360
348361
Returns:
349-
Extracted reason string (truncated to 200 chars), or generic fallback
362+
Full extracted reason string, or generic fallback
350363
351364
Note:
352365
Looks for patterns like "NOT SATISFIED: reason" or "UNSATISFIED: reason"
@@ -414,19 +427,21 @@ def _log_sdk_error(consideration_id: str, error: Exception) -> None:
414427
sys.stderr.flush()
415428

416429

417-
def _format_conversation_summary(conversation: list[dict], max_length: int | None = None) -> str:
430+
def _format_conversation_summary(
431+
conversation: list[dict], max_length: int = MAX_CONVERSATION_SUMMARY_LENGTH
432+
) -> str:
418433
"""Format conversation summary for analysis.
419434
420435
Args:
421436
conversation: List of message dicts
422-
max_length: Optional maximum summary length (None = unlimited, includes all messages)
437+
max_length: Maximum summary length in characters (default: 50000 to prevent oversized prompts)
423438
424439
Returns:
425440
Formatted conversation summary
426441
427442
Note:
428-
All messages in the conversation are included in the analysis unless max_length is specified.
429-
Individual messages longer than 500 chars are truncated for readability.
443+
Individual messages longer than 500 chars are truncated for readability. The summary
444+
is truncated at max_length characters to prevent oversized SDK prompts.
430445
"""
431446
summary_parts = []
432447
current_length = 0
@@ -474,7 +489,7 @@ def _format_conversation_summary(conversation: list[dict], max_length: int | Non
474489
msg_summary = f"\n**Message {i + 1} ({role}):** {content_text}\n"
475490

476491
# Only check length limit if max_length is specified
477-
if max_length is not None and current_length + len(msg_summary) > max_length:
492+
if current_length + len(msg_summary) > max_length:
478493
truncation_indicator = f"\n[... {len(conversation) - i} more messages ...]"
479494
# Only add truncation indicator if we have room for it
480495
if current_length + len(truncation_indicator) <= max_length:
@@ -545,10 +560,18 @@ async def generate_final_guidance(
545560
response_parts = []
546561
async with asyncio.timeout(CHECKER_TIMEOUT):
547562
async for message in query(prompt=prompt, options=options):
548-
if hasattr(message, "text"):
549-
response_parts.append(message.text)
550-
elif hasattr(message, "content"):
551-
response_parts.append(str(message.content))
563+
# Extract text from AssistantMessage content blocks
564+
content = getattr(message, "content", None)
565+
if content is not None:
566+
if isinstance(content, list):
567+
# AssistantMessage: content is list[ContentBlock]
568+
for block in content:
569+
text = getattr(block, "text", None)
570+
if isinstance(text, str):
571+
response_parts.append(text)
572+
elif isinstance(content, str):
573+
# UserMessage: content can be str
574+
response_parts.append(content)
552575

553576
guidance = "".join(response_parts).strip()
554577

@@ -640,10 +663,18 @@ async def analyze_claims(delta_text: str, project_root: Path) -> list[str]:
640663
response_parts = []
641664
async with asyncio.timeout(CHECKER_TIMEOUT):
642665
async for message in query(prompt=prompt, options=options):
643-
if hasattr(message, "text"):
644-
response_parts.append(message.text)
645-
elif hasattr(message, "content"):
646-
response_parts.append(str(message.content))
666+
# Extract text from AssistantMessage content blocks
667+
content = getattr(message, "content", None)
668+
if content is not None:
669+
if isinstance(content, list):
670+
# AssistantMessage: content is list[ContentBlock]
671+
for block in content:
672+
text = getattr(block, "text", None)
673+
if isinstance(text, str):
674+
response_parts.append(text)
675+
elif isinstance(content, str):
676+
# UserMessage: content can be str
677+
response_parts.append(content)
647678

648679
response = "".join(response_parts).strip()
649680

@@ -755,10 +786,18 @@ async def analyze_if_addressed(
755786
response_parts = []
756787
async with asyncio.timeout(CHECKER_TIMEOUT):
757788
async for message in query(prompt=prompt, options=options):
758-
if hasattr(message, "text"):
759-
response_parts.append(message.text)
760-
elif hasattr(message, "content"):
761-
response_parts.append(str(message.content))
789+
# Extract text from AssistantMessage content blocks
790+
content = getattr(message, "content", None)
791+
if content is not None:
792+
if isinstance(content, list):
793+
# AssistantMessage: content is list[ContentBlock]
794+
for block in content:
795+
text = getattr(block, "text", None)
796+
if isinstance(text, str):
797+
response_parts.append(text)
798+
elif isinstance(content, str):
799+
# UserMessage: content can be str
800+
response_parts.append(content)
762801

763802
response = "".join(response_parts).strip()
764803

@@ -936,6 +975,184 @@ def analyze_consideration_sync(
936975
return (True, None) # Fail-open on any error
937976

938977

978+
async def analyze_workflow_invocation(
979+
conversation: list[dict], session_type: str, project_root: Path
980+
) -> tuple[bool, str | None]:
981+
"""Use Claude SDK to analyze if workflow was properly invoked.
982+
983+
Context-aware analysis that understands multiple valid invocation patterns:
984+
- Explicit Skill tool invocation (Skill("default-workflow"))
985+
- Explicit Read tool invocation (Read(.claude/workflow/DEFAULT_WORKFLOW.md))
986+
- Implicit step-by-step workflow following (shows systematic approach)
987+
- Async completion (PR created for review, CI running)
988+
989+
Args:
990+
conversation: Session messages (list of dicts)
991+
session_type: Session type (DEVELOPMENT, INVESTIGATION, etc.)
992+
project_root: Project root directory
993+
994+
Returns:
995+
Tuple of (valid, reason):
996+
- valid: True if workflow properly invoked or not required
997+
- reason: String explanation if invalid, None if valid
998+
(Fail-open: returns (True, None) on SDK unavailable or errors)
999+
1000+
Note:
1001+
Only validates DEVELOPMENT and INVESTIGATION sessions.
1002+
Other session types return (True, None) immediately.
1003+
"""
1004+
if not CLAUDE_SDK_AVAILABLE:
1005+
return (True, None) # Fail-open if SDK unavailable
1006+
1007+
# Only validate DEVELOPMENT and INVESTIGATION sessions
1008+
if session_type not in ("DEVELOPMENT", "INVESTIGATION"):
1009+
return (True, None)
1010+
1011+
# Format conversation summary
1012+
conv_summary = _format_conversation_summary(conversation)
1013+
1014+
# Context-aware prompt that understands multiple valid patterns
1015+
prompt = f"""Analyze if the workflow was properly invoked in this session.
1016+
1017+
**Session Type**: {session_type}
1018+
1019+
**Session Conversation** ({len(conversation)} messages):
1020+
{conv_summary}
1021+
1022+
## Your Task
1023+
1024+
Determine if the appropriate workflow was properly invoked. A workflow is INVOKED if ANY of these patterns are present:
1025+
1026+
1. **Explicit Skill tool invocation**: Skill(skill="default-workflow") or Skill(skill="investigation-workflow")
1027+
2. **Explicit Read tool invocation**: Read(.claude/workflow/DEFAULT_WORKFLOW.md) or INVESTIGATION_WORKFLOW.md
1028+
3. **Implicit workflow following**: Claude systematically follows workflow steps (shows step-by-step execution)
1029+
4. **Async completion pattern**: PR created for review with CI running (workflow continues asynchronously)
1030+
1031+
**IMPORTANT**: Only flag as NOT INVOKED if there is NO evidence of ANY systematic workflow approach.
1032+
1033+
**Respond with ONE of:**
1034+
- "INVOKED: [brief evidence of which pattern was used]" if workflow was properly invoked
1035+
- "NOT INVOKED: [brief reason]" if no workflow approach was used
1036+
1037+
Be conservative - default to INVOKED unless there is clear evidence of ad-hoc work without systematic approach.
1038+
"""
1039+
1040+
try:
1041+
options = ClaudeAgentOptions(
1042+
cwd=str(project_root),
1043+
)
1044+
1045+
# Query Claude with timeout
1046+
response_parts = []
1047+
async with asyncio.timeout(CHECKER_TIMEOUT):
1048+
async for message in query(prompt=prompt, options=options):
1049+
# Extract text from AssistantMessage content blocks
1050+
content = getattr(message, "content", None)
1051+
if content is not None:
1052+
if isinstance(content, list):
1053+
# AssistantMessage: content is list[ContentBlock]
1054+
for block in content:
1055+
text = getattr(block, "text", None)
1056+
if isinstance(text, str):
1057+
response_parts.append(text)
1058+
elif isinstance(content, str):
1059+
# UserMessage: content can be str
1060+
response_parts.append(content)
1061+
1062+
# Join all parts
1063+
response = "".join(response_parts)
1064+
1065+
# Sanitize HTML before processing
1066+
response = _sanitize_html(response)
1067+
1068+
# Validate response before processing
1069+
if not _validate_sdk_response(response):
1070+
# Security validation failed - fail-open (assume valid)
1071+
return (True, None)
1072+
1073+
response_stripped = response.lstrip()
1074+
response_lower = response_stripped.lower()
1075+
1076+
# Check for NOT INVOKED indicator first to avoid matching "invoked" in "not invoked"
1077+
if response_lower.startswith("not invoked:") or response_lower.startswith("not invoked"):
1078+
# Extract reason from response
1079+
idx = response_lower.find("not invoked:")
1080+
if idx != -1:
1081+
reason = response_stripped[idx + 12 :].strip()
1082+
# Clean up and truncate
1083+
if reason and len(reason) > 10:
1084+
return (False, reason[:200])
1085+
return (False, "Workflow not properly invoked")
1086+
1087+
# Check for INVOKED indicator
1088+
if response_lower.startswith("invoked:") or response_lower.startswith("invoked"):
1089+
return (True, None)
1090+
1091+
# Ambiguous response - fail-open (assume valid)
1092+
return (True, None)
1093+
1094+
except Exception as e:
1095+
# Log error and fail-open on any error
1096+
_log_sdk_error("workflow_invocation", e)
1097+
return (True, None)
1098+
1099+
1100+
def analyze_workflow_invocation_sync(
1101+
conversation: list[dict], session_type: str, project_root: Path
1102+
) -> tuple[bool, str | None]:
1103+
"""Synchronous wrapper for analyze_workflow_invocation with shutdown detection.
1104+
1105+
During shutdown, returns (True, None) immediately to prevent asyncio hang.
1106+
Otherwise, runs async analysis to check if workflow was properly invoked.
1107+
1108+
Args:
1109+
conversation: Session messages
1110+
session_type: Session type (DEVELOPMENT, INVESTIGATION, etc.)
1111+
project_root: Project root
1112+
1113+
Returns:
1114+
Tuple of (valid, reason):
1115+
- valid: True if workflow properly invoked or not required
1116+
- reason: String explanation if invalid, None if valid
1117+
Returns (True, None) during shutdown
1118+
1119+
Shutdown Behavior:
1120+
When AMPLIHACK_SHUTDOWN_IN_PROGRESS=1, immediately returns (True, None)
1121+
without starting async operation. This prevents asyncio event loop hangs
1122+
during application teardown.
1123+
1124+
Fail-open philosophy: Assumes workflow is valid during shutdown
1125+
to never block the user from exiting.
1126+
1127+
Example:
1128+
>>> # Normal operation - runs full analysis
1129+
>>> conversation = [{"role": "user", "content": "Implement feature"}]
1130+
>>> valid, reason = analyze_workflow_invocation_sync(
1131+
... conversation, "DEVELOPMENT", Path.cwd()
1132+
... )
1133+
>>> isinstance(valid, bool)
1134+
True
1135+
1136+
>>> # During shutdown - returns valid immediately
1137+
>>> os.environ["AMPLIHACK_SHUTDOWN_IN_PROGRESS"] = "1"
1138+
>>> valid, reason = analyze_workflow_invocation_sync(
1139+
... conversation, "DEVELOPMENT", Path.cwd()
1140+
... )
1141+
>>> valid
1142+
True
1143+
>>> reason is None
1144+
True
1145+
"""
1146+
# Shutdown check: bypass async operation during teardown
1147+
if is_shutting_down():
1148+
return (True, None) # Fail-open: assume valid during shutdown
1149+
1150+
try:
1151+
return asyncio.run(analyze_workflow_invocation(conversation, session_type, project_root))
1152+
except Exception:
1153+
return (True, None) # Fail-open on any error
1154+
1155+
9391156
# For testing
9401157
if __name__ == "__main__":
9411158
import argparse

0 commit comments

Comments
 (0)