Skip to content

Commit 1feaf7b

Browse files
committed
BUG: Fixes issue pandas-dev#3334: brittle margin in pivot_table.
Adds support for margin computation when all columns are used in rows and cols
1 parent 527db38 commit 1feaf7b

File tree

4 files changed

+127
-82
lines changed

4 files changed

+127
-82
lines changed

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ pandas 0.13
102102
set _ref_locs (:issue:`4403`)
103103
- Fixed an issue where hist subplots were being overwritten when they were
104104
called using the top level matplotlib API (:issue:`4408`)
105+
- Fixed (:issue:`3334`). Margins did not compute if values is the index.
105106

106107
pandas 0.12
107108
===========

doc/source/v0.13.0.txt

+2-38
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,12 @@ API changes
1111

1212
- ``read_excel`` now supports an integer in its ``sheetname`` argument giving
1313
the index of the sheet to read in (:issue:`4301`).
14-
- Text parser now treats anything that reads like inf ("inf", "Inf", "-Inf",
15-
"iNf", etc.) as infinity. (:issue:`4220`, :issue:`4219`), affecting
16-
``read_table``, ``read_csv``, etc.
17-
- ``pandas`` now is Python 2/3 compatible without the need for 2to3 thanks to
18-
@jtratner. As a result, pandas now uses iterators more extensively. This
19-
also led to the introduction of substantive parts of the Benjamin
20-
Peterson's ``six`` library into compat. (:issue:`4384`, :issue:`4375`,
21-
:issue:`4372`)
22-
- ``pandas.util.compat`` and ``pandas.util.py3compat`` have been merged into
23-
``pandas.compat``. ``pandas.compat`` now includes many functions allowing
24-
2/3 compatibility. It contains both list and iterator versions of range,
25-
filter, map and zip, plus other necessary elements for Python 3
26-
compatibility. ``lmap``, ``lzip``, ``lrange`` and ``lfilter`` all produce
27-
lists instead of iterators, for compatibility with ``numpy``, subscripting
28-
and ``pandas`` constructors.(:issue:`4384`, :issue:`4375`, :issue:`4372`)
29-
- deprecated ``iterkv``, which will be removed in a future release (was just
30-
an alias of iteritems used to get around ``2to3``'s changes).
31-
(:issue:`4384`, :issue:`4375`, :issue:`4372`)
32-
- ``Series.get`` with negative indexers now returns the same as ``[]`` (:issue:`4390`)
3314

3415
Enhancements
3516
~~~~~~~~~~~~
3617

3718
- ``read_html`` now raises a ``URLError`` instead of catching and raising a
3819
``ValueError`` (:issue:`4303`, :issue:`4305`)
39-
- Added a test for ``read_clipboard()`` and ``to_clipboard()`` (:issue:`4282`)
40-
- Clipboard functionality now works with PySide (:issue:`4282`)
41-
- Added a more informative error message when plot arguments contain
42-
overlapping color and style arguments (:issue:`4402`)
4320

4421
Bug Fixes
4522
~~~~~~~~~
@@ -52,22 +29,9 @@ Bug Fixes
5229

5330
- Fixed bug in ``PeriodIndex.map`` where using ``str`` would return the str
5431
representation of the index (:issue:`4136`)
32+
33+
- Fixed (:issue:`3334`). Margins did not compute if values is the index.
5534

56-
- Fixed test failure ``test_time_series_plot_color_with_empty_kwargs`` when
57-
using custom matplotlib default colors (:issue:`4345`)
58-
59-
- Fix running of stata IO tests. Now uses temporary files to write
60-
(:issue:`4353`)
61-
62-
- Fixed an issue where ``DataFrame.sum`` was slower than ``DataFrame.mean``
63-
for integer valued frames (:issue:`4365`)
64-
65-
- ``read_html`` tests now work with Python 2.6 (:issue:`4351`)
66-
67-
- Fixed bug where ``network`` testing was throwing ``NameError`` because a
68-
local variable was undefined (:issue:`4381`)
69-
70-
- Suppressed DeprecationWarning associated with internal calls issued by repr() (:issue:`4391`)
7135

7236
See the :ref:`full release notes
7337
<release>` or issue tracker

pandas/tools/pivot.py

+91-27
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
from pandas import Series, DataFrame
44
from pandas.core.index import MultiIndex
5-
from pandas.core.reshape import _unstack_multiple
65
from pandas.tools.merge import concat
76
from pandas.tools.util import cartesian_product
8-
from pandas.compat import range, lrange, zip
9-
from pandas import compat
107
import pandas.core.common as com
118
import numpy as np
129

@@ -149,17 +146,64 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
149146
DataFrame.pivot_table = pivot_table
150147

151148

152-
def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
153-
grand_margin = {}
154-
for k, v in compat.iteritems(data[values]):
155-
try:
156-
if isinstance(aggfunc, compat.string_types):
157-
grand_margin[k] = getattr(v, aggfunc)()
158-
else:
159-
grand_margin[k] = aggfunc(v)
160-
except TypeError:
161-
pass
149+
def _add_margins(table, data, values, rows, cols, aggfunc):
150+
151+
grand_margin = _compute_grand_margin(data, values, aggfunc)
152+
153+
if not values and isinstance(table, Series):
154+
# If there are no values and the table is a series, then there is only
155+
# one column in the data. Compute grand margin and return it.
156+
row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
157+
return table.append(Series({row_key: grand_margin['All']}))
158+
159+
if values:
160+
marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
161+
if not isinstance(marginal_result_set, tuple):
162+
return marginal_result_set
163+
result, margin_keys, row_margin = marginal_result_set
164+
else:
165+
marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
166+
if not isinstance(marginal_result_set, tuple):
167+
return marginal_result_set
168+
result, margin_keys, row_margin = marginal_result_set
169+
170+
key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
171+
172+
row_margin = row_margin.reindex(result.columns)
173+
# populate grand margin
174+
for k in margin_keys:
175+
if isinstance(k, basestring):
176+
row_margin[k] = grand_margin[k]
177+
else:
178+
row_margin[k] = grand_margin[k[0]]
162179

180+
margin_dummy = DataFrame(row_margin, columns=[key]).T
181+
182+
row_names = result.index.names
183+
result = result.append(margin_dummy)
184+
result.index.names = row_names
185+
186+
return result
187+
188+
189+
def _compute_grand_margin(data, values, aggfunc):
190+
191+
if values:
192+
grand_margin = {}
193+
for k, v in data[values].iteritems():
194+
try:
195+
if isinstance(aggfunc, basestring):
196+
grand_margin[k] = getattr(v, aggfunc)()
197+
else:
198+
grand_margin[k] = aggfunc(v)
199+
except TypeError:
200+
pass
201+
return grand_margin
202+
else:
203+
return {'All': aggfunc(data.index)}
204+
205+
206+
def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
163207
if len(cols) > 0:
164208
# need to "interleave" the margins
165209
table_pieces = []
@@ -198,28 +242,48 @@ def _all_key(key):
198242
row_margin = row_margin.stack()
199243

200244
# slight hack
201-
new_order = [len(cols)] + lrange(len(cols))
245+
new_order = [len(cols)] + range(len(cols))
202246
row_margin.index = row_margin.index.reorder_levels(new_order)
203247
else:
204248
row_margin = Series(np.nan, index=result.columns)
205249

206-
key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
250+
return result, margin_keys, row_margin
207251

208-
row_margin = row_margin.reindex(result.columns)
209-
# populate grand margin
210-
for k in margin_keys:
211-
if len(cols) > 0:
212-
row_margin[k] = grand_margin[k[0]]
213-
else:
214-
row_margin[k] = grand_margin[k]
215252

216-
margin_dummy = DataFrame(row_margin, columns=[key]).T
253+
def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
254+
if len(cols) > 0:
255+
# need to "interleave" the margins
256+
margin_keys = []
217257

218-
row_names = result.index.names
219-
result = result.append(margin_dummy)
220-
result.index.names = row_names
258+
def _all_key():
259+
if len(cols) == 1:
260+
return 'All'
261+
return ('All', ) + ('', ) * (len(cols) - 1)
221262

222-
return result
263+
if len(rows) > 0:
264+
margin = data[rows].groupby(rows).apply(aggfunc)
265+
all_key = _all_key()
266+
table[all_key] = margin
267+
result = table
268+
margin_keys.append(all_key)
269+
270+
else:
271+
margin = data.groupby(level=0, axis=0).apply(aggfunc)
272+
all_key = _all_key()
273+
table[all_key] = margin
274+
result = table
275+
margin_keys.append(all_key)
276+
return result
277+
else:
278+
result = table
279+
margin_keys = table.columns
280+
281+
if len(cols):
282+
row_margin = data[cols].groupby(cols).apply(aggfunc)
283+
else:
284+
row_margin = Series(np.nan, index=result.columns)
285+
286+
return result, margin_keys, row_margin
223287

224288

225289
def _convert_by(by):

pandas/tools/tests/test_pivot.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import datetime
21
import unittest
32

43
import numpy as np
54
from numpy.testing import assert_equal
65

7-
import pandas
86
from pandas import DataFrame, Series, Index, MultiIndex
97
from pandas.tools.merge import concat
108
from pandas.tools.pivot import pivot_table, crosstab
11-
from pandas.compat import range, u, product
129
import pandas.util.testing as tm
1310

1411

@@ -75,18 +72,9 @@ def test_pivot_table_dropna(self):
7572
pv_col = df.pivot_table('quantity', 'month', ['customer', 'product'], dropna=False)
7673
pv_ind = df.pivot_table('quantity', ['customer', 'product'], 'month', dropna=False)
7774

78-
m = MultiIndex.from_tuples([(u('A'), u('a')),
79-
(u('A'), u('b')),
80-
(u('A'), u('c')),
81-
(u('A'), u('d')),
82-
(u('B'), u('a')),
83-
(u('B'), u('b')),
84-
(u('B'), u('c')),
85-
(u('B'), u('d')),
86-
(u('C'), u('a')),
87-
(u('C'), u('b')),
88-
(u('C'), u('c')),
89-
(u('C'), u('d'))])
75+
m = MultiIndex.from_tuples([(u'A', u'a'), (u'A', u'b'), (u'A', u'c'), (u'A', u'd'),
76+
(u'B', u'a'), (u'B', u'b'), (u'B', u'c'), (u'B', u'd'),
77+
(u'C', u'a'), (u'C', u'b'), (u'C', u'c'), (u'C', u'd')])
9078

9179
assert_equal(pv_col.columns.values, m.values)
9280
assert_equal(pv_ind.index.values, m.values)
@@ -211,17 +199,20 @@ def _check_output(res, col, rows=['A', 'B'], cols=['C']):
211199
# no rows
212200
rtable = self.data.pivot_table(cols=['AA', 'BB'], margins=True,
213201
aggfunc=np.mean)
214-
tm.assert_isinstance(rtable, Series)
202+
self.assert_(isinstance(rtable, Series))
215203
for item in ['DD', 'EE', 'FF']:
216204
gmarg = table[item]['All', '']
217205
self.assertEqual(gmarg, self.data[item].mean())
218206

219207
def test_pivot_integer_columns(self):
220208
# caused by upstream bug in unstack
209+
from pandas.util.compat import product
210+
import datetime
211+
import pandas
221212

222213
d = datetime.date.min
223214
data = list(product(['foo', 'bar'], ['A', 'B', 'C'], ['x1', 'x2'],
224-
[d + datetime.timedelta(i) for i in range(20)], [1.0]))
215+
[d + datetime.timedelta(i) for i in xrange(20)], [1.0]))
225216
df = pandas.DataFrame(data)
226217
table = df.pivot_table(values=4, rows=[0, 1, 3], cols=[2])
227218

