Understanding tensordot
The idea with tensordot
is pretty simple - We input the arrays and the respective axes along which the sum-reductions are intended. The axes that take part in sum-reduction are removed in the output and all of the remaining axes from the input arrays are spread-out as different axes in the output keeping the order in which the input arrays are fed.
Let's look at few sample cases with one and two axes of sum-reductions and also swap the input places and see how the order is kept in the output.
I. One axis of sum-reduction
Inputs :
In [7]: A = np.random.randint(2, size=(2, 6, 5)) ...: B = np.random.randint(2, size=(3, 2, 4)) ...:
Case #1:
In [9]: np.tensordot(A, B, axes=((0),(1))).shapeOut[9]: (6, 5, 3, 4)A : (2, 6, 5) -> reduction of axis=0B : (3, 2, 4) -> reduction of axis=1Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`
Case #2 (same as case #1 but the inputs are fed swapped):
In [8]: np.tensordot(B, A, axes=((1),(0))).shapeOut[8]: (3, 4, 6, 5)B : (3, 2, 4) -> reduction of axis=1A : (2, 6, 5) -> reduction of axis=0Output : `(3, 2, 4)`, `(2, 6, 5)` ===(2 gone)==> `(3,4)` + `(6,5)` => `(3,4,6,5)`.
II. Two axes of sum-reduction
Inputs :
In [11]: A = np.random.randint(2, size=(2, 3, 5)) ...: B = np.random.randint(2, size=(3, 2, 4)) ...:
Case #1:
In [12]: np.tensordot(A, B, axes=((0,1),(1,0))).shapeOut[12]: (5, 4)A : (2, 3, 5) -> reduction of axis=(0,1)B : (3, 2, 4) -> reduction of axis=(1,0)Output : `(2, 3, 5)`, `(3, 2, 4)` ===(2,3 gone)==> `(5)` + `(4)` => `(5,4)`
Case #2:
In [14]: np.tensordot(B, A, axes=((1,0),(0,1))).shapeOut[14]: (4, 5)B : (3, 2, 4) -> reduction of axis=(1,0)A : (2, 3, 5) -> reduction of axis=(0,1)Output : `(3, 2, 4)`, `(2, 3, 5)` ===(2,3 gone)==> `(4)` + `(5)` => `(4,5)`
We can extend this to as many axes as possible.
tensordot
swaps axes and reshapes the inputs so it can apply np.dot
to 2 2d arrays. It then swaps and reshapes back to the target. It may be easier to experiment than to explain. There's no special tensor math going on, just extending dot
to work in higher dimensions. tensor
just means arrays with more than 2d. If you are already comfortable with einsum
then it will be simplest compare the results to that.
A sample test, summing on 1 pair of axes
In [823]: np.tensordot(A,B,[0,1]).shapeOut[823]: (3, 5, 3, 4)In [824]: np.einsum('ijk,lim',A,B).shapeOut[824]: (3, 5, 3, 4)In [825]: np.allclose(np.einsum('ijk,lim',A,B),np.tensordot(A,B,[0,1]))Out[825]: True
another, summing on two.
In [826]: np.tensordot(A,B,[(0,1),(1,0)]).shapeOut[826]: (5, 4)In [827]: np.einsum('ijk,jim',A,B).shapeOut[827]: (5, 4)In [828]: np.allclose(np.einsum('ijk,jim',A,B),np.tensordot(A,B,[(0,1),(1,0)]))Out[828]: True
We could do same with the (1,0)
pair. Given the mix of dimension I don't think there's another combination.
The answers above are great and helped me a lot in understanding tensordot
. But they don't show actual math behind operations. That's why I did equivalent operations in TF 2 for myself and decided to share them here:
a = tf.constant([1,2.])b = tf.constant([2,3.])print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('i,j', a, b)\t\t- ((the last 0 axes of a), (the first 0 axes of b))")print(f"{tf.tensordot(a, b, ((),()))}\t tf.einsum('i,j', a, b)\t\t- ((() axis of a), (() axis of b))")print(f"{tf.tensordot(b, a, 0)}\t tf.einsum('i,j->ji', a, b)\t- ((the last 0 axes of b), (the first 0 axes of a))")print(f"{tf.tensordot(a, b, 1)}\t\t tf.einsum('i,i', a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")print(f"{tf.tensordot(a, b, ((0,), (0,)))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, (0,0))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")[[2. 3.] [4. 6.]] tf.einsum('i,j', a, b) - ((the last 0 axes of a), (the first 0 axes of b))[[2. 3.] [4. 6.]] tf.einsum('i,j', a, b) - ((() axis of a), (() axis of b))[[2. 4.] [3. 6.]] tf.einsum('i,j->ji', a, b) - ((the last 0 axes of b), (the first 0 axes of a))8.0 tf.einsum('i,i', a, b) - ((the last 1 axes of a), (the first 1 axes of b))8.0 tf.einsum('i,i', a, b) - ((0th axis of a), (0th axis of b))8.0 tf.einsum('i,i', a, b) - ((0th axis of a), (0th axis of b))
And for (2,2)
shape:
a = tf.constant([[1,2], [-2,3.]])b = tf.constant([[-2,3], [0,4.]])print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('ij,kl', a, b)\t- ((the last 0 axes of a), (the first 0 axes of b))")print(f"{tf.tensordot(a, b, (0,0))}\t tf.einsum('ij,ik', a, b)\t- ((0th axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, (0,1))}\t tf.einsum('ij,ki', a, b)\t- ((0th axis of a), (1st axis of b))")print(f"{tf.tensordot(a, b, 1)}\t tf.matmul(a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")print(f"{tf.tensordot(a, b, ((1,), (0,)))}\t tf.einsum('ij,jk', a, b)\t- ((1st axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, (1, 0))}\t tf.matmul(a, b)\t\t- ((1st axis of a), (0th axis of b))")print(f"{tf.tensordot(a, b, 2)}\t tf.reduce_sum(tf.multiply(a, b))\t- ((the last 2 axes of a), (the first 2 axes of b))")print(f"{tf.tensordot(a, b, ((0,1), (0,1)))}\t tf.einsum('ij,ij->', a, b)\t\t- ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))")[[[[-2. 3.] [ 0. 4.]] [[-4. 6.] [ 0. 8.]]] [[[ 4. -6.] [-0. -8.]] [[-6. 9.] [ 0. 12.]]]] tf.einsum('ij,kl', a, b) - ((the last 0 axes of a), (the first 0 axes of b))[[-2. -5.] [-4. 18.]] tf.einsum('ij,ik', a, b) - ((0th axis of a), (0th axis of b))[[-8. -8.] [ 5. 12.]] tf.einsum('ij,ki', a, b) - ((0th axis of a), (1st axis of b))[[-2. 11.] [ 4. 6.]] tf.matmul(a, b) - ((the last 1 axes of a), (the first 1 axes of b))[[-2. 11.] [ 4. 6.]] tf.einsum('ij,jk', a, b) - ((1st axis of a), (0th axis of b))[[-2. 11.] [ 4. 6.]] tf.matmul(a, b) - ((1st axis of a), (0th axis of b))16.0 tf.reduce_sum(tf.multiply(a, b)) - ((the last 2 axes of a), (the first 2 axes of b))16.0 tf.einsum('ij,ij->', a, b) - ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))