Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ import (
// signals and heartbeats are automatically handled background in another goroutine.
func Dial(ctx context.Context, net string, laddr, raddr *sctp.SCTPAddr, cfg *Config) (*Conn, error) {
var err error

cfg.SCTPConfig.sctpInfo = &sctp.SndRcvInfo{PPID: 3, Stream: 0}
conn := &Conn{
muState: new(sync.RWMutex),
mode: modeClient,
stateChan: make(chan State),
established: make(chan struct{}),
sctpInfo: &sctp.SndRcvInfo{PPID: 3, Stream: 0},
cfg: cfg,
}

Expand All @@ -36,13 +37,25 @@ func Dial(ctx context.Context, net string, laddr, raddr *sctp.SCTPAddr, cfg *Con
return nil, fmt.Errorf("invalid network: %s", net)
}

conn.sctpConn, err = sctp.DialSCTP(n, laddr, raddr)
conn.cfg.SCTPConfig.sctpConn, err = sctp.DialSCTP(n, laddr, raddr)
if err != nil {
return nil, err
}

r, err := conn.sctpConn.GetStatus()
if conn.cfg.SCTPConfig.SctpSackInfo != nil && conn.cfg.SCTPConfig.SctpSackInfo.Enabled {
err = conn.cfg.SCTPConfig.sctpConn.SetSackTimer(&sctp.SackTimer{
SackDelay: conn.cfg.SCTPConfig.SctpSackInfo.SackDelay,
SackFrequency: conn.cfg.SCTPConfig.SctpSackInfo.SackFrequency,
})
if err != nil {
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, fmt.Errorf("failed to set sack timer: %w", err)
}
}

r, err := conn.cfg.SCTPConfig.sctpConn.GetStatus()
if err != nil {
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, fmt.Errorf("failed to get sctpConn status: %w", err)
}
conn.maxMessageStreamID = r.Ostreams - 1 // removing 1 for management messages of stream ID 0
Expand All @@ -55,12 +68,12 @@ func Dial(ctx context.Context, net string, laddr, raddr *sctp.SCTPAddr, cfg *Con
select {
case _, ok := <-conn.established:
if !ok {
conn.sctpConn.Close()
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, ErrFailedToEstablish
}
return conn, nil
case <-time.After(10 * time.Second):
conn.sctpConn.Close()
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, ErrTimeout
}
}
54 changes: 54 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package m3ua
import (
"time"

"github.com/ishidawataru/sctp"
"github.com/wmnsk/go-m3ua/messages/params"
)

Expand All @@ -25,9 +26,37 @@ func NewHeartbeatInfo(interval, timer time.Duration, data []byte) *HeartbeatInfo
}
}

// SctpSackInfo is a set of information for SCTP SACK timer configuration.
//
// SackDelay sack_delay: This parameter contains the number of milliseconds the
// user is requesting that the delayed SACK timer be set to. Note
// that this value is defined in [RFC4960] to be between 200 and 500
// milliseconds.
//
// SackFrequency sack_freq: This parameter contains the number of packets that must
// be received before a SACK is sent without waiting for the delay
// timer to expire. The default value is 2; setting this value to 1
// will disable the delayed SACK algorithm.
type SctpSackInfo struct {
Enabled bool
SackDelay uint32
SackFrequency uint32
}

// SCTPConfig holds all SCTP-related configuration parameters.
// This separates SCTP layer configuration from M3UA layer configuration.
type SCTPConfig struct {
*SctpSackInfo
// sctpConn is the underlying SCTP association
sctpConn *sctp.SCTPConn
// sctpInfo is SndRcvInfo in SCTP association
sctpInfo *sctp.SndRcvInfo
}

