Converting floats to strings, part 1

Much has been written about converting floating point numbers to strings: Bruce Dawson has lots to say, David M. Gay's venerable dtoa is a classic, and Florian Loitsch's relatively recent grisu paper is worth studying. Often the problem is framed as "converting a floating point number to the shortest possible decimal string representation", but this framing is neither neccessary nor sufficient for implementing the %e / %f / %g formats of sprintf. Furthermore, this framing introduces significant complexity. As such, I'd like to begin by considering a much simpler framing: convert a double-precision floating point number to a decimal string, such that the string gives the exact mathematical real number represented by the float.

First of all, we need a quick reminder of what an IEEE754 double-precision floating point number looks like under the hood. If we ignore negative numbers, infinities, NaNs, and denormals, then we have just an 11-bit exponent e and a 53-bit mantissa m, which together represent the number m (an integer) times 2e (where e might be negative). Denormals also fit the pattern of m times 2e, albeit with a different encoding of m and e.

Armed with this knowledge, we can pull the m and e fields out of a number:

typedef union {
  double n;
  uint64_t u64;
  struct {
    uint32_t lo;
    uint32_t hi;
  } u32;
} TValue;

void decode(double n) {
  TValue t;
  t.n = n;
  if ((t.u32.hi << 1) >= 0xffe00000) {
    if (((t.u32.hi & 0x000fffff) | t.u32.lo) != 0) {
      printf("NaN\n");
    } else {
      printf("Infinity\n");
    }
  } else {
    int32_t e = (t.u32.hi >> 20) & 0x7ff;
    uint64_t m = t.u32.hi & 0xfffff;
    if (e == 0) {
      e++;
    } else {
      m |= 0x100000;
    }
    e -= 1043;
    if (t.u32.lo) {
      e -= 32;
      m = (m << 32) | t.u32.lo;
    }
    printf("%llu * 2^%d\n", (long long unsigned)m, (int)e);
  }
}

Some example outputs of this decode function are:

nOutput
0./0.NaN
1./0.Infinity
0.0 * 2^-1042
1.1048576 * 2^-20
10.1310720 * 2^-17
0.17205759403792794 * 2^-56
1e-3082024022533073106 * 2^-1074
pow(2, 1020)1048576 * 2^1000

These outputs are not in decimal, and they're not particularly convenient for human comprehension, but they are 100% accurate.

From m and e values, one classical approach for getting decimal digits is:

  1. Convert m to a binary bignum (easy).
  2. Multiply the binary bignum by 2e (easy).
  3. In a loop, compute the binary bignum modulo ten to get one digit (hard), and then divide the binary bignum by ten (hard).

I prefer the following approach, as it gets rid of the hard part:

  1. Convert m to a decimal bignum (not easy, but not hard).
  2. Multiply the decimal bignum by 2e (not easy, but not hard).
  3. In a loop, print the digits of the decimal bignum (easy).

What do I mean by binary bignum and decimal bignum? A binary bignum is a number stored as multiple uint32_t pieces, with each piece being in the range 0 through 232-1, and the number being sum(piecei * (232)i). On the other hand, a decimal bignum is a number stored as multiple uint32_t pieces, with each piece being in the range 0 through 109-1, and the number being sum(piecei * (109)i). 109 is chosen as it is the largest power of ten which fits into a uint32_t (even better, 109 requires 29.9 bits to store, so only 2.1 bits out of every 32 are wasted).

I'd like to represent a decimal bignum using the following variables:

uint32_t nd[128];
int32_t ndlo;
int32_t ndhi;

For ndlo <= i <= ndhi, piecei is nd[i & 127], and otherwise piecei is zero.

As it happens, we already know how to print the digits of an individual decimal bignum piece. Using one of those functions from last time, printing an entire decimal bignum is painless:

void nd_print(char* p, uint32_t* nd, int32_t ndlo, int32_t ndhi) {
  int32_t i;
  for (i = ndhi; i >= 0; --i) {
    nasonov9(p, nd[i & 127]); p += 9;
  }
  *p++ = '.';
  for (; i >= ndlo; --i) {
    nasonov9(p, nd[i & 127]); p += 9;
  }
  *p = 0;
}

Multiplying by 2e takes a bit of effort. Let's start by considering the case where e is negative, at which point we're really dividing by 2-e. In turn, this is dividing by 2 -e times. Just dividing by 2 doesn't take that much code:

int32_t nd_div2(uint32_t* nd, int32_t ndlo, int32_t ndhi) {
  uint32_t i = ndhi & 127, carry = 0;
  for (;;) {
    uint32_t val = nd[i];
    nd[i] = (val >> 1) + carry;
    carry = (val & 1) * 500000000;
    if (i == (ndlo & 127)) break;
    i = (i - 1) & 127;
  }
  if (carry) nd[--ndlo & 127] = carry;
  return ndlo;
}

We can generalise this to dividing by 2k:

int32_t nd_div2k(uint32_t* nd, int32_t ndlo, int32_t ndhi,
                 uint32_t k) {
  uint32_t mask = (1U << k) - 1, mul = 1000000000 >> k;
  uint32_t i = ndhi & 127, carry = 0;
  for (;;) {
    uint32_t val = nd[i];
    nd[i] = (val >> k) + carry;
    carry = (val & mask) * mul;
    if (i == (ndlo & 127)) break;
    i = (i - 1) & 127;
  }
  if (carry) nd[--ndlo & 127] = carry;
  return ndlo;
}

The above will work for k between zero and nine inclusive; if k is larger than nine, then 1000000000 >> k can no longer be represented by an integer. As such, the complete solution has to start by dividing in batches of 29:

int32_t nd_div2k(uint32_t* nd, int32_t ndlo, int32_t ndhi,
                 uint32_t k) {
  while (k >= 9) {
    uint32_t i = ndhi & 127, carry = 0;
    for (;;) {
      uint32_t val = nd[i];
      nd[i] = (val >> 9) + carry;
      carry = (val & 0x1ff) * 1953125;
      if (i == (ndlo & 127)) break;
      i = (i - 1) & 127;
    }
    if (carry) nd[--ndlo & 127] = carry;
    k -= 9;
  }
  if (k) {
    uint32_t mask = (1U << k) - 1, mul = 1000000000 >> k;
    uint32_t i = ndhi & 127, carry = 0;
    for (;;) {
      uint32_t val = nd[i];
      nd[i] = (val >> k) + carry;
      carry = (val & mask) * mul;
      if (i == (ndlo & 127)) break;
      i = (i - 1) & 127;
    }
    if (carry) nd[--ndlo & 127] = carry;
  }
  return ndlo;
}

We can then go through the same process for multiplying, starting with multiplication by 2:

int32_t nd_mul2(uint32_t* nd, int32_t ndhi) {
  uint32_t carry_in = 0;
  for (uint32_t i = 0; i <= (uint32_t)ndhi; i++) {
    uint32_t val = (nd[i] << 1) | carry_in;
    carry_in = val / 1000000000;
    nd[i] = val - carry_in * 1000000000;
  }
  if (carry_in) nd[++ndhi] = carry_in;
  return ndhi;
}

By promoting val to 64-bits, this can be generalised to small k:

int32_t nd_mul2k(uint32_t* nd, int32_t ndhi, uint32_t k) {
  uint32_t carry_in = 0;
  for (uint32_t i = 0; i <= (uint32_t)ndhi; i++) {
    uint64_t val = ((uint64_t)nd[i] << k) | carry_in;
    carry_in = (uint32_t)(val / 1000000000);
    nd[i] = (uint32_t)val - carry_in * 1000000000;
  }
  if (carry_in) nd[++ndhi] = carry_in;
  return ndhi;
}

This time the constraint on k comes from wanting val to be no more than (109)2, which limits k to 29. It also turns out to be useful to make carry_in a parameter, all of which leads to the complete code for multiplying by 2k:

