Tenstorrent Wormhole Series Part 1: Physicalities

A company called Tenstorrent design and sell PCIe cards for AI acceleration. At the time of writing, they've recently started shipping their Wormhole n150s and Wormhole n300s cards. After taking the shell and the heatsinks off, the underlying circuit board looks like:

Wormhole n150sWormhole n300s
(Circuit board photos from Tenstorrent's own sales pages)

At the top of the board, there are two QSFP-DD cages for ethernet, which are intended for direct-attach cables between cards, as shown here. The left edge has a PCIe 4.0 x16 connector, which is connected to the immediately adjacent Wormhole ASIC. The Wormhole ASIC is surrounded by six GDDR6 chips (two on the left, four on the right), with each GDDR6 chip containing two 1GB banks, for a total of 12GB GDDR6 connected to the Wormhole ASIC. The right edge has two of what Tenstorrent call a Warp 100 Bridge connector, which seems to be a proprietary 100Gb ethernet connector. After this, the two boards diverge: the n300s card has another Wormhole ASIC with another 12GB of GDDR6 attached, whereas the n150s card has a bunch of empty space. The 2nd Wormhole ASIC on the n300s card is not connected to PCIe. Instead, on the n300s card, there are internal ethernet links between the two ASICs. Accordingly, for the host to communicate with the 2nd ASIC, it needs to use PCIe to communicate with the 1st ASIC, and ask the 1st ASIC to use its internal ethernet to communicate with the 2nd ASIC's internal ethernet.

The Wormhole ASIC itself is described as a 10 by 12 grid of tiles, arranged as:

There are a few different types of tile:

KindCountContents (per tile)
ARC11x Argonaut RISC Core (also connected to PCIe)
D6 x 3Bridge to 2x 1GB GDDR6 (shared by three tiles)
E161x Baby RISC-V CPU (E variant)
256K SRAM
Bridge to 100Gb ethernet (if connected)
PCIe1Bridge to host over PCIe
T80 (not all usable)5x Baby RISC-V CPU (B/T/T/T/NC variants)
1.5MB SRAM
1x Matrix unit (2048 multipliers, each 5b x 7b)
1x Vector/SIMD unit (32 lanes, each 32b wide)

Each E tile can manage 100Gb ethernet (that is, simultaneous transmit at 100Gb/s and receive at 100Gb/s). On the n300s card, E8 of the 1st ASIC is connected to E0 of the 2nd ASIC, and E9 of the 1st ASIC is connected to E1 of the 2nd ASIC. Other tiles are connected to the QSFP-DD cages and to the Warp 100 Bridge connectors (possibly two E tiles on the 2nd ASIC route to the Warp, and four E tiles on the 1st ASIC route to the QSFP-DD cages). There's not enough I/O on the n150s/n300s boards to allow all the E tiles of a single ASIC to be connected, let alone all the E tiles of two ASICs to be connected; you'd need to buy a Galaxy Server to find a form factor with enough I/O to saturate all the E tiles.

Each tile has four outbound connections to its (north/east/south/west) neighbours, along with four corresponding inbound connections from its neighbours, with each connection being 32 bytes wide (apparently upgraded to 64 bytes wide in the next-generation Blackhole ASIC). The combination of the east-bound and south-bound channels are called NoC (Network on Chip) #0, whilst the combination of the west-bound and north-bound channels are called NoC #1. Going eastward from the eastmost edge lands you on the westmost edge (e.g. going east from the T at (9, 3) gets you to PCIe at (0, 3)). A similar thing is true on the other three edges: they are each connected to the edge on the opposite side. Taking the eastbound connections in row 3 as an example, we have:

It looks like the path between most neighbours is short, but then the wraparound path is extremely long. This would make both software and hardware designers unhappy, so the physical reality interleaves tiles to equalize the distances:

A similar interleaving is done to the columns, meaning that the actual physical locations of the various tiles within the grid is closer to:

The various tile-to-tile connections aren't shown on the above, as it would be a horrible mess of overlapping and dodging lines. It does however match the circuit board photos: the DRAM tiles are on the left and right edges, and tiles routing to the same DRAM chip are adjacent (e.g. the three D5 tiles are physically adjacent, even though their logical Y coordinates are 3, 8, 4). The ethernet tiles also end up on the physical edges, which makes sense. Thankfully, software mostly doesn't need to worry about this interleaving (except for HARVESTING_NOC_LOCATIONS, which you'll note are the Y coordinates of an interleaved column of T tiles), and can instead stick to the original logical view:

That said, there's one physical reality that software can't escape: manufacturing defects. You could imagine TSMC having a master template Wormhole ASIC and then photocopying it thousands of times to create the ASICs that actually ship to customers. As with actual photocopying, the photocopies are slightly lower quality than the original. Sometimes the defects are benign, but other times the defects completely ruin the piece of logic or storage that they occur within. Throwing away an entire ASIC because of a single defect would not be viable, so there are various strategies to deal with the defects. One such employed strategy is to disable (or "harvest") an entire row of T tiles if there is a defect in any of the T tiles in that row, effectively pretending that all the T tiles in that row are empty tiles (in the same way that e.g. (0, 2) is empty but still participates in the connection grid). An ASIC with one row of disabled T tiles can be used in an n150s board, which is why the n150s is reported as having 72 T tiles despite the ASIC having 80. Similarly, an ASIC with two rows of disabled T tiles can be used in an n300s board, which is why the n300s is reported as having 64 T tiles per ASIC despite there originally being 80. If the photocopying process had fewer defects than planned for, then one (n150s) or two (n300s) rows of T tiles are disabled regardless, for consistency (of what customers receive, of power usage, etc.). The non-T tiles within a disabled row are kept, though they become slightly harder to address. In particular, some workloads want to really minimise latency to DRAM, so want to run on a T tile immediately adjacent to a D tile. Some rows have 2 D tiles and thus four T tiles immediately adjacent to a D tile, whereas other rows have only 1 D tile and thus only two T tiles immediately adjacent to a D tile. Whilst each PCIe card will have the same number of usable T tiles (72 for n150s, 64 x 2 for n300s), the number of usable T tiles immediately adjacent to a D tile will vary randomly from card to card depending on which rows were disabled.

That wraps up part 1. If you're reading along, part 2 is next.

Arm CSSC quick reference

The AArch64 FEAT_CSSC (Common Short Sequence Compression) extension adds a few instructions operating on general purpose registers. The extension is optional since Armv8.7, and mandatory since Armv8.9 if FEAT_AdvSIMD is implemented. The most interesting addition is a popcnt instruction operating on general purpose registers, but the full list of new instructions is:

CSSC instructionEquivalent pre-CSSC sequence (†)
ABS Rd, Rn
TST  Rn, Rn
CNEG Rd, Rn, MI
CNT Rd, Rn
FMOV d0, Rn
CNT  v0.8B, v0.8B
ADDV b0, v0.8B
FMOV Rd, s0
CTZ Rd, Rn
RBIT Rd, Rn
CLZ  Rd, Rd
SMAX Rd, Rn, Rm
CMP  Rn, Rm
CSEL Rd, Rn, Rm, GT
SMAX Rd, Rn, #i8
MOV  Rtmp, #i8
CMP  Rn, Rtmp
CSEL Rd, Rn, Rtmp, GT
SMIN Rd, Rn, Rm
CMP  Rn, Rm
CSEL Rd, Rn, Rm, LT
SMIN Rd, Rn, #i8
MOV  Rtmp, #i8
CMP  Rn, Rtmp
CSEL Rd, Rn, Rtmp, LT
UMAX Rd, Rn, Rm
CMP  Rn, Rm
CSEL Rd, Rn, Rm, HI
UMAX Rd, Rn, #u8
MOV  Rtmp, #u8
CMP  Rn, Rtmp
CSEL Rd, Rn, Rtmp, HI
UMIN Rd, Rn, Rm
CMP  Rn, Rm
CSEL Rd, Rn, Rm, LO
UMIN Rd, Rn, #u8
MOV  Rtmp, #u8
CMP  Rn, Rtmp
CSEL Rd, Rn, Rtmp, LO

(†) Except that the CSSC instructions do not mutate flags, and CSSC CNT does not mutate any SIMD registers.

My new favourite AArch64 CPU instruction: rotate then merge in to flags (RMIF)

I find myself writing some CPU emulators at the moment, which has caused the AArch64 (aka. ARM64) RMIF instruction to become my new favourite instruction. It takes a 64-bit general purpose register, rotates it right by a specified number of bits, then selectively merges the low four bits into the flags register. A 6-bit field in the instruction gives the rotate amount, and a 4-bit field in the instruction gives a mask of which flag bits to overwrite versus which to leave unchanged.

One use of rmif is to emulate the x86 bt reg, imm instruction, which extracts one bit from a general purpose register, writes that bit to the C flag, and leaves other flags unchanged. Thus bt reg, imm in x86 becomes rmif reg, #((imm - 1) & 63), #2 in AArch64.

At the other end of the spectrum is the x86 inc instruction, which adds 1 to a general purpose register, and then sets most flags based on this addition, but leaves the C flag unchanged. To emulate inc reg, we can first save off the old value of the C flag (via csinc tmp, wzr, wzr, cc or adc tmp, wzr, wzr), then do adds reg, reg, #1 to perform the addition and set all the flags, then rmif tmp, #63, #2 to restore the old value of the C flag.

As another example, the AArch32 muls instruction sets the N and Z flags based on the result of the multiplication, but leaves the C and V flags unchanged. To emulate this on AArch64, we can save off all the flags (mrs tmp, NZCV), then do the multiplication, then set N and Z based on the result but also clobber C and V (ands wzr, dst, dst or adds wzr, dst, #0), then restore the old values of C and V (rmif tmp, #28, #3).

For want of a relative path

Distributing dynamically-linked ELF executables on Linux can be arduous. Some downstream effects of this include:

At first, the problem doesn't look arduous: an ELF executable can contain an rpath or runpath attribute telling the dynamic linker where to find its shared object dependencies, and if that attribute starts with the magic placeholder $ORIGIN/, then the dynamic linker will look in the directory containing the executable (or a directory nearby) for its shared object dependencies. For example, if my_executable depended upon libz.so.1, and my_executable had an rpath or runpath of $ORIGIN/libs, then the executable and the library could be distributed using the following directory structure:

my_executable
libs/
  libz.so.1

This is great, but it has one limitation: an ELF executable also contains an attribute telling the kernel where to find the dynamic linker, and that attribute has to be an absolute path (or a path relative to the current working directory); it cannot be a path relative to the executable. On contemporary x86-64 systems, that absolute path tends to be /lib64/ld-linux-x86-64.so.2. This forces ELF executables to use whatever the system provides at /lib64/ld-linux-x86-64.so.2, which is typically version N of glibc's dynamic linker, for some N. In turn, this forces the ELF executable to use version N of the rest of glibc (libc.so.6, libm.so.6, libpthread.so.0, etc).

Continuing the example, it is likely that my_executable and libz.so.1 were built against some version M of glibc. If M ≤ N, then everything will work fine, but problems often crop up when M > N. One commonly touted solution is to set up a build environment with a very old version M of glibc, build my_executable and libz.so.1 in that environment, and then distribute them and hope for M ≤ N.

The polyfill-glibc project presents another possible solution: build my_executable and libz.so.1 against whatever version of glibc is convenient, and then run polyfill-glibc --target-glibc=N my_executable libz.so.1 to make them compatible with version N of glibc.

Sometimes we don't want either of these solutions, and what we want is to distribute the required version of glibc along with the executable, as in:

my_executable
libs/
  ld-linux-x86-64.so.2
  libc.so.6
  libz.so.1

We can get close to this by adding a launcher script:

launch_my_executable
my_executable
libs/
  ld-linux-x86-64.so.2
  libc.so.6
  libz.so.1

Where launch_my_executable is something like:

#!/usr/bin/env bash

ORIGIN="$(dirname "$(readlink -f "$0")")"
exec "$ORIGIN/libs/ld-linux-x86-64.so.2" --library-path "$ORIGIN/libs:$LD_LIBRARY_PATH" "$ORIGIN/my_executable" "$@"

This will work most of the time, though comes with caveats:

As an alternative without these caveats, there's an experimental tool in the polyfill-glibc repository called set_relative_interp. For our running example, the tool would be invoked as:

$ set_relative_interp my_executable libs/ld-linux-x86-64.so.2

After running the tool as above, my_executable will use $ORIGIN/libs/ld-linux-x86-64.so.2 as its dynamic linker.

(Ab)using gf2p8affineqb to turn indices into bits

@geofflangdale posed the question on Twitter of how to vectorise this:

__mmask64 reference_impl(__m512i indices, __mmask64 valids) {
  __mmask64 result = 0;
  for (int i = 0; i < 64; ++i) {
    if (valids.bit[i]) {
      result ^= 1ull << indices.byte[i];
    }
  }
  return result;
}

After a week of code golf also involving @HaroldAptroot, we ended up with:

__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
  // Convert indices to bits within each qword lane.
  __m512i khi = _mm512_setr_epi8(
    0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
    0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
    0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
    0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
    0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
    0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
    0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
  );
  __m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
  __m512i klo = _mm512_set1_epi64(0x0102040810204080);
  __m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
  __m512i kid = _mm512_set1_epi64(0x8040201008040201);
  __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0);
  __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0);
  __m512i x0  = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
  // Combine results from various qword lanes.
  __m512i ktr = _mm512_setr_epi8(
    0,  8, 16, 24, 32, 40, 48, 56,
    1,  9, 17, 25, 33, 41, 49, 57,
    2, 10, 18, 26, 34, 42, 50, 58,
    3, 11, 19, 27, 35, 43, 51, 59,
    4, 12, 20, 28, 36, 44, 52, 60,
    5, 13, 21, 29, 37, 45, 53, 61,
    6, 14, 22, 30, 38, 46, 54, 62,
    7, 15, 23, 31, 39, 47, 55, 63);
  __m512i x1  = _mm512_permutexvar_epi8(ktr, x0);
  __m512i x2  = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0);
  // Reduce 64 bytes down to 64 bits.
  __m512i kff = _mm512_set1_epi8(0xff);
  __m512i x3  = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
  return _mm512_movepi8_mask(x3);
}

