Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 104 additions & 22 deletions cmd/buf-plugin-required-fields/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package main

import (
"context"
"fmt"
"strings"

"buf.build/go/bufplugin/check"
Expand All @@ -37,6 +38,21 @@ const (
requiredRequestFieldsOptionKey = "required_request_fields"
)

// FieldValidator validates a single field.
// Returns an error message and false if validation fails.
type FieldValidator func(field protoreflect.FieldDescriptor) *ValidationError

// MessageValidator validates a message as a whole, based on the set of fields present in the message.
// Returns an error message and false if validation fails.
type MessageValidator func(message protoreflect.MessageDescriptor, messageFields map[string]bool) *ValidationError

// ValidationError represents a linting error and includes the error message and
// the descriptor where the linting issue was found.
type ValidationError struct {
Message string
Descriptor protoreflect.Descriptor
}

var (
requiredEntityFieldsRuleSpec = &check.RuleSpec{
ID: requiredEntityFieldsRuleID,
Expand Down Expand Up @@ -68,35 +84,50 @@ var (
crudMethodWithoutFullEntityPrefixes = []string{"List", "Get", "Delete"}
defaultRequiredFields = []string{"id", "name", "account_id", "created_at"}
defaultRequiredRequestFields = []string{"account_id"}
preferredEntityFieldNames = map[string]string{
"updated_at": "last_modified_at",
"last_updated_at": "last_modified_at",
"cloud_provider": "cloud_provider_id",
"cloud_provider_region": "cloud_provider_region_id",
"cloud_region": "cloud_provider_region_id",
"cloud_region_id": "cloud_provider_region_id",
}
)

func main() {
check.Main(spec)
}

// checkEntityFields validates all entity-related messages in a file descriptor.
// It applies:
// - Field-level validators (e.g. preferred naming).
// - Message-level validators (e.g. required fields).
func checkEntityFields(ctx context.Context, responseWriter check.ResponseWriter, request check.Request, fileDescriptor descriptor.FileDescriptor) error {
requiredFields, err := getRequiredEntityFields(request)
if err != nil {
return err
}

for entityName := range extractEntityNames(fileDescriptor) {
msg := fileDescriptor.ProtoreflectFileDescriptor().Messages().ByName(protoreflect.Name(entityName))
if msg == nil {
continue
}
missingFields := findMissingFields(msg, requiredFields)
if len(missingFields) > 0 {
responseWriter.AddAnnotation(
check.WithMessagef("%q is missing required fields: %v", entityName, missingFields),
check.WithDescriptor(msg),
)
errors := validateMessage(
msg,
[]FieldValidator{preferredFieldNamesValidator(preferredEntityFieldNames)},
[]MessageValidator{missingFieldsValidator(requiredFields)},
)

for _, err := range errors {
responseWriter.AddAnnotation(check.WithMessage(err.Message), check.WithDescriptor(err.Descriptor))
}
}

return nil
}

// checkRequestFields validates messages that end with "Request" and match a known
// CRUD pattern (e.g., ListClustersRequest). It ensures these messages include required fields.
func checkRequestFields(ctx context.Context, responseWriter check.ResponseWriter, request check.Request, messageDescriptor protoreflect.MessageDescriptor) error {
msgName := string(messageDescriptor.Name())
if !strings.HasSuffix(msgName, "Request") {
Expand All @@ -110,12 +141,11 @@ func checkRequestFields(ctx context.Context, responseWriter check.ResponseWriter
requiredFields = defaultRequiredRequestFields
}
}
missingFields := findMissingFields(messageDescriptor, requiredFields)
if len(missingFields) > 0 {
responseWriter.AddAnnotation(
check.WithMessagef("%q is missing required fields: %v", msgName, missingFields),
check.WithDescriptor(messageDescriptor),
)
errors := validateMessage(
messageDescriptor, []FieldValidator{}, []MessageValidator{missingFieldsValidator(requiredFields)},
)
for _, err := range errors {
responseWriter.AddAnnotation(check.WithMessage(err.Message), check.WithDescriptor(err.Descriptor))
}

return nil
Expand Down Expand Up @@ -163,21 +193,73 @@ func inferEntityFromMethodName(methodName string) string {
return ""
}

// findMissingFields checks if a message contains all required fields.
func findMissingFields(msg protoreflect.MessageDescriptor, requiredFields []string) []string {
missingFields := []string{}
fieldMap := make(map[string]bool)
// validateMessage runs a set of field-level and message-level validators
// against a protobuf message descriptor.
//
// Field-level validators are executed for each individual field in the message,
// allowing checks like discouraged field names or naming conventions.
//
// Message-level validators are run once per message, and have access to the
// full set of field names, enabling checks like required field presence.
func validateMessage(msg protoreflect.MessageDescriptor, fieldValidators []FieldValidator, messageValidators []MessageValidator) []ValidationError {
// missingFields := []string{}
existingFields := make(map[string]bool)
fields := msg.Fields()
errors := []ValidationError{}

for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
fieldMap[string(field.Name())] = true
fieldName := string(field.Name())
existingFields[string(fieldName)] = true

for _, validator := range fieldValidators {
if err := validator(field); err != nil {
errors = append(errors, *err)
}
}
}

for _, requiredField := range requiredFields {
if !fieldMap[requiredField] {
missingFields = append(missingFields, requiredField)
for _, validator := range messageValidators {
if err := validator(msg, existingFields); err != nil {
errors = append(errors, *err)
}
}
return missingFields

return errors
}

// preferredFieldNamesValidator returns a FieldValidator that checks
// if a given field name is discouraged and suggests the preferred one.
func preferredFieldNamesValidator(preferredFieldNames map[string]string) FieldValidator {
return func(field protoreflect.FieldDescriptor) *ValidationError {
fieldName := string(field.Name())
if suggestion, ok := preferredFieldNames[fieldName]; ok && suggestion != fieldName {
return &ValidationError{
Message: fmt.Sprintf("field %q is discouraged, use %q instead", fieldName, suggestion),
Descriptor: field,
}
}
return nil
}
}

// missingFieldsValidator returns a MessageValidator that ensures a message
// contains all of the specified required fields.
func missingFieldsValidator(requiredFields []string) MessageValidator {
return func(message protoreflect.MessageDescriptor, messageFields map[string]bool) *ValidationError {
messageName := string(message.Name())
missingFields := []string{}
for _, requiredField := range requiredFields {
if !messageFields[requiredField] {
missingFields = append(missingFields, requiredField)
}
}
if len(missingFields) > 0 {
return &ValidationError{
Message: fmt.Sprintf("message %q is missing required fields: %v", messageName, missingFields),
Descriptor: message,
}
}
return nil
}
}
64 changes: 54 additions & 10 deletions cmd/buf-plugin-required-fields/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,37 @@ func TestSimpleFailureWithOption(t *testing.T) {
ExpectedAnnotations: []checktest.ExpectedAnnotation{
{
RuleID: requiredEntityFieldsRuleID,
Message: "\"BookCategory\" is missing required fields: [category]",
Message: "field \"updated_at\" is discouraged, use \"last_modified_at\" instead",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 51,
StartLine: 50,
StartColumn: 4,
EndLine: 50,
EndColumn: 45,
},
},
{
RuleID: requiredEntityFieldsRuleID,
Message: "message \"BookCategory\" is missing required fields: [category]",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 53,
StartColumn: 0,
EndLine: 56,
EndLine: 60,
EndColumn: 1,
},
},
{
RuleID: requiredEntityFieldsRuleID,
Message: "field \"last_updated_at\" is discouraged, use \"last_modified_at\" instead",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 59,
StartColumn: 4,
EndLine: 59,
EndColumn: 50,
},
},
},
}.Run(t)
}
Expand All @@ -70,29 +92,51 @@ func TestSimpleFailure(t *testing.T) {
ExpectedAnnotations: []checktest.ExpectedAnnotation{
{
RuleID: requiredEntityFieldsRuleID,
Message: "\"Book\" is missing required fields: [id account_id created_at]",
Message: "message \"Book\" is missing required fields: [id account_id created_at]",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 42,
StartColumn: 0,
EndLine: 49,
EndLine: 51,
EndColumn: 1,
},
},
{
RuleID: requiredEntityFieldsRuleID,
Message: "\"BookCategory\" is missing required fields: [name]",
Message: "field \"updated_at\" is discouraged, use \"last_modified_at\" instead",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 51,
StartLine: 50,
StartColumn: 4,
EndLine: 50,
EndColumn: 45,
},
},
{
RuleID: requiredEntityFieldsRuleID,
Message: "message \"BookCategory\" is missing required fields: [name]",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 53,
StartColumn: 0,
EndLine: 56,
EndLine: 60,
EndColumn: 1,
},
},
{
RuleID: requiredEntityFieldsRuleID,
Message: "field \"last_updated_at\" is discouraged, use \"last_modified_at\" instead",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 59,
StartColumn: 4,
EndLine: 59,
EndColumn: 50,
},
},
{
RuleID: requiredRequestFieldsRuleID,
Message: "\"ListBooksRequest\" is missing required fields: [account_id]",
Message: "message \"ListBooksRequest\" is missing required fields: [account_id]",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 17,
Expand All @@ -103,7 +147,7 @@ func TestSimpleFailure(t *testing.T) {
},
{
RuleID: requiredRequestFieldsRuleID,
Message: "\"GetBookRequest\" is missing required fields: [account_id]",
Message: "message \"GetBookRequest\" is missing required fields: [account_id]",
FileLocation: &checktest.ExpectedFileLocation{
FileName: "simple.proto",
StartLine: 25,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ message Book {
// missing `created_at` field
BookCategory category = 2;
Publisher publisher = 3;
// updated_at instead of last_modified_at
google.protobuf.Timestamp updated_at = 4;
}

message BookCategory {
string id = 1;
// missing `name` field
string account_id = 2;
google.protobuf.Timestamp created_at = 3;
// last_updated_at instead of last_modified_at
google.protobuf.Timestamp last_updated_at = 4;
}

// this message does not have any related CRUD method, we don't consider it an entity and
Expand Down
Loading