@geofflangdale posed the question on Twitter of how to vectorise this:

```
__mmask64 reference_impl(__m512i indices, __mmask64 valids) {
__mmask64 result = 0;
for (int i = 0; i < 64; ++i) {
if (valids.bit[i]) {
result ^= 1ull << indices.byte[i];
}
}
return result;
}
```

After a week of code golf also involving @HaroldAptroot, we ended up with:

```
__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
__m512i khi = _mm512_setr_epi8(
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
);
__m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
__m512i klo = _mm512_set1_epi64(0x0102040810204080);
__m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
__m512i kid = _mm512_set1_epi64(0x8040201008040201);
__m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0);
__m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0);
__m512i x0 = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
__m512i ktr = _mm512_setr_epi8(
0, 8, 16, 24, 32, 40, 48, 56,
1, 9, 17, 25, 33, 41, 49, 57,
2, 10, 18, 26, 34, 42, 50, 58,
3, 11, 19, 27, 35, 43, 51, 59,
4, 12, 20, 28, 36, 44, 52, 60,
5, 13, 21, 29, 37, 45, 53, 61,
6, 14, 22, 30, 38, 46, 54, 62,
7, 15, 23, 31, 39, 47, 55, 63);
__m512i x1 = _mm512_permutexvar_epi8(ktr, x0);
__m512i x2 = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0);
__m512i kff = _mm512_set1_epi8(0xff);
__m512i x3 = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
return _mm512_movepi8_mask(x3);
}
```

NB: If the valid indices can be assumed to be distinct, then the final reduction from 64 bytes to 64 bits can instead be:

```
return _mm512_cmpneq_epi8_mask(x2, _mm512_setzero_si512());
```

As is often the case, `simd_impl`

looks nothing like `reference_impl`

, despite doing the same thing. In particular, `simd_impl`

contains no shifts, and instead contains alternating shuffles and invocations of the mysterious `_mm512_gf2p8affine_epi64_epi8`

, which is the intrinsic function corresponding to the `gf2p8affineqb`

assembly instruction. To understand how `simd_impl`

works, we're going to have to first understand what `gf2p8affineqb`

does.

There are various ways of understanding what `gf2p8affineqb`

does, but for the purposes of this blog post, I think the following Python pseudo-code is most useful:

```
def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
assert len(src1.byte) == len(src2.byte)
dst = vector()
for i in range(len(src1.byte)):
munged_src2 = munge(src2.qword[i // 8])
dst.byte[i] = xor_selected(src1.byte[i], munged_src2, imm8)
return dst
def xor_selected(src1 : u8, munged_src2 : u64, imm8 : u8) -> u8:
result = imm8
for i in range(8):
if src1.bit[i]:
result ^= munged_src2.byte[i]
return result
def munge(x : u64) -> u64:
return transpose8x8(byte_swap(x))
return bitrev_in_each_byte(transpose8x8(x))
def transpose8x8(x : u64) -> u64:
result = 0
for i in range(8):
for j in range(8):
result.byte[i].bit[j] = x.byte[j].bit[i]
return result
def byte_swap(x : u64) -> u64:
result = 0
for i in range(8):
result.byte[i] = x.byte[7 - i]
return result
def bitrev_in_each_byte(x : u64) -> u64:
result = 0
for i in range(8):
result.byte[i] = bitrev(x.byte[i])
return result
def bitrev(x : u8) -> u8:
result = 0
for i in range(8):
result.bit[i] = x.bit[7 - i]
return result
```

The mathematically inclined might notice that the above is in fact doing matrix multiplication of two 8x8 matrices of bits:

```
def gf2p8affineqb(src1 : vector, src2 : vector, imm8 : u8) -> vector:
assert len(src1.byte) == len(src2.byte)
dst = vector()
for i in range(len(src1.qword)):
dst.qword[i] = matmul(src1.qword[i], munge(src2.qword[i]))
for i in range(len(src1.byte)):
dst.byte[i] ^= imm8
return dst
def matmul(lhs : u64, rhs : u64) -> u64:
result = 0
for i in range(8):
for j in range(8):
for k in range(8):
b = lhs.byte[i].bit[j] * rhs.byte[j].bit[k]
result.byte[i].bit[k] += b
return result
def munge(x : u64) -> u64:
```

The `xor_selected`

view of `gf2p8affineqb`

and the `matmul`

view of `gf2p8affineqb`

are complementary: I think that the `xor_selected`

view makes it clearer what is going on, but the `matmul`

view is useful for higher level transformations and optimisations. As a middle ground between the two views, `matmul`

can be re-expressed as byte-level operations by unrolling the `k`

loop:

```
def matmul(lhs : u64, rhs : u64) -> u64:
result = 0
for i in range(8):
for j in range(8):
if lhs.byte[i].bit[j]:
result.byte[i] ^= rhs.byte[j]
return result
```

One observation from the `matmul`

view is that when `src1.qword[i]`

is the identity matrix, we end up with `dst.qword[i]`

being `munge(src2.qword[i])`

. As a 64-bit integer, said identity matrix is `0x8040201008040201`

(i.e. in byte `i`

, just bit `i`

is set). This explains `__m512i kid = _mm512_set1_epi64(0x8040201008040201)`

in `simd_impl`

(`kid`

is just an identity matrix) and also explains `__m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0)`

and `__m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0)`

- these are just applying `munge`

to every qword (as for what said munges are *achieving*, we'll get to later).

Changing tack somewhat, it is time to gradually transform `reference_impl`

to make it look more like `matmul`

. For this, we'll start with a simplified version of `reference_impl`

that takes 8 indices rather than 64:

```
__mmask64 reference_impl_1(__m64i indices, __mmask8 valids) {
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
if (valids.bit[i]) {
result ^= 1ull << indices.byte[i];
}
}
return result;
}
```

The first transformation is to split each 6-bit index into its low 3 bits and high 3 bits, so that we can address bytes of `result`

:

```
__mmask64 reference_impl_2(__m64i indices, __mmask8 valids) {
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
if (valids.bit[i]) {
uint8_t b = indices.byte[i];
uint8_t hi = b >> 3;
uint8_t lo = b & 7;
result.byte[hi] ^= 1 << lo;
}
}
return result;
}
```

Next up we perform loop *fission*; doing the exact same work, but using two loops rather than one (so that we can focus on the loops separately):

```
__mmask64 reference_impl_3(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = b & 7;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
if (valids.bit[i]) {
result.byte[hi.byte[i]] ^= 1 << lo.byte[i];
}
}
return result;
}
```

Then the `if`

and the `1 <<`

can also be moved from the 2^{nd} loop to the 1^{st} loop:

```
__mmask64 reference_impl_4(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
result.byte[hi.byte[i]] ^= lo.byte[i];
}
return result;
}
```

Then a transformation that looks utterly deranged, but is key to the SIMD transformation; rather than directly indexing using `hi.byte[i]`

, we'll loop over the 8 possible values of `hi.byte[i]`

and act when we find the right value:

```
__mmask64 reference_impl_5(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (hi.byte[i] == j) {
result.byte[j] ^= lo.byte[i];
}
}
}
return result;
}
```

Next up we perform loop interchange of the two nested loops:

```
__mmask64 reference_impl_6(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = b >> 3;
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (hi.byte[j] == i) {
result.byte[i] ^= lo.byte[j];
}
}
}
return result;
}
```

Then another transformation that initially looks deranged; the `==`

in `hi.byte[j] == i`

is annoying, and can be replaced by a bit test if we one-hot encode `hi`

:

```
__mmask64 reference_impl_7(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = 1 << (b >> 3);
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (hi.byte[j].bit[i]) {
result.byte[i] ^= lo.byte[j];
}
}
}
return result;
}
```

Then one final transformation to get where we want to be; apply `transpose8x8`

to `hi`

, and undo it by changing `.byte[j].bit[i]`

to `.byte[i].bit[j]`

:

```
__mmask64 reference_impl_8(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = 1 << (b >> 3);
lo.byte[i] = valids.bit[i] ? 1 << (b & 7) : 0;
}
__mmask64 result = 0;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
if (transpose8x8(hi).byte[i].bit[j]) {
result.byte[i] ^= lo.byte[j];
}
}
}
return result;
}
```

A number of these transformations seemed pointless or even unhelpful, but having done them all, the latter half of `reference_impl_8`

is exactly `result = matmul(transpose8x8(hi), lo)`

.

The expression `matmul(transpose8x8(A), B)`

looks deceptively similar to the `matmul(A, munge(B))`

done by `gf2p8affineqb(A, B, 0)`

, and *if* `munge`

was just `transpose8x8`

, then `gf2p8affineqb(munge(A), munge(B), 0)`

would be exactly `matmul(transpose8x8(A), B)`

. Unfortunately, `munge`

also does a bit or byte reversal, causing `gf2p8affineqb(munge(A), munge(B), 0)`

to actually be `matmul(transpose8x8(A), bitrev_in_each_byte(B))`

(if deriving this, note that `munge(A)`

is `bitrev_in_each_byte(transpose8x8(A))`

, `munge(munge(B))`

is `byte_swap(bitrev_in_each_byte(B))`

, and then the `bitrev_in_each_byte`

on `A`

cancels out with the `byte_swap`

on `B`

).

The expression `matmul(transpose8x8(A), bitrev_in_each_byte(B))`

is very close to what we want, and the errant `bitrev_in_each_byte`

can be cancelled out by doing *another* `bitrev_in_each_byte`

on `B`

:

```
__mmask64 reference_impl_9(__m64i indices, __mmask8 valids) {
__m64i hi;
__m64i lo;
for (int i = 0; i < 8; ++i) {
uint8_t b = indices.byte[i];
hi.byte[i] = 1 << (b >> 3);
lo.byte[i] = bitrev(valids.bit[i] ? 1 << (b & 7) : 0);
}
__mmask64 result = gf2p8affineqb(munge(hi), munge(lo), 0);
return result;
}
```

The 1^{st} loop is easy to express in a SIMD manner via a pair of table lookups, thereby giving us the first chunk of `simd_impl`

:

```
__mmask64 simd_impl(__m512i indices, __mmask64 valids) {
__m512i khi = _mm512_setr_epi8(
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04,
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
);
__m512i hi0 = _mm512_permutexvar_epi8(indices, khi);
__m512i klo = _mm512_set1_epi64(0x0102040810204080);
__m512i lo0 = _mm512_maskz_shuffle_epi8(valids, klo, indices);
__m512i kid = _mm512_set1_epi64(0x8040201008040201);
__m512i hi1 = _mm512_gf2p8affine_epi64_epi8(kid, hi0, 0);
__m512i lo1 = _mm512_gf2p8affine_epi64_epi8(kid, lo0, 0);
__m512i x0 = _mm512_gf2p8affine_epi64_epi8(hi1, lo1, 0);
}
```

At this point, `x0.qword[i]`

contains `reference_impl_9(indices.qword[i], valids.word[i])`

. To finish up, "all" we need to do is xor together the eight qwords of `x0`

. The traditional way of doing this would be a shuffle followed by a xor to reduce eight to four, another shuffle followed by a xor to reduce four to two, and yet another shuffle followed by a xor to reduce two to one. We can do better than the traditional approach though. The first step is to do one big shuffle rather than three sequential suffles, where the result of the big shuffle moves the eight bytes `qword[i].byte[0]`

to be contiguous, then the eight bytes `qword[i].byte[1]`

to be contiguous, and so on. Seen differently, the bug shuffle is a transpose on an 8x8 matrix of bytes. After this big shuffle, the remaining problem is to take each contiguous group of eight bytes and xor them together. If we wanted to *add* together each contiguous group of eight bytes, then `_mm512_sad_epu8`

against zero would be one option, but we want xor rather than add. There are a few different ways of approaching the problem, but one cute way is to apply `transpose8x8`

to each contiguous group of eight bytes, after which we just need to xor together each contiguous group of eight *bits*. Applying `transpose8x8`

on its own is hard, but we can apply `munge`

fairly easily, which does `transpose8x8`

followed by `bitrev_in_each_byte`

, and the `bitrev_in_each_byte`

is harmless given that we're about to xor together the bits in each byte. This gives us the next chunk of `simd_impl`

:

```
__m512i ktr = _mm512_setr_epi8(
0, 8, 16, 24, 32, 40, 48, 56,
1, 9, 17, 25, 33, 41, 49, 57,
2, 10, 18, 26, 34, 42, 50, 58,
3, 11, 19, 27, 35, 43, 51, 59,
4, 12, 20, 28, 36, 44, 52, 60,
5, 13, 21, 29, 37, 45, 53, 61,
6, 14, 22, 30, 38, 46, 54, 62,
7, 15, 23, 31, 39, 47, 55, 63);
__m512i x1 = _mm512_permutexvar_epi8(ktr, x0);
__m512i x2 = _mm512_gf2p8affine_epi64_epi8(kid, x1, 0);
```

If the valid indices can be assumed to be distinct, then we can *or* (rather than xor) together the bits in each byte, which is just `_mm512_cmpneq_epi8_mask`

against zero.

If we really do need to xor the bits together, then what we want is this function applied to every byte:

```
def xor_together_bits(x : u8) -> u8:
result = 0
for i in range(8):
if x.bit[i]:
result ^= 0xff
return result
```

If you're thinking that `xor_together_bits`

looks very similar to `xor_selected`

, then you'd be right: `xor_together_bits`

is just `xor_selected`

where every byte of `munged_src2`

is `0xff`

, and it so happens that if every byte of `src2`

is `0xff`

, then the same is true for `munged_src2`

. This gives the final chunk of `simd_impl`

:

```
__m512i kff = _mm512_set1_epi8(0xff);
__m512i x3 = _mm512_gf2p8affine_epi64_epi8(x2, kff, 0);
return _mm512_movepi8_mask(x3);
```