Skip to content

Commit 1622fca

Browse files
Merge pull request #5 from AndreaBissoli/master
Improvements to DFD interface and some local LLM features with LM Studio
2 parents 9caeb3d + 17cdc7b commit 1622fca

File tree

13 files changed

+1265
-451
lines changed

13 files changed

+1265
-451
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ venv/
22
__pycache__/
33
test_functions.py
44
.streamlit/secrets.toml
5-
.devcontainer
5+
.devcontainer
6+
benchmarks/

llms/dfd.py

Lines changed: 435 additions & 96 deletions
Large diffs are not rendered by default.

llms/linddun_go.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import google.generativeai as genai
1616
import random
1717
from openai import OpenAI
18-
from mistralai.client import MistralClient
19-
from mistralai.models.chat_completion import ChatMessage
18+
from mistralai import Mistral
2019
from misc.utils import (
2120
match_number_color,
2221
match_letter,
@@ -29,7 +28,7 @@
2928
LINDDUN_GO_JUDGE_PROMPT,
3029
)
3130

32-
from pydantic import BaseModel
31+
from pydantic import BaseModel
3332

3433
def linddun_go_gen_markdown(threats):
3534
"""
@@ -58,7 +57,7 @@ def linddun_go_gen_markdown(threats):
5857

5958
return markdown_output
6059

61-
def get_deck(file="misc/deck.json"):
60+
def get_deck(shuffled=False, file="misc/deck.json"):
6261
"""
6362
This function reads the deck of cards from a JSON file.
6463
@@ -75,10 +74,14 @@ def get_deck(file="misc/deck.json"):
7574
"""
7675
with open(file, 'r') as deck_file:
7776
deck = json.load(deck_file)
78-
return deck["cards"]
77+
78+
result = deck["cards"]
79+
if shuffled:
80+
random.shuffle(result)
81+
return result
7982

8083

81-
def get_linddun_go(api_key, model_name, inputs, threats_to_analyze, temperature):
84+
def get_linddun_go(api_key, model_name, inputs, threats_to_analyze, temperature, lmstudio=False):
8285
"""
8386
This function generates a single-agent LINDDUN threat model from the prompt.
8487
@@ -98,11 +101,12 @@ def get_linddun_go(api_key, model_name, inputs, threats_to_analyze, temperature)
98101
- reply: boolean. Whether the threat was deemed present or not in the application by the LLM.
99102
- reason: string. The reason for the detection or non-detection of the threat.
100103
"""
101-
client = OpenAI(api_key=api_key)
102-
deck = get_deck()
103-
104-
# Shuffle the deck of cards, simulating the experience of drawing cards from the deck
105-
random.shuffle(deck)
104+
if lmstudio:
105+
client = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio")
106+
else:
107+
client = OpenAI(api_key=api_key)
108+
109+
deck = get_deck(shuffled=True)
106110

107111
threats = []
108112

@@ -124,10 +128,10 @@ def get_linddun_go(api_key, model_name, inputs, threats_to_analyze, temperature)
124128
},
125129
]
126130

127-
if model_name in ["gpt-4o", "gpt-4o-mini"]:
131+
if model_name in ["gpt-4o", "gpt-4o-mini"] or lmstudio:
128132
class Threat(BaseModel):
129-
reply: bool
130133
reason: str
134+
reply: bool
131135
response = client.beta.chat.completions.parse(
132136
model=model_name,
133137
messages=messages,
@@ -155,7 +159,7 @@ class Threat(BaseModel):
155159

156160

157161

158-
def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats_to_analyze, llms_to_use):
162+
def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats_to_analyze, llms_to_use, lmstudio=False):
159163
"""
160164
This function generates a multi-agent LINDDUN threat model from the prompt.
161165
@@ -178,23 +182,23 @@ def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats
178182
- reason: string. The reason for the detection or non-detection of the threat.
179183
"""
180184

181-
# Initialize the LLM clients
182-
openai_client = OpenAI(api_key=keys["openai_api_key"]) if "OpenAI API" in llms_to_use else None
183-
mistral_client = MistralClient(api_key=keys["mistral_api_key"]) if "Mistral API" in llms_to_use else None
184-
if "Google AI API" in llms_to_use:
185-
genai.configure(api_key=keys["google_api_key"])
186-
google_client = genai.GenerativeModel(
187-
models["google_model"], generation_config={"response_mime_type": "application/json"}
188-
)
185+
if lmstudio:
186+
lmstudio_client = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio")
189187
else:
190-
google_client = None
188+
# Initialize the LLM clients
189+
openai_client = OpenAI(api_key=keys["openai_api_key"]) if "OpenAI API" in llms_to_use else None
190+
mistral_client = Mistral(api_key=keys["mistral_api_key"]) if "Mistral API" in llms_to_use else None
191+
if "Google AI API" in llms_to_use:
192+
genai.configure(api_key=keys["google_api_key"])
193+
google_client = genai.GenerativeModel(
194+
models["google_model"], generation_config={"response_mime_type": "application/json"}
195+
)
196+
else:
197+
google_client = None
191198

192199
threats = []
193-
deck = get_deck()
200+
deck = get_deck(shuffled=True)
194201

195-
# Shuffle the deck of cards, simulating the experience of drawing cards from the deck
196-
random.shuffle(deck)
197-
198202

199203
for card in deck[0:threats_to_analyze]:
200204
question = "\n".join(card["questions"])
@@ -238,14 +242,23 @@ def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats
238242
system_prompt,
239243
user_prompt
240244
)
245+
elif llms_to_use[llm] == "Local LM Studio":
246+
response_content = get_response_openai(
247+
lmstudio_client,
248+
models["lmstudio_model"],
249+
temperature,
250+
system_prompt,
251+
user_prompt,
252+
lmstudio=lmstudio
253+
)
241254
else:
242255
raise Exception("Invalid LLM provider")
243256

244257

245258
previous_analysis[i] = response_content
246259

247260
# Judge the final verdict based on the previous analysis
248-
final_verdict = judge(keys, models, previous_analysis, temperature)
261+
final_verdict = judge(keys, models, previous_analysis, temperature, lmstudio=lmstudio)
249262
final_verdict["question"] = question
250263
final_verdict["threat_title"] = title
251264
final_verdict["threat_description"] = description
@@ -255,7 +268,7 @@ def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats
255268

256269
return threats
257270

258-
def get_response_openai(client, model, temperature, system_prompt, user_prompt):
271+
def get_response_openai(client, model, temperature, system_prompt, user_prompt, lmstudio=False):
259272
"""
260273
This function generates a response from the OpenAI API.
261274
@@ -281,10 +294,10 @@ def get_response_openai(client, model, temperature, system_prompt, user_prompt):
281294
"content": user_prompt,
282295
},
283296
]
284-
if model in ["gpt-4o", "gpt-4o-mini"]:
297+
if model in ["gpt-4o", "gpt-4o-mini"] or lmstudio:
285298
class Threat(BaseModel):
286-
reply: bool
287299
reason: str
300+
reply: bool
288301
response = client.beta.chat.completions.parse(
289302
model=model,
290303
response_format=Threat,
@@ -322,8 +335,8 @@ def get_response_mistral(client, model, temperature, system_prompt, user_prompt)
322335
model=model,
323336
response_format={"type": "json_object"},
324337
messages=[
325-
ChatMessage(role="system", content=system_prompt),
326-
ChatMessage(role="user", content=user_prompt),
338+
{"role":"system", "content":system_prompt},
339+
{"role":"user", "content":user_prompt},
327340
],
328341
max_tokens=4096,
329342
temperature=temperature,
@@ -367,7 +380,7 @@ def get_response_google(client, temperature, system_prompt, user_prompt):
367380

368381
return json.loads(response.candidates[0].content.parts[0].text)
369382

370-
def judge(keys, models, previous_analysis, temperature):
383+
def judge(keys, models, previous_analysis, temperature, lmstudio=False):
371384
"""
372385
This function judges the final verdict based on the previous analysis.
373386
@@ -382,7 +395,10 @@ def judge(keys, models, previous_analysis, temperature):
382395
- reply: boolean. Whether the threat was deemed present or not in the application by the LLM.
383396
- reason: string. The reason for the detection or non-detection of the threat.
384397
"""
385-
client = OpenAI(api_key=keys["openai_api_key"])
398+
if lmstudio:
399+
client = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio")
400+
else:
401+
client = OpenAI(api_key=keys["openai_api_key"])
386402
messages=[
387403
{
388404
"role": "system",
@@ -393,20 +409,20 @@ def judge(keys, models, previous_analysis, temperature):
393409
"content": LINDDUN_GO_PREVIOUS_ANALYSIS_PROMPT(previous_analysis)
394410
},
395411
]
396-
if models["openai_model"] in ["gpt-4o", "gpt-4o-mini"]:
412+
if models["openai_model"] in ["gpt-4o", "gpt-4o-mini"] or lmstudio:
397413
class Threat(BaseModel):
398-
reply: bool
399414
reason: str
415+
reply: bool
400416
response = client.beta.chat.completions.parse(
401-
model=models["openai_model"],
417+
model=models["openai_model"] if not lmstudio else models["lmstudio_model"],
402418
response_format=Threat,
403419
temperature=temperature,
404420
messages=messages,
405421
max_tokens=4096,
406422
)
407423
else:
408424
response = client.chat.completions.create(
409-
model=models["openai_model"],
425+
model=models["openai_model"] if not lmstudio else models["lmstudio_model"],
410426
messages=messages,
411427
response_format={"type": "json_object"},
412428
temperature=temperature,

llms/linddun_pro.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def mapping_table(edge, category):
7979
return (True, True, True)
8080

8181

82-
def get_linddun_pro(api_key, model, dfd, edge, category, description, temperature):
82+
def get_linddun_pro(api_key, model, dfd, edge, category, boundaries, temperature):
8383
"""
8484
This function generates a LINDDUN Pro threat model from the information provided.
8585
@@ -92,9 +92,15 @@ def get_linddun_pro(api_key, model, dfd, edge, category, description, temperatur
9292
- to: string. The entity where the data flow ends
9393
- typeto: string. The type of the entity where the data flow ends
9494
- trusted: bool. Whether the data flow is trusted
95+
- boundary: string. The trust boundary id of the data flow
96+
- description: string. The description of the data flow
9597
- edge (dict): The specific edge of the DFD to find threats for. The dictionary has the same keys as the DFD.
9698
- category (str): The LINDDUN category to look for in the threat model, in the format "Linking", "Identifying", etc.
97-
- description (str): A brief description of the data flow.
99+
- boundaries (dict): The trust boundaries of the application. The dictionary has the following keys:
100+
- id: string. The ID of the trust boundary.
101+
- title: string. The title of the trust boundary.
102+
- description: string. The description of the trust
103+
- color: string. The color of the trust boundary.
98104
- temperature (float): The temperature to use for the model.
99105
100106
Returns:
@@ -124,7 +130,7 @@ def get_linddun_pro(api_key, model, dfd, edge, category, description, temperatur
124130
},
125131
{
126132
"role": "user",
127-
"content": LINDDUN_PRO_USER_PROMPT(dfd, edge, category, description, source, data_flow, destination, tree),
133+
"content": LINDDUN_PRO_USER_PROMPT(dfd, edge, category, source, data_flow, destination, boundaries, tree),
128134
},
129135
]
130136

0 commit comments

Comments
 (0)