Skip to content

Commit 60abe82

Browse files
committed
[mypyc] Specialize s[i] == 'x' to a codepoint int compare
Recognizes the AST shape `IndexExpr(str) == StrLiteral` (and the symmetric `StrLiteral == IndexExpr(str)`, plus the `!=` variants) and lowers it to an int compare of codepoints reusing the existing CPyStr_GetItemUnsafeAsInt primitive. Today the pattern lowers to CPyStr_GetItem + CPyStr_EqualLiteral, which allocates or looks up a 1-character PyUnicode object per iteration and goes through a generic string-equality call. After specialization it becomes an inlined PyUnicode_READ plus an int compare -- about 4x faster on bench_str_compare with a 3-compares-per-iteration workload, and closer to ~9x with the more typical 1-compare-per-iteration shape. No annotations required; benefits any code that compares a string index against a 1-character literal. Multi-character / empty literals fall through to the generic path (which still correctly returns False). Bounds checking is preserved -- the helper raises IndexError for out-of-range indices, same as the unspecialized path. Stack: builds on the `ord(s[i])` primitive (#20578) and the librt.strings codepoint helpers (#21462, #21504, #21509, #21521, #21522, #21553).
1 parent 4c8f994 commit 60abe82

3 files changed

Lines changed: 284 additions & 1 deletion

File tree

mypyc/irbuild/expression.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,12 @@
9393
is_list_rprimitive,
9494
is_none_rprimitive,
9595
is_object_rprimitive,
96+
is_str_rprimitive,
97+
is_tagged,
9698
is_tuple_rprimitive,
9799
object_rprimitive,
98100
set_rprimitive,
101+
short_int_rprimitive,
99102
vec_api_by_item_type,
100103
)
101104
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
@@ -119,6 +122,7 @@
119122
apply_dunder_specialization,
120123
apply_function_specialization,
121124
apply_method_specialization,
125+
translate_getitem_with_bounds_check,
122126
translate_object_new,
123127
translate_object_setattr,
124128
)
@@ -137,7 +141,12 @@
137141
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
138142
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
139143
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
140-
from mypyc.primitives.str_ops import str_slice_op
144+
from mypyc.primitives.str_ops import (
145+
str_adjust_index_op,
146+
str_get_item_unsafe_as_int_op,
147+
str_range_check_op,
148+
str_slice_op,
149+
)
141150
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
142151

143152
# Name and attribute references
@@ -918,6 +927,16 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
918927
return result
919928

920929
if len(e.operators) == 1:
930+
# s[i] == 'x' / s[i] != 'x' (and the symmetric RHS) -> int compare of
931+
# codepoints. Skips the per-iteration 1-char str allocation/lookup and
932+
# generic str equality call.
933+
if first_op in ("==", "!="):
934+
result = try_specialize_str_index_compare(
935+
builder, first_op, e.operands[0], e.operands[1], e.line
936+
)
937+
if result is not None:
938+
return result
939+
921940
# Special some common simple cases
922941
if first_op in ("is", "is not"):
923942
right_expr = e.operands[1]
@@ -960,6 +979,49 @@ def go(i: int, prev: Value) -> Value:
960979
return go(0, builder.accept(e.operands[0]))
961980

962981

982+
def try_specialize_str_index_compare(
983+
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
984+
) -> Value | None:
985+
"""Specialize `s[i] == 'x'` / `s[i] != 'x'` (and the symmetric form with
986+
operands swapped) into an int compare of codepoints.
987+
988+
Returns None if the pattern doesn't match: the indexed base must be str,
989+
the index must be an integer, and the literal must be a 1-character str.
990+
Multi-character or empty literals fall through to the generic str compare
991+
(which still returns False for them, matching today's behavior).
992+
"""
993+
# Normalize so the IndexExpr is on the left.
994+
if isinstance(rhs, IndexExpr) and not isinstance(lhs, IndexExpr):
995+
lhs, rhs = rhs, lhs
996+
# Shape: s[i] {==, !=} "x" where "x" is exactly one codepoint.
997+
if (
998+
not isinstance(lhs, IndexExpr)
999+
or not isinstance(rhs, StrExpr)
1000+
or len(rhs.value) != 1
1001+
or not is_str_rprimitive(builder.node_type(lhs.base))
1002+
):
1003+
return None
1004+
index_type = builder.node_type(lhs.index)
1005+
if not (is_tagged(index_type) or is_fixed_width_rtype(index_type)):
1006+
return None
1007+
1008+
# ord(s[i]) with bounds check; raises IndexError for out-of-range indices,
1009+
# matching the behavior of the generic s[i] path.
1010+
codepoint = translate_getitem_with_bounds_check(
1011+
builder,
1012+
lhs.base,
1013+
[lhs.index],
1014+
lhs,
1015+
str_adjust_index_op,
1016+
str_range_check_op,
1017+
str_get_item_unsafe_as_int_op,
1018+
)
1019+
if codepoint is None:
1020+
return None
1021+
literal_cp = Integer(ord(rhs.value), short_int_rprimitive, line)
1022+
return builder.binary_op(codepoint, literal_cp, op, line)
1023+
1024+
9631025
def try_specialize_in_expr(
9641026
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
9651027
) -> Value | None:

