diff --git a/ext/prism/extension.c b/ext/prism/extension.c index e8f678d341..7abd93ec62 100644 --- a/ext/prism/extension.c +++ b/ext/prism/extension.c @@ -24,6 +24,7 @@ VALUE rb_cPrismParseResult; VALUE rb_cPrismLexResult; VALUE rb_cPrismParseLexResult; VALUE rb_cPrismStringQuery; +VALUE rb_cPrismScope; VALUE rb_cPrismDebugEncoding; @@ -38,6 +39,10 @@ ID rb_id_option_partial_script; ID rb_id_option_scopes; ID rb_id_option_version; ID rb_id_source_for; +ID rb_id_forwarding_positionals; +ID rb_id_forwarding_keywords; +ID rb_id_forwarding_block; +ID rb_id_forwarding_all; /******************************************************************************/ /* IO of Ruby code */ @@ -95,14 +100,53 @@ build_options_scopes(pm_options_t *options, VALUE scopes) { for (size_t scope_index = 0; scope_index < scopes_count; scope_index++) { VALUE scope = rb_ary_entry(scopes, scope_index); - // Check that the scope is an array. If it's not, then raise a type - // error. - if (!RB_TYPE_P(scope, T_ARRAY)) { - rb_raise(rb_eTypeError, "wrong argument type %"PRIsVALUE" (expected Array)", rb_obj_class(scope)); + // The scope can be either an array or it can be a Prism::Scope object. + // Parse out the correct values here from either. + VALUE locals; + uint8_t forwarding = PM_OPTIONS_SCOPE_FORWARDING_NONE; + + if (RB_TYPE_P(scope, T_ARRAY)) { + locals = scope; + } else if (rb_obj_is_kind_of(scope, rb_cPrismScope)) { + locals = rb_ivar_get(scope, rb_intern("@locals")); + if (!RB_TYPE_P(locals, T_ARRAY)) { + rb_raise(rb_eTypeError, "wrong argument type %"PRIsVALUE" (expected Array)", rb_obj_class(locals)); + } + + VALUE names = rb_ivar_get(scope, rb_intern("@forwarding")); + if (!RB_TYPE_P(names, T_ARRAY)) { + rb_raise(rb_eTypeError, "wrong argument type %"PRIsVALUE" (expected Array)", rb_obj_class(names)); + } + + size_t names_count = RARRAY_LEN(names); + for (size_t name_index = 0; name_index < names_count; name_index++) { + VALUE name = rb_ary_entry(names, name_index); + + // Check that the name is a symbol. If it's not, then raise + // a type error. + if (!RB_TYPE_P(name, T_SYMBOL)) { + rb_raise(rb_eTypeError, "wrong argument type %"PRIsVALUE" (expected Symbol)", rb_obj_class(name)); + } + + ID id = SYM2ID(name); + if (id == rb_id_forwarding_positionals) { + forwarding |= PM_OPTIONS_SCOPE_FORWARDING_POSITIONALS; + } else if (id == rb_id_forwarding_keywords) { + forwarding |= PM_OPTIONS_SCOPE_FORWARDING_KEYWORDS; + } else if (id == rb_id_forwarding_block) { + forwarding |= PM_OPTIONS_SCOPE_FORWARDING_BLOCK; + } else if (id == rb_id_forwarding_all) { + forwarding |= PM_OPTIONS_SCOPE_FORWARDING_ALL; + } else { + rb_raise(rb_eArgError, "invalid forwarding value: %" PRIsVALUE, name); + } + } + } else { + rb_raise(rb_eTypeError, "wrong argument type %"PRIsVALUE" (expected Array or Prism::Scope)", rb_obj_class(scope)); } // Initialize the scope array. - size_t locals_count = RARRAY_LEN(scope); + size_t locals_count = RARRAY_LEN(locals); pm_options_scope_t *options_scope = &options->scopes[scope_index]; if (!pm_options_scope_init(options_scope, locals_count)) { rb_raise(rb_eNoMemError, "failed to allocate memory"); @@ -110,7 +154,7 @@ build_options_scopes(pm_options_t *options, VALUE scopes) { // Iterate over the locals and add them to the scope. for (size_t local_index = 0; local_index < locals_count; local_index++) { - VALUE local = rb_ary_entry(scope, local_index); + VALUE local = rb_ary_entry(locals, local_index); // Check that the local is a symbol. If it's not, then raise a // type error. @@ -123,6 +167,9 @@ build_options_scopes(pm_options_t *options, VALUE scopes) { const char *name = rb_id2name(SYM2ID(local)); pm_string_constant_init(scope_local, name, strlen(name)); } + + // Now set the forwarding options. + pm_options_scope_forwarding_set(options_scope, forwarding); } } @@ -1302,6 +1349,7 @@ Init_prism(void) { rb_cPrismLexResult = rb_define_class_under(rb_cPrism, "LexResult", rb_cPrismResult); rb_cPrismParseLexResult = rb_define_class_under(rb_cPrism, "ParseLexResult", rb_cPrismResult); rb_cPrismStringQuery = rb_define_class_under(rb_cPrism, "StringQuery", rb_cObject); + rb_cPrismScope = rb_define_class_under(rb_cPrism, "Scope", rb_cObject); // Intern all of the IDs eagerly that we support so that we don't have to do // it every time we parse. @@ -1316,6 +1364,10 @@ Init_prism(void) { rb_id_option_scopes = rb_intern_const("scopes"); rb_id_option_version = rb_intern_const("version"); rb_id_source_for = rb_intern("for"); + rb_id_forwarding_positionals = rb_intern("*"); + rb_id_forwarding_keywords = rb_intern("**"); + rb_id_forwarding_block = rb_intern("&"); + rb_id_forwarding_all = rb_intern("..."); /** * The version of the prism library. diff --git a/include/prism/options.h b/include/prism/options.h index 45eb81caa8..2f64701b0c 100644 --- a/include/prism/options.h +++ b/include/prism/options.h @@ -39,8 +39,26 @@ typedef struct pm_options_scope { /** The names of the locals in the scope. */ pm_string_t *locals; + + /** Flags for the set of forwarding parameters in this scope. */ + uint8_t forwarding; } pm_options_scope_t; +/** The default value for parameters. */ +static const uint8_t PM_OPTIONS_SCOPE_FORWARDING_NONE = 0x0; + +/** When the scope is fowarding with the * parameter. */ +static const uint8_t PM_OPTIONS_SCOPE_FORWARDING_POSITIONALS = 0x1; + +/** When the scope is fowarding with the ** parameter. */ +static const uint8_t PM_OPTIONS_SCOPE_FORWARDING_KEYWORDS = 0x2; + +/** When the scope is fowarding with the & parameter. */ +static const uint8_t PM_OPTIONS_SCOPE_FORWARDING_BLOCK = 0x4; + +/** When the scope is fowarding with the ... parameter. */ +static const uint8_t PM_OPTIONS_SCOPE_FORWARDING_ALL = 0x8; + // Forward declaration needed by the callback typedef. struct pm_options; @@ -337,6 +355,14 @@ PRISM_EXPORTED_FUNCTION bool pm_options_scope_init(pm_options_scope_t *scope, si */ PRISM_EXPORTED_FUNCTION const pm_string_t * pm_options_scope_local_get(const pm_options_scope_t *scope, size_t index); +/** + * Set the forwarding option on the given scope struct. + * + * @param scope The scope struct to set the forwarding on. + * @param forwarding The forwarding value to set. + */ +PRISM_EXPORTED_FUNCTION void pm_options_scope_forwarding_set(pm_options_scope_t *scope, uint8_t forwarding); + /** * Free the internal memory associated with the options. * @@ -386,6 +412,7 @@ PRISM_EXPORTED_FUNCTION void pm_options_free(pm_options_t *options); * | # bytes | field | * | ------- | -------------------------- | * | `4` | the number of locals | + * | `1` | the forwarding flags | * | ... | the locals | * * Each local is laid out as follows: diff --git a/java/org/prism/ParsingOptions.java b/java/org/prism/ParsingOptions.java index 0fc9d03e47..ff2e995e6b 100644 --- a/java/org/prism/ParsingOptions.java +++ b/java/org/prism/ParsingOptions.java @@ -35,6 +35,69 @@ public byte getValue() { */ public enum CommandLine { A, E, L, N, P, X }; + /** + * The forwarding options for a given scope in the parser. + */ + public enum Forwarding { + NONE(0), + POSITIONAL(1), + KEYWORD(2), + BLOCK(4), + ALL(8); + + private final int value; + + Forwarding(int value) { + this.value = value; + } + + public byte getValue() { + return (byte) value; + } + }; + + /** + * Represents a scope in the parser. + */ + public static class Scope { + private byte[][] locals; + private Forwarding[] forwarding; + + Scope(byte[][] locals) { + this(locals, new Forwarding[0]); + } + + Scope(Forwarding[] forwarding) { + this(new byte[0][], forwarding); + } + + Scope(byte[][] locals, Forwarding[] forwarding) { + this.locals = locals; + this.forwarding = forwarding; + } + + public byte[][] getLocals() { + return locals; + } + + public int getForwarding() { + int value = 0; + for (Forwarding f : forwarding) { + value |= f.getValue(); + } + return value; + } + } + + public static byte[] serialize(byte[] filepath, int line, byte[] encoding, boolean frozenStringLiteral, EnumSet commandLine, SyntaxVersion version, boolean encodingLocked, boolean mainScript, boolean partialScript, byte[][][] scopes) { + Scope[] normalizedScopes = new Scope[scopes.length]; + for (int i = 0; i < scopes.length; i++) { + normalizedScopes[i] = new Scope(scopes[i]); + } + + return serialize(filepath, line, encoding, frozenStringLiteral, commandLine, version, encodingLocked, mainScript, partialScript, normalizedScopes); + } + /** * Serialize parsing options into byte array. * @@ -50,7 +113,7 @@ public enum CommandLine { A, E, L, N, P, X }; * @param scopes scopes surrounding the code that is being parsed with local variable names defined in every scope * ordered from the outermost scope to the innermost one */ - public static byte[] serialize(byte[] filepath, int line, byte[] encoding, boolean frozenStringLiteral, EnumSet commandLine, SyntaxVersion version, boolean encodingLocked, boolean mainScript, boolean partialScript, byte[][][] scopes) { + public static byte[] serialize(byte[] filepath, int line, byte[] encoding, boolean frozenStringLiteral, EnumSet commandLine, SyntaxVersion version, boolean encodingLocked, boolean mainScript, boolean partialScript, Scope[] scopes) { final ByteArrayOutputStream output = new ByteArrayOutputStream(); // filepath @@ -91,12 +154,17 @@ public static byte[] serialize(byte[] filepath, int line, byte[] encoding, boole write(output, serializeInt(scopes.length)); // local variables in each scope - for (byte[][] scope : scopes) { + for (Scope scope : scopes) { + byte[][] locals = scope.getLocals(); + // number of locals - write(output, serializeInt(scope.length)); + write(output, serializeInt(locals.length)); + + // forwarding flags + output.write(scope.getForwarding()); // locals - for (byte[] local : scope) { + for (byte[] local : locals) { write(output, serializeInt(local.length)); write(output, local); } diff --git a/javascript/src/index.js b/javascript/src/index.js index 8c361b5416..3c40c4c16b 100644 --- a/javascript/src/index.js +++ b/javascript/src/index.js @@ -11,7 +11,9 @@ export * from "./nodes.js"; /** * Load the prism wasm module and return a parse function. * - * @returns {Promise<(source: string) => ParseResult>} + * @typedef {import("./parsePrism.js").Options} Options + * + * @returns {Promise<(source: string, options?: Options) => ParseResult>} */ export async function loadPrism() { const wasm = await WebAssembly.compile(await readFile(fileURLToPath(new URL("prism.wasm", import.meta.url)))); @@ -20,7 +22,7 @@ export async function loadPrism() { const instance = await WebAssembly.instantiate(wasm, wasi.getImportObject()); wasi.initialize(instance); - return function (source) { - return parsePrism(instance.exports, source); + return function (source, options = {}) { + return parsePrism(instance.exports, source, options); } } diff --git a/javascript/src/parsePrism.js b/javascript/src/parsePrism.js index 1d0233e9c9..535e6e9a0f 100644 --- a/javascript/src/parsePrism.js +++ b/javascript/src/parsePrism.js @@ -3,9 +3,26 @@ import { ParseResult, deserialize } from "./deserialize.js"; /** * Parse the given source code. * + * @typedef {{ + * locals?: string[], + * forwarding?: string[] + * }} Scope + * + * @typedef {{ + * filepath?: string, + * line?: number, + * encoding?: string | false, + * frozen_string_literal?: boolean, + * command_line?: string, + * version?: string, + * main_script?: boolean, + * partial_script?: boolean, + * scopes?: (string[] | Scope)[] + * }} Options + * * @param {WebAssembly.Exports} prism * @param {string} source - * @param {Object} options + * @param {Options} options * @returns {ParseResult} */ export function parsePrism(prism, source, options = {}) { @@ -35,7 +52,12 @@ export function parsePrism(prism, source, options = {}) { return result; } -// Dump the command line options into a serialized format. +/** + * Dump the command line options into a serialized format. + * + * @param {Options} options + * @returns {number} + */ function dumpCommandLineOptions(options) { if (options.command_line === undefined) { return 0; @@ -62,15 +84,30 @@ function dumpCommandLineOptions(options) { return value; } -// Convert a boolean value into a serialized byte. +/** + * Convert a boolean value into a serialized byte. + * + * @param {boolean} value + * @returns {number} + */ function dumpBooleanOption(value) { return (value === undefined || value === false || value === null) ? 0 : 1; } -// Converts the given options into a serialized options string. +/** + * Converts the given options into a serialized options string. + * + * @param {Options} options + * @returns {Uint8Array} + */ function dumpOptions(options) { - const values = []; + /** @type {PackTemplate} */ const template = []; + + /** @type {PackValues} */ + const values = []; + + /** @type {TextEncoder} */ const encoder = new TextEncoder(); template.push("L") @@ -131,16 +168,37 @@ function dumpOptions(options) { values.push(scopes.length); for (const scope of scopes) { + let locals; + let forwarding = 0; + + if (Array.isArray(scope)) { + locals = scope; + } else { + locals = scope.locals || []; + + for (const forward of (scope.forwarding || [])) { + switch (forward) { + case "*": forwarding |= 0x1; break; + case "**": forwarding |= 0x2; break; + case "&": forwarding |= 0x4; break; + case "...": forwarding |= 0x8; break; + default: throw new Error(`invalid forwarding value: ${forward}`); + } + } + } + template.push("L"); - values.push(scope.length); + values.push(locals.length); + + template.push("C"); + values.push(forwarding); - for (const local of scope) { - const name = local.name; + for (const local of locals) { template.push("L"); - values.push(name.length); + values.push(local.length); template.push("A") - values.push(encoder.encode(name)); + values.push(encoder.encode(local)); } } } else { @@ -150,40 +208,17 @@ function dumpOptions(options) { return pack(values, template); } -function totalSizeOf(values, template) { - let size = 0; - - for (let i = 0; i < values.length; i ++) { - size += sizeOf(values, template, i); - } - - return size; -} - -function sizeOf(values, template, index) { - switch (template[index]) { - // arbitrary binary string - case "A": - return values[index].length; - - // l: signed 32-bit integer, L: unsigned 32-bit integer - case "l": - case "L": - return 4; - - // 8-bit unsigned integer - case "C": - return 1; - } -} - -// platform-agnostic implementation of Node's os.endianness() -function endianness() { - const arr = new Uint8Array(4); - new Uint32Array(arr.buffer)[0] = 0xffcc0011; - return arr[0] === 0xff ? "BE" : "LE"; -} - +/** + * Pack the given values using the given template. This function matches a + * subset of the functionality from Ruby's Array#pack method. + * + * @typedef {(number | string)[]} PackValues + * @typedef {string[]} PackTemplate + * + * @param {PackValues} values + * @param {PackTemplate} template + * @returns {Uint8Array} + */ function pack(values, template) { const littleEndian = endianness() === "LE"; const buffer = new ArrayBuffer(totalSizeOf(values, template)); @@ -217,3 +252,57 @@ function pack(values, template) { return new Uint8Array(buffer); } + +/** + * Returns the total size of the given values in bytes. + * + * @param {PackValues} values + * @param {PackTemplate} template + * @returns {number} + */ +function totalSizeOf(values, template) { + let size = 0; + + for (let i = 0; i < values.length; i ++) { + size += sizeOf(values, template, i); + } + + return size; +} + +/** + * Returns the size of the given value inside the list of values at the + * specified index in bytes. + * + * @param {PackValues} values + * @param {PackTemplate} template + * @param {number} index + * @returns {number} + */ +function sizeOf(values, template, index) { + switch (template[index]) { + // arbitrary binary string + case "A": + return values[index].length; + + // l: signed 32-bit integer, L: unsigned 32-bit integer + case "l": + case "L": + return 4; + + // 8-bit unsigned integer + case "C": + return 1; + } +} + +/** + * Platform-agnostic implementation of Node's os.endianness(). + * + * @returns {"BE" | "LE"} + */ +function endianness() { + const arr = new Uint8Array(4); + new Uint32Array(arr.buffer)[0] = 0xffcc0011; + return arr[0] === 0xff ? "BE" : "LE"; +} diff --git a/javascript/test.js b/javascript/test.js index 58801708cc..ce0e3d9033 100644 --- a/javascript/test.js +++ b/javascript/test.js @@ -5,6 +5,10 @@ import * as nodes from "./src/nodes.js"; const parse = await loadPrism(); +function statement(result) { + return result.value.statements.body[0]; +} + test("node", () => { const result = parse("foo"); assert(result.value instanceof nodes.ProgramNode); @@ -12,12 +16,12 @@ test("node", () => { test("node? present", () => { const result = parse("foo.bar"); - assert(result.value.statements.body[0].receiver instanceof nodes.CallNode); + assert(statement(result).receiver instanceof nodes.CallNode); }); test("node? absent", () => { const result = parse("foo"); - assert(result.value.statements.body[0].receiver === null); + assert(statement(result).receiver === null); }); test("node[]", () => { @@ -27,7 +31,7 @@ test("node[]", () => { test("string", () => { const result = parse('"foo"'); - const node = result.value.statements.body[0]; + const node = statement(result); assert(!node.isForcedUtf8Encoding()) assert(!node.isForcedBinaryEncoding()) @@ -39,7 +43,7 @@ test("string", () => { test("forced utf-8 string using \\u syntax", () => { const result = parse('# encoding: utf-8\n"\\u{9E7F}"'); - const node = result.value.statements.body[0]; + const node = statement(result); const str = node.unescaped; assert(node.isForcedUtf8Encoding()); @@ -52,7 +56,7 @@ test("forced utf-8 string using \\u syntax", () => { test("forced utf-8 string with invalid byte sequence", () => { const result = parse('# encoding: utf-8\n"\\xFF\\xFF\\xFF"'); - const node = result.value.statements.body[0]; + const node = statement(result); const str = node.unescaped; assert(node.isForcedUtf8Encoding()); @@ -68,7 +72,7 @@ test("ascii string with embedded utf-8 character", () => { // # encoding: ascii\n"鹿"' const ascii_str = new Buffer.from([35, 32, 101, 110, 99, 111, 100, 105, 110, 103, 58, 32, 97, 115, 99, 105, 105, 10, 34, 233, 185, 191, 34]); const result = parse(ascii_str); - const node = result.value.statements.body[0]; + const node = statement(result); const str = node.unescaped; assert(!node.isForcedUtf8Encoding()); @@ -81,7 +85,7 @@ test("ascii string with embedded utf-8 character", () => { test("forced binary string", () => { const result = parse('# encoding: ascii\n"\\xFF\\xFF\\xFF"'); - const node = result.value.statements.body[0]; + const node = statement(result); const str = node.unescaped; assert(!node.isForcedUtf8Encoding()); @@ -96,7 +100,7 @@ test("forced binary string with Unicode character", () => { // # encoding: us-ascii\n"\\xFF鹿\\xFF" const ascii_str = Buffer.from([35, 32, 101, 110, 99, 111, 100, 105, 110, 103, 58, 32, 97, 115, 99, 105, 105, 10, 34, 92, 120, 70, 70, 233, 185, 191, 92, 120, 70, 70, 34]); const result = parse(ascii_str); - const node = result.value.statements.body[0]; + const node = statement(result); const str = node.unescaped; assert(!node.isForcedUtf8Encoding()); @@ -114,12 +118,12 @@ test("constant", () => { test("constant? present", () => { const result = parse("def foo(*bar); end"); - assert(result.value.statements.body[0].parameters.rest.name === "bar"); + assert(statement(result).parameters.rest.name === "bar"); }); test("constant? absent", () => { const result = parse("def foo(*); end"); - assert(result.value.statements.body[0].parameters.rest.name === null); + assert(statement(result).parameters.rest.name === null); }); test("constant[]", async() => { @@ -134,27 +138,27 @@ test("location", () => { test("location? present", () => { const result = parse("def foo = bar"); - assert(result.value.statements.body[0].equalLoc !== null); + assert(statement(result).equalLoc !== null); }); test("location? absent", () => { const result = parse("def foo; bar; end"); - assert(result.value.statements.body[0].equalLoc === null); + assert(statement(result).equalLoc === null); }); test("uint8", () => { const result = parse("-> { _3 }"); - assert(result.value.statements.body[0].parameters.maximum === 3); + assert(statement(result).parameters.maximum === 3); }); test("uint32", () => { const result = parse("foo = 1"); - assert(result.value.statements.body[0].depth === 0); + assert(statement(result).depth === 0); }); test("flags", () => { const result = parse("/foo/mi"); - const regexp = result.value.statements.body[0]; + const regexp = statement(result); assert(regexp.isIgnoreCase()); assert(regexp.isMultiLine()); @@ -163,25 +167,44 @@ test("flags", () => { test("integer (decimal)", () => { const result = parse("10"); - assert(result.value.statements.body[0].value == 10); + assert(statement(result).value === 10); }); test("integer (hex)", () => { const result = parse("0xA"); - assert(result.value.statements.body[0].value == 10); + assert(statement(result).value === 10); }); test("integer (2 nodes)", () => { const result = parse("4294967296"); - assert(result.value.statements.body[0].value == 4294967296n); + assert(statement(result).value === 4294967296n); }); test("integer (3 nodes)", () => { const result = parse("18446744073709552000"); - assert(result.value.statements.body[0].value == 18446744073709552000n); + assert(statement(result).value === 18446744073709552000n); }); test("double", () => { const result = parse("1.0"); - assert(result.value.statements.body[0].value == 1.0); + assert(statement(result).value === 1.0); +}); + +test("scopes", () => { + let result; + + result = parse("foo"); + assert(statement(result) instanceof nodes.CallNode); + + result = parse("foo", { scopes: [["foo"]] }); + assert(statement(result) instanceof nodes.LocalVariableReadNode); + + result = parse("foo", { scopes: [{ locals: ["foo"] }] }); + assert(statement(result) instanceof nodes.LocalVariableReadNode); + + result = parse("foo(*)"); + assert(result.errors.length > 0); + + result = parse("foo(*)", { scopes: [{ forwarding: ["*"] }] }); + assert(result.errors.length === 0); }); diff --git a/lib/prism/ffi.rb b/lib/prism/ffi.rb index eda61b3ead..35b91e41b2 100644 --- a/lib/prism/ffi.rb +++ b/lib/prism/ffi.rb @@ -478,10 +478,35 @@ def dump_options(options) values << scopes.length scopes.each do |scope| + locals = nil + forwarding = 0 + + case scope + when Array + locals = scope + when Scope + locals = scope.locals + + scope.forwarding.each do |forward| + case forward + when :* then forwarding |= 0x1 + when :** then forwarding |= 0x2 + when :& then forwarding |= 0x4 + when :"..." then forwarding |= 0x8 + else raise ArgumentError, "invalid forwarding value: #{forward}" + end + end + else + raise TypeError, "wrong argument type #{scope.class.inspect} (expected Array or Prism::Scope)" + end + template << "L" - values << scope.length + values << locals.length + + template << "C" + values << forwarding - scope.each do |local| + locals.each do |local| name = local.name template << "L" values << name.bytesize diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index e76ea7e17e..9a3e7c5b79 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -879,4 +879,32 @@ def deep_freeze freeze end end + + # This object is passed to the various Prism.* methods that accept the + # `scopes` option as an element of the list. It defines both the local + # variables visible at that scope as well as the forwarding parameters + # available at that scope. + class Scope + # The list of local variables that are defined in this scope. This should be + # defined as an array of symbols. + attr_reader :locals + + # The list of local variables that are forwarded to the next scope. This + # should by defined as an array of symbols containing the specific values of + # :*, :**, :&, or :"...". + attr_reader :forwarding + + # Create a new scope object with the given locals and forwarding. + def initialize(locals, forwarding) + @locals = locals + @forwarding = forwarding + end + end + + # Create a new scope with the given locals and forwarding options that is + # suitable for passing into one of the Prism.* methods that accepts the + # `scopes` option. + def self.scope(locals: [], forwarding: []) + Scope.new(locals, forwarding) + end end diff --git a/rbi/prism.rbi b/rbi/prism.rbi index 1ba5d1fc3d..8866e7b3f2 100644 --- a/rbi/prism.rbi +++ b/rbi/prism.rbi @@ -60,4 +60,7 @@ module Prism sig { params(filepath: String, command_line: T.nilable(String), encoding: T.nilable(T.any(FalseClass, Encoding)), freeze: T.nilable(T::Boolean), frozen_string_literal: T.nilable(T::Boolean), line: T.nilable(Integer), main_script: T.nilable(T::Boolean), partial_script: T.nilable(T::Boolean), scopes: T.nilable(T::Array[T::Array[Symbol]]), version: T.nilable(String)).returns(T::Boolean) } def self.parse_file_failure?(filepath, command_line: nil, encoding: nil, freeze: nil, frozen_string_literal: nil, line: nil, main_script: nil, partial_script: nil, scopes: nil, version: nil); end + + sig { params(locals: T::Array[Symbol], forwarding: T::Array[Symbol]).returns(Prism::Scope) } + def self.scope(locals: [], forwarding: []); end end diff --git a/rbi/prism/parse_result.rbi b/rbi/prism/parse_result.rbi index 8d52ed3daf..6f2bbb6146 100644 --- a/rbi/prism/parse_result.rbi +++ b/rbi/prism/parse_result.rbi @@ -391,3 +391,14 @@ class Prism::Token sig { params(other: T.untyped).returns(T::Boolean) } def ==(other); end end + +class Prism::Scope + sig { returns(T::Array[Symbol]) } + def locals; end + + sig { returns(T::Array[Symbol]) } + def forwarding; end + + sig { params(locals: T::Array[Symbol], forwarding: T::Array[Symbol]).void } + def initialize(locals, forwarding); end +end diff --git a/sig/prism/parse_result.rbs b/sig/prism/parse_result.rbs index 164421114f..8dea016a68 100644 --- a/sig/prism/parse_result.rbs +++ b/sig/prism/parse_result.rbs @@ -182,4 +182,11 @@ module Prism def pretty_print: (untyped q) -> untyped def ==: (untyped other) -> bool end + + class Scope + attr_reader locals: Array[Symbol] + attr_reader forwarding: Array[Symbol] + + def initialize: (Array[Symbol] locals, Array[Symbol] forwarding) -> void + end end diff --git a/src/options.c b/src/options.c index b5be140820..a457178ce8 100644 --- a/src/options.c +++ b/src/options.c @@ -181,6 +181,7 @@ PRISM_EXPORTED_FUNCTION bool pm_options_scope_init(pm_options_scope_t *scope, size_t locals_count) { scope->locals_count = locals_count; scope->locals = xcalloc(locals_count, sizeof(pm_string_t)); + scope->forwarding = PM_OPTIONS_SCOPE_FORWARDING_NONE; return scope->locals != NULL; } @@ -192,6 +193,14 @@ pm_options_scope_local_get(const pm_options_scope_t *scope, size_t index) { return &scope->locals[index]; } +/** + * Set the forwarding option on the given scope struct. + */ +PRISM_EXPORTED_FUNCTION void +pm_options_scope_forwarding_set(pm_options_scope_t *scope, uint8_t forwarding) { + scope->forwarding = forwarding; +} + /** * Free the internal memory associated with the options. */ @@ -300,6 +309,9 @@ pm_options_read(pm_options_t *options, const char *data) { return; } + uint8_t forwarding = (uint8_t) *data++; + pm_options_scope_forwarding_set(&options->scopes[scope_index], forwarding); + for (size_t local_index = 0; local_index < locals_count; local_index++) { uint32_t local_length = pm_options_read_u32(data); data += 4; diff --git a/src/prism.c b/src/prism.c index e2654be228..3cfcdd8be5 100644 --- a/src/prism.c +++ b/src/prism.c @@ -22492,7 +22492,7 @@ pm_parser_init(pm_parser_t *parser, const uint8_t *source, size_t size, const pm // Scopes given from the outside are not allowed to have numbered // parameters. - parser->current_scope->parameters |= PM_SCOPE_PARAMETERS_IMPLICIT_DISALLOWED; + parser->current_scope->parameters = ((pm_scope_parameters_t) scope->forwarding) | PM_SCOPE_PARAMETERS_IMPLICIT_DISALLOWED; for (size_t local_index = 0; local_index < scope->locals_count; local_index++) { const pm_string_t *local = pm_options_scope_local_get(scope, local_index); diff --git a/templates/sig/prism.rbs.erb b/templates/sig/prism.rbs.erb index 96cadec9dd..2f30cbc29f 100644 --- a/templates/sig/prism.rbs.erb +++ b/templates/sig/prism.rbs.erb @@ -84,4 +84,6 @@ module Prism ?scopes: Array[Array[Symbol]], ?verbose: bool ) -> ParseResult + + def self.scope: (?locals: Array[Symbol], ?forwarding: Array[Symbol]) -> Scope end diff --git a/test/prism/api/parse_test.rb b/test/prism/api/parse_test.rb index 55b2731225..bbce8a8fad 100644 --- a/test/prism/api/parse_test.rb +++ b/test/prism/api/parse_test.rb @@ -140,6 +140,24 @@ def test_version end end + def test_scopes + assert_kind_of Prism::CallNode, Prism.parse_statement("foo") + assert_kind_of Prism::LocalVariableReadNode, Prism.parse_statement("foo", scopes: [[:foo]]) + assert_kind_of Prism::LocalVariableReadNode, Prism.parse_statement("foo", scopes: [Prism.scope(locals: [:foo])]) + + assert Prism.parse_failure?("foo(*)") + assert Prism.parse_success?("foo(*)", scopes: [Prism.scope(forwarding: [:*])]) + + assert Prism.parse_failure?("foo(**)") + assert Prism.parse_success?("foo(**)", scopes: [Prism.scope(forwarding: [:**])]) + + assert Prism.parse_failure?("foo(&)") + assert Prism.parse_success?("foo(&)", scopes: [Prism.scope(forwarding: [:&])]) + + assert Prism.parse_failure?("foo(...)") + assert Prism.parse_success?("foo(...)", scopes: [Prism.scope(forwarding: [:"..."])]) + end + private def find_source_file_node(program)