字符集查找指的是类似 strpbrk 的接口:在目标字符串中扫描,直到匹配指定字符集的任意字符。本文将使用一个压缩的 bitmap 来替代接口的 breakset 参数,提供不同场景的 SIMD 字符集查找算法。

NOTES:

  • 本文是给 StringZilla 做 PR 时顺便整理的测试集,特定背景使用 AVX2 而非 AVX512。
  • 即使是标量而非 SIMD,使用 bitmap(而非字符串)也有相当的性能收益,详见 benchmark
  • 本文工作重点是场景/benchmark,算法当中最重要的核心思路不是原创的(我太菜了)。但是目前没有确切的最早来源,个人猜测作者是 Daniel Lemire 或者 Wojciech Muła。

WORK IN PROGRESS 警告:代码和注释是完整的,但是详细解释我后续需要补充。(赶车回家)

ASCII character

假设 char 的可用范围是 [0, 127],也就是 ASCII 场景。我们需要的压缩 bitmap 将是一个 u8[16] 的数组(或者是其他内存布局兼容的设计,要求每个字符占据 1 个比特位)。

// 伪代码,用于将常规的字符串转为需要的 bitmap
using bitmap_t = u8[16];
void set(bitmap_t &charset, char ch) {
    charset[ch >> 3] |= 1 << (ch & 7);
}

上面代码提供一个 bitmap 内存布局的描述,非常直接的压缩策略:对一个字符 ch,其高四位决定了 bitmap 字节槽的下标,低三位决定了字节槽内bit的下标。(不使用第八位 bit 7。)

// charset: u8[16] as a bitmap of char field.
//
// Core algorithm:
// u8[16] -> 16 x 8 matrix
//        -> 16 x 8
//           ^^table
// b3-b6 (w4): byte index (row).
// b0-b2 (w3): bit index (col).
ssize_t find_charset_avx2_ascii128(std::ranges::range auto &&rng, const auto &charset) {
    static_assert(std::ranges::size(charset) == ((1 << CHAR_BIT) / CHAR_BIT / 2));
    constexpr auto lane = sizeof(__m256i) / sizeof(char);
    // 广播以解决 AVX 的 lane 独立问题
    auto filter128 = _mm_loadu_si128((__m128i *) &charset);
    auto filter256 = _mm256_set_m128i(filter128, filter128);

    // 用向量查表替代向量计算
    // LUT for 1 << b2b1b0.
    const auto shift = _mm256_setr_epi8(
        1, 2, 4, 8, 16, 32, 64, -128, // 1<<7 == -128
        0, 0, 0, 0,  0,  0,  0,    0,
        1, 2, 4, 8, 16, 32, 64, -128,
        0, 0, 0, 0,  0,  0,  0,    0
    );

    // 这里用了一个定制的 ranges view,可以简化代码和扩展更高维度的 SIMD
    // 但是本文仅用于简化代码,即丢弃 rng 无法被 SIMD 处理的尾部,并且每个代理元素的宽度都是以 lane 为单位
    // 一个简单理解但是不够高效的实现是:simdify = stride(lane) | take(size / lane)
    // 更多应用见之前的文章:https://www.bluepuni.com/archives/cpp-explicit-ilp-in-practice/
    auto simd_view = rng | simdify<lane>;
    for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
        auto addr = (__m256i *) &simd_v;
        auto chars = _mm256_loadu_si256(addr);

        auto col_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b00000111));
        auto row_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b01111000));
             row_index = _mm256_srli_epi16(row_index, 3);

        auto col = _mm256_shuffle_epi8(shift, col_index);
        auto row = _mm256_shuffle_epi8(filter256, row_index);

        // 标准的 SIMD 字符串匹配三板斧:cmpeq+movemask+count
        auto movemask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(
            // 检查 charset[row_index] & (1 << col_index) 是否为 0
            _mm256_and_si256(row, col), _mm256_setzero_si256()));
        if(unsigned counter = ~movemask) {
            return index * lane + std::countr_zero(counter);
        }
    }
    // 尾部处理,直接标量算就行了
    auto offset = lane * std::ranges::size(simd_view);
    auto scalar_view = rng | std::views::drop(offset);
    for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
        auto v = static_cast<unsigned char>(_v);
        auto v_row = charset[v / 8];
        auto v_col = 1 << (v % 8);
        if(v_row & v_col) return offset + index;
    }
    return -1;
}

