Commit da95ec72 authored by Turner, Sean's avatar Turner, Sean
Browse files

Add cascade network topology visualization script



Standalone script that parses any cascade_config.yaml and produces a
publication-quality PDF/SVG/PNG network diagram with trapezoid reservoir
shapes, hydropower plant indicators, Bezier edge routing, and lag labels.

Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 1ba3307b
Loading
Loading
Loading
Loading
+62.3 KiB

File added.

No diff preview for this file type.

+207 KiB
Loading image diff...
+499 −0
Original line number Diff line number Diff line
"""
Publication-quality cascade network diagram from cascade_config.yaml.

Usage:
    python examples/plot_network.py examples/Cumberland/cascade_config.yaml
    python examples/plot_network.py examples/Cumberland/cascade_config.yaml -o network.svg
"""

import argparse
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import yaml
from matplotlib.patches import FancyBboxPatch, Polygon
from matplotlib.path import Path as MplPath

# ---------------------------------------------------------------------------
# Style — single place to tune
# ---------------------------------------------------------------------------
STYLE = {
    # Reservoir colors
    "res_fill": "#4F7D9B",
    "res_stroke": "#345570",
    "res_text": "#FFFFFF",
    "cap_text": "#C8DCE8",
    "water_line": "#FFFFFF",
    "water_line_alpha": 0.30,
    # Powerhouse (hydropower indicator)
    "hp_color": "#D4A03C",
    "hp_stroke": "#B8862E",
    "hp_text": "#FFFFFF",
    # Confluence
    "conf_fill": "#C8CDD2",
    "conf_stroke": "#8E99A4",
    "conf_text": "#3D4F5F",
    # Edges
    "edge_color": "#6B7D8D",
    "direct_color": "#9EAAB5",
    "lag_text": "#455565",
    # Reservoir trapezoid geometry (data coords)
    "res_tw": 1.85,       # top width (wider — open water surface)
    "res_bw": 1.20,       # bottom width (narrower — dam base)
    "res_h": 0.65,
    "hp_bar_h": 0.07,     # powerhouse bar height at bottom of trapezoid
    # Confluence
    "conf_r": 0.10,
    # Line weights (points)
    "edge_lw": 0.75,
    "node_lw": 0.7,
    "hp_lw": 0.4,
    # Font sizes (points)
    "name_fs": 9.5,
    "cap_fs": 7,
    "lag_fs": 7.5,
    "legend_fs": 7.5,
    "hp_fs": 4.5,
    # Layout
    "fig_w": 7.5,
    "fig_h": 9.5,
    "dpi": 300,
    "layer_gap": 1.55,
    "col_gap": 2.8,
}


# ---------------------------------------------------------------------------
# 1. Config parsing
# ---------------------------------------------------------------------------
def load_cascade_topology(config_path):
    """Parse cascade_config.yaml into nodes and edges."""
    with open(config_path) as f:
        config = yaml.safe_load(f)

    reservoirs, rivers, confluences = {}, {}, {}
    for name, attrs in config.items():
        if not isinstance(attrs, dict) or "object_type" not in attrs:
            continue
        kind = attrs["object_type"]
        if kind == "reservoir":
            reservoirs[name] = attrs
        elif kind == "river":
            rivers[name] = attrs
        elif kind == "confluence":
            confluences[name] = attrs

    nodes = []
    for name, a in reservoirs.items():
        nodes.append({
            "name": name,
            "type": "reservoir",
            "capacity": a.get("capacity"),
            "hydropower": "min_power_pool" in a,
        })
    for name in confluences:
        nodes.append({"name": name, "type": "confluence"})
    node_names = {n["name"] for n in nodes}

    edges, seen = [], set()
    for name, attrs in {**reservoirs, **confluences}.items():
        ds = attrs.get("downstream_object")
        if ds is None or ds == "NA":
            continue
        if ds in rivers:
            target = rivers[ds]["downstream_object"]
            lag = rivers[ds].get("lag", 0)
            key = (name, target)
            if key not in seen:
                edges.append({"source": name, "target": target,
                              "lag": lag, "river": ds})
                seen.add(key)
        elif ds in node_names:
            key = (name, ds)
            if key not in seen:
                edges.append({"source": name, "target": ds,
                              "lag": 0, "river": None})
                seen.add(key)

    return nodes, edges


