Skip to content

Commit 8ba7553

Browse files
committed
Simplify guardrail implementation and improve error handling in runner
1 parent 44e9fd0 commit 8ba7553

File tree

4 files changed

+18
-76
lines changed

4 files changed

+18
-76
lines changed

examples/guardrail.js

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@ const inappropriateContentGuardrail = new Guardrail.Input({
1515
message.toLowerCase().includes(word)
1616
);
1717

18-
return {
19-
output_info: {
20-
checked_for: "inappropriate content",
21-
found: containsInappropriate ? "yes" : "no"
22-
},
23-
tripwire_triggered: containsInappropriate
24-
};
18+
return containsInappropriate;
2519
}
2620
});
2721

@@ -49,15 +43,7 @@ const minimumLengthGuardrail = new Guardrail.Output({
4943
const outputStr = String(output);
5044
const isTooShort = outputStr.length < minimumLength;
5145

52-
return {
53-
output_info: {
54-
checked_for: "minimum length",
55-
required_length: minimumLength,
56-
actual_length: outputStr.length,
57-
is_too_short: isTooShort
58-
},
59-
tripwire_triggered: isTooShort
60-
};
46+
return isTooShort;
6147
}
6248
});
6349

src/agent.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class Agent {
107107
* @returns {LLMProvider} - Initialized provider instance
108108
*/
109109
initializeProvider(provider, config) {
110-
console.log('Initializing provider:', provider, config);
110+
// console.log('Initializing provider:', provider, config);
111111
switch (provider.toLowerCase()) {
112112
case 'openai':
113113
return new OpenAIProvider(config);

src/guardrail.js

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,8 @@ class Input {
1111
this.check = check;
1212
}
1313

14-
async validate(agent, input, context) {
15-
const result = await this.check(context, agent, input);
16-
return {
17-
guardrail: { name: this.name },
18-
output: result
19-
};
20-
}
21-
2214
async run(agent, input, context) {
23-
return this.validate(agent, input, context);
15+
return await this.check(context, agent, input);
2416
}
2517
}
2618

@@ -37,16 +29,8 @@ class Output {
3729
this.check = check;
3830
}
3931

40-
async validate(agent, output, context) {
41-
const result = await this.check(context, agent, output);
42-
return {
43-
guardrail: { name: this.name },
44-
output: result
45-
};
46-
}
47-
4832
async run(agent, output, context) {
49-
return this.validate(agent, output, context);
33+
return await this.check(context, agent, output);
5034
}
5135
}
5236

src/runner.js

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class Runner {
3737
let currentAgent = startingAgent;
3838
let messages = [];
3939
let finalOutput = null;
40+
const inputGuardrailResults = [];
41+
const outputGuardrailResults = [];
4042

4143
// Validate provider configuration
4244
currentAgent.llmProvider.validateConfig();
@@ -48,19 +50,11 @@ class Runner {
4850
messages = [...input];
4951
}
5052

51-
// Run input guardrails inline
52-
const inputGuardrailResults = [];
53+
// Run input guardrails
5354
for (const guardrail of currentAgent?.guardrails?.input || []) {
54-
console.log(`Running input guardrail: ${guardrail.name}`);
55-
try {
56-
const result = await guardrail.validate(currentAgent, messages, context);
57-
inputGuardrailResults.push(result);
58-
if (result.output && result.output.tripwireTriggered) {
59-
throw new Error(`Input guardrail ${guardrail.name} triggered: ${JSON.stringify(result.output)}`);
60-
}
61-
} catch (error) {
62-
console.error(`Error running input guardrail: ${error.message}`);
63-
throw error;
55+
const result = await guardrail.run(currentAgent, messages, context);
56+
if (result) {
57+
inputGuardrailResults.push(guardrail.name);
6458
}
6559
}
6660

@@ -112,6 +106,9 @@ class Runner {
112106

113107
// Check for tool calls
114108
if (assistantMessage.tool_calls && assistantMessage.tool_calls.length > 0) {
109+
110+
// console.log("Tool calls detected:", assistantMessage.tool_calls);
111+
115112
for (const toolCall of assistantMessage.tool_calls) {
116113
const { name, arguments: args } = toolCall.function;
117114

@@ -178,22 +175,6 @@ class Runner {
178175
// No tool calls, this is a final output
179176
finalOutput = assistantMessage.content;
180177

181-
// Run output guardrails inline
182-
const outputGuardrailResults = [];
183-
for (const guardrail of currentAgent?.guardrails?.output || []) {
184-
console.log(`Running output guardrail: ${guardrail.name}`);
185-
try {
186-
const result = await guardrail.validate(currentAgent, finalOutput, context);
187-
outputGuardrailResults.push(result);
188-
if (result.output && result.output.tripwireTriggered) {
189-
throw new Error(`Output guardrail ${guardrail.name} triggered: ${JSON.stringify(result.output)}`);
190-
}
191-
} catch (error) {
192-
console.error(`Error running output guardrail: ${error.message}`);
193-
throw error;
194-
}
195-
}
196-
197178
// If we have a final output, we're done
198179
break;
199180
}
@@ -204,23 +185,14 @@ class Runner {
204185
throw new Error(`Max turns (${maxTurns}) exceeded`);
205186
}
206187

207-
// Run output guardrails inline (final)
208-
const outputGuardrailResults = [];
188+
// Run output guardrails
209189
for (const guardrail of currentAgent?.guardrails?.output || []) {
210-
console.log(`Running output guardrail: ${guardrail.name}`);
211-
try {
212-
const result = await guardrail.validate(currentAgent, finalOutput, context);
213-
outputGuardrailResults.push(result);
214-
if (result.output && result.output.tripwireTriggered) {
215-
throw new Error(`Output guardrail ${guardrail.name} triggered: ${JSON.stringify(result.output)}`);
216-
}
217-
} catch (error) {
218-
console.error(`Error running output guardrail: ${error.message}`);
219-
throw error;
190+
const result = await guardrail.run(currentAgent, finalOutput, context);
191+
if (result) {
192+
outputGuardrailResults.push(guardrail.name);
220193
}
221194
}
222195

223-
224196
const result = new RunResult({
225197
input,
226198
messages,

0 commit comments

Comments
 (0)