核心算法就是查表。u8[16] 可视为 16x8 的矩阵,也就是 16 行 8 列。很显然 16 是一个很适合用于 shuffle 的关键数字。

Safe ASCII character

前面的算法无法解决高位污染问题,就是假设你的字符串本身含有非法 ASCII 字符时([-128, 0)),算法会失效。这很不利于 JSON/HTML 场景,比如字符串可能是 UTF-8,但是查找的字符集也就只是固定 ASCII 范围的 []{}:" 等等。

+ template <avx2_ascii128_config Config>
ssize_t find_charset_avx2_ascii128(std::ranges::range auto &&rng, const auto &charset) {
    // ...
    for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
        auto addr = (__m256i *) &simd_v;
        auto chars = _mm256_loadu_si256(addr);

        auto col_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b00000111));
        auto row_index = _mm256_and_si256(chars, _mm256_set1_epi8(0b01111000));
             row_index = _mm256_srli_epi16(row_index, 3);

+       if constexpr (Config.overflow) {
+           auto sign_bits = _mm256_and_si256(chars, _mm256_set1_epi8(0b10000000));
+           row_index = _mm256_or_si256(row_index, sign_bits);
+       }

        auto col = _mm256_shuffle_epi8(shift, col_index);
        auto row = _mm256_shuffle_epi8(filter256, row_index);

        auto movemask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(
            _mm256_and_si256(row, col), _mm256_setzero_si256()));
        if(unsigned counter = ~movemask) {
            return index * lane + std::countr_zero(counter);
        }
    }
    // ...
    for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
        auto v = static_cast<unsigned char>(_v);
+       if constexpr (Config.overflow) if(v >= 128) continue;
        // ...
    }
    return -1;
}

SIMD 部分只需要花费额外的成本就能修复,因为 shuffle 遇到 0x80 会置零,我们需要还原它的槽下标高位信息。注意还需要处理标量。

此时超出 ASCII 范围的字符会被安全地忽略。

Full character

假设 char 的可用范围是完整的 [-128, 127],比如 UTF-8 场景。我们需要修改 bitmap 的定义和核心算法。

// 伪代码,将 bitmap 扩展翻倍
using bitmap_t = u8[32];
void set(bitmap_t &charset, char ch) {
    // 依然是低三位决定槽内下标,但是剩余高五位决定槽
    charset[ch >> 3] |= 1 << (ch & 7);
}

bitmap 的数据结构需要翻倍已容纳字符范围。