int32_t nd_mul2k(uint32_t* nd, int32_t ndhi, uint32_t k,
                 uint32_t carry_in) {
  while (k >= 29) {
    for (uint32_t i = 0; i <= (uint32_t)ndhi; i++) {
      uint64_t val = ((uint64_t)nd[i] << 29) | carry_in;
      carry_in = (uint32_t)(val / 1000000000);
      nd[i] = (uint32_t)val - carry_in * 1000000000;
    }
    if (carry_in) {
      nd[++ndhi] = carry_in; carry_in = 0;
    }
    k -= 29;
  }
  if (k) {
    for (uint32_t i = 0; i <= (uint32_t)ndhi; i++) {
      uint64_t val = ((uint64_t)nd[i] << k) | carry_in;
      carry_in = (uint32_t)(val / 1000000000);
      nd[i] = (uint32_t)val - carry_in * 1000000000;
    }
    if (carry_in) nd[++ndhi] = carry_in;
  }
  return ndhi;
}

We can plug these routines into the decode function from earlier to create a print function:

void print(double n) {
  TValue t;
  t.n = n;
  if ((t.u32.hi << 1) >= 0xffe00000) {
    if (((t.u32.hi & 0x000fffff) | t.u32.lo) != 0) {
      printf("NaN\n");
    } else {
      printf("Infinity\n");
    }
  } else {
    char buf[1154];
    uint32_t nd[128];
    int32_t ndlo = 0;
    int32_t ndhi = 0;
    int32_t e = (t.u32.hi >> 20) & 0x7ff;
    nd[0] = t.u32.hi & 0xfffff;
    if (e == 0) {
      e++;
    } else {
      nd[0] |= 0x100000;
    }
    e -= 1043;
    if (t.u32.lo) {
      e -= 32;
      nd[0] = (nd[0] << 3) | (t.u32.lo >> 29);
      ndhi = nd_mul2k(nd, ndhi, 29, t.u32.lo & 0x1fffffff);
    }
    if (e >= 0) {
      ndhi = nd_mul2k(nd, ndhi, (uint32_t)e, 0);
    } else {
      ndlo = nd_div2k(nd, ndlo, ndhi, (uint32_t)-e);
    }
    nd_print(buf, nd, ndlo, ndhi);
    printf("%s\n", buf);
  }
}

Some example outputs of this print function are:

nOutput
0./0.NaN
1./0.Infinity
0.000000000.
1.000000001.
10.000000010.
0.1000000000000000000.1000000000000000055511
15123125782702118158340454101562500000000
1e-308000000000000000000.0000000000000000000000
00000000000000000000000000000000000000000
00000000000000000000000000000000000000000
00000000000000000000000000000000000000000
00000000000000000000000000000000000000000
00000000000000000000000000000000000000000
00000000000000000000000000000000000000000
00000000000000000000000000000000000000009
99999999999999909326625337248461995470488
73403204569370722504933164788134100221702
36685306110285951575783017584918228243784
38792553200763769833775473829862512856683
41346193998972906543693727922885247662294
86591679434355446221493480729436132941672
16662821737555414480159115639791276054897
20142038977058035153396077150619905566488
97702602917109778267250244017165230316273
90652604144008597950935492433262042405635
56399326294969169893097546113480479123599
46979384052000893178607312050101591177117
04697471514344499487123311264707354172378
09953873785021982614510236627959137966047
18812599767273565216024053297899062477635
21525981391443887618575275588619928089116
90506171197530846785775640581096161907433
18668839610809435427125598308539800029848
265694454312324523925781250000000
pow(2, 1020)00000001123558209288947442330815744243140
45851123561183894160795893800723582922378
43810195794279832650471001320007117491962
08485367436055090103890580296441496713277
36104933390540928297688887250778808824658
17684505312860552384417646403930092119569
40880170232270940691778664363999670287115
4982269052209770601514008576.

Some of these outputs could be tidied up by trimming leading and trailing zeroes, but other than that, this is the output we were aiming for. We can see that the double-precision floating point number closest to 0.1 is actually ever so slightly more than 0.1, that the double-precision floating point number closest to 1e-308 is actually ever so slightly less than 1e-308, and that in some cases the string representation of the exact decimal represented by a float can be extremely long (though, to be fair, 1e-308 is almost as bad as it gets for length).

Next time we'll see how to adapt nd_print into something which can behave like the %e, %f, and %g formats of sprintf (i.e. outputs which you might actually want, as opposed to pedantic exact representations).

Converting 9-digit integers to strings

Much has been written on the internet about converting integers to strings (e.g. A, B, C). Normally, when this is considered, half of the problem (or the interesting trick in the solution) revolves around knowing how long the resulting string will be. As such, I'd like to consider a slight variant of the problem: converting an integer between 0 and 999999999 to a nine digit string, and always emitting the full nine digits (so 456 becomes 000000456, for example). We can begin with a simple libc-based approach:

void sprintf9(char* p, uint32_t n) {
  sprintf(p, "%09u", n);
}

Using sprintf is short and pithy, but potentially not the most efficient choice. Another option might be:

void divmod9(char* p, uint32_t n) {
  for (uint32_t i = 9; i--; n /= 10) {
    p[i] = '0' + (n % 10);
  }
}

It might look like this approach is going to be slowed down by looping and dividing quite a lot, but Clang will fully unroll the loop, turn all the division into multiplication, and turn the addition into bitwise-or. The result is impressively efficient. In particular, the divide-by-multiplying trick is something that everyone should read up on - I'd recommend this explaination by "fish" for those unfamiliar with it.

A third alternative is this enigma from LuaJIT:

#define WINT_R(x, sh, sc) { \
  uint32_t d = (x * (((1<<sh)+sc-1)/sc)) >> sh; \
  x -= d * sc; \
  *p++ = (char)('0' + d); }

void lj_strfmt_wuint9(char* p, uint32_t n) {
  uint32_t v = n / 10000, w;
  n -= v * 10000;
  w = v / 10000;
  v -= w * 10000;
  *p++ = (char)('0' + w);
  WINT_R(v, 23, 1000)
  WINT_R(v, 12, 100)
  WINT_R(v, 10, 10)
  *p++ = (char)('0' + v);
  WINT_R(n, 23, 1000)
  WINT_R(n, 12, 100)
  WINT_R(n, 10, 10)
  *p++ = (char)('0' + n);
}

If you consider (x * (((1<<sh)+sc-1)/sc)) >> sh in the realm of mathematical real numbers, rather than the realm of C integer arithmetic, then the expression looks kind of like x * (((1 << sh) / sc) >> sh), which in turn looks kind of like x * (1 / sc), which is of course x / sc. This intution turns out to be spot-on: this expression is the same divide-by-multiplying trick as used by Clang - sc is the value we want to divide by (what fish calls d), sh is what fish calls k, and (((1<<sh)+sc-1)/sc) is what fish calls m. The choices of sh look lightly magical, but 23 is the smallest k which suffices for dividing numbers in the range 0 through 9999 by 1000, 12 is the smallest k for 0 through 999 by 100, and 10 is the smallest k for 0 through 99 by 10.

Another option is this SSE2 implementation, which is clear as mud:

#include <emmintrin.h>

void vectorised9(char* p, uint32_t n) {
  __m128i a = _mm_set1_epi32(n);
  __m128i b = _mm_srli_epi64(
      _mm_mul_epu32(a, _mm_set1_epi32(879609303)), 43);
  __m128i c = _mm_shuffle_epi32(_mm_mul_epu32(b,
      _mm_setr_epi32(10000, 0, 429497, 0)), 0x47);
  p[0] = '0' | _mm_cvtsi128_si32(c);
  __m128i d = _mm_sub_epi32(_mm_unpacklo_epi64(b, a),
      _mm_mul_epu32(c, _mm_setr_epi32(10000, 0, 1, 0)));
  __m128i e = _mm_srli_epi32(
      _mm_mul_epu32(d, _mm_set1_epi32(5243)), 19);
  __m128i f = _mm_or_si128(e,
      _mm_shuffle_epi32(_mm_sub_epi32(d,
          _mm_mul_epu32(e, _mm_set1_epi32(100))), 0x91));
  __m128i g = _mm_mulhi_epu16(f, _mm_set1_epi32(6554));
  __m128i h = _mm_slli_si128(_mm_sub_epi32(f,
      _mm_mullo_epi16(g, _mm_set1_epi32(10))), 2);
  __m128i i = _mm_packus_epi16(_mm_or_si128(g, h), h);
  _mm_storel_epi64((__m128i*)(p + 1),
      _mm_or_si128(i, _mm_set1_epi32(0x30303030)));
}

