Skip to content

Commit 39990f7

Browse files
Fix run info memleaks (#412)
* [add] added mobilenet model use case to quickly check for leaks on modelruns with large tensors (dagrun and modelrun variations) * Fix run info memleaks * [fix] fixed reference count on ai.dagrun and ai.dagrunro for tensor structure. Added AI_dictType AI_dictTypeTensorVals with proper valDestructor * [fix] removed WIP unit test folder (not for this issue) Co-authored-by: filipecosta90 <filipecosta.90@gmail.com>
1 parent 8a87333 commit 39990f7

14 files changed

+308
-51
lines changed

src/dag.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) {
9191
currentOp->result = REDISMODULE_ERR;
9292
}
9393
}
94+
// since we've increased the reference count prior modelrun we need to decrease it
95+
const size_t ninputs = RAI_ModelRunCtxNumInputs(currentOp->mctx);
96+
for (size_t inputNumber = 0; inputNumber < ninputs; inputNumber++) {
97+
RAI_Tensor *tensor =
98+
RAI_ModelRunCtxInputTensor(currentOp->mctx, inputNumber);
99+
if (tensor) {
100+
RAI_TensorFree(tensor);
101+
}
102+
}
103+
94104
} else {
95105
currentOp->result = REDISMODULE_ERR;
96106
}
@@ -195,7 +205,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
195205
}
196206
RedisModule_CloseKey(key);
197207
RedisAI_ReplicateTensorSet(ctx, tensor_keyname, tensor);
198-
// TODO: free Tensor
199208
} else {
200209
RedisModule_ReplyWithError(
201210
ctx, "ERR specified persistent key that was not used on DAG");

src/model.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -389,17 +389,20 @@ RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t index) {
389389
return mctx->outputs[index].tensor;
390390
}
391391

392-
void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx) {
393-
for (size_t i=0; i<array_len(mctx->inputs); ++i) {
394-
RAI_TensorFree(mctx->inputs[i].tensor);
395-
}
396-
array_free(mctx->inputs);
392+
void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx, int freeTensors) {
393+
if (freeTensors) {
394+
for (size_t i=0; i<array_len(mctx->inputs); ++i) {
395+
RAI_TensorFree(mctx->inputs[i].tensor);
396+
}
397397

398-
for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) {
399-
if (mctx->outputs[i].tensor) {
400-
RAI_TensorFree(mctx->outputs[i].tensor);
398+
for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) {
399+
if (mctx->outputs[i].tensor) {
400+
RAI_TensorFree(mctx->outputs[i].tensor);
401+
}
401402
}
402403
}
404+
405+
array_free(mctx->inputs);
403406
array_free(mctx->outputs);
404407

405408
RAI_Error err = {0};

src/model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ RAI_ModelRunCtx* RAI_ModelRunCtxCreate(RAI_Model* model);
7979
* work
8080
*
8181
* @param mctx
82+
* @param freeTensors free input and output tensors or leave them allocated
8283
*/
83-
void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx);
84+
void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx, int freeTensors);
8485

