Commit eeb7dd4b authored by Turner, Sean's avatar Turner, Sean
Browse files

Fix HPF interpolation bugs and add comprehensive test suite



- Fix corner lookup bug in compute_release_to_meet_target_power: changed
  else-if chain to separate if statements so multiple corners can be set
  from the same table entry when h_lo==h_hi or p_lo==p_hi

- Fix reverse function (compute_power_given_actual_release) to use
  consistent cell selection with forward function, ensuring P→Q→P'
  roundtrip consistency

- Add PyO3 exports for compute_release and compute_power functions
  to enable direct testing from Python

- Add comprehensive HPF test suite (tests/test_hpf.py) with 74 tests:
  - Grid corner, edge, and interior roundtrip tests
  - Statistical robustness tests with random (H,P) pairs
  - Performance benchmarks
  - Edge cases and diagnostics

- Update CLAUDE.md with testing instructions

- Fix Cargo.toml edition (2025 -> 2021) for compatibility

- Add pytest as dev dependency

Co-Authored-By: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent dfa116cc
Loading
Loading
Loading
Loading
+65 −0
Original line number Diff line number Diff line
@@ -66,6 +66,71 @@ reticulate::use_virtualenv(".venv", required = TRUE)
# quarto render examples/Cumberland/script.qmd
```

## Testing

After any major code changes to `src/helpers.rs` or `src/lib.rs`, run the test suite:

```bash
# Rebuild the module first
uv pip install -e .

# Run the HPF interpolation tests
uv run python -m pytest tests/test_hpf.py -v

# Quick smoke test with the toy example
uv run python -c "
import powersheds
from dataclasses import dataclass

@dataclass
class ReservoirData:
    object_type: str
    capacity: float
    initial_pool_elevation: float
    min_power_pool: float
    set_storage: list
    set_elevation: list
    hpf_h: list
    hpf_p: list
    hpf_q: list
    tailwater_elevation: float
    max_release: float
    min_release: float
    catchment_inflow: list
    target_power: list
    simulation_order: int
    downstream_object: str

@dataclass
class CascadeData:
    reservoirs: dict
    rivers: dict
    confluences: dict

res = ReservoirData(
    object_type='reservoir', simulation_order=1, downstream_object='NA',
    capacity=500.0, initial_pool_elevation=200.0, min_power_pool=190.0,
    tailwater_elevation=150.0, max_release=2.0, min_release=0.0,
    set_storage=[0.0, 250.0, 500.0], set_elevation=[180.0, 195.0, 210.0],
    catchment_inflow=[0.2]*24, target_power=[50.0]*24,
    hpf_h=[40,40,40,60,60,60], hpf_p=[0,50,100,0,50,100], hpf_q=[0,60,120,0,45,90])
