Why is 2 * (i * i) faster than 2 * i * i in Java?
There is a slight difference in the ordering of the bytecode.
2 * (i * i)
:
iconst_2 iload0 iload0 imul imul iadd
vs 2 * i * i
:
iconst_2 iload0 imul iload0 imul iadd
At first sight this should not make a difference; if anything the second version is more optimal since it uses one slot less.
So we need to dig deeper into the lower level (JIT)1.
Remember that JIT tends to unroll small loops very aggressively. Indeed we observe a 16x unrolling for the 2 * (i * i)
case:
030 B2: # B2 B3 <- B1 B2 Loop: B2-B2 inner main of N18 Freq: 1e+006030 addl R11, RBP # int033 movl RBP, R13 # spill036 addl RBP, #14 # int039 imull RBP, RBP # int03c movl R9, R13 # spill03f addl R9, #13 # int043 imull R9, R9 # int047 sall RBP, #1049 sall R9, #104c movl R8, R13 # spill04f addl R8, #15 # int053 movl R10, R8 # spill056 movdl XMM1, R8 # spill05b imull R10, R8 # int05f movl R8, R13 # spill062 addl R8, #12 # int066 imull R8, R8 # int06a sall R10, #106d movl [rsp + #32], R10 # spill072 sall R8, #1075 movl RBX, R13 # spill078 addl RBX, #11 # int07b imull RBX, RBX # int07e movl RCX, R13 # spill081 addl RCX, #10 # int084 imull RCX, RCX # int087 sall RBX, #1089 sall RCX, #108b movl RDX, R13 # spill08e addl RDX, #8 # int091 imull RDX, RDX # int094 movl RDI, R13 # spill097 addl RDI, #7 # int09a imull RDI, RDI # int09d sall RDX, #109f sall RDI, #10a1 movl RAX, R13 # spill0a4 addl RAX, #6 # int0a7 imull RAX, RAX # int0aa movl RSI, R13 # spill0ad addl RSI, #4 # int0b0 imull RSI, RSI # int0b3 sall RAX, #10b5 sall RSI, #10b7 movl R10, R13 # spill0ba addl R10, #2 # int0be imull R10, R10 # int0c2 movl R14, R13 # spill0c5 incl R14 # int0c8 imull R14, R14 # int0cc sall R10, #10cf sall R14, #10d2 addl R14, R11 # int0d5 addl R14, R10 # int0d8 movl R10, R13 # spill0db addl R10, #3 # int0df imull R10, R10 # int0e3 movl R11, R13 # spill0e6 addl R11, #5 # int0ea imull R11, R11 # int0ee sall R10, #10f1 addl R10, R14 # int0f4 addl R10, RSI # int0f7 sall R11, #10fa addl R11, R10 # int0fd addl R11, RAX # int100 addl R11, RDI # int103 addl R11, RDX # int106 movl R10, R13 # spill109 addl R10, #9 # int10d imull R10, R10 # int111 sall R10, #1114 addl R10, R11 # int117 addl R10, RCX # int11a addl R10, RBX # int11d addl R10, R8 # int120 addl R9, R10 # int123 addl RBP, R9 # int126 addl RBP, [RSP + #32 (32-bit)] # int12a addl R13, #16 # int12e movl R11, R13 # spill131 imull R11, R13 # int135 sall R11, #1138 cmpl R13, #99999998513f jl B2 # loop end P=1.000000 C=6554623.000000
We see that there is 1 register that is "spilled" onto the stack.
And for the 2 * i * i
version:
05a B3: # B2 B4 <- B1 B2 Loop: B3-B2 inner main of N18 Freq: 1e+00605a addl RBX, R11 # int05d movl [rsp + #32], RBX # spill061 movl R11, R8 # spill064 addl R11, #15 # int068 movl [rsp + #36], R11 # spill06d movl R11, R8 # spill070 addl R11, #14 # int074 movl R10, R9 # spill077 addl R10, #16 # int07b movdl XMM2, R10 # spill080 movl RCX, R9 # spill083 addl RCX, #14 # int086 movdl XMM1, RCX # spill08a movl R10, R9 # spill08d addl R10, #12 # int091 movdl XMM4, R10 # spill096 movl RCX, R9 # spill099 addl RCX, #10 # int09c movdl XMM6, RCX # spill0a0 movl RBX, R9 # spill0a3 addl RBX, #8 # int0a6 movl RCX, R9 # spill0a9 addl RCX, #6 # int0ac movl RDX, R9 # spill0af addl RDX, #4 # int0b2 addl R9, #2 # int0b6 movl R10, R14 # spill0b9 addl R10, #22 # int0bd movdl XMM3, R10 # spill0c2 movl RDI, R14 # spill0c5 addl RDI, #20 # int0c8 movl RAX, R14 # spill0cb addl RAX, #32 # int0ce movl RSI, R14 # spill0d1 addl RSI, #18 # int0d4 movl R13, R14 # spill0d7 addl R13, #24 # int0db movl R10, R14 # spill0de addl R10, #26 # int0e2 movl [rsp + #40], R10 # spill0e7 movl RBP, R14 # spill0ea addl RBP, #28 # int0ed imull RBP, R11 # int0f1 addl R14, #30 # int0f5 imull R14, [RSP + #36 (32-bit)] # int0fb movl R10, R8 # spill0fe addl R10, #11 # int102 movdl R11, XMM3 # spill107 imull R11, R10 # int10b movl [rsp + #44], R11 # spill110 movl R10, R8 # spill113 addl R10, #10 # int117 imull RDI, R10 # int11b movl R11, R8 # spill11e addl R11, #8 # int122 movdl R10, XMM2 # spill127 imull R10, R11 # int12b movl [rsp + #48], R10 # spill130 movl R10, R8 # spill133 addl R10, #7 # int137 movdl R11, XMM1 # spill13c imull R11, R10 # int140 movl [rsp + #52], R11 # spill145 movl R11, R8 # spill148 addl R11, #6 # int14c movdl R10, XMM4 # spill151 imull R10, R11 # int155 movl [rsp + #56], R10 # spill15a movl R10, R8 # spill15d addl R10, #5 # int161 movdl R11, XMM6 # spill166 imull R11, R10 # int16a movl [rsp + #60], R11 # spill16f movl R11, R8 # spill172 addl R11, #4 # int176 imull RBX, R11 # int17a movl R11, R8 # spill17d addl R11, #3 # int181 imull RCX, R11 # int185 movl R10, R8 # spill188 addl R10, #2 # int18c imull RDX, R10 # int190 movl R11, R8 # spill193 incl R11 # int196 imull R9, R11 # int19a addl R9, [RSP + #32 (32-bit)] # int19f addl R9, RDX # int1a2 addl R9, RCX # int1a5 addl R9, RBX # int1a8 addl R9, [RSP + #60 (32-bit)] # int1ad addl R9, [RSP + #56 (32-bit)] # int1b2 addl R9, [RSP + #52 (32-bit)] # int1b7 addl R9, [RSP + #48 (32-bit)] # int1bc movl R10, R8 # spill1bf addl R10, #9 # int1c3 imull R10, RSI # int1c7 addl R10, R9 # int1ca addl R10, RDI # int1cd addl R10, [RSP + #44 (32-bit)] # int1d2 movl R11, R8 # spill1d5 addl R11, #12 # int1d9 imull R13, R11 # int1dd addl R13, R10 # int1e0 movl R10, R8 # spill1e3 addl R10, #13 # int1e7 imull R10, [RSP + #40 (32-bit)] # int1ed addl R10, R13 # int1f0 addl RBP, R10 # int1f3 addl R14, RBP # int1f6 movl R10, R8 # spill1f9 addl R10, #16 # int1fd cmpl R10, #999999985204 jl B2 # loop end P=1.000000 C=7419903.000000
Here we observe much more "spilling" and more accesses to the stack [RSP + ...]
, due to more intermediate results that need to be preserved.
Thus the answer to the question is simple: 2 * (i * i)
is faster than 2 * i * i
because the JIT generates more optimal assembly code for the first case.
But of course it is obvious that neither the first nor the second version is any good; the loop could really benefit from vectorization, since any x86-64 CPU has at least SSE2 support.
So it's an issue of the optimizer; as is often the case, it unrolls too aggressively and shoots itself in the foot, all the while missing out on various other opportunities.
In fact, modern x86-64 CPUs break down the instructions further into micro-ops (µops) and with features like register renaming, µop caches and loop buffers, loop optimization takes a lot more finesse than a simple unrolling for optimal performance. According to Agner Fog's optimization guide:
The gain in performance due to the µop cache can be quiteconsiderable if the average instruction length is more than 4 bytes.The following methods of optimizing the use of the µop cache maybe considered:
- Make sure that critical loops are small enough to fit into the µop cache.
- Align the most critical loop entries and function entries by 32.
- Avoid unnecessary loop unrolling.
- Avoid instructions that have extra load time
. . .
Regarding those load times - even the fastest L1D hit costs 4 cycles, an extra register and µop, so yes, even a few accesses to memory will hurt performance in tight loops.
But back to the vectorization opportunity - to see how fast it can be, we can compile a similar C application with GCC, which outright vectorizes it (AVX2 is shown, SSE2 is similar)2:
vmovdqa ymm0, YMMWORD PTR .LC0[rip] vmovdqa ymm3, YMMWORD PTR .LC1[rip] xor eax, eax vpxor xmm2, xmm2, xmm2.L2: vpmulld ymm1, ymm0, ymm0 inc eax vpaddd ymm0, ymm0, ymm3 vpslld ymm1, ymm1, 1 vpaddd ymm2, ymm2, ymm1 cmp eax, 125000000 ; 8 calculations per iteration jne .L2 vmovdqa xmm0, xmm2 vextracti128 xmm2, ymm2, 1 vpaddd xmm2, xmm0, xmm2 vpsrldq xmm0, xmm2, 8 vpaddd xmm0, xmm2, xmm0 vpsrldq xmm1, xmm0, 4 vpaddd xmm0, xmm0, xmm1 vmovd eax, xmm0 vzeroupper
With run times:
- SSE: 0.24 s, or 2 times as fast.
- AVX: 0.15 s, or 3 times as fast.
- AVX2: 0.08 s, or 5 times as fast.
1 To get JIT generated assembly output, get a debug JVM and run with -XX:+PrintOptoAssembly
2 The C version is compiled with the -fwrapv
flag, which enables GCC to treat signed integer overflow as a two's-complement wrap-around.
(Editor's note: this answer is contradicted by evidence from looking at the asm, as shown by another answer. This was a guess backed up by some experiments, but it turned out not to be correct.)
When the multiplication is 2 * (i * i)
, the JVM is able to factor out the multiplication by 2
from the loop, resulting in this equivalent but more efficient code:
int n = 0;for (int i = 0; i < 1000000000; i++) { n += i * i;}n *= 2;
but when the multiplication is (2 * i) * i
, the JVM doesn't optimize it since the multiplication by a constant is no longer right before the n +=
addition.
Here are a few reasons why I think this is the case:
- Adding an
if (n == 0) n = 1
statement at the start of the loop results in both versions being as efficient, since factoring out the multiplication no longer guarantees that the result will be the same - The optimized version (by factoring out the multiplication by 2) is exactly as fast as the
2 * (i * i)
version
Here is the test code that I used to draw these conclusions:
public static void main(String[] args) { long fastVersion = 0; long slowVersion = 0; long optimizedVersion = 0; long modifiedFastVersion = 0; long modifiedSlowVersion = 0; for (int i = 0; i < 10; i++) { fastVersion += fastVersion(); slowVersion += slowVersion(); optimizedVersion += optimizedVersion(); modifiedFastVersion += modifiedFastVersion(); modifiedSlowVersion += modifiedSlowVersion(); } System.out.println("Fast version: " + (double) fastVersion / 1000000000 + " s"); System.out.println("Slow version: " + (double) slowVersion / 1000000000 + " s"); System.out.println("Optimized version: " + (double) optimizedVersion / 1000000000 + " s"); System.out.println("Modified fast version: " + (double) modifiedFastVersion / 1000000000 + " s"); System.out.println("Modified slow version: " + (double) modifiedSlowVersion / 1000000000 + " s");}private static long fastVersion() { long startTime = System.nanoTime(); int n = 0; for (int i = 0; i < 1000000000; i++) { n += 2 * (i * i); } return System.nanoTime() - startTime;}private static long slowVersion() { long startTime = System.nanoTime(); int n = 0; for (int i = 0; i < 1000000000; i++) { n += 2 * i * i; } return System.nanoTime() - startTime;}private static long optimizedVersion() { long startTime = System.nanoTime(); int n = 0; for (int i = 0; i < 1000000000; i++) { n += i * i; } n *= 2; return System.nanoTime() - startTime;}private static long modifiedFastVersion() { long startTime = System.nanoTime(); int n = 0; for (int i = 0; i < 1000000000; i++) { if (n == 0) n = 1; n += 2 * (i * i); } return System.nanoTime() - startTime;}private static long modifiedSlowVersion() { long startTime = System.nanoTime(); int n = 0; for (int i = 0; i < 1000000000; i++) { if (n == 0) n = 1; n += 2 * i * i; } return System.nanoTime() - startTime;}
And here are the results:
Fast version: 5.7274411 sSlow version: 7.6190804 sOptimized version: 5.1348007 sModified fast version: 7.1492705 sModified slow version: 7.2952668 s
Byte codes: https://cs.nyu.edu/courses/fall00/V22.0201-001/jvm2.htmlByte codes Viewer: https://github.com/Konloch/bytecode-viewer
On my JDK (Windows 10 64 bit, 1.8.0_65-b17) I can reproduce and explain:
public static void main(String[] args) { int repeat = 10; long A = 0; long B = 0; for (int i = 0; i < repeat; i++) { A += test(); B += testB(); } System.out.println(A / repeat + " ms"); System.out.println(B / repeat + " ms");}private static long test() { int n = 0; for (int i = 0; i < 1000; i++) { n += multi(i); } long startTime = System.currentTimeMillis(); for (int i = 0; i < 1000000000; i++) { n += multi(i); } long ms = (System.currentTimeMillis() - startTime); System.out.println(ms + " ms A " + n); return ms;}private static long testB() { int n = 0; for (int i = 0; i < 1000; i++) { n += multiB(i); } long startTime = System.currentTimeMillis(); for (int i = 0; i < 1000000000; i++) { n += multiB(i); } long ms = (System.currentTimeMillis() - startTime); System.out.println(ms + " ms B " + n); return ms;}private static int multiB(int i) { return 2 * (i * i);}private static int multi(int i) { return 2 * i * i;}
Output:
...405 ms A 785527736327 ms B 785527736404 ms A 785527736329 ms B 785527736404 ms A 785527736328 ms B 785527736404 ms A 785527736328 ms B 785527736410 ms333 ms
So why?The byte code is this:
private static multiB(int arg0) { // 2 * (i * i) <localVar:index=0, name=i , desc=I, sig=null, start=L1, end=L2> L1 { iconst_2 iload0 iload0 imul imul ireturn } L2 { } } private static multi(int arg0) { // 2 * i * i <localVar:index=0, name=i , desc=I, sig=null, start=L1, end=L2> L1 { iconst_2 iload0 imul iload0 imul ireturn } L2 { } }
The difference being:With brackets (2 * (i * i)
):
- push const stack
- push local on stack
- push local on stack
- multiply top of stack
- multiply top of stack
Without brackets (2 * i * i
):
- push const stack
- push local on stack
- multiply top of stack
- push local on stack
- multiply top of stack
Loading all on the stack and then working back down is faster than switching between putting on the stack and operating on it.