How do I get numpy.einsum to play well with sympy?
Einsum basically supersedes tensordot (not dot, because dot is normally using optimized linear algebra packages), code wise it is completely different.
Here is an object einsum, its untested (for more complex things), but I think it should work... Doing the same thing in C is probably even simpler since you can steal everything but the loop itself from the real einsum function. So if you feel like it, implement it and make more people happy...
https://gist.github.com/seberg/5236560
I will not guarantee anything, especially not for weirder corner cases. Of course you can translate einsum notation to tensordot notation too I am sure, and that is probably a bit faster since the loops would end up being mostly in C...
Here is a much simpler implementation that separates the einsum
in multiple tensordot
s.
def einsum(string, *args): index_groups = map(list, string.split(',')) assert len(index_groups) == len(args) tensor_indices_tuples = zip(index_groups, args) return reduce(einsum_for_two, tensor_indices_tuples)[1]def einsum_for_two(tensor_indices1, tensor_indices2): string1, tensor1 = tensor_indices1 string2, tensor2 = tensor_indices2 sum_over_indices = set(string1).intersection(set(string2)) new_string = string1 + string2 axes = ([], []) for i in sum_over_indices: new_string.remove(i) new_string.remove(i) axes[0].append(string1.index(i)) axes[1].append(string2.index(i)) return new_string, np.tensordot(tensor1, tensor2, axes)
First it separates the einsum
arguments in tuples of (indices, tensor). Then it reduces of the list as follows:
- Takes the first two tuples, and evaluates a simple
einsum_for_two
on them. It also prints out the new indices signature. - The value of
einsum_for_two
is used with the next tuple in the list as the new arguments foreinsum_for_two
. - Continues until there is only tuple left. The indices signature is discarded and only the tensor is returned.
It is probably slow (but anyway you are using object
dtype
). It does not do many correctness checks on the input.
As @seberg noted, my code does not work for traces of tensors.