Skip to content

Commit 1d8dbd0

Browse files
committed
add functionality for storing prompts and responses in CSV
1 parent 546431e commit 1d8dbd0

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

agents/attackers/llm_qa/llm_action_planner.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,27 @@ def __init__(self, model_name: str, goal: str, memory_len: int = 10, api_url=Non
7474
self.memory_len = memory_len
7575
self.logger = logging.getLogger("REACT-agent")
7676
self.update_instructions(goal.lower())
77-
77+
self.prompts = []
78+
self.states = []
79+
self.responses = []
80+
81+
def get_prompts(self) -> list:
82+
"""
83+
Returns the list of prompts sent to the LLM."""
84+
return self.prompts
85+
86+
def get_responses(self) -> list:
87+
"""
88+
Returns the list of responses received from the LLM. Only Stage 2 responses are included.
89+
"""
90+
return self.responses
91+
92+
def get_states(self) -> list:
93+
"""
94+
Returns the list of states received from the LLM. In JSON format.
95+
"""
96+
return self.states
97+
7898
def update_instructions(self, new_goal: str) -> None:
7999
template = jinja2.Environment().from_string(self.config['prompts']['INSTRUCTIONS_TEMPLATE'])
80100
self.instructions = template.render(goal=new_goal)
@@ -141,6 +161,8 @@ def parse_response(self, llm_response: str, state: Observation.state):
141161

142162

143163
def get_action_from_obs_react(self, observation: Observation, memory_buf: list) -> tuple:
164+
self.states.append(observation.state.as_json())
165+
144166
status_prompt = create_status_from_state(observation.state)
145167
Q1 = self.config['questions'][0]['text']
146168
Q4 = self.config['questions'][3]['text']
@@ -168,7 +190,9 @@ def get_action_from_obs_react(self, observation: Observation, memory_buf: list)
168190
{"role": "user", "content": memory_prompt},
169191
{"role": "user", "content": Q4},
170192
]
171-
193+
self.prompts.append(messages)
194+
172195
response = self.openai_query(messages, max_tokens=80, fmt={"type": "json_object"})
196+
self.responses.append(response)
173197
self.logger.info(f"(Stage 2) Response from LLM: {response}")
174198
return self.parse_response(response, observation.state)

agents/attackers/llm_qa/llm_agent_qa.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@
100100
num_actions_repeated = []
101101
reward_memory = ""
102102

103-
states = []
104-
prompts = []
105-
responses = []
106-
evaluations = []
103+
104+
# Create an empty DataFrame for storing prompts and responses, and evaluations
105+
prompt_table = pd.DataFrame(columns=["state", "prompt", "response", "evaluation"])
106+
107+
107108
# We are still not using this, but we keep track
108109
is_detected = False
109110

@@ -113,7 +114,7 @@
113114
print("Done")
114115
for episode in range(1, args.test_episodes + 1):
115116
actions_took_in_episode = []
116-
117+
evaluations = [] # used for prompt table storage.
117118
logger.info(f"Running episode {episode}")
118119
print(f"Running episode {episode}")
119120

@@ -138,9 +139,7 @@
138139
for i in range(num_iterations):
139140
good_action = False
140141
#is_json_ok = True
141-
states.append(observation.state.as_json())
142142
is_valid, response_dict, action = llm_query.get_action_from_obs_react(observation, memories)
143-
144143
if is_valid:
145144
observation = agent.make_step(action)
146145
logger.info(f"Observation received: {observation}")
@@ -265,7 +264,16 @@
265264
)
266265
break
267266

268-
267+
episode_prompt_table = {
268+
"state": llm_query.get_states(),
269+
"prompt": llm_query.get_prompts(),
270+
"response": llm_query.get_responses(),
271+
"evaluation": evaluations,
272+
}
273+
episode_prompt_table = pd.DataFrame(episode_prompt_table)
274+
prompt_table = pd.concat([prompt_table,episode_prompt_table],axis=0,ignore_index=True)
275+
276+
prompt_table.to_csv("states_prompts_responses_new.csv", index=False)
269277

270278
# After all episodes are done. Compute statistics
271279
test_win_rate = (wins / (args.test_episodes)) * 100

0 commit comments

Comments
 (0)