1- from typing import Dict , Any , Optional , Callable , TypeVar , ParamSpec , Union
1+ from typing import Dict , Any , Optional , Callable , TypeVar , ParamSpec , Union , Awaitable
22from traceloop .sdk .evaluator .config import EvaluatorDetails
33from traceloop .sdk .evaluator .evaluator import Evaluator
4- from .types import InputExtractor , OutputSchema
4+ from .types import OutputSchema
55import httpx
66import asyncio
77from functools import wraps
1717def guardrail (
1818 evaluator : EvaluatorSpec ,
1919 on_evaluation_complete : Optional [Callable [[OutputSchema , Any ], Any ]] = None
20- ):
20+ ) -> Callable [[ Callable [ P , Awaitable [ Dict [ str , Any ]]]], Callable [ P , Awaitable [ Dict [ str , Any ]]]] :
2121 """
2222 Decorator that executes a guardrails evaluator on the decorated function's output.
2323
@@ -31,35 +31,57 @@ def guardrail(
3131 Result from on_evaluation_complete callback if provided, otherwise original result or error message
3232 """
3333 # Extract evaluator details as tuple (slug, version, config, required_fields) - same pattern as experiments
34+ slug : str
35+ evaluator_version : Optional [str ]
36+ evaluator_config : Optional [Dict [str , Any ]]
37+ required_input_fields : Optional [list [str ]]
38+
3439 if isinstance (evaluator , str ):
3540 # Simple string slug - use default field mapping
36- evaluator_details = (evaluator , None , None , None )
41+ slug = evaluator
42+ evaluator_version = None
43+ evaluator_config = None
44+ required_input_fields = None
3745 elif isinstance (evaluator , EvaluatorDetails ):
3846 # EvaluatorDetails object with config
39- evaluator_details = (
40- evaluator .slug ,
41- evaluator .version ,
42- evaluator .config ,
43- evaluator .required_input_fields
44- )
47+ slug = evaluator .slug
48+ evaluator_version = evaluator .version
49+ evaluator_config = evaluator .config
50+ required_input_fields = evaluator .required_input_fields
4551 else :
4652 raise ValueError (f"evaluator must be str or EvaluatorDetails, got { type (evaluator )} " )
4753
48- slug , evaluator_version , evaluator_config , required_input_fields = evaluator_details
49-
5054 def decorator (func : Callable [P , R ]) -> Callable [P , Dict [str , Any ]]:
5155 @wraps (func )
5256 async def async_wrapper (* args : P .args , ** kwargs : P .kwargs ) -> Dict [str , Any ]:
5357 # Execute the original function - should return a dict with fields matching required_input_fields
54- original_result = await func (* args , ** kwargs )
58+ result = await func (* args , ** kwargs ) # type: ignore[misc]
5559
56- # Validate that original_result is a dict
57- if not isinstance (original_result , dict ):
60+ # Ensure we have a dict
61+ if not isinstance (result , dict ):
5862 raise ValueError (
59- f"Function { func .__name__ } must return a dict, got { type (original_result )} . "
63+ f"Function { func .__name__ } must return a dict, got { type (result )} . "
6064 f"Required fields: { required_input_fields or 'unknown' } "
6165 )
6266
67+ original_result : Dict [str , Any ] = result
68+
69+ # Build evaluator_data based on required_input_fields or use all fields from result
70+ evaluator_data : Dict [str , str ] = {}
71+ if required_input_fields :
72+ # Use only the required fields from the function result
73+ for field in required_input_fields :
74+ if field not in original_result :
75+ raise ValueError (
76+ f"Function { func .__name__ } must return dict with field '{ field } '. "
77+ f"Got: { list (original_result .keys ())} "
78+ )
79+ evaluator_data [field ] = str (original_result [field ])
80+ else :
81+ # No required fields specified, use all fields from result
82+ for field , value in original_result .items ():
83+ evaluator_data [field ] = str (value )
84+
6385 try :
6486 from traceloop .sdk import Traceloop
6587 client_instance = Traceloop .get ()
@@ -68,12 +90,14 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Dict[str, Any]:
6890 return original_result
6991
7092 evaluator_result = await client_instance .guardrails .execute_evaluator (
71- slug , original_result , evaluator_version , evaluator_config
93+ slug , evaluator_data , evaluator_version , evaluator_config
7294 )
7395
7496 # Use callback if provided, otherwise use default behavior
7597 if on_evaluation_complete :
76- return on_evaluation_complete (evaluator_result , original_result )
98+ callback_result = on_evaluation_complete (evaluator_result , original_result )
99+ # Callback should return a dict, but we can't enforce this at compile time
100+ return callback_result # type: ignore[no-any-return]
77101 else :
78102 # Default behavior: return original result (dict) regardless of pass/fail
79103 # Users should use callback for custom behavior
@@ -83,17 +107,19 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Dict[str, Any]:
83107 def sync_wrapper (* args : P .args , ** kwargs : P .kwargs ) -> Dict [str , Any ]:
84108
85109 # Execute the original function - should return a dict with fields matching required_input_fields
86- original_result = func (* args , ** kwargs )
110+ result = func (* args , ** kwargs )
87111
88- # Validate that original_result is a dict
89- if not isinstance (original_result , dict ):
112+ # Ensure we have a dict
113+ if not isinstance (result , dict ):
90114 raise ValueError (
91- f"Function { func .__name__ } must return a dict, got { type (original_result )} . "
115+ f"Function { func .__name__ } must return a dict, got { type (result )} . "
92116 f"Required fields: { required_input_fields or 'unknown' } "
93117 )
94118
119+ original_result : Dict [str , Any ] = result
120+
95121 # Build evaluator_data based on required_input_fields or use all fields from result
96- evaluator_data = {}
122+ evaluator_data : Dict [ str , str ] = {}
97123 if required_input_fields :
98124 # Use only the required fields from the function result
99125 for field in required_input_fields :
@@ -102,11 +128,11 @@ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> Dict[str, Any]:
102128 f"Function { func .__name__ } must return dict with field '{ field } '. "
103129 f"Got: { list (original_result .keys ())} "
104130 )
105- evaluator_data [field ] = InputExtractor ( source = original_result [field ])
131+ evaluator_data [field ] = str ( original_result [field ])
106132 else :
107133 # No required fields specified, use all fields from result
108134 for field , value in original_result .items ():
109- evaluator_data [field ] = InputExtractor ( source = value )
135+ evaluator_data [field ] = str ( value )
110136
111137 # Get client instance
112138 try :
@@ -125,19 +151,21 @@ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> Dict[str, Any]:
125151
126152 # Use callback if provided, otherwise use default behavior
127153 if on_evaluation_complete :
128- return on_evaluation_complete (evaluator_result , original_result )
154+ callback_result = on_evaluation_complete (evaluator_result , original_result )
155+ # Callback should return a dict, but we can't enforce this at compile time
156+ return callback_result # type: ignore[no-any-return]
129157 else :
130158 # Default behavior: return original result (dict) regardless of pass/fail
131159 # Users should use callback for custom behavior
132160 return original_result
133161
134162 # Return appropriate wrapper based on function type
135163 if asyncio .iscoroutinefunction (func ):
136- return async_wrapper
164+ return async_wrapper # type: ignore[return-value]
137165 else :
138166 return sync_wrapper
139167
140- return decorator
168+ return decorator # type: ignore[return-value]
141169
142170
143171class Guardrails :
0 commit comments