Skip to content

Commit 7dab0e2

Browse files
authored
File upload improvements (CTFd#2451)
* Calculate a files sha1sum on upload for future local change detection purposes * Allow clients to control the location of an uploaded file * Adds the optional path field to the Uploaders class to control where files get uploaded to * Closes CTFd#1595
1 parent ffeff9f commit 7dab0e2

File tree

6 files changed

+212
-15
lines changed

6 files changed

+212
-15
lines changed

CTFd/api/v1/files.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,25 @@ def post(self):
9090
# challenge_id
9191
# page_id
9292

93+
# Handle situation where users attempt to upload multiple files with a single location
94+
if len(files) > 1 and request.form.get("location"):
95+
return {
96+
"success": False,
97+
"errors": {
98+
"location": ["Location cannot be specified with multiple files"]
99+
},
100+
}, 400
101+
93102
objs = []
94103
for f in files:
95104
# uploads.upload_file(file=f, chalid=req.get('challenge'))
96-
obj = uploads.upload_file(file=f, **request.form.to_dict())
105+
try:
106+
obj = uploads.upload_file(file=f, **request.form.to_dict())
107+
except ValueError as e:
108+
return {
109+
"success": False,
110+
"errors": {"location": [str(e)]},
111+
}, 400
97112
objs.append(obj)
98113

99114
schema = FileSchema(many=True)

CTFd/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class Files(db.Model):
288288
id = db.Column(db.Integer, primary_key=True)
289289
type = db.Column(db.String(80), default="standard")
290290
location = db.Column(db.Text)
291+
sha1sum = db.Column(db.String(40))
291292

292293
__mapper_args__ = {"polymorphic_identity": "standard", "polymorphic_on": type}
293294

CTFd/utils/uploads/__init__.py

+47-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import hashlib
12
import shutil
3+
from pathlib import Path
24

35
from CTFd.models import ChallengeFiles, Files, PageFiles, db
46
from CTFd.utils import get_app_config
@@ -16,8 +18,23 @@ def upload_file(*args, **kwargs):
1618
challenge_id = kwargs.get("challenge_id") or kwargs.get("challenge")
1719
page_id = kwargs.get("page_id") or kwargs.get("page")
1820
file_type = kwargs.get("type", "standard")
19-
20-
model_args = {"type": file_type, "location": None}
21+
location = kwargs.get("location")
22+
23+
# Validate location and default filename to uploaded file's name
24+
parent = None
25+
filename = file_obj.filename
26+
if location:
27+
path = Path(location)
28+
if len(path.parts) != 2:
29+
raise ValueError(
30+
"Location must contain two parts, a directory and a filename"
31+
)
32+
# Allow location to override the directory and filename
33+
parent = path.parts[0]
34+
filename = path.parts[1]
35+
location = parent + "/" + filename
36+
37+
model_args = {"type": file_type, "location": location}
2138

2239
model = Files
2340
if file_type == "challenge":
@@ -28,16 +45,39 @@ def upload_file(*args, **kwargs):
2845
model_args["page_id"] = page_id
2946

3047
uploader = get_uploader()
31-
location = uploader.upload(file_obj=file_obj, filename=file_obj.filename)
48+
location = uploader.upload(file_obj=file_obj, filename=filename, path=parent)
3249

33-
model_args["location"] = location
50+
sha1sum = hash_file(fp=file_obj)
3451

35-
file_row = model(**model_args)
36-
db.session.add(file_row)
37-
db.session.commit()
52+
model_args["location"] = location
53+
model_args["sha1sum"] = sha1sum
54+
55+
existing_file = Files.query.filter_by(location=location).first()
56+
if existing_file:
57+
for k, v in model_args.items():
58+
setattr(existing_file, k, v)
59+
db.session.commit()
60+
file_row = existing_file
61+
else:
62+
file_row = model(**model_args)
63+
db.session.add(file_row)
64+
db.session.commit()
3865
return file_row
3966

4067

68+
def hash_file(fp, algo="sha1"):
69+
fp.seek(0)
70+
if algo == "sha1":
71+
h = hashlib.sha1() # nosec
72+
# https://stackoverflow.com/a/64730457
73+
while chunk := fp.read(1024):
74+
h.update(chunk)
75+
fp.seek(0)
76+
return h.hexdigest()
77+
else:
78+
raise NotImplementedError
79+
80+
4181
def delete_file(file_id):
4282
f = Files.query.filter_by(id=file_id).first_or_404()
4383

CTFd/utils/uploads/uploaders.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,20 @@ def store(self, fileobj, filename):
5454

5555
return filename
5656

57-
def upload(self, file_obj, filename):
57+
def upload(self, file_obj, filename, path=None):
5858
if len(filename) == 0:
5959
raise Exception("Empty filenames cannot be used")
6060

61+
# Sanitize directory name
62+
if path:
63+
path = secure_filename(path) or hexencode(os.urandom(16))
64+
path = path.replace(".", "")
65+
else:
66+
path = hexencode(os.urandom(16))
67+
68+
# Sanitize file name
6169
filename = secure_filename(filename)
62-
md5hash = hexencode(os.urandom(16))
63-
file_path = posixpath.join(md5hash, filename)
70+
file_path = posixpath.join(path, filename)
6471

6572
return self.store(file_obj, file_path)
6673

@@ -110,17 +117,25 @@ def store(self, fileobj, filename):
110117
self.s3.upload_fileobj(fileobj, self.bucket, filename)
111118
return filename
112119

113-
def upload(self, file_obj, filename):
120+
def upload(self, file_obj, filename, path=None):
121+
# Sanitize directory name
122+
if path:
123+
path = secure_filename(path) or hexencode(os.urandom(16))
124+
path = path.replace(".", "")
125+
# Sanitize path
126+
path = filter(self._clean_filename, secure_filename(path).replace(" ", "_"))
127+
else:
128+
path = hexencode(os.urandom(16))
129+
130+
# Sanitize file name
114131
filename = filter(
115132
self._clean_filename, secure_filename(filename).replace(" ", "_")
116133
)
117134
filename = "".join(filename)
118135
if len(filename) <= 0:
119136
return False
120137

121-
md5hash = hexencode(os.urandom(16))
122-
123-
dst = md5hash + "/" + filename
138+
dst = path + "/" + filename
124139
self.s3.upload_fileobj(file_obj, self.bucket, dst)
125140
return dst
126141

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Add sha1sum field to Files
2+
3+
Revision ID: 5c4996aeb2cb
4+
Revises: 9e6f6578ca84
5+
Create Date: 2024-01-07 13:09:08.843903
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = "5c4996aeb2cb"
13+
down_revision = "9e6f6578ca84"
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade():
19+
op.add_column("files", sa.Column("sha1sum", sa.String(length=40), nullable=True))
20+
21+
22+
def downgrade():
23+
op.drop_column("files", "sha1sum")

tests/api/v1/test_files.py

+103
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5+
import pathlib
56
import shutil
67
from io import BytesIO
78

@@ -75,6 +76,7 @@ def test_api_files_post_admin():
7576
)
7677
assert r.status_code == 200
7778
f = Files.query.filter_by(id=1).first()
79+
assert f.sha1sum == "9032bbc224ed8b39183cb93b9a7447727ce67f9d"
7880
os.remove(os.path.join(app.config["UPLOAD_FOLDER"] + "/" + f.location))
7981
destroy_ctfd(app)
8082

