@@ -28,7 +28,7 @@ const headerComment = `// Code generated - DO NOT EDIT.
2828// AbigenArgs is the arguments to the abigen executable. E.g., Bin is the -bin
2929// arg.
3030type AbigenArgs struct {
31- Bin , ABI , Out , Type , Pkg string
31+ Bin , ABI , Out , Type , Pkg , ZkBinPath string
3232}
3333
3434// Abigen calls Abigen with the given arguments
@@ -72,6 +72,9 @@ func Abigen(a AbigenArgs) {
7272 }
7373
7474 ImproveAbigenOutput (a .Out , a .ABI )
75+ if a .ZkBinPath != "" {
76+ ImproveAbigenOutputZks (a .Out , a .ZkBinPath )
77+ }
7578}
7679
7780func ImproveAbigenOutput (path string , abiPath string ) {
@@ -466,3 +469,195 @@ func writeInterface(contractName string, fileNode *ast.File) *ast.File {
466469func addHeader (code []byte ) []byte {
467470 return utils .ConcatBytes ([]byte (headerComment ), code )
468471}
472+
473+ // ZK stack logic
474+ func ImproveAbigenOutputZks (path string , zkBinPath string ) {
475+ bs , err := os .ReadFile (path )
476+ if err != nil {
477+ Exit ("Error while improving abigen output" , err )
478+ }
479+
480+ fset , fileNode := parseFile (bs )
481+
482+ contractName := getContractName (fileNode )
483+
484+ zkByteCode , err := os .ReadFile (zkBinPath )
485+ if err != nil {
486+ Exit ("Error while improving abigen output" , err )
487+ }
488+ zkHexString := string (zkByteCode )
489+
490+ // add zksync binary to the wrapper
491+ fileNode = addZKSyncBin (fileNode , contractName , zkHexString )
492+
493+ // add zksync logic to the deploy function
494+ fileNode = updateDeployMethod (contractName , fset , fileNode )
495+
496+ bs = generateCode (fset , fileNode )
497+
498+ err = os .WriteFile (path , bs , 0600 )
499+ if err != nil {
500+ Exit ("Error while writing improved abigen source" , err )
501+ }
502+ }
503+
504+ // add zksync binary to the wrapper
505+ func addZKSyncBin (fileNode * ast.File , contractName string , zkHexString string ) * ast.File {
506+ // zksync
507+ newVarSpec := & ast.ValueSpec {
508+ Names : []* ast.Ident {ast .NewIdent (contractName + "ZKBin" )},
509+ Values : []ast.Expr {
510+ & ast.BasicLit {
511+ Kind : token .STRING ,
512+ Value : fmt .Sprintf ("(\" %s\" )" , zkHexString ),
513+ },
514+ },
515+ }
516+ newVarDecl := & ast.GenDecl {
517+ Tok : token .VAR ,
518+ Specs : []ast.Spec {newVarSpec },
519+ }
520+
521+ // Insert the new variable declaration at the top of the file (before existing functions)
522+ fileNode .Decls = append (fileNode .Decls , newVarDecl )
523+ return fileNode
524+ }
525+
526+ // add zksync logic to the deploy function
527+ func updateDeployMethod (contractName string , fset * token.FileSet , fileNode * ast.File ) * ast.File {
528+
529+ return astutil .Apply (fileNode , func (cursor * astutil.Cursor ) bool {
530+ x , is := cursor .Node ().(* ast.FuncDecl )
531+ if ! is {
532+ return true
533+ } else if x .Name .Name != "Deploy" + contractName {
534+ return false
535+ }
536+
537+ // only add this import if Deploy method found
538+ astutil .AddImport (fset , fileNode , "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated" )
539+
540+ // Extract the parameters from the existing function x
541+ paramList := getConstructorParams (x .Type .Params .List )
542+ // get the `if zksync()` block
543+ zkSyncBlock := getZKSyncBlock (contractName , paramList )
544+ // insert the `if zksync()` block
545+ addZKSyncBlock (* x , zkSyncBlock )
546+ // update the return type in the function signature
547+ updateTxReturnType (* x )
548+ // update the actual return value
549+ updateReturnStmt (* x )
550+
551+ return false
552+ }, nil ).(* ast.File )
553+ }
554+
555+ // get the `if zksync()` block
556+ func getZKSyncBlock (contractName , paramList string ) string {
557+ zkSyncBlock := `if generated.IsZKSync(backend) {
558+ address, ethTx, contractBind, _ := generated.DeployContract(auth, parsed, common.FromHex(%sZKBin), backend, %params)
559+ contractReturn := &%s{address: address, abi: *parsed, %sCaller: %sCaller{contract: contractBind}, %sTransactor: %sTransactor{contract: contractBind},%sFilterer: %sFilterer{contract: contractBind}}
560+ return address, ethTx, contractReturn, err
561+ }`
562+ zkSyncBlock = strings .ReplaceAll (zkSyncBlock , "%s" , contractName )
563+ zkSyncBlock = strings .ReplaceAll (zkSyncBlock , "%params" , paramList )
564+ return strings .ReplaceAll (zkSyncBlock , "%s" , contractName )
565+ }
566+
567+ // Extract the parameters for constructor function
568+ func getConstructorParams (contstructorParams []* ast.Field ) string {
569+ params := []string {}
570+ for i , param := range contstructorParams {
571+ if i > 1 { // Skip auth and backend
572+ for _ , name := range param .Names {
573+ params = append (params , name .Name )
574+ }
575+ }
576+ }
577+ paramList := strings .Join (params , ", " )
578+ return paramList
579+ }
580+
581+ // insert the `if zksync()` block
582+ func addZKSyncBlock (x ast.FuncDecl , zkSyncBlock string ) ast.FuncDecl {
583+ for i , stmt := range x .Body .List {
584+
585+ ifStmt , ok := stmt .(* ast.IfStmt )
586+ if ! ok {
587+ continue
588+ }
589+ binaryExpr , ok := ifStmt .Cond .(* ast.BinaryExpr )
590+ if ! ok {
591+ continue
592+ }
593+ if ident , ok := binaryExpr .X .(* ast.Ident ); ok && ident .Name == "parsed" {
594+ // Creating new statement to insert
595+ newStmt := & ast.ExprStmt {
596+ X : & ast.BasicLit {
597+ Kind : token .STRING ,
598+ Value : zkSyncBlock ,
599+ },
600+ }
601+
602+ // Insert the new statement after the current statement
603+ x .Body .List = append (x .Body .List [:i + 1 ], append ([]ast.Stmt {newStmt }, x .Body .List [i + 1 :]... )... )
604+ break
605+ }
606+ }
607+ return x
608+ }
609+
610+ // convert *types.Transaction to *generated_zks.Transaction
611+ func updateTxReturnType (x ast.FuncDecl ) {
612+ x .Type .Results .List [1 ].Type = & ast.StarExpr {
613+ X : & ast.SelectorExpr {
614+ X : & ast.Ident {Name : "generated" },
615+ Sel : & ast.Ident {Name : "Transaction" },
616+ },
617+ }
618+ }
619+
620+ // convert tx to &Transaction{Transaction: tx, HashZks: tx.Hash()}
621+ func updateReturnStmt (x ast.FuncDecl ) {
622+ for _ , stmt := range x .Body .List {
623+ returnStmt , is := stmt .(* ast.ReturnStmt )
624+ if ! is {
625+ continue
626+ }
627+ if len (returnStmt .Results ) < 3 {
628+ continue
629+ }
630+
631+ txExpr , ok := returnStmt .Results [1 ].(* ast.Ident )
632+ if ! ok {
633+ return
634+ }
635+ if txExpr .Name != "tx" {
636+ return
637+ }
638+
639+ txField := & ast.KeyValueExpr {
640+ Key : ast .NewIdent ("Transaction" ),
641+ Value : ast .NewIdent ("tx" ),
642+ }
643+
644+ hashField := & ast.KeyValueExpr {
645+ Key : ast .NewIdent ("HashZks" ),
646+ Value : & ast.CallExpr {
647+ Fun : & ast.SelectorExpr {
648+ X : ast .NewIdent ("tx" ),
649+ Sel : ast .NewIdent ("Hash" ),
650+ },
651+ },
652+ }
653+ newRet := & ast.CompositeLit {
654+ Type : & ast.SelectorExpr {
655+ X : ast .NewIdent ("generated" ),
656+ Sel : ast .NewIdent ("Transaction" ),
657+ },
658+ Elts : []ast.Expr {txField , hashField },
659+ }
660+ pointerRet := & ast.UnaryExpr {Op : token .AND , X : newRet }
661+ returnStmt .Results [1 ] = pointerRet
662+ }
663+ }
0 commit comments