@@ -1442,7 +1442,6 @@ def forward(
14421442 attention_mask : Optional [torch .Tensor ] = None ,
14431443 position_ids : Optional [torch .LongTensor ] = None ,
14441444 past_key_value : Optional [Cache ] = None ,
1445- output_attentions : Optional [bool ] = False ,
14461445 use_cache : Optional [bool ] = False ,
14471446 cache_position : Optional [torch .LongTensor ] = None ,
14481447 protein_kv_states : Optional [torch .Tensor ] = None ,
@@ -1497,7 +1496,11 @@ class EvollaPreTrainedModel(PreTrainedModel):
14971496 config : EvollaConfig
14981497 base_model_prefix = "model"
14991498 supports_gradient_checkpointing = True
1500- _no_split_modules = ["EvollaDecoderLayer" ]
1499+ _no_split_modules = [
1500+ "EvollaDecoderLayer" ,
1501+ "EvollaSequenceCompressorResampler" ,
1502+ "EvollaSequenceAlignerCrossAttention" ,
1503+ ]
15011504 _skip_keys_device_placement = ["past_key_values" ]
15021505 _supports_flash_attn = True
15031506 _supports_sdpa = True
@@ -1512,20 +1515,8 @@ class EvollaPreTrainedModel(PreTrainedModel):
15121515
15131516 def _init_weights (self , module ):
15141517 std = self .config .initializer_range
1515- if isinstance (module , nn .Linear ):
1516- module .weight .data .normal_ (mean = 0.0 , std = std )
1517- if module .bias is not None :
1518- module .bias .data .zero_ ()
1519- elif isinstance (module , nn .Embedding ):
1520- module .weight .data .normal_ (mean = 0.0 , std = std )
1521- if module .padding_idx is not None :
1522- module .weight .data [module .padding_idx ].zero_ ()
1523- elif isinstance (module , nn .LayerNorm ):
1524- module .bias .data .zero_ ()
1525- module .weight .data .fill_ (1.0 )
1526- elif isinstance (module , EvollaRMSNorm ):
1527- module .weight .data .fill_ (1.0 )
1528- elif isinstance (module , EvollaSequenceAlignerCrossAttention ):
1518+ super ()._init_weights (module )
1519+ if isinstance (module , EvollaSequenceAlignerCrossAttention ):
15291520 module .gate_attention .zero_ ()
15301521 module .gate_ffw .zero_ ()
15311522 module .attention_norm .weight .data .fill_ (1.0 )
@@ -1594,15 +1585,6 @@ def forward(
15941585 msa_batch_mask (torch.Tensor):
15951586 The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
15961587 """
1597- # If not provided `protein_feats`, use the `protein_encoder` to get the protein features
1598- if protein_input_ids is not None and protein_attention_mask is not None :
1599- protein_outputs = self .protein_encoder (
1600- input_ids = protein_input_ids ,
1601- attention_mask = protein_attention_mask ,
1602- )
1603- protein_feats = protein_outputs .sequence_compressor_output
1604- protein_batch_mask = torch .tensor ([True ] * protein_input_ids .shape [0 ], device = protein_input_ids .device )
1605-
16061588 if (input_ids is None ) ^ (inputs_embeds is not None ):
16071589 raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
16081590
@@ -1621,6 +1603,17 @@ def forward(
16211603 if position_ids is None :
16221604 position_ids = cache_position .unsqueeze (0 )
16231605
1606+ protein_feats = None
1607+ protein_batch_mask = None
1608+ # If provided, actually compute them
1609+ if protein_input_ids is not None and protein_attention_mask is not None :
1610+ protein_outputs = self .protein_encoder (
1611+ input_ids = protein_input_ids ,
1612+ attention_mask = protein_attention_mask ,
1613+ )
1614+ protein_feats = protein_outputs .sequence_compressor_output
1615+ protein_batch_mask = torch .tensor ([True ] * protein_input_ids .shape [0 ], device = protein_input_ids .device )
1616+
16241617 causal_mask = create_causal_mask (
16251618 config = self .config ,
16261619 input_embeds = inputs_embeds ,
0 commit comments