23
23
from pandas .core .dtypes .inference import is_integer , is_list_like
24
24
from pandas .io .formats .printing import pprint_thing
25
25
26
- from databricks .koalas .plot import TopNPlot , SampledPlot
27
- from pyspark .ml .feature import Bucketizer
26
+ from databricks .koalas .plot import TopNPlotBase , SampledPlotBase , HistogramPlotBase
28
27
from pyspark .mllib .stat import KernelDensity
29
28
from pyspark .sql import functions as F
30
29
61
60
_all_kinds = PlotAccessor ._all_kinds
62
61
63
62
64
- class KoalasBarPlot (PandasBarPlot , TopNPlot ):
63
+ class KoalasBarPlot (PandasBarPlot , TopNPlotBase ):
65
64
def __init__ (self , data , ** kwargs ):
66
65
super ().__init__ (self .get_top_n (data ), ** kwargs )
67
66
@@ -442,47 +441,23 @@ def _get_fliers(colname, outliers, min_val):
442
441
return fliers
443
442
444
443
445
- class KoalasHistPlot (PandasHistPlot ):
444
+ class KoalasHistPlot (PandasHistPlot , HistogramPlotBase ):
446
445
def _args_adjust (self ):
447
446
if is_list_like (self .bottom ):
448
447
self .bottom = np .array (self .bottom )
449
448
450
449
def _compute_plot_data (self ):
451
- # TODO: this logic is same with KdePlot. Might have to deduplicate it.
452
- from databricks .koalas .series import Series
453
-
454
- data = self .data
455
- if isinstance (data , Series ):
456
- data = data .to_frame ()
457
-
458
- numeric_data = data .select_dtypes (
459
- include = ["byte" , "decimal" , "integer" , "float" , "long" , "double" , np .datetime64 ]
460
- )
461
-
462
- # no empty frames or series allowed
463
- if len (numeric_data .columns ) == 0 :
464
- raise TypeError (
465
- "Empty {0!r}: no numeric data to " "plot" .format (numeric_data .__class__ .__name__ )
466
- )
467
-
468
- if is_integer (self .bins ):
469
- # computes boundaries for the column
470
- self .bins = self ._get_bins (data .to_spark (), self .bins )
471
-
472
- self .data = numeric_data
450
+ self .data , self .bins = HistogramPlotBase .prepare_hist_data (self .data , self .bins )
473
451
474
452
def _make_plot (self ):
475
453
# TODO: this logic is similar with KdePlot. Might have to deduplicate it.
476
454
# 'num_colors' requires to calculate `shape` which has to count all.
477
455
# Use 1 for now to save the computation.
478
456
colors = self ._get_colors (num_colors = 1 )
479
457
stacking_id = self ._get_stacking_id ()
458
+ output_series = HistogramPlotBase .compute_hist (self .data , self .bins )
480
459
481
- sdf = self .data ._internal .spark_frame
482
-
483
- for i , label in enumerate (self .data ._internal .column_labels ):
484
- # 'y' is a Spark DataFrame that selects one column.
485
- y = sdf .select (self .data ._internal .spark_column_for (label ))
460
+ for (i , label ), y in zip (enumerate (self .data ._internal .column_labels ), output_series ):
486
461
ax = self ._get_ax (i )
487
462
488
463
kwds = self .kwds .copy ()
@@ -494,11 +469,6 @@ def _make_plot(self):
494
469
if style is not None :
495
470
kwds ["style" ] = style
496
471
497
- # 'y' is a Spark DataFrame that selects one column.
498
- # here, we manually calculates the weights separately via Spark
499
- # and assign it directly to histogram plot.
500
- y = KoalasHistPlot ._compute_hist (y , self .bins ) # now y is a pandas Series.
501
-
502
472
kwds = self ._make_plot_keywords (kwds , y )
503
473
artists = self ._plot (ax , y , column_num = i , stacking_id = stacking_id , ** kwds )
504
474
self ._add_legend_handle (artists [0 ], label , index = i )
@@ -518,56 +488,8 @@ def _plot(cls, ax, y, style=None, bins=None, bottom=0, column_num=0, stacking_id
518
488
cls ._update_stacker (ax , stacking_id , n )
519
489
return patches
520
490
521
- @staticmethod
522
- def _get_bins (sdf , bins ):
523
- # 'data' is a Spark DataFrame that selects all columns.
524
- if len (sdf .columns ) > 1 :
525
- min_col = F .least (* map (F .min , sdf ))
526
- max_col = F .greatest (* map (F .max , sdf ))
527
- else :
528
- min_col = F .min (sdf .columns [- 1 ])
529
- max_col = F .max (sdf .columns [- 1 ])
530
- boundaries = sdf .select (min_col , max_col ).first ()
531
-
532
- # divides the boundaries into bins
533
- if boundaries [0 ] == boundaries [1 ]:
534
- boundaries = (boundaries [0 ] - 0.5 , boundaries [1 ] + 0.5 )
535
-
536
- return np .linspace (boundaries [0 ], boundaries [1 ], bins + 1 )
537
-
538
- @staticmethod
539
- def _compute_hist (sdf , bins ):
540
- # 'data' is a Spark DataFrame that selects one column.
541
- assert isinstance (bins , (np .ndarray , np .generic ))
542
-
543
- colname = sdf .columns [- 1 ]
544
-
545
- bucket_name = "__{}_bucket" .format (colname )
546
- # creates a Bucketizer to get corresponding bin of each value
547
- bucketizer = Bucketizer (
548
- splits = bins , inputCol = colname , outputCol = bucket_name , handleInvalid = "skip"
549
- )
550
- # after bucketing values, groups and counts them
551
- result = (
552
- bucketizer .transform (sdf )
553
- .select (bucket_name )
554
- .groupby (bucket_name )
555
- .agg (F .count ("*" ).alias ("count" ))
556
- .toPandas ()
557
- .sort_values (by = bucket_name )
558
- )
559
-
560
- # generates a pandas DF with one row for each bin
561
- # we need this as some of the bins may be empty
562
- indexes = pd .DataFrame ({bucket_name : np .arange (0 , len (bins ) - 1 ), "bucket" : bins [:- 1 ]})
563
- # merges the bins with counts on it and fills remaining ones with zeros
564
- pdf = indexes .merge (result , how = "left" , on = [bucket_name ]).fillna (0 )[["count" ]]
565
- pdf .columns = [bucket_name ]
566
-
567
- return pdf [bucket_name ]
568
-
569
491
570
- class KoalasPiePlot (PandasPiePlot , TopNPlot ):
492
+ class KoalasPiePlot (PandasPiePlot , TopNPlotBase ):
571
493
def __init__ (self , data , ** kwargs ):
572
494
super ().__init__ (self .get_top_n (data ), ** kwargs )
573
495
@@ -576,7 +498,7 @@ def _make_plot(self):
576
498
super ()._make_plot ()
577
499
578
500
579
- class KoalasAreaPlot (PandasAreaPlot , SampledPlot ):
501
+ class KoalasAreaPlot (PandasAreaPlot , SampledPlotBase ):
580
502
def __init__ (self , data , ** kwargs ):
581
503
super ().__init__ (self .get_sampled (data ), ** kwargs )
582
504
@@ -585,7 +507,7 @@ def _make_plot(self):
585
507
super ()._make_plot ()
586
508
587
509
588
- class KoalasLinePlot (PandasLinePlot , SampledPlot ):
510
+ class KoalasLinePlot (PandasLinePlot , SampledPlotBase ):
589
511
def __init__ (self , data , ** kwargs ):
590
512
super ().__init__ (self .get_sampled (data ), ** kwargs )
591
513
@@ -594,7 +516,7 @@ def _make_plot(self):
594
516
super ()._make_plot ()
595
517
596
518
597
- class KoalasBarhPlot (PandasBarhPlot , TopNPlot ):
519
+ class KoalasBarhPlot (PandasBarhPlot , TopNPlotBase ):
598
520
def __init__ (self , data , ** kwargs ):
599
521
super ().__init__ (self .get_top_n (data ), ** kwargs )
600
522
@@ -603,7 +525,7 @@ def _make_plot(self):
603
525
super ()._make_plot ()
604
526
605
527
606
- class KoalasScatterPlot (PandasScatterPlot , TopNPlot ):
528
+ class KoalasScatterPlot (PandasScatterPlot , TopNPlotBase ):
607
529
def __init__ (self , data , x , y , ** kwargs ):
608
530
super ().__init__ (self .get_top_n (data ), x , y , ** kwargs )
609
531
0 commit comments