diff --git a/WORKSPACE b/WORKSPACE index dc59b6b7c..3d35cd3ea 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -54,6 +54,16 @@ http_archive( url = "https://github.com/sewenew/redis-plus-plus/archive/refs/tags/1.2.3.zip", ) +http_archive( + name = "rocksdb", + build_file = "//build_deps/toolchains/rocksdb:rocksdb.BUILD", + sha256 = "2df8f34a44eda182e22cf84dee7a14f17f55d305ff79c06fb3cd1e5f8831e00d", + strip_prefix = "rocksdb-6.22.1", + urls = [ + "https://github.com/facebook/rocksdb/archive/refs/tags/v6.22.1.tar.gz", + ], +) + http_archive( name = "hadoop", build_file = "//third_party:hadoop.BUILD", diff --git a/build_deps/toolchains/rocksdb/BUILD b/build_deps/toolchains/rocksdb/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/build_deps/toolchains/rocksdb/rocksdb.BUILD b/build_deps/toolchains/rocksdb/rocksdb.BUILD new file mode 100644 index 000000000..dd207e19b --- /dev/null +++ b/build_deps/toolchains/rocksdb/rocksdb.BUILD @@ -0,0 +1,38 @@ +load("@rules_foreign_cc//foreign_cc:defs.bzl", "make") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # BSD + +filegroup( + name = "all_srcs", + srcs = glob(["**"]), + visibility = ["//visibility:public"], +) + +# Enable this to compile RocksDB from source instead. +#make( +# name = "rocksdb", +# args = [ +# "EXTRA_CXXFLAGS=\"-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\"", +# "-j6", +# ], +# targets = ["static_lib", "install-static"], +# lib_source = "@rocksdb//:all_srcs", +# out_static_libs = ["librocksdb.a"], +#) + +# Enable this to use the precompiled library in our image. +cc_library( + name = "rocksdb", + includes = ["./include"], + hdrs = glob(["rocksdb/*.h"]), + visibility = ["//visibility:public"], +) +cc_import( + name = "rocksdb_precompiled", + static_library = "librocksdb.a", + visibility = ["//visibility:public"], +) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/__init__.py b/tensorflow_recommenders_addons/dynamic_embedding/__init__.py index 950bea68f..18495877a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/__init__.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/__init__.py @@ -21,6 +21,9 @@ 'RedisTable', 'RedisTableConfig', 'RedisTableCreator', + 'RocksDBTable', + 'RocksDBTableConfig', + 'RocksDBTableCreator', 'Variable', 'TrainableWrapper', 'DynamicEmbeddingOptimizer', @@ -50,11 +53,15 @@ CuckooHashTableCreator, RedisTableConfig, RedisTableCreator, + RocksDBTableConfig, + RocksDBTableCreator, ) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.cuckoo_hashtable_ops import ( CuckooHashTable,) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.redis_table_ops import ( RedisTable,) +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.rocksdb_table_ops import ( + RocksDBTable,) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import ( embedding_lookup,) from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import ( diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD b/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD index 9a53530b6..18b01ae47 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD @@ -71,6 +71,30 @@ custom_op_library( ], ) +custom_op_library( + name = "_rocksdb_table_ops.so", + srcs = [ + "kernels/rocksdb_table_op.cc", + "kernels/rocksdb_table_op.h", + "ops/rocksdb_table_ops.cc", + "utils/types.h", + "utils/utils.h", + ], + # Hack: To allow allow locating . + includes = [ + ".", + ], + linkopts = [ + "-L/usr/local/lib", + "-lbz2", + "-llz4", + "-lzstd", + ], + deps = [ + "@rocksdb", + ], +) + custom_op_library( name = "_math_ops.so", srcs = [ diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/rocksdb_table_op.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/rocksdb_table_op.cc new file mode 100644 index 000000000..5696c3bf9 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/rocksdb_table_op.cc @@ -0,0 +1,1583 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#if __cplusplus >= 201703L +#include +#else +#include +#endif +#include "../utils/utils.h" +#include "rocksdb/db.h" +#include "rocksdb_table_op.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { +namespace recommenders_addons { +namespace lookup_rocksdb { + +static const size_t BATCH_SIZE_MIN = 2; +static const size_t BATCH_SIZE_MAX = 128; + +static const uint32_t FILE_MAGIC = + ( // TODO: Little endian / big endian conversion? + (static_cast('R') << 0) | (static_cast('O') << 8) | + (static_cast('C') << 16) | + (static_cast('K') << 24)); +static const uint32_t FILE_VERSION = 1; + +typedef uint16_t KEY_SIZE_TYPE; +typedef uint32_t VALUE_SIZE_TYPE; +typedef uint32_t STRING_SIZE_TYPE; + +#define ROCKSDB_OK(EXPR) \ + do { \ + const ROCKSDB_NAMESPACE::Status s = (EXPR); \ + if (!s.ok()) { \ + std::ostringstream msg; \ + msg << "RocksDB error " << s.code() << "; reason: " << s.getState() \ + << "; expr: " << #EXPR; \ + throw std::runtime_error(msg.str()); \ + } \ + } while (0) + +namespace _if { + +template +inline void put_key(ROCKSDB_NAMESPACE::Slice &dst, const T *src) { + dst.data_ = reinterpret_cast(src); + dst.size_ = sizeof(T); +} + +template <> +inline void put_key(ROCKSDB_NAMESPACE::Slice &dst, + const tstring *src) { + dst.data_ = src->data(); + dst.size_ = src->size(); +} + +template +inline void get_value(T *dst, const std::string &src, const size_t &n) { + const size_t dst_size = n * sizeof(T); + + if (src.size() < dst_size) { + std::ostringstream msg; + msg << "Expected " << n * sizeof(T) << " bytes, but only " << src.size() + << " bytes were returned by the database."; + throw std::runtime_error(msg.str()); + } else if (src.size() > dst_size) { + LOG(WARNING) << "Expected " << dst_size << " bytes. The database returned " + << src.size() << ", which is more. Truncating!"; + } + + std::memcpy(dst, src.data(), dst_size); +} + +template <> +inline void get_value(tstring *dst, const std::string &src_, + const size_t &n) { + const char *src = src_.data(); + const char *const src_end = &src[src_.size()]; + const tstring *const dst_end = &dst[n]; + + for (; dst != dst_end; ++dst) { + const char *const src_size = src; + src += sizeof(STRING_SIZE_TYPE); + if (src > src_end) { + throw std::out_of_range("String value is malformed!"); + } + const auto &size = *reinterpret_cast(src_size); + + const char *const src_data = src; + src += size; + if (src > src_end) { + throw std::out_of_range("String value is malformed!"); + } + dst->assign(src_data, size); + } + + if (src != src_end) { + throw std::runtime_error( + "Database returned more values than the destination tensor could " + "absorb."); + } +} + +template +inline void put_value(ROCKSDB_NAMESPACE::PinnableSlice &dst, const T *src, + const size_t &n) { + dst.data_ = reinterpret_cast(src); + dst.size_ = sizeof(T) * n; +} + +template <> +inline void put_value(ROCKSDB_NAMESPACE::PinnableSlice &dst_, + const tstring *src, const size_t &n) { + std::string &dst = *dst_.GetSelf(); + dst.clear(); + + // Concatenate the strings. + const tstring *const src_end = &src[n]; + for (; src != src_end; ++src) { + if (src->size() > std::numeric_limits::max()) { + throw std::runtime_error("String value is too large."); + } + const auto size = static_cast(src->size()); + dst.append(reinterpret_cast(&size), sizeof(size)); + dst.append(src->data(), size); + } + + dst_.PinSelf(); +} + +template +inline void add_value(ROCKSDB_NAMESPACE::PinnableSlice &dst, const T *src, const size_t &n) { + const T *acc = reinterpret_cast(dst.data()); + const T *const acc_end = &acc[n]; + for (; acc != acc_end; acc++, src++) { + *acc += *src; + } +} + +template <> +inline void add_value(ROCKSDB_NAMESPACE::PinnableSlice &dst, const tstring *src, const size_t &n) { + throw std::runtime_error("String vectors cannot be accumulated!"); +} + +} // namespace _if + +namespace _io { + +template +inline void read(std::istream &src, T &dst) { + if (!src.read(reinterpret_cast(&dst), sizeof(T))) { + throw std::overflow_error("Unexpected end of file!"); + } +} + +template +inline T read(std::istream &src) { + T tmp; + read(src, tmp); + return tmp; +} + +template +inline void write(std::ostream &dst, const T &src) { + if (!dst.write(reinterpret_cast(&src), sizeof(T))) { + throw std::runtime_error("Writing file failed!"); + } +} + +template +inline void read_key(std::istream &src, std::string *dst) { + dst->resize(sizeof(T)); + if (!src.read(&dst->front(), sizeof(T))) { + throw std::overflow_error("Unexpected end of file!"); + } +} + +template <> +inline void read_key(std::istream &src, std::string *dst) { + const auto size = read(src); + dst->resize(size); + if (!src.read(&dst->front(), size)) { + throw std::overflow_error("Unexpected end of file!"); + } +} + +template +inline void write_key(std::ostream &dst, const ROCKSDB_NAMESPACE::Slice &src) { + write(dst, *reinterpret_cast(src.data())); +} + +template <> +inline void write_key(std::ostream &dst, + const ROCKSDB_NAMESPACE::Slice &src) { + if (src.size() > std::numeric_limits::max()) { + throw std::overflow_error("String key is too long for RDB_KEY_SIZE_TYPE."); + } + const auto size = static_cast(src.size()); + write(dst, size); + if (!dst.write(src.data(), size)) { + throw std::runtime_error("Writing file failed!"); + } +} + +inline void read_value(std::istream &src, std::string *dst) { + const auto size = read(src); + dst->resize(size); + if (!src.read(&dst->front(), size)) { + throw std::overflow_error("Unexpected end of file!"); + } +} + +inline void write_value(std::ostream &dst, + const ROCKSDB_NAMESPACE::Slice &src) { + const auto size = static_cast(src.size()); + write(dst, size); + if (!dst.write(src.data(), size)) { + throw std::runtime_error("Writing file failed!"); + } +} + +} // namespace _io + +namespace _it { + +template +inline void read_key(std::vector &dst, const ROCKSDB_NAMESPACE::Slice &src) { + if (src.size() != sizeof(T)) { + std::ostringstream msg; + msg << "Key size is out of bounds [ " << src.size() << " != " << sizeof(T) + << " ]."; + throw std::out_of_range(msg.str()); + } + dst.emplace_back(*reinterpret_cast(src.data())); +} + +template <> +inline void read_key(std::vector &dst, + const ROCKSDB_NAMESPACE::Slice &src) { + if (src.size() > std::numeric_limits::max()) { + std::ostringstream msg; + msg << "Key size is out of bounds [ " << src.size() << " > " + << std::numeric_limits::max() << " ]."; + throw std::out_of_range(msg.str()); + } + dst.emplace_back(src.data(), src.size()); +} + +template +inline size_t read_value(std::vector &dst, + const ROCKSDB_NAMESPACE::Slice &src_, + const size_t &n_limit) { + const size_t n = src_.size() / sizeof(T); + + if (n * sizeof(T) != src_.size()) { + std::ostringstream msg; + msg << "Vector value is out of bounds [ " << n * sizeof(T) + << " != " << src_.size() << " ]."; + throw std::out_of_range(msg.str()); + } else if (n < n_limit) { + throw std::underflow_error("Database entry violates nLimit."); + } + + const T *const src = reinterpret_cast(src_.data()); + dst.insert(dst.end(), src, &src[n_limit]); + return n; +} + +template <> +inline size_t read_value(std::vector &dst, + const ROCKSDB_NAMESPACE::Slice &src_, + const size_t &n_limit) { + size_t n = 0; + + const char *src = src_.data(); + const char *const src_end = &src[src_.size()]; + + for (; src < src_end; ++n) { + const char *const src_size = src; + src += sizeof(STRING_SIZE_TYPE); + if (src > src_end) { + throw std::out_of_range("String value is malformed!"); + } + const auto &size = *reinterpret_cast(src_size); + + const char *const src_data = src; + src += size; + if (src > src_end) { + throw std::out_of_range("String value is malformed!"); + } + if (n < n_limit) { + dst.emplace_back(src_data, size); + } + } + + if (src != src_end) { + throw std::out_of_range("String value is malformed!"); + } else if (n < n_limit) { + throw std::underflow_error("Database entry violates nLimit."); + } + return n; +} + +} // namespace _it + +class DBWrapper final { + public: + DBWrapper(const std::string &path, const bool &read_only) + : path_(path), read_only_(read_only), database_(nullptr) { + ROCKSDB_NAMESPACE::Options options; + options.create_if_missing = !read_only; + options.manual_wal_flush = false; + + // Create or connect to the RocksDB database. + std::vector column_names; +#if __cplusplus >= 201703L + if (!std::filesystem::exists(path)) { + colFamilies.push_back(ROCKSDB_NAMESPACE::kDefaultColumnFamilyName); + } else if (std::filesystem::is_directory(path)) { + ROCKSDB_OK(ROCKSDB_NAMESPACE::DB::ListColumnFamilies(options, path, + &column_names)); + } else { + throw std::runtime_error("Provided database path is invalid."); + } +#else + struct stat db_path_stat {}; + if (stat(path.c_str(), &db_path_stat) == 0) { + if (S_ISDIR(db_path_stat.st_mode)) { + ROCKSDB_OK(ROCKSDB_NAMESPACE::DB::ListColumnFamilies(options, path, + &column_names)); + } else { + throw std::runtime_error("Provided database path is invalid."); + } + } else { + column_names.push_back(ROCKSDB_NAMESPACE::kDefaultColumnFamilyName); + } +#endif + + ROCKSDB_NAMESPACE::ColumnFamilyOptions column_options; + std::vector column_descriptors; + for (const auto &column_name : column_names) { + column_descriptors.emplace_back(column_name, column_options); + } + + ROCKSDB_NAMESPACE::DB *db; + std::vector column_handles; + if (read_only) { + ROCKSDB_OK(ROCKSDB_NAMESPACE::DB::OpenForReadOnly( + options, path, column_descriptors, &column_handles, &db)); + } else { + ROCKSDB_OK(ROCKSDB_NAMESPACE::DB::Open(options, path, column_descriptors, + &column_handles, &db)); + } + database_.reset(db); + + // Maintain map of the available column handles for quick access. + for (const auto &column_handle : column_handles) { + column_handles_[column_handle->GetName()] = column_handle; + } + + LOG(INFO) << "Connected to database \'" << path_ << "\'."; + } + + ~DBWrapper() { + for (const auto &column_handle : column_handles_) { + if (!read_only_) { + database_->FlushWAL(true); + } + database_->DestroyColumnFamilyHandle(column_handle.second); + } + column_handles_.clear(); + database_.reset(); + LOG(INFO) << "Disconnected from database \'" << path_ << "\'."; + } + + inline ROCKSDB_NAMESPACE::DB *database() { return database_.get(); } + + inline const std::string &path() const { return path_; } + + inline bool read_only() const { return read_only_; } + + void DeleteColumn(const std::string &column_name) { + mutex_lock guard(lock_); + + // Try to locate column handle, and return if it anyway doe not exist. + const auto &item = column_handles_.find(column_name); + if (item == column_handles_.end()) { + return; + } + + // If a modification would be required make sure we are not in readonly + // mode. + if (read_only_) { + throw std::runtime_error("Cannot delete a column in read-only mode."); + } + + // Perform actual removal. + ROCKSDB_NAMESPACE::ColumnFamilyHandle *column_handle = item->second; + ROCKSDB_OK(database_->DropColumnFamily(column_handle)); + ROCKSDB_OK(database_->DestroyColumnFamilyHandle(column_handle)); + column_handles_.erase(column_name); + } + + template + T WithColumn( + const std::string &column_name, + std::function fn) { + mutex_lock guard(lock_); + + ROCKSDB_NAMESPACE::ColumnFamilyHandle *column_handle; + + // Try to locate column handle. + const auto &item = column_handles_.find(column_name); + if (item != column_handles_.end()) { + column_handle = item->second; + } + // Do not create an actual column handle in readonly mode. + else if (read_only_) { + column_handle = nullptr; + } + // Create a new column handle. + else { + ROCKSDB_NAMESPACE::ColumnFamilyOptions colFamilyOptions; + ROCKSDB_OK(database_->CreateColumnFamily(colFamilyOptions, column_name, + &column_handle)); + column_handles_[column_name] = column_handle; + } + + return fn(database_.get(), column_handle); + } + + // inline ROCKSDB_NAMESPACE::DB *operator->() { return database_.get(); } + + private: + const std::string path_; + const bool read_only_; + std::unique_ptr database_; + + mutex lock_; + std::unordered_map + column_handles_; +}; + +class DBWrapperRegistry final { + public: + static DBWrapperRegistry &instance() { + static DBWrapperRegistry instance; + return instance; + } + + private: + DBWrapperRegistry() = default; + + ~DBWrapperRegistry() = default; + + public: + std::shared_ptr connect(const std::string &databasePath, + const bool &readOnly) { + mutex_lock guard(lock); + + // Try to find database, or open it if it is not open yet. + std::shared_ptr db; + auto pos = wrappers.find(databasePath); + if (pos != wrappers.end()) { + db = pos->second.lock(); + } else { + db.reset(new DBWrapper(databasePath, readOnly), deleter); + wrappers[databasePath] = db; + } + + // Suicide, if the desired access level is below the available access level. + if (readOnly < db->read_only()) { + throw std::runtime_error( + "Cannot simultaneously open database in read + write mode."); + } + + return db; + } + + private: + static void deleter(DBWrapper *wrapper) { + static std::default_delete default_deleter; + + DBWrapperRegistry ®istry = instance(); + const std::string path = wrapper->path(); + + // Make sure we are alone. + mutex_lock guard(registry.lock); + + // Destroy the wrapper. + default_deleter(wrapper); + // LOG(INFO) << "Database wrapper " << path << " has been deleted."; + + // Locate the corresponding weak_ptr and evict it. + auto pos = registry.wrappers.find(path); + if (pos == registry.wrappers.end()) { + LOG(ERROR) << "Unknown database wrapper. How?"; + } else if (pos->second.expired()) { + registry.wrappers.erase(pos); + // LOG(INFO) << "Database wrapper " << path << " evicted."; + } else { + LOG(ERROR) << "Registry is in an inconsistent state. This is very bad..."; + } + } + + private: + mutex lock; + std::unordered_map> wrappers; +}; + +template +class RocksDBTableOfTensors final : public PersistentStorageLookupInterface { + public: + /* --- BASE INTERFACE ----------------------------------------------------- */ + RocksDBTableOfTensors(OpKernelContext *ctx, OpKernel *kernel) + : read_only_(false), estimate_size_(false), dirty_count_(0) { + OP_REQUIRES_OK(ctx, + GetNodeAttr(kernel->def(), "value_shape", &value_shape_)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(value_shape_), + errors::InvalidArgument("Default value must be a vector, got shape ", + value_shape_.DebugString())); + OP_REQUIRES_OK( + ctx, GetNodeAttr(kernel->def(), "database_path", &database_path_)); + OP_REQUIRES_OK( + ctx, GetNodeAttr(kernel->def(), "embedding_name", &embedding_name_)); + OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "read_only", &read_only_)); + OP_REQUIRES_OK( + ctx, GetNodeAttr(kernel->def(), "estimate_size", &estimate_size_)); + flush_interval_ = 1; + OP_REQUIRES_OK( + ctx, GetNodeAttr(kernel->def(), "export_path", &default_export_path_)); + + db_ = DBWrapperRegistry::instance().connect(database_path_, read_only_); + LOG(INFO) << "Acquired reference to database wrapper " << db_->path() + << " [ #refs = " << db_.use_count() << " ]."; + } + + ~RocksDBTableOfTensors() override { + LOG(INFO) << "Dropping reference to database wrapper " << db_->path() + << " [ #refs = " << db_.use_count() << " ]."; + } + + DataType key_dtype() const override { return DataTypeToEnum::v(); } + TensorShape key_shape() const override { return TensorShape{}; } + + DataType value_dtype() const override { return DataTypeToEnum::v(); } + TensorShape value_shape() const override { return value_shape_; } + + int64_t MemoryUsed() const override { + size_t mem_size = 0; + + mem_size += sizeof(RocksDBTableOfTensors); + mem_size += sizeof(ROCKSDB_NAMESPACE::DB); + + db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) { + uint64_t tmp; + + if (db->GetIntProperty(column_handle, ROCKSDB_NAMESPACE::DB::Properties::kBlockCacheUsage, &tmp)) { + mem_size += tmp; + } + + + if (db->GetIntProperty(column_handle, ROCKSDB_NAMESPACE::DB::Properties::kEstimateTableReadersMem, &tmp)) { + mem_size += tmp; + } + + if (db->GetIntProperty(column_handle, ROCKSDB_NAMESPACE::DB::Properties::kCurSizeAllMemTables, &tmp)) { + mem_size += tmp; + } + + if (db->GetIntProperty(column_handle, ROCKSDB_NAMESPACE::DB::Properties::kBlockCachePinnedUsage, &tmp)) { + mem_size += tmp; + } + }); + + return static_cast(mem_size); + } + + size_t size() const override { + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> size_t { + // Empty database. + if (!column_handle) { + return 0; + } + + // If allowed, try to just estimate of the number of keys. + if (estimate_size_) { + uint64_t num_keys; + if (db->GetIntProperty( + column_handle, + ROCKSDB_NAMESPACE::DB::Properties::kEstimateNumKeys, + &num_keys)) { + return num_keys; + } + } + + // Alternative method, walk the entire database column and count the keys. + std::unique_ptr iter( + db->NewIterator(read_options_, column_handle)); + iter->SeekToFirst(); + + size_t num_keys = 0; + for (; iter->Valid(); iter->Next()) { + ++num_keys; + } + return num_keys; + }); + } + + /* --- LOOKUP ------------------------------------------------------------- */ + Status Accum(OpKernelContext *ctx, const Tensor &keys, const Tensor &values_or_delta, const Tensor &exists) { + if (keys.dtype() != key_dtype() || values_or_delta.dtype() != value_dtype()) { + return errors::InvalidArgument("The tensor dtypes are incompatible."); + } + if (keys.dims() > std::min(values_or_delta.dims(), exists.dims())) { + return errors::InvalidArgument("The tensor sizes are incompatible."); + } + for (int i = 0; i < keys.dims(); ++i) { + if (keys.dim_size(i) != values_or_delta.dim_size(i) || keys.dim_size(i) != exists.dim_size(i)) { + return errors::InvalidArgument("The tensor sizes are incompatible."); + } + } + if (keys.NumElements() == 0) { + return Status::OK(); + } + + const size_t num_keys = keys.NumElements(); + const size_t num_values = values_or_delta.NumElements(); + const size_t values_per_key = num_values / std::max(num_keys, 1UL); + + const K *const k = static_cast(keys.data()); + const V *const v = static_cast(values_or_delta.data()); + auto exists_flat = exists.flat(); + + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (!column_handle) { + } else if (num_keys < BATCH_SIZE_MIN) { + ROCKSDB_NAMESPACE::Slice k_slice; + + rocksdb::PinnableSlice v_slice; + for (size_t i = 0, offset = 0; i < num_keys; ++i, offset += values_per_key) { + _if::put_key(k_slice, &k[i]); + + const auto &status = + db->Get(read_options_, column_handle, k_slice, &v_slice); + + if (status.ok()) { + _if::add_value(v_slice, &v[offset], values_per_key); + ROCKSDB_OK(db->Put(write_options_, column_handle, k_slice, v_slice)); + exists_flat(i) = true; + } else if (status.IsNotFound()) { + exists_flat(i) = false; + } else { + throw std::runtime_error(status.getState()); + } + } + } else { + // There is no point in filling this vector every time as long as it is + // big enough. + if (!column_handle_cache_.empty() && + column_handle_cache_.front() != column_handle) { + std::fill(column_handle_cache_.begin(), column_handle_cache_.end(), + column_handle); + } + if (column_handle_cache_.size() < num_keys) { + column_handle_cache_.insert(column_handle_cache_.end(), + num_keys - column_handle_cache_.size(), + column_handle); + } + + // Query all keys using a single Multi-Get. + std::vector k_slices{num_keys}; + for (size_t i = 0; i < num_keys; ++i) { + _if::put_key(k_slices[i], &k[i]); + } + std::vector v_slices; + + const auto &s = db->MultiGet(read_options_, column_handle_cache_, + k_slices, &v_slices); + if (s.size() != num_keys) { + std::ostringstream msg; + msg << "Requested " << num_keys << " keys, but only got " << s.size() + << " responses."; + throw std::runtime_error(msg.str()); + } + + // Process results. + for (size_t i = 0, offset = 0; i < num_keys; + ++i, offset += values_per_key) { + const auto &status = s[i]; + const auto &v_slice = v_slices[i]; + + if (status.ok()) { + _if::add_value(v_slice, &v[offset], values_per_key); + ROCKSDB_OK(db->Put(write_options_, column_handle, k_slices[i], v_slice)); + exists_flat(i) = true; + } else if (status.IsNotFound()) { + exists_flat(i) = false; + } else { + throw std::runtime_error(status.getState()); + } + } + } + + return Status::OK(); + }); + } + + Status Clear(OpKernelContext *ctx) override { + if (read_only_) { + return errors::PermissionDenied("Cannot clear in read_only mode."); + } + db_->DeleteColumn(embedding_name_); + return Status::OK(); + } + + Status Find(OpKernelContext *ctx, const Tensor &keys, Tensor *values, + const Tensor &default_value) override { + if (keys.dtype() != key_dtype() || values->dtype() != value_dtype() || + default_value.dtype() != value_dtype()) { + return errors::InvalidArgument("The tensor dtypes are incompatible."); + } + if (keys.dims() > values->dims()) { + return errors::InvalidArgument("The tensor sizes are incompatible."); + } + for (int i = 0; i < keys.dims(); ++i) { + if (keys.dim_size(i) != values->dim_size(i)) { + return errors::InvalidArgument("The tensor sizes are incompatible."); + } + } + if (keys.NumElements() == 0) { + return Status::OK(); + } + + const size_t num_keys = keys.NumElements(); + const size_t num_values = values->NumElements(); + const size_t values_per_key = num_values / std::max(num_keys, 1UL); + const size_t default_size = default_value.NumElements(); + if (default_size % values_per_key != 0) { + std::ostringstream msg; + msg << "The shapes of the 'values' and 'default_value' tensors are " + "incompatible" + << " (" << default_size << " % " << values_per_key << " != 0)."; + return errors::InvalidArgument(msg.str()); + } + + const K *k = static_cast(keys.data()); + V *const v = static_cast(values->data()); + const V *const d = static_cast(default_value.data()); + + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (!column_handle) { + const K *const k_end = &k[num_keys]; + for (size_t offset = 0; k != k_end; ++k, offset += values_per_key) { + std::copy_n(&d[offset % default_size], values_per_key, &v[offset]); + } + } else if (num_keys < BATCH_SIZE_MIN) { + ROCKSDB_NAMESPACE::Slice k_slice; + + std::string v_slice; + for (size_t i = 0, offset = 0; i < num_keys; ++i, offset += values_per_key) { + _if::put_key(k_slice, &k[i]); + + v_slice.clear(); + const auto &status = + db->Get(read_options_, column_handle, k_slice, &v_slice); + + if (status.ok()) { + _if::get_value(&v[offset], v_slice, values_per_key); + } else if (status.IsNotFound()) { + std::copy_n(&d[offset % default_size], values_per_key, &v[offset]); + } else { + throw std::runtime_error(status.getState()); + } + } + } else { + // There is no point in filling this vector every time as long as it is + // big enough. + if (!column_handle_cache_.empty() && + column_handle_cache_.front() != column_handle) { + std::fill(column_handle_cache_.begin(), column_handle_cache_.end(), + column_handle); + } + if (column_handle_cache_.size() < num_keys) { + column_handle_cache_.insert(column_handle_cache_.end(), + num_keys - column_handle_cache_.size(), + column_handle); + } + + // Query all keys using a single Multi-Get. + std::vector k_slices{num_keys}; + for (size_t i = 0; i < num_keys; ++i) { + _if::put_key(k_slices[i], &k[i]); + } + std::vector v_slices; + + const auto &s = db->MultiGet(read_options_, column_handle_cache_, + k_slices, &v_slices); + if (s.size() != num_keys) { + std::ostringstream msg; + msg << "Requested " << num_keys << " keys, but only got " << s.size() + << " responses."; + throw std::runtime_error(msg.str()); + } + + // Process results. + for (size_t i = 0, offset = 0; i < num_keys; + ++i, offset += values_per_key) { + const auto &status = s[i]; + const auto &v_slice = v_slices[i]; + + if (status.ok()) { + _if::get_value(&v[offset], v_slice, values_per_key); + } else if (status.IsNotFound()) { + std::copy_n(&d[offset % default_size], values_per_key, &v[offset]); + } else { + throw std::runtime_error(status.getState()); + } + } + } + + return Status::OK(); + }); + } + + Status FindWithExists(OpKernelContext *ctx, const Tensor &keys, + Tensor *values, const Tensor &default_value, + Tensor &exists) { + if (keys.dtype() != key_dtype() || values->dtype() != value_dtype() || + default_value.dtype() != value_dtype()) { + return errors::InvalidArgument("The tensor dtypes are incompatible."); + } + if (keys.dims() > std::min(values->dims(), exists.dims())) { + return errors::InvalidArgument("The tensor sizes are incompatible."); + } + for (int i = 0; i < keys.dims(); ++i) { + if (keys.dim_size(i) != values->dim_size(i) || keys.dim_size(i) != exists.dim_size(i)) { + return errors::InvalidArgument("The tensor sizes are incompatible."); + } + } + if (keys.NumElements() == 0) { + return Status::OK(); + } + + const size_t num_keys = keys.NumElements(); + const size_t num_values = values->NumElements(); + const size_t values_per_key = num_values / std::max(num_keys, 1UL); + const size_t default_size = default_value.NumElements(); + if (default_size % values_per_key != 0) { + std::ostringstream msg; + msg << "The shapes of the 'values' and 'default_value' tensors are " + "incompatible" + << " (" << default_size << " % " << values_per_key << " != 0)."; + return errors::InvalidArgument(msg.str()); + } + + const K *k = static_cast(keys.data()); + V *const v = static_cast(values->data()); + const V *const d = static_cast(default_value.data()); + auto exists_flat = exists.flat(); + + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (!column_handle) { + const K *const k_end = &k[num_keys]; + for (size_t offset = 0; k != k_end; ++k, offset += values_per_key) { + std::copy_n(&d[offset % default_size], values_per_key, &v[offset]); + } + } else if (num_keys < BATCH_SIZE_MIN) { + ROCKSDB_NAMESPACE::Slice k_slice; + + std::string v_slice; + for (size_t i = 0, offset = 0; i < num_keys; ++i, offset += values_per_key) { + _if::put_key(k_slice, &k[i]); + + v_slice.clear(); + const auto &status = + db->Get(read_options_, column_handle, k_slice, &v_slice); + + if (status.ok()) { + _if::get_value(&v[offset], v_slice, values_per_key); + exists_flat(i) = true; + } else if (status.IsNotFound()) { + std::copy_n(&d[offset % default_size], values_per_key, &v[offset]); + exists_flat(i) = false; + } else { + throw std::runtime_error(status.getState()); + } + } + } else { + // There is no point in filling this vector every time as long as it is + // big enough. + if (!column_handle_cache_.empty() && + column_handle_cache_.front() != column_handle) { + std::fill(column_handle_cache_.begin(), column_handle_cache_.end(), + column_handle); + } + if (column_handle_cache_.size() < num_keys) { + column_handle_cache_.insert(column_handle_cache_.end(), + num_keys - column_handle_cache_.size(), + column_handle); + } + + // Query all keys using a single Multi-Get. + std::vector k_slices{num_keys}; + for (size_t i = 0; i < num_keys; ++i) { + _if::put_key(k_slices[i], &k[i]); + } + std::vector v_slices; + + const auto &s = db->MultiGet(read_options_, column_handle_cache_, + k_slices, &v_slices); + if (s.size() != num_keys) { + std::ostringstream msg; + msg << "Requested " << num_keys << " keys, but only got " << s.size() + << " responses."; + throw std::runtime_error(msg.str()); + } + + // Process results. + for (size_t i = 0, offset = 0; i < num_keys; + ++i, offset += values_per_key) { + const auto &status = s[i]; + const auto &v_slice = v_slices[i]; + + if (status.ok()) { + _if::get_value(&v[offset], v_slice, values_per_key); + exists_flat(i) = true; + } else if (status.IsNotFound()) { + std::copy_n(&d[offset % default_size], values_per_key, &v[offset]); + exists_flat(i) = false; + } else { + throw std::runtime_error(status.getState()); + } + } + } + + return Status::OK(); + }); + } + + Status Insert(OpKernelContext *ctx, const Tensor &keys, + const Tensor &values) override { + if (keys.dtype() != key_dtype() || values.dtype() != value_dtype()) { + return errors::InvalidArgument("The tensor dtypes are incompatible!"); + } + if (keys.dims() <= values.dims()) { + for (int i = 0; i < keys.dims(); ++i) { + if (keys.dim_size(i) != values.dim_size(i)) { + return errors::InvalidArgument("The tensor sizes are incompatible!"); + } + } + } else { + return errors::InvalidArgument("The tensor sizes are incompatible!"); + } + + const size_t num_keys = keys.NumElements(); + const size_t num_values = values.NumElements(); + const size_t values_per_key = num_values / std::max(num_keys, 1UL); + if (values_per_key != static_cast(value_shape_.num_elements())) { + LOG(WARNING) + << "The number of values provided does not match the signature (" + << values_per_key << " != " << value_shape_.num_elements() << ")."; + } + + const K *k = static_cast(keys.data()); + const V *v = static_cast(values.data()); + + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (read_only_ || !column_handle) { + return errors::PermissionDenied("Cannot insert in read_only mode."); + } + + const K *const k_end = &k[num_keys]; + ROCKSDB_NAMESPACE::Slice k_slice; + ROCKSDB_NAMESPACE::PinnableSlice v_slice; + + if (num_keys < BATCH_SIZE_MIN) { + for (; k != k_end; ++k, v += values_per_key) { + _if::put_key(k_slice, k); + _if::put_value(v_slice, v, values_per_key); + ROCKSDB_OK( + db->Put(write_options_, column_handle, k_slice, v_slice)); + } + } else { + ROCKSDB_NAMESPACE::WriteBatch batch; + for (; k != k_end; ++k, v += values_per_key) { + _if::put_key(k_slice, k); + _if::put_value(v_slice, v, values_per_key); + ROCKSDB_OK(batch.Put(column_handle, k_slice, v_slice)); + } + ROCKSDB_OK(db->Write(write_options_, &batch)); + } + + // Handle interval flushing. + dirty_count_ += 1; + if (dirty_count_ % flush_interval_ == 0) { + ROCKSDB_OK(db->FlushWAL(true)); + } + + return Status::OK(); + }); + } + + Status Remove(OpKernelContext *ctx, const Tensor &keys) override { + if (keys.dtype() != key_dtype()) { + return errors::InvalidArgument("Tensor dtypes are incompatible!"); + } + + const size_t num_keys = keys.dim_size(0); + const K *k = static_cast(keys.data()); + + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (read_only_ || !column_handle) { + return errors::PermissionDenied("Cannot remove in read_only mode."); + } + + const K *const k_end = &k[num_keys]; + ROCKSDB_NAMESPACE::Slice k_slice; + + if (num_keys < BATCH_SIZE_MIN) { + for (; k != k_end; ++k) { + _if::put_key(k_slice, k); + ROCKSDB_OK(db->Delete(write_options_, column_handle, k_slice)); + } + } else { + ROCKSDB_NAMESPACE::WriteBatch batch; + for (; k != k_end; ++k) { + _if::put_key(k_slice, k); + ROCKSDB_OK(batch.Delete(column_handle, k_slice)); + } + ROCKSDB_OK(db->Write(write_options_, &batch)); + } + + // Handle interval flushing. + dirty_count_ += 1; + if (dirty_count_ % flush_interval_ == 0) { + ROCKSDB_OK(db->FlushWAL(true)); + } + + return Status::OK(); + }); + } + + /* --- IMPORT / EXPORT ---------------------------------------------------- */ + Status ExportValues(OpKernelContext *ctx) override { + if (default_export_path_.empty()) { + return ExportValuesToTensor(ctx); + } else { + return ExportValuesToFile(ctx, default_export_path_); + } + } + Status ImportValues(OpKernelContext *ctx, const Tensor &keys, + const Tensor &values) override { + if (default_export_path_.empty()) { + return ImportValuesFromTensor(ctx, keys, values); + } else { + return ImportValuesFromFile(ctx, default_export_path_); + } + } + + Status ExportValuesToFile(OpKernelContext *ctx, const std::string &path) { + const auto &status = db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB *const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + std::ofstream file(path + "/" + embedding_name_ + ".rock", + std::ofstream::binary); + if (!file) { + return errors::Unknown("Could not open dump file."); + } + + // Create file header. + _io::write(file, FILE_MAGIC); + _io::write(file, FILE_VERSION); + _io::write(file, key_dtype()); + _io::write(file, value_dtype()); + + // Iterate through entries one-by-one and append them to the file. + if (column_handle) { + std::unique_ptr iter( + db->NewIterator(read_options_, column_handle)); + iter->SeekToFirst(); + + for (; iter->Valid(); iter->Next()) { + _io::write_key(file, iter->key()); + _io::write_value(file, iter->value()); + } + } + + return Status::OK(); + }); + if (!status.ok()) { + return status; + } + + // Creat dummy tensors. + Tensor *k_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_output("keys", TensorShape({0}), &k_tensor)); + + Tensor *v_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output( + "values", TensorShape({0, value_shape_.num_elements()}), &v_tensor)); + + return status; + } + Status ImportValuesFromFile(OpKernelContext *ctx, const std::string &path) { + // Make sure the column family is clean. + const auto &clear_status = Clear(ctx); + if (!clear_status.ok()) { + return clear_status; + } + + return db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB *const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (read_only_ || !column_handle) { + return errors::PermissionDenied("Cannot import in read_only mode."); + } + + std::ifstream file(path + "/" + embedding_name_ + ".rock", + std::ifstream::binary); + if (!file) { + return errors::NotFound("Accessing file system failed."); + } + + // Parse header. + const auto magic = _io::read(file); + if (magic != FILE_MAGIC) { + return errors::Unknown("Not a RocksDB export file."); + } + const auto version = _io::read(file); + if (version != FILE_VERSION) { + return errors::Unimplemented("File version ", version, + " is not supported"); + } + const auto k_dtype = _io::read(file); + const auto v_dtype = _io::read(file); + if (k_dtype != key_dtype() || v_dtype != value_dtype()) { + return errors::Internal("DataType of file [k=", k_dtype, + ", v=", v_dtype, "] ", + "do not match module DataType [k=", key_dtype(), + ", v=", value_dtype(), "]."); + } + + // Read payload and subsequently populate column family. + ROCKSDB_NAMESPACE::WriteBatch batch; + + ROCKSDB_NAMESPACE::PinnableSlice k_slice; + ROCKSDB_NAMESPACE::PinnableSlice v_slice; + + while (file.peek() != EOF) { + _io::read_key(file, k_slice.GetSelf()); + k_slice.PinSelf(); + _io::read_value(file, v_slice.GetSelf()); + v_slice.PinSelf(); + + ROCKSDB_OK(batch.Put(column_handle, k_slice, v_slice)); + + // If batch reached target size, write to database. + if (batch.Count() >= BATCH_SIZE_MAX) { + ROCKSDB_OK(db->Write(write_options_, &batch)); + batch.Clear(); + } + } + + // Write remaining entries, if any. + if (batch.Count()) { + ROCKSDB_OK(db->Write(write_options_, &batch)); + } + + // Handle interval flushing. + dirty_count_ += 1; + if (dirty_count_ % flush_interval_ == 0) { + ROCKSDB_OK(db->FlushWAL(true)); + } + + return Status::OK(); + }); + } + + Status ExportValuesToTensor(OpKernelContext *ctx) { + // Fetch data from database. + std::vector k_buffer; + std::vector v_buffer; + const size_t value_size = value_shape_.num_elements(); + size_t value_count = std::numeric_limits::max(); + + const auto &status = db_->WithColumn(embedding_name_, [&](ROCKSDB_NAMESPACE::DB* const db, ROCKSDB_NAMESPACE::ColumnFamilyHandle *const column_handle) + -> Status { + if (column_handle) { + std::unique_ptr iter( + db->NewIterator(read_options_, column_handle)); + iter->SeekToFirst(); + + for (; iter->Valid(); iter->Next()) { + const auto &k_slice = iter->key(); + _it::read_key(k_buffer, k_slice); + + const auto v_slice = iter->value(); + const size_t v_count = _it::read_value(v_buffer, v_slice, value_size); + + // Make sure we have a square tensor. + if (value_count == std::numeric_limits::max()) { + value_count = v_count; + } else if (v_count != value_count) { + return errors::Internal("The returned tensor sizes differ."); + } + } + } + + return Status::OK(); + }); + if (!status.ok()) { + return status; + } + + if (value_count != value_size) { + LOG(WARNING) << "Retrieved values differ from signature size (" + << value_count << " != " << value_size << ")."; + } + const auto numKeys = static_cast(k_buffer.size()); + + // Populate keys tensor. + Tensor *k_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_output("keys", TensorShape({numKeys}), &k_tensor)); + K *const k = reinterpret_cast(k_tensor->data()); + std::copy(k_buffer.begin(), k_buffer.end(), k); + + // Populate values tensor. + Tensor *v_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output( + "values", TensorShape({numKeys, static_cast(value_size)}), + &v_tensor)); + V *const v = reinterpret_cast(v_tensor->data()); + std::copy(v_buffer.begin(), v_buffer.end(), v); + + return status; + } + Status ImportValuesFromTensor(OpKernelContext *ctx, const Tensor &keys, + const Tensor &values) { + // Make sure the column family is clean. + const auto &clear_status = Clear(ctx); + if (!clear_status.ok()) { + return clear_status; + } + + // Just call normal insertion function. + return Insert(ctx, keys, values); + } + + protected: + TensorShape value_shape_; + std::string database_path_; + std::string embedding_name_; + bool read_only_; + bool estimate_size_; + size_t flush_interval_; + std::string default_export_path_; + + std::shared_ptr db_; + ROCKSDB_NAMESPACE::ReadOptions read_options_; + ROCKSDB_NAMESPACE::WriteOptions write_options_; + size_t dirty_count_; + + std::vector column_handle_cache_; +}; + +#undef ROCKSDB_OK + +/* --- KERNEL REGISTRATION -------------------------------------------------- */ +#define ROCKSDB_REGISTER_KERNEL_BUILDER(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name(PREFIX_OP_NAME(RocksdbTableOfTensors)) \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + RocksDBTableOp, key_dtype, \ + value_dtype>) + +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, bool); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, int8); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, int16); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, int32); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, int64); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, Eigen::half); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, float); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, double); +ROCKSDB_REGISTER_KERNEL_BUILDER(int32, tstring); + +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, bool); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, int8); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, int16); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, int32); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, int64); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, Eigen::half); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, float); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, double); +ROCKSDB_REGISTER_KERNEL_BUILDER(int64, tstring); + +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, bool); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, int8); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, int16); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, int32); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, int64); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, Eigen::half); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, float); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, double); +ROCKSDB_REGISTER_KERNEL_BUILDER(tstring, tstring); + +#undef ROCKSDB_REGISTER_KERNEL_BUILDER +} // namespace lookup_rocksdb + +/* --- OP KERNELS ----------------------------------------------------------- */ +class RocksDBTableOpKernel : public OpKernel { + public: + explicit RocksDBTableOpKernel(OpKernelConstruction *ctx) + : OpKernel(ctx), + expected_input_0_(ctx->input_type(0) == DT_RESOURCE ? DT_RESOURCE + : DT_STRING_REF) {} + + protected: + Status LookupResource(OpKernelContext *ctx, const ResourceHandle &p, + LookupInterface **value) { + return ctx->resource_manager()->Lookup( + p.container(), p.name(), value); + } + + Status GetTableHandle(StringPiece input_name, OpKernelContext *ctx, + tstring *container, tstring *table_handle) { + { + mutex *guard; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &guard)); + mutex_lock lock(*guard); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Lookup table handle must be scalar, but had shape: ", + tensor.shape().DebugString()); + } + auto h = tensor.flat(); + *container = h(0); + *table_handle = h(1); + } + return Status::OK(); + } + + Status GetResourceHashTable(StringPiece input_name, OpKernelContext *ctx, + LookupInterface **table) { + const Tensor *handle_tensor; + TF_RETURN_IF_ERROR(ctx->input(input_name, &handle_tensor)); + const auto &handle = handle_tensor->scalar()(); + return LookupResource(ctx, handle, table); + } + + Status GetReferenceLookupTable(StringPiece input_name, OpKernelContext *ctx, + LookupInterface **table) { + tstring container; + tstring table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + return ctx->resource_manager()->Lookup(container, table_handle, table); + } + + Status GetTable(OpKernelContext *ctx, LookupInterface **table) { + if (expected_input_0_ == DT_RESOURCE) { + return GetResourceHashTable("table_handle", ctx, table); + } else { + return GetReferenceLookupTable("table_handle", ctx, table); + } + } + + protected: + const DataType expected_input_0_; +}; + +class RocksDBTableClear : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + auto *rocks_table = dynamic_cast(table); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, rocks_table->Clear(ctx)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +class RocksDBTableExport : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, table->ExportValues(ctx)); + } +}; + +class RocksDBTableFind : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(), + table->value_dtype()}; + DataTypeVector expected_outputs = {table->value_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + + const Tensor &key = ctx->input(1); + const Tensor &default_value = ctx->input(2); + + TensorShape output_shape = key.shape(); + output_shape.RemoveLastDims(table->key_shape().dims()); + output_shape.AppendShape(table->value_shape()); + Tensor *out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out)); + OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); + } +}; + +class RocksDBTableImport : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(), + table->value_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + const Tensor &keys = ctx->input(1); + const Tensor &values = ctx->input(2); + OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values)); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +class RocksDBTableInsert : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(), + table->value_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + const Tensor &keys = ctx->input(1); + const Tensor &values = ctx->input(2); + OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values)); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +class RocksDBTableRemove : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + const Tensor &key = ctx->input(1); + OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key)); + + int64 memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = table->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, table->Remove(ctx, key)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() - + memory_used_before); + } + } +}; + +class RocksDBTableSize : public RocksDBTableOpKernel { + public: + using RocksDBTableOpKernel::RocksDBTableOpKernel; + + void Compute(OpKernelContext *ctx) override { + LookupInterface *table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + Tensor *out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out)); + out->flat().setConstant(static_cast(table->size())); + } +}; + +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableClear)).Device(DEVICE_CPU), + RocksDBTableClear); +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableExport)).Device(DEVICE_CPU), + RocksDBTableExport); +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableFind)).Device(DEVICE_CPU), + RocksDBTableFind); +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableImport)).Device(DEVICE_CPU), + RocksDBTableImport); +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableInsert)).Device(DEVICE_CPU), + RocksDBTableInsert); +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableRemove)).Device(DEVICE_CPU), + RocksDBTableRemove); +REGISTER_KERNEL_BUILDER( + Name(PREFIX_OP_NAME(RocksdbTableSize)).Device(DEVICE_CPU), + RocksDBTableSize); + +} // namespace recommenders_addons +} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/rocksdb_table_op.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/rocksdb_table_op.h new file mode 100644 index 000000000..53b73eafe --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/rocksdb_table_op.h @@ -0,0 +1,125 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TFRA_CORE_KERNELS_ROCKSDB_TABLE_H_ +#define TFRA_CORE_KERNELS_ROCKSDB_TABLE_H_ + +#include "tensorflow/core/kernels/lookup_table_op.h" + +namespace tensorflow { +namespace recommenders_addons { + +using tensorflow::lookup::LookupInterface; + +class PersistentStorageLookupInterface : public LookupInterface { + public: + virtual Status Clear(OpKernelContext *ctx) = 0; +}; + +template +class RocksDBTableOp : public OpKernel { + public: + explicit RocksDBTableOp(OpKernelConstruction *ctx) + : OpKernel(ctx), table_handle_set_(false) { + if (ctx->output_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_RESOURCE, + tensorflow::TensorShape({}), + &table_handle_)); + } else { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING, + tensorflow::TensorShape({2}), + &table_handle_)); + } + + OP_REQUIRES_OK( + ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_)); + } + + void Compute(OpKernelContext *ctx) override { + mutex_lock l(mu_); + + if (!table_handle_set_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), + use_node_name_sharing_)); + } + + auto creator = + [ctx, this](LookupInterface **ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + LookupInterface *container = new Container(ctx, this); + if (!ctx->status().ok()) { + container->Unref(); + return ctx->status(); + } + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation( + container->MemoryUsed() + table_handle_.AllocatedBytes()); + } + *ret = container; + return Status::OK(); + }; + + LookupInterface *table = nullptr; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &table, creator)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, CheckTableDataTypes( + *table, DataTypeToEnum::v(), + DataTypeToEnum::v(), cinfo_.name())); + + if (ctx->expected_output_dtype(0) == DT_RESOURCE) { + if (!table_handle_set_) { + auto h = table_handle_.template scalar(); + h() = MakeResourceHandle(ctx, cinfo_.container(), + cinfo_.name()); + } + ctx->set_output(0, table_handle_); + } else { + if (!table_handle_set_) { + auto h = table_handle_.template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } + ctx->set_output_ref(0, &mu_, &table_handle_); + } + + table_handle_set_ = true; + } + + ~RocksDBTableOp() override { + if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete(cinfo_.container(), cinfo_.name()) + .ok()) { + // Took this over from other code, what should we do here? + } + } + } + + private: + mutex mu_; + Tensor table_handle_ TF_GUARDED_BY(mu_); + bool table_handle_set_ TF_GUARDED_BY(mu_); + ContainerInfo cinfo_; + bool use_node_name_sharing_; + + TF_DISALLOW_COPY_AND_ASSIGN(RocksDBTableOp); +}; + +} // namespace recommenders_addons +} // namespace tensorflow + +#endif // TFRA_CORE_KERNELS_ROCKSDB_TABLE_H_ diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/rocksdb_table_ops.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/rocksdb_table_ops.cc new file mode 100644 index 000000000..3c9a140c9 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/rocksdb_table_ops.cc @@ -0,0 +1,258 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeAndType; +using shape_inference::ShapeHandle; + +namespace { + +Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext *c) { + ShapeHandle handle; + DimensionHandle unused_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + for (int i = 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +} // namespace + +Status ValidateTableResourceHandle(InferenceContext *c, ShapeHandle keys, + const string &key_dtype_attr, + const string &value_dtype_attr, + bool is_lookup, + ShapeAndType *output_shape_and_type) { + auto *handle_data = c->input_handle_shapes_and_types(0); + if (handle_data == nullptr || handle_data->size() != 2) { + output_shape_and_type->shape = c->UnknownShape(); + output_shape_and_type->dtype = DT_INVALID; + } else { + const ShapeAndType &key_shape_and_type = (*handle_data)[0]; + const ShapeAndType &value_shape_and_type = (*handle_data)[1]; + DataType key_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); + if (key_shape_and_type.dtype != key_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(key_shape_and_type.dtype), " got ", + DataTypeString(key_dtype)); + } + + DataType value_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); + if (value_shape_and_type.dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(value_shape_and_type.dtype), " got ", + DataTypeString(value_dtype)); + } + output_shape_and_type->dtype = value_shape_and_type.dtype; + + if (is_lookup) { + if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) { + int keys_rank = c->Rank(keys); + int key_suffix_rank = c->Rank(key_shape_and_type.shape); + if (keys_rank < key_suffix_rank) { + return errors::InvalidArgument( + "Expected keys to have suffix ", + c->DebugString(key_shape_and_type.shape), + " but saw shape: ", c->DebugString(keys)); + } + for (int d = 0; d < key_suffix_rank; ++d) { + // Ensure the suffix of keys match what's in the Table. + DimensionHandle dim = c->Dim(key_shape_and_type.shape, d); + TF_RETURN_IF_ERROR( + c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys)); + } + + std::vector keys_prefix_vec; + keys_prefix_vec.reserve(keys_rank - key_suffix_rank); + for (int d = 0; d < keys_rank - key_suffix_rank; ++d) { + keys_prefix_vec.push_back(c->Dim(keys, d)); + } + + ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec); + TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, + value_shape_and_type.shape, + &output_shape_and_type->shape)); + + } else { + output_shape_and_type->shape = c->UnknownShape(); + } + } else { + TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape, + &output_shape_and_type->shape)); + } + } + return Status::OK(); +} + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableFind)) + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeAndType value_shape_and_type; + TF_RETURN_IF_ERROR(ValidateTableResourceHandle( + c, + /*keys=*/c->input(1), + /*key_dtype_attr=*/"Tin", + /*value_dtype_attr=*/"Tout", + /*is_lookup=*/true, &value_shape_and_type)); + c->set_output(0, value_shape_and_type.shape); + + return Status::OK(); + }); + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableInsert)) + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }); + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableRemove)) + .Input("table_handle: resource") + .Input("keys: Tin") + .Attr("Tin: type") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle)); + + // TODO(turboale): Validate keys shape. + return Status::OK(); + }); + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableClear)) + .Input("table_handle: resource") + .Attr("key_dtype: type") + .Attr("value_dtype: type"); + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableSize)) + .Input("table_handle: resource") + .Output("size: int64") + .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs); + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableExport)) + .Input("table_handle: resource") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + ShapeHandle keys = c->UnknownShapeOfRank(1); + ShapeAndType value_shape_and_type; + TF_RETURN_IF_ERROR(ValidateTableResourceHandle( + c, + /*keys=*/keys, + /*key_dtype_attr=*/"Tkeys", + /*value_dtype_attr=*/"Tvalues", + /*is_lookup=*/false, &value_shape_and_type)); + c->set_output(0, keys); + c->set_output(1, value_shape_and_type.shape); + return Status::OK(); + }); + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableImport)) + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }); + +Status RocksDBTableShape(InferenceContext *c, const ShapeHandle &key, + const ShapeHandle &value) { + c->set_output(0, c->Scalar()); + + ShapeHandle key_s; + TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s)); + + DataType key_t; + TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t)); + + DataType value_t; + TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t)); + + c->set_output_handle_shapes_and_types( + 0, std::vector{{key_s, key_t}, {value, value_t}}); + + return Status::OK(); +} + +REGISTER_OP(PREFIX_OP_NAME(RocksdbTableOfTensors)) + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("database_path: string = ''") + .Attr("embedding_name: string = ''") + .Attr("read_only: bool = false") + .Attr("estimate_size: bool = false") + .Attr("export_path: string = ''") + .SetIsStateful() + .SetShapeFn([](InferenceContext *c) { + PartialTensorShape valueP; + TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &valueP)); + ShapeHandle valueS; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(valueP, &valueS)); + return RocksDBTableShape(c, /*key=*/c->Scalar(), /*value=*/valueS); + }); + +} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/rocksdb_table_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/rocksdb_table_ops_test.py new file mode 100644 index 000000000..e2daed424 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/rocksdb_table_ops_test.py @@ -0,0 +1,1693 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""unit tests of variable (adapted from redis test-code) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json + +import itertools +import math +import shutil + +import numpy as np +import os +import six +import tempfile + +from tensorflow_recommenders_addons import dynamic_embedding as de + +import tensorflow as tf +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.keras import optimizer_v2 +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.python.training import saver +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat + + +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +def _type_converter(tf_type): + mapper = { + dtypes.int32: np.int32, + dtypes.int64: np.int64, + dtypes.float32: np.float, + dtypes.float64: np.float64, + dtypes.string: np.str, + dtypes.half: np.float16, + dtypes.int8: np.int8, + dtypes.bool: np.bool, + } + return mapper[tf_type] + + +def _get_devices(): + return [ + "/gpu:0" if len(tf.config.list_physical_devices('GPU')) > 0 else "/cpu:0" + ] + + +def _check_device(op, expected_device="gpu"): + return expected_device.upper() in op.device + + +def embedding_result(params, id_vals, weight_vals=None): + if weight_vals is None: + weight_vals = np.copy(id_vals) + weight_vals.fill(1) + + values = [] + weights = [] + weights_squared = [] + + for pms, ids, wts in zip(params, id_vals, weight_vals): + value_aggregation = None + weight_aggregation = None + squared_weight_aggregation = None + + if isinstance(ids, compat.integral_types): + pms = [pms] + ids = [ids] + wts = [wts] + + for val, i, weight_value in zip(pms, ids, wts): + if value_aggregation is None: + assert weight_aggregation is None + assert squared_weight_aggregation is None + value_aggregation = val * weight_value + weight_aggregation = weight_value + squared_weight_aggregation = weight_value * weight_value + else: + assert weight_aggregation is not None + assert squared_weight_aggregation is not None + value_aggregation += val * weight_value + weight_aggregation += weight_value + squared_weight_aggregation += weight_value * weight_value + + values.append(value_aggregation) + weights.append(weight_aggregation) + weights_squared.append(squared_weight_aggregation) + + values = np.array(values).astype(np.float32) + weights = np.array(weights).astype(np.float32) + weights_squared = np.array(weights_squared).astype(np.float32) + + return values, weights, weights_squared + + +def data_fn(shape, maxval): + return random_ops.random_uniform(shape, maxval=maxval, dtype=dtypes.int64) + + +def model_fn(sparse_vars, embed_dim, feature_inputs): + embedding_weights = [] + embedding_trainables = [] + for sp in sparse_vars: + for inp_tensor in feature_inputs: + embed_w, trainable = de.embedding_lookup(sp, + inp_tensor, + return_trainable=True) + embedding_weights.append(embed_w) + embedding_trainables.append(trainable) + + def layer_fn(entry, dimension, activation=False): + entry = array_ops.reshape(entry, (-1, dimension, embed_dim)) + dnn_fn = layers.Dense(dimension, use_bias=False) + batch_normal_fn = layers.BatchNormalization() + dnn_result = dnn_fn(entry) + if activation: + return batch_normal_fn(nn.selu(dnn_result)) + return dnn_result + + def dnn_fn(entry, dimension, activation=False): + hidden = layer_fn(entry, dimension, activation) + output = layer_fn(hidden, 1) + logits = math_ops.reduce_mean(output) + return logits + + logits_sum = sum(dnn_fn(w, 16, activation=True) for w in embedding_weights) + labels = 0.0 + err_prob = nn.sigmoid_cross_entropy_with_logits(logits=logits_sum, + labels=labels) + loss = math_ops.reduce_mean(err_prob) + return labels, embedding_trainables, loss + + +def ids_and_weights_2d(embed_dim=4): + # Each row demonstrates a test case: + # Row 0: multiple valid ids, 1 invalid id, weighted mean + # Row 1: all ids are invalid (leaving no valid ids after pruning) + # Row 2: no ids to begin with + # Row 3: single id + # Row 4: all ids have <=0 weight + indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [5, embed_dim] + + sparse_ids = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int64), + constant_op.constant(shape, dtypes.int64), + ) + + sparse_weights = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64), + ) + + return sparse_ids, sparse_weights + + +def ids_and_weights_3d(embed_dim=4): + # Each (2-D) index demonstrates a test case: + # Index 0, 0: multiple valid ids, 1 invalid id, weighted mean + # Index 0, 1: all ids are invalid (leaving no valid ids after pruning) + # Index 0, 2: no ids to begin with + # Index 1, 0: single id + # Index 1, 1: all ids have <=0 weight + # Index 1, 2: no ids to begin with + indices = [ + [0, 0, 0], + [0, 0, 1], + [0, 0, 2], + [0, 1, 0], + [1, 0, 0], + [1, 1, 0], + [1, 1, 1], + ] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [2, 3, embed_dim] + + sparse_ids = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int64), + constant_op.constant(shape, dtypes.int64), + ) + + sparse_weights = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64), + ) + + return sparse_ids, sparse_weights + + +def _random_weights( + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + vocab_size=4, + embed_dim=4, + num_shards=1, +): + assert vocab_size > 0 + assert embed_dim > 0 + assert num_shards > 0 + assert num_shards <= vocab_size + + initializer = init_ops.truncated_normal_initializer(mean=0.0, + stddev=1.0 / + math.sqrt(vocab_size), + dtype=dtypes.float32) + embedding_weights = de.get_variable( + key_dtype=key_dtype, + value_dtype=value_dtype, + devices=_get_devices() * num_shards, + name="embedding_weights", + initializer=initializer, + dim=embed_dim, + ) + return embedding_weights + + +def _test_dir(temp_dir, test_name): + """Create an empty dir to use for tests. + + Args: + temp_dir: Tmp directory path. + test_name: Name of the test. + + Returns: + Absolute path to the test directory. + """ + test_dir = os.path.join(temp_dir, test_name) + if os.path.isdir(test_dir): + for f in glob.glob(f"{test_dir}/*"): + os.remove(f) + else: + os.makedirs(test_dir) + return test_dir + + +def _create_dynamic_shape_tensor( + max_len=100, + min_len=2, + min_val=0x0000_F000_0000_0001, + max_val=0x0000_F000_0000_0020, + dtype=np.int64, +): + + def _func(): + length = np.random.randint(min_len, max_len) + tensor = np.random.randint(min_val, max_val, max_len, dtype=dtype) + tensor = np.array(tensor[0:length], dtype=dtype) + return tensor + + return _func + + +default_config = config_pb2.ConfigProto( + allow_soft_placement=False, + gpu_options=config_pb2.GPUOptions(allow_growth=True)) + +ROCKSDB_CONFIG_PATH = os.path.join(tempfile.gettempdir(), + 'test_rocksdb_config.json') +ROCKSDB_CONFIG_PARAMS = { + 'database_path': os.path.join(tempfile.gettempdir(), 'test_rocksdb_4711'), + 'embedding_name': None, + 'read_only': False, + 'estimate_size': False, + 'export_path': None, +} + + +def conf_with(**kwargs): + config = {k: v for k, v in ROCKSDB_CONFIG_PARAMS.items()} + for k, v in kwargs.items(): + config[k] = v + return de.RocksDBTableConfig(config) + + +DELETE_DATABASE_AT_STARTUP = False + +SKIP_PASSING = False +SKIP_PASSING_WITH_QUESTIONS = False +SKIP_FAILING = True +SKIP_FAILING_WITH_QUESTIONS = True + + +@test_util.run_all_in_graph_and_eager_modes +class RocksDBVariableTest(test.TestCase): + + def __init__(self, method_name='runTest'): + super().__init__(method_name) + self.gpu_available = len(tf.config.list_physical_devices('GPU')) > 0 + + @test_util.skip_if(SKIP_PASSING) + def test_basic(self): + with self.session(config=default_config, use_gpu=False): + table = de.get_variable( + "t0-test_basic", + dtypes.int64, + dtypes.int32, + initializer=0, + dim=8, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t0_test_basic')), + ) + self.evaluate(table.clear()) + self.evaluate(table.size()) + + @test_util.skip_if(SKIP_PASSING) + def test_variable(self): + if self.gpu_available: + dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 200] + kv_list = [ + [dtypes.int64, dtypes.int8], + [dtypes.int64, dtypes.int32], + [dtypes.int64, dtypes.half], + [dtypes.int64, dtypes.float32], + ] + else: + dim_list = [1, 8, 16, 128] + kv_list = [ + [dtypes.int32, dtypes.int32], + [dtypes.int32, dtypes.float32], + [dtypes.int32, dtypes.double], + [dtypes.int64, dtypes.int8], + [dtypes.int64, dtypes.int32], + [dtypes.int64, dtypes.int64], + [dtypes.int64, dtypes.half], + [dtypes.int64, dtypes.float32], + [dtypes.int64, dtypes.double], + [dtypes.int64, dtypes.string], + [dtypes.string, dtypes.int8], + [dtypes.string, dtypes.int32], + [dtypes.string, dtypes.int64], + [dtypes.string, dtypes.half], + [dtypes.string, dtypes.float32], + [dtypes.string, dtypes.double], + ] + + def _convert(v, t): + return np.array(v).astype(_type_converter(t)) + + for _id, ((key_dtype, value_dtype), + dim) in enumerate(itertools.product(kv_list, dim_list)): + + with self.session(config=default_config, use_gpu=self.gpu_available): + keys = constant_op.constant( + np.array([0, 1, 2, 3]).astype(_type_converter(key_dtype)), + key_dtype) + values = constant_op.constant( + _convert([[0] * dim, [1] * dim, [2] * dim, [3] * dim], value_dtype), + value_dtype) + + table = de.get_variable( + f't1-{_id}_test_variable', + key_dtype=key_dtype, + value_dtype=value_dtype, + initializer=np.array([-1]).astype(_type_converter(value_dtype)), + dim=dim, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t1_test_variable')), + ) + self.evaluate(table.clear()) + + self.assertAllEqual(0, self.evaluate(table.size())) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_keys = constant_op.constant(_convert([1, 5], key_dtype), + key_dtype) + self.evaluate(table.remove(remove_keys)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant(_convert([0, 1, 5], key_dtype), + key_dtype) + output = table.lookup(remove_keys) + self.assertAllEqual([3, dim], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual( + _convert([[0] * dim, [-1] * dim, [-1] * dim], value_dtype), + _convert(result, value_dtype)) + + exported_keys, exported_values = table.export() + + # exported data is in the order of the internal map, i.e. undefined + sorted_keys = np.sort(self.evaluate(exported_keys)) + sorted_values = np.sort(self.evaluate(exported_values), axis=0) + self.assertAllEqual(_convert([0, 2, 3], key_dtype), + _convert(sorted_keys, key_dtype)) + self.assertAllEqual( + _convert([[0] * dim, [2] * dim, [3] * dim], value_dtype), + _convert(sorted_values, value_dtype)) + + self.evaluate(table.clear()) + del table + + @test_util.skip_if(SKIP_PASSING) + def test_empty_kvs(self): + dim_list = [1, 8, 16] + kv_list = [ + [dtypes.int32, dtypes.int32], + [dtypes.int32, dtypes.float32], + [dtypes.int32, dtypes.double], + [dtypes.int64, dtypes.int8], + [dtypes.int64, dtypes.int32], + [dtypes.int64, dtypes.int64], + [dtypes.int64, dtypes.half], + [dtypes.int64, dtypes.float32], + [dtypes.int64, dtypes.double], + [dtypes.int64, dtypes.string], + [dtypes.string, dtypes.int8], + [dtypes.string, dtypes.int32], + [dtypes.string, dtypes.int64], + [dtypes.string, dtypes.half], + [dtypes.string, dtypes.float32], + [dtypes.string, dtypes.double], + ] + + def _convert(v, t): + return np.array(v).astype(_type_converter(t)) + + for _id, ((key_dtype, value_dtype), + dim) in enumerate(itertools.product(kv_list, dim_list)): + with self.session(config=default_config, use_gpu=self.gpu_available): + keys = constant_op.constant( + np.array([]).astype(_type_converter(key_dtype)), key_dtype) + values = constant_op.constant(_convert([], value_dtype), value_dtype) + table = de.get_variable( + 't1-' + str(_id) + '_test_empty_kvs', + key_dtype=key_dtype, + value_dtype=value_dtype, + initializer=np.array([-1]).astype(_type_converter(value_dtype)), + dim=dim, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t1_test_empty_kvs')), + ) + self.evaluate(table.clear()) + + self.assertAllEqual(0, self.evaluate(table.size())) + + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(0, self.evaluate(table.size())) + + output = table.lookup(keys) + self.assertAllEqual([0, dim], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual(np.reshape(_convert([], value_dtype), (0, dim)), + _convert(result, value_dtype)) + + self.evaluate(table.clear()) + del table + + @test_util.skip_if(SKIP_PASSING) + def test_variable_initializer(self): + for _id, (initializer, target_mean, target_stddev) in enumerate([ + (-1.0, -1.0, 0.0), + (init_ops.random_normal_initializer(0.0, 0.01, seed=2), 0.0, 0.01), + ]): + with self.session(config=default_config, use_gpu=self.gpu_available): + keys = constant_op.constant(list(range(2**16)), dtypes.int64) + table = de.get_variable( + f't2-{_id}_test_variable_initializer', + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=initializer, + dim=10, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t2_test_variable_initializer')), + ) + self.evaluate(table.clear()) + + vals_op = table.lookup(keys) + mean = self.evaluate(math_ops.reduce_mean(vals_op)) + stddev = self.evaluate(math_ops.reduce_std(vals_op)) + + atol = rtol = 2e-5 + self.assertAllClose(target_mean, mean, rtol, atol) + self.assertAllClose(target_stddev, stddev, rtol, atol) + + self.evaluate(table.clear()) + del table + + @test_util.skip_if(SKIP_FAILING) + def test_save_restore(self): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + with self.session(config=default_config, graph=ops.Graph()) as sess: + v0 = variables.Variable(10.0, name="v0") + v1 = variables.Variable(20.0, name="v1") + + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0.0], [1.0], [2.0]], dtypes.float32) + table = de.Variable( + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1.0, + name='t1', + dim=1, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t3_test_save_restore')), + ) + self.evaluate(table.clear()) + + save = saver.Saver(var_list=[v0, v1, table]) + self.evaluate(variables.global_variables_initializer()) + + # Check that the parameter nodes have been initialized. + self.assertEqual(10.0, self.evaluate(v0)) + self.assertEqual(20.0, self.evaluate(v1)) + + self.assertAllEqual(0, self.evaluate(table.size())) + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + val = save.save(sess, save_path) + self.assertIsInstance(val, six.string_types) + self.assertEqual(save_path, val) + + self.evaluate(table.clear()) + del table + + with self.session(config=default_config, graph=ops.Graph()) as sess: + v0 = variables.Variable(-1.0, name="v0") + v1 = variables.Variable(-1.0, name="v1") + table = de.Variable( + name="t1", + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1.0, + dim=1, + checkpoint=True, + ) + self.evaluate(table.clear()) + + self.evaluate( + table.upsert( + constant_op.constant([0, 1], dtypes.int64), + constant_op.constant([[12.0], [24.0]], dtypes.float32), + )) + size_op = table.size() + self.assertAllEqual(2, self.evaluate(size_op)) + + save = saver.Saver(var_list=[v0, v1, table]) + + # Restore the saved values in the parameter nodes. + save.restore(sess, save_path) + # Check that the parameter nodes have been restored. + self.assertEqual([10.0], self.evaluate(v0)) + self.assertEqual([20.0], self.evaluate(v1)) + + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([5, 0, 1, 2, 6], dtypes.int64) + output = table.lookup(remove_keys) + self.assertAllEqual([[-1.0], [0.0], [1.0], [2.0], [-1.0]], + self.evaluate(output)) + + self.evaluate(table.clear()) + del table + + @test_util.skip_if(SKIP_FAILING) + def test_save_restore_only_table(self): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + with self.session( + config=default_config, + graph=ops.Graph(), + use_gpu=self.gpu_available, + ) as sess: + v0 = variables.Variable(10.0, name="v0") + v1 = variables.Variable(20.0, name="v1") + + default_val = -1 + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0], [1], [2]], dtypes.int32) + table = de.Variable( + dtypes.int64, + dtypes.int32, + name="t1", + initializer=default_val, + checkpoint=True, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t4_save_restore_only_table')), + ) + self.evaluate(table.clear()) + + save = saver.Saver([table]) + self.evaluate(variables.global_variables_initializer()) + + # Check that the parameter nodes have been initialized. + self.assertEqual(10.0, self.evaluate(v0)) + self.assertEqual(20.0, self.evaluate(v1)) + + self.assertAllEqual(0, self.evaluate(table.size())) + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + val = save.save(sess, save_path) + self.assertIsInstance(val, six.string_types) + self.assertEqual(save_path, val) + + self.evaluate(table.clear()) + del table + + with self.session( + config=default_config, + graph=ops.Graph(), + use_gpu=self.gpu_available, + ) as sess: + default_val = -1 + table = de.Variable( + dtypes.int64, + dtypes.int32, + name="t1", + initializer=default_val, + checkpoint=True, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t6_save_restore_only_table')), + ) + self.evaluate(table.clear()) + + self.evaluate( + table.upsert( + constant_op.constant([0, 2], dtypes.int64), + constant_op.constant([[12], [24]], dtypes.int32), + )) + self.assertAllEqual(2, self.evaluate(table.size())) + + save = saver.Saver([table._tables[0]]) + + # Restore the saved values in the parameter nodes. + save.restore(sess, save_path) + + # Check that the parameter nodes have been restored. + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) + output = table.lookup(remove_keys) + self.assertAllEqual([[0], [1], [2], [-1], [-1]], self.evaluate(output)) + + self.evaluate(table.clear()) + del table + + @test_util.skip_if(SKIP_FAILING) + def test_training_save_restore(self): + opt = de.DynamicEmbeddingOptimizer(adam.AdamOptimizer(0.3)) + if self.gpu_available: + dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 200] + else: + dim_list = [10] + + for _id, (key_dtype, value_dtype, dim, step) in enumerate( + itertools.product( + [dtypes.int64], + [dtypes.float32], + dim_list, + [10], + )): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + ids = script_ops.py_func( + _create_dynamic_shape_tensor(), + inp=[], + Tout=key_dtype, + stateful=True, + ) + + params = de.get_variable( + name=f'params-test-0915-{_id}_test_training_save_restore', + key_dtype=key_dtype, + value_dtype=value_dtype, + initializer=init_ops.random_normal_initializer(0.0, 0.01), + dim=dim, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t5_training_save_restore')), + ) + self.evaluate(params.clear()) + + _, var0 = de.embedding_lookup(params, + ids, + name="emb", + return_trainable=True) + + def loss(): + return var0 * var0 + + params_keys, params_vals = params.export() + mini = opt.minimize(loss, var_list=[var0]) + opt_slots = [opt.get_slot(var0, _s) for _s in opt.get_slot_names()] + _saver = saver.Saver([params] + [_s.params for _s in opt_slots]) + + with self.session(config=default_config, + use_gpu=self.gpu_available) as sess: + self.evaluate(variables.global_variables_initializer()) + for _i in range(step): + self.evaluate([mini]) + size_before_saved = self.evaluate(params.size()) + np_params_keys_before_saved = self.evaluate(params_keys) + np_params_vals_before_saved = self.evaluate(params_vals) + opt_slots_kv_pairs = [_s.params.export() for _s in opt_slots] + np_slots_kv_pairs_before_saved = [ + self.evaluate(_kv) for _kv in opt_slots_kv_pairs + ] + params_size = self.evaluate(params.size()) + _saver.save(sess, save_path) + + with self.session(config=default_config, + use_gpu=self.gpu_available) as sess: + self.evaluate(variables.global_variables_initializer()) + self.assertAllEqual(params_size, self.evaluate(params.size())) + + _saver.restore(sess, save_path) + params_keys_restored, params_vals_restored = params.export() + size_after_restored = self.evaluate(params.size()) + np_params_keys_after_restored = self.evaluate(params_keys_restored) + np_params_vals_after_restored = self.evaluate(params_vals_restored) + + opt_slots_kv_pairs_restored = [_s.params.export() for _s in opt_slots] + np_slots_kv_pairs_after_restored = [ + self.evaluate(_kv) for _kv in opt_slots_kv_pairs_restored + ] + self.assertAllEqual(size_before_saved, size_after_restored) + self.assertAllEqual( + np.sort(np_params_keys_before_saved), + np.sort(np_params_keys_after_restored), + ) + self.assertAllEqual( + np.sort(np_params_vals_before_saved, axis=0), + np.sort(np_params_vals_after_restored, axis=0), + ) + for pairs_before, pairs_after in zip(np_slots_kv_pairs_before_saved, + np_slots_kv_pairs_after_restored): + self.assertAllEqual( + np.sort(pairs_before[0], axis=0), + np.sort(pairs_after[0], axis=0), + ) + self.assertAllEqual( + np.sort(pairs_before[1], axis=0), + np.sort(pairs_after[1], axis=0), + ) + if self.gpu_available: + self.assertTrue('GPU' in params.tables[0].resource_handle.device) + + self.evaluate(params.clear()) + del params + + @test_util.skip_if(SKIP_PASSING) + def test_training_save_restore_by_files(self): + opt = de.DynamicEmbeddingOptimizer(adam.AdamOptimizer(0.3)) + for _id, (key_dtype, value_dtype, dim, step) in enumerate( + itertools.product( + [dtypes.int64], + [dtypes.float32], + [10], + [10], + )): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + os.makedirs(save_path) + + ids = script_ops.py_func(_create_dynamic_shape_tensor(), + inp=[], + Tout=key_dtype, + stateful=True) + + params = de.get_variable( + name=f'params-test-0916-{_id}_test_training_save_restore_by_files', + key_dtype=key_dtype, + value_dtype=value_dtype, + initializer=0, + dim=dim, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t5_training_save_restore', + export_path=save_path)), + ) + self.evaluate(params.clear()) + + _, var0 = de.embedding_lookup(params, + ids, + name="emb", + return_trainable=True) + + def loss(): + return var0 * var0 + + mini = opt.minimize(loss, var_list=[var0]) + opt_slots = [opt.get_slot(var0, _s) for _s in opt.get_slot_names()] + _saver = saver.Saver([params] + [_s.params for _s in opt_slots]) + + keys = np.random.randint(1, 100, dim) + values = np.random.rand(keys.shape[0], dim) + + with self.session(config=default_config, + use_gpu=self.gpu_available) as sess: + self.evaluate(variables.global_variables_initializer()) + self.evaluate(params.upsert(keys, values)) + params_vals = params.lookup(keys) + for _i in range(step): + self.evaluate([mini]) + size_before_saved = self.evaluate(params.size()) + np_params_vals_before_saved = self.evaluate(params_vals) + params_size = self.evaluate(params.size()) + _saver.save(sess, save_path) + + with self.session(config=default_config, + use_gpu=self.gpu_available) as sess: + _saver.restore(sess, save_path) + self.evaluate(variables.global_variables_initializer()) + self.assertAllEqual(params_size, self.evaluate(params.size())) + params_vals_restored = params.lookup(keys) + size_after_restored = self.evaluate(params.size()) + np_params_vals_after_restored = self.evaluate(params_vals_restored) + + self.assertAllEqual(size_before_saved, size_after_restored) + self.assertAllEqual( + np.sort(np_params_vals_before_saved, axis=0), + np.sort(np_params_vals_after_restored, axis=0), + ) + + self.evaluate(params.clear()) + del params + + @test_util.skip_if(SKIP_PASSING) + def test_get_variable(self): + with self.session( + config=default_config, + graph=ops.Graph(), + use_gpu=self.gpu_available, + ): + default_val = -1 + with variable_scope.variable_scope("embedding", reuse=True): + table1 = de.get_variable( + 't1_test_get_variable', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t7_get_variable')), + ) + table2 = de.get_variable( + 't1_test_get_variable', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t7_get_variable')), + ) + table3 = de.get_variable( + 't3_test_get_variable', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t7_get_variable')), + ) + self.evaluate(table1.clear()) + self.evaluate(table2.clear()) + self.evaluate(table3.clear()) + + self.assertAllEqual(table1, table2) + self.assertNotEqual(table1, table3) + + @test_util.skip_if(SKIP_PASSING) + def test_get_variable_reuse_error(self): + ops.disable_eager_execution() + with self.session( + config=default_config, + graph=ops.Graph(), + use_gpu=self.gpu_available, + ): + with variable_scope.variable_scope('embedding', reuse=False): + _ = de.get_variable( + 't900', + initializer=-1, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t8_get_variable_reuse_error')), + ) + with self.assertRaisesRegexp(ValueError, + 'Variable embedding/t900 already exists'): + _ = de.get_variable( + 't900', + initializer=-1, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t8_get_variable_reuse_error')), + ) + + @test_util.skip_if(SKIP_PASSING) + @test_util.run_v1_only("Multiple sessions") + def test_sharing_between_multi_sessions(self): + ops.disable_eager_execution() + + # Start a server to store the table state + server = server_lib.Server({'local0': ['localhost:0']}, + protocol='grpc', + start=True) + + # Create two sessions sharing the same state + session1 = session.Session(server.target, config=default_config) + session2 = session.Session(server.target, config=default_config) + + table = de.get_variable( + 'tx100_test_sharing_between_multi_sessions', + dtypes.int64, + dtypes.int32, + initializer=0, + dim=1, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t9_sharing_between_multi_sessions')), + ) + self.evaluate(table.clear()) + + # Populate the table in the first session + with session1: + with ops.device(_get_devices()[0]): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(variables.local_variables_initializer()) + self.assertAllEqual(0, table.size().eval()) + + keys = constant_op.constant([11, 12], dtypes.int64) + values = constant_op.constant([[11], [12]], dtypes.int32) + table.upsert(keys, values).run() + self.assertAllEqual(2, table.size().eval()) + + output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64)) + self.assertAllEqual([[11], [12], [0]], output.eval()) + + # Verify that we can access the shared data from the second session + with session2: + with ops.device(_get_devices()[0]): + self.assertAllEqual(2, table.size().eval()) + + output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64)) + self.assertAllEqual([[0], [11], [12]], output.eval()) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = constant_op.constant([-1, -2], dtypes.int64) + keys = constant_op.constant([0, 1, 2, 3], dtypes.int64) + values = constant_op.constant([ + [0, 1], + [2, 3], + [4, 5], + [6, 7], + ], dtypes.int32) + + table = de.get_variable( + 't10_test_dynamic_embedding_variable', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t10_dynamic_embedding_variable')), + ) + self.evaluate(table.clear()) + + self.assertAllEqual(0, self.evaluate(table.size())) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_keys = constant_op.constant([3, 4], dtypes.int64) + self.evaluate(table.remove(remove_keys)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([0, 1, 4], dtypes.int64) + output = table.lookup(remove_keys) + self.assertAllEqual([3, 2], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([ + [0, 1], + [2, 3], + [-1, -2], + ], result) + + exported_keys, exported_values = table.export() + # exported data is in the order of the internal map, i.e. undefined + sorted_keys = np.sort(self.evaluate(exported_keys)) + sorted_values = np.sort(self.evaluate(exported_values), axis=0) + self.assertAllEqual([0, 1, 2], sorted_keys) + sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0) + self.assertAllEqual(sorted_expected_values, sorted_values) + + self.evaluate(table.clear()) + del table + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_export_insert(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = constant_op.constant([-1, -1], dtypes.int64) + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([ + [0, 1], + [2, 3], + [4, 5], + ], dtypes.int32) + + table1 = de.get_variable( + 't101_test_dynamic_embedding_variable_export_insert', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't101_dynamic_embedding_variable_export_insert_a')), + ) + self.evaluate(table1.clear()) + + self.assertAllEqual(0, self.evaluate(table1.size())) + self.evaluate(table1.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table1.size())) + + input_keys = constant_op.constant([0, 1, 3], dtypes.int64) + expected_output = [[0, 1], [2, 3], [-1, -1]] + output1 = table1.lookup(input_keys) + self.assertAllEqual(expected_output, self.evaluate(output1)) + + exported_keys, exported_values = table1.export() + self.assertAllEqual(3, self.evaluate(exported_keys).size) + self.assertAllEqual(6, self.evaluate(exported_values).size) + + # Populate a second table from the exported data + table2 = de.get_variable( + 't102_test_dynamic_embedding_variable_export_insert', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='t10_dynamic_embedding_variable_export_insert_b' + )), + ) + self.evaluate(table2.clear()) + + self.assertAllEqual(0, self.evaluate(table2.size())) + self.evaluate(table2.upsert(exported_keys, exported_values)) + self.assertAllEqual(3, self.evaluate(table2.size())) + + # Verify lookup result is still the same + output2 = table2.lookup(input_keys) + self.assertAllEqual(expected_output, self.evaluate(output2)) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_invalid_shape(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = constant_op.constant([-1, -1], dtypes.int64) + keys = constant_op.constant([0, 1, 2], dtypes.int64) + + table = de.get_variable( + 't110_test_dynamic_embedding_variable_invalid_shape', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=2, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='t110_dynamic_embedding_variable_invalid_shape' + )), + ) + self.evaluate(table.clear()) + + # Shape [6] instead of [3, 2] + values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int32) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.upsert(keys, values)) + + # Shape [2,3] instead of [3, 2] + values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int32) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.upsert(keys, values)) + + # Shape [2, 2] instead of [3, 2] + values = constant_op.constant([[0, 1], [2, 3]], dtypes.int32) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.upsert(keys, values)) + + # Shape [3, 1] instead of [3, 2] + values = constant_op.constant([[0], [2], [4]], dtypes.int32) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.upsert(keys, values)) + + # Valid Insert + values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int32) + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_duplicate_insert(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = -1 + keys = constant_op.constant([0, 1, 2, 2], dtypes.int64) + values = constant_op.constant([[0.0], [1.0], [2.0], [3.0]], + dtypes.float32) + + table = de.get_variable( + 't130_test_dynamic_embedding_variable_duplicate_insert', + dtypes.int64, + dtypes.float32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't130_dynamic_embedding_variable_duplicate_insert')), + ) + self.evaluate(table.clear()) + + self.assertAllEqual(0, self.evaluate(table.size())) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_keys = constant_op.constant([0, 1, 2], dtypes.int64) + output = table.lookup(input_keys) + + result = self.evaluate(output) + self.assertTrue( + list(result) in [[[0.0], [1.0], [3.0]], [[0.0], [1.0], [2.0]]]) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_find_high_rank(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = -1 + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0], [1], [2]], dtypes.int32) + + table = de.get_variable( + 't140_test_dynamic_embedding_variable_find_high_rank', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='t140_dynamic_embedding_variable_find_high_rank' + )), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_keys = constant_op.constant([[0, 1], [2, 4]], dtypes.int64) + output = table.lookup(input_keys) + self.assertAllEqual([2, 2, 1], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([[[0], [1]], [[2], [-1]]], result) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_insert_low_rank(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = -1 + keys = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + values = constant_op.constant([[[0], [1]], [[2], [3]]], dtypes.int32) + + table = de.get_variable( + 't150_test_dynamic_embedding_variable_insert_low_rank', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't150_dynamic_embedding_variable_insert_low_rank')), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_keys = constant_op.constant([0, 1, 3, 4], dtypes.int64) + output = table.lookup(remove_keys) + + result = self.evaluate(output) + self.assertAllEqual([[0], [1], [3], [-1]], result) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_remove_low_rank(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = -1 + keys = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + values = constant_op.constant([[[0], [1]], [[2], [3]]], dtypes.int32) + + table = de.get_variable( + 't160_test_dynamic_embedding_variable_remove_low_rank', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't160_dynamic_embedding_variable_remove_low_rank')), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_keys = constant_op.constant([1, 4], dtypes.int64) + self.evaluate(table.remove(remove_keys)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([0, 1, 3, 4], dtypes.int64) + output = table.lookup(remove_keys) + + result = self.evaluate(output) + self.assertAllEqual([[0], [-1], [3], [-1]], result) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_insert_high_rank(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = constant_op.constant([-1, -1, -1], dtypes.int32) + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([ + [0, 1, 2], + [2, 3, 4], + [4, 5, 6], + ], dtypes.int32) + + table = de.get_variable( + 't170_test_dynamic_embedding_variable_insert_high_rank', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=3, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't170_dynamic_embedding_variable_insert_high_rank')), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([[0, 1], [3, 4]], dtypes.int64) + output = table.lookup(remove_keys) + self.assertAllEqual([2, 2, 3], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual( + [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], + result, + ) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_remove_high_rank(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = constant_op.constant([-1, -1, -1], dtypes.int32) + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([ + [0, 1, 2], + [2, 3, 4], + [4, 5, 6], + ], dtypes.int32) + + table = de.get_variable( + 't180_test_dynamic_embedding_variable_remove_high_rank', + dtypes.int64, + dtypes.int32, + initializer=default_val, + dim=3, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't180_dynamic_embedding_variable_remove_high_rank')), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([[0, 3]], dtypes.int64) + self.evaluate(table.remove(remove_keys)) + self.assertAllEqual(2, self.evaluate(table.size())) + + remove_keys = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + output = table.lookup(remove_keys) + self.assertAllEqual([2, 2, 3], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual( + [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], + result, + ) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variables(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = -1 + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0], [1], [2]], dtypes.int32) + + table1 = de.get_variable( + 't191_test_dynamic_embedding_variables', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t191_dynamic_embedding_variables')), + ) + table2 = de.get_variable( + 't192_test_dynamic_embedding_variables', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t192_dynamic_embedding_variables')), + ) + table3 = de.get_variable( + 't193_test_dynamic_embedding_variables', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t193_dynamic_embedding_variables')), + ) + self.evaluate(table1.clear()) + self.evaluate(table2.clear()) + self.evaluate(table3.clear()) + + self.evaluate(table1.upsert(keys, values)) + self.evaluate(table2.upsert(keys, values)) + self.evaluate(table3.upsert(keys, values)) + + self.assertAllEqual(3, self.evaluate(table1.size())) + self.assertAllEqual(3, self.evaluate(table2.size())) + self.assertAllEqual(3, self.evaluate(table3.size())) + + remove_keys = constant_op.constant([0, 1, 3], dtypes.int64) + output1 = table1.lookup(remove_keys) + output2 = table2.lookup(remove_keys) + output3 = table3.lookup(remove_keys) + + out1, out2, out3 = self.evaluate([output1, output2, output3]) + self.assertAllEqual([[0], [1], [-1]], out1) + self.assertAllEqual([[0], [1], [-1]], out2) + self.assertAllEqual([[0], [1], [-1]], out3) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_with_tensor_default(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = constant_op.constant(-1, dtypes.int32) + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0], [1], [2]], dtypes.int32) + + table = de.get_variable( + 't200_test_dynamic_embedding_variable_with_tensor_default', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't200_dynamic_embedding_variable_with_tensor_default')), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([0, 1, 3], dtypes.int64) + output = table.lookup(remove_keys) + + result = self.evaluate(output) + self.assertAllEqual([[0], [1], [-1]], result) + + @test_util.skip_if(SKIP_PASSING) + def test_signature_mismatch(self): + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + config.gpu_options.allow_growth = True + + with self.session(config=config, use_gpu=self.gpu_available): + default_val = -1 + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0], [1], [2]], dtypes.int32) + + table = de.get_variable( + 't210_test_signature_mismatch', + dtypes.int64, + dtypes.int32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name='t210_signature_mismatch')), + ) + self.evaluate(table.clear()) + + # upsert with keys of the wrong type + with self.assertRaises(ValueError): + self.evaluate( + table.upsert(constant_op.constant([4.0, 5.0, 6.0], dtypes.float32), + values)) + + # upsert with values of the wrong type + with self.assertRaises(ValueError): + self.evaluate(table.upsert(keys, constant_op.constant(["a", "b", "c"]))) + + self.assertAllEqual(0, self.evaluate(table.size())) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys_ref = variables.Variable(0, dtype=dtypes.int64) + input_int64_ref = variables.Variable([-1], dtype=dtypes.int32) + self.evaluate(variables.global_variables_initializer()) + + # Ref types do not produce an upsert signature mismatch. + self.evaluate(table.upsert(remove_keys_ref, input_int64_ref)) + self.assertAllEqual(3, self.evaluate(table.size())) + + # Ref types do not produce a lookup signature mismatch. + self.assertEqual([-1], self.evaluate(table.lookup(remove_keys_ref))) + + # lookup with keys of the wrong type + remove_keys = constant_op.constant([1, 2, 3], dtypes.int32) + with self.assertRaises(ValueError): + self.evaluate(table.lookup(remove_keys)) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_int_float(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + default_val = -1.0 + keys = constant_op.constant([3, 7, 0], dtypes.int64) + values = constant_op.constant([[7.5], [-1.2], [9.9]], dtypes.float32) + table = de.get_variable( + 't220_test_dynamic_embedding_variable_int_float', + dtypes.int64, + dtypes.float32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='t220_dynamic_embedding_variable_int_float')), + ) + self.evaluate(table.clear()) + + self.assertAllEqual(0, self.evaluate(table.size())) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([7, 0, 11], dtypes.int64) + output = table.lookup(remove_keys) + + result = self.evaluate(output) + self.assertAllClose([[-1.2], [9.9], [default_val]], result) + + @test_util.skip_if(SKIP_PASSING) + def test_dynamic_embedding_variable_with_random_init(self): + with self.session(config=default_config, use_gpu=self.gpu_available): + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant([[0.0], [1.0], [2.0]], dtypes.float32) + default_val = init_ops.random_uniform_initializer() + + table = de.get_variable( + 't230_test_dynamic_embedding_variable_with_random_init', + dtypes.int64, + dtypes.float32, + initializer=default_val, + kv_creator=de.RocksDBTableCreator( + conf_with(embedding_name= + 't230_dynamic_embedding_variable_with_random_init')), + ) + self.evaluate(table.clear()) + + self.evaluate(table.upsert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + remove_keys = constant_op.constant([0, 1, 3], dtypes.int64) + output = table.lookup(remove_keys) + + result = self.evaluate(output) + self.assertNotEqual([-1.0], result[2]) + + @test_util.skip_if(SKIP_FAILING_WITH_QUESTIONS) + def test_dynamic_embedding_variable_with_restrict_v1(self): + if context.executing_eagerly(): + self.skipTest('skip eager test when using legacy optimizers.') + + optmz = de.DynamicEmbeddingOptimizer(adam.AdamOptimizer(0.1)) + data_len = 32 + maxval = 256 + num_reserved = 100 + trigger = 150 + embed_dim = 8 + + # TODO: Should these use the same embedding or independent embeddings? + # TODO: These tests do something odd. They write 32 byte entries to the table, but + # then expect the responses to be 4 bytes. Is there a bug in TFRA? + # >> See LOG(WARNING) outputs I added. + # TODO: Will fail with TF2. + var_guard_by_tstp = de.get_variable( + 'tstp_guard' + '_test_dynamic_embedding_variable_with_restrict_v1', + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1., + dim=embed_dim, + init_size=256, + restrict_policy=de.TimestampRestrictPolicy, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='dynamic_embedding_variable_with_restrict_v1')), + ) + self.evaluate(var_guard_by_tstp.clear()) + + var_guard_by_freq = de.get_variable( + 'freq_guard' + '_test_dynamic_embedding_variable_with_restrict_v1', + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1., + dim=embed_dim, + init_size=256, + restrict_policy=de.FrequencyRestrictPolicy, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='dynamic_embedding_variable_with_restrict_v1')), + ) + self.evaluate(var_guard_by_freq.clear()) + + sparse_vars = [var_guard_by_tstp, var_guard_by_freq] + + indices = [data_fn((data_len, 1), maxval) for _ in range(3)] + _, trainables, loss = model_fn(sparse_vars, embed_dim, indices) + train_op = optmz.minimize(loss, var_list=trainables) + + var_sizes = [0, 0] + self.evaluate(variables.global_variables_initializer()) + + while not all(sz > trigger for sz in var_sizes): + self.evaluate(train_op) + var_sizes = self.evaluate([spv.size() for spv in sparse_vars]) + + self.assertTrue(all(sz >= trigger for sz in var_sizes)) + tstp_restrict_op = var_guard_by_tstp.restrict(num_reserved, trigger=trigger) + if tstp_restrict_op != None: + self.evaluate(tstp_restrict_op) + freq_restrict_op = var_guard_by_freq.restrict(num_reserved, trigger=trigger) + if freq_restrict_op != None: + self.evaluate(freq_restrict_op) + var_sizes = self.evaluate([spv.size() for spv in sparse_vars]) + self.assertAllEqual(var_sizes, [num_reserved, num_reserved]) + + slot_params = [] + for _trainable in trainables: + slot_params += [ + optmz.get_slot(_trainable, name).params + for name in optmz.get_slot_names() + ] + slot_params = list(set(slot_params)) + + for sp in slot_params: + self.assertAllEqual(self.evaluate(sp.size()), num_reserved) + tstp_size = self.evaluate(var_guard_by_tstp.restrict_policy.status.size()) + self.assertAllEqual(tstp_size, num_reserved) + freq_size = self.evaluate(var_guard_by_freq.restrict_policy.status.size()) + self.assertAllEqual(freq_size, num_reserved) + + # @test_util.skip_if(SKIP_PASSING_WITH_QUESTIONS) + @test_util.skip_if(SKIP_FAILING) + def test_dynamic_embedding_variable_with_restrict_v2(self): + if not context.executing_eagerly(): + self.skipTest('Test in eager mode only.') + + optmz = de.DynamicEmbeddingOptimizer(optimizer_v2.adam.Adam(0.1)) + data_len = 32 + maxval = 256 + num_reserved = 100 + trigger = 150 + embed_dim = 8 + trainables = [] + + # TODO: Should these use the same embedding or independent embeddings? + # TODO: These tests do something odd. They write 32 byte entries to the table, but + # then expect the responses to be 4 bytes. Is there a bug in TFRA? + # >> See LOG(WARNING) outputs I added. + var_guard_by_tstp = de.get_variable( + 'tstp_guard' + '_test_dynamic_embedding_variable_with_restrict_v2', + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1., + dim=embed_dim, + restrict_policy=de.TimestampRestrictPolicy, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='dynamic_embedding_variable_with_restrict_v2')), + ) + self.evaluate(var_guard_by_tstp.clear()) + + var_guard_by_freq = de.get_variable( + 'freq_guard' + '_test_dynamic_embedding_variable_with_restrict_v2', + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1., + dim=embed_dim, + restrict_policy=de.FrequencyRestrictPolicy, + kv_creator=de.RocksDBTableCreator( + conf_with( + embedding_name='dynamic_embedding_variable_with_restrict_v2')), + ) + self.evaluate(var_guard_by_freq.clear()) + + sparse_vars = [var_guard_by_tstp, var_guard_by_freq] + + def loss_fn(sparse_vars, trainables): + indices = [data_fn((data_len, 1), maxval) for _ in range(3)] + _, tws, loss = model_fn(sparse_vars, embed_dim, indices) + trainables.clear() + trainables.extend(tws) + return loss + + def var_fn(): + return trainables + + var_sizes = [0, 0] + + while not all(sz > trigger for sz in var_sizes): + optmz.minimize(lambda: loss_fn(sparse_vars, trainables), var_fn) + var_sizes = [spv.size() for spv in sparse_vars] + + self.assertTrue(all(sz >= trigger for sz in var_sizes)) + var_guard_by_tstp.restrict(num_reserved, trigger=trigger) + var_guard_by_freq.restrict(num_reserved, trigger=trigger) + var_sizes = [spv.size() for spv in sparse_vars] + self.assertAllEqual(var_sizes, [num_reserved, num_reserved]) + + slot_params = [] + for _trainable in trainables: + slot_params += [ + optmz.get_slot(_trainable, name).params + for name in optmz.get_slot_names() + ] + slot_params = list(set(slot_params)) + + for sp in slot_params: + self.assertAllEqual(sp.size(), num_reserved) + self.assertAllEqual(var_guard_by_tstp.restrict_policy.status.size(), + num_reserved) + self.assertAllEqual(var_guard_by_freq.restrict_policy.status.size(), + num_reserved) + + +if __name__ == "__main__": + if DELETE_DATABASE_AT_STARTUP: + shutil.rmtree(ROCKSDB_CONFIG_PARAMS['database_path'], ignore_errors=True) + test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD index add01f623..5c898af00 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD @@ -15,6 +15,7 @@ py_library( "//tensorflow_recommenders_addons/dynamic_embedding/core:_cuckoo_hashtable_ops.so", "//tensorflow_recommenders_addons/dynamic_embedding/core:_math_ops.so", "//tensorflow_recommenders_addons/dynamic_embedding/core:_redis_table_ops.so", + "//tensorflow_recommenders_addons/dynamic_embedding/core:_rocksdb_table_ops.so", ], srcs_version = "PY2AND3", ) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py index 3b7656ea7..9f18576f6 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Recommenders-Addons Authors. +# Copyright 2021 The TensorFlow Recommenders-Addons Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ from tensorflow.python.eager import context from tensorflow.python.ops import gen_parsing_ops from tensorflow_recommenders_addons import dynamic_embedding as de +import json class KVCreator(object, metaclass=ABCMeta): @@ -194,3 +195,64 @@ def create( checkpoint=checkpoint, config=self.config, ) + + +class RocksDBTableConfig(object): + """ + RocksDBTableConfig config json file for loading a RocksDB database. + An example of a configuration file is shown below: + "" + { + "database_path": "/tmp/file_system_path_to_where_the_database_path", + "embedding_name": "name_of_this_embedding", // We use RocksDB column families for this. + "read_only": false, // If true, the database is opened in read-only mode. Having multiple + read-only connections to the same database is possible. + "estimate_size": false, // If true, size() will only return estimates, which is orders of + magnitude faster but could be inaccurate. + "export_path": "/tmp/some_path, // If set, export/import will dump/restore database to/from + filesystem. + } + "" + """ + + def __init__( + self, + src="/tmp/rocksdb_config.json", + ): + if isinstance(src, str): + with open(src, 'r', encoding='utf-8') as src: + self.params = json.load(src) + elif isinstance(src, dict): + self.params = {k: v for k, v in src.items()} + else: + raise ValueError + + +class RocksDBTableCreator(KVCreator): + """ + RedisTableCreator will create a object to pass itself to the others classes + for creating a real RocksDB client instance which can interact with TF. + """ + + def create( + self, + key_dtype=None, + value_dtype=None, + default_value=None, + name=None, + checkpoint=None, + init_size=None, + config=None, + ): + real_config = config if config is not None else self.config + if not isinstance(real_config, RocksDBTableConfig): + raise TypeError("config should be instance of 'config', but got ", + str(type(real_config))) + return de.RocksDBTable( + key_dtype=key_dtype, + value_dtype=value_dtype, + default_value=default_value, + name=name, + checkpoint=checkpoint, + config=real_config, + ) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/rocksdb_table_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/rocksdb_table_ops.py new file mode 100644 index 000000000..da3c72a2b --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/rocksdb_table_ops.py @@ -0,0 +1,352 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RocksDB Lookup operations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops.lookup_ops import LookupInterface +from tensorflow.python.training.saver import BaseSaverBuilder + +from tensorflow_recommenders_addons.utils.resource_loader import LazySO +from tensorflow_recommenders_addons.utils.resource_loader import prefix_op_name + +rocksdb_table_ops = LazySO("dynamic_embedding/core/_rocksdb_table_ops.so").ops + + +class RocksDBTable(LookupInterface): + """ + Transparently redirects the lookups to a RocksDB database. + + Data can be inserted by calling the insert method and removed by calling the + remove method. Initialization via the init method is not supported. + + Example usage: + + ```python + table = tfra.dynamic_embedding.RocksDBTable(key_dtype=tf.string, + value_dtype=tf.int64, + default_value=-1) + sess.run(table.insert(keys, values)) + out = table.lookup(query_keys) + print(out.eval()) + ``` + """ + + default_rocksdb_params = {"model_lib_abs_dir": "/tmp/"} + + def __init__( + self, + key_dtype, + value_dtype, + default_value, + name="RocksDBTable", + checkpoint=False, + config=None, + ): + """ + Creates an empty `RocksDBTable` object. + + Creates a RocksDB table through OS environment variables, the type of its keys and values + are specified by key_dtype and value_dtype, respectively. + + Args: + key_dtype: the type of the key tensors. + value_dtype: the type of the value tensors. + default_value: The value to use if a key is missing in the table. + name: A name for the operation (optional, usually it's embedding table name). + checkpoint: if True, the contents of the table are saved to and restored + from a RocksDB binary dump files according to the directory "[model_lib_abs_dir]/[model_tag]/[name].rdb". + If `shared_name` is empty for a checkpointed table, it is shared using the table node name. + + Returns: + A `RocksDBTable` object. + + Raises: + ValueError: If checkpoint is True and no name was specified. + """ + + self._default_value = ops.convert_to_tensor(default_value, + dtype=value_dtype) + self._value_shape = self._default_value.get_shape() + self._checkpoint = checkpoint + self._key_dtype = key_dtype + self._value_dtype = value_dtype + self._name = name + + self._database_path = config.params['database_path'] + self._embedding_name = config.params['embedding_name'] + if not self._embedding_name: + self._embedding_name = self._name.split('_mht_', 1)[0] + self._read_only = config.params['read_only'] + self._estimate_size = config.params['estimate_size'] + self._export_path = config.params['export_path'] + + self._shared_name = None + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + # TODO(rohanj): Use context.shared_name() instead. + self._shared_name = "table_%d" % (ops.uid(),) + super().__init__(key_dtype, value_dtype) + + self._resource_handle = self._create_resource() + if checkpoint: + _ = self._Saveable(self, name) + if not context.executing_eagerly(): + self.saveable = self._Saveable( + self, + name=self._resource_handle.op.name, + full_name=self._resource_handle.op.name, + ) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self.saveable) + else: + self.saveable = self._Saveable(self, name=name, full_name=name) + + def _create_resource(self): + # The table must be shared if checkpointing is requested for multi-worker + # training to work correctly. Use the node name if no shared_name has been + # explicitly specified. + use_node_name_sharing = self._checkpoint and self._shared_name is None + + table_ref = rocksdb_table_ops.tfra_rocksdb_table_of_tensors( + shared_name=self._shared_name, + use_node_name_sharing=use_node_name_sharing, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + value_shape=self._default_value.get_shape(), + database_path=self._database_path, + embedding_name=self._embedding_name, + read_only=self._read_only, + estimate_size=self._estimate_size, + export_path=self._export_path, + ) + + if context.executing_eagerly(): + self._table_name = None + else: + self._table_name = table_ref.op.name.split("/")[-1] + return table_ref + + @property + def name(self): + return self._table_name + + def size(self, name=None): + """ + Compute the number of elements in this table. + + Args: + name: A name for the operation (optional). + + Returns: + A scalar tensor containing the number of elements in this table. + """ + with ops.name_scope(name, f"{self.name}_Size", (self.resource_handle,)): + with ops.colocate_with(self.resource_handle): + size = rocksdb_table_ops.tfra_rocksdb_table_size(self.resource_handle) + + return size + + def remove(self, keys, name=None): + """ + Removes `keys` and its associated values from the table. + + If a key is not present in the table, it is silently ignored. + + Args: + keys: Keys to remove. Can be a tensor of any shape. Must match the table's key type. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `keys` do not match the table data types. + """ + if keys.dtype != self._key_dtype: + raise TypeError( + f"Signature mismatch. Keys must be dtype {self._key_dtype}, got {keys.dtype}." + ) + + with ops.name_scope( + name, + f"{self.name}_lookup_table_remove", + (self.resource_handle, keys, self._default_value), + ): + op = rocksdb_table_ops.tfra_rocksdb_table_remove(self.resource_handle, + keys) + + return op + + def clear(self, name=None): + """ + Clear all keys and values in the table. + + Args: + name: A name for the operation (optional). + + Returns: + The created Operation. + """ + with ops.name_scope(name, f"{self.name}_lookup_table_clear", + (self.resource_handle, self._default_value)): + op = rocksdb_table_ops.tfra_rocksdb_table_clear( + self.resource_handle, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype) + + return op + + def lookup(self, + keys, + dynamic_default_values=None, + return_exists=False, + name=None): + """ + Looks up `keys` in a table, outputs the corresponding values. + + The `default_value` is used for keys not present in the table. + + Args: + keys: Keys to look up. Can be a tensor of any shape. Must match the + table's key_dtype. + dynamic_default_values: The values to use if a key is missing in the table. If None (by + default), the static default_value `self._default_value` will be used. + return_exists: if True, will return a additional Tensor which indicates + if or not keys are existing in the table. + name: A name for the operation (optional). + + Returns: + A tensor containing the values in the same shape as `keys` using the table's value type. + + Raises: + TypeError: when `keys` do not match the table data types. + """ + with ops.name_scope(name, f"{self.name}_lookup_table_find", + (self.resource_handle, keys, self._default_value)): + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") + with ops.colocate_with(self.resource_handle): + if return_exists: + values, exists = redis_table_ops.tfra_redis_table_find_with_exists( + self.resource_handle, + keys, + dynamic_default_values + if dynamic_default_values is not None else self._default_value, + ) + else: + values = rocksdb_table_ops.tfra_rocksdb_table_find( + self.resource_handle, + keys, + dynamic_default_values + if dynamic_default_values is not None else self._default_value, + ) + return (values, exists) if return_exists else values + + def insert(self, keys, values, name=None): + """ + Associates `keys` with `values`. + + Args: + keys: Keys to insert. Can be a tensor of any shape. Must match the table's key type. + values: Values to be associated with keys. Must be a tensor of the same shape as `keys` and + match the table's value type. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `keys` or `values` doesn't match the table data types. + """ + with ops.name_scope(name, f"{self.name}_lookup_table_insert", + (self.resource_handle, keys, values)): + keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") + values = ops.convert_to_tensor(values, self._value_dtype, name="values") + + with ops.colocate_with(self.resource_handle): + op = rocksdb_table_ops.tfra_rocksdb_table_insert( + self.resource_handle, keys, values) + + return op + + def export(self, name=None): + """ + Returns nothing in RocksDB Implement. It will dump some binary files to model_lib_abs_dir. + + Args: + name: A name for the operation (optional). + + Returns: + A pair of tensors with the first tensor containing all keys and the second tensors + containing all values in the table. + """ + with ops.name_scope(name, f"{self.name}_lookup_table_export_values", + (self.resource_handle,)): + with ops.colocate_with(self.resource_handle): + exported_keys, exported_values = rocksdb_table_ops.tfra_rocksdb_table_export( + self.resource_handle, self._key_dtype, self._value_dtype) + + return exported_keys, exported_values + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + # full_name helps to figure out the name-based Saver's name for this saveable. + # if context.executing_eagerly(): + # full_name = self._table_name + # else: + # full_name = self._resource_handle.op.name + full_name = self._table_name + return { + "table": + functools.partial( + RocksDBTable._Saveable, + table=self, + name=self._name, + full_name=full_name, + ) + } + + class _Saveable(BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for RocksDBTable.""" + + def __init__(self, table, name, full_name=""): + tensors = table.export() + specs = [ + BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), + BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values"), + ] + super().__init__(table, specs, name) + self._restore_name = table._name + + def restore(self, restored_tensors, restored_shapes, name=None): + del restored_shapes # unused + # pylint: disable=protected-access + with ops.name_scope(name, f"{self._restore_name}_table_restore"): + with ops.colocate_with(self.op.resource_handle): + return rocksdb_table_ops.tfra_rocksdb_table_import( + self.op.resource_handle, + restored_tensors[0], + restored_tensors[1], + ) + + +ops.NotDifferentiable(prefix_op_name("RocksDBTableOfTensors")) diff --git a/tools/docker/install/install_rocksdb.sh b/tools/docker/install/install_rocksdb.sh index 5c17ed887..41d33bf7c 100755 --- a/tools/docker/install/install_rocksdb.sh +++ b/tools/docker/install/install_rocksdb.sh @@ -51,7 +51,6 @@ cd /tmp/rocksdb-$ROCKSDB_VERSION DEBUG_LEVEL=0 make static_lib -j \ EXTRA_CXXFLAGS="-fPIC -D_GLIBCXX_USE_CXX11_ABI=0" \ EXTRA_CFLAGS="-fPIC -D_GLIBCXX_USE_CXX11_ABI=0" -chmod -R 777 /tmp/rocksdb-$ROCKSDB_VERSION/librocksdb* -cp /tmp/rocksdb-$ROCKSDB_VERSION/librocksdb* ${install_dir} +make install rm -f /tmp/$ROCKSDB_VERSION.tar.gz rm -rf /tmp/rocksdb-${ROCKSDB_VERSION} diff --git a/tools/docker/sanity_check.Dockerfile b/tools/docker/sanity_check.Dockerfile index dd6a4b753..c7e0588dc 100644 --- a/tools/docker/sanity_check.Dockerfile +++ b/tools/docker/sanity_check.Dockerfile @@ -24,7 +24,7 @@ RUN --mount=type=cache,id=cache_pip,target=/root/.cache/pip \ -r typedapi.txt \ -r pytest.txt -RUN apt-get update && apt-get install -y sudo rsync cmake +RUN apt-get update && apt-get install -y sudo rsync cmake libbz2-dev liblz4-dev libzstd-dev COPY tools/docker/install/install_bazel.sh ./ RUN ./install_bazel.sh $USE_BAZEL_VERSION @@ -105,7 +105,7 @@ RUN pip install -r requirements.txt COPY tools/install_deps/doc_requirements.txt ./ RUN pip install -r doc_requirements.txt -RUN apt-get update && apt-get install -y sudo rsync cmake +RUN apt-get update && apt-get install -y sudo rsync cmake libbz2-dev liblz4-dev libzstd-dev COPY tools/docker/install/install_bazel.sh ./ RUN ./install_bazel.sh $USE_BAZEL_VERSION