Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Some fixes for AWS/S3 #54

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/resty/aws/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,13 @@ local function s3_patch(request, bucket)
end

request.host = bucket .. "." .. request.host
request.headers['Host'] = request.host

local path = request.path
if bucket and path then
path = path:sub(#bucket + 2)
if path == "/" then
path = ""
if path == "" then
path = "/"
end

request.path = path
Expand All @@ -307,9 +308,11 @@ local function generate_service_methods(service)
--print(require("pl.pretty").write(self.config))

-- validate parameters
local ok, err = validate_input(params, operation.input, "params")
if not ok then
return nil, operation_prefix .. " validation error: " .. tostring(err)
if operation.input then
local ok, err = validate_input(params, operation.input, "params")
if not ok then
return nil, operation_prefix .. " validation error: " .. tostring(err)
end
end

-- generate request data and format it according to the protocol
Expand Down Expand Up @@ -354,8 +357,18 @@ local function generate_service_methods(service)

--print(require("pl.pretty").write(signed_request))

local need_raw_reader = false
if operation.output then
for key, shape in pairs(operation.output.members) do
if shape.type == 'blob' or shape.streaming then
need_raw_reader = true
break
end
end
end

-- execute the request
local response, err = execute_request(signed_request)
local response, err = execute_request(signed_request, need_raw_reader)
if not response then
return nil, operation_prefix .. " " .. tostring(err)
end
Expand Down
74 changes: 38 additions & 36 deletions src/resty/aws/request/build.lua
Original file line number Diff line number Diff line change
Expand Up @@ -157,45 +157,47 @@ local function build_request(operation, config, params)

-- inject parameters in the right places; path/query/header/body
-- this assumes they all live on the top-level of the structure, is this correct??
for name, member_config in pairs(operation.input.members) do
local param_value = params[name]
-- TODO: date-time value should be properly formatted???
if param_value ~= nil then

-- a parameter value is provided
local location = member_config.location
local locationName = member_config.locationName
-- print(name," = ", param_value, ": ",location, " (", locationName,")")

if location == "uri" then
local place_holder = "{" .. locationName .. "%+?}"
local replacement = escape_uri(param_value):gsub("%%", "%%%%")
request.path = request.path:gsub(place_holder, replacement)

elseif location == "querystring" then
request.query[locationName] = param_value

elseif location == "header" then
request.headers[locationName] = param_value

elseif location == "headers" then
for k,v in pairs(param_value) do
request.headers[locationName .. k] = v
end
if operation.input then
for name, member_config in pairs(operation.input.members) do
local param_value = params[name]
-- TODO: date-time value should be properly formatted???
if param_value ~= nil then

-- a parameter value is provided
local location = member_config.location
local locationName = member_config.locationName
-- print(name," = ", param_value, ": ",location, " (", locationName,")")

if location == "uri" then
local place_holder = "{" .. locationName .. "%+?}"
local replacement = escape_uri(param_value):gsub("%%", "%%%%")
request.path = request.path:gsub(place_holder, replacement)

elseif location == "querystring" then
request.query[locationName] = param_value

elseif location == "header" then
request.headers[locationName] = param_value

elseif location == "headers" then
for k,v in pairs(param_value) do
request.headers[locationName .. k] = v
end

elseif location == nil then
if config.protocol == "query" then
-- no location specified, but protocol is query, so it goes into query
request.query[name] = param_value
elseif member_config.type == "blob" then
request.body = param_value
else
-- nowhere else to go, so put it in the body (for json and xml)
body_tbl[name] = param_value
end

elseif location == nil then
if config.protocol == "query" then
-- no location specified, but protocol is query, so it goes into query
request.query[name] = param_value
elseif member_config.type == "blob" then
request.body = param_value
else
-- nowhere else to go, so put it in the body (for json and xml)
body_tbl[name] = param_value
error("Unknown location: " .. location)
end

else
error("Unknown location: " .. location)
end
end
end
Expand Down
13 changes: 12 additions & 1 deletion src/resty/aws/request/execute.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ local json_decode = require("cjson.safe").new().decode
--
-- Input parameters:
-- * signed_request table
local function execute_request(signed_request)
local function execute_request(signed_request, return_raw_body)

local httpc = http.new()
httpc:set_timeout(60000)
Expand Down Expand Up @@ -49,6 +49,17 @@ local function execute_request(signed_request)

local body do
if response.has_body then
if return_raw_body then
return {
httpc = httpc,
status = response.status,
reason = response.reason,
headers = response.headers,
has_body = response.has_body,
body_reader = response.body_reader,
read_body = response.read_body,
}
end
body, err = response:read_body()
if not body then
return nil, ("failed reading response body from '%s:%s': %s"):format(
Expand Down
3 changes: 3 additions & 0 deletions src/resty/aws/request/validate.lua
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ local validators do
type = always_pass,
deprecated = always_pass,
box = always_pass,
locationName = always_pass,
location = always_pass,
deprecatedMessage = always_pass,
},ops_mt)


Expand Down