// charset: u8[32] as a bitmap of char field.
//
// Core algorithm:
// u8[32] -> 32 x 8 matrix
//        -> 16 x 2 x 8
//        -> 16 x 8 (table1) + 16 x 8 (table2)
//        -> 16 x 8          + 16 x 8
//           ^^table-odd       ^^table-even
// b3-b7 (w5): byte index (row). We use b3 to select the table, which determines the parity.
// b0-b2 (w3): bit index (col).
ssize_t find_charset_avx2(std::ranges::range auto &&rng, const auto &charset) {
    // TODO: check u64[4]...
    static_assert(std::ranges::size(charset) == ((1 << CHAR_BIT) / CHAR_BIT));
    constexpr auto lane = sizeof(__m256i) / sizeof(char);

    const auto filter_lo = _mm_loadu_si128((__m128i *)(&charset));
    const auto filter_hi = _mm_loadu_si128((__m128i *)(&charset) + 1);
    const auto filter_mask = _mm_set1_epi16(0x00ff);
    const auto filter_lo_even = _mm_and_si128(filter_lo, filter_mask);
    const auto filter_hi_even = _mm_and_si128(filter_hi, filter_mask);
    const auto filter_lo_odd = _mm_srli_epi16(filter_lo, 8);
    const auto filter_hi_odd = _mm_srli_epi16(filter_hi, 8);
    // e4 e3 e2 e1
    const auto filter_even_128 = _mm_packus_epi16(filter_lo_even, filter_hi_even);
    // o4 o3 o2 o1
    const auto filter_odd_128 = _mm_packus_epi16(filter_lo_odd, filter_hi_odd);

    // e4 e3 e2 e1 | e4 e3 e2 e1
    const auto filter_even = _mm256_set_m128i(filter_even_128, filter_even_128);
    // o4 o3 o2 o1 | o4 o3 o2 o1
    const auto filter_odd = _mm256_set_m128i(filter_odd_128, filter_odd_128);

    // LUT for 1 << (b3b2b1b0 % 8). Equivalent to 1 << b2b1b0.
    const auto shift_mod = _mm256_setr_epi8(
        1, 2, 4, 8, 16, 32, 64, -128, // 1<<7 == -128
        1, 2, 4, 8, 16, 32, 64, -128,
        1, 2, 4, 8, 16, 32, 64, -128,
        1, 2, 4, 8, 16, 32, 64, -128
    );

    auto simd_view = rng | simdify<lane>;
    for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
        auto addr = &simd_v;
        auto chars = _mm256_loadu_si256((__m256i*) addr);

        // b0-b3, conflict with b3
        auto col_index = _mm256_and_si256(chars, _mm256_set1_epi8(0x0f));
        // b4-b7, conflict with b3
        auto row_index = _mm256_and_si256(_mm256_srli_epi16(chars, 4), _mm256_set1_epi8(0x0f));

        // 求出 1 << (n % 8)
        // 实际上前面可以 set 0x07,然后这里查表就是 1<<n 了,表里每 lane-128 只用前半段
        auto col = _mm256_shuffle_epi8(shift_mod, col_index);

        auto row_even = _mm256_shuffle_epi8(filter_even, row_index);
        auto row_odd  = _mm256_shuffle_epi8(filter_odd, row_index);
        // blend 要求使用 bit 7,那就把原来的 bit 3 左移 4 位,用于确定奇偶
        auto parity = _mm256_slli_epi16(chars, 4);
        auto row = _mm256_blendv_epi8(row_even, row_odd, parity);

        auto movemask = _mm256_movemask_epi8(
            _mm256_cmpeq_epi8(_mm256_andnot_si256(row, col), _mm256_setzero_si256()));

        if(movemask) {
            return index * lane + std::countr_zero(unsigned(movemask));
        }
    }
    auto offset = lane * std::ranges::size(simd_view);
    auto scalar_view = rng | std::views::drop(offset);
    for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
        auto v = static_cast<unsigned char>(_v);
        auto v_row = charset[v / 8];
        auto v_col = 1 << (v % 8);
        if(v_row & v_col) return offset + index;
    }
    return -1;
}

问题在于核心算法的改动。矩阵已经扩展到 32 行 8 列,这会触及 AVX 的 lane 独立问题,无法单趟 shuffle 解决。因此采用了拆表的技巧。比特位 b3 作为选择表的依据,又因为 b3 位于 b7b6b5b4b3 & 1,它同时是一个决定奇偶的比特位。所以拆表就是拆为奇数表和偶数表。这正是前面 filter_evenfilter_odd 的来源。

Safe but faster ASCII character

假设我们的需求再次回到了 safe ASCII 场景,即要求字符串本身支持完整的 char 范围,但是 charset 限制于 ASCII。我们想要更快,可以从 bitmap 的内存布局入手优化。

using bitmap_t = u8[16];
// 「转置」的 bitmap
void set(bitmap_t &charset, char ch) {
    charset[ch & 15] |= 1 << (ch >> 4);
}

