@@ -6746,6 +6746,7 @@ def test_functional_error(self):
67466746@pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
67476747@needs_cuda
67486748class TestToCVCUDATensor :
6749+ @pytest .mark .parametrize ("image_type" , (torch .Tensor , tv_tensors .Image ))
67496750 @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .uint16 , torch .float32 , torch .float64 ])
67506751 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
67516752 @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
@@ -6754,30 +6755,24 @@ class TestToCVCUDATensor:
67546755 "fn" ,
67556756 [F .to_cvcuda_tensor , transform_cls_to_functional (transforms .ToCVCUDATensor )],
67566757 )
6757- def test_functional_and_transform (self , dtype , device , color_space , batch_dims , fn ):
6758- input = make_image_tensor (dtype = dtype , device = device , color_space = color_space , batch_dims = batch_dims )
6759- output = fn (input )
6758+ def test_functional_and_transform (self , image_type , dtype , device , color_space , batch_dims , fn ):
6759+ image = make_image (dtype = dtype , device = device , color_space = color_space , batch_dims = batch_dims )
6760+ if image_type is torch .Tensor :
6761+ image = image .as_subclass (torch .Tensor )
6762+ assert is_pure_tensor (image )
6763+ output = fn (image )
67606764
67616765 assert isinstance (output , cvcuda .Tensor )
6762- assert F .get_size (output ) == F .get_size (input )
6766+ assert F .get_size (output ) == F .get_size (image )
67636767 assert output is not None
67646768
67656769 def test_invalid_input_type (self ):
6766- with pytest .raises (TypeError , match = r"inpt should be `torch.Tensor`" ):
6770+ with pytest .raises (TypeError , match = r"inpt should be `` torch.Tensor` `" ):
67676771 F .to_cvcuda_tensor ("invalid_input" )
67686772
67696773 def test_invalid_dimensions (self ):
67706774 with pytest .raises (ValueError , match = r"pic should be 4 dimensional" ):
6771- img_data = torch .randint (
6772- 0 ,
6773- 256 ,
6774- (
6775- 3 ,
6776- 12 ,
6777- 34 ,
6778- ),
6779- dtype = torch .uint8 ,
6780- )
6775+ img_data = torch .randint (0 , 256 , (3 , 1 , 3 ), dtype = torch .uint8 )
67816776 img_data = img_data .cuda ()
67826777 F .to_cvcuda_tensor (img_data )
67836778
0 commit comments