diff --git a/ui/options.go b/ui/options.go index b1f7d45..9766639 100644 --- a/ui/options.go +++ b/ui/options.go @@ -14,6 +14,7 @@ type options struct { value string allowEdit bool printTemplate string + minLength int promptTemplates *promptui.PromptTemplates selectTemplates *promptui.SelectTemplates validateFunc promptui.ValidateFunc @@ -109,6 +110,15 @@ func WithAllowEdit(b bool) Option { } } +// WithMinLength sets a minimum length requirement that will be verified +// at the time when the prompt is run. +// checks the input string meets the minimum length requirement. +func WithMinLength(minLength int) Option { + return func(o *options) { + o.minLength = minLength + } +} + // WithPrintTemplate sets the template to use on the print methods. func WithPrintTemplate(template string) Option { return func(o *options) { @@ -143,12 +153,6 @@ func WithValidateNotEmpty() Option { return WithValidateFunc(NotEmpty()) } -// WithValidateMinLength adds a custom validation function to a prompt that -// checks the input string meets the minimum length requirement. -func WithValidateMinLength(minLength int) Option { - return WithValidateFunc(MinLength(minLength)) -} - // WithValidateYesNo adds a custom validation function to a prompt for a Yes/No // prompt. func WithValidateYesNo() Option { diff --git a/ui/options_test.go b/ui/options_test.go new file mode 100644 index 0000000..1e66e95 --- /dev/null +++ b/ui/options_test.go @@ -0,0 +1,32 @@ +package ui + +import "testing" + +func TestWithMinLength(t *testing.T) { + tests := []struct { + name string + length int + }{ + { + name: "negative", + length: -5, + }, + { + name: "zero", + length: 0, + }, + { + name: "positive", + length: 11, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &options{} + WithMinLength(tt.length)(o) + if o.minLength != tt.length { + t.Errorf("want %v, but got %v", tt.length, o.minLength) + } + }) + } +} diff --git a/ui/ui.go b/ui/ui.go index 40fbc0d..0480eb5 100644 --- a/ui/ui.go +++ b/ui/ui.go @@ -5,6 +5,7 @@ import ( "os" "strings" "text/template" + "unicode" "github.com/chzyer/readline" "github.com/manifoldco/promptui" @@ -203,13 +204,33 @@ func PromptPassword(label string, opts ...Option) ([]byte, error) { Validate: o.validateFunc, Templates: o.promptTemplates, } - pass, err := prompt.Run() + + pass, err := runPrompt(prompt.Run, o) if err != nil { - return nil, errors.Wrap(err, "error reading password") + return nil, err } return []byte(pass), nil } +// runPrompt is a helper for the method prompt.Run. This helper will loop the +// prompt indefinitely while the input does not meet a minimum length requirement. +func runPrompt(run func() (string, error), opts *options) (string, error) { + for { + pass, err := run() + if err != nil { + return "", errors.Wrap(err, "error reading password") + } + + pass = strings.TrimRightFunc(pass, unicode.IsSpace) + + if opts.minLength <= 0 || len(pass) >= opts.minLength { + return pass, nil + } + + fmt.Printf(">>> input does not meet minimum length requirement; must be at least %v characters\n", opts.minLength) + } +} + // PromptPasswordGenerate creates and runs a promptui.Prompt with the given label. // This prompt will mask the key entries with \r. If the result password length // is 0, it will generate a new prompt with a generated password that can be diff --git a/ui/ui_test.go b/ui/ui_test.go new file mode 100644 index 0000000..abb0909 --- /dev/null +++ b/ui/ui_test.go @@ -0,0 +1,91 @@ +package ui + +import ( + "errors" + "testing" +) + +func Test_promptRun(t *testing.T) { + promptRunner := func(input []string, err error) func() (string, error) { + i := 0 + return func() (string, error) { + ret := input[i] + i++ + return ret, err + } + } + + tests := []struct { + name string + minLength int + promptRun func() (string, error) + want string + wantErr bool + }{ + { + name: "prompt-error", + minLength: -5, + promptRun: promptRunner([]string{"foobar"}, errors.New("prompt-error")), + want: "foobar", + wantErr: true, + }, + { + name: "negative", + minLength: -5, + promptRun: promptRunner([]string{"foobar"}, nil), + want: "foobar", + wantErr: false, + }, + { + name: "zero", + minLength: 0, + promptRun: promptRunner([]string{"foobar"}, nil), + want: "foobar", + wantErr: false, + }, + { + name: "greater-than-min-length", + minLength: 5, + promptRun: promptRunner([]string{"foobar"}, nil), + want: "foobar", + wantErr: false, + }, + { + name: "equal-min-length", + minLength: 6, + promptRun: promptRunner([]string{"foobar"}, nil), + want: "foobar", + wantErr: false, + }, + { + name: "less-than-min-length", + minLength: 8, + promptRun: promptRunner([]string{"pass", "foobar", "password"}, nil), + want: "password", + wantErr: false, + }, + { + name: "ignore-post-whitespace-characters", + minLength: 7, + promptRun: promptRunner([]string{"pass ", "foobar ", "password "}, nil), + want: "password", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := runPrompt(tt.promptRun, &options{minLength: tt.minLength}) + gotErr := err != nil + if gotErr != tt.wantErr { + t.Errorf("expected error=%v, but got error=%v", tt.wantErr, err) + return + } + if gotErr { + return + } + if val != tt.want { + t.Errorf("expected %v, but got %v", tt.want, val) + } + }) + } +} diff --git a/ui/validators.go b/ui/validators.go index 57e4910..7d8b603 100644 --- a/ui/validators.go +++ b/ui/validators.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "strings" - "unicode" "github.com/manifoldco/promptui" ) @@ -74,17 +73,3 @@ func YesNo() promptui.ValidateFunc { } } } - -// MinLength is a validation function that checks for a minimum length. -// An input length <= 0 indicates that the check should not be performed. -func MinLength(minLength int) promptui.ValidateFunc { - return func(s string) error { - if minLength <= 0 { - return nil - } - if len(strings.TrimRightFunc(s, unicode.IsSpace)) < minLength { - return fmt.Errorf("input does not meet minimum length requirement; must be at least %v characters", minLength) - } - return nil - } -} diff --git a/ui/validators_test.go b/ui/validators_test.go index 5b06044..816079d 100644 --- a/ui/validators_test.go +++ b/ui/validators_test.go @@ -90,63 +90,3 @@ func TestDNS(t *testing.T) { }) } } - -func TestMinLength(t *testing.T) { - tests := []struct { - name string - length int - input string - wantErr bool - }{ - { - name: "negative", - length: -5, - input: "foobar", - wantErr: false, - }, - { - name: "zero", - length: 0, - input: "localhost", - wantErr: false, - }, - { - name: "greater-than-min-length", - length: 5, - input: "foobar", - wantErr: false, - }, - { - name: "equal-min-length", - length: 6, - input: "foobar", - wantErr: false, - }, - { - name: "less-than-min-length", - length: 8, - input: "foobar", - wantErr: true, - }, - { - name: "ignore-post-whitespace-characters", - length: 7, - input: " pass ", - wantErr: true, - }, - { - name: "ignore-post-whitespace-characters-ok", - length: 6, - input: " pass ", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotErr := MinLength(tt.length)(tt.input) != nil - if gotErr != tt.wantErr { - t.Errorf("MinLength(%v)(%s) = %v, want %v", tt.length, tt.input, gotErr, tt.wantErr) - } - }) - } -}