NB: If the valid indices can be assumed to be distinct, then the final reduction from 64 bytes to 64 bits can instead be:

  return _mm512_cmpneq_epi8_mask(x2, _mm512_setzero_si512());

As is often the case, simd_impl looks nothing like reference_impl, despite doing the same thing. In particular, simd_impl contains no shifts, and instead contains alternating shuffles and invocations of the mysterious _mm512_gf2p8affine_epi64_epi8, which is the intrinsic function corresponding to the gf2p8affineqb assembly instruction. To understand how simd_impl works, we're going to have to first understand what gf2p8affineqb does.

There are various ways of understanding what gf2p8affineqb does, but for the purposes of this blog post, I think the following Python pseudo-code is most useful:

def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
  assert len(src1.byte) == len(src2.byte)
  dst = vector()
  for i in range(len(src1.byte)):
    munged_src2 = munge(src2.qword[i // 8])
    dst.byte[i] = xor_selected(src1.byte[i], munged_src2, imm8)
  return dst

def xor_selected(src1 : u8, munged_src2 : u64, imm8 : u8) -> u8:
  result = imm8
  for i in range(8):
    if src1.bit[i]:
      result ^= munged_src2.byte[i]
  return result

def munge(x : u64) -> u64:
  return transpose8x8(byte_swap(x))
  # Or equivalently:
  return bitrev_in_each_byte(transpose8x8(x))

def transpose8x8(x : u64) -> u64:
  result = 0
  for i in range(8):
    for j in range(8):
      result.byte[i].bit[j] = x.byte[j].bit[i]
  return result

def byte_swap(x : u64) -> u64:
  result = 0
  for i in range(8):
    result.byte[i] = x.byte[7 - i]
  return result

def bitrev_in_each_byte(x : u64) -> u64:
  result = 0
  for i in range(8):
    result.byte[i] = bitrev(x.byte[i])
  return result

def bitrev(x : u8) -> u8:
  result = 0
  for i in range(8):
    result.bit[i] = x.bit[7 - i]
  return result

The mathematically inclined might notice that the above is in fact doing matrix multiplication of two 8x8 matrices of bits:

def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
  assert len(src1.byte) == len(src2.byte)
  dst = vector()
  for i in range(len(src1.qword)):
    dst.qword[i] = matmul(src1.qword[i], munge(src2.qword[i]))
  for i in range(len(src1.byte)):
    dst.byte[i] ^= imm8
  return dst

def matmul(lhs : u64, rhs : u64) -> u64:
  result = 0
  for i in range(8):
    for j in range(8):
      for k in range(8):
        b = lhs.byte[i].bit[j] * rhs.byte[j].bit[k] # * or &
        result.byte[i].bit[k] += b                  # + or ^
  return result

def munge(x : u64) -> u64:
  # Same as previously

The xor_selected view of gf2p8affineqb and the matmul view of gf2p8affineqb are complementary: I think that the xor_selected view makes it clearer what is going on, but the matmul view is useful for higher level transformations and optimisations. As a middle ground between the two views, matmul can be re-expressed as byte-level operations by unrolling the k loop:

def matmul(lhs : u64, rhs : u64) -> u64:
  result = 0
  for i in range(8):
    for j in range(8):
      if lhs.byte[i].bit[j]:
        result.byte[i] ^= rhs.byte[j]
  return result

One observation from the matmul view is that when src1.qword[i] is the identity matrix, we end up with dst.qword[i] being munge(src2.qword[i]). As a 64-bit integer, said identity matrix is 0x8040201008040201 (i.e. in byte i, just bit i is set). This explains __m512i kid = _mm512_set1_epi64(0x8040201008040201) in simd_impl (kid is just an identity matrix) and also explains __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0) and __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0) - these are just applying munge to every qword (as for what said munges are achieving, we'll get to later).

Changing tack somewhat, it is time to gradually transform reference_impl to make it look more like matmul. For this, we'll start with a simplified version of reference_impl that takes 8 indices rather than 64:

__mmask64 reference_impl_1(__m64i indices, __mmask8 valids) {
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    if (valids.bit[i]) {
      result ^= 1ull << indices.byte[i];
    }
  }
  return result;
}

