Skip to content

Commit

Permalink
adding tests for visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
riddhibattu committed Apr 12, 2024
1 parent 8c6145a commit 4a3b393
Showing 1 changed file with 183 additions and 2 deletions.
185 changes: 183 additions & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pandas as pd
import matplotlib
import numpy as np
import pytest
from pynyairbnb.plotting import sns_plotting
from unittest.mock import patch, MagicMock
from pynyairbnb.plotting import sns_plotting, rank_correlations, plot_pynyairbnb

# Same Data to be used for all tests
data = pd.DataFrame({"price": [25, 75, 125, 175, 225, 275, 325, 375],
Expand All @@ -11,25 +13,204 @@

# Check to see if output type is correct given correct inputs
def test_sns_plotting_output_type():
"""
Test the output type of the sns_plotting function.
This function tests whether the output of the sns_plotting function is of type matplotlib.figure.Figure.
It calls the sns_plotting function with specific parameters and checks the type of the returned result.
If the type is not matplotlib.figure.Figure, the test fails.
Returns:
None
"""
result = sns_plotting('scatterplot', data, 'number_of_reviews', 'price', 20, 10)
assert type(result) == matplotlib.figure.Figure, "Test failed: Output type is incorrect."

# Check to see if exception raised for n/a plot type
def test_sns_plotting_plottype_error():
"""
Test case to check if an exception is raised when an invalid plot type is provided to sns_plotting function.
"""
with pytest.raises(Exception):
result = sns_plotting('barplot', data, 'number_of_reviews', 'price', 20, 10)

# Check to see if value error raised for x-variable not in data
def test_sns_plotting_x_error():
"""
Test case to check if ValueError is raised when an invalid x column is provided to sns_plotting function.
"""
with pytest.raises(ValueError):
sns_plotting('scatterplot', data, 'random_x', 'price', 20, 10)

# Check to see if value error raised for y-variable not in data
def test_sns_plotting_y_error():
"""
Test case to check if a ValueError is raised when plotting with an invalid y-axis column.
Raises:
ValueError: If the y-axis column is not present in the dataset.
"""
with pytest.raises(ValueError):
sns_plotting('scatterplot', data, 'number_of_reviews', 'random_y', 20, 10)

# Check to see the figlength and figheight are both <= 25 to avoid being too large
def test_sns_plotting_figsize_check():
"""
Test function to check the size of the plot generated by sns_plotting.
This function calls the sns_plotting function with the 'scatterplot' type, 'number_of_reviews' as x-axis, and 'price' as y-axis.
It then checks if the size of the generated plot is within the specified limits (<= 25 inches for both width and height).
If the plot size is larger than the specified limits, the test fails.
Returns:
None
"""
result = sns_plotting('scatterplot', data, 'number_of_reviews', 'price')
assert result.get_size_inches()[0] <= 25 and result.get_size_inches()[1] <= 25, "Test failed: Plot size is too large."
assert result.get_size_inches()[0] <= 25 and result.get_size_inches()[1] <= 25, "Test failed: Plot size is too large."

@pytest.mark.parametrize("plot_type", ['scatterplot', 'boxplot', 'histplot', 'heatmap'])
def test_sns_plotting_supported_types(plot_type):
"""
Test the sns_plotting function with supported plot types.
This function tests the sns_plotting function with different supported plot types.
It calls the sns_plotting function with specific parameters for each plot type and checks the type of the returned result.
If the type is not matplotlib.figure.Figure, the test fails.
Args:
plot_type (str): The type of plot to be tested.
Returns:
None
"""
result = sns_plotting(plot_type, data, 'number_of_reviews', 'price', 12, 6)
assert isinstance(result, matplotlib.figure.Figure), f"Failed for {plot_type}"

def test_sns_plotting_empty_data():
"""
Test case to check if ValueError is raised when plotting with empty data.
This function creates an empty DataFrame and calls the sns_plotting function with specific parameters.
It checks if a ValueError is raised, indicating that plotting with empty data is not allowed.
Returns:
None
"""
empty_data = pd.DataFrame(columns=['price', 'number_of_reviews'])
with pytest.raises(ValueError):
sns_plotting('scatterplot', empty_data, 'number_of_reviews', 'price', 12, 6)

def test_sns_plotting_with_nans():
"""
Test case to check if NaN values are handled correctly by sns_plotting function.
This function creates a copy of the original data DataFrame and replaces a value with NaN.
It then calls the sns_plotting function with specific parameters and checks the type of the returned result.
If the type is not matplotlib.figure.Figure, the test fails.
Returns:
None
"""
data_with_nans = data.copy()
data_with_nans.loc[0, 'price'] = np.nan
result = sns_plotting('scatterplot', data_with_nans, 'number_of_reviews', 'price', 12, 6)
assert isinstance(result, matplotlib.figure.Figure), "Test failed: Handling NaNs incorrectly."

def test_sns_plotting_titles():
"""
Test case to check if titles are set correctly by sns_plotting function.
This function calls the sns_plotting function with specific parameters and checks if a title is set for the plot.
If no title is set, the test fails.
Returns:
None
"""
result = sns_plotting('scatterplot', data, 'number_of_reviews', 'price', 12, 6)
assert result.axes[0].get_title() != '', "Test failed: Title not set."

def test_rank_correlations_empty():
"""
Test case to check if rank_correlations function returns an empty DataFrame for empty input.
This function creates an empty DataFrame and calls the rank_correlations function.
It checks if the returned result is empty, indicating that the function handles empty input correctly.
Returns:
None
"""
df = pd.DataFrame()
result = rank_correlations(df)
assert result.empty, "Should return an empty DataFrame for empty input"

def test_rank_correlations_all_zeros():
"""
Test case to check if rank_correlations function handles no correlation correctly.
This function creates a DataFrame with all zero values and calls the rank_correlations function.
It checks if the returned result is empty, indicating that the function handles no correlation correctly.
Returns:
None
"""
data = pd.DataFrame({"A": [0, 0, 0], "B": [0, 0, 0]})
result = rank_correlations(data.corr())
assert result.empty, "Should handle no correlation correctly"

def test_rank_correlations_with_nans():
"""
Test case to check if rank_correlations function handles NaNs without failing.
This function creates a DataFrame with NaN values and calls the rank_correlations function.
It checks if the returned result is not empty, indicating that the function handles NaNs correctly.
Returns:
None
"""
data = pd.DataFrame({"A": [1, np.nan, 3], "B": [3, 2, 1]})
result = rank_correlations(data.corr())
assert not result.empty, "Should handle NaNs without failing"


@pytest.fixture
def mock_data():
return pd.DataFrame({
"price": [10, 20, 30],
"number_of_reviews": [1, 2, 3],
"reviews_per_month": [0.5, 1.0, 1.5],
"longitude": [-74.00597, -73.98565, -73.97801],
"latitude": [40.71278, 40.75283, 40.73586],
"room_type": ["Private room", "Entire home/apt", "Shared room"],
"neighbourhood_group": ["Manhattan", "Brooklyn", "Queens"]
})

# Use specific MagicMock objects for savefig and to_csv to clearly control and assert their usage
savefig_mock = MagicMock()
to_csv_mock = MagicMock()

@patch("matplotlib.pyplot.Figure.savefig", savefig_mock)
@patch("pandas.DataFrame.to_csv", to_csv_mock)
@patch("os.path.join", MagicMock(return_value="mock/path/to/file.csv"))
@patch("os.makedirs", MagicMock()) # Mock directory creation if necessary
@patch("pandas.read_csv")
def test_plot_pynyairbnb(mock_read_csv, mock_data):
"""
Test function for the `plot_pynyairbnb` function in the `pynyairbnb.plotting` module.
Args:
mock_read_csv (MagicMock): Mock object for the `pandas.read_csv` function.
mock_data (MagicMock): Mock object for the data to be read from CSV.
Returns:
None
Raises:
AssertionError: If `Figure.savefig` or `DataFrame.to_csv` functions are not called.
"""
mock_read_csv.return_value = mock_data
from pynyairbnb.plotting import plot_pynyairbnb # Ensure to import within the test if needed
plot_pynyairbnb("dummy.csv", "viz_dir", "tbl_dir")

assert savefig_mock.called, "Figure.savefig was not called"
assert to_csv_mock.called, "DataFrame.to_csv was not called"

0 comments on commit 4a3b393

Please # to comment.