@@ -154,11 +154,7 @@ def tensordot(x1, x2, axes=2):
154154 same_shapes = True
155155 for i in range (n_axes1 ):
156156 axis1 = axes1 [i ]
157- if axis1 < 0 :
158- raise ValueError ("`axes` must be non-negative" )
159157 axis2 = axes2 [i ]
160- if axis2 < 0 :
161- raise ValueError ("`axes` must be non-negative" )
162158 same_shapes = same_shapes and (x1_shape [axis1 ] == x2_shape [axis2 ])
163159 if not same_shapes :
164160 raise ValueError ("shape mismatch in contracted `tensordot` axes" )
@@ -361,7 +357,7 @@ def vecdot(x1, x2, axis=-1):
361357 elif x2_nd > x1_nd :
362358 x1_shape = (1 ,) * (x2_nd - x1_nd ) + x1_shape
363359 x1_nd = len (x1_shape )
364- axis = normalize_axis_index (operator .index (axis ), x1_nd )
360+ axis = normalize_axis_index (operator .index (axis ), min ( x1_nd , x2_nd ) )
365361 if x1_shape [axis ] != x2_shape [axis ]:
366362 raise ValueError (
367363 "given axis must have the same shape for `x1` and `x2`"
@@ -427,7 +423,7 @@ def vecdot(x1, x2, axis=-1):
427423 ht_dot_ev , _ = tli ._dot (
428424 x1 = x1 ,
429425 x2 = x2 ,
430- batch_dims = len (x1 . shape [: - 1 ] ),
426+ batch_dims = len (res_sh ),
431427 x1_outer_dims = 0 ,
432428 x2_outer_dims = 0 ,
433429 inner_dims = 1 ,
@@ -471,7 +467,7 @@ def vecdot(x1, x2, axis=-1):
471467 ht_dot_ev , _ = tli ._dot (
472468 x1 = x1 ,
473469 x2 = buf2 ,
474- batch_dims = len (x1 . shape [: - 1 ] ),
470+ batch_dims = len (res_sh ),
475471 x1_outer_dims = 0 ,
476472 x2_outer_dims = 0 ,
477473 inner_dims = 1 ,
@@ -513,7 +509,7 @@ def vecdot(x1, x2, axis=-1):
513509 ht_dot_ev , _ = tli ._dot (
514510 x1 = buf1 ,
515511 x2 = x2 ,
516- batch_dims = len (x1 . shape [: - 1 ] ),
512+ batch_dims = len (res_sh ),
517513 x1_outer_dims = 0 ,
518514 x2_outer_dims = 0 ,
519515 inner_dims = 1 ,
@@ -560,7 +556,7 @@ def vecdot(x1, x2, axis=-1):
560556 ht_dot_ev , _ = tli ._dot (
561557 x1 = buf1 ,
562558 x2 = buf2 ,
563- batch_dims = len (x1 . shape [: - 1 ] ),
559+ batch_dims = len (res_sh ),
564560 x1_outer_dims = 0 ,
565561 x2_outer_dims = 0 ,
566562 inner_dims = 1 ,
0 commit comments