The first transformation is to split each 6-bit index into its low 3 bits and high 3 bits, so that we can address bytes of result:

__mmask64 reference_impl_2(__m64i indices, __mmask8 valids) {
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    if (valids.bit[i]) {
      uint8_t b = indices.byte[i];
      uint8_t hi = b >> 3;
      uint8_t lo = b  & 7;
      result.byte[hi] ^= 1 << lo;
    }
  }
  return result;
}

Next up we perform loop fission; doing the exact same work, but using two loops rather than one (so that we can focus on the loops separately):

__mmask64 reference_impl_3(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = b  & 7;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    if (valids.bit[i]) {
      result.byte[hi.byte[i]] ^= 1 << lo.byte[i];
    }
  }
  return result;
}

Then the if and the 1 << can also be moved from the 2nd loop to the 1st loop:

__mmask64 reference_impl_4(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    result.byte[hi.byte[i]] ^= lo.byte[i];
  }
  return result;
}

Then a transformation that looks utterly deranged, but is key to the SIMD transformation; rather than directly indexing using hi.byte[i], we'll loop over the 8 possible values of hi.byte[i] and act when we find the right value:

__mmask64 reference_impl_5(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (hi.byte[i] == j) {
        result.byte[j] ^= lo.byte[i];
      }
    }
  }
  return result;
}

