1
1
#include "graph.h"
2
+ #include "utils/arr_rm_alloc.h"
2
3
3
4
RedisModuleType * RedisDL_GraphType = NULL ;
4
5
6
+ typedef struct RDL_Graph {
7
+ TF_Graph * graph ;
8
+ // TODO: use session pool? The ideal would be to use one session per client.
9
+ // If a client disconnects, we dispose the session or reuse it for
10
+ // another client.
11
+ void * session ;
12
+ size_t refCount ;
13
+ }RDL_Graph ;
14
+
15
+ typedef struct RDL_GraphCtxParam {
16
+ TF_Output name ;
17
+ RDL_Tensor * tensor ;
18
+ }RDL_GraphCtxParam ;
19
+
20
+ typedef struct RDL_GraphRunCtx {
21
+ RDL_Graph * graph ;
22
+ RDL_GraphCtxParam * inputs ;
23
+ RDL_GraphCtxParam * outputs ;
24
+ }RDL_GraphRunCtx ;
25
+
5
26
static void * Graph_RdbLoad (struct RedisModuleIO * io , int encver ){
6
27
//todo
7
28
return NULL ;
@@ -12,10 +33,10 @@ static void Graph_RdbSave(RedisModuleIO *rdb, void *value){
12
33
}
13
34
14
35
static void Graph_DTFree (void * value ){
15
- Graph_Free (value );
36
+ RDL_GraphFree (value );
16
37
}
17
38
18
- int Graph_Init (RedisModuleCtx * ctx ){
39
+ int RDL_GraphInit (RedisModuleCtx * ctx ){
19
40
RedisModuleTypeMethods tmGraph = {
20
41
.version = REDISMODULE_TYPE_METHOD_VERSION ,
21
42
.rdb_load = Graph_RdbLoad ,
@@ -30,7 +51,7 @@ int Graph_Init(RedisModuleCtx* ctx){
30
51
return RedisDL_GraphType != NULL ;
31
52
}
32
53
33
- RDL_Graph * Graph_Create (const char * prefix , const char * graphdef , size_t graphlen ){
54
+ RDL_Graph * RDL_GraphCreate (const char * prefix , const char * graphdef , size_t graphlen ){
34
55
TF_Graph * graph = TF_NewGraph ();
35
56
36
57
TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
@@ -74,7 +95,10 @@ RDL_Graph* Graph_Create(const char* prefix, const char* graphdef, size_t graphle
74
95
return ret ;
75
96
}
76
97
77
- void Graph_Free (RDL_Graph * graph ){
98
+ void RDL_GraphFree (RDL_Graph * graph ){
99
+ if (-- graph -> refCount > 0 ){
100
+ return ;
101
+ }
78
102
TF_Status * status = TF_NewStatus ();
79
103
TF_CloseSession (graph -> session , status );
80
104
@@ -100,3 +124,102 @@ void Graph_Free(RDL_Graph* graph){
100
124
101
125
RedisModule_Free (graph );
102
126
}
127
+
128
+ RDL_GraphRunCtx * RDL_RunCtxCreate (RDL_Graph * graph ){
129
+ #define PARAM_INITIAL_SIZE 10
130
+ RDL_GraphRunCtx * gctx = RedisModule_Alloc (sizeof (* gctx ));
131
+ gctx -> graph = RDL_GraphGetShallowCopy (graph );
132
+ gctx -> inputs = array_new (RDL_GraphCtxParam , PARAM_INITIAL_SIZE );
133
+ gctx -> outputs = array_new (RDL_GraphCtxParam , PARAM_INITIAL_SIZE );
134
+ return gctx ;
135
+ }
136
+
137
+ static int Graph_RunCtxAddParam (RDL_GraphRunCtx * gctx , RDL_GraphCtxParam * paramArr , const char * name , RDL_Tensor * tensor ){
138
+ TF_Output port ;
139
+ port .oper = TF_GraphOperationByName (gctx -> graph -> graph , name );
140
+ port .index = 0 ;
141
+ if (port .oper == NULL ){
142
+ return 0 ;
143
+ }
144
+ RDL_GraphCtxParam param = {
145
+ .name = port ,
146
+ .tensor = tensor ? RDL_TensorGetShallowCopy (tensor ): NULL ,
147
+ };
148
+ paramArr = array_append (paramArr , param );
149
+ return 1 ;
150
+ }
151
+
152
+ int RDL_RunCtxAddInput (RDL_GraphRunCtx * gctx , const char * inputName , RDL_Tensor * inputTensor ){
153
+ return Graph_RunCtxAddParam (gctx , gctx -> inputs , inputName , inputTensor );
154
+ }
155
+
156
+ int RDL_RunCtxAddOutput (RDL_GraphRunCtx * gctx , const char * outputName ){
157
+ return Graph_RunCtxAddParam (gctx , gctx -> outputs , outputName , NULL );
158
+ }
159
+
160
+ size_t RDL_RunCtxNumOutputs (RDL_GraphRunCtx * gctx ){
161
+ return array_len (gctx -> outputs );
162
+ }
163
+
164
+ RDL_Tensor * RDL_RunCtxOutputTensor (RDL_GraphRunCtx * gctx , size_t index ){
165
+ assert (RDL_RunCtxNumOutputs (gctx ) > index && index >= 0 );
166
+ return gctx -> outputs [index ].tensor ;
167
+ }
168
+
169
+ void RDL_RunCtxFree (RDL_GraphRunCtx * gctx ){
170
+ for (size_t i = 0 ; i < array_len (gctx -> inputs ) ; ++ i ){
171
+ RDL_TensorFree (gctx -> inputs [i ].tensor );
172
+ }
173
+ array_free (gctx -> inputs );
174
+
175
+ for (size_t i = 0 ; i < array_len (gctx -> outputs ) ; ++ i ){
176
+ if (gctx -> outputs [i ].tensor ){
177
+ RDL_TensorFree (gctx -> outputs [i ].tensor );
178
+ }
179
+ }
180
+ array_free (gctx -> outputs );
181
+
182
+ RDL_GraphFree (gctx -> graph );
183
+ }
184
+
185
+ int RDL_GraphRun (RDL_GraphRunCtx * gctx ){
186
+ TF_Status * status = TF_NewStatus ();
187
+
188
+ TF_Tensor * inputTensorsValues [array_len (gctx -> inputs )];
189
+ TF_Output inputs [array_len (gctx -> inputs )];
190
+ TF_Tensor * outputTensorsValues [array_len (gctx -> outputs )];
191
+ TF_Output outputs [array_len (gctx -> outputs )];
192
+
193
+ for (size_t i = 0 ; i < array_len (gctx -> inputs ) ; ++ i ){
194
+ inputTensorsValues [i ] = RDL_TensorGetTensor (gctx -> inputs [i ].tensor );
195
+ inputs [i ] = gctx -> inputs [i ].name ;
196
+ }
197
+
198
+ for (size_t i = 0 ; i < array_len (gctx -> outputs ) ; ++ i ){
199
+ outputs [i ] = gctx -> outputs [i ].name ;
200
+ }
201
+
202
+ TF_SessionRun (gctx -> graph -> session , NULL /* run_options */ ,
203
+ inputs , inputTensorsValues , array_len (gctx -> inputs ),
204
+ outputs , outputTensorsValues , array_len (gctx -> outputs ),
205
+ NULL /* target_opers */ , 0 /* ntargets */ ,
206
+ NULL /* run_Metadata */ ,
207
+ status );
208
+
209
+ if (TF_GetCode (status ) != TF_OK ) {
210
+ TF_DeleteStatus (status );
211
+ return 0 ;
212
+ }
213
+
214
+ for (size_t i = 0 ; i < array_len (gctx -> outputs ) ; ++ i ){
215
+ gctx -> outputs [i ].tensor = RDL_TensorCreateFromTensor (outputTensorsValues [i ]);
216
+ }
217
+
218
+ TF_DeleteStatus (status );
219
+ return 1 ;
220
+ }
221
+
222
+ RDL_Graph * RDL_GraphGetShallowCopy (RDL_Graph * graph ){
223
+ ++ graph -> refCount ;
224
+ return graph ;
225
+ }
0 commit comments