diff --git a/handlersocket/hstcpsvr_worker.cpp b/handlersocket/hstcpsvr_worker.cpp index 3aff490..9cca751 100644 --- a/handlersocket/hstcpsvr_worker.cpp +++ b/handlersocket/hstcpsvr_worker.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #if __linux__ #include #endif @@ -35,24 +36,80 @@ namespace dena { +struct hstcpsvr_conn; + +enum dbrequest_cmd { + dbrequest_cmd_none, + dbrequest_cmd_open, + dbrequest_cmd_exec, + dbrequest_cmd_auth, +}; + +struct dbrequest : public dbcallback_i { + hstcpsvr_conn *conn_backref; + string_buffer *respbuf; + size_t resp_begin_pos; + dbrequest_cmd cmd; + bool executed; + cmd_open_args open_args; + cmd_exec_args exec_args; + std::vector work_flds; + std::vector work_uflds; + std::vector work_filters; + dbrequest(hstcpsvr_conn *conn) + : conn_backref(conn), respbuf(0), resp_begin_pos(0), + cmd(dbrequest_cmd_none), executed(false) { } + void reset_cmd_none(string_buffer *resp) { + cmd = dbrequest_cmd_none; + respbuf = resp; + resp_begin_pos = 0; + executed = false; + } + void reset_cmd_open(string_buffer *resp) { + cmd = dbrequest_cmd_open; + open_args = cmd_open_args(); + respbuf = resp; + resp_begin_pos = 0; + executed = false; + } + void reset_cmd_exec(string_buffer *resp) { + cmd = dbrequest_cmd_exec; + exec_args = cmd_exec_args(); + respbuf = resp; + resp_begin_pos = 0; + executed = false; + } + void reset_cmd_auth(string_buffer *resp) { + cmd = dbrequest_cmd_auth; + respbuf = resp; + resp_begin_pos = 0; + executed = false; + } + virtual void dbcb_set_prep_stmt(size_t pst_id, const prep_stmt& v); + virtual const prep_stmt *dbcb_get_prep_stmt(size_t pst_id) const; + virtual void dbcb_resp_short(uint32_t code, const char *msg); + virtual void dbcb_resp_short_num(uint32_t code, uint32_t value); + virtual void dbcb_resp_short_num64(uint32_t code, uint64_t value); + virtual void dbcb_resp_begin(size_t num_flds); + virtual void dbcb_resp_entry(const char *fld, size_t fldlen); + virtual void dbcb_resp_end(); + virtual void dbcb_resp_cancel(); +}; + struct dbconnstate { string_buffer readbuf; string_buffer writebuf; std::vector prep_stmts; - size_t resp_begin_pos; void reset() { readbuf.clear(); writebuf.clear(); prep_stmts.clear(); - resp_begin_pos = 0; } - dbconnstate() : resp_begin_pos(0) { } }; -struct hstcpsvr_conn; typedef auto_ptrcontainer< std::list > hstcpsvr_conns_type; -struct hstcpsvr_conn : public dbcallback_i { +struct hstcpsvr_conn { public: auto_file fd; sockaddr_storage addr; @@ -66,6 +123,7 @@ struct hstcpsvr_conn : public dbcallback_i { time_t nb_last_io; hstcpsvr_conns_type::iterator conns_iter; bool authorized; + dbrequest cur_request; public: bool closed() const; bool ok_to_close() const; @@ -73,20 +131,10 @@ struct hstcpsvr_conn : public dbcallback_i { int accept(const hstcpsvr_shared_c& cshared); bool write_more(bool *more_r = 0); bool read_more(bool *more_r = 0); - public: - virtual void dbcb_set_prep_stmt(size_t pst_id, const prep_stmt& v); - virtual const prep_stmt *dbcb_get_prep_stmt(size_t pst_id) const; - virtual void dbcb_resp_short(uint32_t code, const char *msg); - virtual void dbcb_resp_short_num(uint32_t code, uint32_t value); - virtual void dbcb_resp_short_num64(uint32_t code, uint64_t value); - virtual void dbcb_resp_begin(size_t num_flds); - virtual void dbcb_resp_entry(const char *fld, size_t fldlen); - virtual void dbcb_resp_end(); - virtual void dbcb_resp_cancel(); public: hstcpsvr_conn() : addr_len(sizeof(addr)), readsize(4096), nonblocking(false), read_finished(false), write_finished(false), - nb_last_io(0), authorized(false) { } + nb_last_io(0), authorized(false), cur_request(this) { } }; bool @@ -166,87 +214,92 @@ hstcpsvr_conn::read_more(bool *more_r) } void -hstcpsvr_conn::dbcb_set_prep_stmt(size_t pst_id, const prep_stmt& v) +dbrequest::dbcb_set_prep_stmt(size_t pst_id, const prep_stmt& v) { - if (cstate.prep_stmts.size() <= pst_id) { - cstate.prep_stmts.resize(pst_id + 1); + if (conn_backref->cstate.prep_stmts.size() <= pst_id) { + conn_backref->cstate.prep_stmts.resize(pst_id + 1); } - cstate.prep_stmts[pst_id] = v; + conn_backref->cstate.prep_stmts[pst_id] = v; } const prep_stmt * -hstcpsvr_conn::dbcb_get_prep_stmt(size_t pst_id) const +dbrequest::dbcb_get_prep_stmt(size_t pst_id) const { - if (cstate.prep_stmts.size() <= pst_id) { + if (conn_backref->cstate.prep_stmts.size() <= pst_id) { return 0; } - return &cstate.prep_stmts[pst_id]; + return &conn_backref->cstate.prep_stmts[pst_id]; } void -hstcpsvr_conn::dbcb_resp_short(uint32_t code, const char *msg) +dbrequest::dbcb_resp_short(uint32_t code, const char *msg) { - write_ui32(cstate.writebuf, code); + write_ui32(*respbuf, code); const size_t msglen = strlen(msg); if (msglen != 0) { - cstate.writebuf.append_literal("\t1\t"); - cstate.writebuf.append(msg, msg + msglen); + respbuf->append_literal("\t1\t"); + respbuf->append(msg, msg + msglen); } else { - cstate.writebuf.append_literal("\t1"); + respbuf->append_literal("\t1"); } - cstate.writebuf.append_literal("\n"); + respbuf->append_literal("\n"); + executed = true; } void -hstcpsvr_conn::dbcb_resp_short_num(uint32_t code, uint32_t value) +dbrequest::dbcb_resp_short_num(uint32_t code, uint32_t value) { - write_ui32(cstate.writebuf, code); - cstate.writebuf.append_literal("\t1\t"); - write_ui32(cstate.writebuf, value); - cstate.writebuf.append_literal("\n"); + write_ui32(*respbuf, code); + respbuf->append_literal("\t1\t"); + write_ui32(*respbuf, value); + respbuf->append_literal("\n"); + executed = true; } void -hstcpsvr_conn::dbcb_resp_short_num64(uint32_t code, uint64_t value) +dbrequest::dbcb_resp_short_num64(uint32_t code, uint64_t value) { - write_ui32(cstate.writebuf, code); - cstate.writebuf.append_literal("\t1\t"); - write_ui64(cstate.writebuf, value); - cstate.writebuf.append_literal("\n"); + write_ui32(*respbuf, code); + respbuf->append_literal("\t1\t"); + write_ui64(*respbuf, value); + respbuf->append_literal("\n"); + executed = true; } void -hstcpsvr_conn::dbcb_resp_begin(size_t num_flds) +dbrequest::dbcb_resp_begin(size_t num_flds) { - cstate.resp_begin_pos = cstate.writebuf.size(); - cstate.writebuf.append_literal("0\t"); - write_ui32(cstate.writebuf, num_flds); + resp_begin_pos = respbuf->size(); + respbuf->append_literal("0\t"); + write_ui32(*respbuf, num_flds); } void -hstcpsvr_conn::dbcb_resp_entry(const char *fld, size_t fldlen) +dbrequest::dbcb_resp_entry(const char *fld, size_t fldlen) { if (fld != 0) { - cstate.writebuf.append_literal("\t"); - escape_string(cstate.writebuf, fld, fld + fldlen); + respbuf->append_literal("\t"); + escape_string(*respbuf, fld, fld + fldlen); } else { static const char t[] = "\t\0"; - cstate.writebuf.append(t, t + 2); + respbuf->append(t, t + 2); } } void -hstcpsvr_conn::dbcb_resp_end() +dbrequest::dbcb_resp_end() { - cstate.writebuf.append_literal("\n"); - cstate.resp_begin_pos = 0; + respbuf->append_literal("\n"); + resp_begin_pos = 0; + executed = true; } void -hstcpsvr_conn::dbcb_resp_cancel() +dbrequest::dbcb_resp_cancel() { - cstate.writebuf.resize(cstate.resp_begin_pos); - cstate.resp_begin_pos = 0; + /* TODO: test case */ + respbuf->resize(resp_begin_pos); + resp_begin_pos = 0; } struct hstcpsvr_worker : public hstcpsvr_worker_i, private noncopyable { @@ -266,16 +319,18 @@ struct hstcpsvr_worker : public hstcpsvr_worker_i, private noncopyable { #endif bool accept_enabled; int accept_balance; - std::vector filters_work; private: int run_one_nb(); int run_one_ep(); - void execute_lines(hstcpsvr_conn& conn); - void execute_line(char *start, char *finish, hstcpsvr_conn& conn); - void do_open_index(char *start, char *finish, hstcpsvr_conn& conn); + void execute_lines(dbrequest& req, hstcpsvr_conn& conn); + void execute_line(char *start, char *finish, dbrequest& req, + hstcpsvr_conn& conn); + void do_open_index(char *start, char *finish, dbrequest& req, + hstcpsvr_conn& conn); void do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, - char *finish, hstcpsvr_conn& conn); - void do_authorization(char *start, char *finish, hstcpsvr_conn& conn); + char *finish, dbrequest& req, hstcpsvr_conn& conn); + void do_authorization(char *start, char *finish, dbrequest& req, + hstcpsvr_conn& conn); }; hstcpsvr_worker::hstcpsvr_worker(const hstcpsvr_worker_arg& arg) @@ -410,7 +465,7 @@ hstcpsvr_worker::run_one_nb() if ((pfd.revents & mask_in) == 0 || (*i)->cstate.readbuf.size() == 0) { continue; } - execute_lines(**i); + execute_lines((*i)->cur_request, **i); } /* COMMIT */ dbctx->unlock_tables_if(); @@ -545,7 +600,7 @@ hstcpsvr_worker::run_one_ep() conn->read_finished = true; conn->write_finished = true; } else { - execute_lines(*conn); + execute_lines(conn->cur_request, *conn); } } /* COMMIT */ @@ -646,7 +701,7 @@ hstcpsvr_worker::run_one_ep() #endif void -hstcpsvr_worker::execute_lines(hstcpsvr_conn& conn) +hstcpsvr_worker::execute_lines(dbrequest& req, hstcpsvr_conn& conn) { dbconnstate& cstate = conn.cstate; char *buf_end = cstate.readbuf.end(); @@ -657,14 +712,15 @@ hstcpsvr_worker::execute_lines(hstcpsvr_conn& conn) break; } char *const lf = (line_begin != nl && nl[-1] == '\r') ? nl - 1 : nl; - execute_line(line_begin, lf, conn); + execute_line(line_begin, lf, req, conn); line_begin = nl + 1; } cstate.readbuf.erase_front(line_begin - cstate.readbuf.begin()); } void -hstcpsvr_worker::execute_line(char *start, char *finish, hstcpsvr_conn& conn) +hstcpsvr_worker::execute_line(char *start, char *finish, dbrequest& req, + hstcpsvr_conn& conn) { /* safe to modify, safe to dereference 'finish' */ char *const cmd_begin = start; @@ -672,31 +728,38 @@ hstcpsvr_worker::execute_line(char *start, char *finish, hstcpsvr_conn& conn) char *const cmd_end = start; skip_one(start, finish); if (cmd_begin == cmd_end) { - return conn.dbcb_resp_short(2, "cmd"); + req.reset_cmd_none(&conn.cstate.writebuf); + return req.dbcb_resp_short(2, "cmd"); } if (cmd_begin + 1 == cmd_end) { if (cmd_begin[0] == 'P') { if (cshared.require_auth && !conn.authorized) { - return conn.dbcb_resp_short(3, "unauth"); + req.reset_cmd_none(&conn.cstate.writebuf); + return req.dbcb_resp_short(3, "unauth"); } - return do_open_index(start, finish, conn); + return do_open_index(start, finish, req, conn); } if (cmd_begin[0] == 'A') { - return do_authorization(start, finish, conn); + return do_authorization(start, finish, req, conn); } } if (cmd_begin[0] >= '0' && cmd_begin[0] <= '9') { if (cshared.require_auth && !conn.authorized) { - return conn.dbcb_resp_short(3, "unauth"); + req.reset_cmd_none(&conn.cstate.writebuf); + return req.dbcb_resp_short(3, "unauth"); } - return do_exec_on_index(cmd_begin, cmd_end, start, finish, conn); + return do_exec_on_index(cmd_begin, cmd_end, start, finish, req, conn); } - return conn.dbcb_resp_short(2, "cmd"); + req.reset_cmd_none(&conn.cstate.writebuf); + return req.dbcb_resp_short(2, "cmd"); } void -hstcpsvr_worker::do_open_index(char *start, char *finish, hstcpsvr_conn& conn) +hstcpsvr_worker::do_open_index(char *start, char *finish, dbrequest& req, + hstcpsvr_conn& conn) { + req.reset_cmd_open(&conn.cstate.writebuf); + cmd_open_args& args = req.open_args; const size_t pst_id = read_ui32(start, finish); skip_one(start, finish); /* dbname */ @@ -728,24 +791,24 @@ hstcpsvr_worker::do_open_index(char *start, char *finish, hstcpsvr_conn& conn) idxname_end[0] = 0; retflds_end[0] = 0; filflds_end[0] = 0; - cmd_open_args args; args.pst_id = pst_id; args.dbn = dbname_begin; args.tbl = tblname_begin; args.idx = idxname_begin; args.retflds = retflds_begin; args.filflds = filflds_begin; - return dbctx->cmd_open(conn, args); + return dbctx->cmd_open(req, args); } void hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, - char *finish, hstcpsvr_conn& conn) + char *finish, dbrequest& req, hstcpsvr_conn& conn) { - cmd_exec_args args; + req.reset_cmd_exec(&conn.cstate.writebuf); + cmd_exec_args& args = req.exec_args; const size_t pst_id = read_ui32(cmd_begin, cmd_end); if (pst_id >= conn.cstate.prep_stmts.size()) { - return conn.dbcb_resp_short(2, "stmtnum"); + return req.dbcb_resp_short(2, "stmtnum"); } args.pst = &conn.cstate.prep_stmts[pst_id]; char *const op_begin = start; @@ -754,8 +817,10 @@ hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, args.op = string_ref(op_begin, op_end); skip_one(start, finish); const uint32_t fldnum = read_ui32(start, finish); - string_ref *const flds = DENA_ALLOCA_ALLOCATE(string_ref, fldnum); - auto_alloca_free flds_autofree(flds); + if (req.work_flds.size() < fldnum) { + req.work_flds.resize(fldnum); + } + string_ref *const flds = &req.work_flds[0]; args.kvals = flds; args.kvalslen = fldnum; for (size_t i = 0; i < fldnum; ++i) { @@ -779,7 +844,7 @@ hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, args.skip = read_ui32(start, finish); if (start == finish) { /* simple query */ - return dbctx->cmd_exec(conn, args); + return dbctx->cmd_exec(req, args); } /* has filters or modops */ skip_one(start, finish); @@ -800,18 +865,18 @@ hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, read_token(start, finish); char *const filter_val_end = start; skip_one(start, finish); - if (filters_work.size() <= filters_count) { - filters_work.resize(filters_count + 1); + if (req.work_filters.size() <= filters_count) { + req.work_filters.resize(filters_count + 1); /* +1 for sentinel */ } - record_filter& fi = filters_work[filters_count]; + record_filter& fi = req.work_filters[filters_count]; if (filter_type_end != filter_type_begin + 1) { - return conn.dbcb_resp_short(2, "filtertype"); + return req.dbcb_resp_short(2, "filtertype"); } fi.filter_type = (filter_type_begin[0] == 'W') ? record_filter_type_break : record_filter_type_skip; const uint32_t num_filflds = args.pst->get_filter_fields().size(); if (ff_offset >= num_filflds) { - return conn.dbcb_resp_short(2, "filterfld"); + return req.dbcb_resp_short(2, "filterfld"); } fi.op = string_ref(filter_op_begin, filter_op_end); fi.ff_offset = ff_offset; @@ -827,17 +892,17 @@ hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, ++filters_count; } if (filters_count > 0) { - if (filters_work.size() <= filters_count) { - filters_work.resize(filters_count + 1); + if (req.work_filters.size() <= filters_count) { + req.work_filters.resize(filters_count + 1); } - filters_work[filters_count].op = string_ref(); /* sentinel */ - args.filters = &filters_work[0]; + req.work_filters[filters_count].op = string_ref(); /* sentinel */ + args.filters = &req.work_filters[0]; } else { args.filters = 0; } if (start == finish) { /* no modops */ - return dbctx->cmd_exec(conn, args); + return dbctx->cmd_exec(req, args); } /* has modops */ char *const mod_op_begin = start; @@ -845,8 +910,10 @@ hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, char *const mod_op_end = start; args.mod_op = string_ref(mod_op_begin, mod_op_end); const size_t num_uvals = args.pst->get_ret_fields().size(); - string_ref *const uflds = DENA_ALLOCA_ALLOCATE(string_ref, num_uvals); - auto_alloca_free uflds_autofree(uflds); + if (req.work_uflds.size() < num_uvals) { + req.work_uflds.resize(num_uvals); + } + string_ref *const uflds = &req.work_uflds[0]; for (size_t i = 0; i < num_uvals; ++i) { skip_one(start, finish); char *const f_begin = start; @@ -863,13 +930,14 @@ hstcpsvr_worker::do_exec_on_index(char *cmd_begin, char *cmd_end, char *start, } } args.uvals = uflds; - return dbctx->cmd_exec(conn, args); + return dbctx->cmd_exec(req, args); } void hstcpsvr_worker::do_authorization(char *start, char *finish, - hstcpsvr_conn& conn) + dbrequest& req, hstcpsvr_conn& conn) { + req.reset_cmd_auth(&conn.cstate.writebuf); /* auth type */ char *const authtype_begin = start; read_token(start, finish); @@ -886,7 +954,7 @@ hstcpsvr_worker::do_authorization(char *start, char *finish, char *wp = key_begin; unescape_string(wp, key_begin, key_end); if (authtype_len != 1 || authtype_begin[0] != '1') { - return conn.dbcb_resp_short(3, "authtype"); + return req.dbcb_resp_short(3, "authtype"); } if (cshared.plain_secret.size() == key_len && memcmp(cshared.plain_secret.data(), key_begin, key_len) == 0) { @@ -895,9 +963,9 @@ hstcpsvr_worker::do_authorization(char *start, char *finish, conn.authorized = false; } if (!conn.authorized) { - return conn.dbcb_resp_short(3, "unauth"); + return req.dbcb_resp_short(3, "unauth"); } else { - return conn.dbcb_resp_short(0, ""); + return req.dbcb_resp_short(0, ""); } }