Skip to content

Commit 5227ef9

Browse files
committed
Merge branch 'master' of https://github.com/databricks/koalas into f_lookup
2 parents db4b6d6 + ea6ad98 commit 5227ef9

File tree

4 files changed

+232
-112
lines changed

4 files changed

+232
-112
lines changed

databricks/koalas/plot/core.py

+177-8
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616

1717
import importlib
1818

19+
import pandas as pd
20+
import numpy as np
21+
from pyspark.ml.feature import Bucketizer
22+
from pyspark.sql import functions as F
1923
from pandas.core.base import PandasObject
24+
from pandas.core.dtypes.inference import is_integer
2025

2126
from databricks.koalas.missing import unsupported_function
2227
from databricks.koalas.config import get_option
2328

2429

25-
class TopNPlot:
30+
class TopNPlotBase:
2631
def get_top_n(self, data):
2732
from databricks.koalas import DataFrame, Series
2833

@@ -56,7 +61,7 @@ def set_result_text(self, ax):
5661
)
5762

5863

59-
class SampledPlot:
64+
class SampledPlotBase:
6065
def get_sampled(self, data):
6166
from databricks.koalas import DataFrame, Series
6267

@@ -89,6 +94,170 @@ def set_result_text(self, ax):
8994
)
9095

9196

