Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 240 additions & 23 deletions .claude/tools/amplihack/hooks/claude_power_steering.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

# Try to import Claude SDK
try:
from claude_agent_sdk import ClaudeAgentOptions, query
from claude_agent_sdk import ClaudeAgentOptions, query # type: ignore[import-not-found]

CLAUDE_SDK_AVAILABLE = True
except ImportError:
Expand All @@ -51,6 +51,9 @@

# Security constants
MAX_SDK_RESPONSE_LENGTH = 5000
MAX_CONVERSATION_SUMMARY_LENGTH = (
512_000 # Max chars for SDK conversation context (1M token window)
)
SUSPICIOUS_PATTERNS = [
r"<script",
r"javascript:",
Expand Down Expand Up @@ -78,6 +81,8 @@
"analyze_claims_sync",
"analyze_if_addressed_sync",
"analyze_consideration_sync",
"analyze_workflow_invocation",
"analyze_workflow_invocation_sync",
"CLAUDE_SDK_AVAILABLE",
]

Expand Down Expand Up @@ -239,10 +244,18 @@ async def analyze_consideration(
response_parts = []
async with asyncio.timeout(CHECKER_TIMEOUT):
async for message in query(prompt=prompt, options=options):
if hasattr(message, "text"):
response_parts.append(message.text)
elif hasattr(message, "content"):
response_parts.append(str(message.content))
# Extract text from AssistantMessage content blocks
content = getattr(message, "content", None)
if content is not None:
if isinstance(content, list):
# AssistantMessage: content is list[ContentBlock]
for block in content:
text = getattr(block, "text", None)
if isinstance(text, str):
response_parts.append(text)
elif isinstance(content, str):
# UserMessage: content can be str
response_parts.append(content)

# Join all parts
response = "".join(response_parts)
Expand Down Expand Up @@ -346,7 +359,7 @@ def _extract_reason_from_response(response: str) -> str | None:
response: Full SDK response text

Returns:
Extracted reason string (truncated to 200 chars), or generic fallback
Full extracted reason string, or generic fallback

Note:
Looks for patterns like "NOT SATISFIED: reason" or "UNSATISFIED: reason"
Expand Down Expand Up @@ -414,19 +427,21 @@ def _log_sdk_error(consideration_id: str, error: Exception) -> None:
sys.stderr.flush()


def _format_conversation_summary(conversation: list[dict], max_length: int | None = None) -> str:
def _format_conversation_summary(
conversation: list[dict], max_length: int = MAX_CONVERSATION_SUMMARY_LENGTH
) -> str:
"""Format conversation summary for analysis.

Args:
conversation: List of message dicts
max_length: Optional maximum summary length (None = unlimited, includes all messages)
max_length: Maximum summary length in characters (default: 50000 to prevent oversized prompts)

Returns:
Formatted conversation summary

Note:
All messages in the conversation are included in the analysis unless max_length is specified.
Individual messages longer than 500 chars are truncated for readability.
Individual messages longer than 500 chars are truncated for readability. The summary
is truncated at max_length characters to prevent oversized SDK prompts.
"""
summary_parts = []
current_length = 0
Expand Down Expand Up @@ -474,7 +489,7 @@ def _format_conversation_summary(conversation: list[dict], max_length: int | Non
msg_summary = f"\n**Message {i + 1} ({role}):** {content_text}\n"

# Only check length limit if max_length is specified
if max_length is not None and current_length + len(msg_summary) > max_length:
if current_length + len(msg_summary) > max_length:
truncation_indicator = f"\n[... {len(conversation) - i} more messages ...]"
# Only add truncation indicator if we have room for it
if current_length + len(truncation_indicator) <= max_length:
Expand Down Expand Up @@ -545,10 +560,18 @@ async def generate_final_guidance(
response_parts = []
async with asyncio.timeout(CHECKER_TIMEOUT):
async for message in query(prompt=prompt, options=options):
if hasattr(message, "text"):
response_parts.append(message.text)
elif hasattr(message, "content"):
response_parts.append(str(message.content))
# Extract text from AssistantMessage content blocks
content = getattr(message, "content", None)
if content is not None:
if isinstance(content, list):
# AssistantMessage: content is list[ContentBlock]
for block in content:
text = getattr(block, "text", None)
if isinstance(text, str):
response_parts.append(text)
elif isinstance(content, str):
# UserMessage: content can be str
response_parts.append(content)

guidance = "".join(response_parts).strip()

Expand Down Expand Up @@ -640,10 +663,18 @@ async def analyze_claims(delta_text: str, project_root: Path) -> list[str]:
response_parts = []
async with asyncio.timeout(CHECKER_TIMEOUT):
async for message in query(prompt=prompt, options=options):
if hasattr(message, "text"):
response_parts.append(message.text)
elif hasattr(message, "content"):
response_parts.append(str(message.content))
# Extract text from AssistantMessage content blocks
content = getattr(message, "content", None)
if content is not None:
if isinstance(content, list):
# AssistantMessage: content is list[ContentBlock]
for block in content:
text = getattr(block, "text", None)
if isinstance(text, str):
response_parts.append(text)
elif isinstance(content, str):
# UserMessage: content can be str
response_parts.append(content)

response = "".join(response_parts).strip()

Expand Down Expand Up @@ -755,10 +786,18 @@ async def analyze_if_addressed(
response_parts = []
async with asyncio.timeout(CHECKER_TIMEOUT):
async for message in query(prompt=prompt, options=options):
if hasattr(message, "text"):
response_parts.append(message.text)
elif hasattr(message, "content"):
response_parts.append(str(message.content))
# Extract text from AssistantMessage content blocks
content = getattr(message, "content", None)
if content is not None:
if isinstance(content, list):
# AssistantMessage: content is list[ContentBlock]
for block in content:
text = getattr(block, "text", None)
if isinstance(text, str):
response_parts.append(text)
elif isinstance(content, str):
# UserMessage: content can be str
response_parts.append(content)

response = "".join(response_parts).strip()

Expand Down Expand Up @@ -936,6 +975,184 @@ def analyze_consideration_sync(
return (True, None) # Fail-open on any error


async def analyze_workflow_invocation(
conversation: list[dict], session_type: str, project_root: Path
) -> tuple[bool, str | None]:
"""Use Claude SDK to analyze if workflow was properly invoked.

Context-aware analysis that understands multiple valid invocation patterns:
- Explicit Skill tool invocation (Skill("default-workflow"))
- Explicit Read tool invocation (Read(.claude/workflow/DEFAULT_WORKFLOW.md))
- Implicit step-by-step workflow following (shows systematic approach)
- Async completion (PR created for review, CI running)

Args:
conversation: Session messages (list of dicts)
session_type: Session type (DEVELOPMENT, INVESTIGATION, etc.)
project_root: Project root directory

Returns:
Tuple of (valid, reason):
- valid: True if workflow properly invoked or not required
- reason: String explanation if invalid, None if valid
(Fail-open: returns (True, None) on SDK unavailable or errors)

Note:
Only validates DEVELOPMENT and INVESTIGATION sessions.
Other session types return (True, None) immediately.
"""
if not CLAUDE_SDK_AVAILABLE:
return (True, None) # Fail-open if SDK unavailable

# Only validate DEVELOPMENT and INVESTIGATION sessions
if session_type not in ("DEVELOPMENT", "INVESTIGATION"):
return (True, None)

# Format conversation summary
conv_summary = _format_conversation_summary(conversation)

# Context-aware prompt that understands multiple valid patterns
prompt = f"""Analyze if the workflow was properly invoked in this session.

**Session Type**: {session_type}

**Session Conversation** ({len(conversation)} messages):
{conv_summary}

## Your Task

Determine if the appropriate workflow was properly invoked. A workflow is INVOKED if ANY of these patterns are present:

1. **Explicit Skill tool invocation**: Skill(skill="default-workflow") or Skill(skill="investigation-workflow")
2. **Explicit Read tool invocation**: Read(.claude/workflow/DEFAULT_WORKFLOW.md) or INVESTIGATION_WORKFLOW.md
3. **Implicit workflow following**: Claude systematically follows workflow steps (shows step-by-step execution)
4. **Async completion pattern**: PR created for review with CI running (workflow continues asynchronously)

**IMPORTANT**: Only flag as NOT INVOKED if there is NO evidence of ANY systematic workflow approach.

**Respond with ONE of:**
- "INVOKED: [brief evidence of which pattern was used]" if workflow was properly invoked
- "NOT INVOKED: [brief reason]" if no workflow approach was used

Be conservative - default to INVOKED unless there is clear evidence of ad-hoc work without systematic approach.
"""

try:
options = ClaudeAgentOptions(
cwd=str(project_root),
)

# Query Claude with timeout
response_parts = []
async with asyncio.timeout(CHECKER_TIMEOUT):
async for message in query(prompt=prompt, options=options):
# Extract text from AssistantMessage content blocks
content = getattr(message, "content", None)
if content is not None:
if isinstance(content, list):
# AssistantMessage: content is list[ContentBlock]
for block in content:
text = getattr(block, "text", None)
if isinstance(text, str):
response_parts.append(text)
elif isinstance(content, str):
# UserMessage: content can be str
response_parts.append(content)

# Join all parts
response = "".join(response_parts)

# Sanitize HTML before processing
response = _sanitize_html(response)

# Validate response before processing
if not _validate_sdk_response(response):
# Security validation failed - fail-open (assume valid)
return (True, None)

response_stripped = response.lstrip()
response_lower = response_stripped.lower()

# Check for NOT INVOKED indicator first to avoid matching "invoked" in "not invoked"
if response_lower.startswith("not invoked:") or response_lower.startswith("not invoked"):
# Extract reason from response
idx = response_lower.find("not invoked:")
if idx != -1:
reason = response_stripped[idx + 12 :].strip()
# Clean up and truncate
if reason and len(reason) > 10:
return (False, reason[:200])
return (False, "Workflow not properly invoked")

# Check for INVOKED indicator
if response_lower.startswith("invoked:") or response_lower.startswith("invoked"):
return (True, None)

# Ambiguous response - fail-open (assume valid)
return (True, None)

except Exception as e:
# Log error and fail-open on any error
_log_sdk_error("workflow_invocation", e)
return (True, None)


def analyze_workflow_invocation_sync(
conversation: list[dict], session_type: str, project_root: Path
) -> tuple[bool, str | None]:
"""Synchronous wrapper for analyze_workflow_invocation with shutdown detection.

During shutdown, returns (True, None) immediately to prevent asyncio hang.
Otherwise, runs async analysis to check if workflow was properly invoked.

Args:
conversation: Session messages
session_type: Session type (DEVELOPMENT, INVESTIGATION, etc.)
project_root: Project root

Returns:
Tuple of (valid, reason):
- valid: True if workflow properly invoked or not required
- reason: String explanation if invalid, None if valid
Returns (True, None) during shutdown

Shutdown Behavior:
When AMPLIHACK_SHUTDOWN_IN_PROGRESS=1, immediately returns (True, None)
without starting async operation. This prevents asyncio event loop hangs
during application teardown.

Fail-open philosophy: Assumes workflow is valid during shutdown
to never block the user from exiting.

Example:
>>> # Normal operation - runs full analysis
>>> conversation = [{"role": "user", "content": "Implement feature"}]
>>> valid, reason = analyze_workflow_invocation_sync(
... conversation, "DEVELOPMENT", Path.cwd()
... )
>>> isinstance(valid, bool)
True

>>> # During shutdown - returns valid immediately
>>> os.environ["AMPLIHACK_SHUTDOWN_IN_PROGRESS"] = "1"
>>> valid, reason = analyze_workflow_invocation_sync(
... conversation, "DEVELOPMENT", Path.cwd()
... )
>>> valid
True
>>> reason is None
True
"""
# Shutdown check: bypass async operation during teardown
if is_shutting_down():
return (True, None) # Fail-open: assume valid during shutdown

try:
return asyncio.run(analyze_workflow_invocation(conversation, session_type, project_root))
except Exception:
return (True, None) # Fail-open on any error


# For testing
if __name__ == "__main__":
import argparse
Expand Down
Loading