diff --git a/fuzztest/internal/domains/BUILD b/fuzztest/internal/domains/BUILD index 545fef36..5682ffb8 100644 --- a/fuzztest/internal/domains/BUILD +++ b/fuzztest/internal/domains/BUILD @@ -187,9 +187,9 @@ cc_library( hdrs = ["flatbuffers_domain_impl.h"], deps = [ ":core_domains_impl", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/random:bit_gen_ref", diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc index 90313441..c675ae00 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.cc +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc @@ -38,13 +38,18 @@ namespace fuzztest::internal { FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( const reflection::Schema* absl_nonnull schema, const reflection::Object* absl_nonnull table_object) - : schema_(schema), table_object_(table_object) {} + : schema_(schema), table_object_(table_object) { + for (const auto& field : *table_object_->fields()) { + fields_by_id_[field->id()] = field; + } +} FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( const FlatbuffersTableUntypedDomainImpl& other) : DomainBase(other), schema_(other.schema_), - table_object_(other.table_object_) { + table_object_(other.table_object_), + fields_by_id_(other.fields_by_id_) { absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = other.domains_; @@ -55,6 +60,7 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( DomainBase::operator=(other); schema_ = other.schema_; table_object_ = other.table_object_; + fields_by_id_ = other.fields_by_id_; absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = other.domains_; @@ -63,7 +69,9 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( FlatbuffersTableUntypedDomainImpl&& other) - : schema_(other.schema_), table_object_(other.table_object_) { + : schema_(other.schema_), + table_object_(other.table_object_), + fields_by_id_(std::move(other.fields_by_id_)) { absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); @@ -74,6 +82,7 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( FlatbuffersTableUntypedDomainImpl&& other) { schema_ = other.schema_; table_object_ = other.table_object_; + fields_by_id_ = std::move(other.fields_by_id_); absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); @@ -87,7 +96,7 @@ FlatbuffersTableUntypedDomainImpl::Init(absl::BitGenRef prng) { return *seed; } corpus_type val; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); } return val; @@ -98,7 +107,7 @@ void FlatbuffersTableUntypedDomainImpl::Mutate( corpus_type& val, absl::BitGenRef prng, const domain_implementor::MutationMetadata& metadata, bool only_shrink) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, CountNumberOfMutableFieldsVisitor{*this, field_count, val, only_shrink}); @@ -112,7 +121,7 @@ void FlatbuffersTableUntypedDomainImpl::Mutate( uint64_t FlatbuffersTableUntypedDomainImpl::CountNumberOfFields( corpus_type& val) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField( schema_, field, CountNumberOfMutableFieldsVisitor{*this, field_count, val}); @@ -130,29 +139,16 @@ uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( return fields_count; } - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { if (!IsSupportedField(field)) { if (only_shrink && !val.contains(field->id())) continue; } ++field_counter; - if (field_counter == selected_field_index) { - VisitFlatbufferField( - schema_, field, - MutateVisitor{*this, prng, metadata, only_shrink, val}); - return field_counter; - } - - if (field->type()->base_type() == reflection::BaseType::Obj) { - auto sub_object = schema_->objects()->Get(field->type()->index()); - if (!sub_object->is_struct()) { - field_counter += - GetCachedDomain(field).MutateSelectedField( - val[field->id()], prng, metadata, only_shrink, - selected_field_index - field_counter); - } - // TODO: Add support for structs. - } + VisitFlatbufferField( + schema_, field, + MutateSelectedFieldVisitor{*this, field_counter, val, prng, metadata, + only_shrink, selected_field_index}); if (field_counter >= selected_field_index) { return field_counter; @@ -163,7 +159,7 @@ uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( absl::Status FlatbuffersTableUntypedDomainImpl::ValidateCorpusValue( const corpus_type& corpus_value) const { - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { absl::Status result; GenericDomainCorpusType field_corpus; if (auto it = corpus_value.find(field->id()); it != corpus_value.end()) { @@ -183,7 +179,7 @@ FlatbuffersTableUntypedDomainImpl::FromValue(const value_type& value) const { return std::nullopt; } corpus_type ret; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); } return ret; diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h index 328b2e7b..92b71585 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.h +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h @@ -26,9 +26,9 @@ #include #include -#include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/bit_gen_ref.h" @@ -359,6 +359,7 @@ class FlatbuffersTableUntypedDomainImpl private: const reflection::Schema* absl_nonnull schema_; const reflection::Object* absl_nonnull table_object_; + absl::btree_map fields_by_id_; mutable absl::Mutex mutex_; mutable absl::flat_hash_map domains_ ABSL_GUARDED_BY(mutex_); @@ -396,10 +397,10 @@ class FlatbuffersTableUntypedDomainImpl const reflection::Field* absl_nullable GetFieldById( typename corpus_type::key_type id) const { - const auto it = - absl::c_find_if(*table_object_->fields(), - [id](const auto* field) { return field->id() == id; }); - return it != table_object_->fields()->end() ? *it : nullptr; + if (auto it = fields_by_id_.find(id); it != fields_by_id_.end()) { + return it->second; + } + return nullptr; } struct SerializeVisitor { @@ -618,22 +619,36 @@ class FlatbuffersTableUntypedDomainImpl } }; - struct MutateVisitor { + struct MutateSelectedFieldVisitor { FlatbuffersTableUntypedDomainImpl& self; + uint64_t& field_counter; + corpus_type& val; absl::BitGenRef prng; const domain_implementor::MutationMetadata& metadata; bool only_shrink; - corpus_type& corpus_value; + uint64_t selected_field_index; template void Visit(const reflection::Field* absl_nonnull field) { - auto& domain = self.GetCachedDomain(field); - auto it = corpus_value.find(field->id()); - if (it == corpus_value.end()) { - if (only_shrink) return; - it = corpus_value.try_emplace(field->id(), domain.Init(prng)).first; + if (!self.IsSupportedField(field)) return; + if (only_shrink && !val.contains(field->id())) return; + + if (field_counter == selected_field_index) { + auto& domain = self.GetCachedDomain(field); + auto it = val.find(field->id()); + if (it == val.end()) { + if (only_shrink) return; + it = val.try_emplace(field->id(), domain.Init(prng)).first; + } + domain.Mutate(it->second, prng, metadata, only_shrink); + } else { + auto& domain = self.GetCachedDomain(field); + if (auto it = val.find(field->id()); it != val.end()) { + field_counter += domain.MutateSelectedField( + it->second, prng, metadata, only_shrink, + selected_field_index - field_counter); + } } - domain.Mutate(it->second, prng, metadata, only_shrink); } }; @@ -764,6 +779,14 @@ class FlatbuffersTableDomainImpl return inner_->CountNumberOfFields(val.untyped_corpus); } + uint64_t MutateSelectedField( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink, + uint64_t selected_field_index) { + return inner_->MutateSelectedField(val.untyped_corpus, prng, metadata, + only_shrink, selected_field_index); + } + // Mutates the given corpus value. void Mutate(corpus_type& val, absl::BitGenRef prng, const domain_implementor::MutationMetadata& metadata,