现在的 bitmap 内存布局改为低四位决定字节槽下标,高三位决定槽内位下标。

// charset: u8[16] as a bitmap of char field.
//
// Core algorithm:
// u8[16] -> 16 x 8 matrix
//        -> 16 x 8
//           ^^table
// b0-b3 (w4): byte index (row).
// b4-b6 (w3): bit index (col).
template <avx2_ascii128_transposed_config Config>
ssize_t find_charset_avx2_ascii128_transposed(std::ranges::range auto &&rng, const auto &_charset) {
    char charset[16];
    if constexpr (!Config.transposed) {
        std::ranges::fill(charset, 0);
        for(int c = 0; c < 128; c++) {
            if(_charset[c / 8] >> (c % 8) & 1) {
                charset[c % 16] |= 1 << (c / 16);
            }
        }
    } else {
        for(int i = 0; i < 16; ++i) charset[i] = _charset[i];
    }

    static_assert(std::ranges::size(charset) == ((1 << CHAR_BIT) / CHAR_BIT / 2));
    constexpr auto lane = sizeof(__m256i) / sizeof(char);
    auto filter128 = _mm_loadu_si128((__m128i *) &charset);
    auto filter256 = _mm256_set_m128i(filter128, filter128);

    // LUT for 1 << b7b6b5b4, but for high bit (ascii>=128, b7=1), set to 0.
    const auto shift_ignored = _mm256_setr_epi8(
        1, 2, 4, 8, 16, 32, 64, -128,
        0, 0, 0, 0,  0,  0,  0,    0,
        1, 2, 4, 8, 16, 32, 64, -128,
        0, 0, 0, 0,  0,  0,  0,    0
    );

    auto simd_view = rng | simdify<lane>;
    for(auto &&[index, simd_v] : std::views::enumerate(simd_view)) {
        auto addr = (__m256i *) &simd_v;
        auto chars = _mm256_loadu_si256(addr);

        auto &row_index = chars; // We use lowbits.
        auto col_index = _mm256_and_si256(_mm256_srli_epi16(chars, 4), _mm256_set1_epi8(0x0f));

        auto row = _mm256_shuffle_epi8(filter256, row_index);
        auto col = _mm256_shuffle_epi8(shift_ignored, col_index);

        auto movemask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(
            _mm256_and_si256(row, col), _mm256_setzero_si256()));
        if(unsigned counter = ~movemask) {
            return index * lane + std::countr_zero(counter);
        }
    }
    auto offset = lane * std::ranges::size(simd_view);
    auto scalar_view = rng | std::views::drop(offset);
    for(auto &&[index, _v] : std::views::enumerate(scalar_view)) {
        auto v = static_cast<unsigned char>(_v);
        if(v >= 128) continue;
        auto v_row = charset[v % 16];
        auto v_col = 1 << (v / 16);
        if(v_row & v_col) return offset + index;
    }
    return -1;
}

这个策略更好地利用了 SIMD 低位特性,row 隐藏地忽略 shuffle 高下标。可以更为节省指令数。

性能对比

算法/负载 35B 350B 3.5KB 35KB 350KB
avx2_full 15.5 19.3 38.8 39.0 39.4
avx2_ascii 17.3 17.2 44.5 47.6 51.7
avx2_ascii_safe 19.4 22.4 39.6 40.3 44.5
avx2_ascii_transposed_false 0.61 4.74 26.4 45.6 56.0
avx2_ascii_transposed_true 19.7 24.0 47.8 49.9 56.0

(单位: Gi/s,使用 clang++-20 -O3 -march=znver3 编译)

做了一个非常详尽的 benchmark。红色排名最高,灰色最低,蓝绿中等水平,金色优秀。另外,标量算法(bitmap、字符串、标准库、还有各种编译时确定 if-or 链路)都太了,没在榜内。

通用场景的结论是:ASCII 范围内转置版本最优,不管是否需要安全处理非法值。

References

???