Skip to content

Commit 98a007c

Browse files
committed
refonte du mécanisme mnésique + élision des listes pr dict
1 parent 6523e51 commit 98a007c

File tree

2 files changed

+59
-111
lines changed

2 files changed

+59
-111
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
venv/
Lines changed: 58 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import torch
2-
from transformers import pipeline, BartTokenizer, BartForConditionalGeneration
2+
from transformers import BartTokenizer, BartForConditionalGeneration
33
import logging
44

5-
# Set device to GPU if available
6-
if torch.cuda.is_available():
7-
device:int = 0
8-
else:
9-
device:int = -1
5+
# Configuration du logging
6+
logging.basicConfig(level=logging.INFO)
107

11-
MAX_MEMORY_SIZE:int = 2000
8+
# Détection automatique du device
9+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
11+
# Paramètres globaux
12+
MAX_MEMORY_SIZE = 2000 # Limite du nombre de messages
13+
MAX_TOKENS_PER_MESSAGE = 1000 # Limite pour compresser la mémoire
14+
BATCH_SIZE = 5 # Taille du batch pour la compression
1215

1316
class ChatbotMemory:
14-
def __init__(self, conv:list = []):
15-
self.conversation_history = conv
17+
def __init__(self, conv:list=None):
18+
self.conversation_history = conv or []
1619
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
17-
self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
20+
self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
1821

1922
def update_memory(self, user_input:str, bot_response:str)->None:
2023
"""
@@ -23,18 +26,16 @@ def update_memory(self, user_input:str, bot_response:str)->None:
2326
user_input (str): The input provided by the user.
2427
bot_response (str): The response generated by the Chatbot.
2528
Returns:
26-
None
27-
"""
28-
self.conversation_history.append(f"'user': {user_input}, 'bot': {bot_response}")
29-
30-
if self.memory_counter(self.conversation_history) > 1000:
31-
self.conversation_history = self.compressed_memory(self.conversation_history)
32-
logging.info("Memory compressed.")
33-
29+
None """
30+
self.conversation_history.append({'user': user_input, 'bot': bot_response})
31+
32+
if self.memory_counter() > MAX_TOKENS_PER_MESSAGE:
33+
self.conversation_history = self.compressed_memory()
34+
logging.info("Mémoire compressée.")
35+
3436
if len(self.conversation_history) > MAX_MEMORY_SIZE:
3537
self.conversation_history.pop(0)
36-
logging.info("Memory trimmed.")
37-
return 0
38+
logging.info("Mémoire tronquée.")
3839

3940
def get_memory(self):
4041
"""
@@ -44,105 +45,51 @@ def get_memory(self):
4445
The conversation history.
4546
"""
4647
return self.conversation_history
47-
48-
def _get_compressed_memory(self, text):
49-
50-
inputs = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
51-
summary_ids = self.model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
52-
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
53-
return summary
54-
# def _get_compressed_memory(sentence:str)->str:
55-
# """
56-
# Compresses the input sentence using the Facebook BART model for summarization.
57-
58-
# Args:
59-
# sentence: The input sentence to be compressed.
6048

61-
# Returns:
62-
# str: The compressed summary of the input sentence.
63-
# """
64-
# summarizer:str = pipeline("summarization",model="facebook/bart-large-cnn",device=device)
65-
# summary:str = summarizer(sentence, max_length=50, min_length=5, do_sample=False)
66-
# return summary[0]['summary_text']
67-
68-
def compressed_memory(self, conv_hist:list)->list:
49+
def _get_compressed_memory(self, text:str):
6950
"""
70-
Compresses each sentence in the conversation history list using summarization.
71-
72-
Args:
73-
conv_hist: List of sentences representing the conversation history.
74-
75-
Returns:
76-
list: List of compressed summaries for each sentence in the conversation history.
51+
Résume un bloc de texte.
7752
"""
78-
# return [_get_compressed_memory(sentence) for sentence in conv_hist]
79-
return [self._get_compressed_memory(' '.join(conv_hist[i:i+5])) for i in range(0, len(conv_hist), 5)]
80-
81-
82-
def memory_counter(self, conv_hist:list[str])->int:
53+
inputs = self.tokenizer(
54+
f"summarize: {text}",
55+
return_tensors="pt",
56+
max_length=1024,
57+
truncation=True,
58+
).to(device)
59+
60+
summary_ids = self.model.generate(
61+
inputs.input_ids,
62+
max_length=150,
63+
min_length=40,
64+
length_penalty=2.0,
65+
num_beams=4,
66+
early_stopping=True
67+
)
68+
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
69+
70+
def compressed_memory(self):
8371
"""
84-
Counts the total number of words in the conversation history list.
85-
86-
Args:
87-
conv_hist: List of sentences representing the conversation history.
72+
Résume l'historique de la conversation par batch de BATCH_SIZE.
73+
"""
74+
combined_history = [f"User: {entry['user']} Bot: {entry['bot']}" for entry in self.conversation_history]
75+
return [
76+
{'summary': self._get_compressed_memory(' '.join(combined_history[i:i+BATCH_SIZE]))}
77+
for i in range(0, len(combined_history), BATCH_SIZE)
78+
]
8879

