diff --git a/binding.gyp b/binding.gyp index 01eb48a..cb288d4 100644 --- a/binding.gyp +++ b/binding.gyp @@ -12,6 +12,7 @@ "src/language.cc", "src/logger.cc", "src/lookaheaditerator.cc", + "src/query_iterator.cc", "src/node.cc", "src/parser.cc", "src/query.cc", diff --git a/index.js b/index.js index b2655b5..bc887d3 100644 --- a/index.js +++ b/index.js @@ -1,5 +1,5 @@ const binding = require('node-gyp-build')(__dirname); -const {Query, Parser, NodeMethods, Tree, TreeCursor, LookaheadIterator} = binding; +const {Query, QueryIterator, Parser, NodeMethods, Tree, TreeCursor, LookaheadIterator} = binding; const util = require('util'); @@ -15,10 +15,10 @@ Object.defineProperty(Tree.prototype, 'rootNode', { Due to a race condition arising from Jest's worker pool, "this" has no knowledge of the native extension if the extension has not yet loaded when multiple Jest tests are being run simultaneously. - If the extension has correctly loaded, "this" should be an instance + If the extension has correctly loaded, "this" should be an instance of the class whose prototype we are acting on (in this case, Tree). - Furthermore, the race condition sometimes results in the function in - question being undefined even when the context is correct, so we also + Furthermore, the race condition sometimes results in the function in + question being undefined even when the context is correct, so we also perform a null function check. */ if (this instanceof Tree && rootNode) { @@ -420,7 +420,7 @@ Object.defineProperties(TreeCursor.prototype, { * Query */ -const {_matches, _captures} = Query.prototype; +const {_matches, _matchesIter, _captures, _capturesIter} = Query.prototype; const PREDICATE_STEP_TYPE = { DONE: 0, @@ -660,6 +660,25 @@ Query.prototype.matches = function( return results; } +Query.prototype.matchesIter = function( + node, + { + startPosition = ZERO_POINT, + endPosition = ZERO_POINT, + startIndex = 0, + endIndex = 0, + matchLimit = 0xFFFFFFFF, + maxStartDepth = 0xFFFFFFFF + } = {} +) { + marshalNode(node); + return _matchesIter.call(this, node.tree, + startPosition.row, startPosition.column, + endPosition.row, endPosition.column, + startIndex, endIndex, matchLimit, maxStartDepth + ); +} + Query.prototype.captures = function( node, { @@ -710,6 +729,95 @@ Query.prototype.captures = function( return results; } +Query.prototype.capturesIter = function( + node, + { + startPosition = ZERO_POINT, + endPosition = ZERO_POINT, + startIndex = 0, + endIndex = 0, + matchLimit = 0xFFFFFFFF, + maxStartDepth = 0xFFFFFFFF + } = {} +) { + marshalNode(node); + return _capturesIter.call(this, node.tree, + startPosition.row, startPosition.column, + endPosition.row, endPosition.column, + startIndex, endIndex, matchLimit, maxStartDepth + ); +} + +/** + * QueryIterator + */ + +const { _next: _matchesNext } = QueryIterator.prototype; + +QueryIterator.prototype.next = function () { + while (true) { + const { done, value } = _matchesNext.call(this); + if (done) { + return { done }; + } + const [returnedMatches, returnedNodes] = value; + const nodes = unmarshalNodes(returnedNodes, this.tree); + + let i = 0 + let nodeIndex = 0; + if (this.captures) { + while (i < returnedMatches.length) { + const patternIndex = returnedMatches[i++]; + const captureIndex = returnedMatches[i++]; + const captures = []; + + while (i < returnedMatches.length && typeof returnedMatches[i] === 'string') { + const captureName = returnedMatches[i++]; + captures.push({ + name: captureName, + node: nodes[nodeIndex++], + }) + } + + if (this.query.predicates[patternIndex].every(p => p(captures))) { + const result = captures[captureIndex]; + const setProperties = this.query.setProperties[patternIndex]; + const assertedProperties = this.query.assertedProperties[patternIndex]; + const refutedProperties = this.query.refutedProperties[patternIndex]; + if (setProperties) result.setProperties = setProperties; + if (assertedProperties) result.assertedProperties = assertedProperties; + if (refutedProperties) result.refutedProperties = refutedProperties; + return { value: result }; + } + } + } else { + while (i < returnedMatches.length) { + const patternIndex = returnedMatches[i++]; + const captures = []; + + while (i < returnedMatches.length && typeof returnedMatches[i] === 'string') { + const captureName = returnedMatches[i++]; + captures.push({ + name: captureName, + node: nodes[nodeIndex++], + }) + } + + if (this.query.predicates[patternIndex].every(p => p(captures))) { + const result = {pattern: patternIndex, captures}; + const setProperties = this.query.setProperties[patternIndex]; + const assertedProperties = this.query.assertedProperties[patternIndex]; + const refutedProperties = this.query.refutedProperties[patternIndex]; + if (setProperties) result.setProperties = setProperties; + if (assertedProperties) result.assertedProperties = assertedProperties; + if (refutedProperties) result.refutedProperties = refutedProperties; + return { value: result }; + } + } + } + } +} + /* * LookaheadIterator */ diff --git a/src/addon_data.h b/src/addon_data.h index 36ecb29..f67b7f1 100644 --- a/src/addon_data.h +++ b/src/addon_data.h @@ -40,6 +40,9 @@ class AddonData final { // lookaheaditerator Napi::FunctionReference lookahead_iterator_constructor; + + // matches_iterator + Napi::FunctionReference query_iterator_constructor; }; } // namespace node_tree_sitter diff --git a/src/binding.cc b/src/binding.cc index 0ccee9b..f3e0946 100644 --- a/src/binding.cc +++ b/src/binding.cc @@ -5,6 +5,7 @@ #include "./node.h" #include "./parser.h" #include "./query.h" +#include "./query_iterator.h" #include "./tree.h" #include "./tree_cursor.h" @@ -24,6 +25,7 @@ Napi::Object InitAll(Napi::Env env, Napi::Object exports) { LookaheadIterator::Init(env, exports); Parser::Init(env, exports); Query::Init(env, exports); + QueryIterator::Init(env, exports); Tree::Init(env, exports); TreeCursor::Init(env, exports); diff --git a/src/query.cc b/src/query.cc index 1c80739..9d8ba3e 100644 --- a/src/query.cc +++ b/src/query.cc @@ -39,7 +39,9 @@ void Query::Init(Napi::Env env, Napi::Object exports) { InstanceAccessor("matchLimit", &Query::MatchLimit, nullptr, napi_default_method), InstanceMethod("_matches", &Query::Matches, napi_default_method), + InstanceMethod("_matchesIter", &Query::MatchesIter, napi_default_method), InstanceMethod("_captures", &Query::Captures, napi_default_method), + InstanceMethod("_capturesIter", &Query::CapturesIter, napi_default_method), InstanceMethod("_getPredicates", &Query::GetPredicates, napi_default_method), InstanceMethod("disableCapture", &Query::DisableCapture, napi_default_method), InstanceMethod("disablePattern", &Query::DisablePattern, napi_default_method), @@ -247,6 +249,18 @@ Napi::Value Query::Matches(const Napi::CallbackInfo &info) { return result; } +Napi::Value Query::MatchesIter(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + auto *data = info.Env().GetInstanceData(); + size_t argc = info.Length(); + std::vector args(argc + 2); + args[0] = Boolean::New(env, false); + args[1] = info.This(); + napi_status status = napi_get_cb_info(env, static_cast(info), &argc, args.data() + 2, nullptr, nullptr); + NAPI_THROW_IF_FAILED_VOID(env, status); + return data->query_iterator_constructor.New(args); +} + Napi::Value Query::Captures(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); auto *data = env.GetInstanceData(); @@ -337,6 +351,18 @@ Napi::Value Query::Captures(const Napi::CallbackInfo &info) { return result; } +Napi::Value Query::CapturesIter(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + auto *data = info.Env().GetInstanceData(); + size_t argc = info.Length(); + std::vector args(argc + 2); + args[0] = Boolean::New(env, true); + args[1] = info.This(); + napi_status status = napi_get_cb_info(env, static_cast(info), &argc, args.data() + 2, nullptr, nullptr); + NAPI_THROW_IF_FAILED_VOID(env, status); + return data->query_iterator_constructor.New(args); +} + Napi::Value Query::DisableCapture(const Napi::CallbackInfo &info) { std::string string = info[0].As().Utf8Value(); const char *capture_name = string.c_str(); diff --git a/src/query.h b/src/query.h index 89d7dc3..b3ff26a 100644 --- a/src/query.h +++ b/src/query.h @@ -5,7 +5,6 @@ #include "tree_sitter/api.h" #include -#include namespace node_tree_sitter { @@ -17,12 +16,18 @@ class Query final : public Napi::ObjectWrap { explicit Query(const Napi::CallbackInfo &info); ~Query() final; + const TSQuery *Get() const { + return query_; + } + private: TSQuery *query_; Napi::Value New(const Napi::CallbackInfo &); Napi::Value Matches(const Napi::CallbackInfo &); Napi::Value Captures(const Napi::CallbackInfo &); + Napi::Value MatchesIter(const Napi::CallbackInfo &); + Napi::Value CapturesIter(const Napi::CallbackInfo &); Napi::Value GetPredicates(const Napi::CallbackInfo &); Napi::Value DisableCapture(const Napi::CallbackInfo &); Napi::Value DisablePattern(const Napi::CallbackInfo &); diff --git a/src/query_iterator.cc b/src/query_iterator.cc new file mode 100644 index 0000000..887ca58 --- /dev/null +++ b/src/query_iterator.cc @@ -0,0 +1,176 @@ +#include "query_iterator.h" +#include "query.h" +#include "tree.h" +#include "node.h" + +#include + +using namespace Napi; + +namespace node_tree_sitter { + +void QueryIterator::Init(Napi::Env env, Napi::Object exports) { + Function ctor = DefineClass(env, "QueryIterator", { + InstanceMethod<&QueryIterator::Iterator>(Symbol::WellKnown(env, "iterator"), napi_default_method), + InstanceMethod<&QueryIterator::Next>("_next", napi_default_method), + InstanceAccessor<&QueryIterator::GetCaptures>("captures", static_cast(napi_enumerable | napi_configurable)), + InstanceAccessor<&QueryIterator::GetQuery>("query", static_cast(napi_enumerable | napi_configurable)), + InstanceAccessor<&QueryIterator::GetTree>("tree", static_cast(napi_enumerable | napi_configurable)), + }); + + auto *data = env.GetInstanceData(); + data->query_iterator_constructor = Napi::Persistent(ctor); + exports["QueryIterator"] = ctor; +} + +QueryIterator::QueryIterator(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) { + Napi::Env env = info.Env(); + + if (!info[0].IsBoolean()) { + throw Error::New(env, "Missing argument captures"); + } + captures_ = info[0].As().Value(); + + const Query *query = Query::UnwrapQuery(info[1]); + if (query == nullptr) { + throw Error::New(env, "Missing argument query"); + } + query_ = Napi::Persistent(info[1]); + + const Tree *tree = Tree::UnwrapTree(info[2]); + if (tree == nullptr) { + throw Error::New(env, "Missing argument tree"); + } + tree_ = Napi::Persistent(info[2]); + + uint32_t start_row = 0, start_column = 0, end_row = 0, end_column = 0, start_index = 0, end_index = 0, + match_limit = UINT32_MAX, max_start_depth = UINT32_MAX; + + if (info.Length() > 3 && info[3].IsNumber()) { + start_row = info[3].As().Uint32Value(); + } + if (info.Length() > 4 && info[4].IsNumber()) { + start_column = info[4].As().Uint32Value() << 1; + } + if (info.Length() > 5 && info[5].IsNumber()) { + end_row = info[5].As().Uint32Value(); + } + if (info.Length() > 6 && info[6].IsNumber()) { + end_column = info[6].As().Uint32Value() << 1; + } + if (info.Length() > 7 && info[7].IsNumber()) { + start_index = info[7].As().Uint32Value(); + } + if (info.Length() > 8 && info[8].IsNumber()) { + end_index = info[8].As().Uint32Value() << 1; + } + if (info.Length() > 9 && info[9].IsNumber()) { + match_limit = info[9].As().Uint32Value(); + } + if (info.Length() > 10 && info[10].IsNumber()) { + max_start_depth = info[10].As().Uint32Value(); + } + + query_cursor_ = ts_query_cursor_new(); + const TSQuery *ts_query = query->Get(); + TSNode root_node = node_methods::UnmarshalNode(env, tree); + TSPoint start_point = {start_row, start_column}; + TSPoint end_point = {end_row, end_column}; + ts_query_cursor_set_point_range(query_cursor_, start_point, end_point); + ts_query_cursor_set_byte_range(query_cursor_, start_index, end_index); + ts_query_cursor_set_match_limit(query_cursor_, match_limit); + ts_query_cursor_set_max_start_depth(query_cursor_, max_start_depth); + ts_query_cursor_exec(query_cursor_, ts_query, root_node); +} + +QueryIterator::~QueryIterator() { + ts_query_cursor_delete(query_cursor_); +} + +Napi::Value QueryIterator::Iterator(const Napi::CallbackInfo &info) { + return info.This(); +} + +Napi::Value QueryIterator::Next(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + auto result = Object::New(env); + const Query *query = Query::UnwrapQuery(query_.Value()); + const Tree *tree = Tree::UnwrapTree(tree_.Value()); + + Array js_match = Array::New(env); + unsigned index = 0; + std::vector nodes; + TSQueryMatch match; + + if (captures_) { + uint32_t capture_index; + if (!ts_query_cursor_next_capture( + query_cursor_, + &match, + &capture_index + )) { + result["done"] = true; + return result; + } + + js_match[index++] = Number::New(env, match.pattern_index); + js_match[index++] = Number::New(env, capture_index); + + for (uint16_t i = 0; i < match.capture_count; i++) { + const TSQueryCapture &capture = match.captures[i]; + + uint32_t capture_name_len = 0; + const char *capture_name = ts_query_capture_name_for_id( + query->Get(), capture.index, &capture_name_len); + + TSNode node = capture.node; + nodes.push_back(node); + + String js_capture = String::New(env, capture_name);; + js_match[index++] = js_capture; + } + } else { + if (!ts_query_cursor_next_match(query_cursor_, &match)) { + result["done"] = true; + return result; + } + + js_match[index++] = Number::New(env, match.pattern_index); + + for (uint16_t i = 0; i < match.capture_count; i++) { + const TSQueryCapture &capture = match.captures[i]; + + uint32_t capture_name_len = 0; + const char *capture_name = ts_query_capture_name_for_id( + query->Get(), capture.index, &capture_name_len); + + TSNode node = capture.node; + nodes.push_back(node); + + String js_capture = String::New(env, capture_name);; + js_match[index++] = js_capture; + } + } + + auto js_nodes = node_methods::GetMarshalNodes(info, tree, nodes.data(), nodes.size()); + + auto value = Array::New(env); + value[0U] = js_match; + value[1] = js_nodes; + result["value"] = value; + return result; +} + +Napi::Value QueryIterator::GetCaptures(const Napi::CallbackInfo &info) { + return Boolean::New(info.Env(), captures_); +} + +Napi::Value QueryIterator::GetQuery(const Napi::CallbackInfo &info) { + return query_.Value(); +} + +Napi::Value QueryIterator::GetTree(const Napi::CallbackInfo &info) { + return tree_.Value(); +} + +} // namespace node_tree_sitter diff --git a/src/query_iterator.h b/src/query_iterator.h new file mode 100644 index 0000000..6207d9b --- /dev/null +++ b/src/query_iterator.h @@ -0,0 +1,31 @@ +#ifndef NODE_TREE_SITTER_MATCHES_ITERATOR_H_ +#define NODE_TREE_SITTER_MATCHES_ITERATOR_H_ + +#include +#include "tree_sitter/api.h" + +namespace node_tree_sitter { + +class QueryIterator final : public Napi::ObjectWrap { + public: + static void Init(Napi::Env env, Napi::Object exports); + + explicit QueryIterator(const Napi::CallbackInfo &info); + ~QueryIterator() final; + + private: + bool captures_; + Napi::Reference query_; + Napi::Reference tree_; + TSQueryCursor *query_cursor_ = nullptr; + + Napi::Value Iterator(const Napi::CallbackInfo &info); + Napi::Value Next(const Napi::CallbackInfo &info); + Napi::Value GetCaptures(const Napi::CallbackInfo &info); + Napi::Value GetQuery(const Napi::CallbackInfo &info); + Napi::Value GetTree(const Napi::CallbackInfo &info); +}; + +} // namespace node_tree_sitter + +#endif // NODE_TREE_SITTER_MATCHES_ITERATOR_H_ diff --git a/test/query_test.js b/test/query_test.js index 4fd2b71..dfe6e80 100644 --- a/test/query_test.js +++ b/test/query_test.js @@ -109,6 +109,20 @@ describe("Query", () => { ]); }); + it("returns all of the matches (iterator) for the given query", () => { + const tree = parser.parse("function one() { two(); function three() {} }"); + const query = new Query(JavaScript, ` + (function_declaration name: (identifier) @fn-def) + (call_expression function: (identifier) @fn-ref) + `); + const matches = [...query.matchesIter(tree.rootNode)]; + assert.deepEqual(formatMatches(tree, matches), [ + { pattern: 0, captures: [{ name: "fn-def", text: "one" }] }, + { pattern: 1, captures: [{ name: "fn-ref", text: "two" }] }, + { pattern: 0, captures: [{ name: "fn-def", text: "three" }] }, + ]); + }); + it("can search in a specified ranges", () => { const tree = parser.parse("[a, b,\nc, d,\ne, f,\ng, h]"); const query = new Query(JavaScript, "(identifier) @element"); @@ -186,6 +200,47 @@ describe("Query", () => { ]); }); + it("returns all of the captures (iterator) for the given query, in order", () => { + const tree = parser.parse(` + a({ + bc: function de() { + const fg = function hi() {} + }, + jk: function lm() { + const no = function pq() {} + }, + }); + `); + const query = new Query(JavaScript, ` + (pair + key: _ @method.def + (function_expression + name: (identifier) @method.alias)) + (variable_declarator + name: _ @function.def + value: (function_expression + name: (identifier) @function.alias)) + ":" @delimiter + "=" @operator + `); + + const captures = [...query.capturesIter(tree.rootNode)]; + assert.deepEqual(formatCaptures(tree, captures), [ + { name: "method.def", text: "bc" }, + { name: "delimiter", text: ":" }, + { name: "method.alias", text: "de" }, + { name: "function.def", text: "fg" }, + { name: "operator", text: "=" }, + { name: "function.alias", text: "hi" }, + { name: "method.def", text: "jk" }, + { name: "delimiter", text: ":" }, + { name: "method.alias", text: "lm" }, + { name: "function.def", text: "no" }, + { name: "operator", text: "=" }, + { name: "function.alias", text: "pq" }, + ]); + }); + it("handles conditions that compare the text of capture to literal strings", () => { const tree = parser.parse(` const ab = require('./ab'); diff --git a/tree-sitter.d.ts b/tree-sitter.d.ts index 9631782..6135471 100644 --- a/tree-sitter.d.ts +++ b/tree-sitter.d.ts @@ -187,7 +187,9 @@ declare module "tree-sitter" { constructor(language: any, source: string | Buffer); captures(node: SyntaxNode, options?: QueryOptions): QueryCapture[]; + capturesIter(node: SyntaxNode, options?: QueryOptions): IterableIterator; matches(node: SyntaxNode, options?: QueryOptions): QueryMatch[]; + matchesIter(node: SyntaxNode, options?: QueryOptions): IterableIterator; disableCapture(captureName: string): void; disablePattern(patternIndex: number): void; isPatternGuaranteedAtStep(byteOffset: number): boolean;