Skip to content

Bar Plot

This module provides flexible functionality for creating bar plots from pandas DataFrames or Series.

It allows you to create bar plots with optional grouping, sorting, orientation, and data labels. The module supports both single and grouped bar plots, where grouped bars are created by providing a x_col, which defines the x-axis labels or categories.

Features

  • Single or Grouped Bar Plots: Plot one or more value columns (value_col) as bars. The x_col is used to define categories or groups on the x-axis (e.g., products, categories, or regions). Grouped bars can be created by specifying both value_col (list of columns) and x_col.
  • Sorting and Orientation: Customize the sorting of bars (ascending or descending) and choose between vertical ("v", "vertical") or horizontal ("h", "horizontal") bar orientations.
  • Data Labels: Add data labels to bars, with options to show absolute values or percentages.
  • Hatching Patterns: Apply hatch patterns to the bars for enhanced visual differentiation.
  • Legend Customization: Move the legend outside the plot for better visibility, especially when dealing with grouped bars or multiple value columns.

Use Cases

  • Sales and Revenue Analysis: Visualize sales or revenue across different products or categories by creating grouped bar plots (e.g., revenue across quarters or regions). The x_col will define the products or categories displayed on the x-axis.
  • Comparative Analysis: Compare multiple metrics simultaneously by plotting grouped bars. For instance, you can compare product sales for different periods side by side, with x_col defining the x-axis categories.
  • Distribution Analysis: Visualize the distribution of categorical data (e.g., product sales) across different categories, where x_col defines the x-axis labels.

Limitations and Handling of Data

  • Series Support: The module can also handle pandas Series, though x_col cannot be provided when plotting a Series. In this case, the index of the Series will define the x-axis labels.

plot(df, value_col=None, x_col=None, title=None, x_label=None, y_label=None, legend_title=None, ax=None, source_text=None, move_legend_outside=False, orientation='vertical', sort_order=None, data_label_format=None, use_hatch=False, num_digits=3, **kwargs)

Creates a customizable bar plot from a DataFrame or Series with optional features like sorting, orientation, and adding data labels. Grouped bars can be created with the use of a grouping column.

Parameters:

Name Type Description Default
df DataFrame | Series

The input DataFrame or Series containing the data to be plotted.

required
value_col str | list[str]

The column(s) containing values to plot as bars. Multiple value columns create grouped bars. Defaults to None.

None
x_col str

The column to group data by, used for grouping bars. Defaults to None.

None
title str

The title of the plot. Defaults to None.

None
x_label str

The label for the x-axis. Defaults to None.

None
y_label str

The label for the y-axis. Defaults to None.

None
legend_title str

The title for the legend. Defaults to None.

None
ax Axes

The Matplotlib Axes object to plot on. Defaults to None.

None
source_text str

Text to be displayed as a source at the bottom of the plot. Defaults to None.

None
move_legend_outside bool

Whether to move the legend outside the plot area. Defaults to False.

False
orientation Literal['horizontal', 'h', 'vertical', 'v']

Orientation of the bars. Can be "horizontal", "h", "vertical", or "v". Defaults to "vertical".

'vertical'
sort_order Literal['ascending', 'descending'] | None

Sorting order for the bars. Can be "ascending" or "descending". Defaults to None.

None
data_label_format Literal['absolute', 'percentage'] | None

Format for displaying data labels. "absolute" shows raw values, "percentage" shows percentage. Defaults to None.

None
use_hatch bool

Whether to apply hatch patterns to the bars. Defaults to False.

False
num_digits int

The number of digits to display in the data labels. Defaults to 3.

3
**kwargs dict[str, any]

Additional keyword arguments for the Pandas plot function.

{}

Returns:

Name Type Description
SubplotBase SubplotBase

The Matplotlib Axes object with the generated plot.