8586
/**
8687
* Allocates a RAI_ModelCtxParam data structure, and enforces a shallow copy of

src/run_info.c

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,40 @@
1616
#include "util/arr_rm_alloc.h"
1717
#include "util/dict.h"
1818

19+
20+
static uint64_t RAI_TensorDictKeyHashFunction(const void *key){
21+
return AI_dictGenHashFunction(key, strlen((char*)key));
22+
}
23+
24+
static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2){
25+
const char* strKey1 = key1;
26+
const char* strKey2 = key2;
27+
return strcmp(strKey1, strKey2) == 0;
28+
}
29+
30+
static void RAI_TensorDictKeyFree(void *privdata, void *key){
31+
RedisModule_Free(key);
32+
}
33+
34+
static void* RAI_TensorDictKeyDup(void *privdata, const void *key){
35+
return RedisModule_Strdup((char*)key);
36+
}
37+
38+
static void RAI_TensorDictValFree(void *privdata, const void *obj){
39+
return RAI_TensorFree((RAI_Tensor*)obj);
40+
}
41+
42+
43+
AI_dictType AI_dictTypeTensorVals = {
44+
.hashFunction = RAI_TensorDictKeyHashFunction,
45+
.keyDup = RAI_TensorDictKeyDup,
46+
.valDup = NULL,
47+
.keyCompare = RAI_TensorDictKeyStrcmp,
48+
.keyDestructor = RAI_TensorDictKeyFree,
49+
.valDestructor = RAI_TensorDictValFree,
50+
};
51+
52+
1953
/**
2054
* Allocate the memory and initialise the RAI_DagOp.
2155
* @param result Output parameter to capture allocated RAI_DagOp.
@@ -76,7 +110,7 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
76110
return REDISMODULE_ERR;
77111
}
78112
rinfo->use_local_context = 0;
79-
rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
113+
rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeTensorVals, NULL);
80114
if (!(rinfo->dagTensorsContext)) {
81115
return REDISMODULE_ERR;
82116
}
@@ -116,6 +150,13 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) {
116150
}
117151
array_free(dagOp->outTensors);
118152

153+
if (dagOp->mctx) {
154+
RAI_ModelRunCtxFree(dagOp->mctx, false);
155+
}
156+
if (dagOp->sctx) {
157+
RAI_ScriptRunCtxFree(dagOp->sctx, false);
158+
}
159+
119160
RedisModule_Free(dagOp);
120161
}
121162
}
@@ -125,37 +166,48 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
125166
return;
126167
}
127168
if (rinfo->mctx) {
128-
RAI_ModelRunCtxFree(rinfo->mctx);
169+
RAI_ModelRunCtxFree(rinfo->mctx, true);
129170
}
130171
if (rinfo->sctx) {
131-
RAI_ScriptRunCtxFree(rinfo->sctx);
172+
RAI_ScriptRunCtxFree(rinfo->sctx, true);
132173
}
133174
RAI_FreeError(rinfo->err);
134175

135176
if (rinfo->dagTensorsContext) {
136177
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
137-
AI_dictEntry *stats_entry = AI_dictNext(iter);
178+
AI_dictEntry *entry = AI_dictNext(iter);
138179
RAI_Tensor *tensor = NULL;
139180

140-
while (stats_entry) {
141-
tensor = AI_dictGetVal(stats_entry);
142-
char *key = (char *)AI_dictGetKey(stats_entry);
181+
while (entry) {
182+
tensor = AI_dictGetVal(entry);
183+
char *key = (char *)AI_dictGetKey(entry);
143184

144-
if (tensor&&key!=NULL) {
185+
if (tensor && key != NULL) {
145186
// if the key is persistent then we should not delete it
146187
AI_dictEntry *persistent_entry =
147188
AI_dictFind(rinfo->dagTensorsPersistentContext, key);
148-
// if the key was loaded from the keyspace then we should not delete
149-
// it
189+
// if the key was loaded from the keyspace then we should not delete it
150190
AI_dictEntry *loaded_entry =
151191
AI_dictFind(rinfo->dagTensorsLoadedContext, key);
192+
152193
if (persistent_entry == NULL && loaded_entry == NULL) {
153-
RAI_TensorFree(tensor);
194+
AI_dictDelete(rinfo->dagTensorsContext, key);
195+
}
196+
197+
if (persistent_entry) {
198+
AI_dictDelete(rinfo->dagTensorsPersistentContext, key);
199+
}
200+
if (loaded_entry) {
201+
AI_dictDelete(rinfo->dagTensorsLoadedContext, key);
154202
}
155203
}
156-
stats_entry = AI_dictNext(iter);
204+
entry = AI_dictNext(iter);
157205
}
158206
AI_dictReleaseIterator(iter);
207+
208+
RedisModule_Free(rinfo->dagTensorsContext);
209+
RedisModule_Free(rinfo->dagTensorsLoadedContext);
210+
RedisModule_Free(rinfo->dagTensorsPersistentContext);
159211
}
160212

161213
if (rinfo->dagOps) {

src/script.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,20 @@ RAI_Tensor* RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx* sctx, size_t index) {
182182
return sctx->outputs[index].tensor;
183183
}
184184

185-
void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx) {
186-
for (size_t i = 0; i < array_len(sctx->inputs); ++i) {
187-
RAI_TensorFree(sctx->inputs[i].tensor);
188-
}
189-
array_free(sctx->inputs);
185+
void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx, int freeTensors) {
186+
if (freeTensors) {
187+
for (size_t i = 0; i < array_len(sctx->inputs); ++i) {
188+
RAI_TensorFree(sctx->inputs[i].tensor);
189+
}
190190

191-
for (size_t i = 0; i < array_len(sctx->outputs); ++i) {
192-
if (sctx->outputs[i].tensor) {
193-
RAI_TensorFree(sctx->outputs[i].tensor);
191+
for (size_t i = 0; i < array_len(sctx->outputs); ++i) {
192+
if (sctx->outputs[i].tensor) {
193+
RAI_TensorFree(sctx->outputs[i].tensor);
194+
}
194195
}
195196
}
197+
198+
array_free(sctx->inputs);
196199
array_free(sctx->outputs);
197200

198201
RedisModule_Free(sctx->fnname);

src/script.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ RAI_Tensor* RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx* sctx, size_t index);
103103
* work
104104
*
105105
* @param sctx
106+
* @param freeTensors free input and output tensors or leave them allocated
106107
*/
107-
void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx);
108+
void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx, int freeTensors);
108109