@@ -245,6 +236,9 @@ def test_pivot_no_level_overlap(self):
245236
tm.assert_frame_equal(table, expected)
246237

247238
def test_pivot_columns_lexsorted(self):
239+
import datetime
240+
import numpy as np
241+
import pandas
248242

249243
n = 10000
250244

@@ -296,6 +290,28 @@ def test_pivot_complex_aggfunc(self):
296290

297291
tm.assert_frame_equal(result, expected)
298292

293+
def test_margins_no_values_no_cols(self):
294+
# Regression test on pivot table: no values or cols passed.
295+
result = self.data[['A', 'B']].pivot_table(rows=['A', 'B'], aggfunc=len, margins=True)
296+
result_list = result.tolist()
297+
self.assertEqual(sum(result_list[:-1]), result_list[-1])
298+
299+
def test_margins_no_values_two_rows(self):
300+
# Regression test on pivot table: no values passed but rows are a multi-index
301+
result = self.data[['A', 'B', 'C']].pivot_table(rows=['A', 'B'], cols='C', aggfunc=len, margins=True)
302+
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])
303+
304+
def test_margins_no_values_one_row_one_col(self):
305+
# Regression test on pivot table: no values passed but row and col defined
306+
result = self.data[['A', 'B']].pivot_table(rows='A', cols='B', aggfunc=len, margins=True)
307+
self.assertEqual(result.All.tolist(), [4.0, 7.0, 11.0])
308+
309+
def test_margins_no_values_two_row_two_cols(self):
310+
# Regression test on pivot table: no values passed but rows and cols are multi-indexed
311+
self.data['D'] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']
312+
result = self.data[['A', 'B', 'C', 'D']].pivot_table(rows=['A', 'B'], cols=['C', 'D'], aggfunc=len, margins=True)
313+
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])
314+
299315

300316
class TestCrosstab(unittest.TestCase):
301317

0 commit comments

Comments
 (0)