forked from ModelTC/LightCompress
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtoken_reduction_module.py
More file actions
37 lines (32 loc) · 1.28 KB
/
token_reduction_module.py
File metadata and controls
37 lines (32 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import time
import torch
from loguru import logger
class TokenReductionModule:
def __init__(self, config, model, blocks):
self.config = config
self.model = model
self.blocks = blocks
self.set_sparse_config()
def set_sparse_config(self):
self.special_config = self.config.get('special', {})
self.special_config['is_video_model'] = self.model.pruning_config['is_video_model']
# vision_token can be image or video
if self.special_config['is_video_model']:
self.special_config['vision_token_index'] = self.model.pruning_config[
'video_token_index'
]
self.special_config['vision_token_length'] = self.model.pruning_config[
'video_token_length'
]
else:
self.special_config['vision_token_index'] = self.model.pruning_config.get(
'image_token_index', None
)
self.special_config['vision_token_start_index'] = self.model.pruning_config.get(
'vision_token_start_index', None
)
self.special_config['vision_token_length'] = self.model.pruning_config.get(
'image_token_length', None
)
def register_reduction_modules(self):
pass