After you get over all the underscores, you can see that this code is full of the same divide-by-multiplying trick. For example, the assignment to g is doing WINT_R(f, 16, 10) four times in parallel (note the choice of k = 16 rather than the minimal k = 10; it could use 10, but using 16 allows it to get a shift for free as part of _mm_mulhi_epu16).

Of course, a consideration of different implementations wouldn't be complete without a benchmark. As such, I present a highly unscientific benchmark of how long it takes a single core of my MacBook to convert every integer between 0 and 999999999 to a string:

clang -O2clang -O2 -m32 -msse2
sprintf91m31s1m52s
divmod98.7s11.7s
lj_strfmt_wuint99.5s11.4s
vectorised95.2s5.2s

We can conclude that vectorisation is a win, that 32-bit code is slightly slower than 64-bit code (except for the vectorised solution), that LuaJIT's cleverness isn't an improvement upon the simple divmod9 (once Clang has optimised the hell out of divmod9), and that sprintf is dog slow.

Of course, your choice of compiler is important. I've used Clang for the above, as it is the default easy option for MacBooks, but we shouldn't forget gcc. Then again, perhaps we should forget gcc: its 64-bit compilation of vectorised9 turns _mm_set1_epi32(n) into a store-to-memory, a load-from-memory, and a shuffle — the shuffle is required, but the store and load are definitely not. Meanwhile, gcc's 32-bit compilation of vectorised9 turns _mm_storel_epi64 into a store-to-memory, two loads-from-memory, and two more stores-to-memory — everything except the first store is completely superfluous.

Next time we'll see how converting 9-digit integers to strings turns out to be very useful in the context of converting floating point numbers to strings.

Update:

@nasonov points out a more efficient vectorised implementation, which clocks in at 3.6s (64-bit) or 4.2 seconds (32-bit) in my benchmark setup:

#include <emmintrin.h>

void nasonov9(char* p, uint32_t u) {
  uint32_t v = u / 10000;
  uint32_t w = v / 10000;
  u -= v * 10000;
  v -= w * 10000;

  const __m128i first_madd =
      _mm_set_epi16(-32768, -32768, 0, 26215, 0, 10486, 0, 8389);
  const __m128i mask =
      _mm_set_epi16(0xffff, 0, 0xfffc, 0, 0xfff0, 0, 0xff80, 0);
  const __m128i second_madd =
      _mm_set_epi16(-256, -640, 64, -160, 16, -20, 2, 0);

  __m128i x = _mm_madd_epi16(_mm_set1_epi16(v), first_madd);
  __m128i y = _mm_madd_epi16(_mm_set1_epi16(u), first_madd);
  x = _mm_and_si128(x, mask);
  y = _mm_and_si128(y, mask);
  x = _mm_or_si128(x, _mm_slli_si128(x, 2));
  y = _mm_or_si128(y, _mm_slli_si128(y, 2));
  x = _mm_madd_epi16(x, second_madd);
  y = _mm_madd_epi16(y, second_madd);

  __m128i z = _mm_srli_epi16(_mm_packs_epi32(x, y), 8);
  z = _mm_packs_epi16(z, z);
  p[0] = '0' | w;
  _mm_storel_epi64((__m128i*)(p + 1),
      _mm_or_si128(z, _mm_set1_epi32(0x30303030)));
}

Finding a 5-year-old LuaJIT bug

LuaJIT#311 is the kind of bug report which compiler authors love: a relatively small self-contained file of code, which when fed through the compiler, does the wrong thing. The issue is now resolved, but I'd like to write about the thought process by which it was resolved.

With a bunch of code omitted for berevity, the issue claims that the following file, row.lua, is problematic:

-- ... lots of code ...

for i=0,100 do
  jit.flush()

  r=TIntArrayNative.new{ 0,1,2,-1,-2,-3 }
  res=ff( TIntArrayNative.new{ 1,2,3,2,1,0 } )
  assert( cEql(r, res), 'ff1 failed' )

  r=TIntArrayNative.new{ 0,0,0 }
  res=ff( TIntArrayNative.new{ 0, 0, 0 } )
  assert( cEql(r, res), 'ff2 failed' )

  r=TIntArrayNative.new{ 0,1,-1 }
  res=ff( TIntArrayNative.new{ 0, 1, 0 } )
  assert( cEql(r, res), 'ff3 failed' )

  r=TIntArrayNative.new{ 0,-1,-2,-3, 1,2,3 }
  res=ff( TIntArrayNative.new{ 0,-1,-2,-4,0,1,2 } )
  assert( cEql(r, res), 'ff4 failed' )
end

From this code, we can conclude that the issue isn't perfectly reproducible (hence the 100-iteration loop), and that the issue only affects some traces and not others (the call to jit.flush at the top of loop will result in LuaJIT sometimes choosing to compile different traces on each iteration). Bearing this in mind, trying to reproduce the issue was often fruitless:

$ luajit.exe row.lua && echo OK
OK

But, sometimes, an issue would present itself:

$ luajit.exe row.lua && echo OK
luajit.exe: row.lua:199: ff4 failed
stack traceback:
        [C]: in function 'assert'
        row.lua:199: in main chunk
        [C]: at 0x7ff61f3c1eb0

Given the proportion of times the issue didn't appear, "not perfectly reproducible" was perhaps an understatement, so increasing reproducibility was the first priority. To begin, the -jdump=t flag can give an idea of what code get selected for tracing - the output is along the lines of:

$ luajit.exe -jdump=t row.lua 
---- TRACE flush

---- TRACE 1 start row.lua:69
---- TRACE 1 stop -> return

---- TRACE flush

---- TRACE 1 start row.lua:107
---- TRACE 1 stop -> loop

---- TRACE 2 start row.lua:69
---- TRACE 2 stop -> return

---- TRACE flush

---- TRACE 1 start row.lua:74
---- TRACE 1 stop -> return

---- TRACE 2 start row.lua:69
---- TRACE 2 stop -> return

---- TRACE 3 start row.lua:170
---- TRACE 3 abort row.lua:175 -- leaving loop in root trace

---- TRACE flush

---- TRACE 1 start row.lua:107
---- TRACE 1 stop -> loop

---- TRACE 2 start row.lua:145
---- TRACE 2 stop -> loop

---- TRACE flush

---- TRACE 1 start row.lua:69
---- TRACE 1 stop -> return

---- TRACE 2 start row.lua:170
---- TRACE 2 stop -> loop

Now comes the game of deciding which trace site to investigate. The traces which finish with stop -> return are less interesting than the traces which finish with stop -> loop (as looping traces have more opportunities for mischief). With that in mind, the row.lua:107 loop is:

      for _,v in ipairs(n) do 
        a[k] = v
        k = k + 1
      end

The row.lua:145 loop is:

  for i=1,#a do
    if a[i] > a[i-1] then
      if r[i-1] > 0 then
        r[i] = r[i-1] + 1
      else
        r[i] = 1
      end
    end

    if a[i] < a[i-1] then
      if r[i-1] < 0 then
        r[i] = r[i-1] - 1
      else
        r[i] = -1
      end
    end
  end

The row.lua:170 loop is:

  for i=0,#a1 do
    if a1[i] ~= a2[i] then
      return false
    end
  end

Of these three loops, row.lua:145 contains the most branches, and therefore is the most likely to sometimes have LuaJIT end up choosing a different sequence of branch directions to trace. Given the hypothesis that branch choice is crucial to the issue, this seemed like a good loop to focus on. One way to focus on it is to prevent JIT compilation of other things (by means of jit.on() and jit.off()), and then to play around with JIT compiler parameters relating to trace selection, eventually stumbling upon this:

