Skip to content

Commit 9ca2508

Browse files
committed
Add support for numeric x
1 parent 86e3b09 commit 9ca2508

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

galleries/examples/lines_bars_and_markers/grouped_bar_chart.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,19 @@
102102
#
103103
# df = pd.DataFrame(data, index=x, columns=columns)
104104
# df.plot.bar()
105+
106+
# %%
107+
# Numeric x values
108+
# ----------------
109+
# In the most common case, one will want to pass categorical labels as *x*.
110+
# Additionally, we allow numeric values for *x*, as with `~.Axes.bar()`.
111+
# But for simplicity and clarity, we require that these are equidistant.
112+
113+
x = [0, 2, 4]
114+
data = {
115+
'data1': [1, 2, 3],
116+
'data2': [1.2, 2.2, 3.2],
117+
}
118+
119+
fig, ax = plt.subplots()
120+
ax.grouped_bar(x, data)

lib/matplotlib/axes/_axes.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,8 +3052,12 @@ def grouped_bar(self, x, heights, dataset_labels=None):
30523052
"""
30533053
Parameters
30543054
-----------
3055-
x : array-like of str
3056-
The labels.
3055+
x : array-like or list of str
3056+
The center positions of the bar groups. If these are numeric values,
3057+
they have to be equidistant. As with `~.Axes.bar`, you can provide
3058+
categorical labels, which will be used at integer numeric positions
3059+
``range(x)``.
3060+
30573061
heights : list of array-like or dict of array-like or 2D array
30583062
The heights for all x and groups. One of:
30593063
@@ -3112,27 +3116,40 @@ def grouped_bar(self, x, heights, dataset_labels=None):
31123116
elif hasattr(heights, 'shape'):
31133117
heights = heights.T
31143118

3115-
num_labels = len(x)
3119+
num_groups = len(x)
31163120
num_datasets = len(heights)
31173121

3118-
for dataset in heights:
3119-
assert len(dataset) == num_labels
3122+
if isinstance(x[0], str):
3123+
tick_labels = x
3124+
group_centers = np.arange(num_groups)
3125+
else:
3126+
if num_groups > 1:
3127+
d = np.diff(x)
3128+
if not np.allclose(d, d.mean()):
3129+
raise ValueError("'x' must be equidistant")
3130+
group_centers = np.asarray(x)
3131+
tick_labels = None
3132+
3133+
for i, dataset in enumerate(heights):
3134+
if len(dataset) != num_groups:
3135+
raise ValueError(
3136+
f"'x' indicates {num_groups} groups, but dataset {i} "
3137+
f"has {len(dataset)} groups"
3138+
)
31203139

31213140
margin = 0.1
31223141
bar_width = (1 - 2 * margin) / num_datasets
3123-
block_centers = np.arange(num_labels)
31243142

31253143
if dataset_labels is None:
31263144
dataset_labels = [None] * num_datasets
31273145
else:
31283146
assert len(dataset_labels) == num_datasets
31293147

31303148
for i, (hs, dataset_label) in enumerate(zip(heights, dataset_labels)):
3131-
lefts = block_centers - 0.5 + margin + i * bar_width
3132-
print(i, x, lefts, hs, dataset_label)
3149+
lefts = group_centers - 0.5 + margin + i * bar_width
31333150
self.bar(lefts, hs, width=bar_width, align="edge", label=dataset_label)
31343151

3135-
self.xaxis.set_ticks(block_centers, labels=x)
3152+
self.xaxis.set_ticks(group_centers, labels=tick_labels)
31363153

31373154
# TODO: does not return anything for now
31383155

0 commit comments

Comments
 (0)