numerically stable way to multiply log probability matrices in numpy numerically stable way to multiply log probability matrices in numpy numpy numpy

numerically stable way to multiply log probability matrices in numpy


logsumexp works by evaluating the right-hand side of the equation

log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])

I.e., it pulls out the max before starting to sum, to prevent overflow in exp. The same can be applied before doing vector dot products:

log(exp[a] ⋅ exp[b]) = log(∑ exp[a] × exp[b]) = log(∑ exp[a + b]) = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }

but by taking a different turn in the derivation, we obtain

log(∑ exp[a] × exp[b]) = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)]) = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])

The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm

def logdotexp(A, B):    max_A = np.max(A)    max_B = np.max(B)    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))    np.log(C, out=C)    C += max_A + max_B    return C

This creates two A-sized temporaries and two B-sized ones, but one of each can be eliminated by

exp_A = A - max_Anp.exp(exp_A, out=exp_A)

and similarly for B. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)


Suppose A.shape==(n,r) and B.shape==(r,m). In computing the matrix product C=A*B, there are actually n*m summations. To have stable results when you're working in log-space, You need the logsumexp trick in each of these summations. Fortunately, using numpy broadcasting that's quite easy to control stability of rows and columns of A and B separately.

Here is the code:

def logdotexp(A, B):    max_A = np.max(A,1,keepdims=True)    max_B = np.max(B,0,keepdims=True)    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))    np.log(C, out=C)    C += max_A + max_B    return C

Note:

The reasoning behind this is similar to the FredFoo's answer, but he used a single maximum value for each matrix. Since he did not consider every n*m summations, some elements of the final matrix might still be unstable as mentioned in one of the comments.

Comparing with the currently accepted answer using @identity-m counter example:

def logdotexp_less_stable(A, B):    max_A = np.max(A)    max_B = np.max(B)    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))    np.log(C, out=C)    C += max_A + max_B    return Cprint('old method:')print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))print('new method:')print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))

which prints

old method:[[      -inf 0.69314718] [      -inf 0.69314718]]new method:[[-9.99306853e+02  6.93147181e-01] [-9.99306853e+02  6.93147181e-01]]


You are accessing columns of res and b, which has poor locality of reference. One thing to try is to store these in column-major order.