Why is 2 * (i * i) faster than 2 * i * i in Java? Why is 2 * (i * i) faster than 2 * i * i in Java? java java

Why is 2 * (i * i) faster than 2 * i * i in Java?


There is a slight difference in the ordering of the bytecode.

2 * (i * i):

     iconst_2     iload0     iload0     imul     imul     iadd

vs 2 * i * i:

     iconst_2     iload0     imul     iload0     imul     iadd

At first sight this should not make a difference; if anything the second version is more optimal since it uses one slot less.

So we need to dig deeper into the lower level (JIT)1.

Remember that JIT tends to unroll small loops very aggressively. Indeed we observe a 16x unrolling for the 2 * (i * i) case:

030   B2: # B2 B3 <- B1 B2  Loop: B2-B2 inner main of N18 Freq: 1e+006030     addl    R11, RBP    # int033     movl    RBP, R13    # spill036     addl    RBP, #14    # int039     imull   RBP, RBP    # int03c     movl    R9, R13 # spill03f     addl    R9, #13 # int043     imull   R9, R9  # int047     sall    RBP, #1049     sall    R9, #104c     movl    R8, R13 # spill04f     addl    R8, #15 # int053     movl    R10, R8 # spill056     movdl   XMM1, R8    # spill05b     imull   R10, R8 # int05f     movl    R8, R13 # spill062     addl    R8, #12 # int066     imull   R8, R8  # int06a     sall    R10, #106d     movl    [rsp + #32], R10    # spill072     sall    R8, #1075     movl    RBX, R13    # spill078     addl    RBX, #11    # int07b     imull   RBX, RBX    # int07e     movl    RCX, R13    # spill081     addl    RCX, #10    # int084     imull   RCX, RCX    # int087     sall    RBX, #1089     sall    RCX, #108b     movl    RDX, R13    # spill08e     addl    RDX, #8 # int091     imull   RDX, RDX    # int094     movl    RDI, R13    # spill097     addl    RDI, #7 # int09a     imull   RDI, RDI    # int09d     sall    RDX, #109f     sall    RDI, #10a1     movl    RAX, R13    # spill0a4     addl    RAX, #6 # int0a7     imull   RAX, RAX    # int0aa     movl    RSI, R13    # spill0ad     addl    RSI, #4 # int0b0     imull   RSI, RSI    # int0b3     sall    RAX, #10b5     sall    RSI, #10b7     movl    R10, R13    # spill0ba     addl    R10, #2 # int0be     imull   R10, R10    # int0c2     movl    R14, R13    # spill0c5     incl    R14 # int0c8     imull   R14, R14    # int0cc     sall    R10, #10cf     sall    R14, #10d2     addl    R14, R11    # int0d5     addl    R14, R10    # int0d8     movl    R10, R13    # spill0db     addl    R10, #3 # int0df     imull   R10, R10    # int0e3     movl    R11, R13    # spill0e6     addl    R11, #5 # int0ea     imull   R11, R11    # int0ee     sall    R10, #10f1     addl    R10, R14    # int0f4     addl    R10, RSI    # int0f7     sall    R11, #10fa     addl    R11, R10    # int0fd     addl    R11, RAX    # int100     addl    R11, RDI    # int103     addl    R11, RDX    # int106     movl    R10, R13    # spill109     addl    R10, #9 # int10d     imull   R10, R10    # int111     sall    R10, #1114     addl    R10, R11    # int117     addl    R10, RCX    # int11a     addl    R10, RBX    # int11d     addl    R10, R8 # int120     addl    R9, R10 # int123     addl    RBP, R9 # int126     addl    RBP, [RSP + #32 (32-bit)]   # int12a     addl    R13, #16    # int12e     movl    R11, R13    # spill131     imull   R11, R13    # int135     sall    R11, #1138     cmpl    R13, #99999998513f     jl     B2   # loop end  P=1.000000 C=6554623.000000

We see that there is 1 register that is "spilled" onto the stack.

And for the 2 * i * i version:

