Understanding PyTorch einsum Understanding PyTorch einsum python python

Understanding PyTorch einsum


Since the description of einsum is skimpy in torch documentation, I decided to write this post to document, compare and contrast how torch.einsum() behaves when compared to numpy.einsum().

Differences:

  • NumPy allows both small case and capitalized letters [a-zA-Z] for the "subscript string" whereas PyTorch allows only the small case letters [a-z].

  • NumPy accepts nd-arrays, plain Python lists (or tuples), list of lists (or tuple of tuples, list of tuples, tuple of lists) or even PyTorch tensors as operands (i.e. inputs). This is because the operands have only to be array_like and not strictly NumPy nd-arrays. On the contrary, PyTorch expects the operands (i.e. inputs) strictly to be PyTorch tensors. It will throw a TypeError if you pass either plain Python lists/tuples (or its combinations) or NumPy nd-arrays.

  • NumPy supports lot of keyword arguments (for e.g. optimize) in addition to nd-arrays while PyTorch doesn't offer such flexibility yet.

Here are the implementations of some examples both in PyTorch and NumPy:

# input tensors to work withIn [16]: vecOut[16]: tensor([0, 1, 2, 3])In [17]: atenOut[17]: tensor([[11, 12, 13, 14],        [21, 22, 23, 24],        [31, 32, 33, 34],        [41, 42, 43, 44]])In [18]: btenOut[18]: tensor([[1, 1, 1, 1],        [2, 2, 2, 2],        [3, 3, 3, 3],        [4, 4, 4, 4]])

1) Matrix multiplication
PyTorch: torch.matmul(aten, bten) ; aten.mm(bten)
NumPy : np.einsum("ij, jk -> ik", arr1, arr2)

In [19]: torch.einsum('ij, jk -> ik', aten, bten)Out[19]: tensor([[130, 130, 130, 130],        [230, 230, 230, 230],        [330, 330, 330, 330],        [430, 430, 430, 430]])

2) Extract elements along the main-diagonal
PyTorch: torch.diag(aten)
NumPy : np.einsum("ii -> i", arr)

In [28]: torch.einsum('ii -> i', aten)Out[28]: tensor([11, 22, 33, 44])

3) Hadamard product (i.e. element-wise product of two tensors)
PyTorch: aten * bten
NumPy : np.einsum("ij, ij -> ij", arr1, arr2)

In [34]: torch.einsum('ij, ij -> ij', aten, bten)Out[34]: tensor([[ 11,  12,  13,  14],        [ 42,  44,  46,  48],        [ 93,  96,  99, 102],        [164, 168, 172, 176]])

4) Element-wise squaring
PyTorch: aten ** 2
NumPy : np.einsum("ij, ij -> ij", arr, arr)

In [37]: torch.einsum('ij, ij -> ij', aten, aten)Out[37]: tensor([[ 121,  144,  169,  196],        [ 441,  484,  529,  576],        [ 961, 1024, 1089, 1156],        [1681, 1764, 1849, 1936]])

General: Element-wise nth power can be implemented by repeating the subscript string and tensor n times.For e.g., computing element-wise 4th power of a tensor can be done using:

# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)Out[38]: tensor([[  14641,   20736,   28561,   38416],        [ 194481,  234256,  279841,  331776],        [ 923521, 1048576, 1185921, 1336336],        [2825761, 3111696, 3418801, 3748096]])

5) Trace (i.e. sum of main-diagonal elements)
PyTorch: torch.trace(aten)
NumPy einsum: np.einsum("ii -> ", arr)

In [44]: torch.einsum('ii -> ', aten)Out[44]: tensor(110)

6) Matrix transpose
PyTorch: torch.transpose(aten, 1, 0)
NumPy einsum: np.einsum("ij -> ji", arr)

In [58]: torch.einsum('ij -> ji', aten)Out[58]: tensor([[11, 21, 31, 41],        [12, 22, 32, 42],        [13, 23, 33, 43],        [14, 24, 34, 44]])

7) Outer Product (of vectors)
PyTorch: torch.ger(vec, vec)
NumPy einsum: np.einsum("i, j -> ij", vec, vec)

In [73]: torch.einsum('i, j -> ij', vec, vec)Out[73]: tensor([[0, 0, 0, 0],        [0, 1, 2, 3],        [0, 2, 4, 6],        [0, 3, 6, 9]])

8) Inner Product (of vectors) PyTorch: torch.dot(vec1, vec2)
NumPy einsum: np.einsum("i, i -> ", vec1, vec2)

In [76]: torch.einsum('i, i -> ', vec, vec)Out[76]: tensor(14)

9) Sum along axis 0
PyTorch: torch.sum(aten, 0)
NumPy einsum: np.einsum("ij -> j", arr)

In [85]: torch.einsum('ij -> j', aten)Out[85]: tensor([104, 108, 112, 116])

10) Sum along axis 1
PyTorch: torch.sum(aten, 1)
NumPy einsum: np.einsum("ij -> i", arr)

In [86]: torch.einsum('ij -> i', aten)Out[86]: tensor([ 50,  90, 130, 170])

11) Batch Matrix Multiplication
PyTorch: torch.bmm(batch_tensor_1, batch_tensor_2)
NumPy : np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

# input batch tensors to work withIn [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)  Out[15]: tensor([[[  20,   23,   26,   29],         [  56,   68,   80,   92],         [  92,  113,  134,  155],         [ 128,  158,  188,  218]],        [[ 632,  671,  710,  749],         [ 776,  824,  872,  920],         [ 920,  977, 1034, 1091],         [1064, 1130, 1196, 1262]]])# sanity check with the shapesIn [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape Out[16]: torch.Size([2, 4, 4])# batch matrix multiply using einsumIn [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)Out[17]: tensor([[[  20,   23,   26,   29],         [  56,   68,   80,   92],         [  92,  113,  134,  155],         [ 128,  158,  188,  218]],        [[ 632,  671,  710,  749],         [ 776,  824,  872,  920],         [ 920,  977, 1034, 1091],         [1064, 1130, 1196, 1262]]])# sanity check with the shapesIn [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape

12) Sum along axis 2
PyTorch: torch.sum(batch_ten, 2)
NumPy einsum: np.einsum("ijk -> ij", arr3D)

In [99]: torch.einsum("ijk -> ij", batch_ten)Out[99]: tensor([[ 50,  90, 130, 170],        [  4,   8,  12,  16]])

13) Sum all the elements in an nD tensor
PyTorch: torch.sum(batch_ten)
NumPy einsum: np.einsum("ijk -> ", arr3D)

In [101]: torch.einsum("ijk -> ", batch_ten)Out[101]: tensor(480)

14) Sum over multiple axes (i.e. marginalization)
PyTorch: torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)

# 8D tensorIn [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))In [104]: nDten.shapeOut[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])# marginalize out dimension 5 (i.e. "n" here)In [111]: esum = torch.einsum("ijklmnop -> n", nDten)In [112]: esumOut[112]: tensor([  98.6921, -206.0575])# marginalize out axis 5 (i.e. sum over rest of the axes)In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))In [115]: torch.allclose(tsum, esum)Out[115]: True

15) Double Dot Products / Frobenius inner product (same as: torch.sum(hadamard-product) cf. 3)
PyTorch: torch.sum(aten * bten)
NumPy : np.einsum("ij, ij -> ", arr1, arr2)

In [120]: torch.einsum("ij, ij -> ", aten, bten)Out[120]: tensor(1300)