97+
class HistogramPlotBase:
98+
@staticmethod
99+
def prepare_hist_data(data, bins):
100+
# TODO: this logic is same with KdePlot. Might have to deduplicate it.
101+
from databricks.koalas.series import Series
102+
103+
if isinstance(data, Series):
104+
data = data.to_frame()
105+
106+
numeric_data = data.select_dtypes(
107+
include=["byte", "decimal", "integer", "float", "long", "double", np.datetime64]
108+
)
109+
110+
# no empty frames or series allowed
111+
if len(numeric_data.columns) == 0:
112+
raise TypeError(
113+
"Empty {0!r}: no numeric data to " "plot".format(numeric_data.__class__.__name__)
114+
)
115+
116+
if is_integer(bins):
117+
# computes boundaries for the column
118+
bins = HistogramPlotBase.get_bins(data.to_spark(), bins)
119+
120+
return numeric_data, bins
121+
122+
@staticmethod
123+
def get_bins(sdf, bins):
124+
# 'data' is a Spark DataFrame that selects all columns.
125+
if len(sdf.columns) > 1:
126+
min_col = F.least(*map(F.min, sdf))
127+
max_col = F.greatest(*map(F.max, sdf))
128+
else:
129+
min_col = F.min(sdf.columns[-1])
130+
max_col = F.max(sdf.columns[-1])
131+
boundaries = sdf.select(min_col, max_col).first()
132+
133+
# divides the boundaries into bins
134+
if boundaries[0] == boundaries[1]:
135+
boundaries = (boundaries[0] - 0.5, boundaries[1] + 0.5)
136+
137+
return np.linspace(boundaries[0], boundaries[1], bins + 1)
138+
139+
@staticmethod
140+
def compute_hist(kdf, bins):
141+
# 'data' is a Spark DataFrame that selects one column.
142+
assert isinstance(bins, (np.ndarray, np.generic))
143+
144+
sdf = kdf._internal.spark_frame
145+
scols = []
146+
for label in kdf._internal.column_labels:
147+
scols.append(kdf._internal.spark_column_for(label))
148+
sdf = sdf.select(*scols)
149+
150+
# 1. Make the bucket output flat to:
151+
# +----------+-------+
152+
# |__group_id|buckets|
153+
# +----------+-------+
154+
# |0 |0.0 |
155+
# |0 |0.0 |
156+
# |0 |1.0 |
157+
# |0 |2.0 |
158+
# |0 |3.0 |
159+
# |0 |3.0 |
160+
# |1 |0.0 |
161+
# |1 |1.0 |
162+
# |1 |1.0 |
163+
# |1 |2.0 |
164+
# |1 |1.0 |
165+
# |1 |0.0 |
166+
# +----------+-------+
167+
colnames = sdf.columns
168+
bucket_names = ["__{}_bucket".format(colname) for colname in colnames]
169+
170+
output_df = None
171+
for group_id, (colname, bucket_name) in enumerate(zip(colnames, bucket_names)):
172+
# creates a Bucketizer to get corresponding bin of each value
173+
bucketizer = Bucketizer(
174+
splits=bins, inputCol=colname, outputCol=bucket_name, handleInvalid="skip"
175+
)
176+
177+
bucket_df = bucketizer.transform(sdf)
178+
179+
if output_df is None:
180+
output_df = bucket_df.select(
181+
F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket")
182+
)
183+
else:
184+
output_df = output_df.union(
185+
bucket_df.select(
186+
F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket")
187+
)
188+
)
189+
190+
# 2. Calculate the count based on each group and bucket.
191+
# +----------+-------+------+
192+
# |__group_id|buckets| count|
193+
# +----------+-------+------+
194+
# |0 |0.0 |2 |
195+
# |0 |1.0 |1 |
196+
# |0 |2.0 |1 |
197+
# |0 |3.0 |2 |
198+
# |1 |0.0 |2 |
199+
# |1 |1.0 |3 |
200+
# |1 |2.0 |1 |
201+
# +----------+-------+------+
202+
result = (
203+
output_df.groupby("__group_id", "__bucket")
204+
.agg(F.count("*").alias("count"))
205+
.toPandas()
206+
.sort_values(by=["__group_id", "__bucket"])
207+
)
208+
209+
# 3. Fill empty bins and calculate based on each group id. From:
210+
# +----------+--------+------+
211+
# |__group_id|__bucket| count|
212+
# +----------+--------+------+
213+
# |0 |0.0 |2 |
214+
# |0 |1.0 |1 |
215+
# |0 |2.0 |1 |
216+
# |0 |3.0 |2 |
217+
# +----------+--------+------+
218+
# +----------+--------+------+
219+
# |__group_id|__bucket| count|
220+
# +----------+--------+------+
221+
# |1 |0.0 |2 |
222+
# |1 |1.0 |3 |
223+
# |1 |2.0 |1 |
224+
# +----------+--------+------+
225+
#
226+
# to:
227+
# +-----------------+
228+
# |__values1__bucket|
229+
# +-----------------+
230+
# |2 |
231+
# |1 |
232+
# |1 |
233+
# |2 |
234+
# |0 |
235+
# +-----------------+
236+
# +-----------------+
237+
# |__values2__bucket|
238+
# +-----------------+
239+
# |2 |
240+
# |3 |
241+
# |1 |
242+
# |0 |
243+
# |0 |
244+
# +-----------------+
245+
output_series = []
246+
for i, bucket_name in enumerate(bucket_names):
247+
current_bucket_result = result[result["__group_id"] == i]
248+
# generates a pandas DF with one row for each bin
249+
# we need this as some of the bins may be empty
250+
indexes = pd.DataFrame({"__bucket": np.arange(0, len(bins) - 1)})
251+
# merges the bins with counts on it and fills remaining ones with zeros
252+
pdf = indexes.merge(current_bucket_result, how="left", on=["__bucket"]).fillna(0)[
253+
["count"]
254+
]
255+
pdf.columns = [bucket_name]
256+
output_series.append(pdf[bucket_name])
257+
258+
return output_series
259+
260+
92261
class KoalasPlotAccessor(PandasObject):
93262
"""
94263
Series/Frames plotting accessor and method.
@@ -188,12 +357,12 @@ def __call__(self, kind="line", backend=None, **kwargs):
188357

189358
if plot_backend.__name__ != "databricks.koalas.plot":
190359
data_preprocessor_map = {
191-
"pie": TopNPlot().get_top_n,
192-
"bar": TopNPlot().get_top_n,
193-
"barh": TopNPlot().get_top_n,
194-
"scatter": TopNPlot().get_top_n,
195-
"area": SampledPlot().get_sampled,
196-
"line": SampledPlot().get_sampled,
360+
"pie": TopNPlotBase().get_top_n,
361+
"bar": TopNPlotBase().get_top_n,
362+
"barh": TopNPlotBase().get_top_n,
363+
"scatter": TopNPlotBase().get_top_n,
364+
"area": SampledPlotBase().get_sampled,
365+
"line": SampledPlotBase().get_sampled,
197366
}
198367
if not data_preprocessor_map[kind]:
199368
raise NotImplementedError(

databricks/koalas/plot/matplotlib.py

+11-89
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from pandas.core.dtypes.inference import is_integer, is_list_like
2424
from pandas.io.formats.printing import pprint_thing
2525

26-
from databricks.koalas.plot import TopNPlot, SampledPlot
27-
from pyspark.ml.feature import Bucketizer
26+
from databricks.koalas.plot import TopNPlotBase, SampledPlotBase, HistogramPlotBase
2827
from pyspark.mllib.stat import KernelDensity
2928
from pyspark.sql import functions as F
3029

@@ -61,7 +60,7 @@
6160
_all_kinds = PlotAccessor._all_kinds
6261

6362

64-
class KoalasBarPlot(PandasBarPlot, TopNPlot):
63+
class KoalasBarPlot(PandasBarPlot, TopNPlotBase):
6564
def __init__(self, data, **kwargs):
6665
super().__init__(self.get_top_n(data), **kwargs)
6766

@@ -442,47 +441,23 @@ def _get_fliers(colname, outliers, min_val):
442441
return fliers
443442

444443

445-
class KoalasHistPlot(PandasHistPlot):
444+
class KoalasHistPlot(PandasHistPlot, HistogramPlotBase):
446445
def _args_adjust(self):
447446
if is_list_like(self.bottom):
448447
self.bottom = np.array(self.bottom)
449448

450449
def _compute_plot_data(self):
451-
# TODO: this logic is same with KdePlot. Might have to deduplicate it.
452-
from databricks.koalas.series import Series
453-
454-
data = self.data
455-
if isinstance(data, Series):
456-
data = data.to_frame()
457-
458-
numeric_data = data.select_dtypes(
459-
include=["byte", "decimal", "integer", "float", "long", "double", np.datetime64]
460-
)
461-
462-
# no empty frames or series allowed
463-
if len(numeric_data.columns) == 0:
464-
raise TypeError(
465-
"Empty {0!r}: no numeric data to " "plot".format(numeric_data.__class__.__name__)
466-
)
467-
468-
if is_integer(self.bins):
469-
# computes boundaries for the column
470-
self.bins = self._get_bins(data.to_spark(), self.bins)
471-
472-
self.data = numeric_data
450+
self.data, self.bins = HistogramPlotBase.prepare_hist_data(self.data, self.bins)
473451

474452
def _make_plot(self):
475453
# TODO: this logic is similar with KdePlot. Might have to deduplicate it.
476454
# 'num_colors' requires to calculate `shape` which has to count all.
477455
# Use 1 for now to save the computation.
478456
colors = self._get_colors(num_colors=1)
479457
stacking_id = self._get_stacking_id()
458+
output_series = HistogramPlotBase.compute_hist(self.data, self.bins)
480459

481-
sdf = self.data._internal.spark_frame
482-
483-
for i, label in enumerate(self.data._internal.column_labels):
484-
# 'y' is a Spark DataFrame that selects one column.
485-
y = sdf.select(self.data._internal.spark_column_for(label))
460+
for (i, label), y in zip(enumerate(self.data._internal.column_labels), output_series):
486461
ax = self._get_ax(i)
487462

488463
kwds = self.kwds.copy()
@@ -494,11 +469,6 @@ def _make_plot(self):
494469
if style is not None:
495470
kwds["style"] = style
496471

497-
# 'y' is a Spark DataFrame that selects one column.
498-
# here, we manually calculates the weights separately via Spark
499-
# and assign it directly to histogram plot.
500-
y = KoalasHistPlot._compute_hist(y, self.bins) # now y is a pandas Series.
501-
502472
kwds = self._make_plot_keywords(kwds, y)
503473
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
504474
self._add_legend_handle(artists[0], label, index=i)
@@ -518,56 +488,8 @@ def _plot(cls, ax, y, style=None, bins=None, bottom=0, column_num=0, stacking_id
518488
cls._update_stacker(ax, stacking_id, n)
519489
return patches
520490

521-
@staticmethod
522-
def _get_bins(sdf, bins):
523-
# 'data' is a Spark DataFrame that selects all columns.
524-
if len(sdf.columns) > 1:
525-
min_col = F.least(*map(F.min, sdf))
526-
max_col = F.greatest(*map(F.max, sdf))
527-
else:
528-
min_col = F.min(sdf.columns[-1])
529-
max_col = F.max(sdf.columns[-1])
530-
boundaries = sdf.select(min_col, max_col).first()
531-
532-
# divides the boundaries into bins
533-
if boundaries[0] == boundaries[1]:
534-
boundaries = (boundaries[0] - 0.5, boundaries[1] + 0.5)
535-
536-
return np.linspace(boundaries[0], boundaries[1], bins + 1)
537-
538-
@staticmethod
539-
def _compute_hist(sdf, bins):
540-
# 'data' is a Spark DataFrame that selects one column.
541-
assert isinstance(bins, (np.ndarray, np.generic))
542-
543-
colname = sdf.columns[-1]
544-
545-
bucket_name = "__{}_bucket".format(colname)
546-
# creates a Bucketizer to get corresponding bin of each value
547-
bucketizer = Bucketizer(
548-
splits=bins, inputCol=colname, outputCol=bucket_name, handleInvalid="skip"
549-
)
550-
# after bucketing values, groups and counts them
551-
result = (
552-
bucketizer.transform(sdf)
553-
.select(bucket_name)
554-
.groupby(bucket_name)
555-
.agg(F.count("*").alias("count"))
556-
.toPandas()
557-
.sort_values(by=bucket_name)
558-
)
559-
560-
# generates a pandas DF with one row for each bin
561-
# we need this as some of the bins may be empty
562-
indexes = pd.DataFrame({bucket_name: np.arange(0, len(bins) - 1), "bucket": bins[:-1]})
563-
# merges the bins with counts on it and fills remaining ones with zeros
564-
pdf = indexes.merge(result, how="left", on=[bucket_name]).fillna(0)[["count"]]
565-
pdf.columns = [bucket_name]
566-
567-
return pdf[bucket_name]
568-
569491

570-
class KoalasPiePlot(PandasPiePlot, TopNPlot):
492+
class KoalasPiePlot(PandasPiePlot, TopNPlotBase):
571493
def __init__(self, data, **kwargs):
572494
super().__init__(self.get_top_n(data), **kwargs)
573495

@@ -576,7 +498,7 @@ def _make_plot(self):
576498
super()._make_plot()
577499

578500

579-
class KoalasAreaPlot(PandasAreaPlot, SampledPlot):
501+
class KoalasAreaPlot(PandasAreaPlot, SampledPlotBase):
580502
def __init__(self, data, **kwargs):
581503
super().__init__(self.get_sampled(data), **kwargs)
582504

@@ -585,7 +507,7 @@ def _make_plot(self):
585507
super()._make_plot()
586508

587509

588-
class KoalasLinePlot(PandasLinePlot, SampledPlot):
510+
class KoalasLinePlot(PandasLinePlot, SampledPlotBase):
589511
def __init__(self, data, **kwargs):
590512
super().__init__(self.get_sampled(data), **kwargs)
591513

@@ -594,7 +516,7 @@ def _make_plot(self):
594516
super()._make_plot()
595517

596518

597-
class KoalasBarhPlot(PandasBarhPlot, TopNPlot):
519+
class KoalasBarhPlot(PandasBarhPlot, TopNPlotBase):
598520
def __init__(self, data, **kwargs):
599521
super().__init__(self.get_top_n(data), **kwargs)
600522

@@ -603,7 +525,7 @@ def _make_plot(self):
603525
super()._make_plot()
604526

605527

606-
class KoalasScatterPlot(PandasScatterPlot, TopNPlot):
528+
class KoalasScatterPlot(PandasScatterPlot, TopNPlotBase):
607529
def __init__(self, data, x, y, **kwargs):
608530
super().__init__(self.get_top_n(data), x, y, **kwargs)
609531

0 commit comments

Comments
 (0)