@@ -791,6 +791,71 @@ def test_sentence_similary_alibaba_nlp_gte(self):
791791 scores = (embeddings [:1 ] @ embeddings [1 :].T ) * 100
792792 print (scores .tolist ())
793793
794+ @never_test ()
795+ def test_imagetext2text_generation_gemma3_4b_it (self ):
796+ """
797+ clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k gemma3_4b_it
798+ """
799+ from transformers import AutoProcessor , Gemma3ForConditionalGeneration
800+
801+ model_id = "google/gemma-3-4b-it"
802+ # model_id = "google/gemma-3n-e4b-it"
803+ # model_id = "qnaug/gemma-3-4b-med"
804+ # model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
805+ # data = get_untrained_model_with_inputs(
806+ # model_id, verbose=1, add_second_input=True,
807+ # same_as_pretrained=True, use_pretrained=True
808+ # )
809+ # model = data["model"]
810+ model = Gemma3ForConditionalGeneration .from_pretrained (
811+ model_id , device_map = "cpu"
812+ ).eval ()
813+ print (f"-- model.device={ model .device } " )
814+ processor = AutoProcessor .from_pretrained (model_id , use_fast = True )
815+ print (f"-- processor={ type (processor )} " )
816+
817+ messages = messages = [
818+ {
819+ "role" : "system" ,
820+ "content" : [{"type" : "text" , "text" : "You are a helpful assistant." }],
821+ },
822+ {
823+ "role" : "user" ,
824+ "content" : [
825+ {
826+ "type" : "image" ,
827+ "image" : "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" ,
828+ },
829+ {"type" : "text" , "text" : "Describe this image in detail." },
830+ ],
831+ },
832+ ]
833+ inputs = processor .apply_chat_template (
834+ messages ,
835+ tokenize = True ,
836+ add_generation_prompt = True ,
837+ return_dict = True ,
838+ return_tensors = "pt" ,
839+ ).to (model .device , dtype = torch .bfloat16 )
840+ # if "token_type_ids" in inputs:
841+ # print(
842+ # f"-- remove token_type_ids: "
843+ # f"{self.string_type(inputs['token_type_ids'], with_shape=True)}"
844+ # )
845+ # inputs.pop("token_type_ids", None)
846+ print (f"-- inputs={ self .string_type (inputs )} " )
847+
848+ print ()
849+ # steal forward creates a bug...
850+ with steal_forward (model ): # , torch.inference_mode():
851+ generated_ids = model .generate (
852+ ** inputs , max_new_tokens = 300 , do_sample = False , cache_implementation = "hybrid"
853+ )
854+ output_text = processor .decode (
855+ generated_ids [0 ][inputs ["input_ids" ].shape [1 ] :], skip_special_tokens = False
856+ )
857+ print (output_text )
858+
794859
795860if __name__ == "__main__" :
796861 unittest .main (verbosity = 2 )
0 commit comments