mypyc/test-data/irbuild-str.test

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,3 +1025,153 @@ def is_digit(x):
10251025
L0:
10261026
r0 = CPyStr_IsDigit(x)
10271027
return r0
1028+
1029+
[case testStrIndexEqLiteral]
1030+
def is_comma(s: str, i: int) -> bool:
1031+
return s[i] == ","
1032+
def is_comma_swapped(s: str, i: int) -> bool:
1033+
return "," == s[i]
1034+
def is_comma_ne(s: str, i: int) -> bool:
1035+
return s[i] != ","
1036+
[out]
1037+
def is_comma(s, i):
1038+
s :: str
1039+
i :: int
1040+
r0 :: native_int
1041+
r1 :: bit
1042+
r2, r3 :: i64
1043+
r4 :: ptr
1044+
r5 :: c_ptr
1045+
r6, r7 :: i64
1046+
r8, r9 :: bool
1047+
r10 :: short_int
1048+
r11 :: bit
1049+
L0:
1050+
r0 = i & 1
1051+
r1 = r0 == 0
1052+
if r1 goto L1 else goto L2 :: bool
1053+
L1:
1054+
r2 = i >> 1
1055+
r3 = r2
1056+
goto L3
1057+
L2:
1058+
r4 = i ^ 1
1059+
r5 = r4
1060+
r6 = CPyLong_AsInt64(r5)
1061+
r3 = r6
1062+
keep_alive i
1063+
L3:
1064+
r7 = CPyStr_AdjustIndex(s, r3)
1065+
r8 = CPyStr_RangeCheck(s, r7)
1066+
if r8 goto L5 else goto L4 :: bool
1067+
L4:
1068+
r9 = raise IndexError('index out of range')
1069+
unreachable
1070+
L5:
1071+
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
1072+
r11 = int_eq r10, 88
1073+
return r11
1074+
def is_comma_swapped(s, i):
1075+
s :: str
1076+
i :: int
1077+
r0 :: native_int
1078+
r1 :: bit
1079+
r2, r3 :: i64
1080+
r4 :: ptr
1081+
r5 :: c_ptr
1082+
r6, r7 :: i64
1083+
r8, r9 :: bool
1084+
r10 :: short_int
1085+
r11 :: bit
1086+
L0:
1087+
r0 = i & 1
1088+
r1 = r0 == 0
1089+
if r1 goto L1 else goto L2 :: bool
1090+
L1:
1091+
r2 = i >> 1
1092+
r3 = r2
1093+
goto L3
1094+
L2:
1095+
r4 = i ^ 1
1096+
r5 = r4
1097+
r6 = CPyLong_AsInt64(r5)
1098+
r3 = r6
1099+
keep_alive i
1100+
L3:
1101+
r7 = CPyStr_AdjustIndex(s, r3)
1102+
r8 = CPyStr_RangeCheck(s, r7)
1103+
if r8 goto L5 else goto L4 :: bool
1104+
L4:
1105+
r9 = raise IndexError('index out of range')
1106+
unreachable
1107+
L5:
1108+
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
1109+
r11 = int_eq r10, 88
1110+
return r11
1111+
def is_comma_ne(s, i):
1112+
s :: str
1113+
i :: int
1114+
r0 :: native_int
1115+
r1 :: bit
1116+
r2, r3 :: i64
1117+
r4 :: ptr
1118+
r5 :: c_ptr
1119+
r6, r7 :: i64
1120+
r8, r9 :: bool
1121+
r10 :: short_int
1122+
r11 :: bit
1123+
L0:
1124+
r0 = i & 1
1125+
r1 = r0 == 0
1126+
if r1 goto L1 else goto L2 :: bool
1127+
L1:
1128+
r2 = i >> 1
1129+
r3 = r2
1130+
goto L3
1131+
L2:
1132+
r4 = i ^ 1
1133+
r5 = r4
1134+
r6 = CPyLong_AsInt64(r5)
1135+
r3 = r6
1136+
keep_alive i
1137+
L3:
1138+
r7 = CPyStr_AdjustIndex(s, r3)
1139+
r8 = CPyStr_RangeCheck(s, r7)
1140+
if r8 goto L5 else goto L4 :: bool
1141+
L4:
1142+
r9 = raise IndexError('index out of range')
1143+
unreachable
1144+
L5:
1145+
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
1146+
r11 = int_ne r10, 88
1147+
return r11
1148+
1149+
[case testStrIndexEqLiteralNoSpecialize]
1150+
def two_char_literal(s: str, i: int) -> bool:
1151+
# Multi-char literals don't match the specialization; falls through to
1152+
# the generic str equality path.
1153+
return s[i] == "ab"
1154+
def empty_literal(s: str, i: int) -> bool:
1155+
# Empty string literals also fall through; the generic path returns False.
1156+
return s[i] == ""
1157+
[out]
1158+
def two_char_literal(s, i):
1159+
s :: str
1160+
i :: int
1161+
r0, r1 :: str
1162+
r2 :: bool
1163+
L0:
1164+
r0 = CPyStr_GetItem(s, i)
1165+
r1 = 'ab'
1166+
r2 = CPyStr_EqualLiteral(r0, r1, 2)
1167+
return r2
1168+
def empty_literal(s, i):
1169+
s :: str
1170+
i :: int
1171+
r0, r1 :: str
1172+
r2 :: bool
1173+
L0:
1174+
r0 = CPyStr_GetItem(s, i)
1175+
r1 = ''
1176+
r2 = CPyStr_EqualLiteral(r0, r1, 0)
1177+
return r2