// Config is a configuration that defines a M3UA server.
type Config struct {
*HeartbeatInfo
*SCTPConfig
AspIdentifier *params.Param
TrafficModeType *params.Param
NetworkAppearance *params.Param
Expand All @@ -49,6 +78,7 @@ type Config struct {
// values.
func NewConfig(opc, dpc uint32, si, ni, mp, sls uint8) *Config {
return &Config{
SCTPConfig: &SCTPConfig{},
OriginatingPointCode: opc,
DestinationPointCode: dpc,
ServiceIndicator: si,
Expand All @@ -72,6 +102,28 @@ func (c *Config) EnableHeartbeat(interval, timer time.Duration) *Config {
return c
}

// SetSackConfig sets the SCTP SACK timer configuration.
//
// sackDelay is the number of milliseconds for the delayed SACK timer
// (per RFC4960, should be between 200 and 500 ms).
//
// sackFrequency is the number of packets to receive before sending a SACK
// without waiting for the delay timer. Setting to 1 disables the delayed
// SACK algorithm.
//
// Note: sackDelay=0, sackFrequency=1 (disables delayed SACK)
func (c *Config) SetSackConfig(sackDelay, sackFrequency uint32) *Config {
if c.SCTPConfig == nil {
c.SCTPConfig = &SCTPConfig{}
}
c.SCTPConfig.SctpSackInfo = &SctpSackInfo{
Enabled: true,
SackDelay: sackDelay,
SackFrequency: sackFrequency,
}
return c
}

// SetAspIdentifier sets AspIdentifier in Config.
func (c *Config) SetAspIdentifier(id uint32) *Config {
c.AspIdentifier = params.NewAspIdentifier(id)
Expand Down Expand Up @@ -109,6 +161,7 @@ func (c *Config) SetCorrelationID(id uint32) *Config {
func NewClientConfig(hbInfo *HeartbeatInfo, opc, dpc, aspID, tmt, nwApr, corrID uint32, rtCtxs []uint32, si, ni, mp, sls uint8) *Config {
return &Config{
HeartbeatInfo: hbInfo,
SCTPConfig: &SCTPConfig{},
AspIdentifier: params.NewAspIdentifier(aspID),
TrafficModeType: params.NewTrafficModeType(tmt),
NetworkAppearance: params.NewNetworkAppearance(nwApr),
Expand All @@ -130,6 +183,7 @@ func NewClientConfig(hbInfo *HeartbeatInfo, opc, dpc, aspID, tmt, nwApr, corrID
func NewServerConfig(hbInfo *HeartbeatInfo, opc, dpc, aspID, tmt, nwApr, corrID uint32, rtCtxs []uint32, si, ni, mp, sls uint8) *Config {
return &Config{
HeartbeatInfo: hbInfo,
SCTPConfig: &SCTPConfig{},
AspIdentifier: params.NewAspIdentifier(aspID),
TrafficModeType: params.NewTrafficModeType(tmt),
NetworkAppearance: params.NewNetworkAppearance(nwApr),
Expand Down
57 changes: 35 additions & 22 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
modeServer
)

// Conn represents a M3UA connection, which satisfies standard net.Conn interface.
// Conn represents a M3UA connection, which satisfies the standard net.Conn interface.
type Conn struct {
// maxMessageStreamID is the maximum negotiated sctp stream ID used,
// must not be zero, must vary from 1 to maxMessageStreamID
Expand All @@ -40,15 +40,11 @@ type Conn struct {
established chan struct{}
// beatAckChan notifies that heartbeat gets the ack as expected
beatAckChan chan struct{}
// dataChan is to pass the ProtocolDataPayload(=payload on M3UA DATA) to user
// dataChan is to pass the ProtocolDataPayload(=payload on M3UA DATA) to the user
dataChan chan *params.ProtocolDataPayload
// errChan is to pass errors to goroutine that monitors status
// errChan is to pass errors to a goroutine that monitors status
errChan chan error
// sctpConn is the underlying SCTP association
sctpConn *sctp.SCTPConn
// sctpInfo is SndRcvInfo in SCTP association
sctpInfo *sctp.SndRcvInfo
// cfg is a configuration that is required to communicate between M3UA endpoints
// cfg is a configuration required to communicate between M3UA endpoints
cfg *Config
// Condition to allow heartbeat, only after the state is AspUp
beatAllow *sync.Cond
Expand Down Expand Up @@ -126,9 +122,9 @@ func (c *Conn) WriteToStream(b []byte, streamID uint16) (n int, err error) {
}

// taken by value to avoid race condition on the stream id
info := *c.sctpInfo
info := *c.cfg.SCTPConfig.sctpInfo
info.Stream = streamID
n, err = c.sctpConn.SCTPWrite(d, &info)
n, err = c.cfg.SCTPConfig.sctpConn.SCTPWrite(d, &info)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -160,9 +156,9 @@ func (c *Conn) WritePDToStream(protocolData *params.Param, streamID uint16) (n i
}

// taken by value to avoid race condition on the stream id
info := *c.sctpInfo
info := *c.cfg.SCTPConfig.sctpInfo
info.Stream = streamID
n, err = c.sctpConn.SCTPWrite(d, &info)
n, err = c.cfg.SCTPConfig.sctpConn.SCTPWrite(d, &info)
if err != nil {
return 0, err
}
Expand All @@ -180,12 +176,12 @@ func (c *Conn) WriteSignal(m3 messages.M3UA) (n int, err error) {
}

// taken by value to avoid race condition on the stream id
sctpInfo := *c.sctpInfo
sctpInfo := *c.cfg.SCTPConfig.sctpInfo
if m3.MessageClass() != messages.MsgClassTransfer {
sctpInfo.Stream = 0
}

nn, err := c.sctpConn.SCTPWrite(buf, &sctpInfo)
nn, err := c.cfg.SCTPConfig.sctpConn.SCTPWrite(buf, &sctpInfo)
if err != nil {
return 0, fmt.Errorf("failed to write M3UA: %w", err)
}
Expand All @@ -200,39 +196,39 @@ func (c *Conn) Close() error {
defer c.muState.Unlock()

if c.state == StateAspDown {
return c.sctpConn.Close()
return c.cfg.SCTPConfig.sctpConn.Close()
}

close(c.established)
close(c.beatAckChan)
close(c.dataChan)
c.state = StateAspDown
return c.sctpConn.Close()
return c.cfg.SCTPConfig.sctpConn.Close()
}

// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.sctpConn.LocalAddr()
return c.cfg.SCTPConfig.sctpConn.LocalAddr()
}

// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.sctpConn.RemoteAddr()
return c.cfg.SCTPConfig.sctpConn.RemoteAddr()
}

// SetDeadline sets the read and write deadlines associated.
func (c *Conn) SetDeadline(t time.Time) error {
return c.sctpConn.SetDeadline(t)
return c.cfg.SCTPConfig.sctpConn.SetDeadline(t)
}

// SetReadDeadline sets the deadline for future Read calls.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.sctpConn.SetReadDeadline(t)
return c.cfg.SCTPConfig.sctpConn.SetReadDeadline(t)
}

// SetWriteDeadline sets the deadline for future Write calls.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.sctpConn.SetWriteDeadline(t)
return c.cfg.SCTPConfig.sctpConn.SetWriteDeadline(t)
}

// State returns current state of Conn.
Expand All @@ -244,7 +240,7 @@ func (c *Conn) State() State {

// StreamID returns sctpInfo.Stream of Conn.
func (c *Conn) StreamID() uint16 {
return c.sctpInfo.Stream
return c.cfg.SCTPConfig.sctpInfo.Stream
}

// MaxMessageStreamID returns the maximum negotiated sctp stream ID
Expand All @@ -253,6 +249,23 @@ func (c *Conn) MaxMessageStreamID() uint16 {
return c.maxMessageStreamID
}

// SetSctpSackConfig sets the SCTP SACK timer configuration on an active connection.
//
// sackDelay is the number of milliseconds for the delayed SACK timer
// (per RFC4960, should be between 200 and 500 ms).
//
// sackFrequency is the number of packets to receive before sending a SACK
// without waiting for the delay timer. Setting to 1 disables the delayed
// SACK algorithm.
//
// Note: sackDelay=0, sackFrequency=1 (disables delayed SACK)
func (c *Conn) SetSctpSackConfig(sackDelay, sackFrequency uint32) error {
return c.cfg.SCTPConfig.sctpConn.SetSackTimer(&sctp.SackTimer{
SackDelay: sackDelay,
SackFrequency: sackFrequency,
})
}

// chooseStreamID generates a random uint16 from 1 to max (inclusive)
func (c *Conn) chooseStreamID() uint16 {
if c.maxMessageStreamID == 1 {
Expand Down
2 changes: 1 addition & 1 deletion fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (c *Conn) monitor(ctx context.Context) {
}

// Read from conn to see something coming from the peer.
n, _, err := c.sctpConn.SCTPRead(buf)
n, _, err := c.cfg.SCTPConfig.sctpConn.SCTPRead(buf)
if err != nil {
c.Close()
return
Expand Down
23 changes: 18 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ func Listen(net string, laddr *sctp.SCTPAddr, cfg *Config) (*Listener, error) {
// After successfully establishing the association with peer, Payload can be read with Read() func.
// Other signals are automatically handled background in another goroutine.
func (l *Listener) Accept(ctx context.Context) (*Conn, error) {

l.Config.SCTPConfig.sctpInfo = &sctp.SndRcvInfo{PPID: 3, Stream: 0}
conn := &Conn{
muState: new(sync.RWMutex),
mode: modeServer,
stateChan: make(chan State),
established: make(chan struct{}),
sctpInfo: &sctp.SndRcvInfo{PPID: 3, Stream: 0},
cfg: l.Config,
}

Expand All @@ -60,14 +61,26 @@ func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
}

var ok bool
conn.sctpConn, ok = c.(*sctp.SCTPConn)
conn.cfg.SCTPConfig.sctpConn, ok = c.(*sctp.SCTPConn)
if !ok {
c.Close()
return nil, fmt.Errorf("failed to assert server connection")
}

r, err := conn.sctpConn.GetStatus()
if conn.cfg.SCTPConfig.SctpSackInfo != nil && conn.cfg.SCTPConfig.SctpSackInfo.Enabled {
err = conn.cfg.SCTPConfig.sctpConn.SetSackTimer(&sctp.SackTimer{
SackDelay: conn.cfg.SCTPConfig.SctpSackInfo.SackDelay,
SackFrequency: conn.cfg.SCTPConfig.SctpSackInfo.SackFrequency,
})
if err != nil {
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, fmt.Errorf("failed to set sack timer: %w", err)
}
}

r, err := conn.cfg.SCTPConfig.sctpConn.GetStatus()
if err != nil {
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, fmt.Errorf("failed to get sctpConn status: %w", err)
}
conn.maxMessageStreamID = r.Ostreams - 1 // removing 1 for management messages of stream ID 0
Expand All @@ -80,12 +93,12 @@ func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
select {
case _, ok := <-conn.established:
if !ok {
conn.sctpConn.Close()
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, ErrFailedToEstablish
}
return conn, nil
case <-time.After(10 * time.Second):
conn.sctpConn.Close()
conn.cfg.SCTPConfig.sctpConn.Close()
return nil, ErrTimeout
}
}
Expand Down
Loading