05a   B3: # B2 B4 <- B1 B2  Loop: B3-B2 inner main of N18 Freq: 1e+00605a     addl    RBX, R11    # int05d     movl    [rsp + #32], RBX    # spill061     movl    R11, R8 # spill064     addl    R11, #15    # int068     movl    [rsp + #36], R11    # spill06d     movl    R11, R8 # spill070     addl    R11, #14    # int074     movl    R10, R9 # spill077     addl    R10, #16    # int07b     movdl   XMM2, R10   # spill080     movl    RCX, R9 # spill083     addl    RCX, #14    # int086     movdl   XMM1, RCX   # spill08a     movl    R10, R9 # spill08d     addl    R10, #12    # int091     movdl   XMM4, R10   # spill096     movl    RCX, R9 # spill099     addl    RCX, #10    # int09c     movdl   XMM6, RCX   # spill0a0     movl    RBX, R9 # spill0a3     addl    RBX, #8 # int0a6     movl    RCX, R9 # spill0a9     addl    RCX, #6 # int0ac     movl    RDX, R9 # spill0af     addl    RDX, #4 # int0b2     addl    R9, #2  # int0b6     movl    R10, R14    # spill0b9     addl    R10, #22    # int0bd     movdl   XMM3, R10   # spill0c2     movl    RDI, R14    # spill0c5     addl    RDI, #20    # int0c8     movl    RAX, R14    # spill0cb     addl    RAX, #32    # int0ce     movl    RSI, R14    # spill0d1     addl    RSI, #18    # int0d4     movl    R13, R14    # spill0d7     addl    R13, #24    # int0db     movl    R10, R14    # spill0de     addl    R10, #26    # int0e2     movl    [rsp + #40], R10    # spill0e7     movl    RBP, R14    # spill0ea     addl    RBP, #28    # int0ed     imull   RBP, R11    # int0f1     addl    R14, #30    # int0f5     imull   R14, [RSP + #36 (32-bit)]   # int0fb     movl    R10, R8 # spill0fe     addl    R10, #11    # int102     movdl   R11, XMM3   # spill107     imull   R11, R10    # int10b     movl    [rsp + #44], R11    # spill110     movl    R10, R8 # spill113     addl    R10, #10    # int117     imull   RDI, R10    # int11b     movl    R11, R8 # spill11e     addl    R11, #8 # int122     movdl   R10, XMM2   # spill127     imull   R10, R11    # int12b     movl    [rsp + #48], R10    # spill130     movl    R10, R8 # spill133     addl    R10, #7 # int137     movdl   R11, XMM1   # spill13c     imull   R11, R10    # int140     movl    [rsp + #52], R11    # spill145     movl    R11, R8 # spill148     addl    R11, #6 # int14c     movdl   R10, XMM4   # spill151     imull   R10, R11    # int155     movl    [rsp + #56], R10    # spill15a     movl    R10, R8 # spill15d     addl    R10, #5 # int161     movdl   R11, XMM6   # spill166     imull   R11, R10    # int16a     movl    [rsp + #60], R11    # spill16f     movl    R11, R8 # spill172     addl    R11, #4 # int176     imull   RBX, R11    # int17a     movl    R11, R8 # spill17d     addl    R11, #3 # int181     imull   RCX, R11    # int185     movl    R10, R8 # spill188     addl    R10, #2 # int18c     imull   RDX, R10    # int190     movl    R11, R8 # spill193     incl    R11 # int196     imull   R9, R11 # int19a     addl    R9, [RSP + #32 (32-bit)]    # int19f     addl    R9, RDX # int1a2     addl    R9, RCX # int1a5     addl    R9, RBX # int1a8     addl    R9, [RSP + #60 (32-bit)]    # int1ad     addl    R9, [RSP + #56 (32-bit)]    # int1b2     addl    R9, [RSP + #52 (32-bit)]    # int1b7     addl    R9, [RSP + #48 (32-bit)]    # int1bc     movl    R10, R8 # spill1bf     addl    R10, #9 # int1c3     imull   R10, RSI    # int1c7     addl    R10, R9 # int1ca     addl    R10, RDI    # int1cd     addl    R10, [RSP + #44 (32-bit)]   # int1d2     movl    R11, R8 # spill1d5     addl    R11, #12    # int1d9     imull   R13, R11    # int1dd     addl    R13, R10    # int1e0     movl    R10, R8 # spill1e3     addl    R10, #13    # int1e7     imull   R10, [RSP + #40 (32-bit)]   # int1ed     addl    R10, R13    # int1f0     addl    RBP, R10    # int1f3     addl    R14, RBP    # int1f6     movl    R10, R8 # spill1f9     addl    R10, #16    # int1fd     cmpl    R10, #999999985204     jl     B2   # loop end  P=1.000000 C=7419903.000000

