Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ cc_library(
hdrs = ["type_check_env.h"],
deps = [
":descriptor_pool_type_introspector",
":proto_type_mask",
":proto_type_mask_registry",
"//common:constant",
"//common:container",
"//common:decl",
Expand All @@ -76,6 +78,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
Expand Down Expand Up @@ -129,6 +132,7 @@ cc_library(
deps = [
":format_type_name",
":namespace_generator",
":proto_type_mask",
":type_check_env",
":type_inference_context",
"//checker:checker_options",
Expand All @@ -153,6 +157,7 @@ cc_library(
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:absl_check",
Expand Down Expand Up @@ -225,6 +230,7 @@ cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:optional",
"@com_google_protobuf//:protobuf",
Expand Down
5 changes: 5 additions & 0 deletions checker/internal/type_check_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cstddef>
#include <cstdint>
#include <optional>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -96,6 +97,10 @@ absl::StatusOr<std::optional<VariableDecl>> TypeCheckEnv::LookupTypeConstant(

absl::StatusOr<std::optional<StructTypeField>> TypeCheckEnv::LookupStructField(
absl::string_view type_name, absl::string_view field_name) const {
if (proto_type_mask_registry_ != nullptr &&
!proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) {
return absl::nullopt;
}
// Check the type providers in registration order.
// Note: this doesn't allow for shadowing a type with a subset type of the
// same name -- the later type provider will still be considered when
Expand Down
14 changes: 14 additions & 0 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@
#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "checker/internal/descriptor_pool_type_introspector.h"
#include "checker/internal/proto_type_mask.h"
#include "checker/internal/proto_type_mask_registry.h"
#include "common/constant.h"
#include "common/container.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_introspector.h"
#include "internal/status_macros.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"

Expand Down Expand Up @@ -154,6 +158,14 @@ class TypeCheckEnv {
variables_[decl.name()] = std::move(decl);
}

absl::Status CreateProtoTypeMaskRegistry(
const std::vector<ProtoTypeMask>& proto_type_masks) {
CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_,
ProtoTypeMaskRegistry::Create(descriptor_pool_.get(),
proto_type_masks));
return absl::OkStatus();
}

const absl::flat_hash_map<std::string, FunctionDecl>& functions() const {
return functions_;
}
Expand Down Expand Up @@ -224,6 +236,8 @@ class TypeCheckEnv {
absl::flat_hash_map<std::string, VariableDecl> variables_;
absl::flat_hash_map<std::string, FunctionDecl> functions_;

std::shared_ptr<ProtoTypeMaskRegistry> proto_type_mask_registry_;

// Type providers for custom types.
std::vector<std::shared_ptr<const TypeIntrospector>> type_providers_;

Expand Down
56 changes: 50 additions & 6 deletions checker/internal/type_checker_builder_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@

#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "checker/internal/proto_type_mask.h"
#include "checker/internal/type_check_env.h"
#include "checker/internal/type_checker_impl.h"
#include "checker/type_checker.h"
Expand Down Expand Up @@ -86,10 +90,19 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) {
}

absl::Status AddWellKnownContextDeclarationVariables(
const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env,
bool use_json_name) {
const google::protobuf::Descriptor* absl_nonnull descriptor,
const absl::flat_hash_map<absl::string_view,
absl::btree_set<absl::string_view>>&
context_type_fields,
TypeCheckEnv& env, bool use_json_name) {
for (int i = 0; i < descriptor->field_count(); ++i) {
const google::protobuf::FieldDescriptor* field = descriptor->field(i);
// Skip fields that are hidden because of a proto type mask.
auto map_iterator = context_type_fields.find(descriptor->full_name());
if (map_iterator != context_type_fields.end() &&
!map_iterator->second.contains(field->name())) {
continue;
}
Type type = MessageTypeField(field).GetType();
if (type.IsEnum()) {
type = IntType();
Expand All @@ -109,11 +122,15 @@ absl::Status AddWellKnownContextDeclarationVariables(
}

absl::Status AddContextDeclarationVariables(
const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) {
const google::protobuf::Descriptor* absl_nonnull descriptor,
const absl::flat_hash_map<absl::string_view,
absl::btree_set<absl::string_view>>&
context_type_fields,
TypeCheckEnv& env) {
const bool use_json_name = env.proto_type_introspector().use_json_name();
if (IsWellKnownMessageType(descriptor)) {
return AddWellKnownContextDeclarationVariables(descriptor, env,
use_json_name);
return AddWellKnownContextDeclarationVariables(
descriptor, context_type_fields, env, use_json_name);
}
CEL_ASSIGN_OR_RETURN(auto fields,
env.proto_type_introspector().ListFieldsForStructType(
Expand All @@ -131,6 +148,13 @@ absl::Status AddContextDeclarationVariables(

absl::string_view name = field_entry.name;

// Skip fields that are hidden because of a proto type mask.
auto map_iterator = context_type_fields.find(descriptor->full_name());
if (map_iterator != context_type_fields.end() &&
!map_iterator->second.contains(name)) {
continue;
}

if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) {
return absl::AlreadyExistsError(
absl::StrCat("variable '", name,
Expand Down Expand Up @@ -317,7 +341,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig(
}

for (const google::protobuf::Descriptor* context_type : config.context_types) {
CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env));
CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(
context_type, config.context_type_fields, env));
}

for (VariableDeclRecord& var : config.variables) {
Expand All @@ -339,6 +364,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig(
}
}

CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks));

return absl::OkStatus();
}

Expand Down Expand Up @@ -462,6 +489,23 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration(
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask(
absl::string_view type, std::vector<std::string> field_paths) {
if (field_paths.empty()) {
return absl::InvalidArgumentError("field paths cannot be the empty set");
}

ProtoTypeMask proto_type_mask(std::string(type), field_paths);
target_config_->proto_type_masks.push_back(proto_type_mask);

CEL_RETURN_IF_ERROR(AddContextDeclaration(type));
CEL_ASSIGN_OR_RETURN(
absl::btree_set<absl::string_view> field_names,
proto_type_mask.GetFieldNames(template_env_.descriptor_pool()));
target_config_->context_type_fields.insert({type, std::move(field_names)});
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) {
CEL_RETURN_IF_ERROR(
ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation,
Expand Down
9 changes: 9 additions & 0 deletions checker/internal/type_checker_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
#include <vector>

#include "absl/base/nullability.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "checker/checker_options.h"
#include "checker/internal/proto_type_mask.h"
#include "checker/internal/type_check_env.h"
#include "checker/type_checker.h"
#include "checker/type_checker_builder.h"
Expand Down Expand Up @@ -76,6 +78,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
absl::Status AddVariable(const VariableDecl& decl) override;
absl::Status AddOrReplaceVariable(const VariableDecl& decl) override;
absl::Status AddContextDeclaration(absl::string_view type) override;
absl::Status AddContextDeclarationWithProtoTypeMask(
absl::string_view type, std::vector<std::string> field_paths) override;

absl::Status AddFunction(const FunctionDecl& decl) override;
absl::Status MergeFunction(const FunctionDecl& decl) override;
Expand Down Expand Up @@ -130,6 +134,11 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
std::vector<FunctionDeclRecord> functions;
std::vector<std::shared_ptr<const TypeIntrospector>> type_providers;
std::vector<const google::protobuf::Descriptor*> context_types;
// Maps context type names to fields names to add as variables.
// Only includes context types that are defined with proto type masks.
absl::flat_hash_map<absl::string_view, absl::btree_set<absl::string_view>>
context_type_fields;
std::vector<ProtoTypeMask> proto_type_masks;
};

absl::Status BuildLibraryConfig(const CheckerLibrary& library,
Expand Down
Loading