mypyc/test-data/run-strings.test

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,3 +1412,74 @@ def test_isdigit_strings() -> None:
14121412
assert not "\u00e9\u00e8".isdigit()
14131413
assert not "123\u00e9".isdigit()
14141414
assert not "\U0001d7ce!".isdigit()
1415+
1416+
[case testStrIndexEqLiteralSpecialize]
1417+
from typing import Any
1418+
1419+
from testutil import assertRaises
1420+
1421+
# The specializer fires on the AST shape `IndexExpr == StrLiteral` (or the
1422+
# symmetric swap, and `!=`). The literal has to be a real source-level
1423+
# string literal (can't be passed in as a parameter), so each test
1424+
# function pins one distinct shape.
1425+
1426+
def eq_comma(s: str, i: int) -> bool:
1427+
# Specialized: s[i] == "x".
1428+
return s[i] == ","
1429+
1430+
def ne_comma(s: str, i: int) -> bool:
1431+
# Specialized: s[i] != "x".
1432+
return s[i] != ","
1433+
1434+
def comma_eq(s: str, i: int) -> bool:
1435+
# Specialized: "x" == s[i]. Operand-swap is normalized.
1436+
return "," == s[i]
1437+
1438+
def eq_two_chars(s: str, i: int) -> bool:
1439+
# Not specialized: literal isn't 1 char. Falls through to the generic
1440+
# str compare, which returns False since s[i] is always 1 codepoint.
1441+
return s[i] == "ab"
1442+
1443+
def eq_empty(s: str, i: int) -> bool:
1444+
# Not specialized: empty literal. Same fall-through.
1445+
return s[i] == ""
1446+
1447+
def test_specialized_path() -> None:
1448+
s = "a,b" # comma at index 1
1449+
assert eq_comma(s, 1)
1450+
assert not eq_comma(s, 0)
1451+
assert not eq_comma(s, 2)
1452+
# != inverts.
1453+
assert ne_comma(s, 0)
1454+
assert not ne_comma(s, 1)
1455+
# Literal on the LHS is normalized to the same shape.
1456+
assert comma_eq(s, 1)
1457+
assert not comma_eq(s, 0)
1458+
1459+
def test_negative_index_is_adjusted() -> None:
1460+
s = "a,b"
1461+
assert eq_comma(s, -2) # -2 -> 1 (',')
1462+
assert not eq_comma(s, -1) # -1 -> 2 ('b')
1463+
1464+
def test_non_1char_literal_falls_through() -> None:
1465+
s = "a,b"
1466+
# Generic str compare answers False because s[i] has length 1.
1467+
assert not eq_two_chars(s, 0)
1468+
assert not eq_two_chars(s, 1)
1469+
assert not eq_empty(s, 0)
1470+
1471+
def test_out_of_range_raises_indexerror() -> None:
1472+
# Bounds-check semantics match the unspecialized s[i] path.
1473+
s = "a,b"
1474+
with assertRaises(IndexError):
1475+
eq_comma(s, 3)
1476+
with assertRaises(IndexError):
1477+
eq_comma(s, -4)
1478+
1479+
def test_any_dispatch_uses_generic_path() -> None:
1480+
# Going through `Any` routes through the interpreted wrapper, which
1481+
# uses the unspecialized lowering. Confirms the str surface still
1482+
# works for callers that bypass the specializer.
1483+
f: Any = eq_comma
1484+
assert f("hello,world", 5) is True
1485+
assert f("hello", 0) is False

0 commit comments

Comments
 (0)