Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 58 additions & 31 deletions tests/models/transformers/test_models_transformer_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import CogView4Transformer2DModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)


enable_full_determinism()


class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView4Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class CogView4TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView4Transformer2DModel

@property
def main_input_name(self) -> str:
return "hidden_states"

@property
def uses_custom_attn_processor(self) -> bool:
return True

@property
def dummy_input(self):
def output_shape(self) -> tuple:
return (4, 8, 8)

@property
def input_shape(self) -> tuple:
return (4, 8, 8)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 4,
"out_channels": 4,
"text_embed_dim": 8,
"time_embed_dim": 8,
"condition_dim": 4,
}
Comment on lines +57 to +68
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askserge, is this using the same set model initialization parameters?


def get_dummy_inputs(self) -> dict:
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8

hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
hidden_states = randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device)

return {
"hidden_states": hidden_states,
Expand All @@ -55,29 +95,16 @@ def dummy_input(self):
"crop_coords": crop_coords,
}

@property
def input_shape(self):
return (4, 8, 8)

@property
def output_shape(self):
return (4, 8, 8)
class TestCogView4Transformer(CogView4TransformerTesterConfig, ModelTesterMixin):
pass


class TestCogView4TransformerMemory(CogView4TransformerTesterConfig, MemoryTesterMixin):
pass

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 4,
"out_channels": 4,
"text_embed_dim": 8,
"time_embed_dim": 8,
"condition_dim": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

class TestCogView4TransformerTraining(CogView4TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView4Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Loading