Skip to content

Commit fb3731d

Browse files
committed
Updated unit tests
1 parent 016f126 commit fb3731d

File tree

5 files changed

+205
-107
lines changed

5 files changed

+205
-107
lines changed

src/models.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ class TrainRideFilter(BaseModel):
5353
destination: str
5454

5555
departure_date: datetime
56-
57-
min_departure_hour: Optional[time] = None
58-
max_departure_hour: Optional[time] = None
5956
max_duration_minutes: Optional[int] = None
6057

6158
max_price: Optional[float] = None

tests/test_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
from unittest.mock import patch, mock_open
3+
from src.config import init_bot, get_bot_token
4+
5+
def test_get_bot_token_no_config():
6+
with patch("builtins.open", mock_open(read_data="")):
7+
with patch("src.config.init_bot") as mock_init_bot:
8+
with pytest.raises(KeyError):
9+
get_bot_token()

tests/test_models.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def test_filter_rides_origin_and_destination(sample_rides):
4646
origin="Madrid",
4747
destination="Barcelona",
4848
departure_date=datetime(2025, 1, 30),
49-
min_departure_hour=None,
50-
max_departure_hour=None,
5149
max_duration_minutes=None,
5250
max_price=None
5351
)
@@ -61,9 +59,7 @@ def test_filter_rides_by_min_departure_hour(sample_rides):
6159
filter = TrainRideFilter(
6260
origin="Madrid",
6361
destination="Barcelona",
64-
departure_date=datetime(2025, 1, 30),
65-
min_departure_hour=time(9, 0),
66-
max_departure_hour=None,
62+
departure_date=datetime(2025, 1, 30, 9, 0),
6763
max_duration_minutes=None,
6864
max_price=None
6965
)
@@ -73,29 +69,11 @@ def test_filter_rides_by_min_departure_hour(sample_rides):
7369
assert len(result) == 1 # Only the third ride should pass
7470

7571

76-
def test_filter_rides_by_max_departure_hour(sample_rides):
77-
filter = TrainRideFilter(
78-
origin="Madrid",
79-
destination="Barcelona",
80-
departure_date=datetime(2025, 1, 30),
81-
min_departure_hour=None,
82-
max_departure_hour=time(9, 0),
83-
max_duration_minutes=None,
84-
max_price=None
85-
)
86-
87-
result = filter.filter_rides(sample_rides)
88-
89-
assert len(result) == 1 # Only the first ride should pass
90-
91-
9272
def test_filter_rides_by_max_duration(sample_rides):
9373
filter = TrainRideFilter(
9474
origin="Madrid",
9575
destination="Barcelona",
9676
departure_date=datetime(2025, 1, 30),
97-
min_departure_hour=None,
98-
max_departure_hour=None,
9977
max_duration_minutes=180,
10078
max_price=None
10179
)
@@ -110,8 +88,6 @@ def test_filter_rides_by_max_price(sample_rides):
11088
origin="Madrid",
11189
destination="Barcelona",
11290
departure_date=datetime(2025, 1, 30),
113-
min_departure_hour=None,
114-
max_departure_hour=None,
11591
max_duration_minutes=None,
11692
max_price=55.0
11793
)
@@ -126,9 +102,7 @@ def test_filter_rides_no_results(sample_rides):
126102
origin="Madrid",
127103
destination="Barcelona",
128104
departure_date=datetime(2025, 1, 31), # Different date
129-
min_departure_hour=None,
130-
max_departure_hour=None,
131-
max_duration_minutes=None,
105+
max_duration_minutes=1,
132106
max_price=None
133107
)
134108

tests/test_scraper.py

