@geofflangdale posed the question on Twitter of how to vectorise this:

__mmask64 reference_impl(__m512i indices, __mmask64 valids) {
  __mmask64 result = 0;
  for (int i = 0; i < 64; ++i) {
    if (valids.bit[i]) {
      result ^= 1ull << indices.byte[i];
    }
  }
  return result;
}

After a week of code golf also involving @HaroldAptroot, we ended up with:

__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
  // Convert indices to bits within each qword lane.
  __m512i khi = _mm512_setr_epi8(
    0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
    0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
    0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
    0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
    0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
    0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
    0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
  );
  __m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
  __m512i klo = _mm512_set1_epi64(0x0102040810204080);
  __m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
  __m512i kid = _mm512_set1_epi64(0x8040201008040201);
  __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0);
  __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0);
  __m512i x0  = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
  // Combine results from various qword lanes.
  __m512i ktr = _mm512_setr_epi8(
    0,  8, 16, 24, 32, 40, 48, 56,
    1,  9, 17, 25, 33, 41, 49, 57,
    2, 10, 18, 26, 34, 42, 50, 58,
    3, 11, 19, 27, 35, 43, 51, 59,
    4, 12, 20, 28, 36, 44, 52, 60,
    5, 13, 21, 29, 37, 45, 53, 61,
    6, 14, 22, 30, 38, 46, 54, 62,
    7, 15, 23, 31, 39, 47, 55, 63);
  __m512i x1  = _mm512_permutexvar_epi8(ktr, x0);
  __m512i x2  = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0);
  // Reduce 64 bytes down to 64 bits.
  __m512i kff = _mm512_set1_epi8(0xff);
  __m512i x3  = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
  return _mm512_movepi8_mask(x3);
}

NB: If the valid indices can be assumed to be distinct, then the final reduction from 64 bytes to 64 bits can instead be:

  return _mm512_cmpneq_epi8_mask(x2, _mm512_setzero_si512());

As is often the case, simd_impl looks nothing like reference_impl, despite doing the same thing. In particular, simd_impl contains no shifts, and instead contains alternating shuffles and invocations of the mysterious _mm512_gf2p8affine_epi64_epi8, which is the intrinsic function corresponding to the gf2p8affineqb assembly instruction. To understand how simd_impl works, we're going to have to first understand what gf2p8affineqb does.

There are various ways of understanding what gf2p8affineqb does, but for the purposes of this blog post, I think the following Python pseudo-code is most useful:

def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
  assert len(src1.byte) == len(src2.byte)
  dst = vector()
  for i in range(len(src1.byte)):
    munged_src2 = munge(src2.qword[i // 8])
    dst.byte[i] = xor_selected(src1.byte[i], munged_src2, imm8)
  return dst

def xor_selected(src1 : u8, munged_src2 : u64, imm8 : u8) -> u8:
  result = imm8
  for i in range(8):
    if src1.bit[i]:
      result ^= munged_src2.byte[i]
  return result

def munge(x : u64) -> u64:
  return transpose8x8(byte_swap(x))
  # Or equivalently:
  return bitrev_in_each_byte(transpose8x8(x))

def transpose8x8(x : u64) -> u64:
  result = 0
  for i in range(8):
    for j in range(8):
      result.byte[i].bit[j] = x.byte[j].bit[i]
  return result

def byte_swap(x : u64) -> u64:
  result = 0
  for i in range(8):
    result.byte[i] = x.byte[7 - i]
  return result

def bitrev_in_each_byte(x : u64) -> u64:
  result = 0
  for i in range(8):
    result.byte[i] = bitrev(x.byte[i])
  return result

def bitrev(x : u8) -> u8:
  result = 0
  for i in range(8):
    result.bit[i] = x.bit[7 - i]
  return result

The mathematically inclined might notice that the above is in fact doing matrix multiplication of two 8x8 matrices of bits:

def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
  assert len(src1.byte) == len(src2.byte)
  dst = vector()
  for i in range(len(src1.qword)):
    dst.qword[i] = matmul(src1.qword[i], munge(src2.qword[i]))
  for i in range(len(src1.byte)):
    dst.byte[i] ^= imm8
  return dst

def matmul(lhs : u64, rhs : u64) -> u64:
  result = 0
  for i in range(8):
    for j in range(8):
      for k in range(8):
        b = lhs.byte[i].bit[j] * rhs.byte[j].bit[k] # * or &
        result.byte[i].bit[k] += b                  # + or ^
  return result

def munge(x : u64) -> u64:
  # Same as previously

The xor_selected view of gf2p8affineqb and the matmul view of gf2p8affineqb are complementary: I think that the xor_selected view makes it clearer what is going on, but the matmul view is useful for higher level transformations and optimisations. As a middle ground between the two views, matmul can be re-expressed as byte-level operations by unrolling the k loop:

def matmul(lhs : u64, rhs : u64) -> u64:
  result = 0
  for i in range(8):
    for j in range(8):
      if lhs.byte[i].bit[j]:
        result.byte[i] ^= rhs.byte[j]
  return result

One observation from the matmul view is that when src1.qword[i] is the identity matrix, we end up with dst.qword[i] being munge(src2.qword[i]). As a 64-bit integer, said identity matrix is 0x8040201008040201 (i.e. in byte i, just bit i is set). This explains __m512i kid = _mm512_set1_epi64(0x8040201008040201) in simd_impl (kid is just an identity matrix) and also explains __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0) and __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0) - these are just applying munge to every qword (as for what said munges are achieving, we'll get to later).

Changing tack somewhat, it is time to gradually transform reference_impl to make it look more like matmul. For this, we'll start with a simplified version of reference_impl that takes 8 indices rather than 64:

__mmask64 reference_impl_1(__m64i indices, __mmask8 valids) {
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    if (valids.bit[i]) {
      result ^= 1ull << indices.byte[i];
    }
  }
  return result;
}