Here we observe much more "spilling" and more accesses to the stack [RSP + ...], due to more intermediate results that need to be preserved.

Thus the answer to the question is simple: 2 * (i * i) is faster than 2 * i * i because the JIT generates more optimal assembly code for the first case.


But of course it is obvious that neither the first nor the second version is any good; the loop could really benefit from vectorization, since any x86-64 CPU has at least SSE2 support.

So it's an issue of the optimizer; as is often the case, it unrolls too aggressively and shoots itself in the foot, all the while missing out on various other opportunities.

In fact, modern x86-64 CPUs break down the instructions further into micro-ops (µops) and with features like register renaming, µop caches and loop buffers, loop optimization takes a lot more finesse than a simple unrolling for optimal performance. According to Agner Fog's optimization guide:

The gain in performance due to the µop cache can be quiteconsiderable if the average instruction length is more than 4 bytes.The following methods of optimizing the use of the µop cache maybe considered:

  • Make sure that critical loops are small enough to fit into the µop cache.
  • Align the most critical loop entries and function entries by 32.
  • Avoid unnecessary loop unrolling.
  • Avoid instructions that have extra load time
    . . .

Regarding those load times - even the fastest L1D hit costs 4 cycles, an extra register and µop, so yes, even a few accesses to memory will hurt performance in tight loops.

But back to the vectorization opportunity - to see how fast it can be, we can compile a similar C application with GCC, which outright vectorizes it (AVX2 is shown, SSE2 is similar)2:

  vmovdqa ymm0, YMMWORD PTR .LC0[rip]  vmovdqa ymm3, YMMWORD PTR .LC1[rip]  xor eax, eax  vpxor xmm2, xmm2, xmm2.L2:  vpmulld ymm1, ymm0, ymm0  inc eax  vpaddd ymm0, ymm0, ymm3  vpslld ymm1, ymm1, 1  vpaddd ymm2, ymm2, ymm1  cmp eax, 125000000      ; 8 calculations per iteration  jne .L2  vmovdqa xmm0, xmm2  vextracti128 xmm2, ymm2, 1  vpaddd xmm2, xmm0, xmm2  vpsrldq xmm0, xmm2, 8  vpaddd xmm0, xmm2, xmm0  vpsrldq xmm1, xmm0, 4  vpaddd xmm0, xmm0, xmm1  vmovd eax, xmm0  vzeroupper

With run times:

  • SSE: 0.24 s, or 2 times as fast.
  • AVX: 0.15 s, or 3 times as fast.
  • AVX2: 0.08 s, or 5 times as fast.

1 To get JIT generated assembly output, get a debug JVM and run with -XX:+PrintOptoAssembly

2 The C version is compiled with the -fwrapv flag, which enables GCC to treat signed integer overflow as a two's-complement wrap-around.


