diff --git a/doc/interfaces/stream.md b/doc/interfaces/stream.md index a7197851..348025dd 100644 --- a/doc/interfaces/stream.md +++ b/doc/interfaces/stream.md @@ -93,6 +93,8 @@ Writes the string `str` to the stream and ends the stream. On error, returns `ni - `options` is a table containing: - `.file` (file) + - `.count` (positive integer): number of bytes of `file` to write + defaults to infinity (the whole file will be written) Writes the contents of file `file` to the stream and ends the stream. `file` will not be automatically seeked, so ensure it is at the correct offset before calling. On error, returns `nil`, an error message and an error number. diff --git a/http/stream_common.lua b/http/stream_common.lua index c7167e4f..d9faf507 100644 --- a/http/stream_common.lua +++ b/http/stream_common.lua @@ -154,18 +154,25 @@ end function stream_methods:write_body_from_file(options, timeout) local deadline = timeout and (monotime()+timeout) - local file + local file, count if io.type(options) then -- lua-http <= 0.2 took a file handle file = options else file = options.file + count = options.count end - -- Can't use :lines here as in Lua 5.1 it doesn't take a parameter - while true do - local chunk, err = file:read(CHUNK_SIZE) + if count == nil then + count = math.huge + elseif type(count) ~= "number" or count < 0 or count % 1 ~= 0 then + error("invalid .count parameter (expected positive integer)") + end + while count > 0 do + local chunk, err = file:read(math.min(CHUNK_SIZE, count)) if chunk == nil then if err then error(err) + elseif count ~= math.huge and count > 0 then + error("unexpected EOF") end break end @@ -173,6 +180,7 @@ function stream_methods:write_body_from_file(options, timeout) if not ok then return nil, err2, errno2 end + count = count - #chunk end return self:write_chunk("", true, deadline and (deadline-monotime())) end diff --git a/spec/stream_common_spec.lua b/spec/stream_common_spec.lua index 4cc19e3f..7b17f892 100644 --- a/spec/stream_common_spec.lua +++ b/spec/stream_common_spec.lua @@ -127,5 +127,28 @@ describe("http.stream_common", function() client:close() server:close() end) + it("limits number of bytes when using .count option", function() + local server, client = new_pair(1.1) + local cq = cqueues.new() + cq:wrap(function() + local file = io.tmpfile() + assert(file:write("hello world!")) + assert(file:seek("set")) + local stream = client:new_stream() + assert(stream:write_headers(new_request_headers(), false)) + assert(stream:write_body_from_file({ + file = file; + count = 5; + })) + end) + cq:wrap(function() + local stream = assert(server:get_next_incoming_stream()) + assert.same("hello", assert(stream:get_body_as_string())) + end) + assert_loop(cq, TEST_TIMEOUT) + assert.truthy(cq:empty()) + client:close() + server:close() + end) end) end)