@@ -611,18 +611,26 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
611611
612612 if cmd.Action == nil {
613613 cmd.Action = helpCommandAction
614- } else if len(cmd.Arguments) > 0 {
615- rargs := cmd.Args().Slice()
616- tracef("calling argparse with %[1]v", rargs)
617- for _, arg := range cmd.Arguments {
618- var err error
619- rargs, err = arg.Parse(rargs)
620- if err != nil {
621- tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
622- return err
614+ } else {
615+ if err := cmd.checkPersistentRequiredFlags(); err != nil {
616+ cmd.isInError = true
617+ _ = ShowSubcommandHelp(cmd)
618+ return err
619+ }
620+
621+ if len(cmd.Arguments) > 0 {
622+ rargs := cmd.Args().Slice()
623+ tracef("calling argparse with %[1]v", rargs)
624+ for _, arg := range cmd.Arguments {
625+ var err error
626+ rargs, err = arg.Parse(rargs)
627+ if err != nil {
628+ tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
629+ return err
630+ }
623631 }
632+ cmd.parsedArgs = &stringSliceArgs{v: rargs}
624633 }
625- cmd.parsedArgs = &stringSliceArgs{v: rargs}
626634 }
627635
628636 if err := cmd.Action(ctx, cmd); err != nil {
@@ -929,26 +937,59 @@ func (cmd *Command) lookupFlagSet(name string) *flag.FlagSet {
929937 return nil
930938}
931939
940+ func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
941+ if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
942+ flagPresent := false
943+ flagName := ""
944+
945+ for _, key := range f.Names() {
946+ flagName = key
947+
948+ if cmd.IsSet(strings.TrimSpace(key)) {
949+ flagPresent = true
950+ }
951+ }
952+
953+ if !flagPresent && flagName != "" {
954+ return false, flagName
955+ }
956+ }
957+ return true, ""
958+ }
959+
932960func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
933961 tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
934962
935963 missingFlags := []string{}
936964
937965 for _, f := range cmd.Flags {
938- if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
939- flagPresent := false
940- flagName := ""
966+ if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() {
967+ if ok, name := cmd.checkRequiredFlag(f); !ok {
968+ missingFlags = append(missingFlags, name)
969+ }
970+ }
971+ }
941972
942- for _, key := range f.Names() {
943- flagName = key
973+ if len(missingFlags) != 0 {
974+ tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name)
944975
945- if cmd.IsSet(strings.TrimSpace(key)) {
946- flagPresent = true
947- }
948- }
976+ return &errRequiredFlags{missingFlags: missingFlags}
977+ }
978+
979+ tracef("all required flags set (cmd=%[1]q)", cmd.Name)
980+
981+ return nil
982+ }
983+
984+ func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr {
985+ tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
986+
987+ missingFlags := []string{}
949988
950- if !flagPresent && flagName != "" {
951- missingFlags = append(missingFlags, flagName)
989+ for _, f := range cmd.appliedFlags {
990+ if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() {
991+ if ok, name := cmd.checkRequiredFlag(f); !ok {
992+ missingFlags = append(missingFlags, name)
952993 }
953994 }
954995 }
0 commit comments