2
2
sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
3
3
4
4
import torch
5
- import torch .nn as nn
6
- import torch .nn .functional as F
7
- from transformers import (
8
- PixtralImageProcessor ,
9
- PixtralVisionModel ,
10
- )
11
- from PIL import Image
12
- import requests
13
- import safetensors
14
5
15
6
from exllamav2 import (
16
7
ExLlamaV2 ,
17
8
ExLlamaV2Config ,
18
9
ExLlamaV2Cache ,
19
10
ExLlamaV2Tokenizer ,
20
- ExLlamaV2MultimodalProjector
11
+ ExLlamaV2MultimodalProjector ,
12
+ ExLlamaV2VisionTower
21
13
)
22
14
23
15
from exllamav2 .generator import (
26
18
ExLlamaV2MMEmbedding
27
19
)
28
20
29
- # Unquantized model used for this experiment:
21
+ from PIL import Image
22
+ import requests
23
+
24
+ # Get an input image
25
+
26
+ url = "https://pbs.twimg.com/media/BAeuBsnCIAAUITV.jpg:large"
27
+ image = Image .open (requests .get (url , stream = True ).raw )
28
+
29
+ # Unquantized model used for experiment:
30
30
#
31
31
# https://huggingface.co/mistral-community/pixtral-12b/
32
32
33
33
model_directory = "/mnt/str/models/pixtral-12b"
34
34
config = ExLlamaV2Config (model_directory )
35
-
36
- # PixtralVisionModel expects vision tower keys to be prefixed with "vision_encoder", but the checkpoint prefixes
37
- # them with "vision_tower". Patch the model implementation to allow the model to load with from_pretrained.
38
-
39
- PixtralVisionModel .base_model_prefix = "vision_tower"
35
+ config .max_seq_len = 32768 # default is 1M
40
36
41
37
# Load multimodal projector
42
38
43
39
multimodal_projector = ExLlamaV2MultimodalProjector (config )
44
40
multimodal_projector .load ()
45
41
46
- with torch .inference_mode ():
47
-
48
- # Initialize preprocessor, vision model and multimodal projector
42
+ # Load vision tower and preprocessor
49
43
50
- image_processor = PixtralImageProcessor .from_pretrained (model_directory , device_map = "cuda:0" )
51
- vision_model = PixtralVisionModel .from_pretrained (
52
- model_directory ,
53
- device_map = "cuda:0" ,
54
- hidden_act = "silu"
55
- )
44
+ vision_tower = ExLlamaV2VisionTower (config )
45
+ vision_tower .load (progress = True )
56
46
57
- # multimodal_projector = ExLlamaV2MultimodalProjector()
58
- # safetensors.torch.load_model(
59
- # multimodal_projector,
60
- # os.path.join(model_directory, "model-00001-of-00006.safetensors"),
61
- # strict = False,
62
- # )
63
- # multimodal_projector.half().to("cuda:0")
47
+ # Preprocess
64
48
65
- # Get an input image and process it
49
+ image_tensor = vision_tower .preprocess (image )
50
+ image_tensor = image_tensor .cuda ()
51
+ image_size = tuple (image_tensor .shape [1 :])
66
52
67
- # url = "https://i.imgur.com/JMDz9pC.jpeg"
68
- # image = Image.open(requests.get(url, stream = True).raw)
69
- image_path = "car2.jpg"
70
- image = Image .open (image_path )
53
+ # Produce embeddings
71
54
72
- inputs = image_processor (image , return_tensors = "pt" )
73
- pixel_values = [inputs ["pixel_values" ][0 ][0 ].to ("cuda:0" , torch .half )]
74
- image_features = vision_model (pixel_values )
75
- image_features = multimodal_projector .forward (image_features .hidden_states [0 ].half ())
76
- image_features = image_features [0 ]
77
- image_size = inputs ["image_sizes" ][0 ][0 ]
55
+ embeddings = vision_tower .process (image_tensor )
56
+ embeddings = multimodal_projector .forward (embeddings )[0 ]
78
57
79
58
# Load EXL2 model
80
59
94
73
img_break = model .modules [0 ].forward (torch .tensor ([id_break ], dtype = torch .long )).to ("cuda:0" )
95
74
img_end = model .modules [0 ].forward (torch .tensor ([id_end ], dtype = torch .long )).to ("cuda:0" )
96
75
97
- dim = image_features .shape [- 1 ]
98
- image_features = image_features .view ((features_y , features_x , dim ))
76
+ dim = embeddings .shape [- 1 ]
77
+ embeddings = embeddings .view ((features_y , features_x , dim ))
99
78
break_col = img_break .expand (features_y , - 1 , - 1 )
100
- image_features = torch .cat ((image_features , break_col ), dim = 1 )
101
- image_features = image_features .view ((features_y * (features_x + 1 )), dim )
102
- image_features = torch .cat ((image_features , img_end ), dim = 0 )
79
+ embeddings = torch .cat ((embeddings , break_col ), dim = 1 )
80
+ embeddings = embeddings .view ((features_y * (features_x + 1 )), dim )
81
+ embeddings = torch .cat ((embeddings , img_end ), dim = 0 )
103
82
104
83
# Create generator
105
84
111
90
112
91
# Create an MMEmbedding for the image features and a prompt containing the placeholder string
113
92
114
- image_tokens = ExLlamaV2MMEmbedding (
93
+ image_tokens_a = ExLlamaV2MMEmbedding (
115
94
model = model ,
116
- embeddings = image_features ,
117
- text_alias = "{{EMBED_HERE }}"
95
+ embeddings = embeddings ,
96
+ text_alias = "{{EMBED_A }}"
118
97
)
119
98
120
- prompt = "[INST] {{EMBED_HERE }}\n Describe the image. [/INST]"
99
+ prompt = "[INST]{{EMBED_A }}\n Describe the image.[/INST]"
121
100
122
101
# Pass embeddings to generator
123
102
124
103
output = generator .generate (
125
104
prompt = prompt ,
126
- max_new_tokens = 200 ,
105
+ max_new_tokens = 500 ,
127
106
add_bos = True ,
128
107
encode_special_tokens = True ,
129
108
decode_special_tokens = True ,
130
109
stop_conditions = [tokenizer .eos_token_id ],
131
- # gen_settings = ExLlamaV2Sampler.Settings.greedy(),
132
- embeddings = [image_tokens ],
110
+ gen_settings = ExLlamaV2Sampler .Settings .greedy (),
111
+ embeddings = [image_tokens_a ],
133
112
)
134
113
135
- print (output )
114
+ print (output )
0 commit comments