89-
Returns:
90-
int: Total number of words in the conversation history.
80+
def memory_counter(self):
81+
"""
82+
Compte le nombre total de tokens dans l'historique.
9183
"""
92-
st = ''.join(conv_hist)
93-
return len(st.split())
84+
all_text = ' '.join(f"{entry['user']} {entry['bot']}" for entry in self.conversation_history)
85+
return len(self.tokenizer.tokenize(all_text))
9486

9587
if __name__ == "__main__":
9688
chat_memory = ChatbotMemory()
97-
user_input = "Bonjour, comment allez-vous?"
98-
bot_response = "Je vais bien, merci ! Et vous ?"
99-
chat_memory.update_memory(user_input, bot_response)
89+
90+
chat_memory.update_memory("Bonjour, comment allez-vous?", "Je vais bien, merci ! Et vous ?")
91+
chat_memory.update_memory("Écris un poème de 300 mots", "Sous l'éclat argenté d'une lune éternelle... (poème)")
10092

101-
# Obtenir l'historique des conversations
10293
historique = chat_memory.get_memory()
103-
print(historique)
104-
105-
user_input = "écris un poème de 300 mots"
106-
bot_response = "\
107-
L’écho des mondes\
108-
Sous l’éclat argenté d’une lune éternelle,\
109-
Les murmures du vent dansent, doux et rebelles.\
110-
Ils transportent des rêves, des promesses lointaines,\
111-
Effleurant l’horizon de montagnes anciennes.\
112-
\
113-
Dans la forêt profonde, les arbres se confient,\
114-
Leurs racines tissent des récits infinis.\
115-
Chaque feuille, chaque branche, porte un secret du temps,\
116-
Un souffle de sagesse, une parole en suspens.\
117-
\
118-
Les rivières murmurent en suivant leur chemin,\
119-
Leurs eaux claires reflètent le bleu du matin.\
120-
Elles emportent les peines, les joies et les cris,\
121-
Traversant les âges, ignorant les périls.\
122-
\
123-
Le soleil se lève, doré, majestueux,\
124-
Réveillant la terre d’un sommeil lumineux.\
125-
Il peint des ombres dansantes sur les collines,\
126-
Offrant à chaque instant des lueurs divines.\
127-
\
128-
Dans ce vaste univers où tout semble figé,\
129-
Les étoiles veillent, telles des âmes égarées.\
130-
Elles brillent de loin, gardiennes silencieuses,\
131-
De secrets millénaires et de vies mystérieuses.\
132-
\
133-
Chaque battement de cœur, chaque souffle d’air,\
134-
Est un fragment du tout, un lien universel.\
135-
Nous marchons sur ce fil, entre ombre et lumière,\
136-
Cherchant notre place dans cette grande sphère.\
137-
\
138-
Et dans le silence, là où tout se résout,\
139-
L’écho des mondes résonne, doux mais flou.\
140-
Il nous rappelle que l’immensité est en nous,\
141-
Et que, dans chaque instant, réside l’infini goût.\
142-
\
143-
Le temps s’efface, les frontières se dissolvent,\
144-
Dans ce ballet cosmique où les âmes s’envolent.\
145-
Nous sommes poussière d’étoile, brève, éphémère,\
146-
Mais dans l’éternité, notre essence prospère."
147-
148-
chat_memory.update_memory(user_input, bot_response)
94+
for entry in historique:
95+
print(entry)

0 commit comments

Comments
 (0)