Skip to content

Commit c563272

Browse files
committed
gemma
1 parent 54afb59 commit c563272

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

_unittests/ut_tasks/try_tasks.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,71 @@ def test_sentence_similary_alibaba_nlp_gte(self):
791791
scores = (embeddings[:1] @ embeddings[1:].T) * 100
792792
print(scores.tolist())
793793

794+
@never_test()
795+
def test_imagetext2text_generation_gemma3_4b_it(self):
796+
"""
797+
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k gemma3_4b_it
798+
"""
799+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
800+
801+
model_id = "google/gemma-3-4b-it"
802+
# model_id = "google/gemma-3n-e4b-it"
803+
# model_id = "qnaug/gemma-3-4b-med"
804+
# model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
805+
# data = get_untrained_model_with_inputs(
806+
# model_id, verbose=1, add_second_input=True,
807+
# same_as_pretrained=True, use_pretrained=True
808+
# )
809+
# model = data["model"]
810+
model = Gemma3ForConditionalGeneration.from_pretrained(
811+
model_id, device_map="cpu"
812+
).eval()
813+
print(f"-- model.device={model.device}")
814+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
815+
print(f"-- processor={type(processor)}")
816+
817+
messages = messages = [
818+
{
819+
"role": "system",
820+
"content": [{"type": "text", "text": "You are a helpful assistant."}],
821+
},
822+
{
823+
"role": "user",
824+
"content": [
825+
{
826+
"type": "image",
827+
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
828+
},
829+
{"type": "text", "text": "Describe this image in detail."},
830+
],
831+
},
832+
]
833+
inputs = processor.apply_chat_template(
834+
messages,
835+
tokenize=True,
836+
add_generation_prompt=True,
837+
return_dict=True,
838+
return_tensors="pt",
839+
).to(model.device, dtype=torch.bfloat16)
840+
# if "token_type_ids" in inputs:
841+
# print(
842+
# f"-- remove token_type_ids: "
843+
# f"{self.string_type(inputs['token_type_ids'], with_shape=True)}"
844+
# )
845+
# inputs.pop("token_type_ids", None)
846+
print(f"-- inputs={self.string_type(inputs)}")
847+
848+
print()
849+
# steal forward creates a bug...
850+
with steal_forward(model): # , torch.inference_mode():
851+
generated_ids = model.generate(
852+
**inputs, max_new_tokens=300, do_sample=False, cache_implementation="hybrid"
853+
)
854+
output_text = processor.decode(
855+
generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False
856+
)
857+
print(output_text)
858+
794859

795860
if __name__ == "__main__":
796861
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def get_untrained_model_with_inputs(
193193
)
194194
if verbose:
195195
print(
196-
f"[get_untrained_model_with_inputs] -- done in "
196+
f"[get_untrained_model_with_inputs] -- done(1) in "
197197
f"{time.perf_counter() - begin}s"
198198
)
199199
else:
@@ -250,14 +250,36 @@ def get_untrained_model_with_inputs(
250250
)
251251
if verbose:
252252
print(
253-
f"[get_untrained_model_with_inputs] -- done in "
253+
f"[get_untrained_model_with_inputs] -- done(2) in "
254254
f"{time.perf_counter() - begin}s"
255255
)
256256

257257
seed = int(os.environ.get("SEED", "17"))
258258
torch.manual_seed(seed)
259+
260+
if verbose:
261+
begin = time.perf_counter()
262+
print(
263+
f"[get_untrained_model_with_inputs] "
264+
f"instantiate_specific_model {cls_model}"
265+
)
266+
259267
model = instantiate_specific_model(cls_model, config)
268+
269+
if verbose:
270+
print(
271+
f"[get_untrained_model_with_inputs] -- done(3) in "
272+
f"{time.perf_counter() - begin}s (model is {type(model)})"
273+
)
274+
260275
if model is None:
276+
277+
if verbose:
278+
print(
279+
f"[get_untrained_model_with_inputs] "
280+
f"instantiate_specific_model(2) {cls_model}"
281+
)
282+
261283
try:
262284
if type(config) is dict:
263285
model = cls_model(**config)
@@ -268,6 +290,12 @@ def get_untrained_model_with_inputs(
268290
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
269291
) from e
270292

293+
if verbose:
294+
print(
295+
f"[get_untrained_model_with_inputs] -- done(4) in "
296+
f"{time.perf_counter() - begin}s (model is {type(model)})"
297+
)
298+
271299
# input kwargs
272300
seed = int(os.environ.get("SEED", "17")) + 1
273301
torch.manual_seed(seed)

0 commit comments

Comments
 (0)