Skip to content

Commit ea6ad98

Browse files
authoredJan 12, 2021
Refactor and extract hist calculation logic from matplotlib (#1998)
This PR extract histogram calculation logic from `matplotlib.py` to `core.py`. This PR is dependent on #1997
1 parent f5f88bd commit ea6ad98

File tree

5 files changed

+231
-219
lines changed

5 files changed

+231
-219
lines changed
 

‎databricks/koalas/plot/core.py

+177-8
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616

1717
import importlib
1818

19+
import pandas as pd
20+
import numpy as np
21+
from pyspark.ml.feature import Bucketizer
22+
from pyspark.sql import functions as F
1923
from pandas.core.base import PandasObject
24+
from pandas.core.dtypes.inference import is_integer
2025

2126
from databricks.koalas.missing import unsupported_function
2227
from databricks.koalas.config import get_option
2328

2429

25-
class TopNPlot:
30+
class TopNPlotBase:
2631
def get_top_n(self, data):
2732
from databricks.koalas import DataFrame, Series
2833

@@ -56,7 +61,7 @@ def set_result_text(self, ax):
5661
)
5762

5863

59-
class SampledPlot:
64+
class SampledPlotBase:
6065
def get_sampled(self, data):
6166
from databricks.koalas import DataFrame, Series
6267

@@ -89,6 +94,170 @@ def set_result_text(self, ax):
8994
)
9095

9196

