Understanding tensordot Understanding tensordot python python

Understanding tensordot


The idea with tensordot is pretty simple - We input the arrays and the respective axes along which the sum-reductions are intended. The axes that take part in sum-reduction are removed in the output and all of the remaining axes from the input arrays are spread-out as different axes in the output keeping the order in which the input arrays are fed.

Let's look at few sample cases with one and two axes of sum-reductions and also swap the input places and see how the order is kept in the output.

I. One axis of sum-reduction

Inputs :

 In [7]: A = np.random.randint(2, size=(2, 6, 5))   ...:  B = np.random.randint(2, size=(3, 2, 4))   ...: 

Case #1:

In [9]: np.tensordot(A, B, axes=((0),(1))).shapeOut[9]: (6, 5, 3, 4)A : (2, 6, 5) -> reduction of axis=0B : (3, 2, 4) -> reduction of axis=1Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`

Case #2 (same as case #1 but the inputs are fed swapped):

In [8]: np.tensordot(B, A, axes=((1),(0))).shapeOut[8]: (3, 4, 6, 5)B : (3, 2, 4) -> reduction of axis=1A : (2, 6, 5) -> reduction of axis=0Output : `(3, 2, 4)`, `(2, 6, 5)` ===(2 gone)==> `(3,4)` + `(6,5)` => `(3,4,6,5)`.

II. Two axes of sum-reduction

Inputs :

In [11]: A = np.random.randint(2, size=(2, 3, 5))    ...: B = np.random.randint(2, size=(3, 2, 4))    ...: 

Case #1:

In [12]: np.tensordot(A, B, axes=((0,1),(1,0))).shapeOut[12]: (5, 4)A : (2, 3, 5) -> reduction of axis=(0,1)B : (3, 2, 4) -> reduction of axis=(1,0)Output : `(2, 3, 5)`, `(3, 2, 4)` ===(2,3 gone)==> `(5)` + `(4)` => `(5,4)`

Case #2:

In [14]: np.tensordot(B, A, axes=((1,0),(0,1))).shapeOut[14]: (4, 5)B : (3, 2, 4) -> reduction of axis=(1,0)A : (2, 3, 5) -> reduction of axis=(0,1)Output : `(3, 2, 4)`, `(2, 3, 5)` ===(2,3 gone)==> `(4)` + `(5)` => `(4,5)`

We can extend this to as many axes as possible.


tensordot swaps axes and reshapes the inputs so it can apply np.dot to 2 2d arrays. It then swaps and reshapes back to the target. It may be easier to experiment than to explain. There's no special tensor math going on, just extending dot to work in higher dimensions. tensor just means arrays with more than 2d. If you are already comfortable with einsum then it will be simplest compare the results to that.

A sample test, summing on 1 pair of axes

In [823]: np.tensordot(A,B,[0,1]).shapeOut[823]: (3, 5, 3, 4)In [824]: np.einsum('ijk,lim',A,B).shapeOut[824]: (3, 5, 3, 4)In [825]: np.allclose(np.einsum('ijk,lim',A,B),np.tensordot(A,B,[0,1]))Out[825]: True

another, summing on two.

In [826]: np.tensordot(A,B,[(0,1),(1,0)]).shapeOut[826]: (5, 4)In [827]: np.einsum('ijk,jim',A,B).shapeOut[827]: (5, 4)In [828]: np.allclose(np.einsum('ijk,jim',A,B),np.tensordot(A,B,[(0,1),(1,0)]))Out[828]: True

We could do same with the (1,0) pair. Given the mix of dimension I don't think there's another combination.


The answers above are great and helped me a lot in understanding tensordot. But they don't show actual math behind operations. That's why I did equivalent operations in TF 2 for myself and decided to share them here:

a = tf.constant([1,2.])b = tf.constant([2,3.])print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('i,j', a, b)\t\t- ((the last 0 axes of a), (the first 0 axes of b))")print(f"{tf.tensordot(a, b, ((),()))}\t tf.einsum('i,j', a, b)\t\t- ((() axis of a), (() axis of b))")print(f"{tf.tensordot(b, a, 0)}\t tf.einsum('i,j->ji', a, b)\t- ((the last 0 axes of b), (the first 0 axes of a))")print(f"{tf.tensordot(a, b, 1)}\t\t tf.einsum('i,i', a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")print(f"{tf.tensordot(a, b, ((0,), (0,)))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, (0,0))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")[[2. 3.] [4. 6.]]    tf.einsum('i,j', a, b)     - ((the last 0 axes of a), (the first 0 axes of b))[[2. 3.] [4. 6.]]    tf.einsum('i,j', a, b)     - ((() axis of a), (() axis of b))[[2. 4.] [3. 6.]]    tf.einsum('i,j->ji', a, b) - ((the last 0 axes of b), (the first 0 axes of a))8.0          tf.einsum('i,i', a, b)     - ((the last 1 axes of a), (the first 1 axes of b))8.0          tf.einsum('i,i', a, b)     - ((0th axis of a), (0th axis of b))8.0          tf.einsum('i,i', a, b)     - ((0th axis of a), (0th axis of b))

And for (2,2) shape:

a = tf.constant([[1,2],                 [-2,3.]])b = tf.constant([[-2,3],                 [0,4.]])print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('ij,kl', a, b)\t- ((the last 0 axes of a), (the first 0 axes of b))")print(f"{tf.tensordot(a, b, (0,0))}\t tf.einsum('ij,ik', a, b)\t- ((0th axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, (0,1))}\t tf.einsum('ij,ki', a, b)\t- ((0th axis of a), (1st axis of b))")print(f"{tf.tensordot(a, b, 1)}\t tf.matmul(a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")print(f"{tf.tensordot(a, b, ((1,), (0,)))}\t tf.einsum('ij,jk', a, b)\t- ((1st axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, (1, 0))}\t tf.matmul(a, b)\t\t- ((1st axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, 2)}\t tf.reduce_sum(tf.multiply(a, b))\t- ((the last 2 axes of a), (the first 2 axes of b))")print(f"{tf.tensordot(a, b, ((0,1), (0,1)))}\t tf.einsum('ij,ij->', a, b)\t\t- ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))")[[[[-2.  3.]   [ 0.  4.]]  [[-4.  6.]   [ 0.  8.]]] [[[ 4. -6.]   [-0. -8.]]  [[-6.  9.]   [ 0. 12.]]]]  tf.einsum('ij,kl', a, b)   - ((the last 0 axes of a), (the first 0 axes of b))[[-2. -5.] [-4. 18.]]      tf.einsum('ij,ik', a, b)   - ((0th axis of a), (0th axis of b))[[-8. -8.] [ 5. 12.]]      tf.einsum('ij,ki', a, b)   - ((0th axis of a), (1st axis of b))[[-2. 11.] [ 4.  6.]]      tf.matmul(a, b)            - ((the last 1 axes of a), (the first 1 axes of b))[[-2. 11.] [ 4.  6.]]      tf.einsum('ij,jk', a, b)   - ((1st axis of a), (0th axis of b))[[-2. 11.] [ 4.  6.]]      tf.matmul(a, b)            - ((1st axis of a), (0th axis of b))16.0    tf.reduce_sum(tf.multiply(a, b))    - ((the last 2 axes of a), (the first 2 axes of b))16.0    tf.einsum('ij,ij->', a, b)          - ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))