numerically stable way to multiply log probability matrices in numpy
logsumexp
works by evaluating the right-hand side of the equation
log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])
I.e., it pulls out the max before starting to sum, to prevent overflow in exp
. The same can be applied before doing vector dot products:
log(exp[a] ⋅ exp[b]) = log(∑ exp[a] × exp[b]) = log(∑ exp[a + b]) = max(a + b) + log(∑ exp[a + b - max(a + b)]) { this is logsumexp(a + b) }
but by taking a different turn in the derivation, we obtain
log(∑ exp[a] × exp[b]) = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)]) = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])
The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm
def logdotexp(A, B): max_A = np.max(A) max_B = np.max(B) C = np.dot(np.exp(A - max_A), np.exp(B - max_B)) np.log(C, out=C) C += max_A + max_B return C
This creates two A
-sized temporaries and two B
-sized ones, but one of each can be eliminated by
exp_A = A - max_Anp.exp(exp_A, out=exp_A)
and similarly for B
. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)
Suppose A.shape==(n,r)
and B.shape==(r,m)
. In computing the matrix product C=A*B
, there are actually n*m
summations. To have stable results when you're working in log-space, You need the logsumexp trick in each of these summations. Fortunately, using numpy broadcasting that's quite easy to control stability of rows and columns of A and B separately.
Here is the code:
def logdotexp(A, B): max_A = np.max(A,1,keepdims=True) max_B = np.max(B,0,keepdims=True) C = np.dot(np.exp(A - max_A), np.exp(B - max_B)) np.log(C, out=C) C += max_A + max_B return C
Note:
The reasoning behind this is similar to the FredFoo's answer, but he used a single maximum value for each matrix. Since he did not consider every n*m
summations, some elements of the final matrix might still be unstable as mentioned in one of the comments.
Comparing with the currently accepted answer using @identity-m counter example:
def logdotexp_less_stable(A, B): max_A = np.max(A) max_B = np.max(B) C = np.dot(np.exp(A - max_A), np.exp(B - max_B)) np.log(C, out=C) C += max_A + max_B return Cprint('old method:')print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))print('new method:')print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
which prints
old method:[[ -inf 0.69314718] [ -inf 0.69314718]]new method:[[-9.99306853e+02 6.93147181e-01] [-9.99306853e+02 6.93147181e-01]]
You are accessing columns of res
and b
, which has poor locality of reference. One thing to try is to store these in column-major order.