From fd7372754bd9e7310b96592ba30819cfd2b82533 Mon Sep 17 00:00:00 2001 From: Fergus Cooper Date: Wed, 10 Jan 2024 22:02:42 +0000 Subject: [PATCH] Use plotly instead of altair Much more flexibility in plotting --- distribution_zoo/cont_uni/normal.py | 47 ++++++++++++++++------------- pyproject.toml | 2 +- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/distribution_zoo/cont_uni/normal.py b/distribution_zoo/cont_uni/normal.py index e1e021d..28097ea 100644 --- a/distribution_zoo/cont_uni/normal.py +++ b/distribution_zoo/cont_uni/normal.py @@ -1,6 +1,6 @@ from distribution_zoo import BaseDistribution -import altair as alt +import plotly.graph_objects as go import numpy as np import pandas as pd import scipy.stats as stats @@ -45,13 +45,12 @@ def update_range(self): mean = st.session_state['normal_mean'] if 'normal_mean' in st.session_state else self.param_mean std = st.session_state['normal_std'] if 'normal_std' in st.session_state else self.param_std - new_lower = round(stats.norm(loc=mean, scale=std).ppf(0.0001), 1) - new_upper = round(stats.norm(loc=mean, scale=std).ppf(0.9999), 1) + new_lower = max(round(stats.norm(loc=mean, scale=std).ppf(0.0001), 1), self.range_min) + new_upper = min(round(stats.norm(loc=mean, scale=std).ppf(0.9999), 1), self.range_max) st.session_state['normal_range'] = (new_lower, new_upper) def plot(self): - - x = np.linspace(self.range_min, self.range_max, 1000) + x = np.linspace(self.param_range_start, self.param_range_end, 1000) chart_data = pd.DataFrame( { @@ -61,31 +60,37 @@ def plot(self): } ) - # Define the initial x-axis range for the view - initial_x_range = [self.param_range_start, self.param_range_end] + line_data = pd.DataFrame( + { + 'x': [self.param_mean, self.param_mean], + 'z': [0.0, 0.0], + 'pdf': [0.0, max(chart_data['pdf'])], + 'cdf': [0.0, max(chart_data['cdf'])], + } + ) - # Create an Altair chart for the PDF - pdf_chart = alt.Chart(chart_data).mark_line().encode( - x=alt.X('x:Q', scale=alt.Scale(domain=initial_x_range)), - y='pdf:Q', - tooltip=['x', 'pdf'] - ).interactive() + # Create Plotly chart for the PDF + pdf_chart = go.Figure(go.Scatter(x=chart_data['x'], y=chart_data['pdf'], mode='lines', name='PDF')) + pdf_chart.add_trace(go.Scatter(x=line_data['x'], y=line_data['pdf'], mode='lines', name=f'Mean ({self.param_mean})', line=dict(color='orange', width=2))) + pdf_chart.add_trace(go.Scatter(x=line_data['z'], y=line_data['pdf'], mode='lines', name='Zero', line=dict(color='green', width=2, dash='dot'))) + pdf_chart.update_layout(xaxis_title='x', yaxis_title='pdf', margin=dict(l=20, r=20, t=20, b=20)) - # Create an Altair chart for the CDF - cdf_chart = alt.Chart(chart_data).mark_line().encode( - x=alt.X('x:Q', scale=alt.Scale(domain=initial_x_range)), - y='cdf:Q', - tooltip=['x', 'cdf'] - ).interactive() + # Create Plotly chart for the CDF + cdf_chart = go.Figure(go.Scatter(x=chart_data['x'], y=chart_data['cdf'], mode='lines', name='CDF')) + cdf_chart.add_trace(go.Scatter(x=line_data['x'], y=line_data['cdf'], mode='lines', name=f'Mean ({self.param_mean})', line=dict(color='orange', width=2))) + cdf_chart.add_trace(go.Scatter(x=line_data['z'], y=line_data['cdf'], mode='lines', name=r'Zero', line=dict(color='green', width=2, dash='dot'))) + cdf_chart.update_layout(xaxis_title='x', yaxis_title='cdf', margin=dict(l=20, r=20, t=20, b=20)) + # Streamlit columns for displaying the charts pdf_col, cdf_col = st.columns(2) with pdf_col: st.subheader('Probability density function') - st.altair_chart(pdf_chart, use_container_width=True) + st.plotly_chart(pdf_chart, use_container_width=True) + with cdf_col: st.subheader('Cumulative distribution function') - st.altair_chart(cdf_chart, use_container_width=True) + st.plotly_chart(cdf_chart, use_container_width=True) def update_code_substitutions(self): self.code_substitutions.add(r'{{{mean}}}', str(self.param_mean)) diff --git a/pyproject.toml b/pyproject.toml index 9168a5f..4403d84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,8 @@ numpy = "^1.26.2" scipy = "^1.11.4" inflection = "^0.5.1" pandas = "^2.1.4" -altair = "^5.2.0" requests = "^2.31.0" +plotly = "^5.18.0" [tool.poetry.group.dev.dependencies] flake8 = "^6.1.0"