diff --git a/cmd/buf-plugin-required-fields/main.go b/cmd/buf-plugin-required-fields/main.go index 1150bce..7182606 100644 --- a/cmd/buf-plugin-required-fields/main.go +++ b/cmd/buf-plugin-required-fields/main.go @@ -19,6 +19,7 @@ package main import ( "context" + "fmt" "strings" "buf.build/go/bufplugin/check" @@ -37,6 +38,21 @@ const ( requiredRequestFieldsOptionKey = "required_request_fields" ) +// FieldValidator validates a single field. +// Returns an error message and false if validation fails. +type FieldValidator func(field protoreflect.FieldDescriptor) *ValidationError + +// MessageValidator validates a message as a whole, based on the set of fields present in the message. +// Returns an error message and false if validation fails. +type MessageValidator func(message protoreflect.MessageDescriptor, messageFields map[string]bool) *ValidationError + +// ValidationError represents a linting error and includes the error message and +// the descriptor where the linting issue was found. +type ValidationError struct { + Message string + Descriptor protoreflect.Descriptor +} + var ( requiredEntityFieldsRuleSpec = &check.RuleSpec{ ID: requiredEntityFieldsRuleID, @@ -68,35 +84,50 @@ var ( crudMethodWithoutFullEntityPrefixes = []string{"List", "Get", "Delete"} defaultRequiredFields = []string{"id", "name", "account_id", "created_at"} defaultRequiredRequestFields = []string{"account_id"} + preferredEntityFieldNames = map[string]string{ + "updated_at": "last_modified_at", + "last_updated_at": "last_modified_at", + "cloud_provider": "cloud_provider_id", + "cloud_provider_region": "cloud_provider_region_id", + "cloud_region": "cloud_provider_region_id", + "cloud_region_id": "cloud_provider_region_id", + } ) func main() { check.Main(spec) } +// checkEntityFields validates all entity-related messages in a file descriptor. +// It applies: +// - Field-level validators (e.g. preferred naming). +// - Message-level validators (e.g. required fields). func checkEntityFields(ctx context.Context, responseWriter check.ResponseWriter, request check.Request, fileDescriptor descriptor.FileDescriptor) error { requiredFields, err := getRequiredEntityFields(request) if err != nil { return err } - for entityName := range extractEntityNames(fileDescriptor) { msg := fileDescriptor.ProtoreflectFileDescriptor().Messages().ByName(protoreflect.Name(entityName)) if msg == nil { continue } - missingFields := findMissingFields(msg, requiredFields) - if len(missingFields) > 0 { - responseWriter.AddAnnotation( - check.WithMessagef("%q is missing required fields: %v", entityName, missingFields), - check.WithDescriptor(msg), - ) + errors := validateMessage( + msg, + []FieldValidator{preferredFieldNamesValidator(preferredEntityFieldNames)}, + []MessageValidator{missingFieldsValidator(requiredFields)}, + ) + + for _, err := range errors { + responseWriter.AddAnnotation(check.WithMessage(err.Message), check.WithDescriptor(err.Descriptor)) } } return nil } +// checkRequestFields validates messages that end with "Request" and match a known +// CRUD pattern (e.g., ListClustersRequest). It ensures these messages include required fields. func checkRequestFields(ctx context.Context, responseWriter check.ResponseWriter, request check.Request, messageDescriptor protoreflect.MessageDescriptor) error { msgName := string(messageDescriptor.Name()) if !strings.HasSuffix(msgName, "Request") { @@ -110,12 +141,11 @@ func checkRequestFields(ctx context.Context, responseWriter check.ResponseWriter requiredFields = defaultRequiredRequestFields } } - missingFields := findMissingFields(messageDescriptor, requiredFields) - if len(missingFields) > 0 { - responseWriter.AddAnnotation( - check.WithMessagef("%q is missing required fields: %v", msgName, missingFields), - check.WithDescriptor(messageDescriptor), - ) + errors := validateMessage( + messageDescriptor, []FieldValidator{}, []MessageValidator{missingFieldsValidator(requiredFields)}, + ) + for _, err := range errors { + responseWriter.AddAnnotation(check.WithMessage(err.Message), check.WithDescriptor(err.Descriptor)) } return nil @@ -163,21 +193,73 @@ func inferEntityFromMethodName(methodName string) string { return "" } -// findMissingFields checks if a message contains all required fields. -func findMissingFields(msg protoreflect.MessageDescriptor, requiredFields []string) []string { - missingFields := []string{} - fieldMap := make(map[string]bool) +// validateMessage runs a set of field-level and message-level validators +// against a protobuf message descriptor. +// +// Field-level validators are executed for each individual field in the message, +// allowing checks like discouraged field names or naming conventions. +// +// Message-level validators are run once per message, and have access to the +// full set of field names, enabling checks like required field presence. +func validateMessage(msg protoreflect.MessageDescriptor, fieldValidators []FieldValidator, messageValidators []MessageValidator) []ValidationError { + // missingFields := []string{} + existingFields := make(map[string]bool) fields := msg.Fields() + errors := []ValidationError{} for i := 0; i < fields.Len(); i++ { field := fields.Get(i) - fieldMap[string(field.Name())] = true + fieldName := string(field.Name()) + existingFields[string(fieldName)] = true + + for _, validator := range fieldValidators { + if err := validator(field); err != nil { + errors = append(errors, *err) + } + } } - for _, requiredField := range requiredFields { - if !fieldMap[requiredField] { - missingFields = append(missingFields, requiredField) + for _, validator := range messageValidators { + if err := validator(msg, existingFields); err != nil { + errors = append(errors, *err) } } - return missingFields + + return errors +} + +// preferredFieldNamesValidator returns a FieldValidator that checks +// if a given field name is discouraged and suggests the preferred one. +func preferredFieldNamesValidator(preferredFieldNames map[string]string) FieldValidator { + return func(field protoreflect.FieldDescriptor) *ValidationError { + fieldName := string(field.Name()) + if suggestion, ok := preferredFieldNames[fieldName]; ok && suggestion != fieldName { + return &ValidationError{ + Message: fmt.Sprintf("field %q is discouraged, use %q instead", fieldName, suggestion), + Descriptor: field, + } + } + return nil + } +} + +// missingFieldsValidator returns a MessageValidator that ensures a message +// contains all of the specified required fields. +func missingFieldsValidator(requiredFields []string) MessageValidator { + return func(message protoreflect.MessageDescriptor, messageFields map[string]bool) *ValidationError { + messageName := string(message.Name()) + missingFields := []string{} + for _, requiredField := range requiredFields { + if !messageFields[requiredField] { + missingFields = append(missingFields, requiredField) + } + } + if len(missingFields) > 0 { + return &ValidationError{ + Message: fmt.Sprintf("message %q is missing required fields: %v", messageName, missingFields), + Descriptor: message, + } + } + return nil + } } diff --git a/cmd/buf-plugin-required-fields/main_test.go b/cmd/buf-plugin-required-fields/main_test.go index 970a9e5..c829367 100644 --- a/cmd/buf-plugin-required-fields/main_test.go +++ b/cmd/buf-plugin-required-fields/main_test.go @@ -43,15 +43,37 @@ func TestSimpleFailureWithOption(t *testing.T) { ExpectedAnnotations: []checktest.ExpectedAnnotation{ { RuleID: requiredEntityFieldsRuleID, - Message: "\"BookCategory\" is missing required fields: [category]", + Message: "field \"updated_at\" is discouraged, use \"last_modified_at\" instead", FileLocation: &checktest.ExpectedFileLocation{ FileName: "simple.proto", - StartLine: 51, + StartLine: 50, + StartColumn: 4, + EndLine: 50, + EndColumn: 45, + }, + }, + { + RuleID: requiredEntityFieldsRuleID, + Message: "message \"BookCategory\" is missing required fields: [category]", + FileLocation: &checktest.ExpectedFileLocation{ + FileName: "simple.proto", + StartLine: 53, StartColumn: 0, - EndLine: 56, + EndLine: 60, EndColumn: 1, }, }, + { + RuleID: requiredEntityFieldsRuleID, + Message: "field \"last_updated_at\" is discouraged, use \"last_modified_at\" instead", + FileLocation: &checktest.ExpectedFileLocation{ + FileName: "simple.proto", + StartLine: 59, + StartColumn: 4, + EndLine: 59, + EndColumn: 50, + }, + }, }, }.Run(t) } @@ -70,29 +92,51 @@ func TestSimpleFailure(t *testing.T) { ExpectedAnnotations: []checktest.ExpectedAnnotation{ { RuleID: requiredEntityFieldsRuleID, - Message: "\"Book\" is missing required fields: [id account_id created_at]", + Message: "message \"Book\" is missing required fields: [id account_id created_at]", FileLocation: &checktest.ExpectedFileLocation{ FileName: "simple.proto", StartLine: 42, StartColumn: 0, - EndLine: 49, + EndLine: 51, EndColumn: 1, }, }, { RuleID: requiredEntityFieldsRuleID, - Message: "\"BookCategory\" is missing required fields: [name]", + Message: "field \"updated_at\" is discouraged, use \"last_modified_at\" instead", FileLocation: &checktest.ExpectedFileLocation{ FileName: "simple.proto", - StartLine: 51, + StartLine: 50, + StartColumn: 4, + EndLine: 50, + EndColumn: 45, + }, + }, + { + RuleID: requiredEntityFieldsRuleID, + Message: "message \"BookCategory\" is missing required fields: [name]", + FileLocation: &checktest.ExpectedFileLocation{ + FileName: "simple.proto", + StartLine: 53, StartColumn: 0, - EndLine: 56, + EndLine: 60, EndColumn: 1, }, }, + { + RuleID: requiredEntityFieldsRuleID, + Message: "field \"last_updated_at\" is discouraged, use \"last_modified_at\" instead", + FileLocation: &checktest.ExpectedFileLocation{ + FileName: "simple.proto", + StartLine: 59, + StartColumn: 4, + EndLine: 59, + EndColumn: 50, + }, + }, { RuleID: requiredRequestFieldsRuleID, - Message: "\"ListBooksRequest\" is missing required fields: [account_id]", + Message: "message \"ListBooksRequest\" is missing required fields: [account_id]", FileLocation: &checktest.ExpectedFileLocation{ FileName: "simple.proto", StartLine: 17, @@ -103,7 +147,7 @@ func TestSimpleFailure(t *testing.T) { }, { RuleID: requiredRequestFieldsRuleID, - Message: "\"GetBookRequest\" is missing required fields: [account_id]", + Message: "message \"GetBookRequest\" is missing required fields: [account_id]", FileLocation: &checktest.ExpectedFileLocation{ FileName: "simple.proto", StartLine: 25, diff --git a/cmd/buf-plugin-required-fields/testdata/simple_failure/simple.proto b/cmd/buf-plugin-required-fields/testdata/simple_failure/simple.proto index a38c6c4..45060ef 100644 --- a/cmd/buf-plugin-required-fields/testdata/simple_failure/simple.proto +++ b/cmd/buf-plugin-required-fields/testdata/simple_failure/simple.proto @@ -47,6 +47,8 @@ message Book { // missing `created_at` field BookCategory category = 2; Publisher publisher = 3; + // updated_at instead of last_modified_at + google.protobuf.Timestamp updated_at = 4; } message BookCategory { @@ -54,6 +56,8 @@ message BookCategory { // missing `name` field string account_id = 2; google.protobuf.Timestamp created_at = 3; + // last_updated_at instead of last_modified_at + google.protobuf.Timestamp last_updated_at = 4; } // this message does not have any related CRUD method, we don't consider it an entity and