Skip to content

Commit 4f73af6

Browse files
committed
📝
1 parent ca3ebf9 commit 4f73af6

File tree

1 file changed

+59
-59
lines changed

1 file changed

+59
-59
lines changed

docs/source/learners/llm.rst

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -192,65 +192,65 @@ For this, you can extend the ``AutoLLM`` class and implement the required
192192

193193
::
194194

195-
from ontolearner import AutoLLM
196-
from typing import List
197-
import torch
198-
199-
class MistralLLM(AutoLLM):
200-
201-
def load(self, model_id: str) -> None:
202-
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
203-
from mistral_common.models.modeling_mistral import Mistral3ForConditionalGeneration
204-
205-
self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
206-
207-
device_map = "cpu" if self.device == "cpu" else "balanced"
208-
self.model = Mistral3ForConditionalGeneration.from_pretrained(
209-
model_id,
210-
device_map=device_map,
211-
torch_dtype=torch.bfloat16,
212-
token=self.token
213-
)
214-
215-
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token_id is None:
216-
self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
217-
218-
self.label_mapper.fit()
219-
220-
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
221-
from mistral_common.protocol.instruct.messages import ChatCompletionRequest
222-
223-
tokenized_list = []
224-
for prompt in inputs:
225-
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
226-
tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
227-
tokenized_list.append(tokenized.tokens)
228-
229-
# Pad inputs and create attention masks
230-
max_len = max(len(tokens) for tokens in tokenized_list)
231-
input_ids, attention_masks = [], []
232-
for tokens in tokenized_list:
233-
pad_length = max_len - len(tokens)
234-
input_ids.append(tokens + [self.tokenizer.pad_token_id] * pad_length)
235-
attention_masks.append([1] * len(tokens) + [0] * pad_length)
236-
237-
input_ids = torch.tensor(input_ids).to(self.model.device)
238-
attention_masks = torch.tensor(attention_masks).to(self.model.device)
239-
240-
outputs = self.model.generate(
241-
input_ids=input_ids,
242-
attention_mask=attention_masks,
243-
eos_token_id=self.model.generation_config.eos_token_id,
244-
pad_token_id=self.tokenizer.pad_token_id,
245-
max_new_tokens=max_new_tokens,
246-
)
247-
248-
decoded_outputs = []
249-
for i, tokens in enumerate(outputs):
250-
output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
251-
decoded_outputs.append(output_text)
252-
253-
return self.label_mapper.predict(decoded_outputs)
195+
from ontolearner import AutoLLM
196+
from typing import List
197+
import torch
198+
199+
class MistralLLM(AutoLLM):
200+
201+
def load(self, model_id: str) -> None:
202+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
203+
from mistral_common.models.modeling_mistral import Mistral3ForConditionalGeneration
204+
205+
self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
206+
207+
device_map = "cpu" if self.device == "cpu" else "balanced"
208+
self.model = Mistral3ForConditionalGeneration.from_pretrained(
209+
model_id,
210+
device_map=device_map,
211+
torch_dtype=torch.bfloat16,
212+
token=self.token
213+
)
214+
215+
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token_id is None:
216+
self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
217+
218+
self.label_mapper.fit()
219+
220+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
221+
from mistral_common.protocol.instruct.messages import ChatCompletionRequest
222+
223+
tokenized_list = []
224+
for prompt in inputs:
225+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
226+
tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
227+
tokenized_list.append(tokenized.tokens)
228+
229+
# Pad inputs and create attention masks
230+
max_len = max(len(tokens) for tokens in tokenized_list)
231+
input_ids, attention_masks = [], []
232+
for tokens in tokenized_list:
233+
pad_length = max_len - len(tokens)
234+
input_ids.append(tokens + [self.tokenizer.pad_token_id] * pad_length)
235+
attention_masks.append([1] * len(tokens) + [0] * pad_length)
236+
237+
input_ids = torch.tensor(input_ids).to(self.model.device)
238+
attention_masks = torch.tensor(attention_masks).to(self.model.device)
239+
240+
outputs = self.model.generate(
241+
input_ids=input_ids,
242+
attention_mask=attention_masks,
243+
eos_token_id=self.model.generation_config.eos_token_id,
244+
pad_token_id=self.tokenizer.pad_token_id,
245+
max_new_tokens=max_new_tokens,
246+
)
247+
248+
decoded_outputs = []
249+
for i, tokens in enumerate(outputs):
250+
output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
251+
decoded_outputs.append(output_text)
252+
253+
return self.label_mapper.predict(decoded_outputs)
254254

255255

256256
Once your custom class is defined, you can pass it into ``AutoLLMLearner``:

0 commit comments

Comments
 (0)