diff --git a/client.go b/client.go index ec328d9..9417956 100644 --- a/client.go +++ b/client.go @@ -18,12 +18,13 @@ import ( // signals and heartbeats are automatically handled background in another goroutine. func Dial(ctx context.Context, net string, laddr, raddr *sctp.SCTPAddr, cfg *Config) (*Conn, error) { var err error + + cfg.SCTPConfig.sctpInfo = &sctp.SndRcvInfo{PPID: 3, Stream: 0} conn := &Conn{ muState: new(sync.RWMutex), mode: modeClient, stateChan: make(chan State), established: make(chan struct{}), - sctpInfo: &sctp.SndRcvInfo{PPID: 3, Stream: 0}, cfg: cfg, } @@ -36,13 +37,25 @@ func Dial(ctx context.Context, net string, laddr, raddr *sctp.SCTPAddr, cfg *Con return nil, fmt.Errorf("invalid network: %s", net) } - conn.sctpConn, err = sctp.DialSCTP(n, laddr, raddr) + conn.cfg.SCTPConfig.sctpConn, err = sctp.DialSCTP(n, laddr, raddr) if err != nil { return nil, err } - r, err := conn.sctpConn.GetStatus() + if conn.cfg.SCTPConfig.SctpSackInfo != nil && conn.cfg.SCTPConfig.SctpSackInfo.Enabled { + err = conn.cfg.SCTPConfig.sctpConn.SetSackTimer(&sctp.SackTimer{ + SackDelay: conn.cfg.SCTPConfig.SctpSackInfo.SackDelay, + SackFrequency: conn.cfg.SCTPConfig.SctpSackInfo.SackFrequency, + }) + if err != nil { + conn.cfg.SCTPConfig.sctpConn.Close() + return nil, fmt.Errorf("failed to set sack timer: %w", err) + } + } + + r, err := conn.cfg.SCTPConfig.sctpConn.GetStatus() if err != nil { + conn.cfg.SCTPConfig.sctpConn.Close() return nil, fmt.Errorf("failed to get sctpConn status: %w", err) } conn.maxMessageStreamID = r.Ostreams - 1 // removing 1 for management messages of stream ID 0 @@ -55,12 +68,12 @@ func Dial(ctx context.Context, net string, laddr, raddr *sctp.SCTPAddr, cfg *Con select { case _, ok := <-conn.established: if !ok { - conn.sctpConn.Close() + conn.cfg.SCTPConfig.sctpConn.Close() return nil, ErrFailedToEstablish } return conn, nil case <-time.After(10 * time.Second): - conn.sctpConn.Close() + conn.cfg.SCTPConfig.sctpConn.Close() return nil, ErrTimeout } } diff --git a/config.go b/config.go index 821df44..df73029 100644 --- a/config.go +++ b/config.go @@ -7,6 +7,7 @@ package m3ua import ( "time" + "github.com/ishidawataru/sctp" "github.com/wmnsk/go-m3ua/messages/params" ) @@ -25,9 +26,37 @@ func NewHeartbeatInfo(interval, timer time.Duration, data []byte) *HeartbeatInfo } } +// SctpSackInfo is a set of information for SCTP SACK timer configuration. +// +// SackDelay sack_delay: This parameter contains the number of milliseconds the +// user is requesting that the delayed SACK timer be set to. Note +// that this value is defined in [RFC4960] to be between 200 and 500 +// milliseconds. +// +// SackFrequency sack_freq: This parameter contains the number of packets that must +// be received before a SACK is sent without waiting for the delay +// timer to expire. The default value is 2; setting this value to 1 +// will disable the delayed SACK algorithm. +type SctpSackInfo struct { + Enabled bool + SackDelay uint32 + SackFrequency uint32 +} + +// SCTPConfig holds all SCTP-related configuration parameters. +// This separates SCTP layer configuration from M3UA layer configuration. +type SCTPConfig struct { + *SctpSackInfo + // sctpConn is the underlying SCTP association + sctpConn *sctp.SCTPConn + // sctpInfo is SndRcvInfo in SCTP association + sctpInfo *sctp.SndRcvInfo +} + // Config is a configuration that defines a M3UA server. type Config struct { *HeartbeatInfo + *SCTPConfig AspIdentifier *params.Param TrafficModeType *params.Param NetworkAppearance *params.Param @@ -49,6 +78,7 @@ type Config struct { // values. func NewConfig(opc, dpc uint32, si, ni, mp, sls uint8) *Config { return &Config{ + SCTPConfig: &SCTPConfig{}, OriginatingPointCode: opc, DestinationPointCode: dpc, ServiceIndicator: si, @@ -72,6 +102,28 @@ func (c *Config) EnableHeartbeat(interval, timer time.Duration) *Config { return c } +// SetSackConfig sets the SCTP SACK timer configuration. +// +// sackDelay is the number of milliseconds for the delayed SACK timer +// (per RFC4960, should be between 200 and 500 ms). +// +// sackFrequency is the number of packets to receive before sending a SACK +// without waiting for the delay timer. Setting to 1 disables the delayed +// SACK algorithm. +// +// Note: sackDelay=0, sackFrequency=1 (disables delayed SACK) +func (c *Config) SetSackConfig(sackDelay, sackFrequency uint32) *Config { + if c.SCTPConfig == nil { + c.SCTPConfig = &SCTPConfig{} + } + c.SCTPConfig.SctpSackInfo = &SctpSackInfo{ + Enabled: true, + SackDelay: sackDelay, + SackFrequency: sackFrequency, + } + return c +} + // SetAspIdentifier sets AspIdentifier in Config. func (c *Config) SetAspIdentifier(id uint32) *Config { c.AspIdentifier = params.NewAspIdentifier(id) @@ -109,6 +161,7 @@ func (c *Config) SetCorrelationID(id uint32) *Config { func NewClientConfig(hbInfo *HeartbeatInfo, opc, dpc, aspID, tmt, nwApr, corrID uint32, rtCtxs []uint32, si, ni, mp, sls uint8) *Config { return &Config{ HeartbeatInfo: hbInfo, + SCTPConfig: &SCTPConfig{}, AspIdentifier: params.NewAspIdentifier(aspID), TrafficModeType: params.NewTrafficModeType(tmt), NetworkAppearance: params.NewNetworkAppearance(nwApr), @@ -130,6 +183,7 @@ func NewClientConfig(hbInfo *HeartbeatInfo, opc, dpc, aspID, tmt, nwApr, corrID func NewServerConfig(hbInfo *HeartbeatInfo, opc, dpc, aspID, tmt, nwApr, corrID uint32, rtCtxs []uint32, si, ni, mp, sls uint8) *Config { return &Config{ HeartbeatInfo: hbInfo, + SCTPConfig: &SCTPConfig{}, AspIdentifier: params.NewAspIdentifier(aspID), TrafficModeType: params.NewTrafficModeType(tmt), NetworkAppearance: params.NewNetworkAppearance(nwApr), diff --git a/conn.go b/conn.go index 4ebf83e..19bf142 100644 --- a/conn.go +++ b/conn.go @@ -23,7 +23,7 @@ const ( modeServer ) -// Conn represents a M3UA connection, which satisfies standard net.Conn interface. +// Conn represents a M3UA connection, which satisfies the standard net.Conn interface. type Conn struct { // maxMessageStreamID is the maximum negotiated sctp stream ID used, // must not be zero, must vary from 1 to maxMessageStreamID @@ -40,15 +40,11 @@ type Conn struct { established chan struct{} // beatAckChan notifies that heartbeat gets the ack as expected beatAckChan chan struct{} - // dataChan is to pass the ProtocolDataPayload(=payload on M3UA DATA) to user + // dataChan is to pass the ProtocolDataPayload(=payload on M3UA DATA) to the user dataChan chan *params.ProtocolDataPayload - // errChan is to pass errors to goroutine that monitors status + // errChan is to pass errors to a goroutine that monitors status errChan chan error - // sctpConn is the underlying SCTP association - sctpConn *sctp.SCTPConn - // sctpInfo is SndRcvInfo in SCTP association - sctpInfo *sctp.SndRcvInfo - // cfg is a configuration that is required to communicate between M3UA endpoints + // cfg is a configuration required to communicate between M3UA endpoints cfg *Config // Condition to allow heartbeat, only after the state is AspUp beatAllow *sync.Cond @@ -126,9 +122,9 @@ func (c *Conn) WriteToStream(b []byte, streamID uint16) (n int, err error) { } // taken by value to avoid race condition on the stream id - info := *c.sctpInfo + info := *c.cfg.SCTPConfig.sctpInfo info.Stream = streamID - n, err = c.sctpConn.SCTPWrite(d, &info) + n, err = c.cfg.SCTPConfig.sctpConn.SCTPWrite(d, &info) if err != nil { return 0, err } @@ -160,9 +156,9 @@ func (c *Conn) WritePDToStream(protocolData *params.Param, streamID uint16) (n i } // taken by value to avoid race condition on the stream id - info := *c.sctpInfo + info := *c.cfg.SCTPConfig.sctpInfo info.Stream = streamID - n, err = c.sctpConn.SCTPWrite(d, &info) + n, err = c.cfg.SCTPConfig.sctpConn.SCTPWrite(d, &info) if err != nil { return 0, err } @@ -180,12 +176,12 @@ func (c *Conn) WriteSignal(m3 messages.M3UA) (n int, err error) { } // taken by value to avoid race condition on the stream id - sctpInfo := *c.sctpInfo + sctpInfo := *c.cfg.SCTPConfig.sctpInfo if m3.MessageClass() != messages.MsgClassTransfer { sctpInfo.Stream = 0 } - nn, err := c.sctpConn.SCTPWrite(buf, &sctpInfo) + nn, err := c.cfg.SCTPConfig.sctpConn.SCTPWrite(buf, &sctpInfo) if err != nil { return 0, fmt.Errorf("failed to write M3UA: %w", err) } @@ -200,39 +196,39 @@ func (c *Conn) Close() error { defer c.muState.Unlock() if c.state == StateAspDown { - return c.sctpConn.Close() + return c.cfg.SCTPConfig.sctpConn.Close() } close(c.established) close(c.beatAckChan) close(c.dataChan) c.state = StateAspDown - return c.sctpConn.Close() + return c.cfg.SCTPConfig.sctpConn.Close() } // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { - return c.sctpConn.LocalAddr() + return c.cfg.SCTPConfig.sctpConn.LocalAddr() } // RemoteAddr returns the remote network address. func (c *Conn) RemoteAddr() net.Addr { - return c.sctpConn.RemoteAddr() + return c.cfg.SCTPConfig.sctpConn.RemoteAddr() } // SetDeadline sets the read and write deadlines associated. func (c *Conn) SetDeadline(t time.Time) error { - return c.sctpConn.SetDeadline(t) + return c.cfg.SCTPConfig.sctpConn.SetDeadline(t) } // SetReadDeadline sets the deadline for future Read calls. func (c *Conn) SetReadDeadline(t time.Time) error { - return c.sctpConn.SetReadDeadline(t) + return c.cfg.SCTPConfig.sctpConn.SetReadDeadline(t) } // SetWriteDeadline sets the deadline for future Write calls. func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.sctpConn.SetWriteDeadline(t) + return c.cfg.SCTPConfig.sctpConn.SetWriteDeadline(t) } // State returns current state of Conn. @@ -244,7 +240,7 @@ func (c *Conn) State() State { // StreamID returns sctpInfo.Stream of Conn. func (c *Conn) StreamID() uint16 { - return c.sctpInfo.Stream + return c.cfg.SCTPConfig.sctpInfo.Stream } // MaxMessageStreamID returns the maximum negotiated sctp stream ID @@ -253,6 +249,23 @@ func (c *Conn) MaxMessageStreamID() uint16 { return c.maxMessageStreamID } +// SetSctpSackConfig sets the SCTP SACK timer configuration on an active connection. +// +// sackDelay is the number of milliseconds for the delayed SACK timer +// (per RFC4960, should be between 200 and 500 ms). +// +// sackFrequency is the number of packets to receive before sending a SACK +// without waiting for the delay timer. Setting to 1 disables the delayed +// SACK algorithm. +// +// Note: sackDelay=0, sackFrequency=1 (disables delayed SACK) +func (c *Conn) SetSctpSackConfig(sackDelay, sackFrequency uint32) error { + return c.cfg.SCTPConfig.sctpConn.SetSackTimer(&sctp.SackTimer{ + SackDelay: sackDelay, + SackFrequency: sackFrequency, + }) +} + // chooseStreamID generates a random uint16 from 1 to max (inclusive) func (c *Conn) chooseStreamID() uint16 { if c.maxMessageStreamID == 1 { diff --git a/fsm.go b/fsm.go index c2d68f7..5268f2d 100644 --- a/fsm.go +++ b/fsm.go @@ -226,7 +226,7 @@ func (c *Conn) monitor(ctx context.Context) { } // Read from conn to see something coming from the peer. - n, _, err := c.sctpConn.SCTPRead(buf) + n, _, err := c.cfg.SCTPConfig.sctpConn.SCTPRead(buf) if err != nil { c.Close() return diff --git a/server.go b/server.go index 3b58cca..4095bd7 100644 --- a/server.go +++ b/server.go @@ -41,12 +41,13 @@ func Listen(net string, laddr *sctp.SCTPAddr, cfg *Config) (*Listener, error) { // After successfully establishing the association with peer, Payload can be read with Read() func. // Other signals are automatically handled background in another goroutine. func (l *Listener) Accept(ctx context.Context) (*Conn, error) { + + l.Config.SCTPConfig.sctpInfo = &sctp.SndRcvInfo{PPID: 3, Stream: 0} conn := &Conn{ muState: new(sync.RWMutex), mode: modeServer, stateChan: make(chan State), established: make(chan struct{}), - sctpInfo: &sctp.SndRcvInfo{PPID: 3, Stream: 0}, cfg: l.Config, } @@ -60,14 +61,26 @@ func (l *Listener) Accept(ctx context.Context) (*Conn, error) { } var ok bool - conn.sctpConn, ok = c.(*sctp.SCTPConn) + conn.cfg.SCTPConfig.sctpConn, ok = c.(*sctp.SCTPConn) if !ok { c.Close() return nil, fmt.Errorf("failed to assert server connection") } - r, err := conn.sctpConn.GetStatus() + if conn.cfg.SCTPConfig.SctpSackInfo != nil && conn.cfg.SCTPConfig.SctpSackInfo.Enabled { + err = conn.cfg.SCTPConfig.sctpConn.SetSackTimer(&sctp.SackTimer{ + SackDelay: conn.cfg.SCTPConfig.SctpSackInfo.SackDelay, + SackFrequency: conn.cfg.SCTPConfig.SctpSackInfo.SackFrequency, + }) + if err != nil { + conn.cfg.SCTPConfig.sctpConn.Close() + return nil, fmt.Errorf("failed to set sack timer: %w", err) + } + } + + r, err := conn.cfg.SCTPConfig.sctpConn.GetStatus() if err != nil { + conn.cfg.SCTPConfig.sctpConn.Close() return nil, fmt.Errorf("failed to get sctpConn status: %w", err) } conn.maxMessageStreamID = r.Ostreams - 1 // removing 1 for management messages of stream ID 0 @@ -80,12 +93,12 @@ func (l *Listener) Accept(ctx context.Context) (*Conn, error) { select { case _, ok := <-conn.established: if !ok { - conn.sctpConn.Close() + conn.cfg.SCTPConfig.sctpConn.Close() return nil, ErrFailedToEstablish } return conn, nil case <-time.After(10 * time.Second): - conn.sctpConn.Close() + conn.cfg.SCTPConfig.sctpConn.Close() return nil, ErrTimeout } }