|
| 1 | +package local |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "sync" |
| 7 | + "sync/atomic" |
| 8 | + |
| 9 | + "github.com/keep-network/keep-core/pkg/net" |
| 10 | + "github.com/keep-network/keep-core/pkg/net/internal" |
| 11 | + "github.com/keep-network/keep-core/pkg/net/key" |
| 12 | + "github.com/keep-network/keep-core/pkg/net/retransmission" |
| 13 | +) |
| 14 | + |
| 15 | +const messageHandlerThrottle = 256 |
| 16 | + |
| 17 | +type messageHandler struct { |
| 18 | + ctx context.Context |
| 19 | + channel chan net.Message |
| 20 | +} |
| 21 | + |
| 22 | +type localChannel struct { |
| 23 | + counter uint64 |
| 24 | + name string |
| 25 | + identifier net.TransportIdentifier |
| 26 | + staticKey *key.NetworkPublic |
| 27 | + messageHandlersMutex sync.Mutex |
| 28 | + messageHandlers []*messageHandler |
| 29 | + unmarshalersMutex sync.Mutex |
| 30 | + unmarshalersByType map[string]func() net.TaggedUnmarshaler |
| 31 | + retransmissionTicker *retransmission.Ticker |
| 32 | +} |
| 33 | + |
| 34 | +func (lc *localChannel) nextSeqno() uint64 { |
| 35 | + return atomic.AddUint64(&lc.counter, 1) |
| 36 | +} |
| 37 | + |
| 38 | +func (lc *localChannel) Name() string { |
| 39 | + return lc.name |
| 40 | +} |
| 41 | + |
| 42 | +func (lc *localChannel) Send(ctx context.Context, message net.TaggedMarshaler) error { |
| 43 | + bytes, err := message.Marshal() |
| 44 | + if err != nil { |
| 45 | + return err |
| 46 | + } |
| 47 | + |
| 48 | + unmarshaler, found := lc.unmarshalersByType[string(message.Type())] |
| 49 | + if !found { |
| 50 | + return fmt.Errorf("couldn't find unmarshaler for type %s", string(message.Type())) |
| 51 | + } |
| 52 | + |
| 53 | + unmarshaled := unmarshaler() |
| 54 | + err = unmarshaled.Unmarshal(bytes) |
| 55 | + if err != nil { |
| 56 | + return err |
| 57 | + } |
| 58 | + |
| 59 | + netMessage := internal.BasicMessage( |
| 60 | + lc.identifier, |
| 61 | + unmarshaled, |
| 62 | + "local", |
| 63 | + key.Marshal(lc.staticKey), |
| 64 | + lc.nextSeqno(), |
| 65 | + ) |
| 66 | + |
| 67 | + retransmission.ScheduleRetransmissions( |
| 68 | + ctx, |
| 69 | + lc.retransmissionTicker, |
| 70 | + func() error { |
| 71 | + return broadcastMessage(lc.name, netMessage) |
| 72 | + }, |
| 73 | + ) |
| 74 | + |
| 75 | + return broadcastMessage(lc.name, netMessage) |
| 76 | +} |
| 77 | + |
| 78 | +func (lc *localChannel) deliver(message net.Message) { |
| 79 | + lc.messageHandlersMutex.Lock() |
| 80 | + snapshot := make([]*messageHandler, len(lc.messageHandlers)) |
| 81 | + copy(snapshot, lc.messageHandlers) |
| 82 | + lc.messageHandlersMutex.Unlock() |
| 83 | + |
| 84 | + for _, handler := range snapshot { |
| 85 | + select { |
| 86 | + case handler.channel <- message: |
| 87 | + default: |
| 88 | + logger.Warningf("handler too slow, dropping message") |
| 89 | + } |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +func (lc *localChannel) Recv(ctx context.Context, handler func(m net.Message)) { |
| 94 | + messageHandler := &messageHandler{ |
| 95 | + ctx: ctx, |
| 96 | + channel: make(chan net.Message, messageHandlerThrottle), |
| 97 | + } |
| 98 | + |
| 99 | + lc.messageHandlersMutex.Lock() |
| 100 | + lc.messageHandlers = append(lc.messageHandlers, messageHandler) |
| 101 | + lc.messageHandlersMutex.Unlock() |
| 102 | + |
| 103 | + handleWithRetransmissions := retransmission.WithRetransmissionSupport(handler) |
| 104 | + |
| 105 | + go func() { |
| 106 | + for { |
| 107 | + select { |
| 108 | + case <-ctx.Done(): |
| 109 | + logger.Debug("context is done, removing handler") |
| 110 | + lc.removeHandler(messageHandler) |
| 111 | + return |
| 112 | + |
| 113 | + case msg := <-messageHandler.channel: |
| 114 | + // Go language specification says that if one or more of the |
| 115 | + // communications in the select statement can proceed, a single |
| 116 | + // one that will proceed is chosen via a uniform pseudo-random |
| 117 | + // selection. |
| 118 | + // Thus, it can happen this communication is called when ctx is |
| 119 | + // already done. Since we guarantee in the network channel API |
| 120 | + // that handler is not called after ctx is done (client code |
| 121 | + // could e.g. perform come cleanup), we need to double-check |
| 122 | + // the context state here. |
| 123 | + if messageHandler.ctx.Err() != nil { |
| 124 | + continue |
| 125 | + } |
| 126 | + |
| 127 | + handleWithRetransmissions(msg) |
| 128 | + } |
| 129 | + } |
| 130 | + }() |
| 131 | +} |
| 132 | + |
| 133 | +func (lc *localChannel) removeHandler(handler *messageHandler) { |
| 134 | + lc.messageHandlersMutex.Lock() |
| 135 | + defer lc.messageHandlersMutex.Unlock() |
| 136 | + |
| 137 | + for i, h := range lc.messageHandlers { |
| 138 | + if h.channel == handler.channel { |
| 139 | + lc.messageHandlers[i] = lc.messageHandlers[len(lc.messageHandlers)-1] |
| 140 | + lc.messageHandlers = lc.messageHandlers[:len(lc.messageHandlers)-1] |
| 141 | + break |
| 142 | + } |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +func (lc *localChannel) RegisterUnmarshaler( |
| 147 | + unmarshaler func() net.TaggedUnmarshaler, |
| 148 | +) (err error) { |
| 149 | + tpe := unmarshaler().Type() |
| 150 | + |
| 151 | + lc.unmarshalersMutex.Lock() |
| 152 | + defer lc.unmarshalersMutex.Unlock() |
| 153 | + |
| 154 | + lc.unmarshalersByType[tpe] = unmarshaler |
| 155 | + return nil |
| 156 | +} |
| 157 | + |
| 158 | +func (lc *localChannel) SetFilter(filter net.BroadcastChannelFilter) error { |
| 159 | + return nil // no-op |
| 160 | +} |
0 commit comments