-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprompt_clip.py
More file actions
191 lines (157 loc) · 8.29 KB
/
prompt_clip.py
File metadata and controls
191 lines (157 loc) · 8.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPVisionModel
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import CLIPVisionTransformer, CLIPEncoder
from transformers.models.clip.configuration_clip import CLIPVisionConfig, CLIPConfig
from transformers.utils import replace_return_docstrings
class CLIPEncoderWithPrompt(CLIPEncoder):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(self, config: CLIPConfig, prompt_length: int):
# super().__init__()
# self.config = config
# self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
# self.gradient_checkpointing = False
super().__init__(config)
self.prompt_length = prompt_length
embed_dim = config.hidden_size
total_d_layer = config.num_hidden_layers
self._init_prompt(prompt_length, embed_dim, total_d_layer)
def _init_prompt(self, num_tokens, prompt_dim, total_d_layer):
import math
num_patches = (self.config.image_size // self.config.patch_size) ** 2
val = math.sqrt(6. / float(3 * num_patches + prompt_dim)) # noqa
if total_d_layer >= 0:
#self.prompt_embeddings = nn.Parameter(torch.zeros(1, num_tokens, prompt_dim))
# xavier_uniform initialization
#nn.init.uniform_(self.prompt_embeddings.data, -val, val)
if total_d_layer > 0: # noqa
self.deep_prompt_embeddings = nn.Parameter(torch.zeros(total_d_layer, num_tokens, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
self.prompt_proj = nn.Linear(prompt_dim, prompt_dim)
nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')
self.prompt_norm = nn.LayerNorm(prompt_dim, eps=self.config.layer_norm_eps)
self.prompt_dropout = nn.Dropout(0.1)
else: # total_d_layer < 0
self.deep_prompt_embeddings = nn.Parameter(torch.zeros(abs(total_d_layer), num_tokens, prompt_dim))
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
self.prompt_proj = nn.Linear(prompt_dim, prompt_dim)
nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')
self.prompt_norm = nn.LayerNorm(prompt_dim, eps=self.config.layer_norm_eps)
self.prompt_dropout = nn.Dropout(0.1)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
B, HW, C = hidden_states.shape
HW = HW - 1
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if idx <= self.deep_prompt_embeddings.shape[0]:
deep_prompt_emb = self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[idx]).expand(B, -1, -1))
hidden_states = torch.cat((
hidden_states[:, 0, :].unsqueeze(1),
deep_prompt_emb,
hidden_states[:, 1+self.prompt_length:, :]
), dim=1)
else:
hidden_states = torch.cat((
hidden_states[:, 0, :].unsqueeze(1),
hidden_states[:, -HW:, :]
), dim=1)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
hidden_states = torch.cat((
hidden_states[:, 0, :].unsqueeze(1),
hidden_states[:, -HW:, :]
), dim=1)
hidden_states = self.prompt_norm(hidden_states)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class CLIPVisionTransformerWithPrompt(CLIPVisionTransformer):
def __init__(self, config: CLIPVisionConfig, prompt_length):
super().__init__(config)
if prompt_length != 0:
self.encoder = CLIPEncoderWithPrompt(config, prompt_length)
else:
self.encoder = CLIPEncoder(config)
class CLIPVisionModelWithPrompt(CLIPVisionModel):
def __init__(self, config: CLIPVisionConfig, prompt_length):
super().__init__(config)
if prompt_length != 0:
self.vision_model = CLIPVisionTransformerWithPrompt(config, prompt_length)
else:
self.vision_model = CLIPVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
#clip = CLIPVisionModelWithPrompt.from_pretrained("../clip-vit-large-patch14/", prompt_length=50)
#print(clip)