The first transformation is to split each 6-bit index into its low 3 bits and high 3 bits, so that we can address bytes of result:

__mmask64 reference_impl_2(__m64i indices, __mmask8 valids) {
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    if (valids.bit[i]) {
      uint8_t b = indices.byte[i];
      uint8_t hi = b >> 3;
      uint8_t lo = b  & 7;
      result.byte[hi] ^= 1 << lo;
    }
  }
  return result;
}

Next up we perform loop fission; doing the exact same work, but using two loops rather than one (so that we can focus on the loops separately):

__mmask64 reference_impl_3(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = b  & 7;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    if (valids.bit[i]) {
      result.byte[hi.byte[i]] ^= 1 << lo.byte[i];
    }
  }
  return result;
}

Then the if and the 1 << can also be moved from the 2nd loop to the 1st loop:

__mmask64 reference_impl_4(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    result.byte[hi.byte[i]] ^= lo.byte[i];
  }
  return result;
}

Then a transformation that looks utterly deranged, but is key to the SIMD transformation; rather than directly indexing using hi.byte[i], we'll loop over the 8 possible values of hi.byte[i] and act when we find the right value:

__mmask64 reference_impl_5(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (hi.byte[i] == j) {
        result.byte[j] ^= lo.byte[i];
      }
    }
  }
  return result;
}

Next up we perform loop interchange of the two nested loops:

__mmask64 reference_impl_6(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = b >> 3;
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (hi.byte[j] == i) {
        result.byte[i] ^= lo.byte[j];
      }
    }
  }
  return result;
}

Then another transformation that initially looks deranged; the == in hi.byte[j] == i is annoying, and can be replaced by a bit test if we one-hot encode hi:

__mmask64 reference_impl_7(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = 1 << (b >> 3);
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (hi.byte[j].bit[i]) {
        result.byte[i] ^= lo.byte[j];
      }
    }
  }
  return result;
}

Then one final transformation to get where we want to be; apply transpose8x8 to hi, and undo it by changing .byte[j].bit[i] to .byte[i].bit[j]:

__mmask64 reference_impl_8(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = 1 << (b >> 3);
    lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
  }
  __mmask64 result = 0;
  for (int i = 0; i < 8; ++i) {
    for (int j = 0; j < 8; ++j) {
      if (transpose8x8(hi).byte[i].bit[j]) {
        result.byte[i] ^= lo.byte[j];
      }
    }
  }
  return result;
}

A number of these transformations seemed pointless or even unhelpful, but having done them all, the latter half of reference_impl_8 is exactly result = matmul(transpose8x8(hi), lo).

The expression matmul(transpose8x8(A), B) looks deceptively similar to the matmul(A, munge(B)) done by gf2p8affineqb(A, B, 0), and if munge was just transpose8x8, then gf2p8affineqb(munge(A), munge(B), 0) would be exactly matmul(transpose8x8(A), B). Unfortunately, munge also does a bit or byte reversal, causing gf2p8affineqb(munge(A), munge(B), 0) to actually be matmul(transpose8x8(A), bitrev_in_each_byte(B)) (if deriving this, note that munge(A) is bitrev_in_each_byte(transpose8x8(A)), munge(munge(B)) is byte_swap(bitrev_in_each_byte(B)), and then the bitrev_in_each_byte on A cancels out with the byte_swap on B).

