Numpy mean of flattened large array slower than mean of mean of all axes Numpy mean of flattened large array slower than mean of mean of all axes numpy numpy

Numpy mean of flattened large array slower than mean of mean of all axes


Let's call your original matrix mat. mat.shape = (10000,32,32,3). Visually, this is like having a "stack" of 10,000 * 32x32x3 * rectangular prisms (I think of them as LEGOs) of floats.

Now lets think about what you did in terms of floating point operations (flops):

In Method A, you do mat.mean(axis=0).mean(axis=0).mean(axis=0). Let's break this down:

  1. You take the mean of each position (i,j,k) across all 10,000 LEGOs. This gives you back a single LEGO of size 32x32x3 which now contains the first set of means. This means you have performed 10,000 additions and 1 division per mean, of which there are 32323 = 3072. In total, you've done 30,723,072 flops.
  2. You then take the mean again, this time of each position (j,k), where i is now the number of the layer (vertical position) you are currently on. This gives you a piece of paper with 32x3 means written on it. You have performed 32 additions and 1 divisions per mean, of which there are 32*3 = 96. In total, you've done 3,168 flops.
  3. Finally, you take the mean of each column k, where j is now the row you are currently on. This gives you a stub with 3 means written on it. You have performed 32 additions and 1 division per mean, of which there are 3. In total, you've done 99 flops.

The grand total of all this is 30,723,072 + 3,168 + 99 = 30,726,339 flops.

In Method B, you do mat_reshaped = mat.reshape(-1,3); mat_means = mat_reshaped.mean(axis=0). Let's break this down:

  1. You reshaped everything, so mat is a long roll of paper of size 10,240,000x3. You take the mean of each column k, where j is now the row you are currently on. This gives you a stub with 3 means written on it. You have performed 10,240,000 additions and 1 division per mean, of which there are 3. In total, you've done 30,720,003 flops.

So now you're saying to yourself "What! All of that work, only to show that the slower method actually does ~less~ work?! " Here's the problem: Although Method B has less work to do, it does not have a lot less work to do, meaning just from a flop standpoint, we would expect things to be similar in terms of runtime.

You also have to consider the size of your reshaped array in Method B: a matrix with 10,240,000 rows is HUGE!!! It's really hard/inefficient for the computer to access all of that, and more memory accesses means longer runtimes. The fact is that in its original 10,000x32x32x3 shape, the matrix was already partitioned into convenient slices that the computer could access more efficiently: this is actually a common technique when handling giant matrices Jaime's response to a similar question or even this article: both talk about how breaking up a big matrix into smaller slices helps your program be more memory efficient, therefore making it run faster.