diff --git a/src/lib/schema.ts b/src/lib/schema.ts index 95dc938..c8aae6e 100644 --- a/src/lib/schema.ts +++ b/src/lib/schema.ts @@ -7,10 +7,25 @@ const zodStringBool = z .transform(x => x === "true") .pipe(z.boolean()); -const zodStringUrl = z.string().url(); +const urlSchema = z + .string() + .url() + .refine( + val => { + try { + const url = new URL(val); + return url.protocol === "http:" || url.protocol === "https:"; + } catch (err) { + return false; + } + }, + { + message: "must start with http or https", + }, + ); export const PlainConfigSchema = z.object({ - url: zodStringUrl, + url: urlSchema, width: z.coerce.number().nullish(), height: z.coerce.number().nullish(), viewPortWidth: z.coerce.number().nullish(), diff --git a/src/middlewares/extract_query_params.ts b/src/middlewares/extract_query_params.ts index a216c9e..e710d7a 100644 --- a/src/middlewares/extract_query_params.ts +++ b/src/middlewares/extract_query_params.ts @@ -33,7 +33,15 @@ export function handleExtractQueryParamsMiddleware(encryptionService?: StringEnc const { validData, errors } = parseForm({ data: input, schema: PlainConfigSchema }); if (errors) { - throw new HTTPException(400, { message: "Invalid query parameters", cause: errors }); + let message: string = "Invalid query parameters: "; + + const specificErrors = Object.entries(errors).map(([key, value]) => `(${key} - ${value})`).join(" ") + + message = `${message} ${specificErrors}`; + + console.log(message); + + throw new HTTPException(400, { message, cause: errors }); } if (validData.width && validData.width > 1920) { diff --git a/tests/app.spec.ts b/tests/app.spec.ts index 98c57d7..c0ff1f0 100644 --- a/tests/app.spec.ts +++ b/tests/app.spec.ts @@ -71,6 +71,18 @@ suite("app", () => { expect(res.status).toBe(400); expect(await res.text()).toMatch(/Invalid query/gi); }); + + [ + "file:///etc/passwd&width=4000", + "view-source:file:///home/&width=4000", + "view-source:file:///home/ec2-user/url-to-png/.env", + ].forEach(invalidDomain => { + it(`throws when invalid protocol ${invalidDomain}`, async () => { + const res = await app.request(`/?url=${invalidDomain}`); + expect(res.status).toBe(400); + expect(await res.text()).toMatch(/url - must start with http or https/gi); + }); + }); }); describe("GET /?hash=", () => {