Skip to content

Commit f4f778b

Browse files
committed
add transform for assertEqual -> == in code blocks
1 parent 4b47b3b commit f4f778b

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

tools/bin/main.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ let main () =
5858
| ["-h"] | ["--help"] -> logAndExit (Ok formatDocstringsHelp)
5959
| path :: args -> (
6060
let isStdout = List.mem "--stdout" args in
61+
let transformAssertEqual = List.mem "--transform-assert-equal" args in
6162
let outputMode = if isStdout then `Stdout else `File in
6263
Clflags.color := Some Misc.Color.Never;
6364
match
6465
( Tools.FormatCodeblocks.formatCodeBlocksInFile ~outputMode
65-
~entryPointFile:path,
66+
~transformAssertEqual ~entryPointFile:path,
6667
outputMode )
6768
with
6869
| Ok content, `Stdout -> print_endline content

tools/src/tools.ml

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,56 @@ let extractEmbedded ~extensionPoints ~filename =
678678
|> List.rev |> array
679679

680680
module FormatCodeblocks = struct
681-
let formatRescriptCodeBlocks content ~displayFilename ~addError
682-
~markdownBlockStartLine =
681+
module Transform = struct
682+
type transform = AssertEqualFnToEquals (** assertEqual(a, b) -> a == b *)
683+
684+
(** Transforms for the code blocks themselves. *)
685+
let transform ~transforms ast =
686+
match transforms with
687+
| [] -> ast
688+
| transforms ->
689+
let hasTransform transform = transforms |> List.mem transform in
690+
let mapper =
691+
{
692+
Ast_mapper.default_mapper with
693+
expr =
694+
(fun mapper exp ->
695+
match exp.pexp_desc with
696+
| Pexp_apply
697+
{
698+
funct =
699+
{
700+
pexp_desc =
701+
Pexp_ident
702+
({txt = Lident "assertEqual"} as identTxt);
703+
} as ident;
704+
partial = false;
705+
args = [(Nolabel, _arg1); (Nolabel, _arg2)] as args;
706+
}
707+
when hasTransform AssertEqualFnToEquals ->
708+
{
709+
exp with
710+
pexp_desc =
711+
Pexp_apply
712+
{
713+
funct =
714+
{
715+
ident with
716+
pexp_desc =
717+
Pexp_ident {identTxt with txt = Lident "=="};
718+
};
719+
args;
720+
partial = false;
721+
transformed_jsx = false;
722+
};
723+
}
724+
| _ -> Ast_mapper.default_mapper.expr mapper exp);
725+
}
726+
in
727+
mapper.structure mapper ast
728+
end
729+
let formatRescriptCodeBlocks content ~transformAssertEqual ~displayFilename
730+
~addError ~markdownBlockStartLine =
683731
let open Cmarkit in
684732
(* Detect ReScript code blocks. *)
685733
let hadCodeBlocks = ref false in
@@ -718,6 +766,12 @@ module FormatCodeblocks = struct
718766
addError (Buffer.contents buf);
719767
code)
720768
else
769+
let parsetree =
770+
if transformAssertEqual then
771+
Transform.transform ~transforms:[AssertEqualFnToEquals]
772+
parsetree
773+
else parsetree
774+
in
721775
Res_printer.print_implementation
722776
~width:Res_multi_printer.default_print_width parsetree ~comments
723777
|> String.trim |> Block_line.list_of_string
@@ -737,7 +791,7 @@ module FormatCodeblocks = struct
737791
in
738792
(newContent, !hadCodeBlocks)
739793

740-
let formatCodeBlocksInFile ~outputMode ~entryPointFile =
794+
let formatCodeBlocksInFile ~outputMode ~transformAssertEqual ~entryPointFile =
741795
let path =
742796
match Filename.is_relative entryPointFile with
743797
| true -> Unix.realpath entryPointFile
@@ -746,7 +800,7 @@ module FormatCodeblocks = struct
746800
let errors = ref [] in
747801
let addError error = errors := error :: !errors in
748802

749-
let makeMapper ~displayFilename =
803+
let makeMapper ~transformAssertEqual ~displayFilename =
750804
{
751805
Ast_mapper.default_mapper with
752806
attribute =
@@ -756,7 +810,8 @@ module FormatCodeblocks = struct
756810
Some (contents, None),
757811
PStr [{pstr_desc = Pstr_eval ({pexp_loc}, _)}] ) ->
758812
let formattedContents, hadCodeBlocks =
759-
formatRescriptCodeBlocks ~addError ~displayFilename
813+
formatRescriptCodeBlocks ~transformAssertEqual ~addError
814+
~displayFilename
760815
~markdownBlockStartLine:pexp_loc.loc_start.pos_lnum contents
761816
in
762817
if hadCodeBlocks && formattedContents <> contents then
@@ -783,8 +838,8 @@ module FormatCodeblocks = struct
783838
in
784839
let displayFilename = Filename.basename path in
785840
let formattedContents, hadCodeBlocks =
786-
formatRescriptCodeBlocks ~addError ~displayFilename
787-
~markdownBlockStartLine:1 content
841+
formatRescriptCodeBlocks ~transformAssertEqual ~addError
842+
~displayFilename ~markdownBlockStartLine:1 content
788843
in
789844
if hadCodeBlocks && formattedContents <> content then
790845
Ok (formattedContents, content)
@@ -797,7 +852,9 @@ module FormatCodeblocks = struct
797852
parser ~filename:path
798853
in
799854
let filename = Filename.basename filename in
800-
let mapper = makeMapper ~displayFilename:filename in
855+
let mapper =
856+
makeMapper ~transformAssertEqual ~displayFilename:filename
857+
in
801858
let astMapped = mapper.structure mapper structure in
802859
Ok
803860
( Res_printer.print_implementation
@@ -810,7 +867,9 @@ module FormatCodeblocks = struct
810867
let {Res_driver.parsetree = signature; comments; source; filename} =
811868
parser ~filename:path
812869
in
813-
let mapper = makeMapper ~displayFilename:filename in
870+
let mapper =
871+
makeMapper ~transformAssertEqual ~displayFilename:filename
872+
in
814873
let astMapped = mapper.signature mapper signature in
815874
Ok
816875
( Res_printer.print_interface

0 commit comments

Comments
 (0)