@@ -192,65 +192,65 @@ For this, you can extend the ``AutoLLM`` class and implement the required
192192
193193 ::
194194
195- from ontolearner import AutoLLM
196- from typing import List
197- import torch
198-
199- class MistralLLM(AutoLLM):
200-
201- def load(self, model_id: str) -> None:
202- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
203- from mistral_common.models.modeling_mistral import Mistral3ForConditionalGeneration
204-
205- self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
206-
207- device_map = "cpu" if self.device == "cpu" else "balanced"
208- self.model = Mistral3ForConditionalGeneration.from_pretrained(
209- model_id,
210- device_map=device_map,
211- torch_dtype=torch.bfloat16,
212- token=self.token
213- )
214-
215- if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token_id is None:
216- self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
217-
218- self.label_mapper.fit()
219-
220- def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
221- from mistral_common.protocol.instruct.messages import ChatCompletionRequest
222-
223- tokenized_list = []
224- for prompt in inputs:
225- messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
226- tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
227- tokenized_list.append(tokenized.tokens)
228-
229- # Pad inputs and create attention masks
230- max_len = max(len(tokens) for tokens in tokenized_list)
231- input_ids, attention_masks = [], []
232- for tokens in tokenized_list:
233- pad_length = max_len - len(tokens)
234- input_ids.append(tokens + [self.tokenizer.pad_token_id] * pad_length)
235- attention_masks.append([1] * len(tokens) + [0] * pad_length)
236-
237- input_ids = torch.tensor(input_ids).to(self.model.device)
238- attention_masks = torch.tensor(attention_masks).to(self.model.device)
239-
240- outputs = self.model.generate(
241- input_ids=input_ids,
242- attention_mask=attention_masks,
243- eos_token_id=self.model.generation_config.eos_token_id,
244- pad_token_id=self.tokenizer.pad_token_id,
245- max_new_tokens=max_new_tokens,
246- )
247-
248- decoded_outputs = []
249- for i, tokens in enumerate(outputs):
250- output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
251- decoded_outputs.append(output_text)
252-
253- return self.label_mapper.predict(decoded_outputs)
195+ from ontolearner import AutoLLM
196+ from typing import List
197+ import torch
198+
199+ class MistralLLM(AutoLLM):
200+
201+ def load(self, model_id: str) -> None:
202+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
203+ from mistral_common.models.modeling_mistral import Mistral3ForConditionalGeneration
204+
205+ self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
206+
207+ device_map = "cpu" if self.device == "cpu" else "balanced"
208+ self.model = Mistral3ForConditionalGeneration.from_pretrained(
209+ model_id,
210+ device_map=device_map,
211+ torch_dtype=torch.bfloat16,
212+ token=self.token
213+ )
214+
215+ if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token_id is None:
216+ self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
217+
218+ self.label_mapper.fit()
219+
220+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
221+ from mistral_common.protocol.instruct.messages import ChatCompletionRequest
222+
223+ tokenized_list = []
224+ for prompt in inputs:
225+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
226+ tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
227+ tokenized_list.append(tokenized.tokens)
228+
229+ # Pad inputs and create attention masks
230+ max_len = max(len(tokens) for tokens in tokenized_list)
231+ input_ids, attention_masks = [], []
232+ for tokens in tokenized_list:
233+ pad_length = max_len - len(tokens)
234+ input_ids.append(tokens + [self.tokenizer.pad_token_id] * pad_length)
235+ attention_masks.append([1] * len(tokens) + [0] * pad_length)
236+
237+ input_ids = torch.tensor(input_ids).to(self.model.device)
238+ attention_masks = torch.tensor(attention_masks).to(self.model.device)
239+
240+ outputs = self.model.generate(
241+ input_ids=input_ids,
242+ attention_mask=attention_masks,
243+ eos_token_id=self.model.generation_config.eos_token_id,
244+ pad_token_id=self.tokenizer.pad_token_id,
245+ max_new_tokens=max_new_tokens,
246+ )
247+
248+ decoded_outputs = []
249+ for i, tokens in enumerate(outputs):
250+ output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
251+ decoded_outputs.append(output_text)
252+
253+ return self.label_mapper.predict(decoded_outputs)
254254
255255
256256Once your custom class is defined, you can pass it into ``AutoLLMLearner ``:
0 commit comments