-
Notifications
You must be signed in to change notification settings - Fork 174
(draft) colpali WIP #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(draft) colpali WIP #394
Changes from 10 commits
0583425
f63f77c
c4bd2c0
ee62c69
3c36c28
34f557c
9274c7d
d4f4e5a
7ca807e
317ccec
d581de9
e43f680
423bb28
dcae3ab
62a065e
68ce437
c040120
367178b
1fff39b
667eee1
a9bddf3
e747c34
0ff8f49
274e0d7
054873a
3684462
4d856c6
b6a51c0
b9eebef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import contextlib | ||
| from typing import Any, Dict, Iterable, List | ||
|
|
||
| import numpy as np | ||
| from PIL import Image | ||
|
|
||
| from fastembed.common import ImageInput | ||
| from fastembed.common.onnx_model import OnnxOutputContext | ||
| from fastembed.image.onnx_embedding import OnnxImageEmbedding | ||
|
|
||
| supported_onnx_models = [ | ||
| { | ||
| "model": "akshayballal/colpali-v1.2-merged", | ||
| "dim": (1030, 128), | ||
| "description": "Image embeddings, Unimodal (image), Aligned to text latent space via PaliGemma-3B, 512 patches max, 2024.", | ||
| "license": "mit", | ||
| "size_in_GB": 6.08, | ||
| "sources": { | ||
| "hf": "akshayballal/colpali-v1.2-merged-onnx", | ||
| }, | ||
| "additional_files": ["model.onnx_data"], | ||
| "model_file": "model.onnx", | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| class ColpaliImageModel(OnnxImageEmbedding): | ||
| empty_text_placeholder = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) | ||
| even_attention_mask = np.array([1] * 1030) | ||
|
|
||
| def _preprocess_onnx_input( | ||
| self, onnx_input: Dict[str, np.ndarray], **kwargs | ||
| ) -> Dict[str, np.ndarray]: | ||
| onnx_input["input_ids"] = np.array( | ||
| [self.empty_text_placeholder for _ in onnx_input["input_ids"]] | ||
| ) | ||
| onnx_input["attention_mask"] = np.array( | ||
| [self.even_attention_mask for _ in onnx_input["input_ids"]] | ||
| ) | ||
| return onnx_input | ||
|
|
||
| @classmethod | ||
| def list_supported_models(cls) -> List[Dict[str, Any]]: | ||
| """ | ||
| Lists the supported models. | ||
| Returns: | ||
| List[Dict[str, Any]]: A list of dictionaries containing the model information. | ||
| """ | ||
| return supported_onnx_models | ||
|
|
||
| def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: | ||
| return output.model_output.astype(np.float32) | ||
|
|
||
| def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: | ||
| with contextlib.ExitStack(): | ||
| image_files = [ | ||
| Image.open(image) if not isinstance(image, Image.Image) else image | ||
| for image in images | ||
| ] | ||
| encoded = self.processor(image_files) | ||
| onnx_input = self._build_onnx_input(encoded) | ||
| onnx_input = self._preprocess_onnx_input(onnx_input) | ||
|
|
||
| model_output = self.model.run(None, onnx_input) | ||
| embeddings = model_output[0].reshape(len(images), *supported_onnx_models[0]["dim"]) | ||
I8dNLo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return OnnxOutputContext(model_output=embeddings) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| from typing import Any, Dict, Iterable, List | ||
|
|
||
| import numpy as np | ||
|
|
||
| from fastembed.common.onnx_model import OnnxOutputContext | ||
| from fastembed.text.onnx_embedding import OnnxTextEmbedding | ||
|
|
||
| supported_onnx_models = [ | ||
| { | ||
| "model": "akshayballal/colpali-v1.2-merged", | ||
|
||
| "dim": (16, 128), | ||
| "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", | ||
| "license": "mit", | ||
| "size_in_GB": 6.08, | ||
| "sources": { | ||
| "hf": "akshayballal/colpali-v1.2-merged-onnx", | ||
| }, | ||
| "additional_files": [ | ||
| "model.onnx_data", | ||
| "tokenizer.json", | ||
| "tokenizer_config.json", | ||
| "config.json", | ||
| ], | ||
| "model_file": "model.onnx", | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| class ColpaliTextModel(OnnxTextEmbedding): | ||
| query_prefix = "Query: " | ||
| bos_token = "<s>" | ||
| pad_token = "<pad>" | ||
| query_tokens = [2, 9413] | ||
| image_placeholder_size = (3, 448, 448) | ||
I8dNLo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def _preprocess_onnx_input( | ||
| self, onnx_input: Dict[str, np.ndarray], **kwargs | ||
| ) -> Dict[str, np.ndarray]: | ||
| empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) | ||
| onnx_input["pixel_values"] = np.array( | ||
| [empty_image_placeholder for _ in onnx_input["input_ids"]] | ||
| ) | ||
| onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) | ||
| return onnx_input | ||
|
|
||
| @classmethod | ||
| def list_supported_models(cls) -> List[Dict[str, Any]]: | ||
| """ | ||
| Lists the supported models. | ||
| Returns: | ||
| List[Dict[str, Any]]: A list of dictionaries containing the model information. | ||
| """ | ||
| return supported_onnx_models | ||
|
|
||
| def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: | ||
| return output.model_output.astype(np.float32) | ||
|
|
||
| def _preprocess_queries(self, documents: List[str]): | ||
| texts_query: List[str] = [] | ||
|
|
||
| for query in documents: | ||
| query = self.bos_token + self.query_prefix + query + self.pad_token * 10 | ||
| query += "\n" | ||
|
|
||
| texts_query.append(query) | ||
| return texts_query | ||
|
|
||
| def onnx_embed( | ||
| self, | ||
| documents: List[str], | ||
| **kwargs, | ||
| ) -> OnnxOutputContext: | ||
| documents = self._preprocess_queries(documents) | ||
I8dNLo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.tokenizer.enable_truncation(max_length=10000) | ||
I8dNLo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| encoded = self.tokenize(documents, **kwargs) | ||
| input_ids = np.array([self.query_tokens + e.ids[2:] for e in encoded]) | ||
|
|
||
| attention_mask = np.array([e.attention_mask for e in encoded]) | ||
| onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} | ||
| onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) | ||
| onnx_input["attention_mask"] = attention_mask | ||
| model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) | ||
| return OnnxOutputContext( | ||
| model_output=model_output[0], | ||
| attention_mask=onnx_input.get("attention_mask", attention_mask), | ||
| input_ids=onnx_input.get("input_ids", input_ids), | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.