Source code in pyretailscience/plots/bar.py
def plot(
    df: pd.DataFrame | pd.Series,
    value_col: str | list[str] | None = None,
    x_col: str | None = None,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    legend_title: str | None = None,
    ax: Axes | None = None,
    source_text: str | None = None,
    move_legend_outside: bool = False,
    orientation: Literal["horizontal", "h", "vertical", "v"] = "vertical",
    sort_order: Literal["ascending", "descending"] | None = None,
    data_label_format: Literal["absolute", "percentage_by_bar_group", "percentage_by_series"] | None = None,
    use_hatch: bool = False,
    num_digits: int = 3,
    **kwargs: dict[str, Any],
) -> SubplotBase:
    """Creates a customizable bar plot from a DataFrame or Series with optional features like sorting, orientation, and adding data labels. Grouped bars can be created with the use of a grouping column.

    Args:
        df (pd.DataFrame | pd.Series): The input DataFrame or Series containing the data to be plotted.
        value_col (str | list[str], optional): The column(s) containing values to plot as bars. Multiple value columns
                                               create grouped bars. Defaults to None.
        x_col (str, optional): The column to group data by, used for grouping bars. Defaults to None.
        title (str, optional): The title of the plot. Defaults to None.
        x_label (str, optional): The label for the x-axis. Defaults to None.
        y_label (str, optional): The label for the y-axis. Defaults to None.
        legend_title (str, optional): The title for the legend. Defaults to None.
        ax (Axes, optional): The Matplotlib Axes object to plot on. Defaults to None.
        source_text (str, optional): Text to be displayed as a source at the bottom of the plot. Defaults to None.
        move_legend_outside (bool, optional): Whether to move the legend outside the plot area. Defaults to False.
        orientation (Literal["horizontal", "h", "vertical", "v"], optional): Orientation of the bars. Can be
                                                                             "horizontal", "h", "vertical", or "v".
                                                                             Defaults to "vertical".
        sort_order (Literal["ascending", "descending"] | None, optional): Sorting order for the bars. Can be
                                                                          "ascending" or "descending". Defaults to None.
        data_label_format (Literal["absolute", "percentage"] | None, optional): Format for displaying data labels.
                                                                                "absolute" shows raw values,
                                                                                "percentage" shows percentage.
                                                                                Defaults to None.
        use_hatch (bool, optional): Whether to apply hatch patterns to the bars. Defaults to False.
        num_digits (int, optional): The number of digits to display in the data labels. Defaults to 3.
        **kwargs (dict[str, any]): Additional keyword arguments for the Pandas `plot` function.

    Returns:
        SubplotBase: The Matplotlib Axes object with the generated plot.
    """
    if df.empty:
        raise ValueError("Cannot plot with empty DataFrame")

    # Check if x_col exists in the DataFrame, if provided
    if x_col is not None and x_col not in df.columns:
        msg = f"x_col '{x_col}' not found in DataFrame"
        raise KeyError(msg)

    valid_orientations = ["horizontal", "h", "vertical", "v"]
    if orientation not in valid_orientations:
        error_msg = f"Invalid orientation: {orientation}. Expected one of {valid_orientations}"
        raise ValueError(error_msg)

    # Validate the sort_order value
    valid_sort_orders = ["ascending", "descending", None]
    if sort_order not in valid_sort_orders:
        error_msg = f"Invalid sort_order: {sort_order}. Expected one of {valid_sort_orders}"
        raise ValueError(error_msg)

    # Validate the data_label_format value
    valid_data_label_formats = ["absolute", "percentage_by_bar_group", "percentage_by_series", None]
    if data_label_format not in valid_data_label_formats:
        error_msg = f"Invalid data_label_format: {data_label_format}. Expected one of {valid_data_label_formats}"
        raise ValueError(error_msg)

    width = kwargs.pop("width", 0.8)

    value_col = [value_col] if isinstance(value_col, str) else (["Value"] if value_col is None else value_col)

    df = df.to_frame(name=value_col[0]) if isinstance(df, pd.Series) else df

    if data_label_format in ["percentage_by_bar_group", "percentage_by_series"] and (df[value_col] < 0).any().any():
        warnings.warn(
            f"Negative values detected in {value_col}. This may lead to unexpected behavior in terms of the data "
            f"label format '{data_label_format}'.",
            UserWarning,
            stacklevel=2,
        )

    df = df.sort_values(by=value_col[0], ascending=sort_order == "ascending") if sort_order is not None else df

    color_gen_threshold = 4
    cmap = get_single_color_cmap() if len(value_col) < color_gen_threshold else get_multi_color_cmap()

    plot_kind = "bar" if orientation in ["vertical", "v"] else "barh"

    ax = df.plot(
        kind=plot_kind,
        y=value_col,
        x=x_col,
        ax=ax,
        width=width,
        color=[next(cmap) for _ in range(len(value_col))],
        legend=(len(value_col) > 1),
        **kwargs,
    )

    ax = gu.standard_graph_styles(
        ax=ax,
        title=title,
        x_label=x_label,
        y_label=y_label,
        legend_title=legend_title,
        move_legend_outside=move_legend_outside,
    )

    if use_hatch:
        ax = gu.apply_hatches(ax=ax, num_segments=len(value_col))

    # Add data labels
    if data_label_format:
        _generate_bar_labels(
            ax=ax,
            plot_kind=plot_kind,
            value_col=value_col,
            df=df,
            data_label_format=data_label_format,
            x_col=x_col if x_col is not None else df.index,
            is_stacked=kwargs.get("stacked", False),
            num_digits=num_digits,
        )

    if source_text:
        gu.add_source_text(ax=ax, source_text=source_text)

    return gu.standard_tick_styles(ax=ax)