@@ -51,6 +51,7 @@ def __init__(
5151 tokenizer = None ,
5252 max_length = 150 ,
5353 model_type : str = "causal_lm" ,
54+ device = None ,
5455 ):
5556 if expert_weights is None :
5657 expert_weights = [- 0.5 , 0.5 ]
@@ -94,6 +95,13 @@ def __init__(
9495 )
9596 self .content_feature = "comment_text"
9697
98+ if isinstance (device , str ):
99+ self .device = torch .device (device )
100+ else :
101+ self .device = torch .device (
102+ "cuda" if torch .cuda .is_available () else "cpu"
103+ )
104+
97105 def load_models (self , experts : list , expert_weights : list = None ):
98106 """Load expert models."""
99107 if expert_weights is not None :
@@ -102,7 +110,9 @@ def load_models(self, experts: list, expert_weights: list = None):
102110 for expert in experts :
103111 if isinstance (expert , str ):
104112 expert = BartForConditionalGeneration .from_pretrained (
105- expert , forced_bos_token_id = self .tokenizer .bos_token_id
113+ expert ,
114+ forced_bos_token_id = self .tokenizer .bos_token_id ,
115+ device_map = "auto" ,
106116 )
107117 expert_models .append (expert )
108118 self .experts = expert_models
@@ -200,15 +210,21 @@ def train_models(
200210
201211 if model_type is None :
202212 gminus = BartForConditionalGeneration .from_pretrained (
203- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
213+ base_model ,
214+ forced_bos_token_id = self .tokenizer .bos_token_id ,
215+ device_map = "auto" ,
204216 )
205217 elif model_type == "causal_lm" :
206218 gminus = AutoModelForCausalLM .from_pretrained (
207- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
219+ base_model ,
220+ forced_bos_token_id = self .tokenizer .bos_token_id ,
221+ device_map = "auto" ,
208222 )
209223 elif model_type == "seq2seq_lm" :
210224 gminus = AutoModelForSeq2SeqLM .from_pretrained (
211- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
225+ base_model ,
226+ forced_bos_token_id = self .tokenizer .bos_token_id ,
227+ device_map = "auto" ,
212228 )
213229 else :
214230 raise Exception (f"unsupported model type { model_type } " )
@@ -254,15 +270,21 @@ def train_models(
254270
255271 if model_type is None :
256272 gplus = BartForConditionalGeneration .from_pretrained (
257- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
273+ base_model ,
274+ forced_bos_token_id = self .tokenizer .bos_token_id ,
275+ device_map = "auto" ,
258276 )
259277 elif model_type == "causal_lm" :
260278 gplus = AutoModelForCausalLM .from_pretrained (
261- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
279+ base_model ,
280+ forced_bos_token_id = self .tokenizer .bos_token_id ,
281+ device_map = "auto" ,
262282 )
263283 elif model_type == "seq2seq_lm" :
264284 gplus = AutoModelForSeq2SeqLM .from_pretrained (
265- base_model , forced_bos_token_id = self .tokenizer .bos_token_id
285+ base_model ,
286+ forced_bos_token_id = self .tokenizer .bos_token_id ,
287+ device_map = "auto" ,
266288 )
267289 else :
268290 raise Exception (f"unsupported model type { model_type } " )
@@ -380,6 +402,7 @@ def rephrase(
380402 model = expert ,
381403 tokenizer = self .tokenizer ,
382404 top_k = self .tokenizer .vocab_size ,
405+ device = self .device ,
383406 )
384407 )
385408 for idx in range (len (masked_sentence_tokens )):
@@ -477,9 +500,10 @@ def compute_mask_logits(
477500 self , model , sequence , verbose : bool = False , mask : bool = True
478501 ):
479502 """Compute mask logits."""
503+ model .to (self .device )
480504 if verbose :
481505 print (f"input sequence: { sequence } " )
482- subseq_ids = self .tokenizer (sequence , return_tensors = "pt" )
506+ subseq_ids = self .tokenizer (sequence , return_tensors = "pt" ). to ( self . device )
483507 if verbose :
484508 raw_outputs = model .generate (** subseq_ids )
485509 print (sequence )
@@ -502,9 +526,12 @@ def compute_mask_logits_multiple(
502526 self , model , sequences , verbose : bool = False , mask : bool = True
503527 ):
504528 """Compute mask logits multiple."""
529+ model .to (self .device )
505530 if verbose :
506531 print (f"input sequences: { sequences } " )
507- subseq_ids = self .tokenizer (sequences , return_tensors = "pt" , padding = True )
532+ subseq_ids = self .tokenizer (
533+ sequences , return_tensors = "pt" , padding = True
534+ ).to (self .device )
508535 if verbose :
509536 raw_outputs = model .generate (** subseq_ids )
510537 print (sequences )
@@ -554,6 +581,7 @@ def score(
554581 model = model ,
555582 tokenizer = self .tokenizer ,
556583 top_k = 10 ,
584+ device = self .device ,
557585 )
558586 for masked_sentence in masked_sentences :
559587 # approximated probabilities for top_k tokens
@@ -567,7 +595,9 @@ def score(
567595 js_distances = []
568596 for distr_pair in distr_pairs :
569597 js_distance = jensenshannon (
570- distr_pair [0 ], distr_pair [1 ], axis = 1
598+ distr_pair [0 ].cpu ().clone ().numpy (),
599+ distr_pair [1 ].cpu ().clone ().numpy (),
600+ axis = 1 ,
571601 )
572602 if normalize :
573603 js_distance = js_distance / np .average (js_distance )
@@ -653,7 +683,10 @@ def reflect(
653683 chat_tokenizer .chat_template = chat_template
654684
655685 converse_pipeline = pipeline (
656- "conversational" , model = chat_model , tokenizer = chat_tokenizer
686+ "conversational" ,
687+ model = chat_model ,
688+ tokenizer = chat_tokenizer ,
689+ device = self .device ,
657690 )
658691
659692 for text_id in range (len (texts )):
@@ -729,6 +762,7 @@ def reflect(
729762 conversation_output = converse_pipeline (
730763 formatted_messages ,
731764 pad_token_id = converse_pipeline .tokenizer .eos_token_id ,
765+ device = self .device ,
732766 )
733767 if verbose :
734768 print (f"chat conversation:\n { conversation_output } " )
0 commit comments