result = powersheds.simulate_cascade(CascadeData(reservoirs={'Demo': res}, rivers={}, confluences={}))
import math
assert not any(math.isnan(v) for v in result['Demo']['actual_power']), 'NA values found!'
print('Smoke test passed')
"
```

**Test Coverage**:
- `tests/test_hpf.py` - HPF bilinear interpolation tests (74 tests)
  - Grid corner, edge, and interior roundtrip tests
  - Statistical robustness with random (H, P) pairs
  - Performance benchmarks
  - Edge cases (zero power, empty tables, etc.)

**Known Limitations**:
- Some HPF tables have flat regions where multiple P values produce the same Q. In these regions, the reverse function (Q→P) cannot uniquely determine P, resulting in roundtrip errors up to ~1 MW. This is a physical limitation of the data, not a bug.

## Dependencies

- **Build**: maturin 1.9.4 (Rust-to-Python bridge)
+1 −1
Original line number Diff line number Diff line
[package]
name = "powersheds"
version = "0.1.0"
edition = "2025"
edition = "2021"
license = "BSD-3-Clause"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+5 −0
Original line number Diff line number Diff line
@@ -25,3 +25,8 @@ dependencies = [

[tool.maturin]
features = ["pyo3/extension-module"]

[dependency-groups]
dev = [
    "pytest>=9.0.2",
]
+181 −86
Original line number Diff line number Diff line
@@ -176,6 +176,8 @@ pub fn compute_release_to_meet_target_power(

    // 3) Fetch Q at the four corners by scanning the flat table.
    //    (H, P) pairs are expected to appear exactly in the table.
    //    NOTE: Use separate if statements (not else if) because when h_lo==h_hi
    //    or p_lo==p_hi, the same table entry may match multiple corners.
    let mut q11 = f64::NAN; // (h_lo, p_lo)
    let mut q21 = f64::NAN; // (h_hi, p_lo)
    let mut q12 = f64::NAN; // (h_lo, p_hi)
@@ -184,14 +186,18 @@ pub fn compute_release_to_meet_target_power(
    for i in 0..hpf_head_m.len() {
        let h = hpf_head_m[i];
        let p = hpf_power_mw[i];
        let q = hpf_flow_cumecs[i];
        if approx_eq(h, h_lo) && approx_eq(p, p_lo) {
            q11 = hpf_flow_cumecs[i];
        } else if approx_eq(h, h_hi) && approx_eq(p, p_lo) {
            q21 = hpf_flow_cumecs[i];
        } else if approx_eq(h, h_lo) && approx_eq(p, p_hi) {
            q12 = hpf_flow_cumecs[i];
        } else if approx_eq(h, h_hi) && approx_eq(p, p_hi) {
            q22 = hpf_flow_cumecs[i];
            q11 = q;
        }
        if approx_eq(h, h_hi) && approx_eq(p, p_lo) {
            q21 = q;
        }
        if approx_eq(h, h_lo) && approx_eq(p, p_hi) {
            q12 = q;
        }
        if approx_eq(h, h_hi) && approx_eq(p, p_hi) {
            q22 = q;
        }
    }

@@ -246,7 +252,8 @@ pub fn compute_release_to_meet_target_power(
///
/// Notes:
/// - Converts Mm3 over 1 hour to m3/s using `MM3_IN_M3` and `SECONDS_IN_HOUR`.
/// - Style mirrors `compute_release_to_meet_target_power`.
/// - Uses consistent cell selection with compute_release_to_meet_target_power to ensure
///   roundtrip consistency (P → Q → P' yields P' ≈ P).
pub fn compute_power_given_actual_release(
    release_mm3: f64,
    head_m: f64,
@@ -265,110 +272,198 @@ debug_assert_eq!(hpf_head_m.len(), hpf_power_mw.len());

    #[inline]
    fn approx_eq(a: f64, b: f64) -> bool {
        let eps = 1e-8; // slightly loose for CSV/Parquet float wiggles
        let eps = 1e-9;
        (a - b).abs() <= eps * (1.0 + a.abs().max(b.abs()))
    }

    // 1) Unique, sorted H axis; pick H_lo/H_hi around H0
    #[inline]
    fn lerp(a: f64, b: f64, t: f64) -> f64 {
        a + t * (b - a)
    }

    // 1) Build unique sorted H and P axes
    let mut unique_h = hpf_head_m.to_vec();
    unique_h.sort_by(|a, b| a.partial_cmp(b).unwrap());
    unique_h.dedup_by(|a, b| approx_eq(*a, *b));
    if unique_h.is_empty() {

    let mut unique_p = hpf_power_mw.to_vec();
    unique_p.sort_by(|a, b| a.partial_cmp(b).unwrap());
    unique_p.dedup_by(|a, b| approx_eq(*a, *b));

    if unique_h.is_empty() || unique_p.is_empty() {
        return f64::NAN;
    }
    let h0 = head_m.max(unique_h[0]).min(*unique_h.last().unwrap());
    let (hi_lo, hi_hi) = match unique_h.binary_search_by(|v| v.partial_cmp(&h0).unwrap()) {

    // 2) Clamp H0 and find H bracket
    let clamp = |x: f64, lo: f64, hi: f64| x.max(lo).min(hi);
    let h0 = clamp(head_m, unique_h[0], *unique_h.last().unwrap());

    let locate = |grid: &Vec<f64>, x: f64| -> (usize, usize) {
        match grid.binary_search_by(|v| v.partial_cmp(&x).unwrap()) {
            Ok(i) => (i, i),
            Err(i) => {
                if i == 0 { (0, 0) }
            else if i >= unique_h.len() { let j = unique_h.len() - 1; (j, j) }
                else if i >= grid.len() { let j = grid.len() - 1; (j, j) }
                else { (i - 1, i) }
            }
        }
    };

    let (hi_lo, hi_hi) = locate(&unique_h, h0);
    let h_lo = unique_h[hi_lo];
    let h_hi = unique_h[hi_hi];

    // 2) From rows H∈{H_lo,H_hi}, keep four points with smallest |Q−Q0|
    #[derive(Clone, Copy)]
    struct Row { p: f64, q: f64, err: f64 }
    let mut best: [Option<Row>; 4] = [None, None, None, None];
    // Compute H interpolation factor
    let t = if approx_eq(h_lo, h_hi) { 0.0 } else { (h0 - h_lo) / (h_hi - h_lo) };

    #[inline]
    fn insert_best(best: &mut [Option<Row>; 4], cand: Row) {
        for slot in best.iter_mut() {
            if slot.is_none() { *slot = Some(cand); return; }
        }
        let mut worst_i = 0usize;
        let mut worst_err = -f64::INFINITY;
        for i in 0..4 {
            let e = best[i].unwrap().err;
            if e > worst_err { worst_err = e; worst_i = i; }
        }
        if cand.err < worst_err { best[worst_i] = Some(cand); }
    }
    // 3) Build a single-pass lookup: collect Q at (h_lo, P) and (h_hi, P) for all P
    //    Use a hashmap-like approach with indices into unique_p
    let n_p = unique_p.len();
    let mut q_at_h_lo: Vec<f64> = vec![f64::NAN; n_p];
    let mut q_at_h_hi: Vec<f64> = vec![f64::NAN; n_p];

    for i in 0..hpf_head_m.len() {
        let h = hpf_head_m[i];
        if approx_eq(h, h_lo) || approx_eq(h, h_hi) {
            let q = hpf_flow_cumecs[i];
            if q.is_finite() {
                let p = hpf_power_mw[i];
                insert_best(&mut best, Row { p, q, err: (q - q0).abs() });
        let pv = hpf_power_mw[i];
        let qv = hpf_flow_cumecs[i];

        if !qv.is_finite() {
            continue;
        }

        // Find P index using binary search
        if let Ok(p_idx) = unique_p.binary_search_by(|v| {
            if approx_eq(*v, pv) { std::cmp::Ordering::Equal }
            else { v.partial_cmp(&pv).unwrap() }
        }) {
            if approx_eq(h, h_lo) {
                q_at_h_lo[p_idx] = qv;
            }
            if approx_eq(h, h_hi) {
                q_at_h_hi[p_idx] = qv;
            }
        }
    let mut pts: Vec<Row> = best.iter().flatten().cloned().collect();
    if pts.len() < 4 { return f64::NAN; }
    }

    // 4) Compute Q at (h0, P) for each P and find the bracketing interval
    let mut p_lo_idx: Option<usize> = None;
    let mut prev_q_h0 = f64::NAN;

    for p_idx in 0..n_p {
        let q_lo = q_at_h_lo[p_idx];
        let q_hi = q_at_h_hi[p_idx];

        if !q_lo.is_finite() || !q_hi.is_finite() {
            continue;
        }

        let q_h0 = lerp(q_lo, q_hi, t);

    // 3) Identify the two P columns (P_lo, P_hi) from those four points
    pts.sort_by(|a, b| a.p.partial_cmp(&b.p).unwrap());
    let mut p_cols: Vec<f64> = Vec::with_capacity(2);
    for r in &pts {
        if p_cols.is_empty() || !approx_eq(*p_cols.last().unwrap(), r.p) {
            p_cols.push(r.p);
            if p_cols.len() == 2 { break; }
        // Check if Q0 is between prev_q_h0 and q_h0
        if prev_q_h0.is_finite() {
            if (prev_q_h0 <= q0 && q0 <= q_h0) || (q_h0 <= q0 && q0 <= prev_q_h0) {
                // Find the previous valid P index
                for prev_idx in (0..p_idx).rev() {
                    if q_at_h_lo[prev_idx].is_finite() && q_at_h_hi[prev_idx].is_finite() {
                        p_lo_idx = Some(prev_idx);
                        break;
                    }
                }
                if p_lo_idx.is_some() {
                    break;
                }
            }
        }
    if p_cols.len() < 2 { return f64::NAN; }
    let p_lo = p_cols[0];
    let p_hi = p_cols[1];

    // 4) Fetch the exact four corners Q_ij at (H_lo/H_hi × P_lo/P_hi)
    let (mut q11, mut q12, mut q21, mut q22) = (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
    for i in 0..hpf_head_m.len() {
        let h = hpf_head_m[i];
        let p = hpf_power_mw[i];
        let q = hpf_flow_cumecs[i];
        if      approx_eq(h, h_lo) && approx_eq(p, p_lo) { q11 = q; }
        else if approx_eq(h, h_lo) && approx_eq(p, p_hi) { q12 = q; }
        else if approx_eq(h, h_hi) && approx_eq(p, p_lo) { q21 = q; }
        else if approx_eq(h, h_hi) && approx_eq(p, p_hi) { q22 = q; }
        prev_q_h0 = q_h0;
    }

    // Handle edge cases
    let (p_lo, p_hi, q11, q12, q21, q22);

    if let Some(lo_idx) = p_lo_idx {
        // Find the next valid P index after lo_idx
        let mut hi_idx = lo_idx + 1;
        while hi_idx < n_p && (!q_at_h_lo[hi_idx].is_finite() || !q_at_h_hi[hi_idx].is_finite()) {
            hi_idx += 1;
        }
    if !(q11.is_finite() && q12.is_finite() && q21.is_finite() && q22.is_finite()) {
        if hi_idx >= n_p {
            return f64::NAN;
        }

    // 5) Solve bilinear coefficients on this cell:
    //    Q(H,P) = a + b H + c P + d H P
    let d_h = h_hi - h_lo;
    let d_p = p_hi - p_lo;
    if approx_eq(d_h, 0.0) || approx_eq(d_p, 0.0) {
        return p_lo; // degenerate cell; fall back
        p_lo = unique_p[lo_idx];
        p_hi = unique_p[hi_idx];
        q11 = q_at_h_lo[lo_idx];
        q12 = q_at_h_lo[hi_idx];
        q21 = q_at_h_hi[lo_idx];
        q22 = q_at_h_hi[hi_idx];
    } else {
        // Q0 is outside the range - clamp to edge
        // Find first and last valid P indices
        let mut first_valid: Option<usize> = None;
        let mut last_valid: Option<usize> = None;
        let mut first_q_h0 = f64::NAN;
        let mut last_q_h0 = f64::NAN;

        for p_idx in 0..n_p {
            let q_lo = q_at_h_lo[p_idx];
            let q_hi = q_at_h_hi[p_idx];
            if q_lo.is_finite() && q_hi.is_finite() {
                let q_h0 = lerp(q_lo, q_hi, t);
                if first_valid.is_none() {
                    first_valid = Some(p_idx);
                    first_q_h0 = q_h0;
                }
                last_valid = Some(p_idx);
                last_q_h0 = q_h0;
            }
        }

        let (first_idx, last_idx) = match (first_valid, last_valid) {
            (Some(f), Some(l)) if f < l => (f, l),
            _ => return f64::NAN,
        };

        // Determine which edge to use based on Q0
        if q0 <= first_q_h0 {
            return unique_p[first_idx];
        } else if q0 >= last_q_h0 {
            return unique_p[last_idx];
        }

    let d = (q22 - q21 - q12 + q11) / (d_h * d_p);
    let b = (q21 - q11) / d_h - d * p_lo;
    let c = (q12 - q11) / d_p - d * h_lo;
    let a = q11 - b * h_lo - c * p_lo - d * h_lo * p_lo;
        // Fallback: use first two valid columns
        let mut hi_idx = first_idx + 1;
        while hi_idx < n_p && (!q_at_h_lo[hi_idx].is_finite() || !q_at_h_hi[hi_idx].is_finite()) {
            hi_idx += 1;
        }
        if hi_idx >= n_p {
            return f64::NAN;
        }

    // 6) Invert at H0: Q0 = (a + b H0) + (c + d H0) * P  =>  P0 = (Q0 - a - b H0)/(c + d H0)
    let denom = c + d * h0;
    if approx_eq(denom, 0.0) {
        return p_lo; // ill-conditioned; pick lower bound
        p_lo = unique_p[first_idx];
        p_hi = unique_p[hi_idx];
        q11 = q_at_h_lo[first_idx];
        q12 = q_at_h_lo[hi_idx];
        q21 = q_at_h_hi[first_idx];
        q22 = q_at_h_hi[hi_idx];
    }
    let p0 = (q0 - a - b * h0) / denom;

    p0
    // 5) Bilinear inversion: solve for P given Q at (h0, q0)
    if approx_eq(p_lo, p_hi) {
        return p_lo;
    }

    // Q = (1-t)(1-u)*Q11 + t*(1-u)*Q21 + (1-t)*u*Q12 + t*u*Q22
    // At fixed H=h0: Q = A + u*B
    // where A = lerp(Q11, Q21, t), B = lerp(Q12-Q11, Q22-Q21, t)
    let a = lerp(q11, q21, t);
    let b = lerp(q12 - q11, q22 - q21, t);

    if approx_eq(b, 0.0) {
        return p_lo;
    }

    let u = (q0 - a) / b;
    lerp(p_lo, p_hi, u)
}
+62 −0
Original line number Diff line number Diff line
@@ -389,9 +389,71 @@ fn simulate_cascade(
    Ok(results)
}

/// Compute volumetric release (Mm³/hr) required to meet target power (MW).
///
/// Uses bilinear interpolation on a regular (H, P) grid with Q values.
///
/// Args:
///     power_mw: Target power in MW
///     head_m: Head in meters
///     hpf_head_m: Vector of head values from HPF table
///     hpf_power_mw: Vector of power values from HPF table
///     hpf_flow_cumecs: Vector of flow values (m³/s) from HPF table
///
/// Returns:
///     Volumetric release in Mm³/hr
#[pyfunction]
fn compute_release(
    power_mw: f64,
    head_m: f64,
    hpf_head_m: Vec<f64>,
    hpf_power_mw: Vec<f64>,
    hpf_flow_cumecs: Vec<f64>,
) -> f64 {
    compute_release_to_meet_target_power(
        power_mw,
        head_m,
        &hpf_head_m,
        &hpf_power_mw,
        &hpf_flow_cumecs,
    )
}

/// Compute power (MW) given actual release (Mm³/hr) and head (m).
///
/// Uses bilinear interpolation on a regular (H, P) grid with Q values.
///
/// Args:
///     release_mm3: Actual turbine release in Mm³/hr
///     head_m: Gross head in meters
///     hpf_head_m: Vector of head values from HPF table
///     hpf_power_mw: Vector of power values from HPF table
///     hpf_flow_cumecs: Vector of flow values (m³/s) from HPF table
///
/// Returns:
///     Power in MW
#[pyfunction]
fn compute_power(
    release_mm3: f64,
    head_m: f64,
    hpf_head_m: Vec<f64>,
    hpf_power_mw: Vec<f64>,
    hpf_flow_cumecs: Vec<f64>,
) -> f64 {
    compute_power_given_actual_release(
        release_mm3,
        head_m,
        &hpf_head_m,
        &hpf_power_mw,
        &hpf_flow_cumecs,
    )
}

/// A Python module implemented in Rust.
#[pymodule]
fn powersheds(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(simulate_cascade, m)?)?;
    m.add_function(wrap_pyfunction!(compute_release, m)?)?;
    m.add_function(wrap_pyfunction!(compute_power, m)?)?;
    Ok(())
}
Loading