11# SPDX-License-Identifier: Apache-2.0 
22from  dataclasses  import  dataclass , field 
3- from  typing  import  Optional 
3+ from  typing  import  List ,  Optional ,  Tuple 
44
55from  fastvideo .v1 .configs .models .encoders .base  import  (ImageEncoderArchConfig ,
66                                                       ImageEncoderConfig ,
77                                                       TextEncoderArchConfig ,
88                                                       TextEncoderConfig )
99
1010
11+ def  _is_transformer_layer (n : str , m ) ->  bool :
12+     return  "layers"  in  n  and  str .isdigit (n .split ("." )[- 1 ])
13+ 
14+ 
15+ def  _is_embeddings (n : str , m ) ->  bool :
16+     return  n .endswith ("embeddings" )
17+ 
18+ 
1119@dataclass  
1220class  CLIPTextArchConfig (TextEncoderArchConfig ):
1321    vocab_size : int  =  49408 
@@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
2735    bos_token_id : int  =  49406 
2836    eos_token_id : int  =  49407 
2937    text_len : int  =  77 
38+     stacked_params_mapping : List [Tuple [str , str ,
39+                                        str ]] =  field (default_factory = lambda : [
40+                                            # (param_name, shard_name, shard_id) 
41+                                            ("qkv_proj" , "q_proj" , "q" ),
42+                                            ("qkv_proj" , "k_proj" , "k" ),
43+                                            ("qkv_proj" , "v_proj" , "v" ),
44+                                        ])
45+     _fsdp_shard_conditions : list  =  field (
46+         default_factory = lambda : [_is_transformer_layer , _is_embeddings ])
3047
3148
3249@dataclass  
@@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
4562    attention_dropout : float  =  0.0 
4663    initializer_range : float  =  0.02 
4764    initializer_factor : float  =  1.0 
65+     stacked_params_mapping : List [Tuple [str , str ,
66+                                        str ]] =  field (default_factory = lambda : [
67+                                            # (param_name, shard_name, shard_id) 
68+                                            ("qkv_proj" , "q_proj" , "q" ),
69+                                            ("qkv_proj" , "k_proj" , "k" ),
70+                                            ("qkv_proj" , "v_proj" , "v" ),
71+                                        ])
4872
4973
5074@dataclass  
0 commit comments