Skip to content
42 changes: 42 additions & 0 deletions docs/examples/plot_types/07_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Layered Sankey diagram
======================

An example of UltraPlot's layered Sankey renderer for publication-ready
flow diagrams.

Why UltraPlot here?
-------------------
``sankey`` in layered mode handles node ordering, flow styling, and
label placement without manual geometry.

Key function: :py:meth:`ultraplot.axes.PlotAxes.sankey`.

See also
--------
* :doc:`2D plot types </2dplots>`
"""

import ultraplot as uplt

nodes = ["Budget", "Operations", "R&D", "Marketing", "Support", "Infra"]
flows = [
("Budget", "Operations", 5.0, "Ops"),
("Budget", "R&D", 3.0, "R&D"),
("Budget", "Marketing", 2.0, "Mkt"),
("Operations", "Support", 1.5, "Support"),
("Operations", "Infra", 2.0, "Infra"),
]

fig, ax = uplt.subplots(refwidth=3.6)
ax.sankey(
nodes=nodes,
flows=flows,
style="budget",
flow_labels=True,
value_format="{:.1f}",
node_label_box=True,
flow_label_pos=0.5,
)
ax.format(title="Budget allocation")
fig.show()
242 changes: 241 additions & 1 deletion ultraplot/axes/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
from collections.abc import Callable, Iterable
from numbers import Integral, Number
from typing import Any, Iterable, Optional, Union
from typing import Any, Iterable, Mapping, Optional, Sequence, Union

import matplotlib as mpl
import matplotlib.artist as martist
Expand Down Expand Up @@ -205,6 +205,83 @@
"""

docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring

_sankey_docstring = """
Draw a Sankey diagram.

Parameters
----------
flows : sequence of float or flow tuples
If a numeric sequence, use Matplotlib's Sankey implementation.
Otherwise, expect flow tuples or dicts describing (source, target, value).
nodes : sequence or dict, optional
Node identifiers or dicts with ``id``/``label``/``color`` keys. If omitted,
nodes are inferred from flow sources/targets.
labels : sequence of str, optional
Labels for each flow in Matplotlib's Sankey mode.
orientations : sequence of int, optional
Flow orientations (-1: down, 0: right, 1: up) for Matplotlib's Sankey.
pathlengths : float or sequence of float, optional
Path lengths for each flow in Matplotlib's Sankey.
trunklength : float, optional
Length of the trunk between the input and output flows.
patchlabel : str, optional
Label for the main patch in Matplotlib's Sankey mode.
scale, unit, format, gap, radius, shoulder, offset, head_angle, margin, tolerance : optional
Passed to `matplotlib.sankey.Sankey`.
prior : int, optional
Index of a prior diagram to connect to.
connect : (int, int), optional
Flow indices for the prior and current diagram connection.
rotation : float, optional
Rotation angle in degrees.
node_kw, flow_kw, label_kw : dict-like, optional
Style dictionaries for the layered Sankey renderer.
node_label_kw, flow_label_kw : dict-like, optional
Label style dictionaries for node and flow labels in layered mode.
node_label_box : bool or dict-like, optional
If ``True``, draw a rounded box behind node labels. If dict-like, used as
the ``bbox`` argument for node label styling.
style : {'budget', 'pastel', 'mono'}, optional
Built-in styling presets for layered mode.
node_order : sequence, optional
Explicit node ordering for layered mode.
layer_order : sequence, optional
Explicit layer ordering for layered mode.
group_cycle : sequence, optional
Cycle for flow group colors (defaults to flow cycle).
flow_other : float, optional
Aggregate flows below this threshold into a single ``other_label``.
other_label : str, optional
Label for the aggregated flow target.
value_format : str or callable, optional
Formatter for flow labels when not explicitly provided.
node_label_outside : {'auto', True, False}, optional
Place node labels outside narrow nodes.
node_label_offset : float, optional
Offset for outside node labels (axes-relative units).
flow_sort : bool, optional
Whether to sort flows by target position to reduce crossings.
flow_label_pos : float, optional
Horizontal placement for single flow labels (0 to 1 along the ribbon).
When flow labels overlap, positions are redistributed between 0.25 and 0.75.
node_labels, flow_labels : bool, optional
Whether to draw node or flow labels in layered mode.
align : {'center', 'top', 'bottom'}, optional
Vertical alignment for nodes within each layer in layered mode.
layers : dict-like, optional
Manual layer assignments for nodes in layered mode.
**kwargs
Patch properties passed to `matplotlib.sankey.Sankey.add` in Matplotlib mode.

Returns
-------
matplotlib.sankey.Sankey or list or SankeyDiagram
The Sankey diagram instance, or a list for multi-diagram usage. For layered
mode, returns a `~ultraplot.axes.plot_types.sankey.SankeyDiagram`.
"""

