From ecdf7af802648be785fbfe57d9fcffaf7b387dc6 Mon Sep 17 00:00:00 2001 From: Eugene Blikh Date: Wed, 3 Sep 2025 14:17:29 +0300 Subject: [PATCH] gentypes: Add support for third-party types The gentypes tool was previously limited to generating optional types for local types that implemented the MarshalMsgpack and UnmarshalMsgpack methods. Thi made it impossible to generate optional types for types from external modules. This commit enhances the generator to support third-party types by introducing the following new flags: * -force: To bypass the check for MarshalMsgpack and UnmarshalMsgpack methods. * -imports: To add necessary imports for the third-party type and custom functions. * -marshal-func: To specify a custom marshal function. * -unmarshal-func: To specify a custom unmarshal function. The generator code has been refactored to use a GenerateOptions struct for better organization. Additionally, this commit: * Adds a new test case for generating an optional type for uuid.UUID. * Updates the README.md with documentation for the new flags and an example. * Moves test files to a more appropriate internal/test directory. Closes #TNTP-3734. --- .golangci.yml | 1 + CHANGELOG.md | 2 + README.md | 65 ++++- cmd/gentypes/flag.go | 34 +++ cmd/gentypes/generate.go | 5 +- cmd/gentypes/generator/extension.go | 73 +++++- cmd/gentypes/generator/type_gen.go.tpl | 4 +- .../{ => internal}/test/fullmsgpackexttype.go | 0 .../test/fullmsgpackexttype_gen.go | 0 .../test/fullmsgpackexttype_test.go | 72 +++--- cmd/gentypes/internal/test/hiddentypealias.go | 8 + .../internal/test/hiddentypealias_gen.go | 241 +++++++++++++++++ .../internal/test/subpackage/hiddentype.go | 18 ++ cmd/gentypes/internal/test/uuid.go | 26 ++ cmd/gentypes/internal/test/uuid_gen.go | 243 ++++++++++++++++++ cmd/gentypes/main.go | 58 ++++- go.mod | 1 + go.sum | 2 + 18 files changed, 789 insertions(+), 64 deletions(-) create mode 100644 cmd/gentypes/flag.go rename cmd/gentypes/{ => internal}/test/fullmsgpackexttype.go (100%) rename cmd/gentypes/{ => internal}/test/fullmsgpackexttype_gen.go (100%) rename cmd/gentypes/{ => internal}/test/fullmsgpackexttype_test.go (63%) create mode 100644 cmd/gentypes/internal/test/hiddentypealias.go create mode 100644 cmd/gentypes/internal/test/hiddentypealias_gen.go create mode 100644 cmd/gentypes/internal/test/subpackage/hiddentype.go create mode 100644 cmd/gentypes/internal/test/uuid.go create mode 100644 cmd/gentypes/internal/test/uuid_gen.go diff --git a/.golangci.yml b/.golangci.yml index b95c11b..6250424 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -57,6 +57,7 @@ linters: - "golang.org/x/tools" - "github.com/vmihailenco/msgpack/v5" - "github.com/tarantool/go-option" + - "github.com/google/uuid" test: files: - "$test" diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d31081..401457a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added +- gentypes: Add support for third-party types. + ### Changed ### Fixed diff --git a/README.md b/README.md index 9042880..1ef2c11 100644 --- a/README.md +++ b/README.md @@ -178,10 +178,65 @@ Or you can use it to generate file from go: Flags: -• `-package`: Path to the Go package containing types to wrap (default: `"."`) -• `-ext-code`: MessagePack extension code to use for custom types (must be between --128 and 127, no default value) -• `-verbose`: Enable verbose output (default: `false`) + * `-package`: Path to the Go package containing types to wrap (default: `"."`) + * `-ext-code`: MessagePack extension code to use for custom types (must be between + -128 and 127, no default value) + * `-verbose`: Enable verbose output (default: `false`) + * `-force`: Ignore absence of marshal/unmarshal methods on type (default: `false`). + Helpful for types from third-party modules. + * `-imports`: Add imports to generated file (default is empty). + Helpful for types from third-party modules. + * `-marshal-func`: func that should do marshaling (default is `MarshalMsgpack` method). + Helpful for types from third-party modules. + Should be func of type `func(v T) ([]byte, error)` and should + be located in the same dir or should be imported. + * `-unmarshal-func`: func that should do unmarshalling (default is `UnmarshalMsgpack` method). + Helpful for types from third-party modules. + Should be func of type `func(v *T, data []byte) error` and should + be located in the same dir or should be imported. + +#### Generating Optional Types for Third-Party Modules + +Sometimes you need to generate an optional type for a type from a third-party module, +and you can't add `MarshalMsgpack`/`UnmarshalMsgpack` methods to it. +In this case, you can use the `-force`, `-imports`, `-marshal-func`, and `-unmarshal-func` flags. + +For example, to generate an optional type for `github.com/google/uuid.UUID`: + +1. Create a file with marshal and unmarshal functions for the third-party type. + For example, `uuid.go`: + + ```go + package main + + import ( + "errors" + + "github.com/google/uuid" + ) + + func encodeUUID(uuid uuid.UUID) ([]byte, error) { + return uuid[:], nil + } + + var ( + ErrInvalidLength = errors.New("invalid length") + ) + + func decodeUUID(uuid *uuid.UUID, data []byte) error { + if len(data) != len(uuid) { + return ErrInvalidLength + } + copy(uuid[:], data) + return nil + } + ``` + +2. Use the following `go:generate` command: + + ```go + //go:generate go run github.com/tarantool/go-option/cmd/gentypes@latest -package . -imports "github.com/google/uuid" -type UUID -marshal-func "encodeUUID" -unmarshal-func "decodeUUID" -force -ext-code 100 + ``` ### Using Generated Types @@ -238,4 +293,4 @@ BSD 2-Clause License [coverage-url]: https://coveralls.io/github/tarantool/go-option?branch=master [telegram-badge]: https://img.shields.io/badge/Telegram-join%20chat-blue.svg [telegram-en-url]: http://telegram.me/tarantool -[telegram-ru-url]: http://telegram.me/tarantoolru \ No newline at end of file +[telegram-ru-url]: http://telegram.me/tarantoolru diff --git a/cmd/gentypes/flag.go b/cmd/gentypes/flag.go new file mode 100644 index 0000000..121851c --- /dev/null +++ b/cmd/gentypes/flag.go @@ -0,0 +1,34 @@ +package main + +import ( + "maps" + "slices" + "strings" +) + +type stringListFlag []string + +func (s *stringListFlag) String() string { + return strings.Join(*s, ", ") +} + +func (s *stringListFlag) Set(s2 string) error { + *s = append(*s, s2) + return nil +} + +func deleteDuplicates(s stringListFlag) stringListFlag { + uniqMap := map[string]struct{}{} + for _, val := range s { + uniqMap[val] = struct{}{} + } + + return slices.Collect(maps.Keys(uniqMap)) +} + +func (s *stringListFlag) Get() []string { + deduped := deleteDuplicates(*s) + slices.Sort(deduped) + + return deduped +} diff --git a/cmd/gentypes/generate.go b/cmd/gentypes/generate.go index 64dcbb1..ac4713f 100644 --- a/cmd/gentypes/generate.go +++ b/cmd/gentypes/generate.go @@ -1,2 +1,5 @@ -//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 1 -package test FullMsgpackExtType +//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 1 -package internal/test FullMsgpackExtType +//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 2 -force -package internal/test HiddenTypeAlias +//go:generate go run github.com/tarantool/go-option/cmd/gentypes -ext-code 3 -imports github.com/google/uuid -package internal/test -marshal-func encodeUUID -unmarshal-func decodeUUID uuid.UUID + package main diff --git a/cmd/gentypes/generator/extension.go b/cmd/gentypes/generator/extension.go index 7e3f7fa..4955706 100644 --- a/cmd/gentypes/generator/extension.go +++ b/cmd/gentypes/generator/extension.go @@ -6,6 +6,7 @@ import ( _ "embed" "fmt" "strconv" + "strings" "text/template" ) @@ -26,22 +27,72 @@ func InitializeTemplates() { cTypeGenTestTemplate = template.Must(template.New("type_gen_test.go.tpl").Parse(typeGenTestTemplate)) } +const ( + maxNameParts = 2 +) + +func constructTypeName(typeName string) string { + splittedName := strings.SplitN(typeName, ".", maxNameParts) + switch len(splittedName) { + case 1: + typeName = splittedName[0] + case maxNameParts: + typeName = splittedName[1] + default: + panic("invalid type name: " + typeName) + } + + return "Optional" + typeName +} + +// GenerateOptions is the options for the code generation. +type GenerateOptions struct { + // TypeName is the name of the type to generate optional to. + TypeName string + // ExtCode is the extension code. + ExtCode int + // PackageName is the name of the package to generate to. + PackageName string + // Imports is the list of imports to add to the generated code. + Imports []string + // CustomMarshalFunc is the name of the custom marshal function. + CustomMarshalFunc string + // CustomUnmarshalFunc is the name of the custom unmarshal function. + CustomUnmarshalFunc string +} + // GenerateByType generates the code for the optional type. -func GenerateByType(typeName string, code int, packageName string) ([]byte, error) { +func GenerateByType(opts GenerateOptions) ([]byte, error) { var buf bytes.Buffer + if opts.CustomMarshalFunc == "" { + opts.CustomMarshalFunc = "o.value.MarshalMsgpack()" + } else { + opts.CustomMarshalFunc += "(o.value)" + } + + if opts.CustomUnmarshalFunc == "" { + opts.CustomUnmarshalFunc = "o.value.UnmarshalMsgpack(a)" + } else { + opts.CustomUnmarshalFunc += "(&o.value, a)" + } + err := cTypeGenTemplate.Execute(&buf, struct { - Name string - Type string - ExtCode string - PackageName string - Imports []string + Name string + Type string + ExtCode string + PackageName string + Imports []string + CustomMarshalFunc string + CustomUnmarshalFunc string }{ - Name: "Optional" + typeName, - Type: typeName, - ExtCode: strconv.Itoa(code), - PackageName: packageName, - Imports: nil, + Name: constructTypeName(opts.TypeName), + Type: opts.TypeName, + ExtCode: strconv.Itoa(opts.ExtCode), + PackageName: opts.PackageName, + Imports: opts.Imports, + CustomMarshalFunc: opts.CustomMarshalFunc, + CustomUnmarshalFunc: opts.CustomUnmarshalFunc, }) if err != nil { return nil, fmt.Errorf("failed to generateByType: %w", err) diff --git a/cmd/gentypes/generator/type_gen.go.tpl b/cmd/gentypes/generator/type_gen.go.tpl index 6111ff7..874e119 100644 --- a/cmd/gentypes/generator/type_gen.go.tpl +++ b/cmd/gentypes/generator/type_gen.go.tpl @@ -153,7 +153,7 @@ func (o {{.Name}}) UnwrapOrElse(defaultValue func() {{.Type}}) {{.Type}} { } func (o {{.Name}}) encodeValue(encoder *msgpack.Encoder) error { - value, err := o.value.MarshalMsgpack() + value, err := {{ .CustomMarshalFunc }} if err != nil { return err } @@ -199,7 +199,7 @@ func (o *{{.Name}}) decodeValue(decoder *msgpack.Decoder) error { return o.newDecodeError(err) } - if err := o.value.UnmarshalMsgpack(a); err != nil { + if err := {{ .CustomUnmarshalFunc }}; err != nil { return o.newDecodeError(err) } diff --git a/cmd/gentypes/test/fullmsgpackexttype.go b/cmd/gentypes/internal/test/fullmsgpackexttype.go similarity index 100% rename from cmd/gentypes/test/fullmsgpackexttype.go rename to cmd/gentypes/internal/test/fullmsgpackexttype.go diff --git a/cmd/gentypes/test/fullmsgpackexttype_gen.go b/cmd/gentypes/internal/test/fullmsgpackexttype_gen.go similarity index 100% rename from cmd/gentypes/test/fullmsgpackexttype_gen.go rename to cmd/gentypes/internal/test/fullmsgpackexttype_gen.go diff --git a/cmd/gentypes/test/fullmsgpackexttype_test.go b/cmd/gentypes/internal/test/fullmsgpackexttype_test.go similarity index 63% rename from cmd/gentypes/test/fullmsgpackexttype_test.go rename to cmd/gentypes/internal/test/fullmsgpackexttype_test.go index b3a57d3..7a1e262 100644 --- a/cmd/gentypes/test/fullmsgpackexttype_test.go +++ b/cmd/gentypes/internal/test/fullmsgpackexttype_test.go @@ -8,18 +8,18 @@ import ( "github.com/stretchr/testify/require" "github.com/vmihailenco/msgpack/v5" - td "github.com/tarantool/go-option/cmd/gentypes/test" + "github.com/tarantool/go-option/cmd/gentypes/internal/test" ) func TestOptionalMsgpackExtType_RoundtripLL(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) b := bytes.Buffer{} enc := msgpack.NewEncoder(&b) @@ -27,7 +27,7 @@ func TestOptionalMsgpackExtType_RoundtripLL(t *testing.T) { require.NoError(t, opt.EncodeMsgpack(enc)) - opt2 := td.NoneOptionalFullMsgpackExtType() + opt2 := test.NoneOptionalFullMsgpackExtType() require.NoError(t, opt2.DecodeMsgpack(dec)) assert.Equal(t, opt, opt2) @@ -37,12 +37,12 @@ func TestOptionalMsgpackExtType_RoundtripLL(t *testing.T) { func TestOptionalMsgpackExtType_RoundtripHL(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) b := bytes.Buffer{} enc := msgpack.NewEncoder(&b) @@ -50,7 +50,7 @@ func TestOptionalMsgpackExtType_RoundtripHL(t *testing.T) { require.NoError(t, enc.Encode(opt)) - opt2 := td.NoneOptionalFullMsgpackExtType() + opt2 := test.NoneOptionalFullMsgpackExtType() require.NoError(t, dec.Decode(&opt2)) assert.Equal(t, opt, opt2) @@ -63,12 +63,12 @@ func TestOptionalFullMsgpackExtType_IsSome(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) assert.True(t, opt.IsSome()) }) @@ -76,7 +76,7 @@ func TestOptionalFullMsgpackExtType_IsSome(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() assert.False(t, opt.IsSome()) }) @@ -88,12 +88,12 @@ func TestOptionalFullMsgpackExtType_IsZero(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) assert.False(t, opt.IsZero()) }) @@ -101,7 +101,7 @@ func TestOptionalFullMsgpackExtType_IsZero(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() assert.True(t, opt.IsZero()) }) @@ -113,12 +113,12 @@ func TestOptionalFullMsgpackExtType_Get(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) val, ok := opt.Get() require.True(t, ok) @@ -128,10 +128,10 @@ func TestOptionalFullMsgpackExtType_Get(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() val, ok := opt.Get() require.False(t, ok) - assert.Equal(t, td.NewEmptyFullMsgpackExtType(), val) + assert.Equal(t, test.NewEmptyFullMsgpackExtType(), val) }) } @@ -141,14 +141,14 @@ func TestOptionalFullMsgpackExtType_MustGet(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) - var val td.FullMsgpackExtType + var val test.FullMsgpackExtType require.NotPanics(t, func() { val = opt.MustGet() @@ -159,7 +159,7 @@ func TestOptionalFullMsgpackExtType_MustGet(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() require.Panics(t, func() { opt.MustGet() }) }) @@ -171,12 +171,12 @@ func TestOptionalFullMsgpackExtType_Unwrap(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) assert.Equal(t, input, opt.Unwrap()) }) @@ -184,8 +184,8 @@ func TestOptionalFullMsgpackExtType_Unwrap(t *testing.T) { t.Run("none", func(t *testing.T) { t.Parallel() - opt := td.NoneOptionalFullMsgpackExtType() - assert.Equal(t, td.NewEmptyFullMsgpackExtType(), opt.Unwrap()) + opt := test.NoneOptionalFullMsgpackExtType() + assert.Equal(t, test.NewEmptyFullMsgpackExtType(), opt.Unwrap()) }) } @@ -195,25 +195,25 @@ func TestOptionalFullMsgpackExtType_UnwrapOr(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) - assert.Equal(t, input, opt.UnwrapOr(td.NewEmptyFullMsgpackExtType())) + assert.Equal(t, input, opt.UnwrapOr(test.NewEmptyFullMsgpackExtType())) }) t.Run("none", func(t *testing.T) { t.Parallel() - alt := td.FullMsgpackExtType{ + alt := test.FullMsgpackExtType{ A: 1, B: "b", } - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() assert.Equal(t, alt, opt.UnwrapOr(alt)) }) } @@ -224,27 +224,27 @@ func TestOptionalFullMsgpackExtType_UnwrapOrElse(t *testing.T) { t.Run("some", func(t *testing.T) { t.Parallel() - input := td.FullMsgpackExtType{ + input := test.FullMsgpackExtType{ A: 412, B: "bababa", } - opt := td.SomeOptionalFullMsgpackExtType(input) + opt := test.SomeOptionalFullMsgpackExtType(input) - assert.Equal(t, input, opt.UnwrapOrElse(td.NewEmptyFullMsgpackExtType)) + assert.Equal(t, input, opt.UnwrapOrElse(test.NewEmptyFullMsgpackExtType)) }) t.Run("none", func(t *testing.T) { t.Parallel() - alt := td.FullMsgpackExtType{ + alt := test.FullMsgpackExtType{ A: 1, B: "b", } - opt := td.NoneOptionalFullMsgpackExtType() + opt := test.NoneOptionalFullMsgpackExtType() - assert.Equal(t, alt, opt.UnwrapOrElse(func() td.FullMsgpackExtType { + assert.Equal(t, alt, opt.UnwrapOrElse(func() test.FullMsgpackExtType { return alt })) }) diff --git a/cmd/gentypes/internal/test/hiddentypealias.go b/cmd/gentypes/internal/test/hiddentypealias.go new file mode 100644 index 0000000..adec75e --- /dev/null +++ b/cmd/gentypes/internal/test/hiddentypealias.go @@ -0,0 +1,8 @@ +package test + +import ( + "github.com/tarantool/go-option/cmd/gentypes/internal/test/subpackage" +) + +// HiddenTypeAlias is a hidden type alias to test. +type HiddenTypeAlias = subpackage.Hidden diff --git a/cmd/gentypes/internal/test/hiddentypealias_gen.go b/cmd/gentypes/internal/test/hiddentypealias_gen.go new file mode 100644 index 0000000..9c1cf31 --- /dev/null +++ b/cmd/gentypes/internal/test/hiddentypealias_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package test + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalHiddenTypeAlias represents an optional value of type HiddenTypeAlias. +// It can either hold a valid HiddenTypeAlias (IsSome == true) or be empty (IsZero == true). +type OptionalHiddenTypeAlias struct { + value HiddenTypeAlias + exists bool +} + +// SomeOptionalHiddenTypeAlias creates an optional OptionalHiddenTypeAlias with the given HiddenTypeAlias value. +// The returned OptionalHiddenTypeAlias will have IsSome() == true and IsZero() == false. +func SomeOptionalHiddenTypeAlias(value HiddenTypeAlias) OptionalHiddenTypeAlias { + return OptionalHiddenTypeAlias{ + value: value, + exists: true, + } +} + +// NoneOptionalHiddenTypeAlias creates an empty optional OptionalHiddenTypeAlias value. +// The returned OptionalHiddenTypeAlias will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalHiddenTypeAlias() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalHiddenTypeAlias() OptionalHiddenTypeAlias { + return OptionalHiddenTypeAlias{} +} + +func (o OptionalHiddenTypeAlias) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalHiddenTypeAlias", + Parent: err, + } +} + +func (o OptionalHiddenTypeAlias) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalHiddenTypeAlias", + Parent: err, + } +} + +// IsSome returns true if the OptionalHiddenTypeAlias contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalHiddenTypeAlias) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalHiddenTypeAlias does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalHiddenTypeAlias) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalHiddenTypeAlias) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of HiddenTypeAlias, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalHiddenTypeAlias) Get() (HiddenTypeAlias, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalHiddenTypeAlias) MustGet() HiddenTypeAlias { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for HiddenTypeAlias. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalHiddenTypeAlias) Unwrap() HiddenTypeAlias { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalHiddenTypeAlias() +// v := o.UnwrapOr(someDefaultOptionalHiddenTypeAlias) +func (o OptionalHiddenTypeAlias) UnwrapOr(defaultValue HiddenTypeAlias) HiddenTypeAlias { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalHiddenTypeAlias() +// v := o.UnwrapOrElse(func() HiddenTypeAlias { return computeDefault() }) +func (o OptionalHiddenTypeAlias) UnwrapOrElse(defaultValue func() HiddenTypeAlias) HiddenTypeAlias { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalHiddenTypeAlias) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(2, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalHiddenTypeAlias value using MessagePack format. +// - If the value is present, it is encoded as HiddenTypeAlias. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalHiddenTypeAlias) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalHiddenTypeAlias) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 2: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalHiddenTypeAlias) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalHiddenTypeAlias value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalHiddenTypeAlias) +// - HiddenTypeAlias: interpreted as a present value (SomeOptionalHiddenTypeAlias) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on HiddenTypeAlias: exists = true, value = decoded value +func (o *OptionalHiddenTypeAlias) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/cmd/gentypes/internal/test/subpackage/hiddentype.go b/cmd/gentypes/internal/test/subpackage/hiddentype.go new file mode 100644 index 0000000..5516fd3 --- /dev/null +++ b/cmd/gentypes/internal/test/subpackage/hiddentype.go @@ -0,0 +1,18 @@ +// Package subpackage contains a hidden type for testing type aliases generation. +package subpackage + +// Hidden is a hidden type for testing type aliases generation. +type Hidden struct { + Hidden string +} + +// MarshalMsgpack implements the MsgpackMarshaler interface. +func (h *Hidden) MarshalMsgpack() ([]byte, error) { + return []byte(h.Hidden), nil +} + +// UnmarshalMsgpack implements the MsgpackUnmarshaler interface. +func (h *Hidden) UnmarshalMsgpack(bytes []byte) error { + h.Hidden = string(bytes) + return nil +} diff --git a/cmd/gentypes/internal/test/uuid.go b/cmd/gentypes/internal/test/uuid.go new file mode 100644 index 0000000..f029755 --- /dev/null +++ b/cmd/gentypes/internal/test/uuid.go @@ -0,0 +1,26 @@ +package test + +import ( + "errors" + + "github.com/google/uuid" +) + +func encodeUUID(uuid uuid.UUID) ([]byte, error) { + return uuid[:], nil +} + +var ( + // ErrInvalidLength is returned when the length of the input data is invalid. + ErrInvalidLength = errors.New("invalid length") +) + +func decodeUUID(uuid *uuid.UUID, data []byte) error { + if len(data) != len(uuid) { + return ErrInvalidLength + } + + copy(uuid[:], data) + + return nil +} diff --git a/cmd/gentypes/internal/test/uuid_gen.go b/cmd/gentypes/internal/test/uuid_gen.go new file mode 100644 index 0000000..ac64ff7 --- /dev/null +++ b/cmd/gentypes/internal/test/uuid_gen.go @@ -0,0 +1,243 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package test + +import ( + "github.com/google/uuid" + + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalUUID represents an optional value of type uuid.UUID. +// It can either hold a valid uuid.UUID (IsSome == true) or be empty (IsZero == true). +type OptionalUUID struct { + value uuid.UUID + exists bool +} + +// SomeOptionalUUID creates an optional OptionalUUID with the given uuid.UUID value. +// The returned OptionalUUID will have IsSome() == true and IsZero() == false. +func SomeOptionalUUID(value uuid.UUID) OptionalUUID { + return OptionalUUID{ + value: value, + exists: true, + } +} + +// NoneOptionalUUID creates an empty optional OptionalUUID value. +// The returned OptionalUUID will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalUUID() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalUUID() OptionalUUID { + return OptionalUUID{} +} + +func (o OptionalUUID) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalUUID", + Parent: err, + } +} + +func (o OptionalUUID) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalUUID", + Parent: err, + } +} + +// IsSome returns true if the OptionalUUID contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalUUID) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalUUID does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalUUID) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalUUID) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of uuid.UUID, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalUUID) Get() (uuid.UUID, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalUUID) MustGet() uuid.UUID { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for uuid.UUID. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalUUID) Unwrap() uuid.UUID { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalUUID() +// v := o.UnwrapOr(someDefaultOptionalUUID) +func (o OptionalUUID) UnwrapOr(defaultValue uuid.UUID) uuid.UUID { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalUUID() +// v := o.UnwrapOrElse(func() uuid.UUID { return computeDefault() }) +func (o OptionalUUID) UnwrapOrElse(defaultValue func() uuid.UUID) uuid.UUID { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalUUID) encodeValue(encoder *msgpack.Encoder) error { + value, err := encodeUUID(o.value) + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(3, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalUUID value using MessagePack format. +// - If the value is present, it is encoded as uuid.UUID. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalUUID) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalUUID) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 3: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := decodeUUID(&o.value, a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalUUID) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalUUID value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalUUID) +// - uuid.UUID: interpreted as a present value (SomeOptionalUUID) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on uuid.UUID: exists = true, value = decoded value +func (o *OptionalUUID) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/cmd/gentypes/main.go b/cmd/gentypes/main.go index 67e5a4c..d4a9dd9 100644 --- a/cmd/gentypes/main.go +++ b/cmd/gentypes/main.go @@ -24,9 +24,13 @@ const ( ) var ( - packagePath string - extCode int - verbose bool + packagePath string + extCode int + verbose bool + force bool + imports stringListFlag + customMarshalFunc string + customUnmarshalFunc string ) func logfuncf(format string, args ...interface{}) { @@ -98,6 +102,26 @@ func printFile(prefix string, data []byte) { } } +func isExternalDep(name string) bool { + return strings.Contains(name, ".") +} + +const ( + maxNameParts = 2 +) + +func constructFileName(name string) string { + parts := strings.SplitN(name, ".", maxNameParts) + switch { + case len(parts) == 1: + name = parts[0] + case len(parts) == maxNameParts: + name = parts[1] + } + + return strings.ToLower(name) + "_gen.go" +} + func main() { //nolint:funlen generator.InitializeTemplates() @@ -106,7 +130,10 @@ func main() { //nolint:funlen flag.StringVar(&packagePath, "package", "./", "input and output path") flag.IntVar(&extCode, "ext-code", undefinedExtCode, "extension code") flag.BoolVar(&verbose, "verbose", false, "print verbose output") - + flag.BoolVar(&force, "force", false, "generate files even if methods do not exist") + flag.Var(&imports, "imports", "imports to add to generated files") + flag.StringVar(&customMarshalFunc, "marshal-func", "", "custom marshal function") + flag.StringVar(&customUnmarshalFunc, "unmarshal-func", "", "custom unmarshal function") flag.Parse() switch { @@ -164,19 +191,32 @@ func main() { //nolint:funlen // Check for existence of all types that we want to generate. typeSpecDef, ok := analyzer.TypeSpecEntryByName(typeName) - if !ok { + switch { + case isExternalDep(typeName): + fmt.Println("typename contains dot, probably third party type:", typeName) + case !ok: fmt.Println("failed to find struct:", typeName) os.Exit(1) } - fmt.Println("generating optional for:", typeName) + fmt.Println("generating optional:", typeName) - if !typeSpecDef.HasMethod("MarshalMsgpack") || !typeSpecDef.HasMethod("UnmarshalMsgpack") { + switch { + case force || isExternalDep(typeName): + // Skipping check for MarshalMsgpack and UnmarshalMsgpack methods. + case !typeSpecDef.HasMethod("MarshalMsgpack") || !typeSpecDef.HasMethod("UnmarshalMsgpack"): fmt.Println("failed to find MarshalMsgpack or UnmarshalMsgpack method for struct:", typeName) os.Exit(1) } - generatedGoSources, err := generator.GenerateByType(typeName, extCode, analyzer.PackageName()) + generatedGoSources, err := generator.GenerateByType(generator.GenerateOptions{ + TypeName: typeName, + ExtCode: extCode, + PackageName: pkg.Name, + Imports: imports, + CustomMarshalFunc: customMarshalFunc, + CustomUnmarshalFunc: customUnmarshalFunc, + }) if err != nil { fmt.Println("failed to generate optional types:") fmt.Println(" ", err) @@ -191,7 +231,7 @@ func main() { //nolint:funlen } err = os.WriteFile( - filepath.Join(packagePath, strings.ToLower(typeName)+"_gen.go"), + filepath.Join(packagePath, constructFileName(typeName)), formattedGoSource, defaultGoPermissions, ) diff --git a/go.mod b/go.mod index 3dcb760..944f136 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/tarantool/go-option go 1.23.0 require ( + github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.11.1 github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/text v0.28.0 diff --git a/go.sum b/go.sum index 4c300ea..8e31417 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=