diff --git a/.golangci.yml b/.golangci.yml index b95c11b..6250424 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -57,6 +57,7 @@ linters: - "golang.org/x/tools" - "github.com/vmihailenco/msgpack/v5" - "github.com/tarantool/go-option" + - "github.com/google/uuid" test: files: - "$test" diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d31081..401457a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added +- gentypes: Add support for third-party types. + ### Changed ### Fixed diff --git a/README.md b/README.md index 9042880..1ef2c11 100644 --- a/README.md +++ b/README.md @@ -178,10 +178,65 @@ Or you can use it to generate file from go: Flags: -• `-package`: Path to the Go package containing types to wrap (default: `"."`) -• `-ext-code`: MessagePack extension code to use for custom types (must be between --128 and 127, no default value) -• `-verbose`: Enable verbose output (default: `false`) + * `-package`: Path to the Go package containing types to wrap (default: `"."`) + * `-ext-code`: MessagePack extension code to use for custom types (must be between + -128 and 127, no default value) + * `-verbose`: Enable verbose output (default: `false`) + * `-force`: Ignore absence of marshal/unmarshal methods on type (default: `false`). + Helpful for types from third-party modules. + * `-imports`: Add imports to generated file (default is empty). + Helpful for types from third-party modules. + * `-marshal-func`: func that should do marshaling (default is `MarshalMsgpack` method). + Helpful for types from third-party modules. + Should be func of type `func(v T) ([]byte, error)` and should + be located in the same dir or should be imported. + * `-unmarshal-func`: func that should do unmarshalling (default is `UnmarshalMsgpack` method). + Helpful for types from third-party modules. + Should be func of type `func(v *T, data []byte) error` and should + be located in the same dir or should be imported. + +#### Generating Optional Types for Third-Party Modules + +Sometimes you need to generate an optional type for a type from a third-party module, +and you can't add `MarshalMsgpack`/`UnmarshalMsgpack` methods to it. +In this case, you can use the `-force`, `-imports`, `-marshal-func`, and `-unmarshal-func` flags. + +For example, to generate an optional type for `github.com/google/uuid.UUID`: + +1. Create a file with marshal and unmarshal functions for the third-party type. + For example, `uuid.go`: + + ```go + package main + + import ( + "errors" + + "github.com/google/uuid" + ) + + func encodeUUID(uuid uuid.UUID) ([]byte, error) { + return uuid[:], nil + } + + var ( + ErrInvalidLength = errors.New("invalid length") + ) + + func decodeUUID(uuid *uuid.UUID, data []byte) error { + if len(data) != len(uuid) { + return ErrInvalidLength + } + copy(uuid[:], data) + return nil + } + ``` + +2. Use the following `go:generate` command: + + ```go + //go:generate go run github.com/tarantool/go-option/cmd/gentypes@latest -package . -imports "github.com/google/uuid" -type UUID -marshal-func "encodeUUID" -unmarshal-func "decodeUUID" -force -ext-code 100 + ``` ### Using Generated Types @@ -238,4 +293,4 @@ BSD 2-Clause License [coverage-url]: https://coveralls.io/github/tarantool/go-option?branch=master [telegram-badge]: https://img.shields.io/badge/Telegram-join%20chat-blue.svg [telegram-en-url]: http://telegram.me/tarantool -[telegram-ru-url]: http://telegram.me/tarantoolru \ No newline at end of file +[telegram-ru-url]: http://telegram.me/tarantoolru diff --git a/cmd/gentypes/flag.go b/cmd/gentypes/flag.go new file mode 100644 index 0000000..121851c --- /dev/null +++ b/cmd/gentypes/flag.go @@ -0,0 +1,34 @@ +package main + +import ( + "maps" + "slices" + "strings" +) + +type stringListFlag []string + +func (s *stringListFlag) String() string { + return strings.Join(*s, ", ") +} + +func (s *stringListFlag) Set(s2 string) error { + *s = append(*s, s2) + return nil +} + +func deleteDuplicates(s stringListFlag) stringListFlag { + uniqMap := map[string]struct{}{} + for _, val := range s { + uniqMap[val] = struct{}{} + } + + return slices.Collect(maps.Keys(uniqMap)) +} + +func (s *stringListFlag) Get() []string { + deduped := deleteDuplicates(*s) + slices.Sort(deduped) + + return deduped +} diff --git a/cmd/gentypes/generate.go b/cmd/gentypes/generate.go index 64dcbb1..ac4713f 100644 --- a/cmd/gentypes/generate.go +++ b/cmd/gentypes/generate.go @@ -1,2 +1,5 @@ -//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 1 -package test FullMsgpackExtType +//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 1 -package internal/test FullMsgpackExtType +//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 2 -force -package internal/test HiddenTypeAlias +//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 3 -imports github.com/google/uuid -package internal/test -marshal-func encodeUUID -unmarshal-func decodeUUID uuid.UUID + package main diff --git a/cmd/gentypes/generator/extension.go b/cmd/gentypes/generator/extension.go index 7e3f7fa..4955706 100644 --- a/cmd/gentypes/generator/extension.go +++ b/cmd/gentypes/generator/extension.go @@ -6,6 +6,7 @@ import ( _ "embed" "fmt" "strconv" + "strings" "text/template" ) @@ -26,22 +27,72 @@ func InitializeTemplates() { cTypeGenTestTemplate = template.Must(template.New("type_gen_test.go.tpl").Parse(typeGenTestTemplate)) } +const ( + maxNameParts = 2 +) + +func constructTypeName(typeName string) string { + splittedName := strings.SplitN(typeName, ".", maxNameParts) + switch len(splittedName) { + case 1: + typeName = splittedName[0] + case maxNameParts: + typeName = splittedName[1] + default: + panic("invalid type name: " + typeName) + } + + return "Optional" + typeName +} + +// GenerateOptions is the options for the code generation. +type GenerateOptions struct { + // TypeName is the name of the type to generate optional to. + TypeName string + // ExtCode is the extension code. + ExtCode int + // PackageName is the name of the package to generate to. + PackageName string + // Imports is the list of imports to add to the generated code. + Imports []string + // CustomMarshalFunc is the name of the custom marshal function. + CustomMarshalFunc string + // CustomUnmarshalFunc is the name of the custom unmarshal function. + CustomUnmarshalFunc string +} + // GenerateByType generates the code for the optional type. -func GenerateByType(typeName string, code int, packageName string) ([]byte, error) { +func GenerateByType(opts GenerateOptions) ([]byte, error) { var buf bytes.Buffer + if opts.CustomMarshalFunc == "" { + opts.CustomMarshalFunc = "o.value.MarshalMsgpack()" + } else { + opts.CustomMarshalFunc += "(o.value)" + } + + if opts.CustomUnmarshalFunc == "" { + opts.CustomUnmarshalFunc = "o.value.UnmarshalMsgpack(a)" + } else { + opts.CustomUnmarshalFunc += "(&o.value, a)" + } + err := cTypeGenTemplate.Execute(&buf, struct { - Name string - Type string - ExtCode string - PackageName string - Imports []string + Name string + Type string + ExtCode string + PackageName string + Imports []string + CustomMarshalFunc string + CustomUnmarshalFunc string }{ - Name: "Optional" + typeName, - Type: typeName, - ExtCode: strconv.Itoa(code), - PackageName: packageName, - Imports: nil, + Name: constructTypeName(opts.TypeName), + Type: opts.TypeName, + ExtCode: strconv.Itoa(opts.ExtCode), + PackageName: opts.PackageName, + Imports: opts.Imports, + CustomMarshalFunc: opts.CustomMarshalFunc, + CustomUnmarshalFunc: opts.CustomUnmarshalFunc, }) if err != nil { return nil, fmt.Errorf("failed to generateByType: %w", err) diff --git a/cmd/gentypes/generator/type_gen.go.tpl b/cmd/gentypes/generator/type_gen.go.tpl index 6111ff7..874e119 100644 --- a/cmd/gentypes/generator/type_gen.go.tpl +++ b/cmd/gentypes/generator/type_gen.go.tpl @@ -153,7 +153,7 @@ func (o {{.Name}}) UnwrapOrElse(defaultValue func() {{.Type}}) {{.Type}} { } func (o {{.Name}}) encodeValue(encoder *msgpack.Encoder) error { - value, err := o.value.MarshalMsgpack() + value, err := {{ .CustomMarshalFunc }} if err != nil { return err } @@ -199,7 +199,7 @@ func (o *{{.Name}}) decodeValue(decoder *msgpack.Decoder) error { return o.newDecodeError(err) } - if err := o.value.UnmarshalMsgpack(a); err != nil { + if err := {{ .CustomUnmarshalFunc }}; err != nil { return o.newDecodeError(err) } diff --git a/cmd/gentypes/test/fullmsgpackexttype.go b/cmd/gentypes/internal/test/fullmsgpackexttype.go similarity index 100% rename from cmd/gentypes/test/fullmsgpackexttype.go rename to cmd/gentypes/internal/test/fullmsgpackexttype.go diff --git a/cmd/gentypes/test/fullmsgpackexttype_gen.go b/cmd/gentypes/internal/test/fullmsgpackexttype_gen.go similarity index 100% rename from cmd/gentypes/test/fullmsgpackexttype_gen.go rename to cmd/gentypes/internal/test/fullmsgpackexttype_gen.go diff --git a/cmd/gentypes/test/fullmsgpackexttype_test.go b/cmd/gentypes/internal/test/fullmsgpackexttype_test.go similarity index 63% rename from cmd/gentypes/test/fullmsgpackexttype_test.go rename to cmd/gentypes/internal/test/fullmsgpackexttype_test.go index b3a57d3..7a1e262 100644 --- a/cmd/gentypes/test/fullmsgpackexttype_test.go +++ b/cmd/gentypes/internal/test/fullmsgpackexttype_test.go @@ -8,18 +8,18 @@ import ( "github.com/stretchr/testify/require" "github.com/vmihailenco/msgpack/v5" - td "github.com/tarantool/go-option/cmd/gentypes/test" + "github.com/tarantool/go-option/cmd/gentypes/internal/test" ) func TestOptionalMsgpackExtType_RoundtripLL(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) b := bytes.Buffer{} enc := msgpack.NewEncoder(&b) @@ -27,7 +27,7 @@ func TestOptionalMsgpackExtType_RoundtripLL(t *testing.T) { require.NoError(t, opt.EncodeMsgpack(enc)) - opt2 := td.NoneOptionalFullMsgpackExtType() + opt2 := test.NoneOptionalFullMsgpackExtType() require.NoError(t, opt2.DecodeMsgpack(dec)) assert.Equal(t, opt, opt2) @@ -37,12 +37,12 @@ func TestOptionalMsgpackExtType_RoundtripLL(t *testing.T) { func TestOptionalMsgpackExtType_RoundtripHL(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) b := bytes.Buffer{} enc := msgpack.NewEncoder(&b) @@ -50,7 +50,7 @@ func TestOptionalMsgpackExtType_RoundtripHL(t *testing.T) { require.NoError(t, enc.Encode(opt)) - opt2 := td.NoneOptionalFullMsgpackExtType() + opt2 := test.NoneOptionalFullMsgpackExtType() require.NoError(t, dec.Decode(&opt2)) assert.Equal(t, opt, opt2) @@ -63,12 +63,12 @@ func TestOptionalFullMsgpackExtType_IsSome(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) assert.True(t, opt.IsSome()) }) @@ -76,7 +76,7 @@ func TestOptionalFullMsgpackExtType_IsSome(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() assert.False(t, opt.IsSome()) }) @@ -88,12 +88,12 @@ func TestOptionalFullMsgpackExtType_IsZero(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) assert.False(t, opt.IsZero()) }) @@ -101,7 +101,7 @@ func TestOptionalFullMsgpackExtType_IsZero(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() assert.True(t, opt.IsZero()) }) @@ -113,12 +113,12 @@ func TestOptionalFullMsgpackExtType_Get(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) val, ok := opt.Get() require.True(t, ok) @@ -128,10 +128,10 @@ func TestOptionalFullMsgpackExtType_Get(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() val, ok := opt.Get() require.False(t, ok) - assert.Equal(t, td.NewEmptyFullMsgpackExtType(), val) + assert.Equal(t, test.NewEmptyFullMsgpackExtType(), val) }) } @@ -141,14 +141,14 @@ func TestOptionalFullMsgpackExtType_MustGet(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) - var val td.FullMsgpackExtType + var val test.FullMsgpackExtType require.NotPanics(t, func() { val = opt.MustGet() @@ -159,7 +159,7 @@ func TestOptionalFullMsgpackExtType_MustGet(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() require.Panics(t, func() { opt.MustGet() }) }) @@ -171,12 +171,12 @@ func TestOptionalFullMsgpackExtType_Unwrap(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) assert.Equal(t, input, opt.Unwrap()) }) @@ -184,8 +184,8 @@ func TestOptionalFullMsgpackExtType_Unwrap(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() - assert.Equal(t, td.NewEmptyFullMsgpackExtType(), opt.Unwrap()) + opt := test.NoneOptionalFullMsgpackExtType() + assert.Equal(t, test.NewEmptyFullMsgpackExtType(), opt.Unwrap()) }) } @@ -195,25 +195,25 @@ func TestOptionalFullMsgpackExtType_UnwrapOr(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) - assert.Equal(t, input, opt.UnwrapOr(td.NewEmptyFullMsgpackExtType())) + assert.Equal(t, input, opt.UnwrapOr(test.NewEmptyFullMsgpackExtType())) }) t.Run("none", func(t *testing.T) { t.Parallel() - alt := td.FullMsgpackExtType{ + alt := test.FullMsgpackExtType{ A: 1, B: "b", } - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() assert.Equal(t, alt, opt.UnwrapOr(alt)) }) } @@ -224,27 +224,27 @@ func TestOptionalFullMsgpackExtType_UnwrapOrElse(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) - assert.Equal(t, input, opt.UnwrapOrElse(td.NewEmptyFullMsgpackExtType)) + assert.Equal(t, input, opt.UnwrapOrElse(test.NewEmptyFullMsgpackExtType)) }) t.Run("none", func(t *testing.T) { t.Parallel() - alt := td.FullMsgpackExtType{ + alt := test.FullMsgpackExtType{ A: 1, B: "b", } - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() - assert.Equal(t, alt, opt.UnwrapOrElse(func() td.FullMsgpackExtType { + assert.Equal(t, alt, opt.UnwrapOrElse(func() test.FullMsgpackExtType { return alt })) }) diff --git a/cmd/gentypes/internal/test/hiddentypealias.go b/cmd/gentypes/internal/test/hiddentypealias.go new file mode 100644 index 0000000..adec75e --- /dev/null +++ b/cmd/gentypes/internal/test/hiddentypealias.go @@ -0,0 +1,8 @@ +package test + +import ( + "github.com/tarantool/go-option/cmd/gentypes/internal/test/subpackage" +) + +// HiddenTypeAlias is a hidden type alias to test. +type HiddenTypeAlias = subpackage.Hidden diff --git a/cmd/gentypes/internal/test/hiddentypealias_gen.go b/cmd/gentypes/internal/test/hiddentypealias_gen.go new file mode 100644 index 0000000..9c1cf31 --- /dev/null +++ b/cmd/gentypes/internal/test/hiddentypealias_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package test + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalHiddenTypeAlias represents an optional value of type HiddenTypeAlias. +// It can either hold a valid HiddenTypeAlias (IsSome == true) or be empty (IsZero == true). +type OptionalHiddenTypeAlias struct { + value HiddenTypeAlias + exists bool +} + +// SomeOptionalHiddenTypeAlias creates an optional OptionalHiddenTypeAlias with the given HiddenTypeAlias value. +// The returned OptionalHiddenTypeAlias will have IsSome() == true and IsZero() == false. +func SomeOptionalHiddenTypeAlias(value HiddenTypeAlias) OptionalHiddenTypeAlias { + return OptionalHiddenTypeAlias{ + value: value, + exists: true, + } +} + +// NoneOptionalHiddenTypeAlias creates an empty optional OptionalHiddenTypeAlias value. +// The returned OptionalHiddenTypeAlias will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalHiddenTypeAlias() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalHiddenTypeAlias() OptionalHiddenTypeAlias { + return OptionalHiddenTypeAlias{} +} + +func (o OptionalHiddenTypeAlias) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalHiddenTypeAlias", + Parent: err, + } +} + +func (o OptionalHiddenTypeAlias) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalHiddenTypeAlias", + Parent: err, + } +} + +// IsSome returns true if the OptionalHiddenTypeAlias contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalHiddenTypeAlias) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalHiddenTypeAlias does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalHiddenTypeAlias) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalHiddenTypeAlias) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of HiddenTypeAlias, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalHiddenTypeAlias) Get() (HiddenTypeAlias, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalHiddenTypeAlias) MustGet() HiddenTypeAlias { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for HiddenTypeAlias. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalHiddenTypeAlias) Unwrap() HiddenTypeAlias { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalHiddenTypeAlias() +// v := o.UnwrapOr(someDefaultOptionalHiddenTypeAlias) +func (o OptionalHiddenTypeAlias) UnwrapOr(defaultValue HiddenTypeAlias) HiddenTypeAlias { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalHiddenTypeAlias() +// v := o.UnwrapOrElse(func() HiddenTypeAlias { return computeDefault() }) +func (o OptionalHiddenTypeAlias) UnwrapOrElse(defaultValue func() HiddenTypeAlias) HiddenTypeAlias { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalHiddenTypeAlias) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(2, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalHiddenTypeAlias value using MessagePack format. +// - If the value is present, it is encoded as HiddenTypeAlias. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalHiddenTypeAlias) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalHiddenTypeAlias) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 2: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalHiddenTypeAlias) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalHiddenTypeAlias value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalHiddenTypeAlias) +// - HiddenTypeAlias: interpreted as a present value (SomeOptionalHiddenTypeAlias) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on HiddenTypeAlias: exists = true, value = decoded value +func (o *OptionalHiddenTypeAlias) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/cmd/gentypes/internal/test/subpackage/hiddentype.go b/cmd/gentypes/internal/test/subpackage/hiddentype.go new file mode 100644 index 0000000..5516fd3 --- /dev/null +++ b/cmd/gentypes/internal/test/subpackage/hiddentype.go @@ -0,0 +1,18 @@ +// Package subpackage contains a hidden type for testing type aliases generation. +package subpackage + +// Hidden is a hidden type for testing type aliases generation. +type Hidden struct { + Hidden string +} + +// MarshalMsgpack implements the MsgpackMarshaler interface. +func (h *Hidden) MarshalMsgpack() ([]byte, error) { + return []byte(h.Hidden), nil +} + +// UnmarshalMsgpack implements the MsgpackUnmarshaler interface. +func (h *Hidden) UnmarshalMsgpack(bytes []byte) error { + h.Hidden = string(bytes) + return nil +} diff --git a/cmd/gentypes/internal/test/uuid.go b/cmd/gentypes/internal/test/uuid.go new file mode 100644 index 0000000..f029755 --- /dev/null +++ b/cmd/gentypes/internal/test/uuid.go @@ -0,0 +1,26 @@ +package test + +import ( + "errors" + + "github.com/google/uuid" +) + +func encodeUUID(uuid uuid.UUID) ([]byte, error) { + return uuid[:], nil +} + +var ( + // ErrInvalidLength is returned when the length of the input data is invalid. + ErrInvalidLength = errors.New("invalid length") +) + +func decodeUUID(uuid *uuid.UUID, data []byte) error { + if len(data) != len(uuid) { + return ErrInvalidLength + } + + copy(uuid[:], data) + + return nil +} diff --git a/cmd/gentypes/internal/test/uuid_gen.go b/cmd/gentypes/internal/test/uuid_gen.go new file mode 100644 index 0000000..ac64ff7 --- /dev/null +++ b/cmd/gentypes/internal/test/uuid_gen.go @@ -0,0 +1,243 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package test + +import ( + "github.com/google/uuid" + + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalUUID represents an optional value of type uuid.UUID. +// It can either hold a valid uuid.UUID (IsSome == true) or be empty (IsZero == true). +type OptionalUUID struct { + value uuid.UUID + exists bool +} + +// SomeOptionalUUID creates an optional OptionalUUID with the given uuid.UUID value. +// The returned OptionalUUID will have IsSome() == true and IsZero() == false. +func SomeOptionalUUID(value uuid.UUID) OptionalUUID { + return OptionalUUID{ + value: value, + exists: true, + } +} + +// NoneOptionalUUID creates an empty optional OptionalUUID value. +// The returned OptionalUUID will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalUUID() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalUUID() OptionalUUID { + return OptionalUUID{} +} + +func (o OptionalUUID) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalUUID", + Parent: err, + } +} + +func (o OptionalUUID) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalUUID", + Parent: err, + } +} + +// IsSome returns true if the OptionalUUID contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalUUID) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalUUID does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalUUID) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalUUID) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of uuid.UUID, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalUUID) Get() (uuid.UUID, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalUUID) MustGet() uuid.UUID { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for uuid.UUID. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalUUID) Unwrap() uuid.UUID { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalUUID() +// v := o.UnwrapOr(someDefaultOptionalUUID) +func (o OptionalUUID) UnwrapOr(defaultValue uuid.UUID) uuid.UUID { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalUUID() +// v := o.UnwrapOrElse(func() uuid.UUID { return computeDefault() }) +func (o OptionalUUID) UnwrapOrElse(defaultValue func() uuid.UUID) uuid.UUID { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalUUID) encodeValue(encoder *msgpack.Encoder) error { + value, err := encodeUUID(o.value) + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(3, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalUUID value using MessagePack format. +// - If the value is present, it is encoded as uuid.UUID. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalUUID) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalUUID) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 3: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := decodeUUID(&o.value, a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalUUID) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalUUID value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalUUID) +// - uuid.UUID: interpreted as a present value (SomeOptionalUUID) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on uuid.UUID: exists = true, value = decoded value +func (o *OptionalUUID) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/cmd/gentypes/main.go b/cmd/gentypes/main.go index 67e5a4c..d4a9dd9 100644 --- a/cmd/gentypes/main.go +++ b/cmd/gentypes/main.go @@ -24,9 +24,13 @@ const ( ) var ( - packagePath string - extCode int - verbose bool + packagePath string + extCode int + verbose bool + force bool + imports stringListFlag + customMarshalFunc string + customUnmarshalFunc string ) func logfuncf(format string, args ...interface{}) { @@ -98,6 +102,26 @@ func printFile(prefix string, data []byte) { } } +func isExternalDep(name string) bool { + return strings.Contains(name, ".") +} + +const ( + maxNameParts = 2 +) + +func constructFileName(name string) string { + parts := strings.SplitN(name, ".", maxNameParts) + switch { + case len(parts) == 1: + name = parts[0] + case len(parts) == maxNameParts: + name = parts[1] + } + + return strings.ToLower(name) + "_gen.go" +} + func main() { //nolint:funlen generator.InitializeTemplates() @@ -106,7 +130,10 @@ func main() { //nolint:funlen flag.StringVar(&packagePath, "package", "./", "input and output path") flag.IntVar(&extCode, "ext-code", undefinedExtCode, "extension code") flag.BoolVar(&verbose, "verbose", false, "print verbose output") - + flag.BoolVar(&force, "force", false, "generate files even if methods do not exist") + flag.Var(&imports, "imports", "imports to add to generated files") + flag.StringVar(&customMarshalFunc, "marshal-func", "", "custom marshal function") + flag.StringVar(&customUnmarshalFunc, "unmarshal-func", "", "custom unmarshal function") flag.Parse() switch { @@ -164,19 +191,32 @@ func main() { //nolint:funlen // Check for existence of all types that we want to generate. typeSpecDef, ok := analyzer.TypeSpecEntryByName(typeName) - if !ok { + switch { + case isExternalDep(typeName): + fmt.Println("typename contains dot, probably third party type:", typeName) + case !ok: fmt.Println("failed to find struct:", typeName) os.Exit(1) } - fmt.Println("generating optional for:", typeName) + fmt.Println("generating optional:", typeName) - if !typeSpecDef.HasMethod("MarshalMsgpack") || !typeSpecDef.HasMethod("UnmarshalMsgpack") { + switch { + case force || isExternalDep(typeName): + // Skipping check for MarshalMsgpack and UnmarshalMsgpack methods. + case !typeSpecDef.HasMethod("MarshalMsgpack") || !typeSpecDef.HasMethod("UnmarshalMsgpack"): fmt.Println("failed to find MarshalMsgpack or UnmarshalMsgpack method for struct:", typeName) os.Exit(1) } - generatedGoSources, err := generator.GenerateByType(typeName, extCode, analyzer.PackageName()) + generatedGoSources, err := generator.GenerateByType(generator.GenerateOptions{ + TypeName: typeName, + ExtCode: extCode, + PackageName: pkg.Name, + Imports: imports, + CustomMarshalFunc: customMarshalFunc, + CustomUnmarshalFunc: customUnmarshalFunc, + }) if err != nil { fmt.Println("failed to generate optional types:") fmt.Println(" ", err) @@ -191,7 +231,7 @@ func main() { //nolint:funlen } err = os.WriteFile( - filepath.Join(packagePath, strings.ToLower(typeName)+"_gen.go"), + filepath.Join(packagePath, constructFileName(typeName)), formattedGoSource, defaultGoPermissions, ) diff --git a/go.mod b/go.mod index 3dcb760..944f136 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/tarantool/go-option go 1.23.0 require ( + github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.11.1 github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/text v0.28.0 diff --git a/go.sum b/go.sum index 4c300ea..8e31417 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=