Lines changed: 96 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,102 @@
1-
import unittest
2-
from unittest.mock import patch, MagicMock
1+
import pytest
32
from datetime import datetime
4-
from scraper import Scraper, extract_dwr_token, create_cookiedict, create_search_id, create_session_script_id
5-
from models import StationRecord
6-
from errors import InvalidTrainRideFilter, InvalidDWRToken
3+
from unittest.mock import patch, MagicMock
4+
from src.scraper import Scraper, extract_dwr_token, extract_train_list, create_search_id, create_session_script_id, tokenify
5+
from models import StationRecord, TrainRideRecord
6+
from errors import InvalidDWRToken, InvalidTrainRideFilter
7+
8+
@pytest.fixture
9+
def scraper():
10+
origin = StationRecord(name="Madrid", code="MAD")
11+
destination = StationRecord(name="Barcelona", code="BCN")
12+
departure_date = datetime(2023, 12, 25)
13+
return Scraper(origin, destination, departure_date)
14+
15+
def test_create_search_id():
16+
search_id = create_search_id()
17+
assert len(search_id) == 5
18+
assert search_id.startswith("_")
19+
20+
def test_create_session_script_id():
21+
dwr_token = "test_token"
22+
session_script_id = create_session_script_id(dwr_token)
23+
assert session_script_id.startswith(dwr_token)
24+
25+
def test_tokenify():
26+
number = 123456
27+
token = tokenify(number)
28+
assert isinstance(token, str)
29+
30+
def test_extract_dwr_token():
31+
response_text = 'r.handleCallback("0","0","test_token")'
32+
token = extract_dwr_token(response_text)
33+
assert token == "test_token"
34+
35+
def test_extract_dwr_token_invalid():
36+
response_text = 'invalid response'
37+
with pytest.raises(InvalidDWRToken):
38+
extract_dwr_token(response_text)
39+
40+
def test_extract_train_list():
41+
response_text = 'r.handleCallback(0,0,{"listadoTrenes":[]});'
42+
train_list = extract_train_list(response_text)
43+
assert "listadoTrenes" in train_list
44+
45+
def test_invalid_return_date():
46+
origin = StationRecord(name="Madrid", code="MAD")
47+
destination = StationRecord(name="Barcelona", code="BCN")
48+
departure_date = datetime(2023, 12, 25)
49+
return_date = datetime(2023, 12, 24)
50+
with pytest.raises(InvalidTrainRideFilter):
51+
Scraper(origin, destination, departure_date, return_date)
52+
53+
@patch('src.scraper.requests.Session.post')
54+
def test_do_search(mock_post, scraper):
55+
mock_post.return_value.ok = True
56+
scraper._do_search()
57+
assert mock_post.called
58+
59+
@patch('src.scraper.requests.Session.post')
60+
def test_do_get_dwr_token(mock_post, scraper):
61+
mock_post.return_value.ok = True
62+
mock_post.return_value.text = 'r.handleCallback("0","0","test_token")'
63+
scraper._do_get_dwr_token()
64+
assert scraper.dwr_token == "test_token"
765

8-
class TestScraper(unittest.TestCase):
9-
def setUp(self):
10-
# Setup mock data for testing
11-
self.origin = StationRecord(name="Madrid", code="MAD")
12-
self.destination = StationRecord(name="Barcelona", code="BCN")
13-
self.departure_date = datetime(2025, 1, 31, 10, 0)
14-
self.return_date = datetime(2025, 2, 1, 10, 0)
15-
16-
self.scraper = Scraper(
17-
origin=self.origin,
18-
destination=self.destination,
19-
departure_date=self.departure_date,
20-
return_date=self.return_date
21-
)
22-
23-
@patch('scraper.requests.Session.post')
24-
def test_get_trainrides(self, mock_post):
25-
# Mock the HTTP responses for the scraper methods
26-
mock_post.return_value.ok = True
27-
mock_post.return_value.text = '{"listadoTrenes": []}' # mock response for getting train list
66+
@patch('src.scraper.requests.Session.post')
67+
def test_do_update_session_objects(mock_post, scraper):
68+
mock_post.return_value.ok = True
69+
scraper.dwr_token = "test_token"
70+
scraper.script_session_id = "test_script_session_id"
71+
scraper._do_update_session_objects()
72+
assert mock_post.called
2873

29-
# Mock necessary functions
30-
self.scraper._do_search = MagicMock()
31-
self.scraper._do_get_dwr_token = MagicMock()
32-
self.scraper._do_update_session_objects = MagicMock()
33-
self.scraper._do_get_train_list = MagicMock(return_value={"listadoTrenes": []})
74+
@patch('src.scraper.requests.Session.post')
75+
def test_do_get_train_list(mock_post, scraper):
76+
mock_post.return_value.ok = True
77+
mock_post.return_value.text = 'r.handleCallback(0,0,{"listadoTrenes":[]});'
78+
train_list = scraper._do_get_train_list()
79+
assert "listadoTrenes" in train_list
3480

35-
# Test if get_trainrides returns an empty list when no trains are found
36-
result = self.scraper.get_trainrides()
37-
self.assertEqual(result, [])
38-
39-
def test_invalid_trainride_filter(self):
40-
# Test the case where return_date is before departure_date
41-
with self.assertRaises(InvalidTrainRideFilter):
42-
Scraper(
43-
origin=self.origin,
44-
destination=self.destination,
45-
departure_date=self.departure_date,
46-
return_date=datetime(2025, 1, 30, 10, 0) # Invalid return date
47-
)
48-
49-
@patch('scraper.extract_dwr_token')
50-
def test_extract_dwr_token(self, mock_extract_dwr_token):
51-
# Test extracting the DWR token with a valid response
52-
mock_response = 'throw #DWR-REPLY\nr.handleCallback("1","0","12345");'
53-
mock_extract_dwr_token.return_value = '12345'
54-
token = extract_dwr_token(mock_response)
55-
self.assertEqual(token, '12345')
81+
def test_change_datetime_hour():
82+
date = datetime(2023, 12, 25)
83+
hour = "14:30"
84+
new_date = Scraper._change_datetime_hour(hour, date)
85+
assert new_date.hour == 14
86+
assert new_date.minute == 30
5687

