From 4a3b3937d3f572e75a82c70e4485b179f67d0c7f Mon Sep 17 00:00:00 2001 From: Riddhi Battu Date: Thu, 11 Apr 2024 19:05:20 -0700 Subject: [PATCH] adding tests for visualizations --- tests/test_plotting.py | 185 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 2 deletions(-) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 069c75a..d6b8e87 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1,7 +1,9 @@ import pandas as pd import matplotlib +import numpy as np import pytest -from pynyairbnb.plotting import sns_plotting +from unittest.mock import patch, MagicMock +from pynyairbnb.plotting import sns_plotting, rank_correlations, plot_pynyairbnb # Same Data to be used for all tests data = pd.DataFrame({"price": [25, 75, 125, 175, 225, 275, 325, 375], @@ -11,25 +13,204 @@ # Check to see if output type is correct given correct inputs def test_sns_plotting_output_type(): + """ + Test the output type of the sns_plotting function. + + This function tests whether the output of the sns_plotting function is of type matplotlib.figure.Figure. + It calls the sns_plotting function with specific parameters and checks the type of the returned result. + If the type is not matplotlib.figure.Figure, the test fails. + + Returns: + None + """ result = sns_plotting('scatterplot', data, 'number_of_reviews', 'price', 20, 10) assert type(result) == matplotlib.figure.Figure, "Test failed: Output type is incorrect." # Check to see if exception raised for n/a plot type def test_sns_plotting_plottype_error(): + """ + Test case to check if an exception is raised when an invalid plot type is provided to sns_plotting function. + """ with pytest.raises(Exception): result = sns_plotting('barplot', data, 'number_of_reviews', 'price', 20, 10) # Check to see if value error raised for x-variable not in data def test_sns_plotting_x_error(): + """ + Test case to check if ValueError is raised when an invalid x column is provided to sns_plotting function. + """ with pytest.raises(ValueError): sns_plotting('scatterplot', data, 'random_x', 'price', 20, 10) # Check to see if value error raised for y-variable not in data def test_sns_plotting_y_error(): + """ + Test case to check if a ValueError is raised when plotting with an invalid y-axis column. + + Raises: + ValueError: If the y-axis column is not present in the dataset. + + """ with pytest.raises(ValueError): sns_plotting('scatterplot', data, 'number_of_reviews', 'random_y', 20, 10) # Check to see the figlength and figheight are both <= 25 to avoid being too large def test_sns_plotting_figsize_check(): + """ + Test function to check the size of the plot generated by sns_plotting. + + This function calls the sns_plotting function with the 'scatterplot' type, 'number_of_reviews' as x-axis, and 'price' as y-axis. + It then checks if the size of the generated plot is within the specified limits (<= 25 inches for both width and height). + If the plot size is larger than the specified limits, the test fails. + + Returns: + None + """ result = sns_plotting('scatterplot', data, 'number_of_reviews', 'price') - assert result.get_size_inches()[0] <= 25 and result.get_size_inches()[1] <= 25, "Test failed: Plot size is too large." \ No newline at end of file + assert result.get_size_inches()[0] <= 25 and result.get_size_inches()[1] <= 25, "Test failed: Plot size is too large." + +@pytest.mark.parametrize("plot_type", ['scatterplot', 'boxplot', 'histplot', 'heatmap']) +def test_sns_plotting_supported_types(plot_type): + """ + Test the sns_plotting function with supported plot types. + + This function tests the sns_plotting function with different supported plot types. + It calls the sns_plotting function with specific parameters for each plot type and checks the type of the returned result. + If the type is not matplotlib.figure.Figure, the test fails. + + Args: + plot_type (str): The type of plot to be tested. + + Returns: + None + """ + result = sns_plotting(plot_type, data, 'number_of_reviews', 'price', 12, 6) + assert isinstance(result, matplotlib.figure.Figure), f"Failed for {plot_type}" + +def test_sns_plotting_empty_data(): + """ + Test case to check if ValueError is raised when plotting with empty data. + + This function creates an empty DataFrame and calls the sns_plotting function with specific parameters. + It checks if a ValueError is raised, indicating that plotting with empty data is not allowed. + + Returns: + None + """ + empty_data = pd.DataFrame(columns=['price', 'number_of_reviews']) + with pytest.raises(ValueError): + sns_plotting('scatterplot', empty_data, 'number_of_reviews', 'price', 12, 6) + +def test_sns_plotting_with_nans(): + """ + Test case to check if NaN values are handled correctly by sns_plotting function. + + This function creates a copy of the original data DataFrame and replaces a value with NaN. + It then calls the sns_plotting function with specific parameters and checks the type of the returned result. + If the type is not matplotlib.figure.Figure, the test fails. + + Returns: + None + """ + data_with_nans = data.copy() + data_with_nans.loc[0, 'price'] = np.nan + result = sns_plotting('scatterplot', data_with_nans, 'number_of_reviews', 'price', 12, 6) + assert isinstance(result, matplotlib.figure.Figure), "Test failed: Handling NaNs incorrectly." + +def test_sns_plotting_titles(): + """ + Test case to check if titles are set correctly by sns_plotting function. + + This function calls the sns_plotting function with specific parameters and checks if a title is set for the plot. + If no title is set, the test fails. + + Returns: + None + """ + result = sns_plotting('scatterplot', data, 'number_of_reviews', 'price', 12, 6) + assert result.axes[0].get_title() != '', "Test failed: Title not set." + +def test_rank_correlations_empty(): + """ + Test case to check if rank_correlations function returns an empty DataFrame for empty input. + + This function creates an empty DataFrame and calls the rank_correlations function. + It checks if the returned result is empty, indicating that the function handles empty input correctly. + + Returns: + None + """ + df = pd.DataFrame() + result = rank_correlations(df) + assert result.empty, "Should return an empty DataFrame for empty input" + +def test_rank_correlations_all_zeros(): + """ + Test case to check if rank_correlations function handles no correlation correctly. + + This function creates a DataFrame with all zero values and calls the rank_correlations function. + It checks if the returned result is empty, indicating that the function handles no correlation correctly. + + Returns: + None + """ + data = pd.DataFrame({"A": [0, 0, 0], "B": [0, 0, 0]}) + result = rank_correlations(data.corr()) + assert result.empty, "Should handle no correlation correctly" + +def test_rank_correlations_with_nans(): + """ + Test case to check if rank_correlations function handles NaNs without failing. + + This function creates a DataFrame with NaN values and calls the rank_correlations function. + It checks if the returned result is not empty, indicating that the function handles NaNs correctly. + + Returns: + None + """ + data = pd.DataFrame({"A": [1, np.nan, 3], "B": [3, 2, 1]}) + result = rank_correlations(data.corr()) + assert not result.empty, "Should handle NaNs without failing" + + +@pytest.fixture +def mock_data(): + return pd.DataFrame({ + "price": [10, 20, 30], + "number_of_reviews": [1, 2, 3], + "reviews_per_month": [0.5, 1.0, 1.5], + "longitude": [-74.00597, -73.98565, -73.97801], + "latitude": [40.71278, 40.75283, 40.73586], + "room_type": ["Private room", "Entire home/apt", "Shared room"], + "neighbourhood_group": ["Manhattan", "Brooklyn", "Queens"] + }) + +# Use specific MagicMock objects for savefig and to_csv to clearly control and assert their usage +savefig_mock = MagicMock() +to_csv_mock = MagicMock() + +@patch("matplotlib.pyplot.Figure.savefig", savefig_mock) +@patch("pandas.DataFrame.to_csv", to_csv_mock) +@patch("os.path.join", MagicMock(return_value="mock/path/to/file.csv")) +@patch("os.makedirs", MagicMock()) # Mock directory creation if necessary +@patch("pandas.read_csv") +def test_plot_pynyairbnb(mock_read_csv, mock_data): + """ + Test function for the `plot_pynyairbnb` function in the `pynyairbnb.plotting` module. + + Args: + mock_read_csv (MagicMock): Mock object for the `pandas.read_csv` function. + mock_data (MagicMock): Mock object for the data to be read from CSV. + + Returns: + None + + Raises: + AssertionError: If `Figure.savefig` or `DataFrame.to_csv` functions are not called. + """ + mock_read_csv.return_value = mock_data + from pynyairbnb.plotting import plot_pynyairbnb # Ensure to import within the test if needed + plot_pynyairbnb("dummy.csv", "viz_dir", "tbl_dir") + + assert savefig_mock.called, "Figure.savefig was not called" + assert to_csv_mock.called, "DataFrame.to_csv was not called"