# ---------------------------------------------------------------------------
# 2. Layout
# ---------------------------------------------------------------------------
def compute_layout(nodes, edges):
    """Layered layout via topological longest-path + barycenter refinement."""
    children = defaultdict(list)
    parents = defaultdict(list)
    for e in edges:
        children[e["source"]].append(e["target"])
        parents[e["target"]].append(e["source"])

    all_names = [n["name"] for n in nodes]
    in_deg = {n: len(parents[n]) for n in all_names}
    layer = {n: 0 for n in all_names}

    queue = [n for n in all_names if in_deg[n] == 0]
    while queue:
        cur = queue.pop(0)
        for ch in children[cur]:
            layer[ch] = max(layer[ch], layer[cur] + 1)
            in_deg[ch] -= 1
            if in_deg[ch] == 0:
                queue.append(ch)

    # Pull headwater tributaries near their merge target
    for n in all_names:
        if not parents[n]:
            for ch in children[n]:
                if layer[ch] - layer[n] > 1:
                    layer[n] = layer[ch] - 1

    layers = defaultdict(list)
    for n, l in layer.items():
        layers[l].append(n)

    x = {}
    for l in sorted(layers):
        members = sorted(layers[l])
        nm = len(members)
        for i, name in enumerate(members):
            x[name] = (i - (nm - 1) / 2) * STYLE["col_gap"]

    gap = STYLE["col_gap"]
    for _ in range(6):
        for l in sorted(layers):
            for name in layers[l]:
                px = [x[p] for p in parents[name] if p in x]
                if px:
                    x[name] = np.mean(px)
            _separate(layers[l], x, gap)
        for l in sorted(layers, reverse=True):
            for name in layers[l]:
                cx = [x[c] for c in children[name] if c in x]
                if cx:
                    x[name] = np.mean(cx)
            _separate(layers[l], x, gap)

    return {n: (x[n], -layer[n] * STYLE["layer_gap"]) for n in x}


def _separate(members, x, gap):
    ordered = sorted(members, key=lambda n: x[n])
    for i in range(1, len(ordered)):
        if x[ordered[i]] - x[ordered[i - 1]] < gap:
            x[ordered[i]] = x[ordered[i - 1]] + gap


# ---------------------------------------------------------------------------
# 3. Geometry helpers
# ---------------------------------------------------------------------------
def _trap_width_at(frac):
    """Width of reservoir trapezoid at fractional height (0=bottom, 1=top)."""
    return STYLE["res_bw"] + (STYLE["res_tw"] - STYLE["res_bw"]) * frac


def _node_anchor(node, pos, side):
    """Return anchor point on node boundary. side='top' or 'bottom'."""
    x, y = pos
    if node["type"] == "reservoir":
        hh = STYLE["res_h"] / 2
        return (x, y + hh) if side == "top" else (x, y - hh)
    r = STYLE["conf_r"]
    return (x, y + r) if side == "top" else (x, y - r)


# ---------------------------------------------------------------------------
# 4. Drawing — edges
# ---------------------------------------------------------------------------
def _bezier_edge(ax, x0, y0, x1, y1, lw, color, dashes=None):
    """Smooth cubic Bezier from (x0,y0) to (x1,y1). Returns midpoint."""
    dy = abs(y1 - y0)
    tension = 0.45 * dy
    verts = [
        (x0, y0),
        (x0, y0 - tension),
        (x1, y1 + tension),
        (x1, y1),
    ]
    codes = [MplPath.MOVETO, MplPath.CURVE4, MplPath.CURVE4, MplPath.CURVE4]
    path = MplPath(verts, codes)

    kw = dict(facecolor="none", edgecolor=color, linewidth=lw,
              capstyle="round", joinstyle="round", zorder=1)
    if dashes:
        kw["linestyle"] = dashes
    from matplotlib.patches import PathPatch
    ax.add_patch(PathPatch(path, **kw))

    # Midpoint at t=0.5
    def _eval(t):
        m = 1 - t
        bx = m**3*verts[0][0] + 3*m**2*t*verts[1][0] + 3*m*t**2*verts[2][0] + t**3*verts[3][0]
        by = m**3*verts[0][1] + 3*m**2*t*verts[1][1] + 3*m*t**2*verts[2][1] + t**3*verts[3][1]
        return bx, by

    mx, my = _eval(0.50)

    # Arrowhead at t≈0.80
    ta = 0.80
    ax_pt, ay_pt = _eval(ta)
    ma = 1 - ta
    dtx = (3*ma**2*(verts[1][0]-verts[0][0]) + 6*ma*ta*(verts[2][0]-verts[1][0])
           + 3*ta**2*(verts[3][0]-verts[2][0]))
    dty = (3*ma**2*(verts[1][1]-verts[0][1]) + 6*ma*ta*(verts[2][1]-verts[1][1])
           + 3*ta**2*(verts[3][1]-verts[2][1]))
    mag = np.hypot(dtx, dty)
    if mag > 0:
        dtx, dty = dtx / mag, dty / mag
    px, py = -dty, dtx
    sz = 0.055
    tri = np.array([
        [ax_pt + dtx*sz*1.2, ay_pt + dty*sz*1.2],
        [ax_pt - dtx*sz*0.2 + px*sz*0.55, ay_pt - dty*sz*0.2 + py*sz*0.55],
        [ax_pt - dtx*sz*0.2 - px*sz*0.55, ay_pt - dty*sz*0.2 - py*sz*0.55],
    ])
    ax.add_patch(Polygon(tri, closed=True, facecolor=color, edgecolor="none", zorder=2))

    return mx, my


