Skip to content

Commit ae4da82

Browse files
committed
Fix
1 parent 1c16686 commit ae4da82

File tree

1 file changed

+45
-128
lines changed

1 file changed

+45
-128
lines changed

label_studio/io_storages/tests/test_get_bytes_stream.py

+45-128
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,43 @@
1-
import io
21
import unittest
3-
from unittest.mock import MagicMock
4-
from urllib.parse import urlparse
2+
from unittest.mock import MagicMock, patch
53

4+
# Add Django models import
5+
from django.db import models
6+
from io_storages.azure_blob.models import AzureBlobStorageMixin
7+
from io_storages.gcs.models import GCSStorageMixin
8+
from io_storages.s3.models import S3StorageMixin
69

7-
# Simple implementation of get_bytes_stream from S3StorageMixin, without Django dependencies
8-
class S3StorageTester:
9-
def __init__(self):
10-
self.client = None
1110

12-
def get_client(self):
13-
return self.client
14-
15-
def get_bytes_stream(self, uri):
16-
"""Get file bytes from S3 storage as a stream and content type.
11+
# Define concrete classes inheriting from the mixins
12+
# Abstract models cannot be instantiated directly, so we create
13+
# simple concrete models for testing purposes.
14+
class ConcreteS3Storage(S3StorageMixin, models.Model):
15+
class Meta:
16+
app_label = 'tests'
1717

18-
Args:
19-
uri: The S3 URI of the file to retrieve
2018

21-
Returns:
22-
Tuple of (BytesIO stream, content_type)
23-
"""
24-
# Parse URI to get bucket and key
25-
parsed_uri = urlparse(uri, allow_fragments=False)
26-
bucket_name = parsed_uri.netloc
27-
key = parsed_uri.path.lstrip('/')
19+
class ConcreteAzureBlobStorage(AzureBlobStorageMixin, models.Model):
20+
class Meta:
21+
app_label = 'tests'
2822

29-
# Get S3 client
30-
client = self.get_client()
3123

32-
try:
33-
# Get the object from S3
34-
object_response = client.get_object(Bucket=bucket_name, Key=key)
35-
content_type = object_response.get('ContentType')
36-
data = io.BytesIO(object_response['Body'].read())
37-
return data, content_type
38-
39-
except Exception:
40-
return None, None
41-
42-
43-
# Simple implementation of get_bytes_stream from AzureBlobStorageMixin, without Django dependencies
44-
class AzureBlobStorageTester:
45-
def __init__(self):
46-
self.client = None
47-
self.container = None
48-
49-
def get_client_and_container(self):
50-
return self.client, self.container
51-
52-
def get_bytes_stream(self, uri):
53-
"""Get file bytes from Azure Blob storage as a stream and content type.
54-
55-
Args:
56-
uri: The Azure URI of the file to retrieve
57-
58-
Returns:
59-
Tuple of (BytesIO stream, content_type)
60-
"""
61-
# Parse URI to get container and blob name
62-
parsed_uri = urlparse(uri, allow_fragments=False)
63-
container_name = parsed_uri.netloc
64-
blob_name = parsed_uri.path.lstrip('/')
65-
66-
try:
67-
# Get the Azure client
68-
client, _ = self.get_client_and_container()
69-
70-
# Get a blob client for the requested blob
71-
blob_client = client.get_blob_client(container=container_name, blob=blob_name)
72-
73-
# Download the blob
74-
download_stream = blob_client.download_blob()
75-
content_type = download_stream.properties.content_settings.content_type
76-
data = io.BytesIO(download_stream.readall())
77-
78-
return data, content_type
79-
80-
except Exception:
81-
return None, None
82-
83-
84-
# Simple implementation of get_bytes_stream from GCSStorageMixin, without Django dependencies
85-
class GCSStorageTester:
86-
def __init__(self):
87-
self.client = None
88-
89-
def get_client(self):
90-
return self.client
91-
92-
def get_bytes_stream(self, uri):
93-
"""Get file bytes from GCS storage as a stream and content type.
94-
95-
Args:
96-
uri: The GCS URI of the file to retrieve
97-
98-
Returns:
99-
Tuple of (BytesIO stream, content_type)
100-
"""
101-
# Parse URI to get bucket and key
102-
parsed_uri = urlparse(uri, allow_fragments=False)
103-
bucket_name = parsed_uri.netloc
104-
blob_name = parsed_uri.path.lstrip('/')
105-
106-
try:
107-
# Get client and bucket using existing methods
108-
client = self.get_client()
109-
bucket = client.get_bucket(bucket_name)
110-
blob = bucket.blob(blob_name)
111-
blob.reload()
112-
content_type = blob.content_type or 'application/octet-stream'
113-
data = io.BytesIO(blob.download_as_bytes())
114-
return data, content_type
115-
except Exception:
116-
return None, None
24+
class ConcreteGCSStorage(GCSStorageMixin, models.Model):
25+
class Meta:
26+
app_label = 'tests'
11727

11828

11929
class TestS3StorageMixinGetBytesStream(unittest.TestCase):
12030
"""Test the get_bytes_stream method in S3StorageMixin"""
12131