97+
class HistogramPlotBase:
98+
@staticmethod
99+
def prepare_hist_data(data, bins):
100+
# TODO: this logic is same with KdePlot. Might have to deduplicate it.
101+
from databricks.koalas.series import Series
102+
103+
if isinstance(data, Series):
104+
data = data.to_frame()
105+
106+
numeric_data = data.select_dtypes(
107+
include=["byte", "decimal", "integer", "float", "long", "double", np.datetime64]
108+
)
109+
110+
# no empty frames or series allowed
111+
if len(numeric_data.columns) == 0:
112+
raise TypeError(
113+
"Empty {0!r}: no numeric data to " "plot".format(numeric_data.__class__.__name__)
114+
)
115+
116+
if is_integer(bins):
117+
# computes boundaries for the column
118+
bins = HistogramPlotBase.get_bins(data.to_spark(), bins)
119+
120+
return numeric_data, bins
121+
122+
@staticmethod
123+
def get_bins(sdf, bins):
124+
# 'data' is a Spark DataFrame that selects all columns.
125+
if len(sdf.columns) > 1:
126+
min_col = F.least(*map(F.min, sdf))
127+
max_col = F.greatest(*map(F.max, sdf))
128+
else:
129+
min_col = F.min(sdf.columns[-1])
130+
max_col = F.max(sdf.columns[-1])
131+
boundaries = sdf.select(min_col, max_col).first()
132+
133+
# divides the boundaries into bins
134+
if boundaries[0] == boundaries[1]:
135+
boundaries = (boundaries[0] - 0.5, boundaries[1] + 0.5)
136+
137+
return np.linspace(boundaries[0], boundaries[1], bins + 1)
138+
139+
@staticmethod
140+
def compute_hist(kdf, bins):
141+
# 'data' is a Spark DataFrame that selects one column.
142+
assert isinstance(bins, (np.ndarray, np.generic))
143+
144+
sdf = kdf._internal.spark_frame
145+
scols = []
146+
for label in kdf._internal.column_labels:
147+
scols.append(kdf._internal.spark_column_for(label))
148+
sdf = sdf.select(*scols)
149+
150+
# 1. Make the bucket output flat to:
151+
# +----------+-------+
152+
# |__group_id|buckets|
153+
# +----------+-------+
154+
# |0 |0.0 |
155+
# |0 |0.0 |
156+
# |0 |1.0 |
157+
# |0 |2.0 |
158+
# |0 |3.0 |
159+
# |0 |3.0 |
160+
# |1 |0.0 |
161+
# |1 |1.0 |
162+
# |1 |1.0 |
163+
# |1 |2.0 |
164+
# |1 |1.0 |
165+
# |1 |0.0 |
166+
# +----------+-------+
167+
colnames = sdf.columns
168+
bucket_names = ["__{}_bucket".format(colname) for colname in colnames]
169+
170+
output_df = None
171+
for group_id, (colname, bucket_name) in enumerate(zip(colnames, bucket_names)):
172+
# creates a Bucketizer to get corresponding bin of each value
173+
bucketizer = Bucketizer(
174+
splits=bins, inputCol=colname, outputCol=bucket_name, handleInvalid="skip"
175+
)
176+
177+
bucket_df = bucketizer.transform(sdf)
178+
179+
if output_df is None:
180+
output_df = bucket_df.select(
181+
F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket")
182+
)
183+
else:
184+
output_df = output_df.union(
185+
bucket_df.select(
186+
F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket")
187+
)
188+
)
189+
190+
# 2. Calculate the count based on each group and bucket.
191+
# +----------+-------+------+
192+
# |__group_id|buckets| count|
193+
# +----------+-------+------+
194+
# |0 |0.0 |2 |
195+
# |0 |1.0 |1 |
196+
# |0 |2.0 |1 |
197+
# |0 |3.0 |2 |
198+
# |1 |0.0 |2 |
199+
# |1 |1.0 |3 |
200+
# |1 |2.0 |1 |
201+
# +----------+-------+------+
202+
result = (
203+
output_df.groupby("__group_id", "__bucket")
204+
.agg(F.count("*").alias("count"))
205+
.toPandas()
206+
.sort_values(by=["__group_id", "__bucket"])
207+
)
208+
209+
# 3. Fill empty bins and calculate based on each group id. From:
210+
# +----------+--------+------+
211+
# |__group_id|__bucket| count|
212+
# +----------+--------+------+
213+
# |0 |0.0 |2 |
214+
# |0 |1.0 |1 |
215+
# |0 |2.0 |1 |
216+
# |0 |3.0 |2 |
217+
# +----------+--------+------+
218+
# +----------+--------+------+
219+
# |__group_id|__bucket| count|
220+
# +----------+--------+------+
221+
# |1 |0.0 |2 |
222+
# |1 |1.0 |3 |
223+
# |1 |2.0 |1 |
224+
# +----------+--------+------+
225+
#
226+
# to:
227+
# +-----------------+
228+
# |__values1__bucket|
229+
# +-----------------+
230+
# |2 |
231+
# |1 |
232+
# |1 |
233+
# |2 |
234+
# |0 |
235+
# +-----------------+
236+
# +-----------------+
237+
# |__values2__bucket|
238+
# +-----------------+
239+
# |2 |
240+
# |3 |
241+
# |1 |
242+
# |0 |
243+
# |0 |
244+
# +-----------------+
245+
output_series = []
246+
for i, bucket_name in enumerate(bucket_names):
247+
current_bucket_result = result[result["__group_id"] == i]
248+
# generates a pandas DF with one row for each bin
249+
# we need this as some of the bins may be empty
250+
indexes = pd.DataFrame({"__bucket": np.arange(0, len(bins) - 1)})
251+
# merges the bins with counts on it and fills remaining ones with zeros
252+
pdf = indexes.merge(current_bucket_result, how="left", on=["__bucket"]).fillna(0)[
253+
["count"]
254+
]
255+
pdf.columns = [bucket_name]
256+
output_series.append(pdf[bucket_name])
257+
258+
return output_series
259+
260+
92261
class KoalasPlotAccessor(PandasObject):
93262
"""
94263
Series/Frames plotting accessor and method.
@@ -188,12 +357,12 @@ def __call__(self, kind="line", backend=None, **kwargs):
188357

189358
if plot_backend.__name__ != "databricks.koalas.plot":
190359
data_preprocessor_map = {
191-
"pie": TopNPlot().get_top_n,
192-
"bar": TopNPlot().get_top_n,
193-
"barh": TopNPlot().get_top_n,
194-
"scatter": TopNPlot().get_top_n,
195-
"area": SampledPlot().get_sampled,
196-
"line": SampledPlot().get_sampled,
360+
"pie": TopNPlotBase().get_top_n,
361+
"bar": TopNPlotBase().get_top_n,
362+
"barh": TopNPlotBase().get_top_n,
363+
"scatter": TopNPlotBase().get_top_n,
364+
"area": SampledPlotBase().get_sampled,
365+
"line": SampledPlotBase().get_sampled,
197366
}
198367
if not data_preprocessor_map[kind]:
199368
raise NotImplementedError(

0 commit comments

Comments
 (0)