docstring._snippet_manager["plot.sankey"] = _sankey_docstring
# Auto colorbar and legend docstring
_guide_docstring = """
colorbar : bool, int, or str, optional
Expand Down Expand Up @@ -1849,6 +1926,169 @@ def curved_quiver(
stream_container = CurvedQuiverSet(lc, ac)
return stream_container

@docstring._snippet_manager
def sankey(
self,
flows: Any,
labels: Sequence[str] | None = None,
orientations: Sequence[int] | None = None,
pathlengths: float | Sequence[float] = 0.25,
trunklength: float = 1.0,
patchlabel: str = "",
*,
nodes: Any = None,
links: Any = None,
node_kw: Mapping[str, Any] | None = None,
flow_kw: Mapping[str, Any] | None = None,
label_kw: Mapping[str, Any] | None = None,
node_label_kw: Mapping[str, Any] | None = None,
flow_label_kw: Mapping[str, Any] | None = None,
node_label_box: bool | Mapping[str, Any] | None = None,
style: str | None = None,
node_order: Sequence[Any] | None = None,
layer_order: Sequence[int] | None = None,
group_cycle: Sequence[Any] | None = None,
flow_other: float | None = None,
other_label: str = "Other",
value_format: str | Callable[[float], str] | None = None,
node_label_outside: bool | str = "auto",
node_label_offset: float = 0.01,
flow_sort: bool = True,
flow_label_pos: float = 0.5,
node_labels: bool = True,
flow_labels: bool = False,
align: str = "center",
layers: Mapping[Any, int] | None = None,
scale: float | None = None,
unit: str | None = None,
format: str | None = None,
gap: float | None = None,
radius: float | None = None,
shoulder: float | None = None,
offset: float | None = None,
head_angle: float | None = None,
margin: float | None = None,
tolerance: float | None = None,
prior: int | None = None,
connect: tuple[int, int] | None = (0, 0),
rotation: float = 0,
**kwargs: Any,
) -> Any:
"""
%(plot.sankey)s
"""

def _looks_like_links(values):
if values is None:
return False
if isinstance(values, np.ndarray) and values.ndim == 1:
return False
if isinstance(values, dict):
return True
if isinstance(values, (list, tuple)) and values:
first = values[0]
if isinstance(first, dict):
return True
if isinstance(first, (list, tuple)) and len(first) >= 3:
return True
return False

use_layered = nodes is not None or links is not None or _looks_like_links(flows)
if use_layered:
from .plot_types.sankey import sankey_diagram

node_kw = node_kw or {}
flow_kw = flow_kw or {}
label_kw = label_kw or {}
if links is None:
links = flows

cycle = rc["axes.prop_cycle"].by_key().get("color", [])
if not cycle:
cycle = [self._get_lines.get_next_color()]

return sankey_diagram(
self,
nodes=nodes,
flows=links,
layers=layers,
flow_cycle=cycle,
group_cycle=group_cycle,
node_order=node_order,
layer_order=layer_order,
style=style,
flow_other=flow_other,
other_label=other_label,
value_format=value_format,
node_kw=node_kw,
flow_kw=flow_kw,
label_kw=label_kw,
node_label_kw=node_label_kw,
flow_label_kw=flow_label_kw,
node_label_box=node_label_box,
node_label_outside=node_label_outside,
node_label_offset=node_label_offset,
flow_sort=flow_sort,
flow_label_pos=flow_label_pos,
node_labels=node_labels,
flow_labels=flow_labels,
align=align,
node_pad=rc["sankey.nodepad"],
node_width=rc["sankey.nodewidth"],
margin=rc["sankey.margin"],
flow_alpha=rc["sankey.flow.alpha"],
flow_curvature=rc["sankey.flow.curvature"],
node_facecolor=rc["sankey.node.facecolor"],
)

from matplotlib.sankey import Sankey

sankey_kw = {}
if scale is not None:
sankey_kw["scale"] = scale
if unit is not None:
sankey_kw["unit"] = unit
if format is not None:
sankey_kw["format"] = format
if gap is not None:
sankey_kw["gap"] = gap
if radius is not None:
sankey_kw["radius"] = radius
if shoulder is not None:
sankey_kw["shoulder"] = shoulder
if offset is not None:
sankey_kw["offset"] = offset
if head_angle is not None:
sankey_kw["head_angle"] = head_angle
if margin is not None:
sankey_kw["margin"] = margin
if tolerance is not None:
sankey_kw["tolerance"] = tolerance

if "facecolor" not in kwargs and "color" not in kwargs:
kwargs["facecolor"] = self._get_lines.get_next_color()

sankey = Sankey(ax=self, **sankey_kw)
add_kw = {
"flows": flows,
"trunklength": trunklength,
"patchlabel": patchlabel,
"rotation": rotation,
"pathlengths": pathlengths,
}
if labels is not None:
add_kw["labels"] = labels
if orientations is not None:
add_kw["orientations"] = orientations
if prior is not None:
add_kw["prior"] = prior
if connect is not None:
add_kw["connect"] = connect

sankey.add(**add_kw, **kwargs)
diagrams = sankey.finish()
return diagrams[0] if len(diagrams) == 1 else diagrams

def _call_native(self, name, *args, **kwargs):
"""
Call the plotting method and redirect internal calls to native methods.
Expand Down
Loading
Loading