Skip to content

Commit 6fc5e77

Browse files
authored
Merge pull request #76 from stratosphereips/harpo-saving-prompts
add functionality for storing prompts and responses in CSV
2 parents df10cf7 + 1d8dbd0 commit 6fc5e77

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

agents/attackers/llm_qa/llm_action_planner.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,27 @@ def __init__(self, model_name: str, goal: str, memory_len: int = 10, api_url=Non
7474
self.memory_len = memory_len
7575
self.logger = logging.getLogger("REACT-agent")
7676
self.update_instructions(goal.lower())
77-
77+
self.prompts = []
78+
self.states = []
79+
self.responses = []
80+
81+
def get_prompts(self) -> list:
82+
"""
83+
Returns the list of prompts sent to the LLM."""
84+
return self.prompts
85+
86+
def get_responses(self) -> list:
87+
"""
88+
Returns the list of responses received from the LLM. Only Stage 2 responses are included.
89+
"""
90+
return self.responses
91+
92+
def get_states(self) -> list:
93+
"""
94+
Returns the list of states received from the LLM. In JSON format.
95+
"""
96+
return self.states
97+
7898
def update_instructions(self, new_goal: str) -> None:
7999
template = jinja2.Environment().from_string(self.config['prompts']['INSTRUCTIONS_TEMPLATE'])
80100
self.instructions = template.render(goal=new_goal)
@@ -141,6 +161,8 @@ def parse_response(self, llm_response: str, state: Observation.state):
141161

142162

143163
def get_action_from_obs_react(self, observation: Observation, memory_buf: list) -> tuple:
164+
self.states.append(observation.state.as_json())
165+
144166
status_prompt = create_status_from_state(observation.state)
145167
Q1 = self.config['questions'][0]['text']
146168
Q4 = self.config['questions'][3]['text']
@@ -168,8 +190,10 @@ def get_action_from_obs_react(self, observation: Observation, memory_buf: list)
168190
{"role": "user", "content": memory_prompt},
169191
{"role": "user", "content": Q4},
170192
]
171-
193+
self.prompts.append(messages)
194+
172195
response = self.openai_query(messages, max_tokens=80, fmt={"type": "json_object"})
196+
self.responses.append(response)
173197
self.logger.info(f"(Stage 2) Response from LLM: {response}")
174198
print(f"(Stage 2) Response from LLM: {response}")
175199
return self.parse_response(response, observation.state)

agents/attackers/llm_qa/llm_agent_qa.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@
113113
num_actions_repeated = []
114114
reward_memory = ""
115115

116-
states = []
117-
prompts = []
118-
responses = []
119-
evaluations = []
116+
117+
# Create an empty DataFrame for storing prompts and responses, and evaluations
118+
prompt_table = pd.DataFrame(columns=["state", "prompt", "response", "evaluation"])
119+
120+
120121
# We are still not using this, but we keep track
121122
is_detected = False
122123

@@ -126,7 +127,7 @@
126127
print("Done")
127128
for episode in range(1, args.test_episodes + 1):
128129
actions_took_in_episode = []
129-
130+
evaluations = [] # used for prompt table storage.
130131
logger.info(f"Running episode {episode}")
131132
print(f"Running episode {episode}")
132133

@@ -151,9 +152,7 @@
151152
for i in range(num_iterations):
152153
good_action = False
153154
#is_json_ok = True
154-
states.append(observation.state.as_json())
155155
is_valid, response_dict, action = llm_query.get_action_from_obs_react(observation, memories)
156-
157156
if is_valid:
158157
observation = agent.make_step(action)
159158
logger.info(f"Observation received: {observation}")
@@ -282,7 +281,16 @@
282281
)
283282
break
284283

285-
284+
episode_prompt_table = {
285+
"state": llm_query.get_states(),
286+
"prompt": llm_query.get_prompts(),
287+
"response": llm_query.get_responses(),
288+
"evaluation": evaluations,
289+
}
290+
episode_prompt_table = pd.DataFrame(episode_prompt_table)
291+
prompt_table = pd.concat([prompt_table,episode_prompt_table],axis=0,ignore_index=True)
292+
293+
prompt_table.to_csv("states_prompts_responses_new.csv", index=False)
286294

287295
# After all episodes are done. Compute statistics
288296
test_win_rate = (wins / (args.test_episodes)) * 100

0 commit comments

Comments
 (0)