Skip to content

Commit abb3772

Browse files
Lifannrhdong
authored andcommitted
[feat] Support for HKV by adding HKV source code
1 parent b33afb8 commit abb3772

22 files changed

+9018
-0
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,5 @@ custom_op_library(
114114
"utils/utils.h",
115115
],
116116
)
117+
118+
# TODO: Add hkv targets.
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/op_def_builder.h"
19+
#include "tensorflow/core/framework/shape_inference.h"
20+
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
21+
22+
namespace tensorflow {
23+
24+
using shape_inference::DimensionHandle;
25+
using shape_inference::InferenceContext;
26+
using shape_inference::ShapeAndType;
27+
using shape_inference::ShapeHandle;
28+
29+
namespace {
30+
31+
Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
32+
ShapeHandle handle;
33+
DimensionHandle unused_handle;
34+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
35+
for (int i = 1; i < c->num_inputs(); ++i) {
36+
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
37+
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
38+
}
39+
for (int i = 0; i < c->num_outputs(); ++i) {
40+
c->set_output(i, c->Scalar());
41+
}
42+
return Status::OK();
43+
}
44+
45+
} // namespace
46+
47+
Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys,
48+
const string& key_dtype_attr,
49+
const string& value_dtype_attr,
50+
bool is_lookup,
51+
ShapeAndType* output_shape_and_type) {
52+
auto* handle_data = c->input_handle_shapes_and_types(0);
53+
if (handle_data == nullptr || handle_data->size() != 2) {
54+
output_shape_and_type->shape = c->UnknownShape();
55+
output_shape_and_type->dtype = DT_INVALID;
56+
} else {
57+
const ShapeAndType& key_shape_and_type = (*handle_data)[0];
58+
const ShapeAndType& value_shape_and_type = (*handle_data)[1];
59+
DataType key_dtype;
60+
TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype));
61+
if (key_shape_and_type.dtype != key_dtype) {
62+
return errors::InvalidArgument(
63+
"Trying to read value with wrong dtype. "
64+
"Expected ",
65+
DataTypeString(key_shape_and_type.dtype), " got ",
66+
DataTypeString(key_dtype));
67+
}
68+
DataType value_dtype;
69+
TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype));
70+
if (value_shape_and_type.dtype != value_dtype) {
71+
return errors::InvalidArgument(
72+
"Trying to read value with wrong dtype. "
73+
"Expected ",
74+
DataTypeString(value_shape_and_type.dtype), " got ",
75+
DataTypeString(value_dtype));
76+
}
77+
output_shape_and_type->dtype = value_shape_and_type.dtype;
78+
79+
if (is_lookup) {
80+
if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) {
81+
int keys_rank = c->Rank(keys);
82+
int key_suffix_rank = c->Rank(key_shape_and_type.shape);
83+
if (keys_rank < key_suffix_rank) {
84+
return errors::InvalidArgument(
85+
"Expected keys to have suffix ",
86+
c->DebugString(key_shape_and_type.shape),
87+
" but saw shape: ", c->DebugString(keys));
88+
}
89+
for (int d = 0; d < key_suffix_rank; d++) {
90+
// Ensure the suffix of keys match what's in the Table.
91+
DimensionHandle dim = c->Dim(key_shape_and_type.shape, d);
92+
TF_RETURN_IF_ERROR(
93+
c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys));
94+
}
95+
std::vector<DimensionHandle> keys_prefix_vec;
96+
keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
97+
for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
98+
keys_prefix_vec.push_back(c->Dim(keys, d));
99+
}
100+
ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec);
101+
TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix,
102+
value_shape_and_type.shape,
103+
&output_shape_and_type->shape));
104+
} else {
105+
output_shape_and_type->shape = c->UnknownShape();
106+
}
107+
} else {
108+
TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape,
109+
&output_shape_and_type->shape));
110+
}
111+
}
112+
return Status::OK();
113+
}
114+
115+
Status HkvHashTableShape(InferenceContext* c, const ShapeHandle& key,
116+
const ShapeHandle& value) {
117+
c->set_output(0, c->Scalar());
118+
119+
ShapeHandle key_s;
120+
TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s));
121+
122+
DataType key_t;
123+
TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t));
124+
125+
DataType value_t;
126+
TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t));
127+
128+
c->set_output_handle_shapes_and_types(
129+
0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}});
130+
131+
return Status::OK();
132+
}
133+
134+
REGISTER_OP("TfraHkvHashTableFind")
135+
.Input("table_handle: resource")
136+
.Input("keys: Tin")
137+
.Input("default_value: Tout")
138+
.Output("values: Tout")
139+
.Attr("Tin: type")
140+
.Attr("Tout: type")
141+
.SetShapeFn([](InferenceContext* c) {
142+
ShapeHandle handle;
143+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
144+
145+
ShapeAndType value_shape_and_type;
146+
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
147+
c,
148+
/*keys=*/c->input(1),
149+
/*key_dtype_attr=*/"Tin",
150+
/*value_dtype_attr=*/"Tout",
151+
/*is_lookup=*/true, &value_shape_and_type));
152+
c->set_output(0, value_shape_and_type.shape);
153+
154+
return Status::OK();
155+
});
156+
157+
REGISTER_OP("TfraHkvHashTableFindWithExists")
158+
.Input("table_handle: resource")
159+
.Input("keys: Tin")
160+
.Input("default_value: Tout")
161+
.Output("values: Tout")
162+
.Output("exists: bool")
163+
.Attr("Tin: type")
164+
.Attr("Tout: type")
165+
.SetShapeFn([](InferenceContext* c) {
166+
ShapeHandle handle;
167+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
168+
169+
ShapeHandle keys = c->UnknownShapeOfRank(1);
170+
ShapeAndType value_shape_and_type;
171+
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
172+
c,
173+
/*keys=*/c->input(1),
174+
/*key_dtype_attr=*/"Tin",
175+
/*value_dtype_attr=*/"Tout",
176+
/*is_lookup=*/true, &value_shape_and_type));
177+
c->set_output(0, value_shape_and_type.shape);
178+
c->set_output(1, keys);
179+
180+
return Status::OK();
181+
});
182+
183+
REGISTER_OP("TfraHkvHashTableInsert")
184+
.Input("table_handle: resource")
185+
.Input("keys: Tin")
186+
.Input("values: Tout")
187+
.Attr("Tin: type")
188+
.Attr("Tout: type")
189+
.SetShapeFn([](InferenceContext* c) {
190+
ShapeHandle handle;
191+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
192+
193+
// TODO: Validate keys and values shape.
194+
return Status::OK();
195+
});
196+
197+
REGISTER_OP("TfraHkvHashTableAccum")
198+
.Input("table_handle: resource")
199+
.Input("keys: key_dtype")
200+
.Input("values_or_deltas: value_dtype")
201+
.Input("exists: bool")
202+
.Attr("key_dtype: type")
203+
.Attr("value_dtype: type")
204+
.SetShapeFn([](InferenceContext* c) {
205+
ShapeHandle handle;
206+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
207+
208+
// TODO: Validate keys and values shape.
209+
return Status::OK();
210+
});
211+
212+
REGISTER_OP("TfraHkvHashTableRemove")
213+
.Input("table_handle: resource")
214+
.Input("keys: Tin")
215+
.Attr("Tin: type")
216+
.SetShapeFn([](InferenceContext* c) {
217+
ShapeHandle handle;
218+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
219+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle));
220+
221+
// TODO(turboale): Validate keys shape.
222+
return Status::OK();
223+
});
224+
225+
REGISTER_OP("TfraHkvHashTableClear")
226+
.Input("table_handle: resource")
227+
.Attr("key_dtype: type")
228+
.Attr("value_dtype: type");
229+
230+
REGISTER_OP("TfraHkvHashTableSize")
231+
.Input("table_handle: resource")
232+
.Output("size: int64")
233+
.SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs);
234+
235+
REGISTER_OP("TfraHkvHashTableExport")
236+
.Input("table_handle: resource")
237+
.Output("keys: Tkeys")
238+
.Output("values: Tvalues")
239+
.Attr("Tkeys: type")
240+
.Attr("Tvalues: type")
241+
.SetShapeFn([](InferenceContext* c) {
242+
ShapeHandle handle;
243+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
244+
ShapeHandle keys = c->UnknownShapeOfRank(1);
245+
ShapeAndType value_shape_and_type;
246+
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
247+
c,
248+
/*keys=*/keys,
249+
/*key_dtype_attr=*/"Tkeys",
250+
/*value_dtype_attr=*/"Tvalues",
251+
/*is_lookup=*/false, &value_shape_and_type));
252+
c->set_output(0, keys);
253+
c->set_output(1, value_shape_and_type.shape);
254+
return Status::OK();
255+
});
256+
257+
REGISTER_OP("TfraHkvHashTableSaveToFileSystem")
258+
.Input("table_handle: resource")
259+
.Input("dirpath: string")
260+
.Input("file_name: string")
261+
.Attr("key_dtype: type")
262+
.Attr("value_dtype: type")
263+
.Attr("dirpath_env: string")
264+
.Attr("append_to_file: bool")
265+
.Attr("buffer_size: int >= 1");
266+
267+
REGISTER_OP("TfraHkvHashTableImport")
268+
.Input("table_handle: resource")
269+
.Input("keys: Tin")
270+
.Input("values: Tout")
271+
.Attr("Tin: type")
272+
.Attr("Tout: type")
273+
.SetShapeFn([](InferenceContext* c) {
274+
ShapeHandle handle;
275+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
276+
277+
ShapeHandle keys;
278+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
279+
TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
280+
return Status::OK();
281+
});
282+
283+
REGISTER_OP("TfraHkvHashTableLoadFromFileSystem")
284+
.Input("table_handle: resource")
285+
.Input("dirpath: string")
286+
.Input("file_name: string")
287+
.Attr("key_dtype: type")
288+
.Attr("value_dtype: type")
289+
.Attr("dirpath_env: string")
290+
.Attr("load_entire_dir: bool")
291+
.Attr("buffer_size: int >= 1");
292+
293+
REGISTER_OP("TfraHkvHashTableOfTensors")
294+
.Output("table_handle: resource")
295+
.Attr("container: string = ''")
296+
.Attr("shared_name: string = ''")
297+
.Attr("use_node_name_sharing: bool = false")
298+
.Attr("key_dtype: type")
299+
.Attr("value_dtype: type")
300+
.Attr("value_shape: shape = {}")
301+
.Attr("init_capacity: int = 0")
302+
.Attr("max_capacity: int = 0")
303+
.SetIsStateful()
304+
.SetShapeFn([](InferenceContext* c) {
305+
PartialTensorShape value_p;
306+
TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
307+
ShapeHandle value_s;
308+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
309+
return HkvHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s);
310+
});
311+
} // namespace tensorflow

0 commit comments

Comments
 (0)