Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion core/capabilities/remote/executable/request/client_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type ClientRequest struct {

requestTimeout time.Duration

responsePolicy responsePolicy

respSent bool
mux sync.Mutex
wg *sync.WaitGroup
Expand Down Expand Up @@ -194,6 +196,7 @@ func newClientRequest(ctx context.Context, lggr logger.Logger, requestID string,
meteringResponses: make(map[[32]byte][]commoncap.MeteringNodeDetail),
errorCount: make(map[string]int),
responseReceived: responseReceived,
responsePolicy: newResponsePolicy(remoteCapabilityInfo, capMethodName),
responseCh: make(chan clientResponse, 1),
wg: &wg,
lggr: lggr,
Expand Down Expand Up @@ -263,6 +266,16 @@ func (c *ClientRequest) Cancel(err error) {
c.mux.Lock()
defer c.mux.Unlock()
if !c.respSent {
if c.responsePolicy != nil {
payload, ok, buildErr := c.responsePolicy.BuildDeterministicResponse(true)
if buildErr != nil {
c.lggr.Warnw("failed to build deterministic policy response", "error", buildErr)
}
if ok {
c.sendResponse(clientResponse{Result: payload})
return
}
}
c.sendResponse(clientResponse{Err: err})
}
}
Expand Down Expand Up @@ -330,14 +343,29 @@ func (c *ClientRequest) OnMessage(_ context.Context, msg *types.MessageBody) err
lggr.Warnw("received multiple unique responses for the same request", "count for responseID", len(c.responseIDCount))
}

if c.responseIDCount[responseID] == c.requiredIdenticalResponses {
if c.responseIDCount[responseID] == c.requiredIdenticalResponses && !c.shouldDeferIdenticalResponse(msg.Payload) {
payload, err := c.encodePayloadWithMetadata(msg, commoncap.ResponseMetadata{Metering: nodeReports})
if err != nil {
return fmt.Errorf("failed to encode payload with metadata: %w", err)
}

c.sendResponse(clientResponse{Result: payload})
}

if !c.respSent {
if c.responsePolicy != nil {
c.responsePolicy.ObserveOKResponse(msg, metadata)
payload, ok, buildErr := c.responsePolicy.BuildDeterministicResponse(c.allResponsesReceived())
if buildErr != nil {
return fmt.Errorf("failed to build deterministic policy response: %w", buildErr)
}
if ok {
c.sendResponse(clientResponse{Result: payload})
} else if err := c.maybeFinalizeResponsePolicyAfterAllResponses(); err != nil {
return err
}
}
}
} else {
c.lggr.Debugw("received error from peer", "error", msg.Error, "errorMsg", msg.ErrorMsg, "peer", sender)
c.errorCount[msg.ErrorMsg]++
Expand All @@ -347,6 +375,13 @@ func (c *ClientRequest) OnMessage(_ context.Context, msg *types.MessageBody) err
c.lggr.Warn("received multiple different errors for the same request, number of different errors received: %d", len(c.errorCount))
}

if c.responsePolicy != nil && c.responsePolicy.ShouldDeferErrorResponses() {
if err := c.maybeFinalizeResponsePolicyAfterAllResponses(); err != nil {
return err
}
return nil
}

if c.errorCount[msg.ErrorMsg] == c.requiredIdenticalResponses {
c.sendResponse(clientResponse{Err: fmt.Errorf("%s : %s", msg.Error, msg.ErrorMsg)})
} else if c.totalErrorCount == c.remoteNodeCount-c.requiredIdenticalResponses+1 {
Expand Down Expand Up @@ -396,3 +431,45 @@ func (c *ClientRequest) encodePayloadWithMetadata(msg *types.MessageBody, metada

return pb.MarshalCapabilityResponse(resp)
}

func (c *ClientRequest) shouldDeferIdenticalResponse(payload []byte) bool {
if c.responsePolicy == nil {
return false
}
return c.responsePolicy.ShouldDeferIdenticalResponse(payload)
}

func (c *ClientRequest) allResponsesReceived() bool {
if len(c.responseReceived) == 0 {
return false
}

for _, received := range c.responseReceived {
if !received {
return false
}
}

return true
}

func (c *ClientRequest) maybeFinalizeResponsePolicyAfterAllResponses() error {
if c.responsePolicy == nil || c.respSent {
return nil
}

response, ok, err := c.responsePolicy.FinalizeAfterAllResponses(responsePolicyState{
AllResponsesReceived: c.allResponsesReceived(),
ResponseVariants: len(c.responseIDCount),
TotalErrorCount: c.totalErrorCount,
})
if err != nil {
return err
}
if !ok || response == nil {
return nil
}

c.sendResponse(*response)
return nil
}
242 changes: 242 additions & 0 deletions core/capabilities/remote/executable/request/client_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"slices"
"strings"
"testing"
"time"

Expand All @@ -15,10 +16,12 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/beholder/beholdertest"
commoncap "github.com/smartcontractkit/chainlink-common/pkg/capabilities"
"github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb"
aptoscap "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/chain-capabilities/aptos"
"github.com/smartcontractkit/chainlink-protos/cre/go/values"
"github.com/smartcontractkit/chainlink-protos/workflows/go/events"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/executable/request"
Expand Down Expand Up @@ -701,6 +704,177 @@ func Test_ClientRequest_MessageValidation(t *testing.T) {
})
}