109110
/**
110111
* Given the input script context, run associated script

test/includes.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
except:
1717
pass
1818

19+
MAX_ITERATIONS = 2 if os.environ.get("MAX_ITERATIONS") == None else os.environ.get("MAX_ITERATIONS")
1920
TEST_TF = os.environ.get("TEST_TF") != "0" and os.environ.get("WITH_TF") != "0"
2021
TEST_TFLITE = os.environ.get("TEST_TFLITE") != "0" and os.environ.get("WITH_TFLITE") != "0"
2122
TEST_PT = os.environ.get("TEST_PT") != "0" and os.environ.get("WITH_PT") != "0"
@@ -24,7 +25,7 @@
2425
DEVICE = os.environ.get('DEVICE', 'CPU').upper().encode('utf-8', 'ignore').decode('utf-8')
2526
VALGRIND = os.environ.get("VALGRIND") == "1"
2627
print(f"Running tests on {DEVICE}\n")
27-
28+
print(f"Using a max of {MAX_ITERATIONS} iterations per test\n")
2829
# change this to make inference tests longer
2930
MAX_TRANSACTIONS=100
3031

@@ -67,11 +68,59 @@ def info_to_dict(info):
6768
return dict(zip(info[::2], info[1::2]))
6869

6970

70-
def load_mobilenet_test_data():
71+
def load_resnet_test_data():
72+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data/imagenet')
73+
labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json')
74+
image_filename = os.path.join(test_data_path, 'dog.jpg')
75+
model_filename = os.path.join(test_data_path, 'resnet50.pb')
76+
script_filename = os.path.join(test_data_path, 'data_processing_script.txt')
77+
78+
with open(script_filename, 'rb') as f:
79+
script = f.read()
80+
81+
with open(model_filename, 'rb') as f:
82+
model_pb = f.read()
83+
84+
with open(labels_filename, 'r') as f:
85+
labels = json.load(f)
86+
87+
img_height, img_width = 224, 224
88+
89+
img = imread(image_filename)
90+
img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True)
91+
img = img.astype(np.uint8)
92+
93+
return model_pb, script, labels, img
94+
95+
def load_mobilenet_v1_test_data():
96+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
97+
labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json')
98+
image_filename = os.path.join(test_data_path, 'panda.jpg')
99+
model_filename = os.path.join(test_data_path, 'mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb')
100+
input_var = 'input'
101+
output_var = 'MobilenetV1/Predictions/Reshape_1'
102+
103+
with open(model_filename, 'rb') as f:
104+
model_pb = f.read()
105+
106+
with open(labels_filename, 'r') as f:
107+
labels = json.load(f)
108+
109+
img_height, img_width = 224, 224
110+
111+
img = imread(image_filename)
112+
img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True)
113+
img = img.astype(np.float32)
114+
115+
return model_pb, input_var, output_var, labels, img
116+
117+
def load_mobilenet_v2_test_data():
71118
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
72119
labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json')
73120
image_filename = os.path.join(test_data_path, 'panda.jpg')
74-
model_filename = os.path.join(test_data_path, 'mobilenet_v2_1.4_224_frozen.pb')
121+
model_filename = os.path.join(test_data_path, 'mobilenet/mobilenet_v2_1.4_224_frozen.pb')
122+
input_var = 'input'
123+
output_var = 'MobilenetV2/Predictions/Reshape_1'
75124

76125
with open(model_filename, 'rb') as f:
77126
model_pb = f.read()
@@ -85,7 +134,7 @@ def load_mobilenet_test_data():
85134
img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True)
86135
img = img.astype(np.float32)
87136

88-
return model_pb, labels, img
137+
return model_pb, input_var, output_var, labels, img
89138

90139
def load_creditcardfraud_data(env,max_tensors=10000):
91140
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:bbb2752038ff1749d2b55988bb5f6e999a799c19413a0691b82d29f7aec0bab3
3+
size 17198345
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:f1fe206dfd3cff261cf403b5757abec886da445a80056e55310ddac0b2805a3b
3+
size 17198345
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:fd925f4b59d8d5035ccb2ecdfbf9b0f47a5ba3acfa81bd5a18536f69021df74a
3+
size 34277746
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:111479258f3841c93d0a7a377c976c24e8281077818991931429d2277dd88590
3+
size 24508794
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import tensorflow as tf
2+
import tensorflow_hub as hub
3+
import ml2rt
4+
import argparse
5+
import sys
6+
7+
url = 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/quantops/classification/3'
8+
model_name = 'mobilenet_v1_100_224'
9+
module = hub.Module(url)
10+
batch_size = 1
11+
number_channels = 3
12+
height, width = hub.get_expected_image_size(module)
13+
input_var = 'input'
14+
output_var = 'MobilenetV1/Predictions/Reshape_1'
15+
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument('--gpu', action="store_true", default=False)
18+
parser.add_argument('--input-shape', default="NxHxWxC", type=str)
19+
args = parser.parse_args()
20+
device = 'gpu' if args.gpu else 'cpu'
21+
22+
gpu_available = tf.test.is_gpu_available(
23+
cuda_only=True, min_cuda_compute_capability=None
24+
)
25+
26+
if gpu_available is False and args.gpu:
27+
print("No CUDA GPUs found. Exiting...")
28+
sys.exit(1)
29+
30+
var_converter = tf.compat.v1.graph_util.convert_variables_to_constants
31+
32+
if args.input_shape == "NxHxWxC":
33+
print("Saving N x H x W x C (1, 224, 224, 3) (with channels_last data format)")
34+
images = tf.compat.v1.placeholder(tf.float32, shape=(
35+
batch_size, height, width, number_channels), name=input_var)
36+
elif args.input_shape == "NxHxWxC":
37+
print("Saving N x C x H x W (1, 3, 224, 224)")
38+
images = tf.placeholder(tf.float32, shape=(
39+
batch_size, number_channels, height, width), name=input_var)
40+
else:
41+
print("inputs shape is either NxHxWxC or NxCxHxW. Exiting...")
42+
sys.exit(1)
43+
44+
logits = module(images)
45+
logits = tf.identity(logits, output_var)
46+
with tf.compat.v1.Session() as sess:
47+
sess.run([tf.compat.v1.global_variables_initializer()])
48+
ml2rt.save_tensorflow(sess, '{model_name}_{device}_{input_shape}.pb'.format(
49+
model_name=model_name, device=device, input_shape=args.input_shape), output=[output_var])

0 commit comments

Comments
 (0)