@@ -137,3 +139,104 @@ def test_api_file_delete_admin():
137139
shutil.rmtree(os.path.dirname(path), ignore_errors=True)
138140

139141
destroy_ctfd(app)
142+
143+
144+
def test_api_file_custom_location():
145+
"""
146+
Test file uploading with custom location
147+
"""
148+
app = create_ctfd()
149+
with app.app_context():
150+
with login_as_user(app, name="admin") as client:
151+
with client.session_transaction() as sess:
152+
nonce = sess.get("nonce")
153+
r = client.post(
154+
"/api/v1/files",
155+
content_type="multipart/form-data",
156+
data={
157+
"file": (BytesIO(b"test file content"), "test.txt"),
158+
"location": "testing/asdf.txt",
159+
"nonce": nonce,
160+
},
161+
)
162+
assert r.status_code == 200
163+
f = Files.query.filter_by(id=1).first()
164+
assert f.sha1sum == "9032bbc224ed8b39183cb93b9a7447727ce67f9d"
165+
assert f.location == "testing/asdf.txt"
166+
r = client.get("/files/" + f.location)
167+
assert r.get_data(as_text=True) == "test file content"
168+
169+
r = client.get("/api/v1/files/1")
170+
response = r.get_json()
171+
assert (
172+
response["data"]["sha1sum"]
173+
== "9032bbc224ed8b39183cb93b9a7447727ce67f9d"
174+
)
175+
assert response["data"]["location"] == "testing/asdf.txt"
176+
177+
# Test deletion
178+
r = client.delete("/api/v1/files/1", json="")
179+
assert r.status_code == 200
180+
assert Files.query.count() == 0
181+
182+
target = pathlib.Path(app.config["UPLOAD_FOLDER"]) / f.location
183+
assert target.exists() is False
184+
185+
# Test invalid locations
186+
invalid_paths = [
187+
"testing/prefix/asdf.txt",
188+
"/testing/asdf.txt",
189+
"asdf.txt",
190+
]
191+
for path in invalid_paths:
192+
r = client.post(
193+
"/api/v1/files",
194+
content_type="multipart/form-data",
195+
data={
196+
"file": (BytesIO(b"test file content"), "test.txt"),
197+
"location": path,
198+
"nonce": nonce,
199+
},
200+
)
201+
assert r.status_code == 400
202+
destroy_ctfd(app)
203+
204+
205+
def test_api_file_overwrite_by_location():
206+
"""
207+
Test file overwriting with a specific location
208+
"""
209+
app = create_ctfd()
210+
with app.app_context():
211+
with login_as_user(app, name="admin") as client:
212+
with client.session_transaction() as sess:
213+
nonce = sess.get("nonce")
214+
r = client.post(
215+
"/api/v1/files",
216+
content_type="multipart/form-data",
217+
data={
218+
"file": (BytesIO(b"test file content"), "test.txt"),
219+
"location": "testing/asdf.txt",
220+
"nonce": nonce,
221+
},
222+
)
223+
assert r.status_code == 200
224+
f = Files.query.filter_by(id=1).first()
225+
r = client.get("/files/" + f.location)
226+
assert r.get_data(as_text=True) == "test file content"
227+
228+
r = client.post(
229+
"/api/v1/files",
230+
content_type="multipart/form-data",
231+
data={
232+
"file": (BytesIO(b"testing new uploaded file content"), "test.txt"),
233+
"location": "testing/asdf.txt",
234+
"nonce": nonce,
235+
},
236+
)
237+
assert r.status_code == 200
238+
f = Files.query.filter_by(id=1).first()
239+
r = client.get("/files/" + f.location)
240+
assert f.sha1sum == "0ee7eb85ac0b8d8ae03f3080589157cde553b13f"
241+
assert r.get_data(as_text=True) == "testing new uploaded file content"
242+
destroy_ctfd(app)

0 commit comments

Comments
 (0)