The expression matmul(transpose8x8(A), bitrev_in_each_byte(B)) is very close to what we want, and the errant bitrev_in_each_byte can be cancelled out by doing another bitrev_in_each_byte on B:

__mmask64 reference_impl_9(__m64i indices, __mmask8 valids) {
  __m64i hi;
  __m64i lo;
  for (int i = 0; i < 8; ++i) {
    uint8_t b = indices.byte[i];
    hi.byte[i] = 1 << (b >> 3);
    lo.byte[i] = bitrev(valids.bit[i] ? 1 << (b & 7) : 0);
  }
  __mmask64 result = gf2p8affineqb(munge(hi), munge(lo), 0);
  return result;
}

The 1st loop is easy to express in a SIMD manner via a pair of table lookups, thereby giving us the first chunk of simd_impl:

__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
  // Convert indices to bits within each qword lane.
  __m512i khi = _mm512_setr_epi8(
    0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
    0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
    0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
    0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
    0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
    0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
    0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
  );
  __m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
  __m512i klo = _mm512_set1_epi64(0x0102040810204080);
  __m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
  __m512i kid = _mm512_set1_epi64(0x8040201008040201);
  __m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0); // munge
  __m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0); // munge
  __m512i x0  = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
}

At this point, x0.qword[i] contains reference_impl_9(indices.qword[i], valids.word[i]). To finish up, "all" we need to do is xor together the eight qwords of x0. The traditional way of doing this would be a shuffle followed by a xor to reduce eight to four, another shuffle followed by a xor to reduce four to two, and yet another shuffle followed by a xor to reduce two to one. We can do better than the traditional approach though. The first step is to do one big shuffle rather than three sequential suffles, where the result of the big shuffle moves the eight bytes qword[i].byte[0] to be contiguous, then the eight bytes qword[i].byte[1] to be contiguous, and so on. Seen differently, the bug shuffle is a transpose on an 8x8 matrix of bytes. After this big shuffle, the remaining problem is to take each contiguous group of eight bytes and xor them together. If we wanted to add together each contiguous group of eight bytes, then _mm512_sad_epu8 against zero would be one option, but we want xor rather than add. There are a few different ways of approaching the problem, but one cute way is to apply transpose8x8 to each contiguous group of eight bytes, after which we just need to xor together each contiguous group of eight bits. Applying transpose8x8 on its own is hard, but we can apply munge fairly easily, which does transpose8x8 followed by bitrev_in_each_byte, and the bitrev_in_each_byte is harmless given that we're about to xor together the bits in each byte. This gives us the next chunk of simd_impl:

  // Combine results from various qword lanes.
  __m512i ktr = _mm512_setr_epi8(
    0,  8, 16, 24, 32, 40, 48, 56,
    1,  9, 17, 25, 33, 41, 49, 57,
    2, 10, 18, 26, 34, 42, 50, 58,
    3, 11, 19, 27, 35, 43, 51, 59,
    4, 12, 20, 28, 36, 44, 52, 60,
    5, 13, 21, 29, 37, 45, 53, 61,
    6, 14, 22, 30, 38, 46, 54, 62,
    7, 15, 23, 31, 39, 47, 55, 63);
  __m512i x1  = _mm512_permutexvar_epi8(ktr, x0); // transpose bytes
  __m512i x2  = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0); // munge

If the valid indices can be assumed to be distinct, then we can or (rather than xor) together the bits in each byte, which is just _mm512_cmpneq_epi8_mask against zero.

If we really do need to xor the bits together, then what we want is this function applied to every byte:

def xor_together_bits(x : u8) -> u8:
  result = 0
  for i in range(8):
    if x.bit[i]:
      result ^= 0xff
  return result

If you're thinking that xor_together_bits looks very similar to xor_selected, then you'd be right: xor_together_bits is just xor_selected where every byte of munged_src2 is 0xff, and it so happens that if every byte of src2 is 0xff, then the same is true for munged_src2. This gives the final chunk of simd_impl:

  // Reduce 64 bytes down to 64 bits.
  __m512i kff = _mm512_set1_epi8(0xff);
  __m512i x3  = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
  return _mm512_movepi8_mask(x3);