Next up we perform loop interchange of the two nested loops:

__mmask64 reference_impl_6(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (hi.byte[j] == i) {
        result.byte[i] ^= lo.byte[j];
      }
    }
  }
  return result;
}

Then another transformation that initially looks deranged; the == in hi.byte[j] == i is annoying, and can be replaced by a bit test if we one-hot encode hi:

__mmask64 reference_impl_7(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = 1 << (b >> 3);
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (hi.byte[j].bit[i]) {
        result.byte[i] ^= lo.byte[j];
      }
    }
  }
  return result;
}

Then one final transformation to get where we want to be; apply transpose8x8 to hi, and undo it by changing .byte[j].bit[i] to .byte[i].bit[j]:

__mmask64 reference_impl_8(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = 1 << (b >> 3);
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (transpose8x8(hi).byte[i].bit[j]) {
        result.byte[i] ^= lo.byte[j];
      }
    }
  }
  return result;
}

A number of these transformations seemed pointless or even unhelpful, but having done them all, the latter half of reference_impl_8 is exactly result = matmul(transpose8x8(hi), lo).

The expression matmul(transpose8x8(A), B) looks deceptively similar to the matmul(A, munge(B)) done by gf2p8affineqb(A, B, 0), and if munge was just transpose8x8, then gf2p8affineqb(munge(A), munge(B), 0) would be exactly matmul(transpose8x8(A), B). Unfortunately, munge also does a bit or byte reversal, causing gf2p8affineqb(munge(A), munge(B), 0) to actually be matmul(transpose8x8(A), bitrev_in_each_byte(B)) (if deriving this, note that munge(A) is bitrev_in_each_byte(transpose8x8(A)), munge(munge(B)) is byte_swap(bitrev_in_each_byte(B)), and then the bitrev_in_each_byte on A cancels out with the byte_swap on B).

