字符集查找指的是类似 strpbrk 的接口:在目标字符串中扫描,直到匹配指定字符集的任意字符。本文将使用一个压缩的 bitmap 来替代接口的 breakset 参数,提供不同场景的 SIMD 字符集查找算法。
NOTES:
- 本文是给 StringZilla 做 PR 时顺便整理的测试集,特定背景使用 AVX2 而非 AVX512。
- 即使是标量而非 SIMD,使用 bitmap(而非字符串)也有相当的性能收益,详见 benchmark。
- 本文工作重点是场景/benchmark,算法当中最重要的核心思路不是原创的(我太菜了)。但是目前没有确切的最早来源,个人猜测作者是 Daniel Lemire 或者 Wojciech Muła。
WORK IN PROGRESS 警告:代码和注释是完整的,但是详细解释我后续需要补充。(赶车回家)
ASCII character
假设 char 的可用范围是 [0, 127],也就是 ASCII 场景。我们需要的压缩 bitmap 将是一个 u8[16] 的数组(或者是其他内存布局兼容的设计,要求每个字符占据 1 个比特位)。
// 伪代码,用于将常规的字符串转为需要的 bitmap
using bitmap_t = u8[16];
void set(bitmap_t &charset, char ch) {
charset[ch >> 3] |= 1 << (ch & 7);
}
上面代码提供一个 bitmap 内存布局的描述,非常直接的压缩策略:对一个字符 ch,其高四位决定了 bitmap 字节槽的下标,低三位决定了字节槽内位的下标。(不使用第八位 bit 7。)
// charset: u8[16] as a bitmap of char field.
//
// Core algorithm:
// u8[16] -> 16 x 8 matrix
// -> 16 x 8
// ^^table
// b3-b6 (w4): byte index (row).
// b0-b2 (w3): bit index (col).
ssize_t find_charset_avx2_ascii128(std::ranges::range auto &&rng, const auto &charset) {
static_assert(std::ranges::size(charset) == ((1 << CHAR_BIT) / CHAR_BIT / 2));
constexpr auto lane = sizeof(__m256i) / sizeof(char);
// 广播以解决 AVX 的 lane 独立问题
auto filter128 = _mm_loadu_si128((__m128i *) &charset);
auto filter256 = _mm256_set_m128i(filter128, filter128);
// 用向量查表替代向量计算
// LUT for 1 << b2b1b0.
const auto shift = _mm256_setr_epi8(
1, 2, 4, 8, 16, 32, 64, -128, // 1<<7 == -128
0, 0, 0, 0, 0, 0, 0, 0,
1, 2, 4, 8, 16, 32, 64, -128,
0, 0, 0, 0, 0, 0, 0, 0
);
// 这里用了一个定制的 ranges view,可以简化代码和扩展更高维度的 SIMD
// 但是本文仅用于简化代码,即丢弃 rng 无法被 SIMD 处理的尾部,并且每个代理元素的宽度都是以 lane 为单位
// 一个简单理解但是不够高效的实现是:simdify = stride(lane) | take(size / lane)
// 更多应用见之前的文章:https://www.bluepuni.com/archives/cpp-explicit-ilp-in-practice/
auto simd_view = rng | simdify<lane>;
for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
auto addr = (__m256i *) &simd_v;
auto chars = _mm256_loadu_si256(addr);
auto col_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b00000111));
auto row_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b01111000));
row_index = _mm256_srli_epi16(row_index, 3);
auto col = _mm256_shuffle_epi8(shift, col_index);
auto row = _mm256_shuffle_epi8(filter256, row_index);
// 标准的 SIMD 字符串匹配三板斧:cmpeq+movemask+count
auto movemask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(
// 检查 charset[row_index] & (1 << col_index) 是否为 0
_mm256_and_si256(row, col), _mm256_setzero_si256()));
if(unsigned counter = ~movemask) {
return index * lane + std::countr_zero(counter);
}
}
// 尾部处理,直接标量算就行了
auto offset = lane * std::ranges::size(simd_view);
auto scalar_view = rng | std::views::drop(offset);
for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
auto v = static_cast<unsigned char>(_v);
auto v_row = charset[v / 8];
auto v_col = 1 << (v % 8);
if(v_row & v_col) return offset + index;
}
return -1;
}
核心算法就是查表。u8[16] 可视为 16x8 的矩阵,也就是 16 行 8 列。很显然 16 是一个很适合用于 shuffle 的关键数字。
Safe ASCII character
前面的算法无法解决高位污染问题,就是假设你的字符串本身含有非法 ASCII 字符时([-128, 0)),算法会失效。这很不利于 JSON/HTML 场景,比如字符串可能是 UTF-8,但是查找的字符集也就只是固定 ASCII 范围的 []{}:" 等等。
+ template <avx2_ascii128_config Config>
ssize_t find_charset_avx2_ascii128(std::ranges::range auto &&rng, const auto &charset) {
// ...
for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
auto addr = (__m256i *) &simd_v;
auto chars = _mm256_loadu_si256(addr);
auto col_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b00000111));
auto row_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b01111000));
row_index = _mm256_srli_epi16(row_index, 3);
+ if constexpr (Config.overflow) {
+ auto sign_bits = _mm256_and_si256(chars, _mm256_set1_epi8(0b10000000));
+ row_index = _mm256_or_si256(row_index, sign_bits);
+ }
auto col = _mm256_shuffle_epi8(shift, col_index);
auto row = _mm256_shuffle_epi8(filter256, row_index);
auto movemask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(
_mm256_and_si256(row, col), _mm256_setzero_si256()));
if(unsigned counter = ~movemask) {
return index * lane + std::countr_zero(counter);
}
}
// ...
for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
auto v = static_cast<unsigned char>(_v);
+ if constexpr (Config.overflow) if(v >= 128) continue;
// ...
}
return -1;
}
SIMD 部分只需要花费额外的成本就能修复,因为 shuffle 遇到 0x80 会置零,我们需要还原它的槽下标高位信息。注意还需要处理标量。
此时超出 ASCII 范围的字符会被安全地忽略。
Full character
假设 char 的可用范围是完整的 [-128, 127],比如 UTF-8 场景。我们需要修改 bitmap 的定义和核心算法。
// 伪代码,将 bitmap 扩展翻倍
using bitmap_t = u8[32];
void set(bitmap_t &charset, char ch) {
// 依然是低三位决定槽内下标,但是剩余高五位决定槽
charset[ch >> 3] |= 1 << (ch & 7);
}
bitmap 的数据结构需要翻倍已容纳字符范围。
// charset: u8[32] as a bitmap of char field.
//
// Core algorithm:
// u8[32] -> 32 x 8 matrix
// -> 16 x 2 x 8
// -> 16 x 8 (table1) + 16 x 8 (table2)
// -> 16 x 8 + 16 x 8
// ^^table-odd ^^table-even
// b3-b7 (w5): byte index (row). We use b3 to select the table, which determines the parity.
// b0-b2 (w3): bit index (col).
ssize_t find_charset_avx2(std::ranges::range auto &&rng, const auto &charset) {
// TODO: check u64[4]...
static_assert(std::ranges::size(charset) == ((1 << CHAR_BIT) / CHAR_BIT));
constexpr auto lane = sizeof(__m256i) / sizeof(char);
const auto filter_lo = _mm_loadu_si128((__m128i *)(&charset));
const auto filter_hi = _mm_loadu_si128((__m128i *)(&charset) + 1);
const auto filter_mask = _mm_set1_epi16(0x00ff);
const auto filter_lo_even = _mm_and_si128(filter_lo, filter_mask);
const auto filter_hi_even = _mm_and_si128(filter_hi, filter_mask);
const auto filter_lo_odd = _mm_srli_epi16(filter_lo, 8);
const auto filter_hi_odd = _mm_srli_epi16(filter_hi, 8);
// e4 e3 e2 e1
const auto filter_even_128 = _mm_packus_epi16(filter_lo_even, filter_hi_even);
// o4 o3 o2 o1
const auto filter_odd_128 = _mm_packus_epi16(filter_lo_odd, filter_hi_odd);
// e4 e3 e2 e1 | e4 e3 e2 e1
const auto filter_even = _mm256_set_m128i(filter_even_128, filter_even_128);
// o4 o3 o2 o1 | o4 o3 o2 o1
const auto filter_odd = _mm256_set_m128i(filter_odd_128, filter_odd_128);
// LUT for 1 << (b3b2b1b0 % 8). Equivalent to 1 << b2b1b0.
const auto shift_mod = _mm256_setr_epi8(
1, 2, 4, 8, 16, 32, 64, -128, // 1<<7 == -128
1, 2, 4, 8, 16, 32, 64, -128,
1, 2, 4, 8, 16, 32, 64, -128,
1, 2, 4, 8, 16, 32, 64, -128
);
auto simd_view = rng | simdify<lane>;
for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
auto addr = &simd_v;
auto chars = _mm256_loadu_si256((__m256i*) addr);
// b0-b3, conflict with b3
auto col_index = _mm256_and_si256(chars, _mm256_set1_epi8(0x0f));
// b4-b7, conflict with b3
auto row_index = _mm256_and_si256(_mm256_srli_epi16(chars, 4), _mm256_set1_epi8(0x0f));
// 求出 1 << (n % 8)
// 实际上前面可以 set 0x07,然后这里查表就是 1<<n 了,表里每 lane-128 只用前半段
auto col = _mm256_shuffle_epi8(shift_mod, col_index);
auto row_even = _mm256_shuffle_epi8(filter_even, row_index);
auto row_odd = _mm256_shuffle_epi8(filter_odd, row_index);
// blend 要求使用 bit 7,那就把原来的 bit 3 左移 4 位,用于确定奇偶
auto parity = _mm256_slli_epi16(chars, 4);
auto row = _mm256_blendv_epi8(row_even, row_odd, parity);
auto movemask = _mm256_movemask_epi8(
_mm256_cmpeq_epi8(_mm256_andnot_si256(row, col), _mm256_setzero_si256()));
if(movemask) {
return index * lane + std::countr_zero(unsigned(movemask));
}
}
auto offset = lane * std::ranges::size(simd_view);
auto scalar_view = rng | std::views::drop(offset);
for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
auto v = static_cast<unsigned char>(_v);
auto v_row = charset[v / 8];
auto v_col = 1 << (v % 8);
if(v_row & v_col) return offset + index;
}
return -1;
}
问题在于核心算法的改动。矩阵已经扩展到 32 行 8 列,这会触及 AVX 的 lane 独立问题,无法单趟 shuffle 解决。因此采用了拆表的技巧。比特位 b3 作为选择表的依据,又因为 b3 位于 b7b6b5b4b3 & 1,它同时是一个决定奇偶的比特位。所以拆表就是拆为奇数表和偶数表。这正是前面 filter_even 和 filter_odd 的来源。
Safe but faster ASCII character
假设我们的需求再次回到了 safe ASCII 场景,即要求字符串本身支持完整的 char 范围,但是 charset 限制于 ASCII。我们想要更快,可以从 bitmap 的内存布局入手优化。
using bitmap_t = u8[16];
// 「转置」的 bitmap
void set(bitmap_t &charset, char ch) {
charset[ch & 15] |= 1 << (ch >> 4);
}
现在的 bitmap 内存布局改为低四位决定字节槽下标,高三位决定槽内位下标。
// charset: u8[16] as a bitmap of char field.
//
// Core algorithm:
// u8[16] -> 16 x 8 matrix
// -> 16 x 8
// ^^table
// b0-b3 (w4): byte index (row).
// b4-b6 (w3): bit index (col).
template <avx2_ascii128_transposed_config Config>
ssize_t find_charset_avx2_ascii128_transposed(std::ranges::range auto &&rng, const auto &_charset) {
char charset[16];
if constexpr (!Config.transposed) {
std::ranges::fill(charset, 0);
for(int c = 0; c < 128; c++) {
if(_charset[c / 8] >> (c % 8) & 1) {
charset[c % 16] |= 1 << (c / 16);
}
}
} else {
for(int i = 0; i < 16; ++i) charset[i] = _charset[i];
}
static_assert(std::ranges::size(charset) == ((1 << CHAR_BIT) / CHAR_BIT / 2));
constexpr auto lane = sizeof(__m256i) / sizeof(char);
auto filter128 = _mm_loadu_si128((__m128i *) &charset);
auto filter256 = _mm256_set_m128i(filter128, filter128);
// LUT for 1 << b7b6b5b4, but for high bit (ascii>=128, b7=1), set to 0.
const auto shift_ignored = _mm256_setr_epi8(
1, 2, 4, 8, 16, 32, 64, -128,
0, 0, 0, 0, 0, 0, 0, 0,
1, 2, 4, 8, 16, 32, 64, -128,
0, 0, 0, 0, 0, 0, 0, 0
);
auto simd_view = rng | simdify<lane>;
for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
auto addr = (__m256i *) &simd_v;
auto chars = _mm256_loadu_si256(addr);
auto &row_index = chars; // We use lowbits.
auto col_index = _mm256_and_si256(_mm256_srli_epi16(chars, 4), _mm256_set1_epi8(0x0f));
auto row = _mm256_shuffle_epi8(filter256, row_index);
auto col = _mm256_shuffle_epi8(shift_ignored, col_index);
auto movemask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(
_mm256_and_si256(row, col), _mm256_setzero_si256()));
if(unsigned counter = ~movemask) {
return index * lane + std::countr_zero(counter);
}
}
auto offset = lane * std::ranges::size(simd_view);
auto scalar_view = rng | std::views::drop(offset);
for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
auto v = static_cast<unsigned char>(_v);
if(v >= 128) continue;
auto v_row = charset[v % 16];
auto v_col = 1 << (v / 16);
if(v_row & v_col) return offset + index;
}
return -1;
}
这个策略更好地利用了 SIMD 低位特性,row 隐藏地忽略 shuffle 高下标。可以更为节省指令数。
性能对比
| 算法/负载 | 35B | 350B | 3.5KB | 35KB | 350KB |
|---|---|---|---|---|---|
| avx2_full | 15.5 | 19.3 | 38.8 | 39.0 | 39.4 |
| avx2_ascii | 17.3 | 17.2 | 44.5 | 47.6 | 51.7 |
| avx2_ascii_safe | 19.4 | 22.4 | 39.6 | 40.3 | 44.5 |
| avx2_ascii_transposed_false | 0.61 | 4.74 | 26.4 | 45.6 | 56.0 |
| avx2_ascii_transposed_true | 19.7 | 24.0 | 47.8 | 49.9 | 56.0 |
(单位: Gi/s,使用 clang++-20 -O3 -march=znver3 编译)
做了一个非常详尽的 benchmark。红色排名最高,灰色最低,蓝绿中等水平,金色优秀。另外,标量算法(bitmap、字符串、标准库、还有各种编译时确定 if-or 链路)都太差了,没在榜内。
通用场景的结论是:ASCII 范围内转置版本最优,不管是否需要安全处理非法值。
References
???