-
Notifications
You must be signed in to change notification settings - Fork 457
Expand file tree
/
Copy pathflickr_30k.py
More file actions
68 lines (58 loc) · 2.42 KB
/
flickr_30k.py
File metadata and controls
68 lines (58 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from copy import deepcopy
from typing import TYPE_CHECKING
from loguru import logger
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.typing import Processor
if TYPE_CHECKING:
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
@TextGenerationDataset.register(name="flickr", alias="flickr30k")
class Flickr30K(TextGenerationDataset):
"""
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
DEFAULT_CHAT_TEMPLATE = (
"{% for message in messages %}\n"
"{% if message['role'] == 'user' %}\n"
"{{ '<|user|>\n' + message['content'] + eos_token }}\n"
"{% elif message['role'] == 'system' %}\n"
"{{ '<|system|>\n' + message['content'] + eos_token }}\n"
"{% elif message['role'] == 'assistant' %}\n"
"{{ '<|assistant|>\n' + message['content'] + eos_token }}\n"
"{% endif %}\n"
"{% if loop.last and add_generation_prompt %}\n"
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "lmms-lab/flickr30k"
super().__init__(data_args=data_args, split=split, processor=processor)
if (
self.tokenizer is not None
and getattr(self.tokenizer, "chat_template", None) is None
):
# note that since tokenizer is a member of processor,
# this change affects processor.apply_chat_template
self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE
logger.warning(
"tokenizer.chat_template is not set, using default chat template for "
f"{self.__class__.__name__}"
)
def dataset_template(self, sample):
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What does the image show?"},
],
}
]
return {
"text": self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
),
"images": sample["image"],
}