@@ -142,6 +142,8 @@ end
 function ff(a)
   local r=TIntArrayNative.new(#a)
 
+  jit.opt.start("hotloop=2")
+  jit.on()
   for i=1,#a do
     if a[i] > a[i-1] then
       if r[i-1] > 0 then
@@ -159,6 +161,7 @@ function ff(a)
       end
     end
   end
+  jit.off()
 
   return r
 end

With this diff in place, the issue becomes reproducible every time, and the next stage of investigation can begin - by which I mean the level of data dumping can be increased from -jdump=t all the way up to -jdump=bitmsrx:

$ luajit.exe -jdump=bitmsrx row.lua 
---- TRACE flush

---- TRACE 1 start row.lua:69
... some bytecode ...
---- TRACE 1 IR
... some IR ...
---- TRACE 1 mcode 184
... some machine code ...
---- TRACE 1 stop -> return

---- TRACE 2 start row.lua:147
... lots of bytecode ...
---- TRACE 2 IR
... lots of IR ...
0039 ------------ LOOP ------------
... more IR ...
---- TRACE 2 mcode 225
... lots of machine code ...
---- TRACE 2 stop -> loop

---- TRACE 2 exit 4
---- TRACE 2 exit 4
---- TRACE 3 start row.lua:74
... some bytecode ...
---- TRACE 3 IR
... some IR ...
---- TRACE 3 mcode 108
... some machine code ...
---- TRACE 3 stop -> return

---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 2 exit 2
---- TRACE 2 exit 2
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 3 exit 0
---- TRACE 4 start 3/0 row.lua:75
... some bytecode ...
---- TRACE 4 IR
... some IR ...
---- TRACE 4 mcode 99
... some machine code ...
---- TRACE 4 stop -> return

---- TRACE 2 exit 1
---- TRACE 2 exit 5
---- TRACE 2 exit 7
---- TRACE 2 exit 1
---- TRACE 2 exit 1
luajit.exe: row.lua:202: ff4 failed
stack traceback:
  [C]: in function 'assert'
  row.lua:202: in main chunk
  [C]: at 0x0100001440

The loop of interest has become TRACE 2, and some of the metamethods it invokes have become TRACE 1, TRACE 3, and TRACE 4. After a quick look over the IR of all these traces, it was the IR of TRACE 2 which caught my attention:

---- TRACE 2 IR
....              SNAP   #0   [ ---- ]
0001 rax   >  int SLOAD  #4    CRI
0002       >  int LE     0001  +2147483646
0003 rbp      int SLOAD  #3    CI
0004 r10   >  cdt SLOAD  #1    T
0005          u16 FLOAD  0004  cdata.ctypeid
0006       >  int EQ     0005  +102
0007          p64 ADD    0004  +8  
0008 rdx      p64 XLOAD  0007  
0012          i64 BSHL   0003  +3  
0013          p64 ADD    0012  0008
0014          p64 ADD    0013  +8  
0015 r8       i64 XLOAD  0014  
0019 r9       i64 XLOAD  0013  
....              SNAP   #1   [ ---- ---- ---- 0003 0001 ---- 0003 ]
0021       >  i64 GE     0019  0015
....              SNAP   #2   [ ---- ---- ---- 0003 0001 ---- ---- ]
0022       >  i64 GT     0019  0015
....              SNAP   #3   [ ---- ---- ---- 0003 0001 ---- 0003 ]
0023 rsi   >  cdt SLOAD  #2    T
0024          u16 FLOAD  0023  cdata.ctypeid
0025       >  int EQ     0024  +102
0026          p64 ADD    0023  +8  
0027 rcx      p64 XLOAD  0026  
0031          p64 ADD    0027  0012
0032          i64 XLOAD  0031  
....              SNAP   #4   [ ---- ---- ---- 0003 0001 ---- 0003 ]
0034       >  i64 GE     0032  +0  
0035          p64 ADD    0031  +8  
0036          i64 XSTORE 0035  -1  
0037 rbp    + int ADD    0003  +1  
....              SNAP   #5   [ ---- ---- ---- ]
0038       >  int LE     0037  0001
....              SNAP   #6   [ ---- ---- ---- 0037 0001 ---- 0037 ]
0039 ------------ LOOP ------------
0040          i64 BSHL   0037  +3  
0041          p64 ADD    0040  0008
0042          p64 ADD    0041  +8  
0043 rbx      i64 XLOAD  0042  
0044 r15      i64 XLOAD  0041  
....              SNAP   #7   [ ---- ---- ---- 0037 0001 ---- 0037 ]
0045       >  i64 GE     0044  0043
....              SNAP   #8   [ ---- ---- ---- 0037 0001 ---- 0037 ]
0046       >  i64 GT     0044  0043
0047          p64 ADD    0040  0027
0048          p64 ADD    0047  +8  
0049          i64 XSTORE 0048  -1  
0050 rbp    + int ADD    0037  +1  
....              SNAP   #9   [ ---- ---- ---- ]
0051       >  int LE     0050  0001
0052 rbp      int PHI    0037  0050
---- TRACE 2 mcode 225

This being a looping trace, all of the instructions before -- LOOP -- should correspond to one iteration of the loop, and similarly all of the instructions after -- LOOP -- should correspond to one iteration of the loop (the difference being that the second set of instructions can rely on a bunch of assumptions and invariants set up by the first set of instructions, and thus can be shorter and more efficient). For example, 0012 i64 BSHL 0003 +3 and 0040 i64 BSHL 0037 +3 represent the same thing (namely, the multiplication of i by 8 as part of array indexing). Similarly, 0021 > i64 GE 0019 0015 and 0045 > i64 GE 0044 0043 both represent the > in if a[i] > a[i-1] then, and 0022 > i64 GT 0019 0015 and 0046 > i64 GT 0044 0043 both represent the < in if a[i] < a[i-1] then.

The thing which jumped out at me was 0034 > i64 GE 0032 +0 (which comes from the < in if r[i-1] < 0 then), which has no corresponding instruction after -- LOOP --. In other words, LuaJIT concluded that it only needed to check r[i-1] < 0 was false once, and that thereafter, it could assume r[i-1] < 0 was false without needing to check. The question became: why did LuaJIT conclude this?

To answer this question, attention moved to the implementation of LuaJIT's "LOOP" optimisation (in lj_opt_loop.c), and in particular the part which emits all of the instructions after -- LOOP --:

/* Copy and substitute all recorded instructions and snapshots. */
for (ins = REF_FIRST; ins < invar; ins++) {
  ...
  /* Substitute instruction operands. */
  ir = IR(ins);
  op1 = ir->op1;
  if (!irref_isk(op1)) op1 = subst[op1];
  op2 = ir->op2;
  if (!irref_isk(op2)) op2 = subst[op2];
  if (irm_kind(lj_ir_mode[ir->o]) == IRM_N &&
      op1 == ir->op1 && op2 == ir->op2) {
    subst[ins] = (IRRef1)ins;  /* Shortcut. */
  } else {
    /* Re-emit substituted instruction to the FOLD/CSE/etc. pipeline. */
    IRType1 t = ir->t;
    IRRef ref = tref_ref(emitir(ir->ot & ~IRT_ISPHI, op1, op2));
    subst[ins] = (IRRef1)ref;
    ...
  }
}

The key parts of this code are:

To see what happened to 0034 > i64 GE 0032 +0, a conditional breakpoint can be set on the ir = IR(ins) line (with the condition being J->cur.traceno == 2 && ins == REF_BASE + 34):

The instruction's opcode ((IROp)ir->o) is GE, as expected. The original left-hand operand (ir->op1) is 0x8020, which after dropping the 8 and converting the 20 from hex to decimal, gives 0032, as expected. The remaining watches are less obvious: they're all less than 0x8000, which means that they are constants rather than instructions. Knowing that they happen to be 64-bit signed integer constants, their values are just another watch away:

This shows that the right-hand operand (op2) is the constant 0, as expected. Furthermore, LuaJIT thinks that the left-hand operand (op1) is the constant -1. Given that the original source code was r[i-1] < 0, this might seem slightly surprising, but it is in fact LuaJIT being clever: the previous loop iteration did r[i] = -1 (this is 0036 i64 XSTORE 0035 -1 in the IR), and i increments by one on each iteration (this is 0037 rbp + int ADD 0003 +1 / 0050 rbp + int ADD 0037 +1 / 0052 rbp int PHI 0037 0050 in the IR), and so LuaJIT is correct to conclude that on this iteration, r[i-1] is -1 (reaching this conclusion requires quite a lot of cleverness around re-assocation and forwarding and aliasing, but that's a story for another day).

So, at this point, everything is lined up for emitir to emit (or find or etc) the instruction int64_t(-1) >= int64_t(0). Given both operands are constants, this instruction should get constant-folded, and indeed emitir ends up at fold_kfold_int64comp from lj_opt_fold.c:

/* Special return values for the fold functions. */
enum {
  ...
  FAILFOLD,   /* Guard would always fail. */
  DROPFOLD,   /* Guard eliminated. */
  ...
};
...
#define CONDFOLD(cond)  ((TRef)FAILFOLD + (TRef)(cond))
...
LJFOLD(LT KINT64 KINT64)
LJFOLD(GE KINT64 KINT64)
LJFOLD(LE KINT64 KINT64)
LJFOLD(GT KINT64 KINT64)
LJFOLD(ULT KINT64 KINT64)
LJFOLD(UGE KINT64 KINT64)
LJFOLD(ULE KINT64 KINT64)
LJFOLD(UGT KINT64 KINT64)
LJFOLDF(kfold_int64comp)
{
#if LJ_HASFFI
  uint64_t a = ir_k64(fleft)->u64, b = ir_k64(fright)->u64;
  switch ((IROp)fins->o) {
  case IR_LT: return CONDFOLD(a < b);
  case IR_GE: return CONDFOLD(a >= b);
  case IR_LE: return CONDFOLD(a <= b);
  case IR_GT: return CONDFOLD(a > b);
  case IR_ULT: return CONDFOLD((uint64_t)a < (uint64_t)b);
  case IR_UGE: return CONDFOLD((uint64_t)a >= (uint64_t)b);
  case IR_ULE: return CONDFOLD((uint64_t)a <= (uint64_t)b);
  case IR_UGT: return CONDFOLD((uint64_t)a > (uint64_t)b);
  default: lua_assert(0); return FAILFOLD;
  }
#else
  UNUSED(J); lua_assert(0); return FAILFOLD;
#endif
}

Any sane person should conclude that int64_t(-1) >= int64_t(0) is false (FAILFOLD), but if you follow the code in the above snippit, you'll see that it returns DROPFOLD (true) for int64_t(-1) IR_GE int64_t(0), because it always does unsigned comparison. Therein lies the bug: the four instructions with U in their name are meant to do unsigned comparison, but the four instructions without U in their name are meant to do signed comparison. At runtime, they do indeed do this, but at constant-folding-time, the signed comparisons do the wrong thing if one operand is negative and the other is non-negative.

The fix is simple: have fold_kfold_int64comp perform signed comparisons for the four instructions which represent signed comparisons.

Micro-optimisations can speed up CPython

Last time, I bemoaned what compilers did to some of the CPython interpreter main loop. Following those remarks, there are three obvious courses of action:

  1. Make targeted improvements to the compilers.
  2. Write the interpreter main loop directly in assembly.
  3. Tweak the C source code to make it more amenable to good compilation.

Option three is the easiest to explore, so let's start with a random benchmark to use as a baseline:

Python 3.6.0+ (default, Mar  7 2017, 00:04:40) 
[GCC 4.8.5 20150623 (Red Hat 4.8.5-11)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from performance.benchmarks.bm_nbody import bench_nbody
>>> bench_nbody(10, 'sun', 100000)
9.907924063038081

The following patch is a little long, but each individual change is relatively boring, and all the changes are motivated by what we saw in gcc's assembly:

diff --git a/Python/ceval.c b/Python/ceval.c
index d5172b9..79ccf2a 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -729,7 +729,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
     int opcode;        /* Current opcode */
     int oparg;         /* Current opcode argument, if any */
     enum why_code why; /* Reason for block stack unwind */
-    PyObject **fastlocals, **freevars;
+    PyObject **freevars;
     PyObject *retval = NULL;            /* Return value */
     PyThreadState *tstate = PyThreadState_GET();
     PyCodeObject *co;
@@ -865,7 +865,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
 /* Code access macros */
 
 /* The integer overflow is checked by an assertion below. */
-#define INSTR_OFFSET()  (sizeof(_Py_CODEUNIT) * (int)(next_instr - first_instr))
+#define INSTR_OFFSET()  ((char*)next_instr - (char*)first_instr)
 #define NEXTOPARG()  do { \
         _Py_CODEUNIT word = *next_instr; \
         opcode = _Py_OPCODE(word); \
@@ -959,7 +959,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
 
 /* Local variable macros */
 
-#define GETLOCAL(i)     (fastlocals[i])
+#define GETLOCAL(i)     (f->f_localsplus[i])
 
 /* The SETLOCAL() macro must not DECREF the local variable in-place and
    then store the new value; it must copy the old value to a temporary
@@ -1045,7 +1045,6 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
     co = f->f_code;
     names = co->co_names;
     consts = co->co_consts;
-    fastlocals = f->f_localsplus;
     freevars = f->f_localsplus + co->co_nlocals;
     assert(PyBytes_Check(co->co_code));
     assert(PyBytes_GET_SIZE(co->co_code) <= INT_MAX);
@@ -1228,7 +1227,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             FAST_DISPATCH();
 
         TARGET(LOAD_FAST) {
-            PyObject *value = GETLOCAL(oparg);
+            PyObject *value = GETLOCAL((unsigned)oparg);
             if (value == NULL) {
                 format_exc_check_arg(PyExc_UnboundLocalError,
                                      UNBOUNDLOCAL_ERROR_MSG,
@@ -1242,7 +1241,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
 
         PREDICTED(LOAD_CONST);
         TARGET(LOAD_CONST) {
-            PyObject *value = GETITEM(consts, oparg);
+            PyObject *value = GETITEM(consts, (unsigned)oparg);
             Py_INCREF(value);
             PUSH(value);
             FAST_DISPATCH();
@@ -1251,7 +1250,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         PREDICTED(STORE_FAST);
         TARGET(STORE_FAST) {
             PyObject *value = POP();
-            SETLOCAL(oparg, value);
+            SETLOCAL((unsigned)oparg, value);
             FAST_DISPATCH();
         }
 
@@ -1526,7 +1525,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
 
         TARGET(LIST_APPEND) {
             PyObject *v = POP();
-            PyObject *list = PEEK(oparg);
+            PyObject *list = PEEK((size_t)(unsigned)oparg);
             int err;
             err = PyList_Append(list, v);
             Py_DECREF(v);
@@ -1731,7 +1730,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             _Py_IDENTIFIER(__annotations__);
             PyObject *ann_dict;
             PyObject *ann = POP();
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             int err;
             if (f->f_locals == NULL) {
                 PyErr_Format(PyExc_SystemError,
@@ -2155,7 +2154,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(STORE_NAME) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *v = POP();
             PyObject *ns = f->f_locals;
             int err;
@@ -2176,7 +2175,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(DELETE_NAME) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *ns = f->f_locals;
             int err;
             if (ns == NULL) {
@@ -2198,7 +2197,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         TARGET(UNPACK_SEQUENCE) {
             PyObject *seq = POP(), *item, **items;
             if (PyTuple_CheckExact(seq) &&
-                PyTuple_GET_SIZE(seq) == oparg) {
+                PyTuple_GET_SIZE(seq) == (Py_ssize_t)(size_t)(unsigned)oparg) {
                 items = ((PyTupleObject *)seq)->ob_item;
                 while (oparg--) {
                     item = items[oparg];
@@ -2206,7 +2205,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
                     PUSH(item);
                 }
             } else if (PyList_CheckExact(seq) &&
-                       PyList_GET_SIZE(seq) == oparg) {
+                       PyList_GET_SIZE(seq) == (Py_ssize_t)(size_t)(unsigned)oparg) {
                 items = ((PyListObject *)seq)->ob_item;
                 while (oparg--) {
                     item = items[oparg];
@@ -2215,7 +2214,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
                 }
             } else if (unpack_iterable(seq, oparg, -1,
                                        stack_pointer + oparg)) {
-                STACKADJ(oparg);
+                STACKADJ((unsigned)oparg);
             } else {
                 /* unpack_iterable() raised an exception */
                 Py_DECREF(seq);
@@ -2241,7 +2240,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(STORE_ATTR) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *owner = TOP();
             PyObject *v = SECOND();
             int err;
@@ -2255,7 +2254,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(DELETE_ATTR) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *owner = POP();
             int err;
             err = PyObject_SetAttr(owner, name, (PyObject *)NULL);
@@ -2266,7 +2265,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(STORE_GLOBAL) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *v = POP();
             int err;
             err = PyDict_SetItem(f->f_globals, name, v);
@@ -2277,7 +2276,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(DELETE_GLOBAL) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             int err;
             err = PyDict_DelItem(f->f_globals, name);
             if (err != 0) {
@@ -2289,7 +2288,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(LOAD_NAME) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *locals = f->f_locals;
             PyObject *v;
             if (locals == NULL) {
@@ -2340,7 +2339,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(LOAD_GLOBAL) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *v;
             if (PyDict_CheckExact(f->f_globals)
                 && PyDict_CheckExact(f->f_builtins))
@@ -2385,9 +2384,9 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(DELETE_FAST) {
-            PyObject *v = GETLOCAL(oparg);
+            PyObject *v = GETLOCAL((unsigned)oparg);
             if (v != NULL) {
-                SETLOCAL(oparg, NULL);
+                SETLOCAL((unsigned)oparg, NULL);
                 DISPATCH();
             }
             format_exc_check_arg(
@@ -2488,7 +2487,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(BUILD_TUPLE) {
-            PyObject *tup = PyTuple_New(oparg);
+            PyObject *tup = PyTuple_New((unsigned)oparg);
             if (tup == NULL)
                 goto error;
             while (--oparg >= 0) {
@@ -2500,7 +2499,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(BUILD_LIST) {
-            PyObject *list =  PyList_New(oparg);
+            PyObject *list =  PyList_New((unsigned)oparg);
             if (list == NULL)
                 goto error;
             while (--oparg >= 0) {
@@ -2571,7 +2570,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
                     err = PySet_Add(set, item);
                 Py_DECREF(item);
             }
-            STACKADJ(-oparg);
+            STACKADJ(-(size_t)(unsigned)oparg);
             if (err != 0) {
                 Py_DECREF(set);
                 goto error;
@@ -2601,7 +2600,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
 
         TARGET(BUILD_MAP) {
             Py_ssize_t i;
-            PyObject *map = _PyDict_NewPresized((Py_ssize_t)oparg);
+            PyObject *map = _PyDict_NewPresized((size_t)(unsigned)oparg);
             if (map == NULL)
                 goto error;
             for (i = oparg; i > 0; i--) {
@@ -2684,12 +2683,12 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             PyObject *map;
             PyObject *keys = TOP();
             if (!PyTuple_CheckExact(keys) ||
-                PyTuple_GET_SIZE(keys) != (Py_ssize_t)oparg) {
+                PyTuple_GET_SIZE(keys) != (Py_ssize_t)(size_t)(unsigned)oparg) {
                 PyErr_SetString(PyExc_SystemError,
                                 "bad BUILD_CONST_KEY_MAP keys argument");
                 goto error;
             }
-            map = _PyDict_NewPresized((Py_ssize_t)oparg);
+            map = _PyDict_NewPresized((size_t)(unsigned)oparg);
             if (map == NULL) {
                 goto error;
             }
@@ -2746,7 +2745,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             for (i = oparg; i > 0; i--) {
                 PyObject *arg = PEEK(i);
                 if (_PyDict_MergeEx(sum, arg, 2) < 0) {
-                    PyObject *func = PEEK(2 + oparg);
+                    PyObject *func = PEEK(2 + (unsigned)oparg);
                     if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
                         PyErr_Format(PyExc_TypeError,
                                 "%.200s%.200s argument after ** "
@@ -2810,7 +2809,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(LOAD_ATTR) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *owner = TOP();
             PyObject *res = PyObject_GetAttr(owner, name);
             Py_DECREF(owner);
@@ -2835,7 +2834,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(IMPORT_NAME) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *fromlist = POP();
             PyObject *level = TOP();
             PyObject *res;
@@ -2869,7 +2868,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(IMPORT_FROM) {
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *from = TOP();
             PyObject *res;
             res = import_from(from, name);
@@ -2880,7 +2879,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(JUMP_FORWARD) {
-            JUMPBY(oparg);
+            JUMPBY((unsigned)oparg);
             FAST_DISPATCH();
         }
 
@@ -2894,7 +2893,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             }
             if (cond == Py_False) {
                 Py_DECREF(cond);
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
                 FAST_DISPATCH();
             }
             err = PyObject_IsTrue(cond);
@@ -2902,7 +2901,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             if (err > 0)
                 err = 0;
             else if (err == 0)
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
             else
                 goto error;
             DISPATCH();
@@ -2918,14 +2917,14 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             }
             if (cond == Py_True) {
                 Py_DECREF(cond);
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
                 FAST_DISPATCH();
             }
             err = PyObject_IsTrue(cond);
             Py_DECREF(cond);
             if (err > 0) {
                 err = 0;
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
             }
             else if (err == 0)
                 ;
@@ -2943,7 +2942,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
                 FAST_DISPATCH();
             }
             if (cond == Py_False) {
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
                 FAST_DISPATCH();
             }
             err = PyObject_IsTrue(cond);
@@ -2953,7 +2952,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
                 err = 0;
             }
             else if (err == 0)
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
             else
                 goto error;
             DISPATCH();
@@ -2968,13 +2967,13 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
                 FAST_DISPATCH();
             }
             if (cond == Py_True) {
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
                 FAST_DISPATCH();
             }
             err = PyObject_IsTrue(cond);
             if (err > 0) {
                 err = 0;
-                JUMPTO(oparg);
+                JUMPTO((unsigned)oparg);
             }
             else if (err == 0) {
                 STACKADJ(-1);
@@ -2987,7 +2986,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
 
         PREDICTED(JUMP_ABSOLUTE);
         TARGET(JUMP_ABSOLUTE) {
-            JUMPTO(oparg);
+            JUMPTO((unsigned)oparg);
 #if FAST_LOOPS
             /* Enabling this path speeds-up all while and for-loops by bypassing
                the per-loop checks for signals.  By default, this should be turned-off
@@ -3065,7 +3064,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
             /* iterator ended normally */
             STACKADJ(-1);
             Py_DECREF(iter);
-            JUMPBY(oparg);
+            JUMPBY((unsigned)oparg);
             PREDICT(POP_BLOCK);
             DISPATCH();
         }
@@ -3076,7 +3075,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
         }
 
         TARGET(CONTINUE_LOOP) {
-            retval = PyLong_FromLong(oparg);
+            retval = PyLong_FromLong((unsigned)oparg);
             if (retval == NULL)
                 goto error;
             why = WHY_CONTINUE;
@@ -3755,7 +3754,7 @@ format_missing(const char *kind, PyCodeObject *co, PyObject *names)
 
 static void
 missing_arguments(PyCodeObject *co, Py_ssize_t missing, Py_ssize_t defcount,
-                  PyObject **fastlocals)
+                  PyFrameObject *f)
 {
     Py_ssize_t i, j = 0;
     Py_ssize_t start, end;
@@ -3793,7 +3792,7 @@ missing_arguments(PyCodeObject *co, Py_ssize_t missing, Py_ssize_t defcount,
 
 static void
 too_many_positional(PyCodeObject *co, Py_ssize_t given, Py_ssize_t defcount,
-                    PyObject **fastlocals)
+                    PyFrameObject *f)
 {
     int plural;
     Py_ssize_t kwonly_given = 0;
@@ -3863,7 +3862,7 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
     PyCodeObject* co = (PyCodeObject*)_co;
     PyFrameObject *f;
     PyObject *retval = NULL;
-    PyObject **fastlocals, **freevars;
+    PyObject **freevars;
     PyThreadState *tstate;
     PyObject *x, *u;
     const Py_ssize_t total_args = co->co_argcount + co->co_kwonlyargcount;
@@ -3883,7 +3882,6 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
     if (f == NULL) {
         return NULL;
     }
-    fastlocals = f->f_localsplus;
     freevars = f->f_localsplus + co->co_nlocals;
 
     /* Create a dictionary for keyword parameters (**kwags) */
@@ -3990,7 +3988,7 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
 
     /* Check the number of positional arguments */
     if (argcount > co->co_argcount && !(co->co_flags & CO_VARARGS)) {
-        too_many_positional(co, argcount, defcount, fastlocals);
+        too_many_positional(co, argcount, defcount, f);
         goto fail;
     }
 
@@ -4004,7 +4002,7 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
             }
         }
         if (missing) {
-            missing_arguments(co, missing, defcount, fastlocals);
+            missing_arguments(co, missing, defcount, f);
             goto fail;
         }
         if (n > m)
@@ -4039,7 +4037,7 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
             missing++;
         }
         if (missing) {
-            missing_arguments(co, missing, -1, fastlocals);
+            missing_arguments(co, missing, -1, f);
             goto fail;
         }
     }
@@ -4845,7 +4843,6 @@ _PyFunction_FastCall(PyCodeObject *co, PyObject **args, Py_ssize_t nargs,
 {
     PyFrameObject *f;
     PyThreadState *tstate = PyThreadState_GET();
-    PyObject **fastlocals;
     Py_ssize_t i;
     PyObject *result;
 
@@ -4861,11 +4858,9 @@ _PyFunction_FastCall(PyCodeObject *co, PyObject **args, Py_ssize_t nargs,
         return NULL;
     }
 
-    fastlocals = f->f_localsplus;
-
     for (i = 0; i < nargs; i++) {
         Py_INCREF(*args);
-        fastlocals[i] = *args++;
+        f->f_localsplus[(size_t)i] = *args++;
     }
     result = PyEval_EvalFrameEx(f,0);
 
@@ -5335,9 +5330,8 @@ unicode_concatenate(PyObject *v, PyObject *w,
         switch (opcode) {
         case STORE_FAST:
         {
-            PyObject **fastlocals = f->f_localsplus;
-            if (GETLOCAL(oparg) == v)
-                SETLOCAL(oparg, NULL);
+            if (GETLOCAL((unsigned)oparg) == v)
+                SETLOCAL((unsigned)oparg, NULL);
             break;
         }
         case STORE_DEREF:
@@ -5352,7 +5346,7 @@ unicode_concatenate(PyObject *v, PyObject *w,
         case STORE_NAME:
         {
             PyObject *names = f->f_code->co_names;
-            PyObject *name = GETITEM(names, oparg);
+            PyObject *name = GETITEM(names, (unsigned)oparg);
             PyObject *locals = f->f_locals;
             if (PyDict_CheckExact(locals) &&
                 PyDict_GetItem(locals, name) == v) {

With all of these changes applied, we get a 1.3% speedup:

Python 3.6.0+ (default, Mar  7 2017, 00:06:13) 
[GCC 4.8.5 20150623 (Red Hat 4.8.5-11)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from performance.benchmarks.bm_nbody import bench_nbody
>>> bench_nbody(10, 'sun', 100000)
9.777307317999657

A 1.3% speedup is simutaneously not very much, and also a surprisingly large amount for eliminating just a few instructions here and there. Of course, your mileage may vary, and this is just one randomly chosen benchmark.

What do compilers do with the CPython interpreter main loop?

Compilers are notoriously bad at compiling the main loop of a programming language interpreter, and the CPython interpreter main loop is no exception: it is hard to compile it perfectly. The difficulty of compilation scales with the number of opcodes which the interpreter has, which in the case of CPython is more than 100, but we can actually get a feel for how well a compiler is doing by looking at just one opcode.

For this exercise I'll look at CPython 3.6's LOAD_FAST opcode, which in C is:

TARGET(LOAD_FAST) {
  PyObject *value = GETLOCAL(oparg);
  if (value == NULL) {
    format_exc_check_arg(PyExc_UnboundLocalError,
                         UNBOUNDLOCAL_ERROR_MSG,
                         PyTuple_GetItem(co->co_varnames, oparg));
    goto error;
  }
  Py_INCREF(value);
  PUSH(value);
  FAST_DISPATCH();
}

This opcode does a very small task: it loads a single local variable and pushes it onto the Python stack. After expanding various macros, the code becomes:

TARGET(LOAD_FAST) {
  PyObject *value = fastlocals[oparg];
  if (value == NULL) {
    // ... error handling ...
  }
  value->ob_refcnt++;
  *stack_pointer++ = value;
  if (_Py_TracingPossible) {
    // ... slow path ...
  }
  f->f_lasti = (sizeof(uint16_t) * (int)(next_instr - first_instr));
  uint16_t word = *next_instr;
  opcode = word & 255;
  oparg = word >> 8;
  next_instr++;
  goto *opcode_targets[opcode];
}

With this code in hand, we can start looking at what compilers do with it, starting with gcc on Linux x64:

; PyObject *value = fastlocals[oparg]
48 8B 44 24 38           mov     rax, [rsp+38h]
49 63 F5                 movsxd  rsi, r13d
48 8B 04 F0              mov     rax, [rax+rsi*8]
; if (value == NULL)
48 85 C0                 test    rax, rax
0F 84 86 3A 00 00        jz      loc_545A2B
; value->ob_refcnt++
; *stack_pointer++ = value
4C 89 F2                 mov     rdx, r14
48 83 00 01              add     qword ptr [rax], 1
49 83 C6 08              add     r14, 8
48 89 02                 mov     [rdx], rax
; if (_Py_TracingPossible)
8B 05 97 AA 3B 00        mov     eax, cs:_Py_TracingPossible
85 C0                    test    eax, eax
0F 85 A3 CD FF FF        jnz     loc_53ED64
; f->f_lasti = (sizeof(uint16_t) * (int)(next_instr - first_instr))
; next_instr++
48 89 EA                 mov     rdx, rbp
48 2B 14 24              sub     rdx, [rsp]
48 83 C5 02              add     rbp, 2
48 D1 FA                 sar     rdx, 1
01 D2                    add     edx, edx
89 53 78                 mov     [rbx+78h], edx
; word = *next_instr
0F B7 55 FE              movzx   edx, word ptr [rbp-2]
; opcode = word & 255
44 0F B6 C2              movzx   r8d, dl
; oparg = word >> 8
0F B6 F6                 movzx   esi, dh
; goto *opcode_targets[opcode]
49 63 D0                 movsxd  rdx, r8d
41 89 F5                 mov     r13d, esi
48 8B 14 D5 20 51 60 00  mov     rdx, ds:opcode_targets_11490[rdx*8]
FF E2                    jmp     rdx

From this, we can infer that gcc made the following choices:

We can also spot a number of sad things:

It feels like there are quite a few things which gcc could improve upon. Despite that, it could be the case that gcc is doing better than other compilers. To find out, we'll have to consider a few other compilers. With that in mind, we can look at what Clang on OSX x64 does:

; PyObject *value = fastlocals[oparg]
49 63 F7                 movsxd  rsi, r15d
49 8B 84 F6 78 01 00 00  mov     rax, [r14+rsi*8+178h]
; if (value == NULL)
48 85 C0                 test    rax, rax
0F 84 51 41 00 00        jz      loc_F7B6D
; value->ob_refcnt++
48 FF 00                 inc     qword ptr [rax]
; *stack_pointer++ = value
49 89 45 00              mov     [r13+0], rax
49 83 C5 08              add     r13, 8
; if (_Py_TracingPossible)
44 8B 3D D2 D5 1A 00     mov     r15d, cs:__Py_TracingPossible
45 85 FF                 test    r15d, r15d
0F 85 B7 BB FF FF        jnz     loc_EF5EE
; f->f_lasti = (sizeof(uint16_t) * (int)(next_instr - first_instr))
48 8B 85 48 FE FF FF     mov     rax, [rbp-1B8h]
48 2B 85 28 FE FF FF     sub     rax, [rbp-1D8h]
48 D1 F8                 sar     rax, 1
48 98                    cdqe
48 01 C0                 add     rax, rax
41 89 46 78              mov     [r14+78h], eax
; word = *next_instr
48 8B 95 48 FE FF FF     mov     rdx, [rbp-1B8h]
0F B7 02                 movzx   eax, word ptr [rdx]
; opcode = word & 255
0F B6 D8                 movzx   ebx, al
; oparg = word >> 8
0F B6 C4                 movzx   eax, ah
41 89 C7                 mov     r15d, eax
; next_instr++
48 83 C2 02              add     rdx, 2
48 89 95 48 FE FF FF     mov     [rbp+var_1B8], rdx
; goto *opcode_targets[opcode]
48 63 C3                 movsxd  rax, ebx
48 8D 0D 87 01 13 00     lea     rcx, _opcode_targets_11343
48 8B 04 C1              mov     rax, [rcx+rax*8]
FF E0                    jmp     rax

From this, we can infer that clang made the following choices:

Again, we can critique this assembly:

Overall, Clang did some things better than gcc, made some of the same mistakes, and did some things worse than gcc.

Next up is MSVC on Windows x64. This compiler is at a slight disadvantage, as it doesn't supported computed goto statements, and has to instead fall back to using a switch statement. Bearing that in mind, the assembly is:

; PyObject *value = fastlocals[oparg]
48 63 C3                  movsxd  rax, ebx
49 8B 8C C7 78 01 00 00   mov     rcx, [r15+rax*8+178h]
; if (value == NULL)
48 85 C9                  test    rcx, rcx
0F 84 5C 68 04 00         jz      loc_1E0791BF
; value->ob_refcnt++
48 FF 01                  inc     qword ptr [rcx]
; *stack_pointer++ = value
49 89 0C 24               mov     [r12], rcx
49 83 C4 08               add     r12, 8
4C 89 64 24 48            mov     [rsp+48h], r12
; end of switch case
E9 77 FF FF FF            jmp     loc_1E0328EF

loc_1E0328EF:
; if (_Py_TracingPossible)
; f->f_lasti = (sizeof(uint16_t) * (int)(next_instr - first_instr))
48 8B C2                 mov     rax, rdx
49 2B C6                 sub     rax, r14
48 D1 F8                 sar     rax, 1
03 C0                    add     eax, eax
83 3D 8F 2F 33 00 00     cmp     cs:_Py_TracingPossible, 0
41 89 47 78              mov     [r15+78h], eax
0F 85 DB 3D 04 00        jnz     loc_1E0766E6
; word = *next_instr
0F B7 1A                 movzx   ebx, word ptr [rdx]
; opcode = word & 255
44 0F B6 EB              movzx   r13d, bl
; oparg = word >> 8
C1 EB 08                 shr     ebx, 8
; next_instr++
48 83 C2 02              add     rdx, 2
48 89 54 24 40           mov     [rsp+40h], rdx
; spill oparg
89 5C 24 70              mov     dword ptr [rsp+70h], ebx
; align instruction stream
0F 1F 40 00              nop     dword ptr [rax+00h]
                         db      66h, 66h, 66h
66 0F 1F 84 00 00 00 00  nop     word ptr [rax+rax+00000000h]
; check opcode in valid range
41 8D 45 FF              lea     eax, [r13-1]
3D 9D 00 00 00           cmp     eax, 9Dh
0F 87 48 6A 04 00        ja      loc_1E079387
; goto *opcode_targets[opcode] (actually jump to switch case)
48 63 C8                 movsxd  rcx, eax
41 8B 84 88 CC 67 03 00  mov     eax, ds:(off_1E0367CC - 1E000000h)[r8+rcx*4]
49 03 C0                 add     rax, r8
FF E0                    jmp     rax

For MSVC, we can infer:

As per the established pattern, the critique on MSVC's code is:

MSVC ends up being like gcc in some regards, like clang in others, and sometimes unlike either. The lack of computed goto statements is certainly painful though, and accounts for four entries in the critique list.

Having bashed the three major compilers for being imperfect, I'm now obliged to provide what I think is the perfect assembly code for this opcode - if I was writing the CPython interpreter main loop in assembly [3] then this is what I'd write for LOAD_FAST:

; PyObject *value = fastlocals[oparg]
48 8B 94 CB 00 01 00 00  mov     rdx, [rbx+rcx*8+100h]
; word = *next_instr
41 0F B7 04 2E           movzx   eax, word ptr [r14+rbp]
; if (value == NULL)
48 85 D2                 test    rdx, rdx
0F 84 F7 12 00 00        jz      loc_1E00696D
; *stack_pointer++ = value
49 89 14 24              mov     [r12], rdx
49 83 C4 08              add     r12, 8
; f->f_lasti = (sizeof(uint16_t) * (int)(next_instr - first_instr))
89 2B                    mov     [rbx], ebp
; value->ob_refcnt++
48 83 02 01              add     qword ptr [rdx], 1
; if (_Py_TracingPossible)
41 F6 47 F8 01           test    byte ptr [r15-8], 1
0F 85 8F BA FF FF        jnz     loc_1E00111E
; oparg = word >> 8
0F B6 CC                 movzx   ecx, ah
; opcode = word & 255
0F B6 C0                 movzx   eax, al
; next_instr++
83 C5 02                 add     ebp, 2
; goto *opcode_targets[opcode]
41 FF 24 C7              jmp     qword ptr [r15+rax*8]

My assembly makes the following register assignment choices:

The combination of storing next_instr as an offset and keeping f biased by offsetof(PyFrameObject, f_lasti) means that f->f_lasti = (sizeof(uint16_t) * (int)(next_instr - first_instr)) is two bytes / one instruction, versus 19 bytes / six instructions for gcc. Keeping f biased has no downside, and has the occasional other upside (accesses to some fields toward the end of PyFrameObject can be accessed with a one-byte displacement rather than a four-byte displacement). Storing next_instr as an offset has the minor downside of making the *next_instr memory operand slightly more complex ([14+rbp] rather than [rbp]), but this is a very low cost, and the offset approach also makes certain jump-related opcodes slightly cleaner and avoids a REX prefix on next_instr++. Keeping the jump table address in r15 is expensive (as POSIX x64 only has six non-volatile registers, and this burns one of those six for a runtime constant), but makes opcode dispatch cheap (which is important, given that dispatch is replicated into all 100+ opcodes), and has some upsides (e.g. rip-relative lea instructions can instead be r15-relative, and thus be executed on a wider range of ports). I also change _Py_TracingPossible from being a 32-bit variable to being a 1-bit variable, and put this variable just before the jump table (so that it can be addressed with a one-byte offset from r15). The other notable thing to point out is pulling word = *next_instr up towards the start of the instruction steam - I want to give the CPU as much time as possible to perform that load, as it is critical for control-flow.

That is one opcode - LOAD_FAST - considered in detail. Only 100+ other opcodes to go...

[1] There are two kinds of assignment to oparg: one kind we've already seen, namely oparg = word >> 8, which fairly obviously can't make oparg negative. The other kind is in the EXTENDED_ARG opcode, which does oparg |= oldoparg << 8;: we have to appeal to language lawyering to claim that oldoparg being non-negative implies that oldoparg << 8 is non-negative (signed overflow is undefined and all that). Then one simple step to claim that oparg being non-negative and oldoparg << 8 being non-negative implies oparg | (oldoparg << 8) is non-negative.

[2] The ah/bh/ch/dh registers can only be accessed if a REX prefix is not used. The r8 through r15 registers can only be accessed if a REX prefix is used. QED.

[3] If I was writing the CPython interpreter main loop in assembly. If. I mean, I'd have to be crazy to write that much assembly...

page: 1 2 3