def _draw_edges(ax, edges, positions, node_lookup):
    """Draw all edges with Bezier curves and lag labels."""
    for e in edges:
        s_node = node_lookup[e["source"]]
        t_node = node_lookup[e["target"]]
        sx, sy = _node_anchor(s_node, positions[e["source"]], "bottom")
        tx, ty = _node_anchor(t_node, positions[e["target"]], "top")

        is_direct = e["lag"] == 0 and e["river"] is None
        color = STYLE["direct_color"] if is_direct else STYLE["edge_color"]
        dashes = (0, (3.5, 2)) if is_direct else None

        mx, my = _bezier_edge(ax, sx, sy, tx, ty,
                              STYLE["edge_lw"], color, dashes)

        label = "direct" if is_direct else f"{e['lag']} h"
        ax.text(
            mx + 0.15, my, label,
            ha="left", va="center",
            fontsize=STYLE["lag_fs"],
            color=STYLE["lag_text"],
            bbox=dict(facecolor="white", edgecolor="none", pad=1.2,
                      alpha=0.92),
            zorder=6,
        )


# ---------------------------------------------------------------------------
# 5. Drawing — nodes
# ---------------------------------------------------------------------------
def _draw_reservoir(ax, node, x, y):
    """Draw a trapezoid reservoir with optional water lines and HP bar."""
    tw, bw, h = STYLE["res_tw"], STYLE["res_bw"], STYLE["res_h"]
    half_h = h / 2

    # Trapezoid vertices (wider at top = open water surface)
    trap_verts = [
        (x - bw/2, y - half_h),   # bottom-left
        (x + bw/2, y - half_h),   # bottom-right
        (x + tw/2, y + half_h),   # top-right
        (x - tw/2, y + half_h),   # top-left
    ]
    trap = Polygon(trap_verts, closed=True,
                   facecolor=STYLE["res_fill"],
                   edgecolor=STYLE["res_stroke"],
                   linewidth=STYLE["node_lw"], zorder=10)
    ax.add_patch(trap)

    # Subtle water-surface ripple lines
    for frac in [0.72, 0.84]:
        wy = y - half_h + h * frac
        w_at_y = _trap_width_at(frac)
        inset = 0.08
        wxs = np.linspace(x - w_at_y/2 + inset, x + w_at_y/2 - inset, 60)
        wys = wy + 0.012 * np.sin(wxs * 28)
        ax.plot(wxs, wys, color=STYLE["water_line"],
                alpha=STYLE["water_line_alpha"], lw=0.5, zorder=11,
                solid_capstyle="round")

    # Hydropower indicator: amber bar at base of trapezoid
    if node.get("hydropower"):
        bh = STYLE["hp_bar_h"]
        frac_bar = bh / h
        w_bar_top = _trap_width_at(frac_bar)
        hp_verts = [
            (x - bw/2,        y - half_h),
            (x + bw/2,        y - half_h),
            (x + w_bar_top/2, y - half_h + bh),
            (x - w_bar_top/2, y - half_h + bh),
        ]
        hp_bar = Polygon(hp_verts, closed=True,
                         facecolor=STYLE["hp_color"],
                         edgecolor=STYLE["hp_stroke"],
                         linewidth=STYLE["hp_lw"], zorder=11)
        ax.add_patch(hp_bar)

    # Name
    ax.text(x, y + 0.06, node["name"],
            ha="center", va="center",
            fontsize=STYLE["name_fs"], fontweight="bold",
            color=STYLE["res_text"], zorder=12)

    # Capacity
    if node.get("capacity"):
        ax.text(x, y - 0.14,
                f"{node['capacity']:,} Mm\u00b3",
                ha="center", va="center",
                fontsize=STYLE["cap_fs"],
                color=STYLE["cap_text"], zorder=12)


def _draw_confluence(ax, node, x, y):
    """Draw a small confluence circle with italic label."""
    r = STYLE["conf_r"]
    circ = plt.Circle((x, y), r,
                       facecolor=STYLE["conf_fill"],
                       edgecolor=STYLE["conf_stroke"],
                       linewidth=STYLE["node_lw"], zorder=10)
    ax.add_patch(circ)
    ax.text(x + r + 0.10, y, node["name"],
            ha="left", va="center",
            fontsize=STYLE["name_fs"], fontstyle="italic",
            color=STYLE["conf_text"], zorder=11)


