Skip to content
155 changes: 77 additions & 78 deletions src/main/java/io/github/treesitter/jtreesitter/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import io.github.treesitter.jtreesitter.internal.*;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.util.*;
import java.util.function.BiPredicate;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;

Expand All @@ -25,7 +24,7 @@
@NullMarked
public final class Query implements AutoCloseable {
private final MemorySegment query;
private final MemorySegment cursor;
private final QueryCursorConfig cursorConfig = new QueryCursorConfig();
private final Arena arena;
private final Language language;
private final String source;
Expand Down Expand Up @@ -85,7 +84,6 @@ public final class Query implements AutoCloseable {
this.language = language;
this.source = source;
this.query = query.reinterpret(arena, TreeSitter::ts_query_delete);
cursor = ts_query_cursor_new().reinterpret(arena, TreeSitter::ts_query_cursor_delete);

var captureCount = ts_query_capture_count(this.query);
captureNames = new ArrayList<>(captureCount);
Expand Down Expand Up @@ -262,6 +260,10 @@ private static boolean invalidPredicateChar(char c) {
return !(Character.isLetterOrDigit(c) || c == '_' || c == '-' || c == '.' || c == '?' || c == '!');
}

MemorySegment self() {
return query;
}

/** Get the number of patterns in the query. */
public @Unsigned int getPatternCount() {
return ts_query_pattern_count(query);
Expand All @@ -272,25 +274,30 @@ private static boolean invalidPredicateChar(char c) {
return ts_query_capture_count(query);
}

public List<List<QueryPredicate>> getPredicates() {
return predicates.stream().map(Collections::unmodifiableList).toList();
}

public List<String> getCaptureNames() {
return Collections.unmodifiableList(captureNames);
}

/**
* Get the maximum number of in-progress matches.
* Get the maximum number of in-progress matches of the default {@link QueryCursorConfig}
*
* @apiNote Defaults to {@code -1} (unlimited).
*/
public @Unsigned int getMatchLimit() {
return ts_query_cursor_match_limit(cursor);
return cursorConfig.getMatchLimit();
}

/**
* Get the maximum number of in-progress matches.
* Set the maximum number of in-progress matches of the default {@link QueryCursorConfig}
*
* @throws IllegalArgumentException If {@code matchLimit == 0}.
*/
public Query setMatchLimit(@Unsigned int matchLimit) throws IllegalArgumentException {
if (matchLimit == 0) {
throw new IllegalArgumentException("The match limit cannot equal 0");
}
ts_query_cursor_set_match_limit(cursor, matchLimit);
cursorConfig.setMatchLimit(matchLimit);
return this;
}

Expand All @@ -300,19 +307,23 @@ public Query setMatchLimit(@Unsigned int matchLimit) throws IllegalArgumentExcep
*
* @apiNote Defaults to {@code 0} (unlimited).
* @since 0.23.1
* @deprecated
*/
@Deprecated(forRemoval = true)
public @Unsigned long getTimeoutMicros() {
return ts_query_cursor_timeout_micros(cursor);
return cursorConfig.getTimeoutMicros();
}

/**
* Set the maximum duration in microseconds that query
* execution should be allowed to take before halting.
*
* @since 0.23.1
* @deprecated
*/
@Deprecated(forRemoval = true)
public Query setTimeoutMicros(@Unsigned long timeoutMicros) {
ts_query_cursor_set_timeout_micros(cursor, timeoutMicros);
cursorConfig.setTimeoutMicros(timeoutMicros);
return this;
}

Expand All @@ -323,33 +334,22 @@ public Query setTimeoutMicros(@Unsigned long timeoutMicros) {
* <br>Note that if a pattern includes many children, then they will still be checked.
*/
public Query setMaxStartDepth(@Unsigned int maxStartDepth) {
ts_query_cursor_set_max_start_depth(cursor, maxStartDepth);
cursorConfig.setMaxStartDepth(maxStartDepth);
return this;
}

/** Set the range of bytes in which the query will be executed. */
public Query setByteRange(@Unsigned int startByte, @Unsigned int endByte) {
ts_query_cursor_set_byte_range(cursor, startByte, endByte);
cursorConfig.setByteRange(startByte, endByte);
return this;
}

/** Set the range of points in which the query will be executed. */
public Query setPointRange(Point startPoint, Point endPoint) {
try (var alloc = Arena.ofConfined()) {
MemorySegment start = startPoint.into(alloc), end = endPoint.into(alloc);
ts_query_cursor_set_point_range(cursor, start, end);
}
cursorConfig.setPointRange(startPoint, endPoint);
return this;
}

/**
* Check if the query exceeded its maximum number of
* in-progress matches during its last execution.
*/
public boolean didExceedMatchLimit() {
return ts_query_cursor_did_exceed_match_limit(cursor);
}

/**
* Disable a certain pattern within a query.
*
Expand Down Expand Up @@ -478,36 +478,77 @@ public Map<String, Optional<String>> getPatternAssertions(@Unsigned int index, b
}

/**
* Iterate over all the matches in the order that they were found.
* Execute the query on a given node with the default {@link QueryCursorConfig}.
* @param node The node that the query will run on.
* @return A cursor that can be used to iterate over the matches.
*/
public QueryCursor execute(Node node) {
return new QueryCursor(this, node, cursorConfig);
}

/**
* Execute the query on a given node with the given options. The options override the default options set on the query.
* @param node The node that the query will run on.
* @param options The options that will be used for this query.
* @return A cursor that can be used to iterate over the matches.
*/
public QueryCursor execute(Node node, QueryCursorConfig options) {
return new QueryCursor(this, node, options);
}

/**
* Iterate over all the matches in the order that they were found. The lifetime of the native memory of the returned
* matches is bound to the lifetime of this query object.
*
* @param node The node that the query will run on.
* @implNote The stream is not created lazily such that there is no open {@link QueryCursor} instance left behind.
* For creating a lazy stream use {@link #execute(Node)} and {@link QueryCursor#matchStream()}.
*/
public Stream<QueryMatch> findMatches(Node node) {
return findMatches(node, null);
return findMatches(node, arena, null);
}

/**
* Iterate over all the matches in the order that they were found.
* Iterate over all the matches in the order that they were found. The lifetime of the native memory of the returned
* matches is bound to the lifetime of this query object.
*
* <h4 id="findMatches-example">Predicate Example</h4>
* <p>
* {@snippet lang="java" :
* {@snippet lang = "java":
* Stream<QueryMatch> matches = query.findMatches(tree.getRootNode(), (predicate, match) -> {
* if (!predicate.getName().equals("ieq?")) return true;
* List<QueryPredicateArg> args = predicate.getArgs();
* Node node = match.findNodes(args.getFirst().value()).getFirst();
* return args.getLast().value().equalsIgnoreCase(node.getText());
* });
* }
*}
*
* @param node The node that the query will run on.
* @param node The node that the query will run on.
* @param predicate A function that handles custom predicates.
* @implNote The stream is not created lazily such that there is no open {@link QueryCursor} instance left behind.
* For creating a lazy stream use {@link #execute(Node)} and {@link QueryCursor#matchStream(BiPredicate)}.
*/
public Stream<QueryMatch> findMatches(Node node, @Nullable BiPredicate<QueryPredicate, QueryMatch> predicate) {
try (var alloc = Arena.ofConfined()) {
ts_query_cursor_exec(cursor, query, node.copy(alloc));
return findMatches(node, arena, predicate);
}

/**
* Like {@link #findMatches(Node, BiPredicate)} but the native memory of the returned matches is created using the
* given allocator.
*
* @param node The node that the query will run on.
* @param allocator The allocator that is used to allocate the native memory of the returned matches.
* @param predicate A function that handles custom predicates.
* @implNote The stream is not created lazily such that there is no open {@link QueryCursor} instance left behind.
* For creating a lazy stream use {@link #execute(Node)} and {@link QueryCursor#matchStream(SegmentAllocator, BiPredicate)}.
*/
public Stream<QueryMatch> findMatches(
Node node, SegmentAllocator allocator, @Nullable BiPredicate<QueryPredicate, QueryMatch> predicate) {
try (QueryCursor cursor = this.execute(node)) {
// make sure to load the entire stream into memory before closing the cursor.
// Otherwise, we call for nextMatch after closing the cursor which leads to an exception.
return cursor.matchStream(allocator, predicate).toList().stream();
}
return StreamSupport.stream(new MatchesIterator(node.getTree(), predicate), false);
}

@Override
Expand All @@ -520,52 +561,10 @@ public String toString() {
return "Query{language=%s, source=%s}".formatted(language, source);
}

private boolean matches(@Nullable BiPredicate<QueryPredicate, QueryMatch> predicate, QueryMatch match) {
return predicates.get(match.patternIndex()).stream().allMatch(p -> {
if (p.getClass() != QueryPredicate.class) return p.test(match);
return predicate == null || predicate.test(p, match);
});
}

private void checkIndex(@Unsigned int index) throws IndexOutOfBoundsException {
if (Integer.compareUnsigned(index, getPatternCount()) >= 0) {
throw new IndexOutOfBoundsException(
"Pattern index %s is out of bounds".formatted(Integer.toUnsignedString(index)));
}
}

private final class MatchesIterator extends Spliterators.AbstractSpliterator<QueryMatch> {
private final @Nullable BiPredicate<QueryPredicate, QueryMatch> predicate;
private final Tree tree;

public MatchesIterator(Tree tree, @Nullable BiPredicate<QueryPredicate, QueryMatch> predicate) {
super(Long.MAX_VALUE, Spliterator.IMMUTABLE | Spliterator.NONNULL);
this.predicate = predicate;
this.tree = tree;
}

@Override
public boolean tryAdvance(Consumer<? super QueryMatch> action) {
var hasNoText = tree.getText() == null;
MemorySegment match = arena.allocate(TSQueryMatch.layout());
while (ts_query_cursor_next_match(cursor, match)) {
var count = Short.toUnsignedInt(TSQueryMatch.capture_count(match));
var matchCaptures = TSQueryMatch.captures(match);
var captureList = new ArrayList<QueryCapture>(count);
for (int i = 0; i < count; ++i) {
var capture = TSQueryCapture.asSlice(matchCaptures, i);
var name = captureNames.get(TSQueryCapture.index(capture));
var node = TSNode.allocate(arena).copyFrom(TSQueryCapture.node(capture));
captureList.add(new QueryCapture(name, new Node(node, tree)));
}
var patternIndex = TSQueryMatch.pattern_index(match);
var result = new QueryMatch(patternIndex, captureList);
if (hasNoText || matches(predicate, result)) {
action.accept(result);
return true;
}
}
return false;
}
}
}
Loading