func Test_ClientRequest_AptosWriteReportFailedHashAggregation(t *testing.T) {
workflowPeers := []p2ptypes.PeerID{NewP2PPeerID(t), NewP2PPeerID(t)}
workflowDonInfo := commoncap.DON{Members: workflowPeers, ID: 2}
capabilityPeers, capDonInfo, _ := capabilityDon(t, 4, 1)

capInfo := commoncap.CapabilityInfo{
ID: "aptos:ChainSelector:4457093679053095497@1.0.0",
CapabilityType: commoncap.CapabilityTypeTarget,
Description: "Remote Aptos Target",
DON: &capDonInfo,
}

capabilityRequest := commoncap.CapabilityRequest{
Metadata: commoncap.RequestMetadata{
WorkflowID: workflowID1,
WorkflowExecutionID: workflowExecutionID1,
ReferenceID: stepRef1,
},
}

makeReq := func(t *testing.T) *request.ClientRequest {
t.Helper()
dispatcher := &clientRequestTestDispatcher{msgs: make(chan *types.MessageBody, 100)}
req, err := request.NewClientExecuteRequest(
t.Context(),
logger.Test(t),
capabilityRequest,
capInfo,
workflowDonInfo,
dispatcher,
10*time.Minute,
&transmission.TransmissionConfig{
Schedule: transmission.Schedule_AllAtOnce,
DeltaStage: 500 * time.Millisecond,
},
"WriteReport",
)
require.NoError(t, err)
for i := 0; i < len(capabilityPeers); i++ {
<-dispatcher.msgs
}
return req
}

makeMessage := func(sender p2ptypes.PeerID, payload []byte) *types.MessageBody {
return &types.MessageBody{
CapabilityId: capInfo.ID,
CapabilityDonId: capDonInfo.ID,
CallerDonId: workflowDonInfo.ID,
Method: types.MethodExecute,
Payload: payload,
MessageId: []byte("messageID"),
Sender: sender[:],
}
}
makeErrorMessage := func(sender p2ptypes.PeerID, errMsg string) *types.MessageBody {
return &types.MessageBody{
CapabilityId: capInfo.ID,
CapabilityDonId: capDonInfo.ID,
CallerDonId: workflowDonInfo.ID,
Method: types.MethodExecute,
MessageId: []byte("messageID"),
Sender: sender[:],
Error: types.Error_INTERNAL_ERROR,
ErrorMsg: errMsg,
}
}

t.Run("falls back to sorted hash when no failed-hash quorum", func(t *testing.T) {
req := makeReq(t)
defer req.Cancel(errors.New("test end"))

hashHigh := "0x" + strings.Repeat("f", 64)
hashLow := "0x" + strings.Repeat("0", 64)
payloadHigh := mustAptosWriteReportCapabilityResponse(t, aptosWriteFailureStatus(), hashHigh, "node high")
payloadLow := mustAptosWriteReportCapabilityResponse(t, aptosWriteFailureStatus(), hashLow, "node low")

require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[0], payloadHigh)))
require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[1], payloadLow)))

select {
case <-req.ResponseChan():
t.Fatal("expected no immediate response without failed-hash quorum")
default:
}

req.Cancel(errors.New("request expired"))
response := <-req.ResponseChan()
require.NoError(t, response.Err)

reply := mustUnwrapAptosWriteReportReply(t, response.Result)
require.NotEqual(t, aptoscap.TxStatus_TX_STATUS_SUCCESS, reply.GetTxStatus())
require.Equal(t, "0x"+strings.Repeat("0", 64), aptosWriteReplyTxHash(reply))
})

t.Run("returns after 2f+1 failed replies using canonical hash selection", func(t *testing.T) {
req := makeReq(t)
defer req.Cancel(errors.New("test end"))

hashC := "0x" + strings.Repeat("c", 64)
hashB := "0x" + strings.Repeat("b", 64)
hashA := "0x" + strings.Repeat("a", 64)

// We need 2f+1 failed replies (3 for f=1). Hashes can differ.
// Canonical selection is deterministic: lexicographically smallest normalized hash.
payloadOne := mustAptosWriteReportCapabilityResponse(t, aptosWriteFailureStatus(), hashC, "node c")
payloadTwo := mustAptosWriteReportCapabilityResponse(t, aptosWriteFailureStatus(), hashB, "node b")
payloadThree := mustAptosWriteReportCapabilityResponse(t, aptosWriteFailureStatus(), strings.ToUpper(hashA), "node a")

require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[0], payloadOne)))
require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[1], payloadTwo)))
select {
case <-req.ResponseChan():
t.Fatal("expected no immediate response with only two failed replies")
default:
}

