@@ -497,6 +497,75 @@ def test_dlpack(backend):
497497 np .testing .assert_allclose (a , a1 , atol = 1e-5 )
498498
499499
500+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
501+ def test_backend_reshaped_basic (backend ):
502+ a1 = tc .backend .convert_to_tensor (np .arange (27 ))
503+ r1 = tc .backend .reshaped (a1 , 3 )
504+ assert r1 .shape == (3 , 3 , 3 )
505+ np .testing .assert_allclose (tc .backend .numpy (r1 ), np .arange (27 ).reshape (3 , 3 , 3 ))
506+ d , n = 4 , 3
507+ dim = d ** n
508+ mat = np .arange (dim * dim , dtype = np .float32 ).reshape (dim , dim )
509+ a2 = tc .backend .convert_to_tensor (mat )
510+ r2 = tc .backend .reshaped (a2 , d )
511+ assert r2 .shape == (d ,) * (2 * n )
512+ np .testing .assert_allclose (tc .backend .numpy (r2 ), mat .reshape ((d ,) * (2 * n )))
513+
514+
515+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
516+ def test_backend_reshaped_zero_size (backend ):
517+ """size == 0 returns a canonical empty vector shape (0,) regardless of input shape."""
518+ a0 = tc .backend .convert_to_tensor (np .array ([], dtype = np .float32 ))
519+ r0 = tc .backend .reshaped (a0 , 3 )
520+ assert r0 .shape == (0 ,)
521+ assert tc .backend .sizen (r0 ) == 0
522+
523+ a1 = tc .backend .convert_to_tensor (np .zeros ((2 , 0 ), dtype = np .float32 ))
524+ r1 = tc .backend .reshaped (a1 , 5 )
525+ assert r1 .shape == (0 ,)
526+ assert tc .backend .sizen (r1 ) == 0
527+
528+
529+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
530+ def test_backend_reshaped_dtype_device_preserved (backend ):
531+ """Reshape should not change dtype or device."""
532+ a = tc .backend .ones ([16 ], dtype = "float32" )
533+ dev = tc .backend .device (a )
534+ r = tc .backend .reshaped (a , 2 )
535+ assert r .shape == (2 , 2 , 2 , 2 )
536+ assert tc .backend .dtype (r ) == tc .backend .dtype (a )
537+ assert tc .backend .device (r ) == dev
538+
539+
540+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
541+ def test_backend_reshaped_scalar_size_one (backend ):
542+ """size == 1 stays scalar: nleg = 0 so shape () is kept."""
543+ a = tc .backend .ones ([]) # scalar tensor, total size = 1
544+ r = tc .backend .reshaped (a , 2 )
545+ assert r .shape == ()
546+ np .testing .assert_allclose (tc .backend .numpy (r ), tc .backend .numpy (a ))
547+
548+
549+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
550+ def test_backend_reshaped_invalid_d_raises (backend ):
551+ """d must be a positive integer: non-int or <=0 should raise."""
552+ a = tc .backend .ones ([4 ], dtype = "float32" )
553+ with pytest .raises (ValueError ):
554+ tc .backend .reshaped (a , 0 )
555+ with pytest .raises (ValueError ):
556+ tc .backend .reshaped (a , - 2 )
557+ with pytest .raises (ValueError ):
558+ tc .backend .reshaped (a , 2.5 ) # not an int
559+
560+
561+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
562+ def test_backend_reshaped_non_power_raises (backend ):
563+ """When size is not a power of d, raise ValueError."""
564+ a = tc .backend .convert_to_tensor (np .arange (10 ))
565+ with pytest .raises (ValueError ):
566+ tc .backend .reshaped (a , 3 )
567+
568+
500569@pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
501570def test_arg_cmp (backend ):
502571 np .testing .assert_allclose (tc .backend .argmax (tc .backend .ones ([3 ], "float64" )), 0 )
0 commit comments