57-
def test_extract_dwr_token_invalid(self):
58-
# Test extracting the DWR token with an invalid response
59-
with self.assertRaises(InvalidDWRToken):
60-
extract_dwr_token('Invalid response')
61-
62-
def test_create_cookiedict(self):
63-
# Test the creation of cookies for the search
64-
cookies = create_cookiedict(self.origin, self.destination)
65-
self.assertIn("Search", cookies["name"])
66-
self.assertIn(self.origin.code, cookies["value"])
67-
68-
def test_create_search_id(self):
69-
# Test the creation of search_id
70-
search_id = create_search_id()
71-
self.assertTrue(search_id.startswith('_'))
72-
self.assertEqual(len(search_id), 5)
73-
74-
def test_create_session_script_id(self):
75-
# Test the creation of session script ID
76-
with patch('scraper.tokenify', return_value='test_token'):
77-
script_id = create_session_script_id('dwr_token')
78-
self.assertTrue(script_id.startswith('dwr_token/'))
79-
self.assertIn('test_token', script_id)
88+
def test_is_train_available():
89+
train = {
90+
"completo": False,
91+
"razonNoDisponible": "",
92+
"tarifaMinima": "10.00"
93+
}
94+
assert Scraper._is_train_available(train)
8095

81-
if __name__ == "__main__":
82-
unittest.main()
96+
def test_is_train_not_available():
97+
train = {
98+
"completo": True,
99+
"razonNoDisponible": "1",
100+
"tarifaMinima": None
101+
}
102+
assert not Scraper._is_train_available(train)

tests/test_validators.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pytest
2+
from datetime import datetime
3+
from dataclasses import dataclass
4+
from validators import validate_station, validate_date, validate_float, StationValidationResult, DateValidationResult, FloatValidationResult
5+
from errors import StationNotFound
6+
from models import StationRecord
7+
from storage import StationsStorage
8+
from messages import user_messages as msg
9+
10+
@dataclass
11+
class Message:
12+
text: str | None = None
13+
14+
@pytest.fixture
15+
def mock_message():
16+
return Message()
17+
18+
def test_validate_station_valid(mock_message, mocker):
19+
mock_message.text = "Madrid"
20+
mock_station = StationRecord(name="Madrid", code="MAD")
21+
mocker.patch.object(StationsStorage, 'get_station', return_value=mock_station)
22+
23+
result = validate_station(mock_message)
24+
25+
assert result.is_valid
26+
assert result.station == mock_station
27+
assert result.error_message == ""
28+
29+
def test_validate_station_not_found(mock_message, mocker):
30+
mock_message.text = "Unknown"
31+
mocker.patch.object(StationsStorage, 'get_station', side_effect=StationNotFound)
32+
mocker.patch.object(StationsStorage, 'find_station', return_value=["Madrid", "Barcelona"])
33+
34+
result = validate_station(mock_message)
35+
36+
assert not result.is_valid
37+
assert result.station is None
38+
assert result.error_message == msg["station_not_found"].format("Unknown", "Madrid\nBarcelona")
39+
40+
def test_validate_station_no_text(mock_message):
41+
mock_message.text = None
42+
43+
result = validate_station(mock_message)
44+
45+
assert not result.is_valid
46+
assert result.station is None
47+
assert result.error_message == msg["station_invalid"]
48+
49+
def test_validate_date_valid(mock_message):
50+
mock_message.text = "Hoy a las 07:00"
51+
52+
result = validate_date(mock_message)
53+
54+
assert result.is_valid
55+
assert result.date == datetime.now().replace(hour=7, minute=0, second=0, microsecond=0)
56+
assert result.error_message == ""
57+
58+
def test_validate_date_invalid(mock_message):
59+
mock_message.text = "invalid date"
60+
61+
result = validate_date(mock_message)
62+
63+
assert not result.is_valid
64+
assert result.date is None
65+
assert result.error_message == msg["wrong_date"]
66+
67+
def test_validate_date_no_text(mock_message):
68+
mock_message.text = None
69+
70+
result = validate_date(mock_message)
71+
72+
assert not result.is_valid
73+
assert result.date is None
74+
assert result.error_message == msg["wrong_date"]
75+
76+
def test_validate_float_valid(mock_message):
77+
mock_message.text = "123.45"
78+
79+
result = validate_float(mock_message)
80+
81+
assert result.is_valid
82+
assert result.number == 123.45
83+
assert result.error_message == ""
84+
85+
def test_validate_float_invalid(mock_message):
86+
mock_message.text = "invalid number"
87+
88+
with pytest.raises(ValueError):
89+
validate_float(mock_message)
90+
91+
def test_validate_float_no_text(mock_message):
92+
mock_message.text = None
93+
94+
result = validate_float(mock_message)
95+
96+
assert not result.is_valid
97+
assert result.number is None
98+
assert result.error_message == msg["wrong_number"]

0 commit comments

Comments
 (0)