require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[2], payloadThree)))

response := <-req.ResponseChan()
require.NoError(t, response.Err)

reply := mustUnwrapAptosWriteReportReply(t, response.Result)
require.NotEqual(t, aptoscap.TxStatus_TX_STATUS_SUCCESS, reply.GetTxStatus())
require.Equal(t, "0x"+strings.Repeat("a", 64), aptosWriteReplyTxHash(reply))
})

t.Run("returns explicit error when all replies are failed without hashes", func(t *testing.T) {
req := makeReq(t)
defer req.Cancel(errors.New("test end"))

payload := mustAptosWriteReportCapabilityResponse(t, aptosWriteFailureStatus(), "", "missing hash")

require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[0], payload)))
require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[1], payload)))
require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[2], payload)))

select {
case <-req.ResponseChan():
t.Fatal("expected no response before all Aptos replies are received")
default:
}

require.NoError(t, req.OnMessage(t.Context(), makeMessage(capabilityPeers[3], payload)))
response := <-req.ResponseChan()
require.Error(t, response.Err)
require.Contains(t, response.Err.Error(), "without deterministic failed hash")
})

t.Run("does not short-circuit Aptos write on repeated peer errors before all responses", func(t *testing.T) {
req := makeReq(t)
defer req.Cancel(errors.New("test end"))

require.NoError(t, req.OnMessage(t.Context(), makeErrorMessage(capabilityPeers[0], "submit failed")))
require.NoError(t, req.OnMessage(t.Context(), makeErrorMessage(capabilityPeers[1], "submit failed")))

select {
case <-req.ResponseChan():
t.Fatal("expected Aptos write request to wait for all responses before returning an error")
default:
}

require.NoError(t, req.OnMessage(t.Context(), makeErrorMessage(capabilityPeers[2], "submit failed")))
require.NoError(t, req.OnMessage(t.Context(), makeErrorMessage(capabilityPeers[3], "submit failed")))

response := <-req.ResponseChan()
require.Error(t, response.Err)
require.Contains(t, response.Err.Error(), "without deterministic failed hash")
})
}

func capabilityDon(t *testing.T, numCapabilityPeers int, f uint8) ([]p2ptypes.PeerID, commoncap.DON, commoncap.CapabilityInfo) {
capabilityPeers := make([]p2ptypes.PeerID, numCapabilityPeers)
for i := range numCapabilityPeers {
Expand All @@ -722,6 +896,74 @@ func capabilityDon(t *testing.T, numCapabilityPeers int, f uint8) ([]p2ptypes.Pe
return capabilityPeers, capDonInfo, capInfo
}

func mustAptosWriteReportCapabilityResponse(t *testing.T, status aptoscap.TxStatus, txHash string, errorMsg string) []byte {
t.Helper()

reply := &aptoscap.WriteReportReply{
TxStatus: status,
}
setAptosWriteReplyTxHash(reply, txHash)
if errorMsg != "" {
reply.ErrorMessage = &errorMsg
}

response := commoncap.CapabilityResponse{}
require.NoError(t, commoncap.SetResponse(&response, false, reply))

payload, err := pb.MarshalCapabilityResponse(response)
require.NoError(t, err)
return payload
}

func mustUnwrapAptosWriteReportReply(t *testing.T, payload []byte) *aptoscap.WriteReportReply {
t.Helper()

response, err := pb.UnmarshalCapabilityResponse(payload)
require.NoError(t, err)

reply := &aptoscap.WriteReportReply{}
_, err = commoncap.UnwrapResponse(response, reply)
require.NoError(t, err)

return reply
}

func aptosWriteFailureStatus() aptoscap.TxStatus {
// Both aptos proto variants model non-success as enum number 0.
return aptoscap.TxStatus(0)
}

func setAptosWriteReplyTxHash(reply *aptoscap.WriteReportReply, txHash string) {
m := reply.ProtoReflect()
fd := m.Descriptor().Fields().ByName("tx_hash")
if fd == nil {
return
}
switch fd.Kind() {
case protoreflect.StringKind:
m.Set(fd, protoreflect.ValueOfString(txHash))
case protoreflect.BytesKind:
m.Set(fd, protoreflect.ValueOfBytes([]byte(txHash)))
}
}

func aptosWriteReplyTxHash(reply *aptoscap.WriteReportReply) string {
m := reply.ProtoReflect()
fd := m.Descriptor().Fields().ByName("tx_hash")
if fd == nil || !m.Has(fd) {
return ""
}
v := m.Get(fd)
switch fd.Kind() {
case protoreflect.StringKind:
return v.String()
case protoreflect.BytesKind:
return string(v.Bytes())
default:
return ""
}
}

type clientRequestTestDispatcher struct {
msgs chan *types.MessageBody
}
Expand Down
Loading
Loading