Skip to content

Commit 7f0f507

Browse files
authored
Merge pull request #3380 from ruby/handle-named-capture-escapes
Handle escapes in named capture names
2 parents c44bbdf + b4b7a69 commit 7f0f507

File tree

4 files changed

+223
-17
lines changed

4 files changed

+223
-17
lines changed

include/prism/util/pm_buffer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ void pm_buffer_append_varsint(pm_buffer_t *buffer, int32_t value);
137137
*/
138138
void pm_buffer_append_double(pm_buffer_t *buffer, double value);
139139

140+
/**
141+
* Append a unicode codepoint to the buffer.
142+
*
143+
* @param buffer The buffer to append to.
144+
* @param value The character to append.
145+
* @returns True if the codepoint was valid and appended successfully, false
146+
* otherwise.
147+
*/
148+
bool pm_buffer_append_unicode_codepoint(pm_buffer_t *buffer, uint32_t value);
149+
140150
/**
141151
* The different types of escaping that can be performed by the buffer when
142152
* appending a slice of Ruby source code.

src/prism.c

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9551,21 +9551,7 @@ escape_write_unicode(pm_parser_t *parser, pm_buffer_t *buffer, const uint8_t fla
95519551
parser->explicit_encoding = PM_ENCODING_UTF_8_ENTRY;
95529552
}
95539553

9554-
if (value <= 0x7F) { // 0xxxxxxx
9555-
pm_buffer_append_byte(buffer, (uint8_t) value);
9556-
} else if (value <= 0x7FF) { // 110xxxxx 10xxxxxx
9557-
pm_buffer_append_byte(buffer, (uint8_t) (0xC0 | (value >> 6)));
9558-
pm_buffer_append_byte(buffer, (uint8_t) (0x80 | (value & 0x3F)));
9559-
} else if (value <= 0xFFFF) { // 1110xxxx 10xxxxxx 10xxxxxx
9560-
pm_buffer_append_byte(buffer, (uint8_t) (0xE0 | (value >> 12)));
9561-
pm_buffer_append_byte(buffer, (uint8_t) (0x80 | ((value >> 6) & 0x3F)));
9562-
pm_buffer_append_byte(buffer, (uint8_t) (0x80 | (value & 0x3F)));
9563-
} else if (value <= 0x10FFFF) { // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
9564-
pm_buffer_append_byte(buffer, (uint8_t) (0xF0 | (value >> 18)));
9565-
pm_buffer_append_byte(buffer, (uint8_t) (0x80 | ((value >> 12) & 0x3F)));
9566-
pm_buffer_append_byte(buffer, (uint8_t) (0x80 | ((value >> 6) & 0x3F)));
9567-
pm_buffer_append_byte(buffer, (uint8_t) (0x80 | (value & 0x3F)));
9568-
} else {
9554+
if (!pm_buffer_append_unicode_codepoint(buffer, value)) {
95699555
pm_parser_err(parser, start, end, PM_ERR_ESCAPE_INVALID_UNICODE);
95709556
pm_buffer_append_byte(buffer, 0xEF);
95719557
pm_buffer_append_byte(buffer, 0xBF);
@@ -20866,6 +20852,123 @@ typedef struct {
2086620852
bool shared;
2086720853
} parse_regular_expression_named_capture_data_t;
2086820854

20855+
static inline const uint8_t *
20856+
pm_named_capture_escape_hex(pm_buffer_t *unescaped, const uint8_t *cursor, const uint8_t *end) {
20857+
cursor++;
20858+
20859+
if (cursor < end && pm_char_is_hexadecimal_digit(*cursor)) {
20860+
uint8_t value = escape_hexadecimal_digit(*cursor);
20861+
cursor++;
20862+
20863+
if (cursor < end && pm_char_is_hexadecimal_digit(*cursor)) {
20864+
value = (uint8_t) ((value << 4) | escape_hexadecimal_digit(*cursor));
20865+
cursor++;
20866+
}
20867+
20868+
pm_buffer_append_byte(unescaped, value);
20869+
} else {
20870+
pm_buffer_append_string(unescaped, "\\x", 2);
20871+
}
20872+
20873+
return cursor;
20874+
}
20875+
20876+
static inline const uint8_t *
20877+
pm_named_capture_escape_octal(pm_buffer_t *unescaped, const uint8_t *cursor, const uint8_t *end) {
20878+
uint8_t value = (uint8_t) (*cursor - '0');
20879+
cursor++;
20880+
20881+
if (cursor < end && pm_char_is_octal_digit(*cursor)) {
20882+
value = ((uint8_t) (value << 3)) | ((uint8_t) (*cursor - '0'));
20883+
cursor++;
20884+
20885+
if (cursor < end && pm_char_is_octal_digit(*cursor)) {
20886+
value = ((uint8_t) (value << 3)) | ((uint8_t) (*cursor - '0'));
20887+
cursor++;
20888+
}
20889+
}
20890+
20891+
pm_buffer_append_byte(unescaped, value);
20892+
return cursor;
20893+
}
20894+
20895+
static inline const uint8_t *
20896+
pm_named_capture_escape_unicode(pm_parser_t *parser, pm_buffer_t *unescaped, const uint8_t *cursor, const uint8_t *end) {
20897+
const uint8_t *start = cursor - 1;
20898+
cursor++;
20899+
20900+
if (cursor >= end) {
20901+
pm_buffer_append_string(unescaped, "\\u", 2);
20902+
return cursor;
20903+
}
20904+
20905+
if (*cursor != '{') {
20906+
size_t length = pm_strspn_hexadecimal_digit(cursor, MIN(end - cursor, 4));
20907+
uint32_t value = escape_unicode(parser, cursor, length);
20908+
20909+
if (!pm_buffer_append_unicode_codepoint(unescaped, value)) {
20910+
pm_buffer_append_string(unescaped, (const char *) start, (size_t) ((cursor + length) - start));
20911+
}
20912+
20913+
return cursor + length;
20914+
}
20915+
20916+
cursor++;
20917+
for (;;) {
20918+
while (cursor < end && *cursor == ' ') cursor++;
20919+
20920+
if (cursor >= end) break;
20921+
if (*cursor == '}') {
20922+
cursor++;
20923+
break;
20924+
}
20925+
20926+
size_t length = pm_strspn_hexadecimal_digit(cursor, end - cursor);
20927+
uint32_t value = escape_unicode(parser, cursor, length);
20928+
20929+
(void) pm_buffer_append_unicode_codepoint(unescaped, value);
20930+
cursor += length;
20931+
}
20932+
20933+
return cursor;
20934+
}
20935+
20936+
static void
20937+
pm_named_capture_escape(pm_parser_t *parser, pm_buffer_t *unescaped, const uint8_t *source, const size_t length, const uint8_t *cursor) {
20938+
const uint8_t *end = source + length;
20939+
pm_buffer_append_string(unescaped, (const char *) source, (size_t) (cursor - source));
20940+
20941+
for (;;) {
20942+
if (++cursor >= end) {
20943+
pm_buffer_append_byte(unescaped, '\\');
20944+
return;
20945+
}
20946+
20947+
switch (*cursor) {
20948+
case 'x':
20949+
cursor = pm_named_capture_escape_hex(unescaped, cursor, end);
20950+
break;
20951+
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7':
20952+
cursor = pm_named_capture_escape_octal(unescaped, cursor, end);
20953+
break;
20954+
case 'u':
20955+
cursor = pm_named_capture_escape_unicode(parser, unescaped, cursor, end);
20956+
break;
20957+
default:
20958+
pm_buffer_append_byte(unescaped, '\\');
20959+
break;
20960+
}
20961+
20962+
const uint8_t *next_cursor = pm_memchr(cursor, '\\', (size_t) (end - cursor), parser->encoding_changed, parser->encoding);
20963+
if (next_cursor == NULL) break;
20964+
20965+
pm_buffer_append_string(unescaped, (const char *) cursor, (size_t) (next_cursor - cursor));
20966+
cursor = next_cursor;
20967+
}
20968+
20969+
pm_buffer_append_string(unescaped, (const char *) cursor, (size_t) (end - cursor));
20970+
}
20971+
2086920972
/**
2087020973
* This callback is called when the regular expression parser encounters a named
2087120974
* capture group.
@@ -20880,13 +20983,32 @@ parse_regular_expression_named_capture(const pm_string_t *capture, void *data) {
2088020983

2088120984
const uint8_t *source = pm_string_source(capture);
2088220985
size_t length = pm_string_length(capture);
20986+
pm_buffer_t unescaped = { 0 };
20987+
20988+
// First, we need to handle escapes within the name of the capture group.
20989+
// This is because regular expressions have three different representations
20990+
// in prism. The first is the plain source code. The second is the
20991+
// representation that will be sent to the regular expression engine, which
20992+
// is the value of the "unescaped" field. This is poorly named, because it
20993+
// actually still contains escapes, just a subset of them that the regular
20994+
// expression engine knows how to handle. The third representation is fully
20995+
// unescaped, which is what we need.
20996+
const uint8_t *cursor = pm_memchr(source, '\\', length, parser->encoding_changed, parser->encoding);
20997+
if (PRISM_UNLIKELY(cursor != NULL)) {
20998+
pm_named_capture_escape(parser, &unescaped, source, length, cursor);
20999+
source = (const uint8_t *) pm_buffer_value(&unescaped);
21000+
length = pm_buffer_length(&unescaped);
21001+
}
2088321002

2088421003
pm_location_t location;
2088521004
pm_constant_id_t name;
2088621005

2088721006
// If the name of the capture group isn't a valid identifier, we do
2088821007
// not add it to the local table.
20889-
if (!pm_slice_is_valid_local(parser, source, source + length)) return;
21008+
if (!pm_slice_is_valid_local(parser, source, source + length)) {
21009+
pm_buffer_free(&unescaped);
21010+
return;
21011+
}
2089021012

2089121013
if (callback_data->shared) {
2089221014
// If the unescaped string is a slice of the source, then we can
@@ -20914,7 +21036,10 @@ parse_regular_expression_named_capture(const pm_string_t *capture, void *data) {
2091421036
if ((depth = pm_parser_local_depth_constant_id(parser, name)) == -1) {
2091521037
// If the local is not already a local but it is a keyword, then we
2091621038
// do not want to add a capture for this.
20917-
if (pm_local_is_keyword((const char *) source, length)) return;
21039+
if (pm_local_is_keyword((const char *) source, length)) {
21040+
pm_buffer_free(&unescaped);
21041+
return;
21042+
}
2091821043

2091921044
// If the identifier is not already a local, then we will add it to
2092021045
// the local table.
@@ -20932,6 +21057,8 @@ parse_regular_expression_named_capture(const pm_string_t *capture, void *data) {
2093221057
pm_node_t *target = (pm_node_t *) pm_local_variable_target_node_create(parser, &location, name, depth == -1 ? 0 : (uint32_t) depth);
2093321058
pm_node_list_append(&callback_data->match->targets, target);
2093421059
}
21060+
21061+
pm_buffer_free(&unescaped);
2093521062
}
2093621063

2093721064
/**

src/util/pm_buffer.c

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,46 @@ pm_buffer_append_double(pm_buffer_t *buffer, double value) {
172172
pm_buffer_append(buffer, source, sizeof(double));
173173
}
174174

175+
/**
176+
* Append a unicode codepoint to the buffer.
177+
*/
178+
bool
179+
pm_buffer_append_unicode_codepoint(pm_buffer_t *buffer, uint32_t value) {
180+
if (value <= 0x7F) {
181+
pm_buffer_append_byte(buffer, (uint8_t) value); // 0xxxxxxx
182+
return true;
183+
} else if (value <= 0x7FF) {
184+
uint8_t bytes[] = {
185+
(uint8_t) (0xC0 | ((value >> 6) & 0x3F)), // 110xxxxx
186+
(uint8_t) (0x80 | (value & 0x3F)) // 10xxxxxx
187+
};
188+
189+
pm_buffer_append_bytes(buffer, bytes, 2);
190+
return true;
191+
} else if (value <= 0xFFFF) {
192+
uint8_t bytes[] = {
193+
(uint8_t) (0xE0 | ((value >> 12) & 0x3F)), // 1110xxxx
194+
(uint8_t) (0x80 | ((value >> 6) & 0x3F)), // 10xxxxxx
195+
(uint8_t) (0x80 | (value & 0x3F)) // 10xxxxxx
196+
};
197+
198+
pm_buffer_append_bytes(buffer, bytes, 3);
199+
return true;
200+
} else if (value <= 0x10FFFF) {
201+
uint8_t bytes[] = {
202+
(uint8_t) (0xF0 | ((value >> 18) & 0x3F)), // 11110xxx
203+
(uint8_t) (0x80 | ((value >> 12) & 0x3F)), // 10xxxxxx
204+
(uint8_t) (0x80 | ((value >> 6) & 0x3F)), // 10xxxxxx
205+
(uint8_t) (0x80 | (value & 0x3F)) // 10xxxxxx
206+
};
207+
208+
pm_buffer_append_bytes(buffer, bytes, 4);
209+
return true;
210+
} else {
211+
return false;
212+
}
213+
}
214+
175215
/**
176216
* Append a slice of source code to the buffer.
177217
*/
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# frozen_string_literal: true
2+
3+
require_relative "../test_helper"
4+
5+
module Prism
6+
class NamedCaptureTest < TestCase
7+
def test_hex_escapes
8+
assert_equal :😀, parse_name("\\xf0\\x9f\\x98\\x80")
9+
end
10+
11+
def test_unicode_escape
12+
assert_equal :し, parse_name("\\u3057")
13+
end
14+
15+
def test_unicode_escapes_bracess
16+
assert_equal :😀, parse_name("\\u{1f600}")
17+
end
18+
19+
def test_octal_escapes
20+
assert_equal :😀, parse_name("\\xf0\\x9f\\x98\\200")
21+
end
22+
23+
private
24+
25+
def parse_name(content)
26+
Prism.parse_statement("/(?<#{content}>)/ =~ ''").targets.first.name
27+
end
28+
end
29+
end

0 commit comments

Comments
 (0)