Skip to content

Commit 96fc416

Browse files
authored
[Feat] Try to compatible with Keras 3 optimizer design and support CUDNN 9.0+. (#392)
* [fix] self.params may not with saveable attribution. * [feat] Update the config file for new CUDNN release. Now support CUDNN 9.0+ * [fix] Suppress nodiscard warnings and unused warnings by adding a new LOG_IF_ERROR macro and using proper api function. * [feat] Competible with Keras3 optimizer style.
1 parent ba467f4 commit 96fc416

File tree

7 files changed

+57
-42
lines changed

7 files changed

+57
-42
lines changed

build_deps/toolchains/gpu/cuda_configure.bzl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -960,19 +960,28 @@ def _create_local_cuda_repository(repository_ctx):
960960
# Copy cudnn.h if cuDNN was not installed to CUDA_TOOLKIT_PATH.
961961
included_files = _read_dir(repository_ctx, cuda_include_path)
962962
if not any([file.endswith("cudnn.h") for file in included_files]):
963-
if [int(x) for x in cuda_config.cudnn_version.split(".")] < [8, 0]:
964-
cudnn_headers = ["cudnn.h"]
965-
else:
966-
cudnn_headers = [
963+
cudnn_headers = ["cudnn.h"]
964+
if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "9":
965+
cudnn_headers += [
966+
"cudnn_adv.h",
967+
"cudnn_backend.h",
968+
"cudnn_cnn.h",
969+
"cudnn_graph.h",
970+
"cudnn_ops.h",
971+
"cudnn_version.h",
972+
]
973+
elif cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8":
974+
cudnn_headers += [
975+
"cudnn_backend.h",
967976
"cudnn_adv_infer.h",
968977
"cudnn_adv_train.h",
969978
"cudnn_cnn_infer.h",
970979
"cudnn_cnn_train.h",
971980
"cudnn_ops_infer.h",
972981
"cudnn_ops_train.h",
973-
"cudnn.h",
974982
"cudnn_version.h",
975983
]
984+
976985
cudnn_srcs = []
977986
cudnn_outs = []
978987
for header in cudnn_headers:

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/dynamic_partition_op_gpu.cu.cc

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,7 @@ void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
119119
values, num_runs, out_size, out));
120120
}
121121

122-
struct IdentityOp {
123-
__device__ int32 __forceinline__ operator()(const int32& a) const {
124-
return a;
125-
}
126-
};
122+
struct IdentityOp {};
127123

128124
// Define an output iterator that only allows assignment to
129125
// positions between [base, base + limit).
@@ -162,27 +158,10 @@ class BoundedOutputIterator
162158
IdentityOp op, int32 size)
163159
: TransformOutputIterator(ptr, op), limit(size), base(base) {}
164160

165-
// Indirection
166-
__host__ __device__ __forceinline__ reference operator*() const {
167-
return BoundedReference(ptr, base, conversion_op, limit);
168-
}
169-
170161
// Array subscript
171162
__host__ __device__ __forceinline__ reference operator[](int32 n) const {
172163
return BoundedReference(ptr + n, base, conversion_op, limit);
173164
}
174-
175-
// Addition
176-
__host__ __device__ __forceinline__ self_type operator+(int32 n) const {
177-
self_type retval(ptr + n, base, conversion_op, limit);
178-
return retval;
179-
}
180-
181-
// Subtraction
182-
__host__ __device__ __forceinline__ self_type operator-(int32 n) const {
183-
self_type retval(ptr - n, base, conversion_op, limit);
184-
return retval;
185-
}
186165
};
187166

