-
Notifications
You must be signed in to change notification settings - Fork 361
[Modelzoo] Add serving for DIEN, DeepFM and WDL. #319
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
|
for #232 |
modelzoo/BST/pb_to_pbtxt.py
Outdated
@@ -0,0 +1,13 @@ | |||
from tensorflow.python.saved_model import loader_impl | |||
from tensorflow.python.lib.io import file_io | |||
from tensorflow.python.platform import tf_logging as logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless file.
modelzoo/BST/prepare_savedmodel.py
Outdated
@@ -0,0 +1,738 @@ | |||
import time | |||
import argparse | |||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要修改原文件,在原文件上添加代码即可。
modelzoo/BST/start_serving.cc
Outdated
#include "serving/processor/serving/processor.h" | ||
#include "serving/processor/serving/predict.pb.h" | ||
|
||
static const char* model_config = "{ \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除fake数据之外,需要根据实际的train data生成serving data,然后能够根据serving data进行process
CLA required. |
I have finished exporting savedmodel, extracting data from the input file to generate requests and serving for DIEN DeepFM WDL under modelzoo/features/EmbeddingVariable, please check. |
cat_voc = os.path.join(data_location, "cat_voc.pkl") | ||
|
||
def prepare_data(input, target, maxlen=None, return_neg=False): | ||
# x: a list of sentences |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeepRec中一般使用两个空格锁进,这些文件都稍微改一下
|
||
|
||
f.close() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多余的空格删除
::tensorflow::eas::ArrayShape array_shape; | ||
::tensorflow::eas::ArrayDataType dtype_f = | ||
::tensorflow::eas::ArrayDataType::DT_FLOAT; | ||
int num_elem = (int)cur_vector.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用(int)强转,会隐式转换的,或者定义为 size_t num_elem也可以
int num_elem = (int)cur_vector.size(); | ||
|
||
array_shape.add_dim(1); | ||
if((int)cur_vector.size() < 0){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if((int)cur_vector.size() < 0) ,cur_vector.size()不会小于0的,类型是size_t本身是正值。
|
||
return input; | ||
} | ||
array_shape.add_dim((int)cur_vector.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cur_vector.size()前面已经赋值了num_elem
@@ -415,6 +415,9 @@ def main(tf_config=None, server=None): | |||
|
|||
if tf_config: | |||
print('train steps : %d' % train_steps) | |||
print("-----------") | |||
print("-----------") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless code
@@ -0,0 +1,58 @@ | |||
## How to use prepare_savedmodel.py to get savedmodel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件名叫README.md就可以
|
||
} | ||
|
||
::tensorflow::eas::ArrayProto get_proto_f(float char_input,int dim,::tensorflow::eas::ArrayDataType type){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数重复
|
||
while (record != NULL) { | ||
// only 1 label and 39 feature | ||
if (j >= 40) break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
struct input_format39 inputs; | ||
inputs.I1 = (float)(atof(all_elems[start_idx])); | ||
inputs.I2 = (float)(atof(all_elems[start_idx+1])); | ||
inputs.I3 = (float)(atof(all_elems[start_idx+2])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个搞一个循环吧
先删除掉不必要的修改。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先把所有的文件整理一下
b52d097
to
576652e
Compare
## How to use prepare_savedmodel.py to get savedmodel | ||
|
||
- Current support model: \ | ||
BST, DBMTL, DeepFM, DIEN, DIN, DLRM, DSSM, ESMM, MMoE, SimpleMultiTask, WDL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now support DIEN, DeepFM and and WDL.
maxlen, | ||
data_location=data_location) | ||
|
||
f = open("./test_data.csv","w") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_data.csv 怎么得到的呢,可以在readme中写清楚
|
||
with tf.Session() as sess1: | ||
|
||
model = Model_DIN_V2_Gru_Vec_attGru_Neg(n_uid, n_mid, n_cat, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Model_DIN_V2_Gru_Vec_attGru_Neg是哪里import的?
input.set_dtype(dtype_f); | ||
|
||
switch(dtype_f){ | ||
case 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两个case中大部分代码都在重复的,可以只在
input.add_float_val((float)atof(cur_vector.back()));
input.add_int_val((int)atoi(cur_vector.back()));
代码加上if判断。
temp_ptrs.clear(); | ||
|
||
// traverse current line | ||
record = strtok(line, delim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面这个可以搞成类似 split(...) 函数吗,现在这样写比较hack。
|
||
// input setting | ||
::tensorflow::eas::ArrayProto I1 = get_proto_cc(&inputs.I1_13[0],1,dtype_f); | ||
::tensorflow::eas::ArrayProto I2 = get_proto_cc(&inputs.I1_13[1],1,dtype_f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
搞一个数组存储?
int main(int argc, char** argv) { | ||
|
||
// PLEASE EDIT THIS LINE!!!! | ||
char filepath[] = "/home/deeprec/DeepRec/modelzoo/features/EmbeddingVariable/WDL/test.csv"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件如何生成的呢
For every model listed above, there is a prepare_savedmodel.py. To run this script please firstly ensure you have gotten the checkpoint file from training. To use prepare_savedmodel.py, please use: | ||
|
||
``` | ||
cd [modelfolder] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
training的文件位置需要写清楚,用户可能找不到具体的training的文件所在。
或者你把training文件copy到你的目录下也可以。
int main(int argc, char** argv) { | ||
|
||
// PLEASE EDIT THIS LINE!!!! | ||
char filepath[] = "/home/deeprec/DeepRec/modelzoo/features/EmbeddingVariable/WDL/test.csv"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件如何生成的呢
::tensorflow::eas::ArrayDataType::DT_INT64; | ||
|
||
// input setting | ||
::tensorflow::eas::ArrayProto I1 = get_proto_cc(&inputs.I1_9[0],1,dtype_f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
搞一个数组存储
请问下,这个代码还是没有合并嘛? 现在仍然只有训练,没有推理部署,文档也没说 |
DeepRec 本身具备推理的能力,这个PR只是这三个model 的推理例子,你可以参考 |
I have finished exporting savedmodel for models in modelzoo. Under Deeprec/modelzoo there is a Get_SavedModel.md for new user to learn how to export and where the result model will locate.