Commit 939c3ebe authored by Turner, Sean's avatar Turner, Sean
Browse files

enforce simulation orders

parent faf965e2
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -169,7 +169,7 @@ CenterHill_OldHickory:
# Old Hickory to Cheatham
OldHickory_Cheatham:
  object_type: river
  simulation_order: 13
  simulation_order: 12
  downstream_object: Cheatham
  lag: 5

+77 −3
Original line number Diff line number Diff line
@@ -165,6 +165,28 @@ fn validate_reservoir_data(name: &str, reservoir: &ReservoirData, n: usize) -> P
    Ok(())
}

fn register_object(
    name: &str,
    simulation_order: i32,
    object_orders: &mut HashMap<String, i32>,
) -> PyResult<()> {
    if object_orders.contains_key(name) {
        return Err(PyValueError::new_err(format!(
            "Object name '{}' is duplicated across cascade objects; names must be globally unique",
            name
        )));
    }

    if object_orders.values().any(|&existing_order| existing_order == simulation_order) {
        return Err(PyValueError::new_err(format!(
            "simulation_order {} is duplicated; simulation_order values must be globally unique and sequential from 1..N",
            simulation_order
        )));
    }

    Ok(())
}

// Define structure for reservoir state
struct ReservoirState {
    storage: f64,
@@ -488,9 +510,12 @@ fn simulate_cascade(
    cascade_data: CascadeData,
) -> PyResult<HashMap<String, CascadeResults>> {

    let first_reservoir = cascade_data.reservoirs.values().next()
    cascade_data.reservoirs.values().next()
        .ok_or_else(|| PyValueError::new_err("No reservoirs provided"))?;
    let n = first_reservoir.catchment_inflow.len();
    let n = cascade_data.reservoirs.values()
        .map(|reservoir| reservoir.catchment_inflow.len())
        .max()
        .unwrap_or(0);

    // Validate reservoir input shapes before allocating outputs or entering
    // the timestep loop so malformed user inputs raise Python errors instead
@@ -502,22 +527,71 @@ fn simulate_cascade(
    let mut results = HashMap::new();

    let mut ordered_objects = Vec::new();
    let mut object_orders: HashMap<String, i32> = HashMap::new();
    for (name, obj) in &cascade_data.reservoirs {
        let obj_type = obj.object_type.parse::<ObjectType>()
            .map_err(|e| PyValueError::new_err(e))?;
        register_object(name, obj.simulation_order, &mut object_orders)?;
        object_orders.insert(name.to_string(), obj.simulation_order);
        ordered_objects.push((name, obj.simulation_order, obj_type));
    }
    for (name, obj) in &cascade_data.rivers {
        let obj_type = obj.object_type.parse::<ObjectType>()
            .map_err(|e| PyValueError::new_err(e))?;
        register_object(name, obj.simulation_order, &mut object_orders)?;
        object_orders.insert(name.to_string(), obj.simulation_order);
        ordered_objects.push((name, obj.simulation_order, obj_type));
    }
    for (name, obj) in &cascade_data.confluences {
        let obj_type = obj.object_type.parse::<ObjectType>()
            .map_err(|e| PyValueError::new_err(e))?;
        register_object(name, obj.simulation_order, &mut object_orders)?;
        object_orders.insert(name.to_string(), obj.simulation_order);
        ordered_objects.push((name, obj.simulation_order, obj_type));
    }
    ordered_objects.sort_by_key(|&(_, order, _)| order);

    let total_objects = ordered_objects.len();
    for expected_order in 1..=(total_objects as i32) {
        if !object_orders.values().any(|&order| order == expected_order) {
            return Err(PyValueError::new_err(format!(
                "simulation_order values must be sequential integers from 1..{}; missing {}",
                total_objects, expected_order
            )));
        }
    }

    for (upstream_name, upstream_order, upstream_type) in &ordered_objects {
        let downstream_name = match upstream_type {
            ObjectType::Reservoir => &cascade_data.reservoirs.get(*upstream_name)
                .unwrap().downstream_object,
            ObjectType::River => &cascade_data.rivers.get(*upstream_name)
                .unwrap().downstream_object,
            ObjectType::Confluence => &cascade_data.confluences.get(*upstream_name)
                .unwrap().downstream_object,
        };

        if downstream_name == "NA" {
            continue;
        }

        let downstream_order = object_orders.get(downstream_name).ok_or_else(|| {
            PyValueError::new_err(format!(
                "Object '{}' references unknown downstream_object '{}'",
                upstream_name, downstream_name
            ))
        })?;

        if downstream_order <= upstream_order {
            return Err(PyValueError::new_err(format!(
                "Object '{}' (simulation_order={}) must have downstream_object '{}' with a greater simulation_order; found {}",
                upstream_name, upstream_order, downstream_name, downstream_order
            )));
        }
    }

    ordered_objects.sort_by(|(name_a, order_a, _), (name_b, order_b, _)| {
        order_a.cmp(order_b).then_with(|| name_a.cmp(name_b))
    });

    for (name, _, obj_type) in &ordered_objects {
        match obj_type {
+65 −0
Original line number Diff line number Diff line
@@ -139,6 +139,71 @@ class TestInputValidation:
            powersheds.simulate_cascade(cascade)


class TestTopologyValidation:
    """Tests that cascade topology errors fail fast with clear messages."""

    def test_unknown_downstream_object_raises(self):
        res = make_reservoir(n_hours=1, downstream_object="MissingNode")
        cascade = make_cascade(reservoirs={"Res": res})
        with pytest.raises(Exception, match="unknown downstream_object 'MissingNode'"):
            powersheds.simulate_cascade(cascade)

    def test_duplicate_simulation_order_raises(self):
        res_a = make_reservoir(n_hours=1, simulation_order=1, downstream_object="NA")
        res_b = make_reservoir(
            n_hours=1,
            simulation_order=1,
            downstream_object="NA",
            catchment_inflow=[0.0],
            target_power=[0.0],
        )
        cascade = make_cascade(reservoirs={"ResA": res_a, "ResB": res_b})
        with pytest.raises(Exception, match="simulation_order 1 is duplicated"):
            powersheds.simulate_cascade(cascade)

    def test_simulation_order_must_be_sequential_without_gaps(self):
        res_a = make_reservoir(n_hours=1, simulation_order=1, downstream_object="NA")
        res_b = make_reservoir(
            n_hours=1,
            simulation_order=3,
            downstream_object="NA",
            catchment_inflow=[0.0],
            target_power=[0.0],
        )
        cascade = make_cascade(reservoirs={"ResA": res_a, "ResB": res_b})
        with pytest.raises(Exception, match="must be sequential integers from 1..2; missing 2"):
            powersheds.simulate_cascade(cascade)

    def test_downstream_must_have_greater_simulation_order(self):
        res_a = make_reservoir(n_hours=1, simulation_order=2, downstream_object="ResB")
        res_b = make_reservoir(
            n_hours=1,
            simulation_order=1,
            downstream_object="NA",
            catchment_inflow=[0.0],
            target_power=[0.0],
        )
        cascade = make_cascade(reservoirs={"ResA": res_a, "ResB": res_b})
        with pytest.raises(Exception, match="must have downstream_object 'ResB' with a greater simulation_order"):
            powersheds.simulate_cascade(cascade)

    def test_duplicate_object_names_across_types_raise(self):
        res = make_reservoir(n_hours=1, downstream_object="NA")
        riv = RiverData(
            object_type="river",
            simulation_order=2,
            downstream_object="NA",
            lag=0,
            legacy_flows=[],
        )
        cascade = make_cascade(
            reservoirs={"SharedName": res},
            rivers={"SharedName": riv},
        )
        with pytest.raises(Exception, match="Object name 'SharedName' is duplicated"):
            powersheds.simulate_cascade(cascade)


class TestTwoReservoirChain:
    """Tests for upstream-downstream routing through a river."""