@@ -782,12 +782,6 @@ def test_tensordot_axes_errors():
782782 with pytest .raises (ValueError ):
783783 dpt .tensordot (m1 , m2 , axes = - 1 )
784784
785- with pytest .raises (ValueError ):
786- dpt .tensordot (m1 , m2 , axes = ((- 1 ,), (1 ,)))
787-
788- with pytest .raises (ValueError ):
789- dpt .tensordot (m1 , m2 , axes = ((1 ,), (- 1 ,)))
790-
791785
792786@pytest .mark .parametrize ("dtype" , _numeric_types )
793787def test_vecdot_1d (dtype ):
@@ -834,7 +828,7 @@ def test_vecdot_axis(dtype):
834828
835829 v2 = dpt .ones ((m1 , n , m2 ), dtype = dtype )
836830
837- r = dpt .vecdot (v1 , v2 , axis = 1 )
831+ r = dpt .vecdot (v1 , v2 , axis = - 2 )
838832
839833 assert r .shape == (
840834 m1 ,
@@ -864,7 +858,7 @@ def test_vecdot_strided(dtype):
864858 :, :n , ::- 1
865859 ]
866860
867- r = dpt .vecdot (v1 , v2 , axis = 1 )
861+ r = dpt .vecdot (v1 , v2 , axis = - 2 )
868862
869863 ref = sum (
870864 el1 * el2
@@ -903,6 +897,9 @@ def test_vector_arg_validation():
903897 with pytest .raises (ValueError ):
904898 dpt .vecdot (v1 , v2 , axis = 2 )
905899
900+ with pytest .raises (ValueError ):
901+ dpt .vecdot (v1 , v2 , axis = - 2 )
902+
906903 q = dpctl .SyclQueue (
907904 v2 .sycl_context , v2 .sycl_device , property = "enable_profiling"
908905 )
0 commit comments