Skip to content

Commit 7c876ef

Browse files
committed
Update Pixtral experiment
1 parent 193a6b2 commit 7c876ef

File tree

2 files changed

+35
-56
lines changed

2 files changed

+35
-56
lines changed
File renamed without changes.

experimental/multimodal_pixtral_hf.py

Lines changed: 35 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,14 @@
22
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
33

44
import torch
5-
import torch.nn as nn
6-
import torch.nn.functional as F
7-
from transformers import (
8-
PixtralImageProcessor,
9-
PixtralVisionModel,
10-
)
11-
from PIL import Image
12-
import requests
13-
import safetensors
145

156
from exllamav2 import (
167
ExLlamaV2,
178
ExLlamaV2Config,
189
ExLlamaV2Cache,
1910
ExLlamaV2Tokenizer,
20-
ExLlamaV2MultimodalProjector
11+
ExLlamaV2MultimodalProjector,
12+
ExLlamaV2VisionTower
2113
)
2214

2315
from exllamav2.generator import (
@@ -26,55 +18,42 @@
2618
ExLlamaV2MMEmbedding
2719
)
2820

29-
# Unquantized model used for this experiment:
21+
from PIL import Image
22+
import requests
23+
24+
# Get an input image
25+
26+
url = "https://pbs.twimg.com/media/BAeuBsnCIAAUITV.jpg:large"
27+
image = Image.open(requests.get(url, stream = True).raw)
28+
29+
# Unquantized model used for experiment:
3030
#
3131
# https://huggingface.co/mistral-community/pixtral-12b/
3232

3333
model_directory = "/mnt/str/models/pixtral-12b"
3434
config = ExLlamaV2Config(model_directory)
35-
36-
# PixtralVisionModel expects vision tower keys to be prefixed with "vision_encoder", but the checkpoint prefixes
37-
# them with "vision_tower". Patch the model implementation to allow the model to load with from_pretrained.
38-
39-
PixtralVisionModel.base_model_prefix = "vision_tower"
35+
config.max_seq_len = 32768 # default is 1M
4036

4137
# Load multimodal projector
4238

4339
multimodal_projector = ExLlamaV2MultimodalProjector(config)
4440
multimodal_projector.load()
4541

46-
with torch.inference_mode():
47-
48-
# Initialize preprocessor, vision model and multimodal projector
42+
# Load vision tower and preprocessor
4943

50-
image_processor = PixtralImageProcessor.from_pretrained(model_directory, device_map = "cuda:0")
51-
vision_model = PixtralVisionModel.from_pretrained(
52-
model_directory,
53-
device_map = "cuda:0",
54-
hidden_act = "silu"
55-
)
44+
vision_tower = ExLlamaV2VisionTower(config)
45+
vision_tower.load(progress = True)
5646

57-
# multimodal_projector = ExLlamaV2MultimodalProjector()
58-
# safetensors.torch.load_model(
59-
# multimodal_projector,
60-
# os.path.join(model_directory, "model-00001-of-00006.safetensors"),
61-
# strict = False,
62-
# )
63-
# multimodal_projector.half().to("cuda:0")
47+
# Preprocess
6448

65-
# Get an input image and process it
49+
image_tensor = vision_tower.preprocess(image)
50+
image_tensor = image_tensor.cuda()
51+
image_size = tuple(image_tensor.shape[1:])
6652

67-
# url = "https://i.imgur.com/JMDz9pC.jpeg"
68-
# image = Image.open(requests.get(url, stream = True).raw)
69-
image_path = "car2.jpg"
70-
image = Image.open(image_path)
53+
# Produce embeddings
7154

72-
inputs = image_processor(image, return_tensors = "pt")
73-
pixel_values = [inputs["pixel_values"][0][0].to("cuda:0", torch.half)]
74-
image_features = vision_model(pixel_values)
75-
image_features = multimodal_projector.forward(image_features.hidden_states[0].half())
76-
image_features = image_features[0]
77-
image_size = inputs["image_sizes"][0][0]
55+
embeddings = vision_tower.process(image_tensor)
56+
embeddings = multimodal_projector.forward(embeddings)[0]
7857

7958
# Load EXL2 model
8059

@@ -94,12 +73,12 @@
9473
img_break = model.modules[0].forward(torch.tensor([id_break], dtype = torch.long)).to("cuda:0")
9574
img_end = model.modules[0].forward(torch.tensor([id_end], dtype = torch.long)).to("cuda:0")
9675

97-
dim = image_features.shape[-1]
98-
image_features = image_features.view((features_y, features_x, dim))
76+
dim = embeddings.shape[-1]
77+
embeddings = embeddings.view((features_y, features_x, dim))
9978
break_col = img_break.expand(features_y, -1, -1)
100-
image_features = torch.cat((image_features, break_col), dim = 1)
101-
image_features = image_features.view((features_y * (features_x + 1)), dim)
102-
image_features = torch.cat((image_features, img_end), dim = 0)
79+
embeddings = torch.cat((embeddings, break_col), dim = 1)
80+
embeddings = embeddings.view((features_y * (features_x + 1)), dim)
81+
embeddings = torch.cat((embeddings, img_end), dim = 0)
10382

10483
# Create generator
10584

@@ -111,25 +90,25 @@
11190

11291
# Create an MMEmbedding for the image features and a prompt containing the placeholder string
11392

114-
image_tokens = ExLlamaV2MMEmbedding(
93+
image_tokens_a = ExLlamaV2MMEmbedding(
11594
model = model,
116-
embeddings = image_features,
117-
text_alias = "{{EMBED_HERE}}"
95+
embeddings = embeddings,
96+
text_alias = "{{EMBED_A}}"
11897
)
11998

120-
prompt = "[INST] {{EMBED_HERE}}\nDescribe the image. [/INST]"
99+
prompt = "[INST]{{EMBED_A}}\nDescribe the image.[/INST]"
121100

122101
# Pass embeddings to generator
123102

124103
output = generator.generate(
125104
prompt = prompt,
126-
max_new_tokens = 200,
105+
max_new_tokens = 500,
127106
add_bos = True,
128107
encode_special_tokens = True,
129108
decode_special_tokens = True,
130109
stop_conditions = [tokenizer.eos_token_id],
131-
# gen_settings = ExLlamaV2Sampler.Settings.greedy(),
132-
embeddings = [image_tokens],
110+
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
111+
embeddings = [image_tokens_a],
133112
)
134113

135-
print(output)
114+
print(output)

0 commit comments

Comments
 (0)