Skip to content
24 changes: 12 additions & 12 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"strings"
"sync"
"time"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove line

"github.com/gofrs/uuid"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
Expand Down Expand Up @@ -295,10 +295,10 @@ func (opts Options) connect() (*Conn, error) {
}

if opts.ConnectionToken != "" && opts.Password != "" {
return nil, memphisError(errors.New("you have to connect with one of the following methods: connection token / password"))
return nil, errInvalidConnectionType
}
if opts.ConnectionToken == "" && opts.Password == "" {
return nil, memphisError(errors.New("you have to connect with one of the following methods: connection token / password"))
return nil, errInvalidConnectionType
}

connId, err := uuid.NewV4()
Expand Down Expand Up @@ -398,29 +398,29 @@ func (c *Conn) startConn() error {

if (opts.TLSOpts.TlsCert != "") || (opts.TLSOpts.TlsKey != "") || (opts.TLSOpts.CaFile != "") {
if opts.TLSOpts.TlsCert == "" {
return memphisError(errors.New("must provide a TLS cert file"))
return errMissingTLSCertFile
}
if opts.TLSOpts.TlsKey == "" {
return memphisError(errors.New("must provide a TLS key file"))
return errMissingTLSKeyFile
}
if opts.TLSOpts.CaFile == "" {
return memphisError(errors.New("must provide a TLS ca file"))
return errMissingTLSCaFile
}
cert, err := tls.LoadX509KeyPair(opts.TLSOpts.TlsCert, opts.TLSOpts.TlsKey)
if err != nil {
return memphisError(errors.New("memphis: error loading client certificate: " + err.Error()))
return errLoadClientCertFailed(err)
}
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return memphisError(errors.New("memphis: error parsing client certificate: " + err.Error()))
return errLoadClientCertFailed(err)
}
TLSConfig := &tls.Config{MinVersion: tls.VersionTLS12}
TLSConfig.Certificates = []tls.Certificate{cert}
certs := x509.NewCertPool()

pemData, err := os.ReadFile(opts.TLSOpts.CaFile)
if err != nil {
return memphisError(errors.New("memphis: error loading ca file: " + err.Error()))
return errLoadClientCertFailed(err)
}
certs.AppendCertsFromPEM(pemData)
TLSConfig.RootCAs = certs
Expand Down Expand Up @@ -852,7 +852,7 @@ func (c *Conn) FetchMessages(stationName string, consumerName string, opts ...Fe
}
}
if defaultOpts.BatchSize > maxBatchSize || defaultOpts.BatchSize < 1 {
return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1"))
return nil, errInvalidBatchSize(maxBatchSize)
}
if cons == nil {
if defaultOpts.GenUniqueSuffix {
Expand Down Expand Up @@ -984,12 +984,12 @@ func (c *Conn) GetPartitionFromKey(key string, stationName string) (int, error)

func (c *Conn) ValidatePartitionNumber(partitionNumber int, stationName string) error {
if partitionNumber < 0 || partitionNumber >= len(c.stationPartitions[stationName].PartitionsList) {
return errors.New("Partition number is out of range")
return errPartitionNumOutOfRange
}
for _, partition := range c.stationPartitions[stationName].PartitionsList {
if partition == partitionNumber {
return nil
}
}
return fmt.Errorf("Partition %v does not exist in station %v", partitionNumber, stationName)
return errPartitionNotInStation(partitionNumber, stationName)
}
50 changes: 22 additions & 28 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ const (
lastConsumerDestroyReqVersion = 1
)

var (
ConsumerErrStationUnreachable = errors.New("station unreachable")
ConsumerErrConsumeInactive = errors.New("consumer is inactive")
ConsumerErrDelayDlsMsg = errors.New("cannot delay DLS message")
)

// Consumer - memphis consumer object.
type Consumer struct {
Name string
Expand Down Expand Up @@ -116,7 +110,7 @@ func (m *Msg) DataDeserialized() (any, error) {

sd, err := m.conn.getSchemaDetails(m.internalStationName)
if err != nil {
return nil, memphisError(errors.New("Schema validation has failed: " + err.Error()))
return nil, errSchemaValidationFailed(err)
}
var msgBytes []byte

Expand All @@ -125,12 +119,12 @@ func (m *Msg) DataDeserialized() (any, error) {
} else if jsMsg, ok := m.msg.(jetstream.Msg); ok {
msgBytes = jsMsg.Data()
} else {
return nil, errors.New("Message format is not supported")
return nil, errInvalidMessageFormat
}

_, err = sd.validateMsg(msgBytes)
if err != nil {
return nil, memphisError(errors.New("Deserialization has been failed since the message format does not align with the currently attached schema: " + err.Error()))
return nil, errMessageMisalignedSchema(err)
}

switch sd.schemaType {
Expand All @@ -139,7 +133,7 @@ func (m *Msg) DataDeserialized() (any, error) {
err = proto.Unmarshal(msgBytes, pMsg)
if err != nil {
if strings.Contains(err.Error(), "cannot parse invalid wire-format data") {
err = errors.New("invalid message format, expecting protobuf")
err = errExpectingProtobuf
}
return data, memphisError(err)
}
Expand All @@ -148,21 +142,21 @@ func (m *Msg) DataDeserialized() (any, error) {
panic(err)
}
if err := json.Unmarshal(jsonBytes, &data); err != nil {
err = errors.New("Bad JSON format - " + err.Error())
err = errBadJSON(err)
return data, memphisError(err)
}
return data, nil
case "json":
if err := json.Unmarshal(msgBytes, &data); err != nil {
err = errors.New("Bad JSON format - " + err.Error())
err = errBadJSON(err)
return data, memphisError(err)
}
return data, nil
case "graphql":
return string(msgBytes), nil
case "avro":
if err := json.Unmarshal(msgBytes, &data); err != nil {
err = errors.New("Bad JSON format - " + err.Error())
err = errBadJSON(err)
return data, memphisError(err)
}
return data, nil
Expand All @@ -188,7 +182,7 @@ func (m *Msg) GetSequenceNumber() (uint64, error) {
}
seq = meta.Sequence.Stream
} else {
return 0, errors.New("message format is not supported")
return 0, errInvalidMessageFormat
}

return seq, nil
Expand Down Expand Up @@ -217,7 +211,7 @@ func (m *Msg) Ack() error {
} else if jsMsg, ok := m.msg.(jetstream.Msg); ok {
err = jsMsg.Ack()
} else {
return errors.New("message format is not supported")
return errInvalidMessageFormat
}
if err != nil {
var headers nats.Header
Expand Down Expand Up @@ -331,10 +325,10 @@ func (m *Msg) Delay(duration time.Duration) error {
} else if jsMsg, ok := m.msg.(jetstream.Msg); ok {
return jsMsg.NakWithDelay(duration)
} else {
return errors.New("Message format is not supported")
return errInvalidMessageFormat
}
}
return memphisError(ConsumerErrDelayDlsMsg)
return errConsumerErrDelayDlsMsg
}

// ConsumerErrHandler is used to process asynchronous errors.
Expand Down Expand Up @@ -462,19 +456,19 @@ func (opts *ConsumerOpts) createConsumer(c *Conn, options ...RequestOpt) (*Consu
}

if consumer.StartConsumeFromSequence == 0 {
return nil, memphisError(errors.New("startConsumeFromSequence has to be a positive number"))
return nil, errStartConsumeNotPositive
}

if consumer.LastMessages < -1 {
return nil, memphisError(errors.New("min value for LastMessages is -1"))
return nil, errLastMessagesNegative
}

if consumer.StartConsumeFromSequence > 1 && consumer.LastMessages > -1 {
return nil, memphisError(errors.New("Consumer creation options can't contain both startConsumeFromSequence and lastMessages"))
return nil, errBothStartConsumeAndLastMessages
}

if consumer.BatchSize > maxBatchSize || consumer.BatchSize < 1 {
return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1"))
return nil, errInvalidBatchSize(maxBatchSize)
}

sn := getInternalName(consumer.stationName)
Expand Down Expand Up @@ -579,7 +573,7 @@ func (c *Consumer) pingConsumer() {
if generalErr != nil {
if strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found") {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
c.callErrHandler(errConsumerErrStationUnreachable)
}
}
case <-c.pingQuit:
Expand Down Expand Up @@ -674,7 +668,7 @@ func (c *Consumer) Consume(handlerFunc ConsumeHandler, opts ...ConsumingOpt) err
// StopConsume - stops the continuous consume operation.
func (c *Consumer) StopConsume() {
if !c.consumeActive {
c.callErrHandler(ConsumerErrConsumeInactive)
c.callErrHandler(errConsumerErrConsumeInactive)
return
}
c.consumeQuit <- struct{}{}
Expand All @@ -683,14 +677,14 @@ func (c *Consumer) StopConsume() {

func (c *Consumer) fetchSubscription(partitionKey string, partitionNum int) ([]*Msg, error) {
if !c.subscriptionActive {
return nil, memphisError(errors.New("station unreachable"))
return nil, errUnreachableStation
}
wrappedMsgs := make([]*Msg, 0, c.BatchSize)
partitionNumber := 1

if len(c.jsConsumers) > 1 {
if partitionKey != "" && partitionNum > 0 {
return nil, memphisError(fmt.Errorf("can not use both partition number and partition key"))
return nil, errBothPartitionNumAndKey
}
if partitionKey != "" {
partitionFromKey, err := c.conn.GetPartitionFromKey(partitionKey, c.stationName)
Expand All @@ -712,12 +706,12 @@ func (c *Consumer) fetchSubscription(partitionKey string, partitionNum int) ([]*
batch, err := c.jsConsumers[partitionNumber].Fetch(c.BatchSize, jetstream.FetchMaxWait(c.BatchMaxTimeToWait))
if err != nil && err != nats.ErrTimeout {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
c.callErrHandler(errConsumerErrStationUnreachable)
c.StopConsume()
}
if batch.Error() != nil && batch.Error() != nats.ErrTimeout {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
c.callErrHandler(errConsumerErrStationUnreachable)
c.StopConsume()
}
// msgs := batch.Messages()
Expand Down Expand Up @@ -776,7 +770,7 @@ func (c *Consumer) fetchSubscriprionWithTimeout(partitionKey string, partitionNu
// Fetch - immediately fetch a batch of messages.
func (c *Consumer) Fetch(batchSize int, prefetch bool, opts ...ConsumingOpt) ([]*Msg, error) {
if batchSize > maxBatchSize || batchSize < 1 {
return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1"))
return nil, errInvalidBatchSize(maxBatchSize)
}

defaultOpts := getDefaultConsumingOptions()
Expand Down
67 changes: 67 additions & 0 deletions memphis_errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package memphis

import (
"fmt"
"strconv"
"errors"
)

var (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use const here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shay23b Value of type error cannot be const is what the compiler complains about so I guess not

errInvalidConnectionType = memphisError(errors.New("you have to connect with one of the following methods: connection token / password"))
errMissingTLSCertFile = memphisError(errors.New("must provide a TLS cert file"))
errMissingTLSKeyFile = memphisError(errors.New("must provide a TLS key file"))
errMissingTLSCaFile = memphisError(errors.New("must provide a TLS ca file"))
errPartitionNumOutOfRange = memphisError(errors.New("partition number is out of range"))
errConsumerErrStationUnreachable = memphisError(errors.New("station unreachable"))
errConsumerErrConsumeInactive = memphisError(errors.New("consumer is inactive"))
errConsumerErrDelayDlsMsg = memphisError(errors.New("cannot delay DLS message"))
errInvalidMessageFormat = memphisError(errors.New("message format is not supported"))
errExpectingProtobuf = memphisError(errors.New("invalid message format, expecting protobuf"))
errBothPartitionNumAndKey = memphisError(errors.New("can not use both partition number and partition key"))
errStartConsumeNotPositive = memphisError(errors.New("startConsumeFromSequence has to be a positive number"))
errLastMessagesNegative = memphisError(errors.New("min value for LastMessages is -1"))
errBothStartConsumeAndLastMessages = memphisError(errors.New("Consumer creation options can't contain both startConsumeFromSequence and lastMessages"))
errUnreachableStation = memphisError(errors.New("station unreachable"))
errInvalidStationName = memphisError(errors.New("station name should be either string or []string"))
errInvalidHeaderKey = memphisError(errors.New("keys in headers should not start with $memphis"))
errUnsupportedMsgType = memphisError(errors.New("unsupported message type"))
errEmptyMsgId = memphisError(errors.New("msg id can not be empty"))
errPartitionNotInKey = memphisError(errors.New("failed to get partition from key"))
errMissingFunctionsListener = memphisError(errors.New("functions listener doesn't exist"))
errMissingSchemaListener = memphisError(errors.New("schema listener doesn't exist"))
errStationNotSubedToSchema = memphisError(errors.New("station subscription doesn't exist"))
errInvalidSchmeaType = memphisError(errors.New("invalid schema type"))
errExpectinGraphQL = memphisError(errors.New("invalid message format, expecting GraphQL"))
)

func errInvalidAvroFormat(err error) error{
return memphisError(errors.New("Bad Avro format - " + err.Error()))
}

func errProducerNotInCache(producerName string) error{
return memphisError(fmt.Errorf("%s not exists on the map", producerName))
}

func errLoadClientCertFailed(err error) error{
return memphisError(errors.New("memphis: error loading client certificate: " + err.Error()))
}

func errInvalidBatchSize(maxBatchSize int) error{
return memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1"))
}

func errPartitionNotInStation(partitionNumber int, stationName string) error {
return memphisError(fmt.Errorf("partition %v does not exist in station %v", partitionNumber, stationName))
}

func errSchemaValidationFailed(err error) error {
return memphisError(errors.New("Schema validation has failed: " + err.Error()))
}

func errMessageMisalignedSchema(err error) error {
return memphisError(errors.New("Deserialization has been failed since the message format does not align with the currently attached schema: " + err.Error()))
}

func errBadJSON(err error) error {
return memphisError(errors.New("Bad JSON format - " + err.Error()))
}
Loading