@@ -678,8 +678,56 @@ let extractEmbedded ~extensionPoints ~filename =
678678 |> List. rev |> array
679679
680680module FormatCodeblocks = struct
681- let formatRescriptCodeBlocks content ~displayFilename ~addError
682- ~markdownBlockStartLine =
681+ module Transform = struct
682+ type transform = AssertEqualFnToEquals (* * assertEqual(a, b) -> a == b *)
683+
684+ (* * Transforms for the code blocks themselves. *)
685+ let transform ~transforms ast =
686+ match transforms with
687+ | [] -> ast
688+ | transforms ->
689+ let hasTransform transform = transforms |> List. mem transform in
690+ let mapper =
691+ {
692+ Ast_mapper. default_mapper with
693+ expr =
694+ (fun mapper exp ->
695+ match exp.pexp_desc with
696+ | Pexp_apply
697+ {
698+ funct =
699+ {
700+ pexp_desc =
701+ Pexp_ident
702+ ({txt = Lident " assertEqual" } as identTxt);
703+ } as ident;
704+ partial = false ;
705+ args = [(Nolabel , _arg1); (Nolabel , _arg2)] as args;
706+ }
707+ when hasTransform AssertEqualFnToEquals ->
708+ {
709+ exp with
710+ pexp_desc =
711+ Pexp_apply
712+ {
713+ funct =
714+ {
715+ ident with
716+ pexp_desc =
717+ Pexp_ident {identTxt with txt = Lident " ==" };
718+ };
719+ args;
720+ partial = false ;
721+ transformed_jsx = false ;
722+ };
723+ }
724+ | _ -> Ast_mapper. default_mapper.expr mapper exp);
725+ }
726+ in
727+ mapper.structure mapper ast
728+ end
729+ let formatRescriptCodeBlocks content ~transformAssertEqual ~displayFilename
730+ ~addError ~markdownBlockStartLine =
683731 let open Cmarkit in
684732 (* Detect ReScript code blocks. *)
685733 let hadCodeBlocks = ref false in
@@ -718,6 +766,12 @@ module FormatCodeblocks = struct
718766 addError (Buffer. contents buf);
719767 code)
720768 else
769+ let parsetree =
770+ if transformAssertEqual then
771+ Transform. transform ~transforms: [AssertEqualFnToEquals ]
772+ parsetree
773+ else parsetree
774+ in
721775 Res_printer. print_implementation
722776 ~width: Res_multi_printer. default_print_width parsetree ~comments
723777 |> String. trim |> Block_line. list_of_string
@@ -737,7 +791,7 @@ module FormatCodeblocks = struct
737791 in
738792 (newContent, ! hadCodeBlocks)
739793
740- let formatCodeBlocksInFile ~outputMode ~entryPointFile =
794+ let formatCodeBlocksInFile ~outputMode ~transformAssertEqual ~ entryPointFile =
741795 let path =
742796 match Filename. is_relative entryPointFile with
743797 | true -> Unix. realpath entryPointFile
@@ -746,7 +800,7 @@ module FormatCodeblocks = struct
746800 let errors = ref [] in
747801 let addError error = errors := error :: ! errors in
748802
749- let makeMapper ~displayFilename =
803+ let makeMapper ~transformAssertEqual ~ displayFilename =
750804 {
751805 Ast_mapper. default_mapper with
752806 attribute =
@@ -756,7 +810,8 @@ module FormatCodeblocks = struct
756810 Some (contents, None ),
757811 PStr [{pstr_desc = Pstr_eval ({pexp_loc}, _)}] ) ->
758812 let formattedContents, hadCodeBlocks =
759- formatRescriptCodeBlocks ~add Error ~display Filename
813+ formatRescriptCodeBlocks ~transform AssertEqual ~add Error
814+ ~display Filename
760815 ~markdown BlockStartLine:pexp_loc.loc_start.pos_lnum contents
761816 in
762817 if hadCodeBlocks && formattedContents <> contents then
@@ -783,8 +838,8 @@ module FormatCodeblocks = struct
783838 in
784839 let displayFilename = Filename. basename path in
785840 let formattedContents, hadCodeBlocks =
786- formatRescriptCodeBlocks ~add Error ~display Filename
787- ~markdown BlockStartLine:1 content
841+ formatRescriptCodeBlocks ~transform AssertEqual ~add Error
842+ ~display Filename ~ markdown BlockStartLine:1 content
788843 in
789844 if hadCodeBlocks && formattedContents <> content then
790845 Ok (formattedContents, content)
@@ -797,7 +852,9 @@ module FormatCodeblocks = struct
797852 parser ~filename: path
798853 in
799854 let filename = Filename. basename filename in
800- let mapper = makeMapper ~display Filename:filename in
855+ let mapper =
856+ makeMapper ~transform AssertEqual ~display Filename:filename
857+ in
801858 let astMapped = mapper.structure mapper structure in
802859 Ok
803860 ( Res_printer. print_implementation
@@ -810,7 +867,9 @@ module FormatCodeblocks = struct
810867 let {Res_driver. parsetree = signature; comments; source; filename} =
811868 parser ~filename: path
812869 in
813- let mapper = makeMapper ~display Filename:filename in
870+ let mapper =
871+ makeMapper ~transform AssertEqual ~display Filename:filename
872+ in
814873 let astMapped = mapper.signature mapper signature in
815874 Ok
816875 ( Res_printer. print_interface
0 commit comments