(Editor's note: this answer is contradicted by evidence from looking at the asm, as shown by another answer. This was a guess backed up by some experiments, but it turned out not to be correct.)


When the multiplication is 2 * (i * i), the JVM is able to factor out the multiplication by 2 from the loop, resulting in this equivalent but more efficient code:

int n = 0;for (int i = 0; i < 1000000000; i++) {    n += i * i;}n *= 2;

but when the multiplication is (2 * i) * i, the JVM doesn't optimize it since the multiplication by a constant is no longer right before the n += addition.

Here are a few reasons why I think this is the case:

  • Adding an if (n == 0) n = 1 statement at the start of the loop results in both versions being as efficient, since factoring out the multiplication no longer guarantees that the result will be the same
  • The optimized version (by factoring out the multiplication by 2) is exactly as fast as the 2 * (i * i) version

Here is the test code that I used to draw these conclusions:

public static void main(String[] args) {    long fastVersion = 0;    long slowVersion = 0;    long optimizedVersion = 0;    long modifiedFastVersion = 0;    long modifiedSlowVersion = 0;    for (int i = 0; i < 10; i++) {        fastVersion += fastVersion();        slowVersion += slowVersion();        optimizedVersion += optimizedVersion();        modifiedFastVersion += modifiedFastVersion();        modifiedSlowVersion += modifiedSlowVersion();    }    System.out.println("Fast version: " + (double) fastVersion / 1000000000 + " s");    System.out.println("Slow version: " + (double) slowVersion / 1000000000 + " s");    System.out.println("Optimized version: " + (double) optimizedVersion / 1000000000 + " s");    System.out.println("Modified fast version: " + (double) modifiedFastVersion / 1000000000 + " s");    System.out.println("Modified slow version: " + (double) modifiedSlowVersion / 1000000000 + " s");}private static long fastVersion() {    long startTime = System.nanoTime();    int n = 0;    for (int i = 0; i < 1000000000; i++) {        n += 2 * (i * i);    }    return System.nanoTime() - startTime;}private static long slowVersion() {    long startTime = System.nanoTime();    int n = 0;    for (int i = 0; i < 1000000000; i++) {        n += 2 * i * i;    }    return System.nanoTime() - startTime;}private static long optimizedVersion() {    long startTime = System.nanoTime();    int n = 0;    for (int i = 0; i < 1000000000; i++) {        n += i * i;    }    n *= 2;    return System.nanoTime() - startTime;}private static long modifiedFastVersion() {    long startTime = System.nanoTime();    int n = 0;    for (int i = 0; i < 1000000000; i++) {        if (n == 0) n = 1;        n += 2 * (i * i);    }    return System.nanoTime() - startTime;}private static long modifiedSlowVersion() {    long startTime = System.nanoTime();    int n = 0;    for (int i = 0; i < 1000000000; i++) {        if (n == 0) n = 1;        n += 2 * i * i;    }    return System.nanoTime() - startTime;}

And here are the results:

Fast version: 5.7274411 sSlow version: 7.6190804 sOptimized version: 5.1348007 sModified fast version: 7.1492705 sModified slow version: 7.2952668 s


Byte codes: https://cs.nyu.edu/courses/fall00/V22.0201-001/jvm2.htmlByte codes Viewer: https://github.com/Konloch/bytecode-viewer

On my JDK (Windows 10 64 bit, 1.8.0_65-b17) I can reproduce and explain:

public static void main(String[] args) {    int repeat = 10;    long A = 0;    long B = 0;    for (int i = 0; i < repeat; i++) {        A += test();        B += testB();    }    System.out.println(A / repeat + " ms");    System.out.println(B / repeat + " ms");}private static long test() {    int n = 0;    for (int i = 0; i < 1000; i++) {        n += multi(i);    }    long startTime = System.currentTimeMillis();    for (int i = 0; i < 1000000000; i++) {        n += multi(i);    }    long ms = (System.currentTimeMillis() - startTime);    System.out.println(ms + " ms A " + n);    return ms;}private static long testB() {    int n = 0;    for (int i = 0; i < 1000; i++) {        n += multiB(i);    }    long startTime = System.currentTimeMillis();    for (int i = 0; i < 1000000000; i++) {        n += multiB(i);    }    long ms = (System.currentTimeMillis() - startTime);    System.out.println(ms + " ms B " + n);    return ms;}private static int multiB(int i) {    return 2 * (i * i);}private static int multi(int i) {    return 2 * i * i;}

Output:

...405 ms A 785527736327 ms B 785527736404 ms A 785527736329 ms B 785527736404 ms A 785527736328 ms B 785527736404 ms A 785527736328 ms B 785527736410 ms333 ms

So why?The byte code is this:

 private static multiB(int arg0) { // 2 * (i * i)     <localVar:index=0, name=i , desc=I, sig=null, start=L1, end=L2>     L1 {         iconst_2         iload0         iload0         imul         imul         ireturn     }     L2 {     } } private static multi(int arg0) { // 2 * i * i     <localVar:index=0, name=i , desc=I, sig=null, start=L1, end=L2>     L1 {         iconst_2         iload0         imul         iload0         imul         ireturn     }     L2 {     } }

The difference being:With brackets (2 * (i * i)):

  • push const stack
  • push local on stack
  • push local on stack
  • multiply top of stack
  • multiply top of stack

Without brackets (2 * i * i):

  • push const stack
  • push local on stack
  • multiply top of stack
  • push local on stack
  • multiply top of stack

Loading all on the stack and then working back down is faster than switching between putting on the stack and operating on it.