Skip to content

Commit e3e5d52

Browse files
committed
mise en place de résumé contextuel
1 parent 974e024 commit e3e5d52

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed
Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from transformers import pipeline
2+
from transformers import pipeline, BartTokenizer, BartForConditionalGeneration
33
import logging
44

55
# Set device to GPU if available
@@ -13,6 +13,8 @@
1313
class ChatbotMemory:
1414
def __init__(self, conv:list = []):
1515
self.conversation_history = conv
16+
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
17+
self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
1618

1719
def update_memory(self, user_input:str, bot_response:str)->None:
1820
"""
@@ -25,8 +27,8 @@ def update_memory(self, user_input:str, bot_response:str)->None:
2527
"""
2628
self.conversation_history.append(f"'user': {user_input}, 'bot': {bot_response}")
2729

28-
if memory_counter(self.conversation_history) > 1000:
29-
self.conversation_history = compressed_memory(self.conversation_history)
30+
if self.memory_counter(self.conversation_history) > 1000:
31+
self.conversation_history = self.compressed_memory(self.conversation_history)
3032
logging.info("Memory compressed.")
3133

3234
if len(self.conversation_history) > MAX_MEMORY_SIZE:
@@ -43,43 +45,49 @@ def get_memory(self):
4345
"""
4446
return self.conversation_history
4547

46-
def _get_compressed_memory(sentence:str)->str:
47-
"""
48-
Compresses the input sentence using the Facebook BART model for summarization.
48+
def _get_compressed_memory(self, text):
49+
50+
inputs = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
51+
summary_ids = self.model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
52+
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
53+
return summary
54+
# def _get_compressed_memory(sentence:str)->str:
55+
# """
56+
# Compresses the input sentence using the Facebook BART model for summarization.
4957

50-
Args:
51-
sentence: The input sentence to be compressed.
58+
# Args:
59+
# sentence: The input sentence to be compressed.
5260

53-
Returns:
54-
str: The compressed summary of the input sentence.
55-
"""
56-
summarizer:str = pipeline("summarization",model="facebook/bart-large-cnn",device=device)
57-
summary:str = summarizer(sentence, max_length=50, min_length=5, do_sample=False)
58-
return summary[0]['summary_text']
61+
# Returns:
62+
# str: The compressed summary of the input sentence.
63+
# """
64+
# summarizer:str = pipeline("summarization",model="facebook/bart-large-cnn",device=device)
65+
# summary:str = summarizer(sentence, max_length=50, min_length=5, do_sample=False)
66+
# return summary[0]['summary_text']
5967

60-
def compressed_memory(conv_hist:list)->list:
61-
"""
62-
Compresses each sentence in the conversation history list using summarization.
68+
def compressed_memory(self, conv_hist:list)->list:
69+
"""
70+
Compresses each sentence in the conversation history list using summarization.
6371
64-
Args:
65-
conv_hist: List of sentences representing the conversation history.
72+
Args:
73+
conv_hist: List of sentences representing the conversation history.
6674
67-
Returns:
68-
list: List of compressed summaries for each sentence in the conversation history.
69-
"""
70-
# return [_get_compressed_memory(sentence) for sentence in conv_hist]
71-
return [_get_compressed_memory(' '.join(conv_hist[i:i+5])) for i in range(0, len(conv_hist), 5)]
72-
75+
Returns:
76+
list: List of compressed summaries for each sentence in the conversation history.
77+
"""
78+
# return [_get_compressed_memory(sentence) for sentence in conv_hist]
79+
return [self._get_compressed_memory(' '.join(conv_hist[i:i+5])) for i in range(0, len(conv_hist), 5)]
7380

74-
def memory_counter(conv_hist:list)->int:
75-
"""
76-
Counts the total number of words in the conversation history list.
7781

78-
Args:
79-
conv_hist: List of sentences representing the conversation history.
82+
def memory_counter(conv_hist:list)->int:
83+
"""
84+
Counts the total number of words in the conversation history list.
85+
86+
Args:
87+
conv_hist: List of sentences representing the conversation history.
8088
81-
Returns:
82-
int: Total number of words in the conversation history.
83-
"""
84-
st = ''.join(conv_hist)
85-
return len(st.split())
89+
Returns:
90+
int: Total number of words in the conversation history.
91+
"""
92+
st = ''.join(conv_hist)
93+
return len(st.split())

0 commit comments

Comments
 (0)