11import torch
2- from transformers import pipeline
2+ from transformers import pipeline , BartTokenizer , BartForConditionalGeneration
33import logging
44
55# Set device to GPU if available
1313class ChatbotMemory :
1414 def __init__ (self , conv :list = []):
1515 self .conversation_history = conv
16+ self .tokenizer = BartTokenizer .from_pretrained ('facebook/bart-large-cnn' )
17+ self .model = BartForConditionalGeneration .from_pretrained ('facebook/bart-large-cnn' )
1618
1719 def update_memory (self , user_input :str , bot_response :str )-> None :
1820 """
@@ -25,8 +27,8 @@ def update_memory(self, user_input:str, bot_response:str)->None:
2527 """
2628 self .conversation_history .append (f"'user': { user_input } , 'bot': { bot_response } " )
2729
28- if memory_counter (self .conversation_history ) > 1000 :
29- self .conversation_history = compressed_memory (self .conversation_history )
30+ if self . memory_counter (self .conversation_history ) > 1000 :
31+ self .conversation_history = self . compressed_memory (self .conversation_history )
3032 logging .info ("Memory compressed." )
3133
3234 if len (self .conversation_history ) > MAX_MEMORY_SIZE :
@@ -43,43 +45,49 @@ def get_memory(self):
4345 """
4446 return self .conversation_history
4547
46- def _get_compressed_memory (sentence :str )-> str :
47- """
48- Compresses the input sentence using the Facebook BART model for summarization.
48+ def _get_compressed_memory (self , text ):
49+
50+ inputs = self .tokenizer .encode ("summarize: " + text , return_tensors = "pt" , max_length = 1024 , truncation = True )
51+ summary_ids = self .model .generate (inputs , max_length = 150 , min_length = 40 , length_penalty = 2.0 , num_beams = 4 , early_stopping = True )
52+ summary = self .tokenizer .decode (summary_ids [0 ], skip_special_tokens = True )
53+ return summary
54+ # def _get_compressed_memory(sentence:str)->str:
55+ # """
56+ # Compresses the input sentence using the Facebook BART model for summarization.
4957
50- Args:
51- sentence: The input sentence to be compressed.
58+ # Args:
59+ # sentence: The input sentence to be compressed.
5260
53- Returns:
54- str: The compressed summary of the input sentence.
55- """
56- summarizer :str = pipeline ("summarization" ,model = "facebook/bart-large-cnn" ,device = device )
57- summary :str = summarizer (sentence , max_length = 50 , min_length = 5 , do_sample = False )
58- return summary [0 ]['summary_text' ]
61+ # Returns:
62+ # str: The compressed summary of the input sentence.
63+ # """
64+ # summarizer:str = pipeline("summarization",model="facebook/bart-large-cnn",device=device)
65+ # summary:str = summarizer(sentence, max_length=50, min_length=5, do_sample=False)
66+ # return summary[0]['summary_text']
5967
60- def compressed_memory (conv_hist :list )-> list :
61- """
62- Compresses each sentence in the conversation history list using summarization.
68+ def compressed_memory (self , conv_hist :list )-> list :
69+ """
70+ Compresses each sentence in the conversation history list using summarization.
6371
64- Args:
65- conv_hist: List of sentences representing the conversation history.
72+ Args:
73+ conv_hist: List of sentences representing the conversation history.
6674
67- Returns:
68- list: List of compressed summaries for each sentence in the conversation history.
69- """
70- # return [_get_compressed_memory(sentence) for sentence in conv_hist]
71- return [_get_compressed_memory (' ' .join (conv_hist [i :i + 5 ])) for i in range (0 , len (conv_hist ), 5 )]
72-
75+ Returns:
76+ list: List of compressed summaries for each sentence in the conversation history.
77+ """
78+ # return [_get_compressed_memory(sentence) for sentence in conv_hist]
79+ return [self ._get_compressed_memory (' ' .join (conv_hist [i :i + 5 ])) for i in range (0 , len (conv_hist ), 5 )]
7380
74- def memory_counter (conv_hist :list )-> int :
75- """
76- Counts the total number of words in the conversation history list.
7781
78- Args:
79- conv_hist: List of sentences representing the conversation history.
82+ def memory_counter (conv_hist :list )-> int :
83+ """
84+ Counts the total number of words in the conversation history list.
85+
86+ Args:
87+ conv_hist: List of sentences representing the conversation history.
8088
81- Returns:
82- int: Total number of words in the conversation history.
83- """
84- st = '' .join (conv_hist )
85- return len (st .split ())
89+ Returns:
90+ int: Total number of words in the conversation history.
91+ """
92+ st = '' .join (conv_hist )
93+ return len (st .split ())
0 commit comments