188167
} // namespace

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ class HashTableExportKeysAndScoresGpuOp : public OpKernel {
839839
public:
840840
explicit HashTableExportKeysAndScoresGpuOp(OpKernelConstruction* ctx)
841841
: OpKernel(ctx) {
842-
ctx->GetAttr("split_size", &split_size_i64_);
842+
OP_REQUIRES_OK(ctx, ctx->GetAttr("split_size", &split_size_i64_));
843843
}
844844

845845
void Compute(OpKernelContext* ctx) override {

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
207207

208208
void close() {
209209
if (key_writer_) {
210-
key_writer_->Flush();
210+
TFRA_LOG_IF_ERROR(key_writer_->Flush());
211211
}
212212
if (value_writer_) {
213-
value_writer_->Flush();
213+
TFRA_LOG_IF_ERROR(value_writer_->Flush());
214214
}
215215
}
216216

@@ -222,8 +222,9 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
222222
key_buffer_.reserve(key_read_byte);
223223
value_buffer_.reserve(value_read_byte);
224224

225-
key_reader_->ReadNBytes(key_read_byte, &key_buffer_);
226-
value_reader_->ReadNBytes(value_read_byte, &value_buffer_);
225+
TFRA_LOG_IF_ERROR(key_reader_->ReadNBytes(key_read_byte, &key_buffer_));
226+
TFRA_LOG_IF_ERROR(
227+
value_reader_->ReadNBytes(value_read_byte, &value_buffer_));
227228

228229
memcpy((char*)keys, key_buffer_.data(), key_buffer_.size());
229230
memcpy((char*)vectors, value_buffer_.data(), value_buffer_.size());
@@ -237,8 +238,10 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
237238
size_t key_write_byte = n * sizeof(K);
238239
size_t value_write_byte = n * sizeof(V) * value_dim_;
239240

240-
key_writer_->Append(StringPiece((char*)keys, key_write_byte));
241-
value_writer_->Append(StringPiece((char*)vectors, value_write_byte));
241+
TFRA_LOG_IF_ERROR(
242+
key_writer_->Append(StringPiece((char*)keys, key_write_byte)));
243+
TFRA_LOG_IF_ERROR(
244+
value_writer_->Append(StringPiece((char*)vectors, value_write_byte)));
242245

243246
return n;
244247
}
@@ -552,8 +555,8 @@ class TableWrapper {
552555
} else {
553556
wfile.reset(new RandomKVFile<K, V, uint64_t>(
554557
fs, filepath, dim, buffer_size, append_to_file));
555-
status = reinterpret_cast<RandomKVFile<K, V, uint64_t>*>(wfile.get())
556-
->open(keyfile, valuefile, "wb");
558+
status.Update(reinterpret_cast<RandomKVFile<K, V, uint64_t>*>(wfile.get())
559+
->open(keyfile, valuefile, "wb"));
557560
}
558561
if (!status.ok()) {
559562
std::string error_msg = "Failed to dump to file to " + keyfile + ", " +
@@ -603,8 +606,8 @@ class TableWrapper {
603606
} else {
604607
rfile.reset(
605608
new RandomKVFile<K, V, uint64_t>(fs, filepath, dim, buffer_size));
606-
status = reinterpret_cast<RandomKVFile<K, V, uint64_t>*>(rfile.get())
607-
->open(keyfile, valuefile, "rb");
609+
status.Update(reinterpret_cast<RandomKVFile<K, V, uint64_t>*>(rfile.get())
610+
->open(keyfile, valuefile, "rb"));
608611
}
609612
if (!status.ok()) {
610613
std::string error_msg = "Failed to load from file " + keyfile + ", " +

tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,24 @@ This code is for compatibility.*/
4242
} // namespace recommenders_addons
4343
} // namespace tensorflow
4444

45+
// For propagating errors when calling a function but not return status.
46+
#if TF_VERSION_INTEGER >= 2130
47+
#define TFRA_LOG_IF_ERROR(...) \
48+
do { \
49+
const auto _status = (__VA_ARGS__); \
50+
if (TF_PREDICT_FALSE(!_status.ok())) { \
51+
MAYBE_ADD_SOURCE_LOCATION(_status) \
52+
LOG(ERROR) << _status.message(); \
53+
} \
54+
} while (0)
55+
#else
56+
#define TFRA_LOG_IF_ERROR(...) \
57+
do { \
58+
const auto _status = (__VA_ARGS__); \
59+
if (TF_PREDICT_FALSE(!_status.ok())) { \
60+
LOG(ERROR) << _status.error_message(); \
61+
} \
62+
} while (0)
63+
#endif
64+
4565
#endif // TFRA_UTILS_H_

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,9 @@ def __init__(self,
559559
else:
560560
self._mpi_size = mpi_size
561561
super(HvdAllToAllEmbedding, self).__init__(*args, **kwargs)
562-
if type(self.params.saveable).__name__ not in de_fs_saveable_class_names:
562+
try:
563+
assert type(self.params.saveable).__name__ in de_fs_saveable_class_names
564+
except:
563565
tf_logging.warning(
564566
"Please use FileSystemSaver in KVCreator when use HvdAllToAllEmbedding. "
565567
"It will allow TFRA save and restore KV files when Embedding tensor parallel in distributed training. "

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def _update_step_fn(var, grad):
303303
if self.jit_compile:
304304
return self._update_step_xla(grad, var, id(self._var_key(var)))
305305
else:
306-
return script_ops.py_func_common(self._update_step, [grad, var], [])
306+
return self._update_step(grad, var)
307307

308308
if not isinstance(var, de.TrainableWrapper):
309309
return _update_step_fn(var, grad)
@@ -327,9 +327,9 @@ def _update_step_fn(var, grad):
327327
_before = [v0] + s0
328328

329329
with ops.control_dependencies(_before):
330-
_apply_op = _update_step_fn(var, grad)
330+
_update_step_fn(var, grad)
331331

332-
with ops.control_dependencies([_apply_op]):
332+
with ops.control_dependencies([var]):
333333
_after = control_flow_ops.group(
334334
[var.update_op(v0=v0)] +
335335
[_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)])
@@ -514,6 +514,8 @@ def _zeros_slot(var, slot_name, op_name):
514514
def _hvd_aggregate_gradients(hvd_handle,
515515
grads_and_vars_in,
516516
sparse_as_dense=True):
517+
if hvd_handle.size() <= 1:
518+
return grads_and_vars_in
517519
var_list = []
518520
aggregated_grad = []
519521
for grad, var in grads_and_vars_in:

0 commit comments

Comments
 (0)