diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2d78e24..0557e37 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,9 +16,6 @@ jobs: matrix: zig-version: [master] os: [ubuntu-22.04, macos-latest, windows-latest] - include: - - zig-version: "0.14.0" - os: ubuntu-22.04 runs-on: ${{ matrix.os }} steps: - name: Checkout diff --git a/README.md b/README.md index f0e5814..d865a68 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Provides the necessary building blocks to develop Language Server Protocol imple ## Installation > [!NOTE] -> The minimum supported Zig version is `0.14.0`. +> The default branch requires Zig `0.15.0-dev.1018+1a998886c` or later. Checkout the `0.14.x` branch when using Zig 0.14 ```bash # Initialize a `zig build` project if you haven't already diff --git a/build.zig b/build.zig index 378b41c..22b4c4b 100644 --- a/build.zig +++ b/build.zig @@ -1,17 +1,17 @@ const std = @import("std"); const builtin = @import("builtin"); -const minimum_zig_version = std.SemanticVersion.parse("0.14.0") catch unreachable; +const minimum_zig_version = "0.15.0-dev.1018+1a998886c"; pub fn build(b: *std.Build) void { - comptime if (builtin.zig_version.order(minimum_zig_version) == .lt) { + comptime if (builtin.zig_version.order(std.SemanticVersion.parse("0.15.0-dev.1018+1a998886c") catch unreachable) == .lt) { @compileError(std.fmt.comptimePrint( \\Your Zig version does not meet the minimum build requirement: - \\ required Zig version: {[minimum_zig_version]} - \\ actual Zig version: {[current_version]} + \\ required Zig version: {[minimum_zig_version]s} + \\ actual Zig version: {[current_version]s} \\ , .{ - .current_version = builtin.zig_version, + .current_version = builtin.zig_version_string, .minimum_zig_version = minimum_zig_version, })); }; diff --git a/build.zig.zon b/build.zig.zon index 46fc999..69d0790 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -1,7 +1,7 @@ .{ .name = .lsp_codegen, .version = "0.1.0", - .minimum_zig_version = "0.14.0", + .minimum_zig_version = "0.15.0-dev.1023+f551c7c58", .dependencies = .{}, .paths = .{ "build.zig", diff --git a/examples/hello_client.zig b/examples/hello_client.zig index e39e90a..8f9b945 100644 --- a/examples/hello_client.zig +++ b/examples/hello_client.zig @@ -78,8 +78,10 @@ pub fn main() !void { // Language servers can support multiple communication channels (e.g. stdio, pipes, sockets). // See https://microsoft.github.io/language-server-protocol/specifications/specification-current/#implementationConsiderations // - // The `TransportOverStdio` implements the necessary logic to read and write messages over stdio. - var transport: lsp.TransportOverStdio = .init(child_process.stdout.?, child_process.stdin.?); + // The `lsp.Transport.Stdio` implements the necessary logic to read and write messages over stdio. + var read_buffer: [256]u8 = undefined; + var stdio_transport: lsp.Transport.Stdio = .init(&read_buffer, child_process.stdout.?, child_process.stdin.?); + const transport: *lsp.Transport = &stdio_transport.transport; // The order of exchanged messages will look similar to this: // @@ -90,7 +92,7 @@ pub fn main() !void { // 5. send `exit` notification std.log.debug("sending 'initialize' request to server", .{}); - try transport.any().writeRequest( + try transport.writeRequest( gpa, .{ .number = 0 }, "initialize", // https://microsoft.github.io/language-server-protocol/specifications/specification-current/#initialize @@ -102,7 +104,7 @@ pub fn main() !void { // Wait for the response from the server // For the sake of simplicity, we will block here and read messages until the response to our request has been found. All other messages will be ignored. // A more sophisticated client implementation will need to handle messages asynchronously. - const initialize_response = try readAndIgnoreUntilResponse(gpa, transport.any(), .{ .number = 0 }, "initialize"); + const initialize_response = try readAndIgnoreUntilResponse(gpa, transport, .{ .number = 0 }, "initialize"); defer initialize_response.deinit(); const initialize_result: lsp.types.InitializeResult = initialize_response.value; @@ -126,7 +128,7 @@ pub fn main() !void { } std.log.debug("sending 'initialized' notification to server", .{}); - try transport.any().writeNotification( + try transport.writeNotification( gpa, "initialized", // https://microsoft.github.io/language-server-protocol/specifications/specification-current/#initialized lsp.types.InitializedParams, @@ -138,7 +140,7 @@ pub fn main() !void { std.log.info("This document recently came in by the CLI.", .{}); std.log.debug("sending 'textDocument/didOpen' notification to server", .{}); - try transport.any().writeNotification( + try transport.writeNotification( gpa, "textDocument/didOpen", // https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_didOpen lsp.types.DidOpenTextDocumentParams, @@ -155,7 +157,7 @@ pub fn main() !void { std.log.info("Just to double check, could you verify that it is formatted correctly?", .{}); std.log.debug("sending 'textDocument/formatting' request to server", .{}); - try transport.any().writeRequest( + try transport.writeRequest( gpa, .{ .number = 1 }, "textDocument/formatting", // https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_formatting @@ -167,7 +169,7 @@ pub fn main() !void { .{ .emit_null_optional_fields = false }, ); - const formatting_response = try readAndIgnoreUntilResponse(gpa, transport.any(), .{ .number = 1 }, "textDocument/formatting"); + const formatting_response = try readAndIgnoreUntilResponse(gpa, transport, .{ .number = 1 }, "textDocument/formatting"); defer formatting_response.deinit(); const text_edits = formatting_response.value orelse &.{}; @@ -183,7 +185,7 @@ pub fn main() !void { // Even though this is a request, we do not wait for a response because we are going to close the server anyway. std.log.debug("sending 'shutdown' request to server", .{}); - try transport.any().writeRequest( + try transport.writeRequest( gpa, .{ .number = 2 }, "shutdown", // https://microsoft.github.io/language-server-protocol/specifications/specification-current/#shutdown @@ -193,7 +195,7 @@ pub fn main() !void { ); std.log.debug("sending 'exit' notification to server", .{}); - try transport.any().writeNotification( + try transport.writeNotification( gpa, "exit", // https://microsoft.github.io/language-server-protocol/specifications/specification-current/#exit void, @@ -206,7 +208,11 @@ pub fn main() !void { } fn fatalWithUsage(comptime format: []const u8, args: anytype) noreturn { - std.io.getStdErr().writeAll(usage) catch {}; + { + const stderr = std.debug.lockStderrWriter(&.{}); + defer std.debug.unlockStderrWriter(); + stderr.writeAll(usage) catch {}; + } std.log.err(format, args); std.process.exit(1); } @@ -219,7 +225,7 @@ fn fatal(comptime format: []const u8, args: anytype) noreturn { /// Do not use such a function in an actual implementation. fn readAndIgnoreUntilResponse( allocator: std.mem.Allocator, - transport: lsp.AnyTransport, + transport: *lsp.Transport, id: lsp.JsonRPCMessage.ID, comptime method: []const u8, ) !std.json.Parsed(lsp.ResultType(method)) { diff --git a/examples/hello_server.zig b/examples/hello_server.zig index b20c373..142f491 100644 --- a/examples/hello_server.zig +++ b/examples/hello_server.zig @@ -44,8 +44,10 @@ pub fn main() !void { // Language servers can support multiple communication channels (e.g. stdio, pipes, sockets). // See https://microsoft.github.io/language-server-protocol/specifications/specification-current/#implementationConsiderations // - // The `TransportOverStdio` implements the necessary logic to read and write messages over stdio. - var transport: lsp.TransportOverStdio = .init(std.io.getStdIn(), std.io.getStdOut()); + // The `lsp.Transport.Stdio` implements the necessary logic to read and write messages over stdio. + var read_buffer: [256]u8 = undefined; + var stdio_transport: lsp.Transport.Stdio = .init(&read_buffer, .stdin(), .stdout()); + const transport: *lsp.Transport = &stdio_transport.transport; // keep track of opened documents var documents: std.StringArrayHashMapUnmanaged([]const u8) = .empty; @@ -95,7 +97,7 @@ pub fn main() !void { .request => |request| switch (request.params) { .initialize => |params| { _ = params.capabilities; // the client capabilities tell the server what "features" the client supports - try transport.any().writeResponse( + try transport.writeResponse( gpa, request.id, lsp.types.InitializeResult, @@ -111,11 +113,11 @@ pub fn main() !void { .{ .emit_null_optional_fields = false }, ); }, - .shutdown => try transport.any().writeResponse(gpa, request.id, void, {}, .{}), + .shutdown => try transport.writeResponse(gpa, request.id, void, {}, .{}), .@"textDocument/formatting" => |params| { const source = documents.get(params.textDocument.uri) orelse { // We should read the document from the file system - try transport.any().writeResponse(gpa, request.id, void, {}, .{}); + try transport.writeResponse(gpa, request.id, void, {}, .{}); continue; }; const source_z = try gpa.dupeZ(u8, source); @@ -125,7 +127,7 @@ pub fn main() !void { defer tree.deinit(gpa); if (tree.errors.len != 0) { - try transport.any().writeResponse(gpa, request.id, void, {}, .{}); + try transport.writeResponse(gpa, request.id, void, {}, .{}); continue; } @@ -133,7 +135,7 @@ pub fn main() !void { defer gpa.free(formatte_source); if (std.mem.eql(u8, source, formatte_source)) { - try transport.any().writeResponse(gpa, request.id, void, {}, .{}); + try transport.writeResponse(gpa, request.id, void, {}, .{}); continue; } @@ -145,9 +147,9 @@ pub fn main() !void { .newText = formatte_source, }}; - try transport.any().writeResponse(gpa, request.id, []const lsp.types.TextEdit, result, .{}); + try transport.writeResponse(gpa, request.id, []const lsp.types.TextEdit, result, .{}); }, - .other => try transport.any().writeResponse(gpa, request.id, void, {}, .{}), + .other => try transport.writeResponse(gpa, request.id, void, {}, .{}), }, .notification => |notification| switch (notification.params) { .initialized => {}, diff --git a/examples/my_first_server.zig b/examples/my_first_server.zig index 868fb2c..22ec9fe 100644 --- a/examples/my_first_server.zig +++ b/examples/my_first_server.zig @@ -17,14 +17,16 @@ pub fn main() !void { }; // language server typically communicate over stdio (stdin and stdout) - var transport: lsp.TransportOverStdio = .init(std.io.getStdIn(), std.io.getStdOut()); + var read_buffer: [256]u8 = undefined; + var stdio_transport: lsp.Transport.Stdio = .init(&read_buffer, .stdin(), .stdout()); + const transport: *lsp.Transport = &stdio_transport.transport; var handler: Handler = .init(gpa); defer handler.deinit(); try lsp.basic_server.run( gpa, - transport.any(), + transport, &handler, std.log.err, ); diff --git a/src/basic_server.zig b/src/basic_server.zig index b915b0b..22ea6c8 100644 --- a/src/basic_server.zig +++ b/src/basic_server.zig @@ -15,7 +15,7 @@ const types = lsp.types; pub fn run( allocator: std.mem.Allocator, - transport: lsp.AnyTransport, + transport: *lsp.Transport, /// Must be a pointer to a container type (e.g. `struct`) that implements /// the desired LSP methods. /// diff --git a/src/lsp.zig b/src/lsp.zig index e831d1b..a65d245 100644 --- a/src/lsp.zig +++ b/src/lsp.zig @@ -707,23 +707,23 @@ pub const JsonRPCMessage = union(enum) { test "emit_null_optional_fields" { try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"exit"} - , "{}", .{std.json.fmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = null } }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = null } }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"exit","params":null} - , "{}", .{std.json.fmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = null } }, .{ .emit_null_optional_fields = true })}); + , "{f}", .{parser.jsonFmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = null } }, .{ .emit_null_optional_fields = true })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"exit","params":null} - , "{}", .{std.json.fmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = .null } }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = .null } }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"exit","params":null} - , "{}", .{std.json.fmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = .null } }, .{ .emit_null_optional_fields = true })}); + , "{f}", .{parser.jsonFmt(JsonRPCMessage{ .notification = .{ .method = "exit", .params = .null } }, .{ .emit_null_optional_fields = true })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","result":null} - , "{}", .{std.json.fmt(JsonRPCMessage{ .response = .{ .id = null, .result_or_error = .{ .result = null } } }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(JsonRPCMessage{ .response = .{ .id = null, .result_or_error = .{ .result = null } } }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","id":null,"result":null} - , "{}", .{std.json.fmt(JsonRPCMessage{ .response = .{ .id = null, .result_or_error = .{ .result = null } } }, .{ .emit_null_optional_fields = true })}); + , "{f}", .{parser.jsonFmt(JsonRPCMessage{ .response = .{ .id = null, .result_or_error = .{ .result = null } } }, .{ .emit_null_optional_fields = true })}); } fn testParse(message: []const u8, expected: JsonRPCMessage, parse_options: std.json.ParseOptions) !void { @@ -872,13 +872,13 @@ test TypedJsonRPCRequest { try std.testing.expectFmt( \\{"jsonrpc":"2.0","id":42,"method":"name","params":null} - , "{}", .{std.json.fmt(Request{ .id = .{ .number = 42 }, .method = "name", .params = null }, .{})}); + , "{f}", .{parser.jsonFmt(Request{ .id = .{ .number = 42 }, .method = "name", .params = null }, .{})}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","id":"42","method":"name"} - , "{}", .{std.json.fmt(Request{ .id = .{ .string = "42" }, .method = "name", .params = null }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(Request{ .id = .{ .string = "42" }, .method = "name", .params = null }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","id":42,"method":"name","params":true} - , "{}", .{std.json.fmt(Request{ .id = .{ .number = 42 }, .method = "name", .params = true }, .{})}); + , "{f}", .{parser.jsonFmt(Request{ .id = .{ .number = 42 }, .method = "name", .params = true }, .{})}); } pub fn TypedJsonRPCNotification( @@ -924,13 +924,13 @@ test TypedJsonRPCNotification { try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"name","params":null} - , "{}", .{std.json.fmt(Notification{ .method = "name", .params = null }, .{})}); + , "{f}", .{parser.jsonFmt(Notification{ .method = "name", .params = null }, .{})}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"name"} - , "{}", .{std.json.fmt(Notification{ .method = "name", .params = null }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(Notification{ .method = "name", .params = null }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"name","params":true} - , "{}", .{std.json.fmt(Notification{ .method = "name", .params = true }, .{})}); + , "{f}", .{parser.jsonFmt(Notification{ .method = "name", .params = true }, .{})}); } pub fn TypedJsonRPCResponse( @@ -982,13 +982,13 @@ test TypedJsonRPCResponse { try std.testing.expectFmt( \\{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"message","data":null}} - , "{}", .{std.json.fmt(Response{ + , "{f}", .{parser.jsonFmt(Response{ .id = null, .result_or_error = .{ .@"error" = .{ .code = .invalid_request, .message = "message", .data = .null } }, }, .{})}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","id":5,"result":true} - , "{}", .{std.json.fmt(Response{ + , "{f}", .{parser.jsonFmt(Response{ .id = .{ .number = 5 }, .result_or_error = .{ .result = true }, }, .{})}); @@ -1000,69 +1000,47 @@ test TypedJsonRPCResponse { pub const BaseProtocolHeader = struct { content_length: usize, - pub const max_header_length: usize = 1024; + pub const minimum_reader_buffer_size: usize = 1024; pub const ParseError = error{ EndOfStream, /// The message is longer than `std.math.maxInt(usize)`. OversizedMessage, - /// The header field is longer than `max_header_length`. The ": " doesn't count towards the header field length. + /// The header field is longer than buffer size of the `std.io.Reader` which is at least `minimum_reader_buffer_size`. OversizedHeaderField, /// The header is missing the mandatory `Content-Length` field. MissingContentLength, + /// The header field `Content-Length` has been specified multiple times. + DuplicateContentLength, /// The header field value of `Content-Length` is not a valid unsigned integer. InvalidContentLength, /// The header is ill-formed. InvalidHeaderField, }; - /// It is strongly recommended to provide a buffering reader because the parser has to read 1-byte at a time. - pub inline fn parse(reader: anytype) (@TypeOf(reader).Error || ParseError)!BaseProtocolHeader { - return @errorCast(parseAny(reader.any())); - } - - /// It is strongly recommended to provide a buffering reader because the parser has to read 1-byte at a time. - /// - /// Type erased version of `parse`. - pub fn parseAny(reader: std.io.AnyReader) (anyerror || ParseError)!BaseProtocolHeader { + /// The maximum parsable header field length is controlled by `reader.buffer.len`. + /// Asserts that `reader.buffer.len >= minimum_reader_buffer_size`. + pub fn parse(reader: *std.io.Reader) (std.io.Reader.Error || ParseError)!BaseProtocolHeader { + std.debug.assert(@import("builtin").is_test or reader.buffer.len >= minimum_reader_buffer_size); var content_length: ?usize = null; - outer: while (true) { - var maybe_colon_index: ?usize = null; + while (true) { + var header = reader.takeDelimiterInclusive('\n') catch |err| switch (err) { + error.StreamTooLong => return error.OversizedHeaderField, + else => |e| return e, + }; + if (!std.mem.endsWith(u8, header, "\r\n")) return error.InvalidHeaderField; + header.len -= "\r\n".len; - var buffer: [max_header_length]u8 = undefined; - var buffer_index: usize = 0; + if (header.len == 0) break; - while (true) { - const byte: u8 = try reader.readByte(); - switch (byte) { - '\n' => return error.InvalidHeaderField, - '\r' => { - if (try reader.readByte() != '\n') return error.InvalidHeaderField; - if (buffer_index == 0) break :outer; - break; - }, - ':' => { - // The ": " is not being added to the buffer here! - if (try reader.readByte() != ' ') return error.InvalidHeaderField; - if (maybe_colon_index != null) return error.InvalidHeaderField; // duplicate ':' - maybe_colon_index = buffer_index; - }, - else => { - if (buffer_index == max_header_length) return error.OversizedHeaderField; - buffer[buffer_index] = byte; - buffer_index += 1; - }, - } - } - - const colon_index = maybe_colon_index orelse return error.InvalidHeaderField; + const colon_index = std.mem.indexOf(u8, header, ": ") orelse return error.InvalidHeaderField; - const header = buffer[0..buffer_index]; const header_name = header[0..colon_index]; - const header_value = header[colon_index..]; + const header_value = header[colon_index + 2 ..]; if (!std.ascii.eqlIgnoreCase(header_name, "content-length")) continue; + if (content_length != null) return error.DuplicateContentLength; content_length = std.fmt.parseUnsigned(usize, header_value, 10) catch |err| switch (err) { error.Overflow => return error.OversizedMessage, @@ -1084,12 +1062,14 @@ pub const BaseProtocolHeader = struct { try expectParseError("\r\n\r\n", error.MissingContentLength); try expectParseError("content-length: 32\r\n", error.EndOfStream); + try expectParseError("content-length: \r\n\r\n", error.InvalidContentLength); try expectParseError("content-length 32\r\n\r\n", error.InvalidHeaderField); try expectParseError("content-length:32\r\n\r\n", error.InvalidHeaderField); try expectParseError("contentLength: 32\r\n\r\n", error.MissingContentLength); + try expectParseError("content-length: 32\r\ncontent-length: 32\r\n\r\n", error.DuplicateContentLength); try expectParseError("content-length: abababababab\r\n\r\n", error.InvalidContentLength); + try expectParseError("content-length: : 32\r\n\r\n", error.InvalidContentLength); try expectParseError("content-length: 9999999999999999999999999999999999\r\n\r\n", error.OversizedMessage); - try expectParseError("content-length: : 32\r\n\r\n", error.InvalidHeaderField); try expectParse("content-length: 32\r\n\r\n", .{ .content_length = 32 }); try expectParse("Content-Length: 32\r\n\r\n", .{ .content_length = 32 }); @@ -1098,412 +1078,135 @@ pub const BaseProtocolHeader = struct { try expectParse("Content-Type: impostor\r\ncontent-length: 42\r\n\r\n", .{ .content_length = 42 }); } - pub fn format( - header: BaseProtocolHeader, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - _ = options; - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, header); - try std.fmt.format(out_stream, "Content-Length: {d}\r\n\r\n", .{header.content_length}); - } - - test format { - try std.testing.expectFmt("Content-Length: 0\r\n\r\n", "{}", .{BaseProtocolHeader{ .content_length = 0 }}); - try std.testing.expectFmt("Content-Length: 42\r\n\r\n", "{}", .{BaseProtocolHeader{ .content_length = 42 }}); - try std.testing.expectFmt("Content-Length: 4294967295\r\n\r\n", "{}", .{BaseProtocolHeader{ .content_length = std.math.maxInt(u32) }}); - } - - fn expectParse(input: []const u8, expected_header: BaseProtocolHeader) !void { - var fbs = std.io.fixedBufferStream(input); - const actual_header = try parse(fbs.reader()); - try std.testing.expectEqual(expected_header.content_length, actual_header.content_length); - } - - fn expectParseError(input: []const u8, expected_error: ParseError) !void { - var fbs = std.io.fixedBufferStream(input); - try std.testing.expectError(expected_error, parse(fbs.reader())); - } -}; - -pub const TestingTransport = struct { - allocator: std.mem.Allocator, - read_buffer: []const u8, - read_pos: usize = 0, - write_buffer: std.ArrayListUnmanaged(u8) = .{}, - - comptime { - std.debug.assert(@import("builtin").is_test); - } - - pub fn initWriteOnly(allocator: std.mem.Allocator) TestingTransport { - return .{ - .allocator = allocator, - .read_buffer = &.{}, - }; - } + test "parse with oversized header field" { + const stream = struct { + fn stream(_: *std.io.Reader, _: *std.io.Writer, _: std.io.Limit) std.io.Reader.StreamError!usize { + return error.EndOfStream; + } + }.stream; - pub fn initReadOnly(read_buffer: []const u8) TestingTransport { - return .{ - .allocator = std.testing.failing_allocator, - .read_buffer = read_buffer, + var buffer: [128]u8 = @splat(0); + var reader: std.io.Reader = .{ + .vtable = &.{ + .stream = &stream, + .discard = undefined, + }, + .buffer = &buffer, + .end = buffer.len, + .seek = 0, }; + try std.testing.expectError(error.OversizedHeaderField, parse(&reader)); } - pub fn deinit(transport: *TestingTransport) void { - transport.write_buffer.deinit(transport.allocator); - transport.* = undefined; + pub fn format(header: BaseProtocolHeader, writer: *std.io.Writer) std.io.Writer.Error!void { + try writer.print("Content-Length: {d}\r\n\r\n", .{header.content_length}); } - pub fn getWritten(transport: TestingTransport) []const u8 { - return transport.write_buffer.items; - } - - pub fn any(transport: *TestingTransport) AnyTransport { - return .{ .impl = .{ - .transport = transport, - .readJsonMessage = @ptrCast(&readJsonMessage), - .writeJsonMessage = @ptrCast(&writeJsonMessage), - } }; - } - - pub fn readJsonMessage(transport: *TestingTransport, allocator: std.mem.Allocator) (std.mem.Allocator.Error || AnyTransport.ReadError)![]u8 { - var fbs = std.io.fixedBufferStream(transport.read_buffer); - fbs.pos = transport.read_pos; - defer transport.read_pos = fbs.pos; - - const reader = fbs.reader(); - - const header = try BaseProtocolHeader.parse(reader); - - const json_message = try allocator.alloc(u8, header.content_length); - errdefer allocator.free(json_message); - - try reader.readNoEof(json_message); - - return json_message; - } - - test readJsonMessage { - var testing_transport: TestingTransport = .initReadOnly("Content-Length: 70\r\n\r\n" ++ - \\{ - \\ "jsonrpc": "2.0", - \\ "method": "methodName", - \\ "params": {} - \\} - ); - const json_message = try testing_transport.readJsonMessage(std.testing.allocator); - defer std.testing.allocator.free(json_message); - - try std.testing.expectEqualStrings( - \\{ - \\ "jsonrpc": "2.0", - \\ "method": "methodName", - \\ "params": {} - \\} - , json_message); - - const result = testing_transport.any().readJsonMessage(std.testing.allocator); - if (result) |message| std.testing.allocator.free(message) else |_| {} - try std.testing.expectError(error.EndOfStream, result); - } - - pub fn writeJsonMessage(transport: *TestingTransport, json_message: []const u8) AnyTransport.WriteError!void { - const header: BaseProtocolHeader = .{ .content_length = json_message.len }; - - var buffer: [64]u8 = undefined; - const prefix = std.fmt.bufPrint(&buffer, "{}", .{header}) catch unreachable; - - transport.write_buffer.appendSlice(transport.allocator, prefix) catch @panic("OOM"); - transport.write_buffer.appendSlice(transport.allocator, json_message) catch @panic("OOM"); - } - - test writeJsonMessage { - var testing_transport: TestingTransport = .initWriteOnly(std.testing.allocator); - defer testing_transport.deinit(); - - try testing_transport.writeJsonMessage( - \\{ - \\ "jsonrpc": "2.0", - \\ "method": "methodName", - \\ "params": {} - \\} - ); - - try std.testing.expectEqualStrings("Content-Length: 70\r\n\r\n" ++ - \\{ - \\ "jsonrpc": "2.0", - \\ "method": "methodName", - \\ "params": {} - \\} - , testing_transport.getWritten()); - } -}; - -pub const TransportOverStdio = struct { - impl: struct { - in: std.io.BufferedReader(512, std.fs.File.Reader), - out: std.fs.File, - }, - - pub fn init(read_from: std.fs.File, write_to: std.fs.File) TransportOverStdio { - return .{ .impl = .{ - .in = std.io.bufferedReaderSize(512, read_from.reader()), - .out = write_to, - } }; - } - - pub fn any(transport: *TransportOverStdio) AnyTransport { - return .{ .impl = .{ - .transport = transport, - .readJsonMessage = @ptrCast(&readJsonMessage), - .writeJsonMessage = @ptrCast(&writeJsonMessage), - } }; + test format { + try std.testing.expectFmt("Content-Length: 0\r\n\r\n", "{f}", .{BaseProtocolHeader{ .content_length = 0 }}); + try std.testing.expectFmt("Content-Length: 42\r\n\r\n", "{f}", .{BaseProtocolHeader{ .content_length = 42 }}); + try std.testing.expectFmt("Content-Length: 4294967295\r\n\r\n", "{f}", .{BaseProtocolHeader{ .content_length = std.math.maxInt(u32) }}); + if (@sizeOf(usize) == @sizeOf(u64)) { + try std.testing.expectFmt("Content-Length: 18446744073709551615\r\n\r\n", "{f}", .{BaseProtocolHeader{ .content_length = std.math.maxInt(usize) }}); + } } - pub fn readJsonMessage(transport: *TransportOverStdio, allocator: std.mem.Allocator) (std.mem.Allocator.Error || AnyTransport.ReadError)![]u8 { - const reader = transport.impl.in.reader().any(); - - const header = BaseProtocolHeader.parseAny(reader) catch |err| return @as(AnyTransport.ReadError, @errorCast(err)); - - const json_message = try allocator.alloc(u8, header.content_length); - errdefer allocator.free(json_message); - - reader.readNoEof(json_message) catch |err| return @as(AnyTransport.ReadError, @errorCast(err)); - - return json_message; + fn expectParse(input: []const u8, expected_header: BaseProtocolHeader) !void { + var reader: std.io.Reader = .fixed(input); + const actual_header = try parse(&reader); + try std.testing.expectEqual(expected_header.content_length, actual_header.content_length); } - pub fn writeJsonMessage(transport: *TransportOverStdio, json_message: []const u8) AnyTransport.WriteError!void { - const header: BaseProtocolHeader = .{ .content_length = json_message.len }; - - var buffer: [64]u8 = undefined; - const prefix = std.fmt.bufPrint(&buffer, "{}", .{header}) catch unreachable; - - var iovecs: [2]std.posix.iovec_const = .{ - .{ .base = prefix.ptr, .len = prefix.len }, - .{ .base = json_message.ptr, .len = json_message.len }, - }; + fn expectParseError(input: []const u8, expected_error: ParseError) !void { + var buffer: [128]u8 = undefined; + var reader: std.io.Reader = .fixed(&buffer); + reader.end = input.len; + @memcpy(buffer[0..input.len], input); - try transport.impl.out.writevAll(&iovecs); + try std.testing.expectError(expected_error, parse(&reader)); } }; -test TransportOverStdio { - if (comptime @import("builtin").target.os.tag == .windows and - @import("builtin").zig_version.order(std.SemanticVersion.parse("0.15.0-dev.920+b461d07a5") catch unreachable) != .lt) - { - // https://github.com/ziglang/zig/pull/24146 - // This would force us to create a different module in the build system just for tests. - return error.SkipZigTest; - } - - var tmp_dir = std.testing.tmpDir(.{}); - defer tmp_dir.cleanup(); - - var send_to_client: ?std.fs.File = try tmp_dir.dir.createFile("a", .{}); - defer if (send_to_client) |file| file.close(); - - var send_to_server: ?std.fs.File = try tmp_dir.dir.createFile("b", .{}); - defer if (send_to_server) |file| file.close(); - - const receive_from_server = try tmp_dir.dir.openFile("a", .{}); - defer receive_from_server.close(); - - const receive_from_client = try tmp_dir.dir.openFile("b", .{}); - defer receive_from_client.close(); - - var server_transport = TransportOverStdio.init(receive_from_client, send_to_client.?); - var client_transport = TransportOverStdio.init(receive_from_server, send_to_server.?); - - // Server -> Client - try server_transport.writeJsonMessage("\"hello from server\""); - - // Client <- Server - const message_from_server = try client_transport.readJsonMessage(std.testing.allocator); - defer std.testing.allocator.free(message_from_server); - try std.testing.expectEqualStrings("\"hello from server\"", message_from_server); - - // Client -> Server - try client_transport.any().writeJsonMessage("\"hello from client\""); - - // Server <- Client - const message_from_client = try server_transport.readJsonMessage(std.testing.allocator); - defer std.testing.allocator.free(message_from_client); - - try std.testing.expectEqualStrings("\"hello from client\"", message_from_client); - - send_to_client.?.close(); // Do not call `server_transport.writeJsonMessage` anymore - send_to_client = null; - - send_to_server.?.close(); // Do not call `client_transport.writeJsonMessage` anymore - send_to_server = null; - - var buffer: [512]u8 = undefined; - try std.testing.expectEqual(0, receive_from_server.read(&buffer)); - try std.testing.expectEqual(0, receive_from_client.read(&buffer)); - - try std.testing.expectError(error.EndOfStream, server_transport.readJsonMessage(std.testing.allocator)); - try std.testing.expectError(error.EndOfStream, client_transport.readJsonMessage(std.testing.allocator)); - - try std.testing.expectError(error.EndOfStream, server_transport.any().readJsonMessage(std.testing.allocator)); - try std.testing.expectError(error.EndOfStream, client_transport.any().readJsonMessage(std.testing.allocator)); -} - -/// This implementation is totally untested and probably completely wrong. Who is going to use this anyway... -const TransportOverStream = struct { - impl: struct { - server: std.net.Server, - connection: std.net.Server.Connection, - buffered_reader: std.io.BufferedReader(512, std.net.Stream.Reader), - }, - - pub fn initSocket(address: std.net.Address, options: std.net.Address.ListenOptions) (std.net.Address.ListenError || std.net.Server.AcceptError)!TransportOverStream { - var server = try address.listen(options); - errdefer server.deinit(); - const connection = try server.accept(); - errdefer connection.stream.close(); - - return .{ .impl = .{ - .server = server, - .connection = connection, - .buffered_reader = std.io.bufferedReaderSize(512, connection.stream.reader()), - } }; - } - - pub fn deinit(transport: *TransportOverStream) void { - transport.impl.connection.stream.close(); - transport.impl.server.deinit(); - transport.* = undefined; - } - - pub fn any(transport: *TransportOverStream) AnyTransport { - return .{ .impl = .{ - .transport = transport, - .readJsonMessage = @ptrCast(&readJsonMessage), - .writeJsonMessage = @ptrCast(&writeJsonMessage), - } }; - } - - pub fn readJsonMessage(transport: *TransportOverStream, allocator: std.mem.Allocator) (std.mem.Allocator.Error || AnyTransport.ReadError)![]u8 { - const reader = transport.impl.buffered_reader.reader().any(); - - const header = BaseProtocolHeader.parseAny(reader) catch |err| return @as(AnyTransport.ReadError, @errorCast(err)); - - const json_message = try allocator.alloc(u8, header.content_length); - errdefer allocator.free(json_message); +pub const TestingTransport = if (!@import("builtin").is_test) @compileError("Use 'std.io.Reader.fixed' or 'std.io.Writer.Allocating' instead."); +pub const TransportOverStdio = if (!@import("builtin").is_test) @compileError("Use 'Transport.Stdio' instead."); +pub const AnyTransport = if (!@import("builtin").is_test) @compileError("Use 'Transport' instead."); - reader.readNoEof(json_message) catch |err| return @as(AnyTransport.ReadError, @errorCast(err)); +const lsp = @This(); - return json_message; - } - - pub fn writeJsonMessage(transport: *TransportOverStream, json_message: []const u8) AnyTransport.WriteError!void { - const header: BaseProtocolHeader = .{ .content_length = json_message.len }; - - var buffer: [64]u8 = undefined; - const prefix = std.fmt.bufPrint(&buffer, "{}", .{header}) catch unreachable; - - var iovecs: [2]std.posix.iovec_const = .{ - .{ .base = prefix.ptr, .len = prefix.len }, - .{ .base = json_message.ptr, .len = json_message.len }, - }; +pub const Transport = struct { + vtable: *const VTable, - try transport.impl.connection.stream.writevAll(&iovecs); - } -}; - -pub const ThreadSafeTransportConfig = struct { - ChildTransport: type, - /// Makes `readJsonMessage` thread-safe. - thread_safe_read: bool, - /// Makes `writeJsonMessage` thread-safe. - thread_safe_write: bool, - MutexType: ?type = null, -}; - -/// Wraps a non-thread-safe transport and makes it thread-safe. -pub fn ThreadSafeTransport(config: ThreadSafeTransportConfig) type { - return struct { - child_transport: config.ChildTransport, - in_mutex: @TypeOf(in_mutex_init) = in_mutex_init, - out_mutex: @TypeOf(out_mutex_init) = out_mutex_init, + pub const VTable = struct { + readJsonMessage: *const fn (transport: *Transport, allocator: std.mem.Allocator) ReadError![]u8, + writeJsonMessage: *const fn (transport: *Transport, json_message: []const u8) WriteError!void, + }; - // Is there any better name of this? - const Self = @This(); + pub const ReadError = std.posix.ReadError || error{EndOfStream} || BaseProtocolHeader.ParseError || std.mem.Allocator.Error; + pub const WriteError = std.posix.WriteError; - pub fn any(transport: *Self) AnyTransport { - return .{ .impl = .{ - .transport = transport, - .readJsonMessage = @ptrCast(&readJsonMessage), - .writeJsonMessage = @ptrCast(&writeJsonMessage), - } }; + pub const Stdio = struct { + transport: Transport, + reader: std.io.Reader, + read_from: std.fs.File, + write_to: std.fs.File, + + pub fn init( + read_buffer: []u8, + read_from: std.fs.File, + write_to: std.fs.File, + ) Stdio { + return .{ + .transport = .{ + .vtable = &.{ + .readJsonMessage = &Stdio.readJsonMessage, + .writeJsonMessage = &Stdio.writeJsonMessage, + }, + }, + .reader = std.fs.File.Reader.initInterface(read_buffer), + .read_from = read_from, + .write_to = write_to, + }; } - pub fn readJsonMessage(transport: *Self, allocator: std.mem.Allocator) (std.mem.Allocator.Error || AnyTransport.ReadError)![]u8 { - transport.in_mutex.lock(); - defer transport.in_mutex.unlock(); - - return try transport.child_transport.readJsonMessage(allocator); + fn readJsonMessage(transport: *Transport, allocator: std.mem.Allocator) ReadError![]u8 { + const stdio: *Stdio = @fieldParentPtr("transport", transport); + var file_reader: std.fs.File.Reader = .{ + .file = stdio.read_from, + .interface = stdio.reader, + }; + defer stdio.reader = file_reader.interface; + return lsp.readJsonMessage(&file_reader.interface, allocator) catch |err| switch (err) { + error.ReadFailed => return file_reader.err.?, + else => |e| return e, + }; } - pub fn writeJsonMessage(transport: *Self, json_message: []const u8) AnyTransport.WriteError!void { - transport.out_mutex.lock(); - defer transport.out_mutex.unlock(); - - return try transport.child_transport.writeJsonMessage(json_message); + fn writeJsonMessage(transport: *Transport, json_message: []const u8) WriteError!void { + const stdio: *Stdio = @fieldParentPtr("transport", transport); + var file_writer = stdio.write_to.writer(&.{}); + return lsp.writeJsonMessage(&file_writer.interface, json_message) catch |err| switch (err) { + error.WriteFailed => return file_writer.err.?, + }; } - - const in_mutex_init = if (config.MutexType) |T| - T{} - else if (config.thread_safe_read) - std.Thread.Mutex{} - else - DummyMutex{}; - - const out_mutex_init = if (config.MutexType) |T| - T{} - else if (config.thread_safe_write) - std.Thread.Mutex{} - else - DummyMutex{}; - - const DummyMutex = struct { - fn lock(_: *DummyMutex) void {} - fn unlock(_: *DummyMutex) void {} - }; }; -} -/// A type-erased Transport. -pub const AnyTransport = struct { - impl: struct { - transport: *anyopaque, - readJsonMessage: *const fn (transport: *anyopaque, allocator: std.mem.Allocator) (std.mem.Allocator.Error || ReadError)![]u8, - writeJsonMessage: *const fn (transport: *anyopaque, json_message: []const u8) WriteError!void, - }, - - pub const ReadError = std.posix.ReadError || error{EndOfStream} || BaseProtocolHeader.ParseError; - pub const WriteError = std.posix.WriteError; - - pub fn readJsonMessage(transport: AnyTransport, allocator: std.mem.Allocator) (std.mem.Allocator.Error || ReadError)![]u8 { - return try transport.impl.readJsonMessage(transport.impl.transport, allocator); + pub fn readJsonMessage(transport: *Transport, allocator: std.mem.Allocator) ReadError![]u8 { + return try transport.vtable.readJsonMessage(transport, allocator); } - pub fn writeJsonMessage(transport: AnyTransport, json_message: []const u8) WriteError!void { - try transport.impl.writeJsonMessage(transport.impl.transport, json_message); + pub fn writeJsonMessage(transport: *Transport, json_message: []const u8) WriteError!void { + return try transport.vtable.writeJsonMessage(transport, json_message); } pub fn writeRequest( - transport: AnyTransport, + transport: *Transport, allocator: std.mem.Allocator, id: JsonRPCMessage.ID, method: []const u8, comptime Params: type, params: Params, options: std.json.StringifyOptions, - ) (std.mem.Allocator.Error || WriteError)!void { + ) (WriteError || std.mem.Allocator.Error)!void { const request: TypedJsonRPCRequest(Params) = .{ .id = id, .method = method, @@ -1514,38 +1217,14 @@ pub const AnyTransport = struct { try transport.writeJsonMessage(json_message); } - test writeRequest { - var testing_transport: TestingTransport = .initWriteOnly(std.testing.allocator); - defer testing_transport.deinit(); - - try writeRequest( - testing_transport.any(), - std.testing.allocator, - .{ .number = 0 }, - "my/method", - void, - {}, - .{ .whitespace = .indent_2 }, - ); - - try std.testing.expectEqualStrings("Content-Length: 76\r\n\r\n" ++ - \\{ - \\ "jsonrpc": "2.0", - \\ "id": 0, - \\ "method": "my/method", - \\ "params": null - \\} - , testing_transport.getWritten()); - } - pub fn writeNotification( - transport: AnyTransport, + transport: *Transport, allocator: std.mem.Allocator, method: []const u8, comptime Params: type, params: Params, options: std.json.StringifyOptions, - ) (std.mem.Allocator.Error || WriteError)!void { + ) (WriteError || std.mem.Allocator.Error)!void { const request: TypedJsonRPCNotification(Params) = .{ .method = method, .params = params, @@ -1555,36 +1234,14 @@ pub const AnyTransport = struct { try transport.writeJsonMessage(json_message); } - test writeNotification { - var testing_transport: TestingTransport = .initWriteOnly(std.testing.allocator); - defer testing_transport.deinit(); - - try writeNotification( - testing_transport.any(), - std.testing.allocator, - "my/method", - void, - {}, - .{ .whitespace = .indent_2 }, - ); - - try std.testing.expectEqualStrings("Content-Length: 65\r\n\r\n" ++ - \\{ - \\ "jsonrpc": "2.0", - \\ "method": "my/method", - \\ "params": null - \\} - , testing_transport.getWritten()); - } - pub fn writeResponse( - transport: AnyTransport, + transport: *Transport, allocator: std.mem.Allocator, id: ?JsonRPCMessage.ID, comptime Result: type, result: Result, options: std.json.StringifyOptions, - ) (std.mem.Allocator.Error || WriteError)!void { + ) (WriteError || std.mem.Allocator.Error)!void { const request: TypedJsonRPCResponse(Result) = .{ .id = id, .result_or_error = .{ .result = result }, @@ -1594,35 +1251,13 @@ pub const AnyTransport = struct { try transport.writeJsonMessage(json_message); } - test writeResponse { - var testing_transport: TestingTransport = .initWriteOnly(std.testing.allocator); - defer testing_transport.deinit(); - - try writeResponse( - testing_transport.any(), - std.testing.allocator, - .{ .number = 0 }, - void, - {}, - .{ .whitespace = .indent_2 }, - ); - - try std.testing.expectEqualStrings("Content-Length: 51\r\n\r\n" ++ - \\{ - \\ "jsonrpc": "2.0", - \\ "id": 0, - \\ "result": null - \\} - , testing_transport.getWritten()); - } - pub fn writeErrorResponse( - transport: AnyTransport, + transport: *Transport, allocator: std.mem.Allocator, id: ?JsonRPCMessage.ID, err: JsonRPCMessage.Response.Error, options: std.json.StringifyOptions, - ) (std.mem.Allocator.Error || WriteError)!void { + ) (WriteError || std.mem.Allocator.Error)!void { const request: TypedJsonRPCResponse(void) = .{ .id = id, .result_or_error = .{ .@"error" = err }, @@ -1631,32 +1266,270 @@ pub const AnyTransport = struct { defer allocator.free(json_message); try transport.writeJsonMessage(json_message); } +}; - test writeErrorResponse { - var testing_transport: TestingTransport = .initWriteOnly(std.testing.allocator); - defer testing_transport.deinit(); +pub const ThreadSafeTransportConfig = struct { + /// Makes `readJsonMessage` thread-safe. + thread_safe_read: bool, + /// Makes `writeJsonMessage` thread-safe. + thread_safe_write: bool, + MutexType: type = std.Thread.Mutex, +}; - try writeErrorResponse( - testing_transport.any(), - std.testing.allocator, - null, - .{ .code = .internal_error, .message = "my message" }, - .{ .whitespace = .indent_2 }, - ); +/// Wraps a non-thread-safe transport and makes it thread-safe. +pub fn ThreadSafeTransport(config: ThreadSafeTransportConfig) type { + return struct { + transport: Transport, + child_transport: *Transport, + in_mutex: @TypeOf(in_mutex_init) = in_mutex_init, + out_mutex: @TypeOf(out_mutex_init) = out_mutex_init, - try std.testing.expectEqualStrings("Content-Length: 120\r\n\r\n" ++ - \\{ - \\ "jsonrpc": "2.0", - \\ "id": null, - \\ "error": { - \\ "code": -32603, - \\ "message": "my message", - \\ "data": null - \\ } - \\} - , testing_transport.getWritten()); - } -}; + // Is there any better name of this? + const Self = @This(); + + pub fn init(child_transport: *Transport) Self { + return .{ + .transport = .{ + .vtable = &.{ + .readJsonMessage = Self.readJsonMessage, + .writeJsonMessage = Self.writeJsonMessage, + }, + }, + .child_transport = child_transport, + }; + } + + pub fn readJsonMessage(transport: *Transport, allocator: std.mem.Allocator) Transport.ReadError![]u8 { + const self: *Self = @fieldParentPtr("transport", transport); + + self.in_mutex.lock(); + defer self.in_mutex.unlock(); + + return try self.child_transport.readJsonMessage(allocator); + } + + pub fn writeJsonMessage(transport: *Transport, json_message: []const u8) Transport.WriteError!void { + const self: *Self = @fieldParentPtr("transport", transport); + + self.out_mutex.lock(); + defer self.out_mutex.unlock(); + + return try self.child_transport.writeJsonMessage(json_message); + } + + const in_mutex_init = if (config.thread_safe_read) + config.MutexType{} + else + DummyMutex{}; + + const out_mutex_init = if (config.thread_safe_write) + config.MutexType{} + else + DummyMutex{}; + + const DummyMutex = struct { + fn lock(_: *DummyMutex) void {} + fn unlock(_: *DummyMutex) void {} + }; + }; +} + +pub fn readJsonMessage(reader: *std.io.Reader, allocator: std.mem.Allocator) (std.io.Reader.ReadAllocError || BaseProtocolHeader.ParseError)![]u8 { + const header: BaseProtocolHeader = try .parse(reader); + return try reader.readAlloc(allocator, header.content_length); +} + +test readJsonMessage { + var reader: std.io.Reader = .fixed("Content-Length: 2\r\n\r\n{}"); + + const json_message = try readJsonMessage(&reader, std.testing.allocator); + defer std.testing.allocator.free(json_message); + + try std.testing.expectEqualStrings("{}", json_message); +} + +pub fn writeJsonMessage(writer: *std.io.Writer, json_message: []const u8) std.io.Writer.Error!void { + const header: BaseProtocolHeader = .{ .content_length = json_message.len }; + var buffer: [64]u8 = undefined; + const prefix = std.fmt.bufPrint(&buffer, "{f}", .{header}) catch unreachable; + var data: [2][]const u8 = .{ prefix, json_message }; + try writer.writeVecAll(&data); + try writer.flush(); +} + +test writeJsonMessage { + var aw: std.io.Writer.Allocating = .init(std.testing.allocator); + defer aw.deinit(); + + try writeJsonMessage(&aw.writer, "{}"); + try std.testing.expectEqualStrings("Content-Length: 2\r\n\r\n{}", aw.getWritten()); +} + +pub fn writeRequest( + writer: *std.io.Writer, + allocator: std.mem.Allocator, + id: JsonRPCMessage.ID, + method: []const u8, + comptime Params: type, + params: Params, + options: std.json.StringifyOptions, +) (std.io.Writer.Error || std.mem.Allocator.Error)!void { + const request: TypedJsonRPCRequest(Params) = .{ + .id = id, + .method = method, + .params = params, + }; + const json_message = try std.json.stringifyAlloc(allocator, request, options); + defer allocator.free(json_message); + try writeJsonMessage(writer, json_message); +} + +test writeRequest { + var buffer: std.ArrayListUnmanaged(u8) = .empty; + var aw: std.io.Writer.Allocating = .fromArrayList(std.testing.allocator, &buffer); + defer aw.deinit(); + + try writeRequest( + &aw.writer, + std.testing.allocator, + .{ .number = 0 }, + "my/method", + void, + {}, + .{ .whitespace = .indent_2 }, + ); + + try std.testing.expectEqualStrings("Content-Length: 76\r\n\r\n" ++ + \\{ + \\ "jsonrpc": "2.0", + \\ "id": 0, + \\ "method": "my/method", + \\ "params": null + \\} + , aw.getWritten()); +} + +pub fn writeNotification( + writer: *std.io.Writer, + allocator: std.mem.Allocator, + method: []const u8, + comptime Params: type, + params: Params, + options: std.json.StringifyOptions, +) (std.io.Writer.Error || std.mem.Allocator.Error)!void { + const request: TypedJsonRPCNotification(Params) = .{ + .method = method, + .params = params, + }; + const json_message = try std.json.stringifyAlloc(allocator, request, options); + defer allocator.free(json_message); + try writeJsonMessage(writer, json_message); +} + +test writeNotification { + var buffer: std.ArrayListUnmanaged(u8) = .empty; + var aw: std.io.Writer.Allocating = .fromArrayList(std.testing.allocator, &buffer); + defer aw.deinit(); + + try writeNotification( + &aw.writer, + std.testing.allocator, + "my/method", + void, + {}, + .{ .whitespace = .indent_2 }, + ); + + try std.testing.expectEqualStrings("Content-Length: 65\r\n\r\n" ++ + \\{ + \\ "jsonrpc": "2.0", + \\ "method": "my/method", + \\ "params": null + \\} + , aw.getWritten()); +} + +pub fn writeResponse( + writer: *std.io.Writer, + allocator: std.mem.Allocator, + id: ?JsonRPCMessage.ID, + comptime Result: type, + result: Result, + options: std.json.StringifyOptions, +) (std.io.Writer.Error || std.mem.Allocator.Error)!void { + const request: TypedJsonRPCResponse(Result) = .{ + .id = id, + .result_or_error = .{ .result = result }, + }; + const json_message = try std.json.stringifyAlloc(allocator, request, options); + defer allocator.free(json_message); + try writeJsonMessage(writer, json_message); +} + +test writeResponse { + var buffer: std.ArrayListUnmanaged(u8) = .empty; + var aw: std.io.Writer.Allocating = .fromArrayList(std.testing.allocator, &buffer); + defer aw.deinit(); + + try writeResponse( + &aw.writer, + std.testing.allocator, + .{ .number = 0 }, + void, + {}, + .{ .whitespace = .indent_2 }, + ); + + try std.testing.expectEqualStrings("Content-Length: 51\r\n\r\n" ++ + \\{ + \\ "jsonrpc": "2.0", + \\ "id": 0, + \\ "result": null + \\} + , aw.getWritten()); +} + +pub fn writeErrorResponse( + writer: *std.io.Writer, + allocator: std.mem.Allocator, + id: ?JsonRPCMessage.ID, + err: JsonRPCMessage.Response.Error, + options: std.json.StringifyOptions, +) (std.io.Writer.Error || std.mem.Allocator.Error)!void { + const request: TypedJsonRPCResponse(void) = .{ + .id = id, + .result_or_error = .{ .@"error" = err }, + }; + const json_message = try std.json.stringifyAlloc(allocator, request, options); + defer allocator.free(json_message); + try writeJsonMessage(writer, json_message); +} + +test writeErrorResponse { + var buffer: std.ArrayListUnmanaged(u8) = .empty; + var aw: std.io.Writer.Allocating = .fromArrayList(std.testing.allocator, &buffer); + defer aw.deinit(); + + try writeErrorResponse( + &aw.writer, + std.testing.allocator, + null, + .{ .code = .internal_error, .message = "my message" }, + .{ .whitespace = .indent_2 }, + ); + + try std.testing.expectEqualStrings("Content-Length: 120\r\n\r\n" ++ + \\{ + \\ "jsonrpc": "2.0", + \\ "id": null, + \\ "error": { + \\ "code": -32603, + \\ "message": "my message", + \\ "data": null + \\ } + \\} + , aw.getWritten()); +} pub const minimum_logging_buffer_size: usize = 128; @@ -1676,8 +1549,8 @@ pub fn bufPrintLogMessage( buffer, message_type, struct { - fn format(writer: std.io.AnyWriter, opaque_params: *const anyopaque) void { - std.fmt.format(writer, fmt, @as(*const @TypeOf(args), @alignCast(@ptrCast(opaque_params))).*) catch {}; + fn format(writer: *std.io.Writer, opaque_params: *const anyopaque) std.io.Writer.Error!void { + return writer.print(fmt, @as(*const @TypeOf(args), @alignCast(@ptrCast(opaque_params))).*); } }.format, &args, @@ -1687,88 +1560,110 @@ pub fn bufPrintLogMessage( fn bufPrintLogMessageTypeErased( buffer: []u8, message_type: types.MessageType, - format_fn: *const fn (std.io.AnyWriter, opaque_params: *const anyopaque) void, + format_fn: *const fn (*std.io.Writer, opaque_params: *const anyopaque) std.io.Writer.Error!void, opaque_params: *const anyopaque, ) []u8 { std.debug.assert(buffer.len >= minimum_logging_buffer_size); - const json_message_suffix: []const u8 = "\"}}"; - var fbs = std.io.fixedBufferStream(buffer[0 .. buffer.len - json_message_suffix.len]); - const writer = fbs.writer(); + var writer: std.io.Writer = .fixed(buffer); writer.print( - \\{{"jsonrpc":"2.0","method":"window/logMessage","params":{{"type":{},"message":" - , .{std.json.fmt(message_type, .{})}) catch unreachable; - - const json_writer: std.io.Writer(*std.io.FixedBufferStream([]u8), error{NoSpaceLeft}, jsonWrite) = .{ - .context = &fbs, - }; + \\{{"jsonrpc":"2.0","method":"window/logMessage","params":{{"type":{f},"message":" + , .{parser.jsonFmt(message_type, .{})}) catch unreachable; - format_fn(json_writer.any(), opaque_params); + const json_message_suffix = "\"}}".*; + const ellipses = "...".*; - fbs.buffer = buffer; - fbs.writer().writeAll(json_message_suffix) catch unreachable; + const no_space_left = no_space_left: { + const reserved_trailing_buffer_space = json_message_suffix.len + ellipses.len; + writer.buffer.len -= reserved_trailing_buffer_space; + defer writer.buffer.len += reserved_trailing_buffer_space; - return fbs.getWritten(); -} - -fn jsonWrite(fbs: *std.io.FixedBufferStream([]u8), bytes: []const u8) error{NoSpaceLeft}!usize { - var write_cursor: usize = 0; - var i: usize = 0; - while (i < bytes.len) : (i += 1) { - switch (bytes[i]) { - 0x20...0x21, 0x23...0x5B, 0x5D...0xFF => {}, - 0x00...0x1F, '\\', '\"' => { - try fbsWriteOrEllipses(fbs, bytes[write_cursor..i]); - - // either write an escape code in its entirety or don't at all - const pos = fbs.pos; - errdefer fbs.pos = pos; + var json_transform: JsonTransform = .init(&writer); + format_fn(&json_transform.interface, opaque_params) catch break :no_space_left true; + break :no_space_left false; + }; - const writer = fbs.writer(); + if (no_space_left) (writer.writableArray(ellipses.len) catch undefined).* = ellipses; + (writer.writableArray(json_message_suffix.len) catch undefined).* = json_message_suffix; - switch (bytes[i]) { - '\\' => try writer.writeAll("\\\\"), - '\"' => try writer.writeAll("\\\""), - 0x08 => try writer.writeAll("\\b"), - 0x0C => try writer.writeAll("\\f"), - '\n' => try writer.writeAll("\\n"), - '\r' => try writer.writeAll("\\r"), - '\t' => try writer.writeAll("\\t"), - else => { - try fbs.writer().writeAll("\\u"); - try std.fmt.formatIntValue(bytes[i], "x", std.fmt.FormatOptions{ .width = 4, .fill = '0' }, fbs.writer()); - }, - } + return writer.buffered(); +} - write_cursor = i + 1; +/// `std.json.encodeJsonString` but streaming. Write into `interface` and the +/// escaped JSON string will be written into `out`. +/// +/// The output will only be written into `out.buffer` and will never be +/// drained. So it is best used with `std.io.Writer.fixed` to write into a +/// fixed sized buffer. +const JsonTransform = struct { + out: *std.io.Writer, + interface: std.io.Writer, + + fn init(out: *std.io.Writer) JsonTransform { + return .{ + .out = out, + .interface = .{ + .vtable = &.{ + .drain = drain, + .flush = std.io.Writer.noopFlush, + }, + .buffer = &.{}, }, - } + }; } - try fbsWriteOrEllipses(fbs, bytes[write_cursor..]); - return bytes.len; -} - -fn fbsWriteOrEllipses( - fbs: *std.io.FixedBufferStream([]u8), - bytes: []const u8, -) error{NoSpaceLeft}!void { - const ellipses: []const u8 = "..."; - - const pos_before_write = fbs.pos; - const amt = fbs.write(bytes) catch 0; - if (amt == bytes.len) return; + fn drain(w: *std.io.Writer, data: []const []const u8, splat: usize) std.io.Writer.Error!usize { + if (data.len == 0) return 0; + + const json_transform: *JsonTransform = @fieldParentPtr("interface", w); + const out = json_transform.out; + + var bytes_written: usize = 0; + outer: for (data, 0..) |bytes, i| { + const is_last = i == data.len - 1; + for (0..if (is_last) splat else 1) |_| { + for (bytes) |c| { + switch (c) { + 0x20...0x21, 0x23...0x5B, 0x5D...0xFF => { + const buffer = out.unusedCapacitySlice(); + if (buffer.len < 1) break :outer; + buffer[0] = c; + out.advance(1); + bytes_written += 1; + }, + 0x00...0x1F, '\\', '\"' => { + const encoded: [2]u8 = switch (c) { + '\\' => "\\\\".*, + '\"' => "\\\"".*, + 0x08 => "\\b".*, + 0x0C => "\\f".*, + '\n' => "\\n".*, + '\r' => "\\r".*, + '\t' => "\\t".*, + else => { + const buffer = out.unusedCapacitySlice(); + if (buffer.len < 6) break :outer; + + buffer[0..2].* = "\\u".*; + const amt = std.fmt.printInt(buffer[2..6], c, 10, .lower, .{ .fill = '0', .width = 4 }); + std.debug.assert(amt == 4); + out.advance(6); + bytes_written += 1; + continue; + }, + }; - // try to move the buffer position back so that we have space for the ellipses - fbs.pos = @max( - pos_before_write, // make sure that we don't backtrack beyond an escape code - @min(fbs.pos, fbs.buffer.len - ellipses.len), - ); - if (fbs.buffer.len - fbs.pos >= ellipses.len) { - fbs.writer().writeAll(ellipses) catch unreachable; + (try out.writableArray(encoded.len)).* = encoded; + bytes_written += 1; + }, + } + } + } + } + if (bytes_written == 0) return error.WriteFailed; + return bytes_written; } - return error.NoSpaceLeft; -} +}; test bufPrintLogMessage { var buffer: [1024]u8 = undefined; @@ -1822,7 +1717,7 @@ test "bufPrintLogMessage - avoid buffer overflow with escape codes" { ); try std.testing.expectEqualStrings( - \\{"jsonrpc":"2.0","method":"window/logMessage","params":{"type":42,"message":"\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000"}} + \\{"jsonrpc":"2.0","method":"window/logMessage","params":{"type":42,"message":"\u0000\u0000\u0000\u0000\u0000\u0000\u0000..."}} , json_message); } @@ -2524,23 +2419,23 @@ test "Message - ignore_unknown_fields" { test "Message - stringify emit_null_optional_fields" { try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"exit"} - , "{}", .{std.json.fmt(ExampleMessage{ .notification = .{ .params = .exit } }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(ExampleMessage{ .notification = .{ .params = .exit } }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"exit","params":null} - , "{}", .{std.json.fmt(ExampleMessage{ .notification = .{ .params = .exit } }, .{ .emit_null_optional_fields = true })}); + , "{f}", .{parser.jsonFmt(ExampleMessage{ .notification = .{ .params = .exit } }, .{ .emit_null_optional_fields = true })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"foo"} - , "{}", .{std.json.fmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = null } } } }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = null } } } }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"foo","params":null} - , "{}", .{std.json.fmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = null } } } }, .{ .emit_null_optional_fields = true })}); + , "{f}", .{parser.jsonFmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = null } } } }, .{ .emit_null_optional_fields = true })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"foo","params":null} - , "{}", .{std.json.fmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = .null } } } }, .{ .emit_null_optional_fields = false })}); + , "{f}", .{parser.jsonFmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = .null } } } }, .{ .emit_null_optional_fields = false })}); try std.testing.expectFmt( \\{"jsonrpc":"2.0","method":"foo","params":null} - , "{}", .{std.json.fmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = .null } } } }, .{ .emit_null_optional_fields = true })}); + , "{f}", .{parser.jsonFmt(ExampleMessage{ .notification = .{ .params = .{ .other = .{ .method = "foo", .params = .null } } } }, .{ .emit_null_optional_fields = true })}); } test "Message.Request" { diff --git a/src/main.zig b/src/main.zig index 515339e..8da4c67 100644 --- a/src/main.zig +++ b/src/main.zig @@ -18,13 +18,13 @@ pub fn main() !void { const parsed_meta_model = try std.json.parseFromSlice(MetaModel, gpa, @embedFile("meta-model"), .{}); defer parsed_meta_model.deinit(); - var buffer: std.ArrayListUnmanaged(u8) = .empty; - defer buffer.deinit(gpa); + var aw: std.io.Writer.Allocating = .init(gpa); + defer aw.deinit(); @setEvalBranchQuota(100_000); - try writeMetaModel(buffer.writer(gpa), parsed_meta_model.value); + writeMetaModel(&aw.writer, parsed_meta_model.value) catch return error.OutOfMemory; - const source = try buffer.toOwnedSliceSentinel(gpa, 0); + const source = try aw.toOwnedSliceSentinel(0); defer gpa.free(source); var zig_tree: std.zig.Ast = try .parse(gpa, source, .zig); @@ -44,53 +44,29 @@ pub fn main() !void { try out_file.writeAll(output_source); } -fn formatDocs( +const FormatDocs = struct { text: []const u8, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - _ = options; - if (fmt.len != 1) std.fmt.invalidFmtError(fmt, text); - const prefix = switch (fmt[0]) { - 'n' => "// ", - 'd' => "/// ", - '!' => "//! ", - else => std.fmt.invalidFmtError(fmt, text), - }; - var iterator = std.mem.splitScalar(u8, text, '\n'); - while (iterator.next()) |line| try writer.print("{s}{s}\n", .{ prefix, line }); -} + comment_kind: CommentKind, -/// The format specifier must be one of: -/// * `{n}` writes normal (`//`) comments. -/// * `{d}` writes doc-comments (`///`) comments. -/// * `{!}` writes top-level-doc-comments (`//!`) comments. -fn fmtDocs(text: []const u8) std.fmt.Formatter(formatDocs) { - return .{ .data = text }; -} + const CommentKind = enum { + normal, + doc, + top_level, + }; +}; -fn formatQuotedString( - string: []const u8, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - _ = options; - if (fmt.len == 0) { - try writer.print("\"{}\"", .{std.zig.fmtEscapes(string)}); - } else if (std.mem.eql(u8, fmt, "'")) { - try writer.print("\'{'}\'", .{std.zig.fmtEscapes(string)}); - } else { - std.fmt.invalidFmtError(fmt, string); - } +fn renderDocs(ctx: FormatDocs, writer: *std.io.Writer) std.io.Writer.Error!void { + const prefix = switch (ctx.comment_kind) { + .normal => "// ", + .doc => "/// ", + .top_level => "//! ", + }; + var iterator = std.mem.splitScalar(u8, ctx.text, '\n'); + while (iterator.next()) |line| try writer.print("{s}{s}\n", .{ prefix, line }); } -/// The format specifier must be one of: -/// * `{}` writes a double-quoted string. -/// * `{''}` writes a single-quoted string. -fn fmtQuotedString(string: []const u8) std.fmt.Formatter(formatQuotedString) { - return .{ .data = string }; +fn fmtDocs(text: []const u8, comment_kind: FormatDocs.CommentKind) std.fmt.Formatter(FormatDocs, renderDocs) { + return .{ .data = .{ .text = text, .comment_kind = comment_kind } }; } fn messageDirectionName(message_direction: MetaModel.MessageDirection) []const u8 { @@ -101,7 +77,7 @@ fn messageDirectionName(message_direction: MetaModel.MessageDirection) []const u }; } -fn guessTypeName(meta_model: MetaModel, writer: anytype, typ: MetaModel.Type, i: usize) @TypeOf(writer).Error!void { +fn guessTypeName(meta_model: MetaModel, writer: *std.io.Writer, typ: MetaModel.Type, i: usize) std.io.Writer.Error!void { switch (typ) { .base => |base| switch (base.name) { .URI => try writer.writeAll("uri"), @@ -114,7 +90,7 @@ fn guessTypeName(meta_model: MetaModel, writer: anytype, typ: MetaModel.Type, i: .boolean => try writer.writeAll("bool"), .null => try writer.writeAll("@\"null\""), }, - .reference => |ref| try writer.print("{p}", .{std.zig.fmtId(ref.name)}), + .reference => |ref| try writer.print("{f}", .{std.zig.fmtId(ref.name)}), .array => |arr| { try writer.writeAll("array_of_"); try guessTypeName(meta_model, writer, arr.element.*, 0); @@ -144,15 +120,13 @@ fn isTypeNull(typ: MetaModel.Type) bool { return (ort.items.len == 2 and ort.items[1] == .base and ort.items[1].base.name == .null) or (ort.items[ort.items.len - 1] == .base and ort.items[ort.items.len - 1].base.name == .null); } -fn formatType( - data: struct { meta_model: *const MetaModel, ty: MetaModel.Type }, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - _ = options; - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, data.ty); - switch (data.ty) { +const FormatType = struct { + meta_model: *const MetaModel, + ty: MetaModel.Type, +}; + +fn renderType(ctx: FormatType, writer: *std.io.Writer) std.io.Writer.Error!void { + switch (ctx.ty) { .base => |base| switch (base.name) { .URI => try writer.writeAll("URI"), .DocumentUri => try writer.writeAll("DocumentUri"), @@ -164,8 +138,8 @@ fn formatType( .boolean => try writer.writeAll("bool"), .null => try writer.writeAll("?void"), }, - .reference => |ref| try writer.print("{p}", .{std.zig.fmtId(ref.name)}), - .array => |arr| try writer.print("[]const {}", .{fmtType(arr.element.*, data.meta_model)}), + .reference => |ref| try writer.print("{f}", .{std.zig.fmtId(ref.name)}), + .array => |arr| try writer.print("[]const {f}", .{fmtType(arr.element.*, ctx.meta_model)}), .map => |map| { try writer.writeAll("parser.Map("); switch (map.key) { @@ -175,17 +149,17 @@ fn formatType( .integer => writer.writeAll("i32"), .string => writer.writeAll("[]const u8"), }, - .reference => |ref| try writer.print("{}", .{fmtType(.{ .reference = ref }, data.meta_model)}), + .reference => |ref| try writer.print("{f}", .{fmtType(.{ .reference = ref }, ctx.meta_model)}), } - try writer.print(", {})", .{fmtType(map.value.*, data.meta_model)}); + try writer.print(", {f})", .{fmtType(map.value.*, ctx.meta_model)}); }, .@"and" => |andt| { try writer.writeAll("struct {\n"); for (andt.items) |item| { if (item != .reference) @panic("Unimplemented and subject encountered!"); - try writer.print("// And {s}\n{}\n\n", .{ + try writer.print("// And {s}\n{f}\n\n", .{ item.reference.name, - fmtReference(item.reference, null, data.meta_model), + fmtReference(item.reference, null, ctx.meta_model), }); } try writer.writeAll("}"); @@ -195,7 +169,7 @@ fn formatType( // There are no triple optional ors (I believe), // so this should work every time if (ort.items.len == 2 and ort.items[1] == .base and ort.items[1].base.name == .null) { - try writer.print("?{}", .{fmtType(ort.items[0], data.meta_model)}); + try writer.print("?{f}", .{fmtType(ort.items[0], ctx.meta_model)}); } else if (isOrActuallyEnum(ort)) { try writer.writeAll("enum {"); for (ort.items) |sub_type| { @@ -209,8 +183,8 @@ fn formatType( try writer.writeAll("union(enum) {\n"); for (ort.items[0..if (has_null) ort.items.len - 1 else ort.items.len], 0..) |sub_type, i| { - try guessTypeName(data.meta_model.*, writer, sub_type, i); - try writer.print(": {},\n", .{fmtType(sub_type, data.meta_model)}); + try guessTypeName(ctx.meta_model.*, writer, sub_type, i); + try writer.print(": {f},\n", .{fmtType(sub_type, ctx.meta_model)}); } try writer.writeAll( \\pub const jsonParse = parser.UnionParser(@This()).jsonParse; @@ -224,7 +198,7 @@ fn formatType( try writer.writeAll("struct {"); for (tup.items, 0..) |ty, i| { if (i != 0) try writer.writeByte(','); - try writer.print(" {}", .{fmtType(ty, data.meta_model)}); + try writer.print(" {f}", .{fmtType(ty, ctx.meta_model)}); } try writer.writeAll(" }"); }, @@ -232,69 +206,60 @@ fn formatType( try writer.writeAll("struct {"); if (lit.value.properties.len != 0) { for (lit.value.properties) |property| { - try writer.print("\n{}", .{fmtProperty(property, data.meta_model)}); + try writer.print("\n{f}", .{fmtProperty(property, ctx.meta_model)}); } try writer.writeByte('\n'); } try writer.writeByte('}'); }, - .stringLiteral => |lit| try writer.print("[]const u8 = \"{}\"", .{std.zig.fmtEscapes(lit.value)}), + .stringLiteral => |lit| try writer.print("[]const u8 = \"{f}\"", .{std.zig.fmtString(lit.value)}), .integerLiteral => |lit| try writer.print("i32 = {d}", .{lit.value}), .booleanLiteral => |lit| try writer.print("bool = {}", .{lit.value}), } } -fn fmtType(ty: MetaModel.Type, meta_model: *const MetaModel) std.fmt.Formatter(formatType) { +fn fmtType(ty: MetaModel.Type, meta_model: *const MetaModel) std.fmt.Formatter(FormatType, renderType) { return .{ .data = .{ .meta_model = meta_model, .ty = ty } }; } -fn formatProperty( - data: struct { meta_model: *const MetaModel, property: MetaModel.Property }, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, data.value); - _ = options; - - const isUndefinedable = data.property.optional orelse false; - const isNull = isTypeNull(data.property.type); +const FormatProperty = struct { + meta_model: *const MetaModel, + property: MetaModel.Property, +}; + +fn renderProperty(ctx: FormatProperty, writer: *std.io.Writer) std.io.Writer.Error!void { + const isUndefinedable = ctx.property.optional orelse false; + const isNull = isTypeNull(ctx.property.type); // WORKAROUND: recursive SelectionRange - const isSelectionRange = data.property.type == .reference and std.mem.eql(u8, data.property.type.reference.name, "SelectionRange"); + const isSelectionRange = ctx.property.type == .reference and std.mem.eql(u8, ctx.property.type.reference.name, "SelectionRange"); - if (data.property.documentation) |docs| try writer.print("{d}", .{fmtDocs(docs)}); + if (ctx.property.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .doc)}); - try writer.print("{p}: {s}{}{s},", .{ - std.zig.fmtId(data.property.name), + try writer.print("{f}: {s}{f}{s},", .{ + std.zig.fmtIdPU(ctx.property.name), if (isSelectionRange) "?*" else if (isUndefinedable and !isNull) "?" else "", - fmtType(data.property.type, data.meta_model), + fmtType(ctx.property.type, ctx.meta_model), if (isNull or isUndefinedable) " = null" else "", }); } -fn fmtProperty(property: MetaModel.Property, meta_model: *const MetaModel) std.fmt.Formatter(formatProperty) { +fn fmtProperty(property: MetaModel.Property, meta_model: *const MetaModel) std.fmt.Formatter(FormatProperty, renderProperty) { return .{ .data = .{ .meta_model = meta_model, .property = property } }; } -fn formatProperties( - data: struct { - meta_model: *const MetaModel, - structure: MetaModel.Structure, - maybe_extender: ?MetaModel.Structure, - }, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, data.value); - _ = options; - - const properties: []MetaModel.Property = data.structure.properties; - const extends: []MetaModel.Type = data.structure.extends orelse &.{}; - const mixins: []MetaModel.Type = data.structure.mixins orelse &.{}; +const FormatProperties = struct { + meta_model: *const MetaModel, + structure: MetaModel.Structure, + maybe_extender: ?MetaModel.Structure, +}; + +fn renderProperties(ctx: FormatProperties, writer: *std.io.Writer) std.io.Writer.Error!void { + const properties: []MetaModel.Property = ctx.structure.properties; + const extends: []MetaModel.Type = ctx.structure.extends orelse &.{}; + const mixins: []MetaModel.Type = ctx.structure.mixins orelse &.{}; skip: for (properties) |property| { - if (data.maybe_extender) |ext| { + if (ctx.maybe_extender) |ext| { for (ext.properties) |ext_property| { if (std.mem.eql(u8, property.name, ext_property.name)) { // std.log.info("Skipping implemented field emission: {s}", .{property.name}); @@ -302,135 +267,137 @@ fn formatProperties( } } } - try writer.print("\n{}", .{fmtProperty(property, data.meta_model)}); + try writer.print("\n{f}", .{fmtProperty(property, ctx.meta_model)}); } for (extends) |ext| { if (ext != .reference) @panic("Expected reference for extends!"); - try writer.print("\n\n// Extends `{s}`{}", .{ + try writer.print("\n\n// Extends `{s}`{f}", .{ ext.reference.name, - fmtReference(ext.reference, data.structure, data.meta_model), + fmtReference(ext.reference, ctx.structure, ctx.meta_model), }); } for (mixins) |ext| { if (ext != .reference) @panic("Expected reference for mixin!"); - try writer.print("\n\n// Uses mixin `{s}`{}", .{ + try writer.print("\n\n// Uses mixin `{s}`{f}", .{ ext.reference.name, - fmtReference(ext.reference, data.structure, data.meta_model), + fmtReference(ext.reference, ctx.structure, ctx.meta_model), }); } } -fn fmtProperties(structure: MetaModel.Structure, maybe_extender: ?MetaModel.Structure, meta_model: *const MetaModel) std.fmt.Formatter(formatProperties) { +fn fmtProperties( + structure: MetaModel.Structure, + maybe_extender: ?MetaModel.Structure, + meta_model: *const MetaModel, +) std.fmt.Formatter(FormatProperties, renderProperties) { return .{ .data = .{ .meta_model = meta_model, .structure = structure, .maybe_extender = maybe_extender } }; } -fn formatReference( - data: struct { - meta_model: *const MetaModel, - reference: MetaModel.ReferenceType, - maybe_extender: ?MetaModel.Structure, - }, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, data.reference); - _ = options; - - for (data.meta_model.structures) |s| { - if (std.mem.eql(u8, s.name, data.reference.name)) { - try writer.print("{}", .{fmtProperties(s, data.maybe_extender, data.meta_model)}); +const FormatReference = struct { + meta_model: *const MetaModel, + reference: MetaModel.ReferenceType, + maybe_extender: ?MetaModel.Structure, +}; + +fn renderReference(ctx: FormatReference, writer: *std.io.Writer) std.io.Writer.Error!void { + for (ctx.meta_model.structures) |s| { + if (std.mem.eql(u8, s.name, ctx.reference.name)) { + try writer.print("{f}", .{fmtProperties(s, ctx.maybe_extender, ctx.meta_model)}); return; } } } -fn fmtReference(reference: MetaModel.ReferenceType, maybe_extender: ?MetaModel.Structure, meta_model: *const MetaModel) std.fmt.Formatter(formatReference) { +fn fmtReference( + reference: MetaModel.ReferenceType, + maybe_extender: ?MetaModel.Structure, + meta_model: *const MetaModel, +) std.fmt.Formatter(FormatReference, renderReference) { return .{ .data = .{ .meta_model = meta_model, .reference = reference, .maybe_extender = maybe_extender } }; } -fn writeRequest(writer: anytype, meta_model: MetaModel, request: MetaModel.Request) @TypeOf(writer).Error!void { - if (request.documentation) |docs| try writer.print("{n}", .{fmtDocs(docs)}); +fn writeRequest(writer: *std.io.Writer, meta_model: MetaModel, request: MetaModel.Request) std.io.Writer.Error!void { + if (request.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .normal)}); try writer.print( \\.{{ \\ .method = "{s}", - \\ .documentation = {?}, + \\ .documentation = {?f}, \\ .direction = .{s}, - \\ .Params = {?}, - \\ .Result = {}, - \\ .PartialResult = {?}, - \\ .ErrorData = {?}, - \\ .registration = .{{ .method = {?}, .Options = {?} }}, + \\ .Params = {?f}, + \\ .Result = {f}, + \\ .PartialResult = {?f}, + \\ .ErrorData = {?f}, + \\ .registration = .{{ .method = {?f}, .Options = {?f} }}, \\}}, \\ , .{ request.method, - if (request.documentation) |documentation| fmtQuotedString(documentation) else null, + if (request.documentation) |documentation| jsonFmt(documentation, .{}) else null, messageDirectionName(request.messageDirection), // NOTE: Multiparams not used here, so we dont have to implement them :) if (request.params) |params| fmtType(params.Type, &meta_model) else null, fmtType(request.result, &meta_model), if (request.partialResult) |ty| fmtType(ty, &meta_model) else null, if (request.errorData) |ty| fmtType(ty, &meta_model) else null, - if (request.registrationMethod) |method| fmtQuotedString(method) else null, + if (request.registrationMethod) |method| jsonFmt(method, .{}) else null, if (request.registrationOptions) |ty| fmtType(ty, &meta_model) else null, }); } -fn writeNotification(writer: anytype, meta_model: MetaModel, notification: MetaModel.Notification) @TypeOf(writer).Error!void { - if (notification.documentation) |docs| try writer.print("{n}", .{fmtDocs(docs)}); +fn writeNotification(writer: *std.io.Writer, meta_model: MetaModel, notification: MetaModel.Notification) std.io.Writer.Error!void { + if (notification.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .normal)}); try writer.print( \\.{{ \\ .method = "{s}", - \\ .documentation = {?}, + \\ .documentation = {?f}, \\ .direction = .{s}, - \\ .Params = {?}, - \\ .registration = .{{ .method = {?}, .Options = {?} }}, + \\ .Params = {?f}, + \\ .registration = .{{ .method = {?f}, .Options = {?f} }}, \\}}, \\ , .{ notification.method, - if (notification.documentation) |documentation| fmtQuotedString(documentation) else null, + if (notification.documentation) |documentation| jsonFmt(documentation, .{}) else null, messageDirectionName(notification.messageDirection), // NOTE: Multiparams not used here, so we dont have to implement them :) if (notification.params) |params| fmtType(params.Type, &meta_model) else null, - if (notification.registrationMethod) |method| fmtQuotedString(method) else null, + if (notification.registrationMethod) |method| jsonFmt(method, .{}) else null, if (notification.registrationOptions) |ty| fmtType(ty, &meta_model) else null, }); } -fn writeStructure(writer: anytype, meta_model: MetaModel, structure: MetaModel.Structure) @TypeOf(writer).Error!void { +fn writeStructure(writer: *std.io.Writer, meta_model: MetaModel, structure: MetaModel.Structure) std.io.Writer.Error!void { if (std.mem.eql(u8, structure.name, "LSPObject")) return; - if (structure.documentation) |docs| try writer.print("{d}", .{fmtDocs(docs)}); - try writer.print("pub const {p} = struct {{{}\n}};\n\n", .{ + if (structure.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .doc)}); + try writer.print("pub const {f} = struct {{{f}\n}};\n\n", .{ std.zig.fmtId(structure.name), fmtProperties(structure, null, &meta_model), }); } -fn writeEnumeration(writer: anytype, meta_model: MetaModel, enumeration: MetaModel.Enumeration) @TypeOf(writer).Error!void { +fn writeEnumeration(writer: *std.io.Writer, meta_model: MetaModel, enumeration: MetaModel.Enumeration) std.io.Writer.Error!void { _ = meta_model; - if (enumeration.documentation) |docs| try writer.print("{d}", .{fmtDocs(docs)}); + if (enumeration.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .doc)}); const container_kind = switch (enumeration.type.name) { .string => "union(enum)", .integer => "enum(i32)", .uinteger => "enum(u32)", }; - try writer.print("pub const {p} = {s} {{\n", .{ std.zig.fmtId(enumeration.name), container_kind }); + try writer.print("pub const {f} = {s} {{\n", .{ std.zig.fmtId(enumeration.name), container_kind }); // WORKAROUND: the enumeration value `pascal` appears twice in LanguageKind var found_pascal = false; var contains_empty_enum = false; for (enumeration.values) |entry| { - if (entry.documentation) |docs| try writer.print("{d}", .{fmtDocs(docs)}); + if (entry.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .doc)}); switch (entry.value) { .string => |value| { if (std.mem.eql(u8, value, "pascal")) { @@ -439,9 +406,9 @@ fn writeEnumeration(writer: anytype, meta_model: MetaModel, enumeration: MetaMod } if (value.len == 0) contains_empty_enum = true; const name = if (value.len == 0) "empty" else value; - try writer.print("{p},\n", .{std.zig.fmtId(name)}); + try writer.print("{f},\n", .{std.zig.fmtIdP(name)}); }, - .number => |value| try writer.print("{p} = {d},\n", .{ std.zig.fmtId(entry.name), value }), + .number => |value| try writer.print("{f} = {d},\n", .{ std.zig.fmtIdP(entry.name), value }), } } @@ -472,14 +439,14 @@ fn writeEnumeration(writer: anytype, meta_model: MetaModel, enumeration: MetaMod try writer.writeAll("};\n\n"); } -fn writeTypeAlias(writer: anytype, meta_model: MetaModel, type_alias: MetaModel.TypeAlias) @TypeOf(writer).Error!void { +fn writeTypeAlias(writer: *std.io.Writer, meta_model: MetaModel, type_alias: MetaModel.TypeAlias) std.io.Writer.Error!void { if (std.mem.startsWith(u8, type_alias.name, "LSP")) return; - if (type_alias.documentation) |docs| try writer.print("{d}", .{fmtDocs(docs)}); - try writer.print("pub const {p} = {};\n\n", .{ std.zig.fmtId(type_alias.name), fmtType(type_alias.type, &meta_model) }); + if (type_alias.documentation) |docs| try writer.print("{f}", .{fmtDocs(docs, .doc)}); + try writer.print("pub const {f} = {f};\n\n", .{ std.zig.fmtId(type_alias.name), fmtType(type_alias.type, &meta_model) }); } -fn writeMetaModel(writer: anytype, meta_model: MetaModel) !void { +fn writeMetaModel(writer: *std.io.Writer, meta_model: MetaModel) std.io.Writer.Error!void { try writer.writeAll(@embedFile("lsp_types_base.zig") ++ "\n"); try writer.writeAll("// Type Aliases\n\n"); @@ -509,3 +476,23 @@ fn writeMetaModel(writer: anytype, meta_model: MetaModel) !void { } try writer.writeAll("};\n"); } + +/// Like `std.json.fmt` but supports `std.io.Writer`. +pub fn jsonFmt(value: anytype, options: std.json.StringifyOptions) std.fmt.Alt(FormatJson(@TypeOf(value)), FormatJson(@TypeOf(value)).format) { + return .{ .data = .{ .value = value, .options = options } }; +} + +fn FormatJson(comptime T: type) type { + return struct { + value: T, + options: std.json.StringifyOptions, + + pub fn format(data: @This(), writer: *std.io.Writer) std.io.Writer.Error!void { + const any_writer: std.io.AnyWriter = .{ + .context = @ptrCast(writer), + .writeFn = @ptrCast(&std.io.Writer.write), + }; + std.json.stringify(data.value, data.options, any_writer) catch |err| return @errorCast(err); + } + }; +} diff --git a/src/parser.zig b/src/parser.zig index 622400e..ebcdbd6 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -2,6 +2,26 @@ const std = @import("std"); +/// Like `std.json.fmt` but supports `std.io.Writer`. +pub fn jsonFmt(value: anytype, options: std.json.StringifyOptions) std.fmt.Alt(FormatJson(@TypeOf(value)), FormatJson(@TypeOf(value)).format) { + return .{ .data = .{ .value = value, .options = options } }; +} + +fn FormatJson(comptime T: type) type { + return struct { + value: T, + options: std.json.StringifyOptions, + + pub fn format(data: @This(), writer: *std.io.Writer) std.io.Writer.Error!void { + const any_writer: std.io.AnyWriter = .{ + .context = writer, + .writeFn = @ptrCast(&std.io.Writer.write), // cast discards const qualifier but should be fine. I hope + }; + std.json.stringify(data.value, data.options, any_writer) catch |err| return @errorCast(err); + } + }; +} + pub fn Map(comptime Key: type, comptime Value: type) type { if (Key != []const u8) @compileError("TODO support non string Key's"); return std.json.ArrayHashMap(Value); @@ -100,8 +120,8 @@ test "UnionParser.jsonStringify" { pub const jsonStringify = UnionParser(@This()).jsonStringify; }; - try std.testing.expectFmt("5", "{}", .{std.json.fmt(U{ .number = 5 }, .{})}); - try std.testing.expectFmt("\"foo\"", "{}", .{std.json.fmt(U{ .string = "foo" }, .{})}); + try std.testing.expectFmt("5", "{f}", .{jsonFmt(U{ .number = 5 }, .{})}); + try std.testing.expectFmt("\"foo\"", "{f}", .{jsonFmt(U{ .string = "foo" }, .{})}); } pub fn EnumCustomStringValues(comptime T: type, comptime contains_empty_enum: bool) type { @@ -184,11 +204,11 @@ test EnumCustomStringValues { pub const jsonStringify = EnumCustomStringValues(@This(), false).jsonStringify; }; - try std.testing.expectFmt("\"foo\"", "{}", .{std.json.fmt(E{ .foo = {} }, .{})}); - try std.testing.expectFmt("\"bar\"", "{}", .{std.json.fmt(E{ .bar = {} }, .{})}); - try std.testing.expectFmt("\"baz\"", "{}", .{std.json.fmt(E{ .baz = {} }, .{})}); - try std.testing.expectFmt("\"\"", "{}", .{std.json.fmt(E{ .custom_value = "" }, .{})}); - try std.testing.expectFmt("\"boo\"", "{}", .{std.json.fmt(E{ .custom_value = "boo" }, .{})}); + try std.testing.expectFmt("\"foo\"", "{f}", .{jsonFmt(E{ .foo = {} }, .{})}); + try std.testing.expectFmt("\"bar\"", "{f}", .{jsonFmt(E{ .bar = {} }, .{})}); + try std.testing.expectFmt("\"baz\"", "{f}", .{jsonFmt(E{ .baz = {} }, .{})}); + try std.testing.expectFmt("\"\"", "{f}", .{jsonFmt(E{ .custom_value = "" }, .{})}); + try std.testing.expectFmt("\"boo\"", "{f}", .{jsonFmt(E{ .custom_value = "boo" }, .{})}); try expectParseEqual(E, E.foo, "\"foo\""); try expectParseEqual(E, E.bar, "\"bar\""); @@ -219,13 +239,13 @@ test EnumCustomStringValues { const boo: E = .{ .unknown_value = "boo" }; const zoo: E = .{ .unknown_value = "zoo" }; - try std.testing.expectFmt("\"foo\"", "{}", .{std.json.fmt(foo, .{})}); - try std.testing.expectFmt("\"bar\"", "{}", .{std.json.fmt(bar, .{})}); - try std.testing.expectFmt("\"baz\"", "{}", .{std.json.fmt(baz, .{})}); - try std.testing.expectFmt("\"\"", "{}", .{std.json.fmt(empty, .{})}); - try std.testing.expectFmt("\"\"", "{}", .{std.json.fmt(custom_empty, .{})}); - try std.testing.expectFmt("\"boo\"", "{}", .{std.json.fmt(boo, .{})}); - try std.testing.expectFmt("\"zoo\"", "{}", .{std.json.fmt(zoo, .{})}); + try std.testing.expectFmt("\"foo\"", "{f}", .{jsonFmt(foo, .{})}); + try std.testing.expectFmt("\"bar\"", "{f}", .{jsonFmt(bar, .{})}); + try std.testing.expectFmt("\"baz\"", "{f}", .{jsonFmt(baz, .{})}); + try std.testing.expectFmt("\"\"", "{f}", .{jsonFmt(empty, .{})}); + try std.testing.expectFmt("\"\"", "{f}", .{jsonFmt(custom_empty, .{})}); + try std.testing.expectFmt("\"boo\"", "{f}", .{jsonFmt(boo, .{})}); + try std.testing.expectFmt("\"zoo\"", "{f}", .{jsonFmt(zoo, .{})}); try expectParseEqual(E, foo, "\"foo\""); try expectParseEqual(E, bar, "\"bar\""); @@ -267,10 +287,10 @@ test EnumCustomStringValues { const true_empty: E = .{ .custom_value = "" }; const boo: E = .{ .custom_value = "boo" }; - try std.testing.expectFmt("\"foo\"", "{}", .{std.json.fmt(foo, .{})}); - try std.testing.expectFmt("\"empty\"", "{}", .{std.json.fmt(empty, .{})}); - try std.testing.expectFmt("\"\"", "{}", .{std.json.fmt(true_empty, .{})}); - try std.testing.expectFmt("\"boo\"", "{}", .{std.json.fmt(boo, .{})}); + try std.testing.expectFmt("\"foo\"", "{f}", .{jsonFmt(foo, .{})}); + try std.testing.expectFmt("\"empty\"", "{f}", .{jsonFmt(empty, .{})}); + try std.testing.expectFmt("\"\"", "{f}", .{jsonFmt(true_empty, .{})}); + try std.testing.expectFmt("\"boo\"", "{f}", .{jsonFmt(boo, .{})}); try expectParseEqual(E, foo, "\"foo\""); try expectParseEqual(E, empty, "\"empty\""); @@ -306,9 +326,9 @@ test EnumStringifyAsInt { baz, pub const jsonStringify = EnumStringifyAsInt(@This()).jsonStringify; }; - try std.testing.expectFmt("0", "{}", .{std.json.fmt(E.foo, .{})}); - try std.testing.expectFmt("1", "{}", .{std.json.fmt(E.bar, .{})}); - try std.testing.expectFmt("2", "{}", .{std.json.fmt(E.baz, .{})}); + try std.testing.expectFmt("0", "{f}", .{jsonFmt(E.foo, .{})}); + try std.testing.expectFmt("1", "{f}", .{jsonFmt(E.bar, .{})}); + try std.testing.expectFmt("2", "{f}", .{jsonFmt(E.baz, .{})}); } { @@ -318,9 +338,9 @@ test EnumStringifyAsInt { baz = 5, pub const jsonStringify = EnumStringifyAsInt(@This()).jsonStringify; }; - try std.testing.expectFmt("2", "{}", .{std.json.fmt(E.foo, .{})}); - try std.testing.expectFmt("3", "{}", .{std.json.fmt(E.bar, .{})}); - try std.testing.expectFmt("5", "{}", .{std.json.fmt(E.baz, .{})}); + try std.testing.expectFmt("2", "{f}", .{jsonFmt(E.foo, .{})}); + try std.testing.expectFmt("3", "{f}", .{jsonFmt(E.bar, .{})}); + try std.testing.expectFmt("5", "{f}", .{jsonFmt(E.baz, .{})}); } { @@ -331,22 +351,19 @@ test EnumStringifyAsInt { _, pub const jsonStringify = EnumStringifyAsInt(@This()).jsonStringify; }; - try std.testing.expectFmt("0", "{}", .{std.json.fmt(E.foo, .{})}); - try std.testing.expectFmt("3", "{}", .{std.json.fmt(E.bar, .{})}); - try std.testing.expectFmt("4", "{}", .{std.json.fmt(E.baz, .{})}); - try std.testing.expectFmt("7", "{}", .{std.json.fmt(@as(E, @enumFromInt(7)), .{})}); + try std.testing.expectFmt("0", "{f}", .{jsonFmt(E.foo, .{})}); + try std.testing.expectFmt("3", "{f}", .{jsonFmt(E.bar, .{})}); + try std.testing.expectFmt("4", "{f}", .{jsonFmt(E.baz, .{})}); + try std.testing.expectFmt("7", "{f}", .{jsonFmt(@as(E, @enumFromInt(7)), .{})}); } } fn expectParseEqual(comptime T: type, comptime expected: anytype, s: []const u8) !void { - var arena_allocator = std.heap.ArenaAllocator.init(std.testing.allocator); + var arena_allocator: std.heap.ArenaAllocator = .init(std.testing.allocator); defer arena_allocator.deinit(); const arena = arena_allocator.allocator(); - const std_builtin_type_rename = comptime std.SemanticVersion.parse("0.14.0-dev.1346+31fef6f11") catch unreachable; - const error_set_tag = comptime if (@import("builtin").zig_version.order(std_builtin_type_rename) == .lt) .ErrorSet else .error_set; - - if (@typeInfo(@TypeOf(expected)) != error_set_tag) { + if (@typeInfo(@TypeOf(expected)) != .error_set) { const actual_from_slice = try std.json.parseFromSliceLeaky(T, arena, s, .{}); try std.testing.expectEqualDeep(@as(T, expected), actual_from_slice);