1515import google .generativeai as genai
1616import random
1717from openai import OpenAI
18- from mistralai .client import MistralClient
19- from mistralai .models .chat_completion import ChatMessage
18+ from mistralai import Mistral
2019from misc .utils import (
2120 match_number_color ,
2221 match_letter ,
2928 LINDDUN_GO_JUDGE_PROMPT ,
3029)
3130
32- from pydantic import BaseModel
31+ from pydantic import BaseModel
3332
3433def linddun_go_gen_markdown (threats ):
3534 """
@@ -58,7 +57,7 @@ def linddun_go_gen_markdown(threats):
5857
5958 return markdown_output
6059
61- def get_deck (file = "misc/deck.json" ):
60+ def get_deck (shuffled = False , file = "misc/deck.json" ):
6261 """
6362 This function reads the deck of cards from a JSON file.
6463
@@ -75,10 +74,14 @@ def get_deck(file="misc/deck.json"):
7574 """
7675 with open (file , 'r' ) as deck_file :
7776 deck = json .load (deck_file )
78- return deck ["cards" ]
77+
78+ result = deck ["cards" ]
79+ if shuffled :
80+ random .shuffle (result )
81+ return result
7982
8083
81- def get_linddun_go (api_key , model_name , inputs , threats_to_analyze , temperature ):
84+ def get_linddun_go (api_key , model_name , inputs , threats_to_analyze , temperature , lmstudio = False ):
8285 """
8386 This function generates a single-agent LINDDUN threat model from the prompt.
8487
@@ -98,11 +101,12 @@ def get_linddun_go(api_key, model_name, inputs, threats_to_analyze, temperature)
98101 - reply: boolean. Whether the threat was deemed present or not in the application by the LLM.
99102 - reason: string. The reason for the detection or non-detection of the threat.
100103 """
101- client = OpenAI (api_key = api_key )
102- deck = get_deck ()
103-
104- # Shuffle the deck of cards, simulating the experience of drawing cards from the deck
105- random .shuffle (deck )
104+ if lmstudio :
105+ client = OpenAI (base_url = "http://localhost:1234/v1" , api_key = "lm-studio" )
106+ else :
107+ client = OpenAI (api_key = api_key )
108+
109+ deck = get_deck (shuffled = True )
106110
107111 threats = []
108112
@@ -124,10 +128,10 @@ def get_linddun_go(api_key, model_name, inputs, threats_to_analyze, temperature)
124128 },
125129 ]
126130
127- if model_name in ["gpt-4o" , "gpt-4o-mini" ]:
131+ if model_name in ["gpt-4o" , "gpt-4o-mini" ] or lmstudio :
128132 class Threat (BaseModel ):
129- reply : bool
130133 reason : str
134+ reply : bool
131135 response = client .beta .chat .completions .parse (
132136 model = model_name ,
133137 messages = messages ,
@@ -155,7 +159,7 @@ class Threat(BaseModel):
155159
156160
157161
158- def get_multiagent_linddun_go (keys , models , inputs , temperature , rounds , threats_to_analyze , llms_to_use ):
162+ def get_multiagent_linddun_go (keys , models , inputs , temperature , rounds , threats_to_analyze , llms_to_use , lmstudio = False ):
159163 """
160164 This function generates a multi-agent LINDDUN threat model from the prompt.
161165
@@ -178,23 +182,23 @@ def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats
178182 - reason: string. The reason for the detection or non-detection of the threat.
179183 """
180184
181- # Initialize the LLM clients
182- openai_client = OpenAI (api_key = keys ["openai_api_key" ]) if "OpenAI API" in llms_to_use else None
183- mistral_client = MistralClient (api_key = keys ["mistral_api_key" ]) if "Mistral API" in llms_to_use else None
184- if "Google AI API" in llms_to_use :
185- genai .configure (api_key = keys ["google_api_key" ])
186- google_client = genai .GenerativeModel (
187- models ["google_model" ], generation_config = {"response_mime_type" : "application/json" }
188- )
185+ if lmstudio :
186+ lmstudio_client = OpenAI (base_url = "http://localhost:1234/v1" , api_key = "lm-studio" )
189187 else :
190- google_client = None
188+ # Initialize the LLM clients
189+ openai_client = OpenAI (api_key = keys ["openai_api_key" ]) if "OpenAI API" in llms_to_use else None
190+ mistral_client = Mistral (api_key = keys ["mistral_api_key" ]) if "Mistral API" in llms_to_use else None
191+ if "Google AI API" in llms_to_use :
192+ genai .configure (api_key = keys ["google_api_key" ])
193+ google_client = genai .GenerativeModel (
194+ models ["google_model" ], generation_config = {"response_mime_type" : "application/json" }
195+ )
196+ else :
197+ google_client = None
191198
192199 threats = []
193- deck = get_deck ()
200+ deck = get_deck (shuffled = True )
194201
195- # Shuffle the deck of cards, simulating the experience of drawing cards from the deck
196- random .shuffle (deck )
197-
198202
199203 for card in deck [0 :threats_to_analyze ]:
200204 question = "\n " .join (card ["questions" ])
@@ -238,14 +242,23 @@ def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats
238242 system_prompt ,
239243 user_prompt
240244 )
245+ elif llms_to_use [llm ] == "Local LM Studio" :
246+ response_content = get_response_openai (
247+ lmstudio_client ,
248+ models ["lmstudio_model" ],
249+ temperature ,
250+ system_prompt ,
251+ user_prompt ,
252+ lmstudio = lmstudio
253+ )
241254 else :
242255 raise Exception ("Invalid LLM provider" )
243256
244257
245258 previous_analysis [i ] = response_content
246259
247260 # Judge the final verdict based on the previous analysis
248- final_verdict = judge (keys , models , previous_analysis , temperature )
261+ final_verdict = judge (keys , models , previous_analysis , temperature , lmstudio = lmstudio )
249262 final_verdict ["question" ] = question
250263 final_verdict ["threat_title" ] = title
251264 final_verdict ["threat_description" ] = description
@@ -255,7 +268,7 @@ def get_multiagent_linddun_go(keys, models, inputs, temperature, rounds, threats
255268
256269 return threats
257270
258- def get_response_openai (client , model , temperature , system_prompt , user_prompt ):
271+ def get_response_openai (client , model , temperature , system_prompt , user_prompt , lmstudio = False ):
259272 """
260273 This function generates a response from the OpenAI API.
261274
@@ -281,10 +294,10 @@ def get_response_openai(client, model, temperature, system_prompt, user_prompt):
281294 "content" : user_prompt ,
282295 },
283296 ]
284- if model in ["gpt-4o" , "gpt-4o-mini" ]:
297+ if model in ["gpt-4o" , "gpt-4o-mini" ] or lmstudio :
285298 class Threat (BaseModel ):
286- reply : bool
287299 reason : str
300+ reply : bool
288301 response = client .beta .chat .completions .parse (
289302 model = model ,
290303 response_format = Threat ,
@@ -322,8 +335,8 @@ def get_response_mistral(client, model, temperature, system_prompt, user_prompt)
322335 model = model ,
323336 response_format = {"type" : "json_object" },
324337 messages = [
325- ChatMessage ( role = " system" , content = system_prompt ) ,
326- ChatMessage ( role = " user" , content = user_prompt ) ,
338+ { " role" : " system" , " content" : system_prompt } ,
339+ { " role" : " user" , " content" : user_prompt } ,
327340 ],
328341 max_tokens = 4096 ,
329342 temperature = temperature ,
@@ -367,7 +380,7 @@ def get_response_google(client, temperature, system_prompt, user_prompt):
367380
368381 return json .loads (response .candidates [0 ].content .parts [0 ].text )
369382
370- def judge (keys , models , previous_analysis , temperature ):
383+ def judge (keys , models , previous_analysis , temperature , lmstudio = False ):
371384 """
372385 This function judges the final verdict based on the previous analysis.
373386
@@ -382,7 +395,10 @@ def judge(keys, models, previous_analysis, temperature):
382395 - reply: boolean. Whether the threat was deemed present or not in the application by the LLM.
383396 - reason: string. The reason for the detection or non-detection of the threat.
384397 """
385- client = OpenAI (api_key = keys ["openai_api_key" ])
398+ if lmstudio :
399+ client = OpenAI (base_url = "http://localhost:1234/v1" , api_key = "lm-studio" )
400+ else :
401+ client = OpenAI (api_key = keys ["openai_api_key" ])
386402 messages = [
387403 {
388404 "role" : "system" ,
@@ -393,20 +409,20 @@ def judge(keys, models, previous_analysis, temperature):
393409 "content" : LINDDUN_GO_PREVIOUS_ANALYSIS_PROMPT (previous_analysis )
394410 },
395411 ]
396- if models ["openai_model" ] in ["gpt-4o" , "gpt-4o-mini" ]:
412+ if models ["openai_model" ] in ["gpt-4o" , "gpt-4o-mini" ] or lmstudio :
397413 class Threat (BaseModel ):
398- reply : bool
399414 reason : str
415+ reply : bool
400416 response = client .beta .chat .completions .parse (
401- model = models ["openai_model" ],
417+ model = models ["openai_model" ] if not lmstudio else models [ "lmstudio_model" ] ,
402418 response_format = Threat ,
403419 temperature = temperature ,
404420 messages = messages ,
405421 max_tokens = 4096 ,
406422 )
407423 else :
408424 response = client .chat .completions .create (
409- model = models ["openai_model" ],
425+ model = models ["openai_model" ] if not lmstudio else models [ "lmstudio_model" ] ,
410426 messages = messages ,
411427 response_format = {"type" : "json_object" },
412428 temperature = temperature ,
0 commit comments