Skip to content

Commit 163ebe0

Browse files
hkuptyStAlKeR7779
andauthored
std.mem.countScalar: rework to benefit from simd (#25477)
`findScalarPos` might do repetitive work, even if using simd. For example, when searching the string `/abcde/fghijk/lm` for the character `/`, a 16-byte wide search would yield `1000001000000100` but would only count the first `1` and re-search the remaining of the string. When testing locally, the difference was quite significative: ``` count scalar 5737 iterations 522.83us per iterations 0 bytes per iteration worst: 2370us median: 512us stddev: 107.64us count v2 38333 iterations 78.03us per iterations 0 bytes per iteration worst: 713us median: 76us stddev: 10.62us count scalar v2 99565 iterations 29.80us per iterations 0 bytes per iteration worst: 41us median: 29us stddev: 1.04us ``` Note that `count v2` is a simpler string search, similar to the remaining version of the simd approach: ``` pub fn countV2(comptime T: type, haystack: []const T, needle: T) usize { const n = haystack.len; if (n < 1) return 0; var count: usize = 0; for (haystack[0..n]) |item| { count += @intFromBool(item == needle); } return count; } ``` Which implies the compiler yields some optimized code for a simpler loop that is more performant than the `findScalarPos`-based approach, hence the usage of iterative approach for the remaining of the haystack. Co-authored-by: StAlKeR7779 <[email protected]>
1 parent 9760068 commit 163ebe0

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

lib/std/mem.zig

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,12 +1706,26 @@ test count {
17061706

17071707
/// Returns the number of needles inside the haystack
17081708
pub fn countScalar(comptime T: type, haystack: []const T, needle: T) usize {
1709+
const n = haystack.len;
17091710
var i: usize = 0;
17101711
var found: usize = 0;
17111712

1712-
while (findScalarPos(T, haystack, i, needle)) |idx| {
1713-
i = idx + 1;
1714-
found += 1;
1713+
if (use_vectors_for_comparison and
1714+
(@typeInfo(T) == .int or @typeInfo(T) == .float) and std.math.isPowerOfTwo(@bitSizeOf(T)))
1715+
{
1716+
if (std.simd.suggestVectorLength(T)) |block_size| {
1717+
const Block = @Vector(block_size, T);
1718+
1719+
const letter_mask: Block = @splat(needle);
1720+
while (n - i >= block_size) : (i += block_size) {
1721+
const haystack_block: Block = haystack[i..][0..block_size].*;
1722+
found += std.simd.countTrues(letter_mask == haystack_block);
1723+
}
1724+
}
1725+
}
1726+
1727+
for (haystack[i..n]) |item| {
1728+
found += @intFromBool(item == needle);
17151729
}
17161730

17171731
return found;

0 commit comments

Comments
 (0)