From 7de8652a02a19a421698e2cbf8c2460108552bf1 Mon Sep 17 00:00:00 2001 From: Xavier Wang Date: Mon, 13 Feb 2023 16:00:46 +0800 Subject: [PATCH] add memory check for all `pb_add*` routines in `pb.h` --- pb.c | 150 ++++++++++++++++++++++++++++++++++------------------------- pb.h | 20 ++++---- 2 files changed, 99 insertions(+), 71 deletions(-) diff --git a/pb.c b/pb.c index c428531..6a12928 100644 --- a/pb.c +++ b/pb.c @@ -33,12 +33,15 @@ PB_NS_BEGIN #define check_slice(L,idx) ((pb_Slice*)luaL_checkudata(L,idx,PB_SLICE)) #define test_slice(L,idx) ((pb_Slice*)luaL_testudata(L,idx,PB_SLICE)) #define push_slice(L,s) lua_pushlstring((L), (s).p, pb_len((s))) -#define return_self(L) { return lua_settop(L, 1), 1; } +#define lpb_returnself(L) { return lua_settop(L, 1), 1; } static int lpb_relindex(int idx, int offset) { return idx < 0 && idx > LUA_REGISTRYINDEX ? idx - offset : idx; } +static size_t lpb_checkmem(lua_State *L, size_t ret) +{ return ret ? ret : (size_t)luaL_error(L, "out of memory"); } + #if LUA_VERSION_NUM < 502 #include @@ -226,8 +229,10 @@ static int Lpb_state(lua_State *L) { /* protobuf util routines */ -static void lpb_addlength(lua_State *L, pb_Buffer *b, size_t len) -{ if (pb_addlength(b, len) == 0) luaL_error(L, "encode bytes fail"); } +static size_t lpb_addlength(lua_State *L, pb_Buffer *b, size_t len, size_t prealloc) { + size_t wlen = pb_addlength(b, len, prealloc); + return wlen ? wlen : (size_t)luaL_error(L, "encode bytes fail"); +} static int typeerror(lua_State *L, int idx, const char *type) { lua_pushfstring(L, "%s expected, got %s", type, luaL_typename(L, idx)); @@ -382,74 +387,73 @@ typedef union lpb_Value { } lpb_Value; static int lpb_addtype(lua_State *L, pb_Buffer *b, int idx, int type, size_t *plen) { - int ret = 0, expected = LUA_TNUMBER; + int ret = 0, has_data = 1, expected = LUA_TNUMBER; lpb_Value v; size_t len = 0; switch (type) { case PB_Tbool: len = pb_addvarint32(b, ret = lua_toboolean(L, idx)); - if (ret) len = 0; ret = 1; break; case PB_Tdouble: v.lnum = lua_tonumberx(L, idx, &ret); if (ret) len = pb_addfixed64(b, pb_encode_double((double)v.lnum)); - if (v.lnum != 0.0) len = 0; + if (v.lnum == 0.0) has_data = 0; break; case PB_Tfloat: v.lnum = lua_tonumberx(L, idx, &ret); if (ret) len = pb_addfixed32(b, pb_encode_float((float)v.lnum)); - if (v.lnum != 0.0) len = 0; + if (v.lnum == 0.0) has_data = 0; break; case PB_Tfixed32: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addfixed32(b, v.u32); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tsfixed32: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addfixed32(b, v.u32); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tint32: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addvarint64(b, pb_expandsig((uint32_t)v.u64)); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tuint32: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addvarint32(b, v.u32); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tsint32: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addvarint32(b, pb_encode_sint32(v.u32)); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tfixed64: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addfixed64(b, v.u64); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tsfixed64: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addfixed64(b, v.u64); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tint64: case PB_Tuint64: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addvarint64(b, v.u64); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tsint64: v.u64 = lpb_tointegerx(L, idx, &ret); if (ret) len = pb_addvarint64(b, pb_encode_sint64(v.u64)); - if (v.u64 != 0) len = 0; + if (v.u64 == 0) has_data = 0; break; case PB_Tbytes: case PB_Tstring: *v.s = lpb_toslice(L, idx); if ((ret = (v.s->p != NULL))) len = pb_addbytes(b, *v.s); - if (pb_len(*v.s) != 0) len = 0; + if (pb_len(*v.s) == 0) has_data = 0; expected = LUA_TSTRING; break; default: @@ -457,8 +461,8 @@ static int lpb_addtype(lua_State *L, pb_Buffer *b, int idx, int type, size_t *pl if (idx > 0) argcheck(L, 0, idx, lua_tostring(L, -1)); lua_error(L); } - if (plen) *plen = len; - return ret ? 0 : expected; + if (plen) *plen = has_data ? 0 : len; + return ret ? (lpb_checkmem(L, len), 0) : expected; } static void lpb_readtype(lua_State *L, lpb_State *LS, int type, pb_Slice *s) { @@ -687,21 +691,30 @@ static int lpb_typefmt(int fmt) { static int lpb_packfmt(lua_State *L, int idx, pb_Buffer *b, const char **pfmt, int level) { const char *fmt = *pfmt; int type, ltype; - size_t len; + size_t len, wlen; argcheck(L, level <= 100, 1, "format level overflow"); for (; *fmt != '\0'; ++fmt) { switch (*fmt) { - case 'v': pb_addvarint64(b, (uint64_t)lpb_checkinteger(L, idx++)); break; - case 'd': pb_addfixed32(b, (uint32_t)lpb_checkinteger(L, idx++)); break; - case 'q': pb_addfixed64(b, (uint64_t)lpb_checkinteger(L, idx++)); break; - case 'c': pb_addslice(b, lpb_checkslice(L, idx++)); break; - case 's': pb_addbytes(b, lpb_checkslice(L, idx++)); break; - case '#': lpb_addlength(L, b, (size_t)lpb_checkinteger(L, idx++)); break; + case '#': + wlen = lpb_addlength(L, b, (size_t)lpb_checkinteger(L, idx++), 0); + break; + case 'v': + wlen = pb_addvarint64(b, (uint64_t)lpb_checkinteger(L, idx++)); + break; + case 'd': + wlen = pb_addfixed32(b, (uint32_t)lpb_checkinteger(L, idx++)); + break; + case 'q': + wlen = pb_addfixed64(b, (uint64_t)lpb_checkinteger(L, idx++)); + break; + case 'c': wlen = pb_addslice(b, lpb_checkslice(L, idx++)); break; + case 's': wlen = pb_addbytes(b, lpb_checkslice(L, idx++)); break; case '(': + if ((wlen = pb_addvarint32(b, 0)) == 0) break; len = pb_bufflen(b); ++fmt; idx = lpb_packfmt(L, idx, b, &fmt, level+1); - lpb_addlength(L, b, len); + wlen = lpb_addlength(L, b, len, 1); break; case ')': if (level == 0) luaL_argerror(L, 1, "unexpected ')' in format"); @@ -716,7 +729,9 @@ static int lpb_packfmt(lua_State *L, int idx, pb_Buffer *b, const char **pfmt, i lua_typename(L, ltype), pb_typename(type, ""), luaL_typename(L, idx)); ++idx; + continue; } + lpb_checkmem(L, wlen); } if (level != 0) luaL_argerror(L, 2, "unmatch '(' in format"); *pfmt = fmt; @@ -782,7 +797,7 @@ static int Lbuf_new(lua_State *L) { pb_initbuffer(buf); luaL_setmetatable(L, PB_BUFFER); for (i = 1; i <= top; ++i) - pb_addslice(buf, lpb_checkslice(L, i)); + lpb_checkmem(L, pb_addslice(buf, lpb_checkslice(L, i))); return 1; } @@ -798,7 +813,7 @@ static int Lbuf_libcall(lua_State *L) { pb_initbuffer(buf); luaL_setmetatable(L, PB_BUFFER); for (i = 2; i <= top; ++i) - pb_addslice(buf, lpb_checkslice(L, i)); + lpb_checkmem(L, pb_addslice(buf, lpb_checkslice(L, i))); return 1; } @@ -813,8 +828,8 @@ static int Lbuf_reset(lua_State *L) { int i, top = lua_gettop(L); pb_bufflen(buf) = 0; for (i = 2; i <= top; ++i) - pb_addslice(buf, lpb_checkslice(L, i)); - return_self(L); + lpb_checkmem(L, pb_addslice(buf, lpb_checkslice(L, i))); + lpb_returnself(L); } static int Lbuf_len(lua_State *L) { @@ -1043,7 +1058,7 @@ static int Lslice_reset(lua_State *L) { lpb_resetslice(L, s, size); if (!lua_isnoneornil(L, 2)) lpb_initslice(L, 2, s, size); - return_self(L); + lpb_returnself(L); } static int Lslice_tostring(lua_State *L) { @@ -1103,7 +1118,7 @@ static int Lslice_enter(lua_State *L) { view.start = s->curr.p; lpb_enterview(L, s, view); } - return_self(L); + lpb_returnself(L); } static int Lslice_leave(lua_State *L) { @@ -1174,7 +1189,7 @@ static void lpb_newmsgtable(lua_State *L, const pb_Type *t) { lua_createtable(L, 0, fieldcnt > 0 ? fieldcnt : 0); } -LUALIB_API const pb_Type *lpb_type(lpb_State *LS, pb_Slice s) { +LUALIB_API const pb_Type *lpb_type(lua_State *L, lpb_State *LS, pb_Slice s) { const pb_Type *t; if (s.p == NULL || *s.p == '.') t = pb_type(lpbS_state(LS), lpb_name(LS, s)); @@ -1183,7 +1198,7 @@ LUALIB_API const pb_Type *lpb_type(lpb_State *LS, pb_Slice s) { pb_initbuffer(&b); *pb_prepbuffsize(&b, 1) = '.'; pb_addsize(&b, 1); - pb_addslice(&b, s); + lpb_checkmem(L, pb_addslice(&b, s)); t = pb_type(lpbS_state(LS), pb_name(lpbS_state(LS),pb_result(&b),NULL)); pb_resetbuffer(&b); } @@ -1277,7 +1292,7 @@ static int lpb_pushfield(lua_State *L, const pb_Type *t, const pb_Field *f) { static int Lpb_typesiter(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_toslice(L, 2)); + const pb_Type *t = lpb_type(L, LS, lpb_toslice(L, 2)); if ((t == NULL && !lua_isnoneornil(L, 2))) return 0; pb_nexttype(lpbS_state(LS), &t); @@ -1293,7 +1308,7 @@ static int Lpb_types(lua_State *L) { static int Lpb_fieldsiter(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); const pb_Field *f = pb_fname(t, lpb_name(LS, lpb_toslice(L, 2))); if ((f == NULL && !lua_isnoneornil(L, 2)) || !pb_nextfield(t, &f)) return 0; @@ -1309,7 +1324,7 @@ static int Lpb_fields(lua_State *L) { static int Lpb_type(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); if (t == NULL || t->is_dead) return 0; return lpb_pushtype(L, t); @@ -1317,13 +1332,13 @@ static int Lpb_type(lua_State *L) { static int Lpb_field(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); return lpb_pushfield(L, t, lpb_field(L, 2, t)); } static int Lpb_enum(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); const pb_Field *f = lpb_field(L, 2, t); if (f == NULL) return 0; if (lua_type(L, 2) == LUA_TNUMBER) @@ -1424,7 +1439,7 @@ static void lpb_cleardefmeta(lua_State *L, lpb_State *LS, const pb_Type *t) { static int Lpb_defaults(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); int clear = lua_toboolean(L, 2); if (t == NULL) luaL_argerror(L, 1, "type not found"); lpb_pushdefmeta(L, LS, t); @@ -1434,7 +1449,7 @@ static int Lpb_defaults(lua_State *L) { static int Lpb_hook(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); int type = lua_type(L, 2); if (t == NULL) luaL_argerror(L, 1, "type not found"); if (type != LUA_TNONE && type != LUA_TNIL && type != LUA_TFUNCTION) @@ -1451,7 +1466,7 @@ static int Lpb_hook(lua_State *L) { static int Lpb_encode_hook(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); int type = lua_type(L, 2); if (t == NULL) luaL_argerror(L, 1, "type not found"); if (type != LUA_TNONE && type != LUA_TNIL && type != LUA_TFUNCTION) @@ -1481,7 +1496,7 @@ static int Lpb_clear(lua_State *L) { return 0; } LS->state = &LS->local; - t = (pb_Type*)lpb_type(LS, lpb_checkslice(L, 1)); + t = (pb_Type*)lpb_type(L, LS, lpb_checkslice(L, 1)); if (lua_isnoneornil(L, 2)) pb_deltype(&LS->local, t); else pb_delfield(&LS->local, t, (pb_Field*)lpb_field(L, 2, t)); LS->state = S; @@ -1496,7 +1511,7 @@ static int Lpb_typefmt(lua_State *L) { int type; if (pb_len(s) == 1) r = pb_typename(type = lpb_typefmt(*s.p), "!"); - else if (lpb_type(lpb_lstate(L), s)) + else if (lpb_type(L, lpb_lstate(L), s)) r = "message", type = PB_TBYTES; else if ((type = pb_typebyname(s.p, PB_Tmessage)) != PB_Tmessage) { switch (type) { @@ -1548,26 +1563,32 @@ static void lpb_useenchooks(lua_State *L, lpb_State *LS, const pb_Type *t) { lua_pop(L, 2); } -static void lpbE_enum(lpb_Env *e, const pb_Field *f, int idx) { +static void lpbE_enum(lpb_Env *e, const pb_Field *f, int idx, size_t *plen) { lua_State *L = e->L; pb_Buffer *b = e->b; const pb_Field *ev; - int type = lua_type(L, idx); - if (type == LUA_TNUMBER) - pb_addvarint64(b, (uint64_t)lua_tonumber(L, idx)); - else if ((ev = pb_fname(f->type, - lpb_name(e->LS, lpb_toslice(L, idx)))) != NULL) - pb_addvarint32(b, ev->number); - else if (type != LUA_TSTRING) + int type = lua_type(L, idx), has_data = 1; + size_t len; + if (type == LUA_TNUMBER) { + uint64_t v = (uint64_t)lua_tonumber(L, idx); + if (v == 0) has_data = 0; + len = lpb_checkmem(L, pb_addvarint64(b, v)); + } else if ((ev = pb_fname(f->type, + lpb_name(e->LS, lpb_toslice(L, idx)))) != NULL) { + if (ev->number == 0) has_data = 0; + len = lpb_checkmem(L, pb_addvarint32(b, ev->number)); + } else if (type != LUA_TSTRING) argcheck(L, 0, 2, "number/string expected at field '%s', got %s", (const char*)f->name, luaL_typename(L, idx)); else { uint64_t v = lpb_tointegerx(L, idx, &type); + if (v == 0) has_data = 0; if (!type) argcheck(L, 0, 2, "can not encode unknown enum '%s' at field '%s'", lua_tostring(L, -1), (const char*)f->name); - pb_addvarint64(b, v); + len = lpb_checkmem(L, pb_addvarint64(b, v)); } + if (plen) *plen = has_data ? 0 : len; } static void lpbE_field(lpb_Env *e, const pb_Field *f, size_t *plen, int idx) { @@ -1579,15 +1600,17 @@ static void lpbE_field(lpb_Env *e, const pb_Field *f, size_t *plen, int idx) { switch (f->type_id) { case PB_Tenum: if (e->LS->use_enc_hooks) lpb_useenchooks(L, e->LS, f->type); - lpbE_enum(e, f, idx); + lpbE_enum(e, f, idx, plen); break; case PB_Tmessage: if (e->LS->use_enc_hooks) lpb_useenchooks(L, e->LS, f->type); lpb_checktable(L, f, idx); + lpb_checkmem(L, pb_addvarint32(b, 0)); len = pb_bufflen(b); lpbE_encode(e, f->type, idx); - lpb_addlength(L, b, len); + if (plen && len == pb_bufflen(b)) *plen = 1; + lpb_addlength(L, b, len, 1); break; default: @@ -1618,11 +1641,11 @@ static void lpbE_map(lpb_Env *e, const pb_Field *f, int idx) { while (lua_next(L, lpb_relindex(idx, 1))) { size_t len; pb_addvarint32(e->b, pb_pair(f->number, PB_TBYTES)); + pb_addvarint32(e->b, 0); len = pb_bufflen(e->b); lpbE_tagfield(e, kf, 1, -2); lpbE_tagfield(e, vf, 1, -1); - lpb_addlength(L, e->b, len); - + lpb_addlength(L, e->b, len, 1); lua_pop(L, 1); } } @@ -1636,6 +1659,7 @@ static void lpbE_repeated(lpb_Env *e, const pb_Field *f, int idx) { if (f->packed) { unsigned len, bufflen = pb_bufflen(b); pb_addvarint32(b, pb_pair(f->number, PB_TBYTES)); + pb_addvarint32(b, 0); len = pb_bufflen(b); for (i = 1; lua53_rawgeti(L, idx, i) != LUA_TNIL; ++i) { lpbE_field(e, f, NULL, -1); @@ -1644,7 +1668,7 @@ static void lpbE_repeated(lpb_Env *e, const pb_Field *f, int idx) { if (i == 1 && !e->LS->encode_default_values) pb_bufflen(b) = bufflen; else - lpb_addlength(L, b, len); + lpb_addlength(L, b, len, 1); } else { for (i = 1; lua53_rawgeti(L, idx, i) != LUA_TNIL; ++i) { lpbE_tagfield(e, f, 0, -1); @@ -1688,7 +1712,7 @@ static void lpbE_encode(lpb_Env *e, const pb_Type *t, int idx) { static int Lpb_encode(lua_State *L) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); lpb_Env e; argcheck(L, t!=NULL, 1, "type '%s' does not exists", lua_tostring(L, 1)); luaL_checktype(L, 2, LUA_TTABLE); @@ -1721,7 +1745,7 @@ static int lpbE_pack(lpb_Env* e, const pb_Type* t, int idx) { static int Lpb_pack(lua_State* L) { lpb_State* LS = lpb_lstate(L); - const pb_Type* t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type* t = lpb_type(L, LS, lpb_checkslice(L, 1)); lpb_Env e; int idx = 3; e.L = L, e.LS = LS, e.b = test_buffer(L, 2); @@ -1913,7 +1937,7 @@ static int lpbD_message(lpb_Env *e, const pb_Type *t) { static int lpbD_decode(lua_State *L, pb_Slice s, int start) { lpb_State *LS = lpb_lstate(L); - const pb_Type *t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type *t = lpb_type(L, LS, lpb_checkslice(L, 1)); lpb_Env e; argcheck(L, t!=NULL, 1, "type '%s' does not exists", lua_tostring(L, 1)); lua_settop(L, start); @@ -1994,7 +2018,7 @@ static int lpbD_unpack(lpb_Env* e, const pb_Type* t) { static int Lpb_unpack(lua_State* L) { lpb_State* LS = lpb_lstate(L); - const pb_Type* t = lpb_type(LS, lpb_checkslice(L, 1)); + const pb_Type* t = lpb_type(L, LS, lpb_checkslice(L, 1)); pb_Slice s = lpb_checkslice(L, 2); lpb_Env e; e.L = L, e.LS = LS, e.s = &s; diff --git a/pb.h b/pb.h index 739f78b..4387912 100644 --- a/pb.h +++ b/pb.h @@ -194,7 +194,7 @@ PB_API size_t pb_addfixed64 (pb_Buffer *b, uint64_t v); PB_API size_t pb_addslice (pb_Buffer *b, pb_Slice s); PB_API size_t pb_addbytes (pb_Buffer *b, pb_Slice s); -PB_API size_t pb_addlength (pb_Buffer *b, size_t len); +PB_API size_t pb_addlength (pb_Buffer *b, size_t len, size_t prealloc); /* type info database state and name table */ @@ -743,18 +743,22 @@ PB_API size_t pb_addslice(pb_Buffer *b, pb_Slice s) { return len; } -PB_API size_t pb_addlength(pb_Buffer *b, size_t len) { +PB_API size_t pb_addlength(pb_Buffer *b, size_t len, size_t prealloc) { char buff[10], *s; - size_t bl, ml; + size_t bl, ml, rl = 0; if ((bl = pb_bufflen(b)) < len) return 0; ml = pb_write64(buff, bl - len); - if (pb_prepbuffsize(b, ml) == NULL) return 0; - s = pb_buffer(b) + len; - memmove(s+ml, s, bl - len); + s = pb_buffer(b) + len - prealloc; + assert(ml >= prealloc); + if (ml > prealloc) { + if (pb_prepbuffsize(b, (rl = ml - prealloc)) == NULL) return 0; + s = pb_buffer(b) + len - prealloc; + memmove(s+ml, s+prealloc, bl - len); + } memcpy(s, buff, ml); - pb_addsize(b, ml); - return ml; + pb_addsize(b, rl); + return ml + (bl - len); } PB_API size_t pb_addbytes(pb_Buffer *b, pb_Slice s) {