diff --git a/connect.go b/connect.go index a8ee7c0..f3a63f3 100644 --- a/connect.go +++ b/connect.go @@ -30,7 +30,7 @@ import ( "strings" "sync" "time" - + "github.com/gofrs/uuid" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" @@ -295,10 +295,10 @@ func (opts Options) connect() (*Conn, error) { } if opts.ConnectionToken != "" && opts.Password != "" { - return nil, memphisError(errors.New("you have to connect with one of the following methods: connection token / password")) + return nil, errInvalidConnectionType } if opts.ConnectionToken == "" && opts.Password == "" { - return nil, memphisError(errors.New("you have to connect with one of the following methods: connection token / password")) + return nil, errInvalidConnectionType } connId, err := uuid.NewV4() @@ -398,21 +398,21 @@ func (c *Conn) startConn() error { if (opts.TLSOpts.TlsCert != "") || (opts.TLSOpts.TlsKey != "") || (opts.TLSOpts.CaFile != "") { if opts.TLSOpts.TlsCert == "" { - return memphisError(errors.New("must provide a TLS cert file")) + return errMissingTLSCertFile } if opts.TLSOpts.TlsKey == "" { - return memphisError(errors.New("must provide a TLS key file")) + return errMissingTLSKeyFile } if opts.TLSOpts.CaFile == "" { - return memphisError(errors.New("must provide a TLS ca file")) + return errMissingTLSCaFile } cert, err := tls.LoadX509KeyPair(opts.TLSOpts.TlsCert, opts.TLSOpts.TlsKey) if err != nil { - return memphisError(errors.New("memphis: error loading client certificate: " + err.Error())) + return errLoadClientCertFailed(err) } cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) if err != nil { - return memphisError(errors.New("memphis: error parsing client certificate: " + err.Error())) + return errLoadClientCertFailed(err) } TLSConfig := &tls.Config{MinVersion: tls.VersionTLS12} TLSConfig.Certificates = []tls.Certificate{cert} @@ -420,7 +420,7 @@ func (c *Conn) startConn() error { pemData, err := os.ReadFile(opts.TLSOpts.CaFile) if err != nil { - return memphisError(errors.New("memphis: error loading ca file: " + err.Error())) + return errLoadClientCertFailed(err) } certs.AppendCertsFromPEM(pemData) TLSConfig.RootCAs = certs @@ -852,7 +852,7 @@ func (c *Conn) FetchMessages(stationName string, consumerName string, opts ...Fe } } if defaultOpts.BatchSize > maxBatchSize || defaultOpts.BatchSize < 1 { - return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1")) + return nil, errInvalidBatchSize(maxBatchSize) } if cons == nil { if defaultOpts.GenUniqueSuffix { @@ -984,12 +984,12 @@ func (c *Conn) GetPartitionFromKey(key string, stationName string) (int, error) func (c *Conn) ValidatePartitionNumber(partitionNumber int, stationName string) error { if partitionNumber < 0 || partitionNumber >= len(c.stationPartitions[stationName].PartitionsList) { - return errors.New("Partition number is out of range") + return errPartitionNumOutOfRange } for _, partition := range c.stationPartitions[stationName].PartitionsList { if partition == partitionNumber { return nil } } - return fmt.Errorf("Partition %v does not exist in station %v", partitionNumber, stationName) + return errPartitionNotInStation(partitionNumber, stationName) } diff --git a/consumer.go b/consumer.go index 4a4495e..b152bcc 100644 --- a/consumer.go +++ b/consumer.go @@ -41,12 +41,6 @@ const ( lastConsumerDestroyReqVersion = 1 ) -var ( - ConsumerErrStationUnreachable = errors.New("station unreachable") - ConsumerErrConsumeInactive = errors.New("consumer is inactive") - ConsumerErrDelayDlsMsg = errors.New("cannot delay DLS message") -) - // Consumer - memphis consumer object. type Consumer struct { Name string @@ -116,7 +110,7 @@ func (m *Msg) DataDeserialized() (any, error) { sd, err := m.conn.getSchemaDetails(m.internalStationName) if err != nil { - return nil, memphisError(errors.New("Schema validation has failed: " + err.Error())) + return nil, errSchemaValidationFailed(err) } var msgBytes []byte @@ -125,12 +119,12 @@ func (m *Msg) DataDeserialized() (any, error) { } else if jsMsg, ok := m.msg.(jetstream.Msg); ok { msgBytes = jsMsg.Data() } else { - return nil, errors.New("Message format is not supported") + return nil, errInvalidMessageFormat } _, err = sd.validateMsg(msgBytes) if err != nil { - return nil, memphisError(errors.New("Deserialization has been failed since the message format does not align with the currently attached schema: " + err.Error())) + return nil, errMessageMisalignedSchema(err) } switch sd.schemaType { @@ -139,7 +133,7 @@ func (m *Msg) DataDeserialized() (any, error) { err = proto.Unmarshal(msgBytes, pMsg) if err != nil { if strings.Contains(err.Error(), "cannot parse invalid wire-format data") { - err = errors.New("invalid message format, expecting protobuf") + err = errExpectingProtobuf } return data, memphisError(err) } @@ -148,13 +142,13 @@ func (m *Msg) DataDeserialized() (any, error) { panic(err) } if err := json.Unmarshal(jsonBytes, &data); err != nil { - err = errors.New("Bad JSON format - " + err.Error()) + err = errBadJSON(err) return data, memphisError(err) } return data, nil case "json": if err := json.Unmarshal(msgBytes, &data); err != nil { - err = errors.New("Bad JSON format - " + err.Error()) + err = errBadJSON(err) return data, memphisError(err) } return data, nil @@ -162,7 +156,7 @@ func (m *Msg) DataDeserialized() (any, error) { return string(msgBytes), nil case "avro": if err := json.Unmarshal(msgBytes, &data); err != nil { - err = errors.New("Bad JSON format - " + err.Error()) + err = errBadJSON(err) return data, memphisError(err) } return data, nil @@ -188,7 +182,7 @@ func (m *Msg) GetSequenceNumber() (uint64, error) { } seq = meta.Sequence.Stream } else { - return 0, errors.New("message format is not supported") + return 0, errInvalidMessageFormat } return seq, nil @@ -217,7 +211,7 @@ func (m *Msg) Ack() error { } else if jsMsg, ok := m.msg.(jetstream.Msg); ok { err = jsMsg.Ack() } else { - return errors.New("message format is not supported") + return errInvalidMessageFormat } if err != nil { var headers nats.Header @@ -331,10 +325,10 @@ func (m *Msg) Delay(duration time.Duration) error { } else if jsMsg, ok := m.msg.(jetstream.Msg); ok { return jsMsg.NakWithDelay(duration) } else { - return errors.New("Message format is not supported") + return errInvalidMessageFormat } } - return memphisError(ConsumerErrDelayDlsMsg) + return errConsumerErrDelayDlsMsg } // ConsumerErrHandler is used to process asynchronous errors. @@ -462,19 +456,19 @@ func (opts *ConsumerOpts) createConsumer(c *Conn, options ...RequestOpt) (*Consu } if consumer.StartConsumeFromSequence == 0 { - return nil, memphisError(errors.New("startConsumeFromSequence has to be a positive number")) + return nil, errStartConsumeNotPositive } if consumer.LastMessages < -1 { - return nil, memphisError(errors.New("min value for LastMessages is -1")) + return nil, errLastMessagesNegative } if consumer.StartConsumeFromSequence > 1 && consumer.LastMessages > -1 { - return nil, memphisError(errors.New("Consumer creation options can't contain both startConsumeFromSequence and lastMessages")) + return nil, errBothStartConsumeAndLastMessages } if consumer.BatchSize > maxBatchSize || consumer.BatchSize < 1 { - return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1")) + return nil, errInvalidBatchSize(maxBatchSize) } sn := getInternalName(consumer.stationName) @@ -579,7 +573,7 @@ func (c *Consumer) pingConsumer() { if generalErr != nil { if strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found") { c.subscriptionActive = false - c.callErrHandler(ConsumerErrStationUnreachable) + c.callErrHandler(errConsumerErrStationUnreachable) } } case <-c.pingQuit: @@ -674,7 +668,7 @@ func (c *Consumer) Consume(handlerFunc ConsumeHandler, opts ...ConsumingOpt) err // StopConsume - stops the continuous consume operation. func (c *Consumer) StopConsume() { if !c.consumeActive { - c.callErrHandler(ConsumerErrConsumeInactive) + c.callErrHandler(errConsumerErrConsumeInactive) return } c.consumeQuit <- struct{}{} @@ -683,14 +677,14 @@ func (c *Consumer) StopConsume() { func (c *Consumer) fetchSubscription(partitionKey string, partitionNum int) ([]*Msg, error) { if !c.subscriptionActive { - return nil, memphisError(errors.New("station unreachable")) + return nil, errUnreachableStation } wrappedMsgs := make([]*Msg, 0, c.BatchSize) partitionNumber := 1 if len(c.jsConsumers) > 1 { if partitionKey != "" && partitionNum > 0 { - return nil, memphisError(fmt.Errorf("can not use both partition number and partition key")) + return nil, errBothPartitionNumAndKey } if partitionKey != "" { partitionFromKey, err := c.conn.GetPartitionFromKey(partitionKey, c.stationName) @@ -712,12 +706,12 @@ func (c *Consumer) fetchSubscription(partitionKey string, partitionNum int) ([]* batch, err := c.jsConsumers[partitionNumber].Fetch(c.BatchSize, jetstream.FetchMaxWait(c.BatchMaxTimeToWait)) if err != nil && err != nats.ErrTimeout { c.subscriptionActive = false - c.callErrHandler(ConsumerErrStationUnreachable) + c.callErrHandler(errConsumerErrStationUnreachable) c.StopConsume() } if batch.Error() != nil && batch.Error() != nats.ErrTimeout { c.subscriptionActive = false - c.callErrHandler(ConsumerErrStationUnreachable) + c.callErrHandler(errConsumerErrStationUnreachable) c.StopConsume() } // msgs := batch.Messages() @@ -758,11 +752,11 @@ func (c *Consumer) fetchSubscriprionWithTimeout(partitionKey string, partitionNu batch, err := c.jsConsumers[partitionNumber].Fetch(c.BatchSize, jetstream.FetchMaxWait(c.BatchMaxTimeToWait)) if err != nil && err != nats.ErrTimeout { - c.callErrHandler(ConsumerErrStationUnreachable) + c.callErrHandler(errConsumerErrStationUnreachable) return []*Msg{}, nil } if batch.Error() != nil && batch.Error() != nats.ErrTimeout { - c.callErrHandler(ConsumerErrStationUnreachable) + c.callErrHandler(errConsumerErrStationUnreachable) return []*Msg{}, nil } @@ -776,7 +770,7 @@ func (c *Consumer) fetchSubscriprionWithTimeout(partitionKey string, partitionNu // Fetch - immediately fetch a batch of messages. func (c *Consumer) Fetch(batchSize int, prefetch bool, opts ...ConsumingOpt) ([]*Msg, error) { if batchSize > maxBatchSize || batchSize < 1 { - return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize) + " or less than 1")) + return nil, errInvalidBatchSize(maxBatchSize) } defaultOpts := getDefaultConsumingOptions() diff --git a/memphis_errors.go b/memphis_errors.go new file mode 100644 index 0000000..0b9d931 --- /dev/null +++ b/memphis_errors.go @@ -0,0 +1,67 @@ +package memphis + +import ( + "fmt" + "strconv" + "errors" +) + +var ( + errInvalidConnectionType = memphisError(errors.New("you have to connect with one of the following methods: connection token / password")) + errMissingTLSCertFile = memphisError(errors.New("must provide a TLS cert file")) + errMissingTLSKeyFile = memphisError(errors.New("must provide a TLS key file")) + errMissingTLSCaFile = memphisError(errors.New("must provide a TLS ca file")) + errPartitionNumOutOfRange = memphisError(errors.New("partition number is out of range")) + errConsumerErrStationUnreachable = memphisError(errors.New("station unreachable")) + errConsumerErrConsumeInactive = memphisError(errors.New("consumer is inactive")) + errConsumerErrDelayDlsMsg = memphisError(errors.New("cannot delay DLS message")) + errInvalidMessageFormat = memphisError(errors.New("message format is not supported")) + errExpectingProtobuf = memphisError(errors.New("invalid message format, expecting protobuf")) + errBothPartitionNumAndKey = memphisError(errors.New("can not use both partition number and partition key")) + errStartConsumeNotPositive = memphisError(errors.New("startConsumeFromSequence has to be a positive number")) + errLastMessagesNegative = memphisError(errors.New("min value for LastMessages is -1")) + errBothStartConsumeAndLastMessages = memphisError(errors.New("Consumer creation options can't contain both startConsumeFromSequence and lastMessages")) + errUnreachableStation = memphisError(errors.New("station unreachable")) + errInvalidStationName = memphisError(errors.New("station name should be either string or []string")) + errInvalidHeaderKey = memphisError(errors.New("keys in headers should not start with $memphis")) + errUnsupportedMsgType = memphisError(errors.New("unsupported message type")) + errEmptyMsgId = memphisError(errors.New("msg id can not be empty")) + errPartitionNotInKey = memphisError(errors.New("failed to get partition from key")) + errMissingFunctionsListener = memphisError(errors.New("functions listener doesn't exist")) + errMissingSchemaListener = memphisError(errors.New("schema listener doesn't exist")) + errStationNotSubedToSchema = memphisError(errors.New("station subscription doesn't exist")) + errInvalidSchmeaType = memphisError(errors.New("invalid schema type")) + errExpectinGraphQL = memphisError(errors.New("invalid message format, expecting GraphQL")) +) + +func errInvalidAvroFormat(err error) error{ + return memphisError(fmt.Errorf("Bad Avro format - %s", err.Error())) +} + +func errProducerNotInCache(producerName string) error{ + return memphisError(fmt.Errorf("%s not exists on the map", producerName)) +} + +func errLoadClientCertFailed(err error) error{ + return memphisError(fmt.Errorf("memphis: error loading client certificate: %s", err.Error())) +} + +func errInvalidBatchSize(maxBatchSize int) error{ + return memphisError(fmt.Errorf("Batch size can not be greater than %s or less than 1", strconv.Itoa(maxBatchSize))) +} + +func errPartitionNotInStation(partitionNumber int, stationName string) error { + return memphisError(fmt.Errorf("partition %v does not exist in station %v", partitionNumber, stationName)) +} + +func errSchemaValidationFailed(err error) error { + return memphisError(fmt.Errorf("Schema validation has failed: %s", err.Error())) +} + +func errMessageMisalignedSchema(err error) error { + return memphisError(fmt.Errorf("Deserialization has been failed since the message format does not align with the currently attached schema: %s", err.Error())) +} + +func errBadJSON(err error) error { + return memphisError(fmt.Errorf("Bad JSON format - %s", err.Error())) +} \ No newline at end of file diff --git a/producer.go b/producer.go index 120bf2e..aa1985d 100644 --- a/producer.go +++ b/producer.go @@ -162,7 +162,7 @@ func (c *Conn) CreateProducer(stationName interface{}, name string, opts ...Prod case string: case []string: default: - return nil, memphisError(errors.New("station name should be either string or []string")) + return nil, errInvalidStationName } name = strings.ToLower(name) @@ -245,7 +245,7 @@ func (c *Conn) Produce(stationName interface{}, name string, message any, opts [ case string: case []string: default: - return memphisError(errors.New("station name should be either string or []string")) + return errInvalidStationName } if singleStationName, ok := stationName.(string); ok { @@ -294,7 +294,7 @@ func (c *Conn) getProducerFromCache(stationName, name string) (*Producer, error) pn := fmt.Sprintf("%s_%s", stationName, name) pm := c.getProducersMap() if pm.getProducer(pn) == nil { - return nil, fmt.Errorf("%s not exists on the map", pn) + return nil, errProducerNotInCache(pn) } return pm.getProducer(pn), nil @@ -485,7 +485,7 @@ func (p *Producer) produceToSingleStation(message any, opts ...ProduceOpt) error func (hdr *Headers) validateHeaderKey(key string) error { if strings.HasPrefix(key, "$memphis") { - return memphisError(errors.New("keys in headers should not start with $memphis")) + return errInvalidHeaderKey } return nil } @@ -521,12 +521,12 @@ func (opts *ProduceOpts) produce(p *Producer) error { streamName = fmt.Sprintf("%v$%v", sn, p.conn.stationPartitions[sn].PartitionsList[0]) } else if len(p.conn.stationPartitions[sn].PartitionsList) > 1 { if opts.ProducerPartitionNumber > 0 && opts.ProducerPartitionKey != "" { - return memphisError(fmt.Errorf("Can not use both partition number and partition key")) + return errBothPartitionNumAndKey } if opts.ProducerPartitionKey != "" { partitionNumber, err := p.conn.GetPartitionFromKey(opts.ProducerPartitionKey, sn) if err != nil { - return memphisError(fmt.Errorf("failed to get partition from key")) + return errPartitionNotInKey } streamName = fmt.Sprintf("%v$%v", sn, partitionNumber) } else if opts.ProducerPartitionNumber > 0 { @@ -644,7 +644,7 @@ func (p *Producer) sendMsgToDls(msg any, headers map[string][]string, err error) func (p *Producer) validateMsg(msg any, headers map[string][]string) ([]byte, error) { sd, err := p.getSchemaDetails() if err != nil { - return nil, memphisError(errors.New("Schema validation has failed: " + err.Error())) + return nil, errSchemaValidationFailed(err) } var originalMsgBytes []byte @@ -674,7 +674,7 @@ func (p *Producer) validateMsg(msg any, headers map[string][]string) ([]byte, er return nil, memphisError(err) } } else { - return nil, memphisError(errors.New("unsupported message type")) + return nil, errUnsupportedMsgType } } @@ -688,7 +688,7 @@ func (p *Producer) validateMsg(msg any, headers map[string][]string) ([]byte, er } p.sendMsgToDls(msgToSend, headers, err) - return nil, memphisError(errors.New("Schema validation has failed: " + err.Error())) + return nil, errSchemaValidationFailed(err) } originalMsgBytes = msgBytes } @@ -762,7 +762,7 @@ func SyncProduce() ProduceOpt { func MsgId(id string) ProduceOpt { return func(opts *ProduceOpts) error { if id == "" { - return errors.New("msg id can not be empty") + return errEmptyMsgId } opts.MsgHeaders.MsgHeaders["msg-id"] = []string{id} return nil diff --git a/station.go b/station.go index 3e726a4..5d626c5 100644 --- a/station.go +++ b/station.go @@ -473,7 +473,7 @@ func (c *Conn) removeFunctionsUpdatesListener(stationName string) error { sfs, ok := c.stationFunctionSubs[sn] if !ok { - return memphisError(errors.New("functions listener doesn't exist")) + return errMissingFunctionsListener } sfs.StationFunctionsMu.Lock() @@ -499,7 +499,7 @@ func (c *Conn) removeSchemaUpdatesListener(stationName string) error { defer stationUpdatesSubsLock.Unlock() sus, ok := c.stationUpdatesSubs[sn] if !ok { - return memphisError(errors.New("listener doesn't exist")) + return errMissingSchemaListener } sus.refCount-- @@ -522,7 +522,7 @@ func (c *Conn) getSchemaDetails(stationName string) (schemaDetails, error) { sus, ok := c.stationUpdatesSubs[sn] if !ok { - return schemaDetails{}, memphisError(errors.New("station subscription doesn't exist")) + return schemaDetails{}, errStationNotSubedToSchema } return sus.schemaDetails, nil @@ -655,7 +655,7 @@ func (sd *schemaDetails) validateMsg(msg any) ([]byte, error) { case "avro": return sd.validAvroSchemaMsg(msg) default: - return nil, memphisError(errors.New("invalid schema type")) + return nil, errInvalidSchmeaType } } @@ -687,14 +687,14 @@ func (sd *schemaDetails) validateProtoMsg(msg any) ([]byte, error) { return nil, memphisError(err) } default: - return nil, memphisError(errors.New("unsupported message type")) + return nil, errUnsupportedMsgType } protoMsg := dynamicpb.NewMessage(sd.msgDescriptor) err = proto.Unmarshal(msgBytes, protoMsg) if err != nil { if strings.Contains(err.Error(), "cannot parse invalid wire-format data") { - err = errors.New("invalid message format, expecting protobuf") + err = errExpectingProtobuf } return msgBytes, memphisError(err) } @@ -713,7 +713,7 @@ func (sd *schemaDetails) validJsonSchemaMsg(msg any) ([]byte, error) { case []byte: msgBytes = msg.([]byte) if err := json.Unmarshal(msgBytes, &message); err != nil { - err = errors.New("Bad JSON format - " + err.Error()) + err = errBadJSON(err) return nil, memphisError(err) } case map[string]interface{}: @@ -734,7 +734,7 @@ func (sd *schemaDetails) validJsonSchemaMsg(msg any) ([]byte, error) { return nil, memphisError(err) } } else { - return nil, memphisError(errors.New("unsupported message type")) + return nil, errUnsupportedMsgType } } if err = sd.jsonSchema.Validate(message); err != nil { @@ -773,7 +773,7 @@ func (sd *schemaDetails) validateGraphQlMsg(msg any) ([]byte, error) { validateErrorGql = strings.Join(validateErrors, resultErr) } if strings.Contains(validateErrorGql, "syntax error") { - return nil, memphisError(errors.New("invalid message format, expecting GraphQL")) + return nil, errExpectinGraphQL } return msgBytes, memphisError(errors.New(validateErrorGql)) @@ -796,7 +796,7 @@ func (sd *schemaDetails) validAvroSchemaMsg(msg any) ([]byte, error) { case []byte: msgBytes = msg.([]byte) if err := json.Unmarshal(msgBytes, &message); err != nil { - err = errors.New("Bad Avro format - " + err.Error()) + err = errInvalidAvroFormat(err) return nil, memphisError(err) } case map[string]interface{}: @@ -805,7 +805,7 @@ func (sd *schemaDetails) validAvroSchemaMsg(msg any) ([]byte, error) { return nil, memphisError(err) } if err := json.Unmarshal(msgBytes, &message); err != nil { - err = errors.New("Bad Avro format - " + err.Error()) + err = errInvalidAvroFormat(err) return nil, memphisError(err) } @@ -825,7 +825,7 @@ func (sd *schemaDetails) validAvroSchemaMsg(msg any) ([]byte, error) { return nil, memphisError(err) } } else { - return nil, memphisError(errors.New("unsupported message type")) + return nil, errUnsupportedMsgType } }