The expression matmul(transpose8x8(A), bitrev_in_each_byte(B)) is very close to what we want, and the errant bitrev_in_each_byte can be cancelled out by doing another bitrev_in_each_byte on B:

__mmask64 reference_impl_9(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = 1 << (b >> 3);
    lo.byte[i] = bitrev(valids.bit[i] ? 1 << (b & 7) : 0);
  }
  __mmask64 result = gf2p8affineqb(munge(hi), munge(lo), 0);
  return result;
}

The 1st loop is easy to express in a SIMD manner via a pair of table lookups, thereby giving us the first chunk of simd_impl:

__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
  // Convert indices to bits within each qword lane.
  __m512i khi = _mm512_setr_epi8(
    0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
    0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
    0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
    0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
    0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
    0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
    0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
  );
  __m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
  __m512i klo = _mm512_set1_epi64(0x0102040810204080);
  __m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
  __m512i kid = _mm512_set1_epi64(0x8040201008040201);
  __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0); // munge
  __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0); // munge
  __m512i x0  = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
}

At this point, x0.qword[i] contains reference_impl_9(indices.qword[i], valids.word[i]). To finish up, "all" we need to do is xor together the eight qwords of x0. The traditional way of doing this would be a shuffle followed by a xor to reduce eight to four, another shuffle followed by a xor to reduce four to two, and yet another shuffle followed by a xor to reduce two to one. We can do better than the traditional approach though. The first step is to do one big shuffle rather than three sequential suffles, where the result of the big shuffle moves the eight bytes qword[i].byte[0] to be contiguous, then the eight bytes qword[i].byte[1] to be contiguous, and so on. Seen differently, the bug shuffle is a transpose on an 8x8 matrix of bytes. After this big shuffle, the remaining problem is to take each contiguous group of eight bytes and xor them together. If we wanted to add together each contiguous group of eight bytes, then _mm512_sad_epu8 against zero would be one option, but we want xor rather than add. There are a few different ways of approaching the problem, but one cute way is to apply transpose8x8 to each contiguous group of eight bytes, after which we just need to xor together each contiguous group of eight bits. Applying transpose8x8 on its own is hard, but we can apply munge fairly easily, which does transpose8x8 followed by bitrev_in_each_byte, and the bitrev_in_each_byte is harmless given that we're about to xor together the bits in each byte. This gives us the next chunk of simd_impl:

  // Combine results from various qword lanes.
  __m512i ktr = _mm512_setr_epi8(
    0,  8, 16, 24, 32, 40, 48, 56,
    1,  9, 17, 25, 33, 41, 49, 57,
    2, 10, 18, 26, 34, 42, 50, 58,
    3, 11, 19, 27, 35, 43, 51, 59,
    4, 12, 20, 28, 36, 44, 52, 60,
    5, 13, 21, 29, 37, 45, 53, 61,
    6, 14, 22, 30, 38, 46, 54, 62,
    7, 15, 23, 31, 39, 47, 55, 63);
  __m512i x1  = _mm512_permutexvar_epi8(ktr, x0); // transpose bytes
  __m512i x2  = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0); // munge

If the valid indices can be assumed to be distinct, then we can or (rather than xor) together the bits in each byte, which is just _mm512_cmpneq_epi8_mask against zero.

If we really do need to xor the bits together, then what we want is this function applied to every byte:

def xor_together_bits(x : u8) -> u8:
  result = 0
  for i in range(8):
    if x.bit[i]:
      result ^= 0xff
  return result

If you're thinking that xor_together_bits looks very similar to xor_selected, then you'd be right: xor_together_bits is just xor_selected where every byte of munged_src2 is 0xff, and it so happens that if every byte of src2 is 0xff, then the same is true for munged_src2. This gives the final chunk of simd_impl:

  // Reduce 64 bytes down to 64 bits.
  __m512i kff = _mm512_set1_epi8(0xff);
  __m512i x3  = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
  return _mm512_movepi8_mask(x3);
page: 2 3 4 5 6