Skip to content

Commit 8c25259

Browse files
candyzonelixy9474
authored andcommitted
[Embedding] Support immutable EmbeddingVariable in inference mode. (#425)
1 parent cec4417 commit 8c25259

File tree

6 files changed

+56
-1
lines changed

6 files changed

+56
-1
lines changed

Diff for: tensorflow/core/framework/embedding/embedding_filter.h

+12
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ class EmbeddingFilter {
4545
public:
4646
virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr,
4747
ValuePtr<V>** value_ptr, int count) = 0;
48+
49+
virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr) {
50+
ValuePtr<V>* value_ptr = nullptr;
51+
Status s = ev->LookupKey(key, &value_ptr);
52+
if (s.ok()) {
53+
V* mem_val = ev->LookupPrimaryEmb(value_ptr);
54+
memcpy(val, mem_val, sizeof(V) * ev->ValueLen());
55+
} else {
56+
memcpy(val, default_value_ptr, sizeof(V) * ev->ValueLen());
57+
}
58+
}
59+
4860
virtual Status LookupOrCreateKey(K key, ValuePtr<V>** val, bool* is_filter) = 0;
4961

5062
virtual int64 GetFreq(K key, ValuePtr<V>* value_ptr) = 0;

Diff for: tensorflow/core/framework/embedding/embedding_var.h

+9
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ class EmbeddingVar : public ResourceBase {
9999
return is_initialized_;
100100
}
101101

102+
Status LookupKey(K key, ValuePtr<V>** value_ptr) {
103+
return storage_manager_->Get(key, value_ptr);
104+
}
105+
102106
Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr, bool* is_filter) {
103107
return filter_->LookupOrCreateKey(key, value_ptr, is_filter);
104108
}
@@ -127,6 +131,11 @@ class EmbeddingVar : public ResourceBase {
127131
return filter_->GetFreq(key);
128132
}
129133

134+
void Lookup(K key, V* val, V* default_v) {
135+
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
136+
filter_->Lookup(this, key, val, default_value_ptr);
137+
}
138+
130139
void LookupOrCreate(K key, V* val, V* default_v, int count = 1) {
131140
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
132141
ValuePtr<V>* value_ptr = nullptr;

Diff for: tensorflow/core/framework/embedding/multilevel_embedding.h

+12
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ class StorageManager {
223223
}
224224
}
225225

226+
Status Get(K key, ValuePtr<V>** value_ptr) {
227+
Status s;
228+
int level = 0;
229+
for (; level < hash_table_count_; ++level) {
230+
s = kvs_[level].first->Lookup(key, value_ptr);
231+
if (s.ok()) {
232+
break;
233+
}
234+
}
235+
return s;
236+
}
237+
226238
Status GetOrCreate(K key, ValuePtr<V>** value_ptr, size_t size) {
227239
bool found = false;
228240
int level = 0;

Diff for: tensorflow/core/kernels/kv_variable_ops.cc

+20-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ using GPUDevice = Eigen::GpuDevice;
5252
namespace {
5353
const int64 kEmbeddingVarUseDB = -214;
5454
const int64 kInitializableEmbeddingVarUseDB = -215;
55+
const char* kInferenceMode = "INFERENCE_MODE";
5556
}
5657

5758
#define REGISTER_KV_VAR_HANDLE(ktype, vtype) \
@@ -370,6 +371,10 @@ template <typename TKey, typename TValue>
370371
class KvResourceGatherOp : public OpKernel {
371372
public:
372373
explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
374+
OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_));
375+
bool is_inference;
376+
TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference));
377+
is_inference_ |= is_inference;
373378
OP_REQUIRES_OK(c,
374379
c->GetAttr("is_use_default_value_tensor",
375380
&is_use_default_value_tensor_));
@@ -393,6 +398,17 @@ class KvResourceGatherOp : public OpKernel {
393398
return 1;
394399
};
395400
}
401+
if (!is_inference_) {
402+
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
403+
TValue* val, TValue* default_v, int count) {
404+
ev->LookupOrCreate(key, val, default_v, count);
405+
};
406+
} else {
407+
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
408+
TValue* val, TValue* default_v, int count) {
409+
ev->Lookup(key, val, default_v);
410+
};
411+
}
396412
}
397413

398414
void Compute(OpKernelContext* c) override {
@@ -443,7 +459,7 @@ class KvResourceGatherOp : public OpKernel {
443459
default_v, indices_flat(i), i, ev->GetDefaultValueDim(),
444460
ev->ValueLen());
445461
int32 count = get_count_fn_(counts, i);
446-
ev->LookupOrCreate(indices_flat(i),
462+
lookup_fn_(ev, indices_flat(i),
447463
out_base + i * slice_elems, default_v_ptr, count);
448464
}
449465
};
@@ -463,9 +479,12 @@ class KvResourceGatherOp : public OpKernel {
463479

464480
private:
465481
bool is_use_default_value_tensor_;
482+
bool is_inference_;
466483
std::function<
467484
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
468485
std::function<int32(int32*, int64)> get_count_fn_;
486+
std::function<void(EmbeddingVar<TKey, TValue>* ev,
487+
TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
469488
};
470489

471490
#define REGISTER_GATHER_FULL(dev, ktype, vtype) \

Diff for: tensorflow/core/ops/kv_variable_ops.cc

+2
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ REGISTER_OP("KvResourceGatherV1")
231231
.Input("counts: counts_type")
232232
.Attr("validate_indices: bool = true")
233233
.Attr("is_use_default_value_tensor: bool = false")
234+
.Attr("is_inference: bool = false")
234235
.Output("output: dtype")
235236
.Attr("dtype: type")
236237
.Attr("Tkeys: {int64,int32,string}")
@@ -281,6 +282,7 @@ REGISTER_OP("KvResourceGather")
281282
.Output("output: dtype")
282283
.Attr("dtype: type")
283284
.Attr("Tkeys: {int64,int32,string}")
285+
.Attr("is_inference: bool = false")
284286
.SetShapeFn([](InferenceContext* c) {
285287
ShapeAndType handle_shape_and_type;
286288
TF_RETURN_IF_ERROR(

Diff for: tensorflow/python/ops/embedding_variable_ops_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from six.moves import xrange # pylint: disable=redefined-builtin
1818

19+
from tensorflow.core.framework import attr_value_pb2
1920
from tensorflow.python.framework import ops
2021
from tensorflow.python.framework import test_util
2122
from tensorflow.python.ops import string_ops

0 commit comments

Comments
 (0)