Skip to content

Commit 8794ecf

Browse files
authored
Merge pull request #24 from MeirShpilraien/refactoring_stage_2 Refactoring stage 2
2 parents 2bfacce + 21abddd commit 8794ecf

File tree

8 files changed

+573
-156
lines changed

8 files changed

+573
-156
lines changed

src/graph.c

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
11
#include "graph.h"
2+
#include "utils/arr_rm_alloc.h"
23

34
RedisModuleType *RedisDL_GraphType = NULL;
45

6+
typedef struct RDL_Graph{
7+
TF_Graph* graph;
8+
// TODO: use session pool? The ideal would be to use one session per client.
9+
// If a client disconnects, we dispose the session or reuse it for
10+
// another client.
11+
void *session;
12+
size_t refCount;
13+
}RDL_Graph;
14+
15+
typedef struct RDL_GraphCtxParam{
16+
TF_Output name;
17+
RDL_Tensor* tensor;
18+
}RDL_GraphCtxParam;
19+
20+
typedef struct RDL_GraphRunCtx{
21+
RDL_Graph* graph;
22+
RDL_GraphCtxParam* inputs;
23+
RDL_GraphCtxParam* outputs;
24+
}RDL_GraphRunCtx;
25+
526
static void* Graph_RdbLoad(struct RedisModuleIO *io, int encver){
627
//todo
728
return NULL;
@@ -12,10 +33,10 @@ static void Graph_RdbSave(RedisModuleIO *rdb, void *value){
1233
}
1334

1435
static void Graph_DTFree(void *value){
15-
Graph_Free(value);
36+
RDL_GraphFree(value);
1637
}
1738

18-
int Graph_Init(RedisModuleCtx* ctx){
39+
int RDL_GraphInit(RedisModuleCtx* ctx){
1940
RedisModuleTypeMethods tmGraph = {
2041
.version = REDISMODULE_TYPE_METHOD_VERSION,
2142
.rdb_load = Graph_RdbLoad,
@@ -30,7 +51,7 @@ int Graph_Init(RedisModuleCtx* ctx){
3051
return RedisDL_GraphType != NULL;
3152
}
3253

33-
RDL_Graph* Graph_Create(const char* prefix, const char* graphdef, size_t graphlen){
54+
RDL_Graph* RDL_GraphCreate(const char* prefix, const char* graphdef, size_t graphlen){
3455
TF_Graph* graph = TF_NewGraph();
3556

3657
TF_ImportGraphDefOptions* options = TF_NewImportGraphDefOptions();
@@ -74,7 +95,10 @@ RDL_Graph* Graph_Create(const char* prefix, const char* graphdef, size_t graphle
7495
return ret;
7596
}
7697

77-
void Graph_Free(RDL_Graph* graph){
98+
void RDL_GraphFree(RDL_Graph* graph){
99+
if(--graph->refCount > 0){
100+
return;
101+
}
78102
TF_Status *status = TF_NewStatus();
79103
TF_CloseSession(graph->session, status);
80104

@@ -100,3 +124,102 @@ void Graph_Free(RDL_Graph* graph){
100124

101125
RedisModule_Free(graph);
102126
}
127+
128+
RDL_GraphRunCtx* RDL_RunCtxCreate(RDL_Graph* graph){
129+
#define PARAM_INITIAL_SIZE 10
130+
RDL_GraphRunCtx* gctx = RedisModule_Alloc(sizeof(*gctx));
131+
gctx->graph = RDL_GraphGetShallowCopy(graph);
132+
gctx->inputs = array_new(RDL_GraphCtxParam, PARAM_INITIAL_SIZE);
133+
gctx->outputs = array_new(RDL_GraphCtxParam, PARAM_INITIAL_SIZE);
134+
return gctx;
135+
}
136+
137+
static int Graph_RunCtxAddParam(RDL_GraphRunCtx* gctx, RDL_GraphCtxParam* paramArr, const char* name, RDL_Tensor* tensor){
138+
TF_Output port;
139+
port.oper = TF_GraphOperationByName(gctx->graph->graph, name);
140+
port.index = 0;
141+
if(port.oper == NULL){
142+
return 0;
143+
}
144+
RDL_GraphCtxParam param = {
145+
.name = port,
146+
.tensor = tensor ? RDL_TensorGetShallowCopy(tensor): NULL,
147+
};
148+
paramArr = array_append(paramArr, param);
149+
return 1;
150+
}
151+
152+
int RDL_RunCtxAddInput(RDL_GraphRunCtx* gctx, const char* inputName, RDL_Tensor* inputTensor){
153+
return Graph_RunCtxAddParam(gctx, gctx->inputs, inputName, inputTensor);
154+
}
155+
156+
int RDL_RunCtxAddOutput(RDL_GraphRunCtx* gctx, const char* outputName){
157+
return Graph_RunCtxAddParam(gctx, gctx->outputs, outputName, NULL);
158+
}
159+
160+
size_t RDL_RunCtxNumOutputs(RDL_GraphRunCtx* gctx){
161+
return array_len(gctx->outputs);
162+
}
163+
164+
RDL_Tensor* RDL_RunCtxOutputTensor(RDL_GraphRunCtx* gctx, size_t index){
165+
assert(RDL_RunCtxNumOutputs(gctx) > index && index >= 0);
166+
return gctx->outputs[index].tensor;
167+
}
168+
169+
void RDL_RunCtxFree(RDL_GraphRunCtx* gctx){
170+
for(size_t i = 0 ; i < array_len(gctx->inputs) ; ++i){
171+
RDL_TensorFree(gctx->inputs[i].tensor);
172+
}
173+
array_free(gctx->inputs);
174+
175+
for(size_t i = 0 ; i < array_len(gctx->outputs) ; ++i){
176+
if(gctx->outputs[i].tensor){
177+
RDL_TensorFree(gctx->outputs[i].tensor);
178+
}
179+
}
180+
array_free(gctx->outputs);
181+
182+
RDL_GraphFree(gctx->graph);
183+
}
184+
185+
int RDL_GraphRun(RDL_GraphRunCtx* gctx){
186+
TF_Status *status = TF_NewStatus();
187+
188+
TF_Tensor* inputTensorsValues[array_len(gctx->inputs)];
189+
TF_Output inputs[array_len(gctx->inputs)];
190+
TF_Tensor* outputTensorsValues[array_len(gctx->outputs)];
191+
TF_Output outputs[array_len(gctx->outputs)];
192+
193+
for(size_t i = 0 ; i < array_len(gctx->inputs) ; ++i){
194+
inputTensorsValues[i] = RDL_TensorGetTensor(gctx->inputs[i].tensor);
195+
inputs[i] = gctx->inputs[i].name;
196+
}
197+
198+
for(size_t i = 0 ; i < array_len(gctx->outputs) ; ++i){
199+
outputs[i] = gctx->outputs[i].name;
200+
}
201+
202+
TF_SessionRun(gctx->graph->session, NULL /* run_options */,
203+
inputs, inputTensorsValues, array_len(gctx->inputs),
204+
outputs, outputTensorsValues, array_len(gctx->outputs),
205+
NULL /* target_opers */, 0 /* ntargets */,
206+
NULL /* run_Metadata */,
207+
status);
208+
209+
if (TF_GetCode(status) != TF_OK) {
210+
TF_DeleteStatus(status);
211+
return 0;
212+
}
213+
214+
for(size_t i = 0 ; i < array_len(gctx->outputs) ; ++i){
215+
gctx->outputs[i].tensor = RDL_TensorCreateFromTensor(outputTensorsValues[i]);
216+
}
217+
218+
TF_DeleteStatus(status);
219+
return 1;
220+
}
221+
222+
RDL_Graph* RDL_GraphGetShallowCopy(RDL_Graph* graph){
223+
++graph->refCount;
224+
return graph;
225+
}

src/graph.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010

1111
#include "tensorflow/c/c_api.h"
1212
#include "redismodule.h"
13-
14-
typedef struct RDL_Graph{
15-
TF_Graph* graph;
16-
// TODO: use session pool? The ideal would be to use one session per client.
17-
// If a client disconnects, we dispose the session or reuse it for
18-
// another client.
19-
void *session;
20-
size_t refCount;
21-
}RDL_Graph;
13+
#include "tensor.h"
2214

2315
extern RedisModuleType *RedisDL_GraphType;
2416

25-
int Graph_Init(RedisModuleCtx* ctx);
26-
RDL_Graph* Graph_Create(const char* prefix, const char* graphdef, size_t graphlen);
27-
void Graph_Free(RDL_Graph* graph);
17+
int RDL_GraphInit(RedisModuleCtx* ctx);
18+
RDL_Graph* RDL_GraphCreate(const char* prefix, const char* graphdef, size_t graphlen);
19+
void RDL_GraphFree(RDL_Graph* graph);
20+
RDL_GraphRunCtx* RDL_RunCtxCreate(RDL_Graph* graph);
21+
int RDL_RunCtxAddInput(RDL_GraphRunCtx* gctx, const char* inputName, RDL_Tensor* inputTensor);
22+
int RDL_RunCtxAddOutput(RDL_GraphRunCtx* gctx, const char* outputName);
23+
size_t RDL_RunCtxNumOutputs(RDL_GraphRunCtx* gctx);
24+
RDL_Tensor* RDL_RunCtxOutputTensor(RDL_GraphRunCtx* gctx, size_t index);
25+
void RDL_RunCtxFree(RDL_GraphRunCtx* gctx);
26+
int RDL_GraphRun(RDL_GraphRunCtx* gctx);
27+
RDL_Graph* RDL_GraphGetShallowCopy(RDL_Graph* graph);
2828

2929

3030

0 commit comments

Comments
 (0)