def _draw_nodes(ax, nodes, positions):
    """Draw all nodes."""
    for node in nodes:
        x, y = positions[node["name"]]
        if node["type"] == "reservoir":
            _draw_reservoir(ax, node, x, y)
        else:
            _draw_confluence(ax, node, x, y)


# ---------------------------------------------------------------------------
# 6. Legend
# ---------------------------------------------------------------------------
def _draw_legend(ax, has_hydropower):
    """Compact legend with proxy artists."""
    handles, labels = [], []

    # Reservoir
    handles.append(plt.Line2D(
        [0], [0], marker=(4, 0, 45), color="w",
        markerfacecolor=STYLE["res_fill"],
        markeredgecolor=STYLE["res_stroke"],
        markersize=8, markeredgewidth=0.5, linewidth=0))
    labels.append("Reservoir")

    # Hydropower
    if has_hydropower:
        handles.append(plt.Line2D(
            [0], [0], marker="s", color="w",
            markerfacecolor=STYLE["hp_color"],
            markeredgecolor=STYLE["hp_stroke"],
            markersize=5, markeredgewidth=0.4, linewidth=0))
        labels.append("Hydropower plant")

    # Confluence
    handles.append(plt.Line2D(
        [0], [0], marker="o", color="w",
        markerfacecolor=STYLE["conf_fill"],
        markeredgecolor=STYLE["conf_stroke"],
        markersize=5, markeredgewidth=0.5, linewidth=0))
    labels.append("Confluence")

    # River
    handles.append(plt.Line2D(
        [0], [0], color=STYLE["edge_color"], linewidth=STYLE["edge_lw"]))
    labels.append("River (lag in hours)")

    # Direct
    handles.append(plt.Line2D(
        [0], [0], color=STYLE["direct_color"],
        linewidth=STYLE["edge_lw"], linestyle="--"))
    labels.append("Direct connection")

    leg = ax.legend(
        handles, labels,
        loc="lower right", fontsize=STYLE["legend_fs"],
        frameon=True, framealpha=0.95,
        edgecolor="#D0D5DA", fancybox=False,
        handletextpad=0.5, borderpad=0.7,
        handlelength=1.3, labelspacing=0.4,
    )
    leg.get_frame().set_linewidth(0.4)


# ---------------------------------------------------------------------------
# 7. Orchestrator
# ---------------------------------------------------------------------------
def plot_cascade_network(config_path, output_path=None, title=None):
    """Load topology, compute layout, draw, and save."""
    config_path = Path(config_path)
    if output_path is None:
        output_path = config_path.parent / "cascade_network.pdf"
    else:
        output_path = Path(output_path)

    nodes, edges = load_cascade_topology(config_path)
    positions = compute_layout(nodes, edges)
    node_lookup = {n["name"]: n for n in nodes}
    has_hp = any(n.get("hydropower") for n in nodes)

    plt.rcParams.update({
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica Neue", "Helvetica", "Arial",
                            "DejaVu Sans"],
        "font.size": 8,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "figure.facecolor": "white",
        "savefig.facecolor": "white",
    })

    fig, ax = plt.subplots(figsize=(STYLE["fig_w"], STYLE["fig_h"]))

    _draw_edges(ax, edges, positions, node_lookup)
    _draw_nodes(ax, nodes, positions)
    _draw_legend(ax, has_hp)

    if title:
        ax.set_title(title, fontsize=10, fontweight="bold", pad=14)

    all_x = [p[0] for p in positions.values()]
    all_y = [p[1] for p in positions.values()]
    ax.set_xlim(min(all_x) - 1.8, max(all_x) + 1.8)
    ax.set_ylim(min(all_y) - 0.8, max(all_y) + 0.8)
    ax.set_aspect("equal")
    ax.axis("off")

    fig.tight_layout(pad=0.5)
    fig.savefig(output_path, bbox_inches="tight", dpi=STYLE["dpi"])
    print(f"Saved: {output_path}")

    if output_path.suffix in (".pdf", ".svg"):
        png_path = output_path.with_suffix(".png")
        fig.savefig(png_path, bbox_inches="tight", dpi=STYLE["dpi"])
        print(f"Saved: {png_path}")

    plt.close(fig)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate a publication-quality cascade network diagram."
    )
    parser.add_argument("config", help="Path to cascade_config.yaml")
    parser.add_argument("-o", "--output", default=None,
                        help="Output file (pdf/svg/png).")
    parser.add_argument("--title", default=None, help="Optional figure title.")
    args = parser.parse_args()
    plot_cascade_network(args.config, args.output, args.title)