12232
def setUp(self):
123-
# Create an instance of our tester class
124-
self.storage = S3StorageTester()
33+
# Create an instance of the concrete class
34+
self.storage = ConcreteS3Storage()
12535
# Setup mock client
12636
self.mock_client = MagicMock()
127-
# Set the mock client
128-
self.storage.client = self.mock_client
37+
# Patch the get_client method to return our mock client
38+
self.get_client_patcher = patch.object(self.storage, 'get_client', return_value=self.mock_client)
39+
self.get_client_patcher.start()
40+
self.addCleanup(self.get_client_patcher.stop)
12941

13042
def test_get_bytes_stream_success(self):
13143
# Create a mock response for get_object
@@ -135,7 +47,7 @@ def test_get_bytes_stream_success(self):
13547
# Set up the mock get_object response
13648
self.mock_client.get_object.return_value = {'Body': mock_body, 'ContentType': 'text/plain'}
13749

138-
# Call the get_bytes_stream method
50+
# Call the real get_bytes_stream method
13951
uri = 's3://test-bucket/test-file.txt'
14052
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
14153

@@ -148,7 +60,7 @@ def test_get_bytes_stream_exception(self):
14860
# Set up the mock to raise an exception
14961
self.mock_client.get_object.side_effect = Exception('Connection error')
15062

151-
# Call the get_bytes_stream method
63+
# Call the real get_bytes_stream method
15264
uri = 's3://test-bucket/test-file.txt'
15365
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
15466

@@ -162,14 +74,17 @@ class TestAzureBlobStorageMixinGetBytesStream(unittest.TestCase):
16274
"""Test the get_bytes_stream method in AzureBlobStorageMixin"""
16375

16476
def setUp(self):
165-
# Create an instance of our tester class
166-
self.storage = AzureBlobStorageTester()
77+
# Create an instance of the concrete class
78+
self.storage = ConcreteAzureBlobStorage()
16779
# Setup mock client and container
16880
self.mock_client = MagicMock()
16981
self.mock_container = MagicMock()
170-
# Set the mock client and container
171-
self.storage.client = self.mock_client
172-
self.storage.container = self.mock_container
82+
# Patch the get_client_and_container method
83+
self.get_client_patcher = patch.object(
84+
self.storage, 'get_client_and_container', return_value=(self.mock_client, self.mock_container)
85+
)
86+
self.get_client_patcher.start()
87+
self.addCleanup(self.get_client_patcher.stop)
17388

17489
def test_get_bytes_stream_success(self):
17590
# Mock the blob client and download_blob
@@ -182,7 +97,7 @@ def test_get_bytes_stream_success(self):
18297
mock_download_stream.properties.content_settings.content_type = 'image/jpeg'
18398
mock_download_stream.readall.return_value = b'fake image data'
18499

185-
# Call the get_bytes_stream method
100+
# Call the real get_bytes_stream method
186101
uri = 'azure-blob://test-container/test-image.jpg'
187102
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
188103

@@ -196,7 +111,7 @@ def test_get_bytes_stream_exception(self):
196111
# Set up mock client to raise an exception
197112
self.mock_client.get_blob_client.side_effect = Exception('Azure connection error')
198113

199-
# Call the get_bytes_stream method
114+
# Call the real get_bytes_stream method
200115
uri = 'azure-blob://test-container/test-image.jpg'
201116
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
202117

@@ -209,12 +124,14 @@ class TestGCSStorageMixinGetBytesStream(unittest.TestCase):
209124
"""Test the get_bytes_stream method in GCSStorageMixin"""
210125

211126
def setUp(self):
212-
# Create an instance of our tester class
213-
self.storage = GCSStorageTester()
127+
# Create an instance of the concrete class
128+
self.storage = ConcreteGCSStorage()
214129
# Setup mock client
215130
self.mock_client = MagicMock()
216-
# Set the mock client
217-
self.storage.client = self.mock_client
131+
# Patch the get_client method
132+
self.get_client_patcher = patch.object(self.storage, 'get_client', return_value=self.mock_client)
133+
self.get_client_patcher.start()
134+
self.addCleanup(self.get_client_patcher.stop)
218135

219136
def test_get_bytes_stream_success(self):
220137
# Mock bucket and blob
@@ -228,7 +145,7 @@ def test_get_bytes_stream_success(self):
228145
mock_blob.content_type = 'application/pdf'
229146
mock_blob.download_as_bytes.return_value = b'fake pdf data'
230147

231-
# Call the get_bytes_stream method
148+
# Call the real get_bytes_stream method
232149
uri = 'gs://test-bucket/test-document.pdf'
233150
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
234151

@@ -243,7 +160,7 @@ def test_get_bytes_stream_exception(self):
243160
# Set up mock client to raise an exception
244161
self.mock_client.get_bucket.side_effect = Exception('GCS connection error')
245162

246-
# Call the get_bytes_stream method
163+
# Call the real get_bytes_stream method
247164
uri = 'gs://test-bucket/test-document.pdf'
248165
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
249166

@@ -263,7 +180,7 @@ def test_get_bytes_stream_with_default_content_type(self):
263180
mock_blob.content_type = None
264181
mock_blob.download_as_bytes.return_value = b'test data'
265182

266-
# Call the get_bytes_stream method
183+
# Call the real get_bytes_stream method
267184
uri = 'gs://test-bucket/test-file'
268185
result_stream, result_content_type = self.storage.get_bytes_stream(uri)
269186

0 commit comments

Comments
 (0)