Commit 35b52589 authored by Blais, Chris's avatar Blais, Chris
Browse files

separate module for constraints

parent 95f04c4b
Loading
Loading
Loading
Loading
+295 −0
Original line number Diff line number Diff line
import sympy as sy
import numpy as np
import os
import sys
import scipy as sp
import copy
from PIL import Image
from constraint import Problem
    
from itertools import product

from sympy.core.numbers import int_valued

module_path = os.path.join(os.path.dirname(os.path.abspath('')))
if module_path not in sys.path:
    sys.path.insert(0, module_path)
else:
    print("path already in sys.path")

import matplotlib.pyplot as plt
from numerical_labelling.labelling import *


from math import gcd
class labelConstraint(sy.core.add.Add):
    """
    making a class for constraint, only additional
    field is the type
    """
    def __new__(cls, *args, type=None, **kwargs):
        obj = super().__new__(cls, *args, **kwargs)
        obj._type = type
        return obj

    @property
    def type(self):
        return getattr(self, '_type', None)
    
    @type.setter
    def type(self, value):
        self._type = value

    @property
    def add_expr(self):
        return getattr(self, '_add_expr', None)
    
    @add_expr.setter
    def add_expr(self, value):
        self._add_expr = value
    
    @classmethod
    def from_add(cls, add_expr, type=None):
        # Extract args from the existing Add expression
        new_obj = cls(*add_expr.args, type=type)
        new_obj.add_expr = add_expr
        return new_obj
    

def lcm(a, b):
    return abs(a*b) // gcd(a, b)

def get_consts(cfa, cfb):
    lcm_value = lcm(cfa, cfb)
    const_1 = lcm_value / cfa
    const_2 = lcm_value / cfb

    return const_1, const_2


def get_varbs_in_eq(equation_list, all_constr):
    """
    equation_list: list of ids in all constr you want to use
    """
    varb_list = []
    for iconstr in equation_list:
        constr = all_constr[iconstr]
        symbs = [symbol.name for symbol in list(constr.free_symbols)]
        for symb in symbs:
            if symb not in varb_list:
                varb_list.append(symb)
    varb_list = sorted(varb_list)
    return varb_list


def create_constraint_function(equation, condition_func=None, debug=False):
    """
    from gitlab duo, but it is pretty straightforward function factory
    Dynamically create a constraint function from a sympy equation
    
    Args:
        equation: sympy expression
        condition_func: function that takes the equation result and returns bool
                       defaults to checking if result == 0
    """
    # Get free variables and sort them for consistent ordering
    free_vars = sorted(equation.free_symbols, key=str)
    var_names = [str(var) for var in free_vars]
    
    if condition_func is None:
        condition_func = lambda result: abs(float(result)) <= 1e-10  # near zero
    
    def constraint_func(*values):
        if len(values) != len(free_vars):
            raise ValueError(f"Expected {len(free_vars)} values, got {len(values)}")
        
        substitution_dict = dict(zip(free_vars, values))
        try:
            result = equation.subs(substitution_dict)
            if debug:
                print(substitution_dict, result, condition_func(result))
            return condition_func(result)
        except Exception as e:
            print("Exception occurred:", e)

            return False
    
    # Store metadata for debugging
    constraint_func.equation = equation
    constraint_func.variables = var_names
    constraint_func.free_symbols = free_vars
    
    return constraint_func, var_names


def add_sltn(sltn_dict, varb_name, constr):
    """add a solved equation to sltn dict 
    example: 
    varb_name = x3
    constr = m1x1 + m2x2 - m3x3
    sltn_dict.update({'x3': (m1x1 + m2x2)/m3})
    """
    sltn = sy.solve(constr, varb_name)
    if len(sltn) !=1: 
        raise ValueError(f"multiple solutions when trying to solve for {varb_name}")
    
    if varb_name in sltn_dict.keys():
        print(f"variable {varb_name} in eq {constr} already solved for in eq {sltn_dict[varb_name]}")
        # raise KeyError(f"{varb_name} already solved for")
    sltn_dict.update({varb_name:sltn[0]})

    return sltn_dict


def add_constraint(constr_dict, eqn, type=None):
    """
    add constraint to constraint dict
    """
    indx = len(constr_dict)
    constr_dict[indx] = labelConstraint.from_add(eqn, type=type)
    return constr_dict

def constr_split(
    min, mout1, mout2, 
    xin, xout1, xout2,
    pin, pout1, pout2, 
    constr, sltns=None,
    ):
    """
    constraint for splitter unit
    """
    constr_hydraulic = min - mout1 - mout2
    constr_eq_conc1 = xin - xout1
    constr_eq_conc2 = xin - xout2
    constr_press1 = pin - pout1
    constr_press2 = pin - pout2
    constr = add_constraint(constr, constr_hydraulic, type="lin_hydraulic")
    constr = add_constraint(constr, constr_eq_conc1, type="bl_component")
    constr = add_constraint(constr, constr_eq_conc2, type="bl_component")
    constr = add_constraint(constr, constr_press1, type="nl_pressure")
    constr = add_constraint(constr, constr_press2, type="nl_pressure")
    if sltns is not None:
        sltns = add_sltn(sltns, xout1.name, constr_eq_conc1)
        sltns = add_sltn(sltns, xout2.name, constr_eq_conc2)
        sltns = add_sltn(sltns, pout1.name, constr_press1)
        sltns = add_sltn(sltns, pout2.name, constr_press2)
        return constr, sltns
    return constr

def constr_mix(
        min1, min2, mout, 
        xin1, xin2, xout,
        pin1, pin2, pout, 
        constr, sltns=None
        ):
    """ 
    constraint for 2 pipes coming together
    """
    constr_hydraulic = min1 + min2 - mout
    constr_component = min1*xin1 + min2*xin2 - mout*xout
    constr_press1 = pin1 - pout
    constr_press2 = pin2 - pout
    constr = add_constraint(constr, constr_hydraulic, type="lin_hydraulic")
    constr = add_constraint(constr, constr_component, type="bl_component")
    constr = add_constraint(constr, constr_press1, type="nl_pressure")
    constr = add_constraint(constr, constr_press2, type="nl_pressure")

    if sltns is not None:
        sltns = add_sltn(sltns, xout.name, constr_component)
        sltns = add_sltn(sltns, pout.name, constr_press1)
        sltns = add_sltn(sltns, pin2.name, constr_press2)
        return constr, sltns
    return constr


def constr_mex(
        mfeed, mperm, mret, 
        xfeed,xperm, xret, 
        pfeed, pperm, pret,
        a, ap, app, 
        b, bp, 
        c, cp,
        constr, sltns = None
    ):
    """
    a, ap, app are A, A', and A'' for the osmotic pressure constraint
    b, bp are B, B' for the diffusion conc constraint
    c is constant in major loss constraint. don't multiply by masses b/c they 
    are already squared
    """

    constr_hyd = mfeed - mperm  - mret
    constr_component = xfeed*mfeed -xperm*mperm -xret*mret
    constr_osm_press = a*mperm - ap*((pfeed - pperm) - app*(xfeed-xperm))
    constr_diff_conc = b*mperm*xperm - bp*(xfeed-xperm) # xperm may be very small, but for integer example maybe it's better to keep
    constr_major_loss = c*(pfeed - pret) - cp*(mfeed+mret)**2 # major loss from feed to concentrate side. 
    constr = add_constraint(constr, constr_hyd, type="lin_hydraulic")
    constr = add_constraint(constr, constr_component, type="bl_component")
    constr = add_constraint(constr, constr_osm_press, type="nl_pressure")
    constr = add_constraint(constr, constr_diff_conc, type="nl_component")
    constr = add_constraint(constr, constr_major_loss, type="nl_pressure")

    # return solutions for the membrane constraint equations
    if sltns is not None:
        sltns = add_sltn(sltns, xperm.name, constr_diff_conc)
        sltns = add_sltn(sltns, pperm.name, constr_osm_press)
        sltns = add_sltn(sltns, pret.name, constr_major_loss)
        return constr, sltns
    
    return constr


def constr_pump(
        min, mout,
        xin, xout,
        pin, pout,
        E, H, G, J,
        constr, sltns=None,
    ):
    """
    E: second order coefficient (Ex^2)
    H: first order coefficient (Hx)
    G: intercept for pump curve
    J: unit conversion for pressure (arbitrary)
    """
    constr_mass = mout - min
    constr_conc = xout - xin
    constr_pump = J*(pout - pin) - (E*min**2 + H*min + G)
    constr = add_constraint(constr, constr_mass, type="lin_hydraulic")
    constr = add_constraint(constr, constr_conc, type="bl_component")
    constr = add_constraint(constr, constr_pump, type="nl_pressure")

    if sltns is not None:
        sltns = add_sltn(sltns, xout.name, constr_conc)
        sltns = add_sltn(sltns, pout.name, constr_pump)
        return constr, sltns
    return constr
    


def constr_valve(
        min, mout,
        xin, xout,
        pin, pout,
        K, L, 
        constr, sltns= None,
    ):
    # # see if multiple a, b, c, and ds defined
    # A, B, C, D = sy.symbols("A, B, C, D")
    # varb_symbs = [A, B, C, D]
    # labels = [ "A", "B", "C", "D"]
    # varbs.update(dict(zip(labels, varb_symbs)))

    constr_mass = mout - min
    constr_conc = xout - xin
    constr_major_loss = L*(pin - pout) - K*(min)**2 # major loss from feed to concentrate side.
    
    constr = add_constraint(constr, constr_mass, type="lin_hydraulic")
    constr = add_constraint(constr, constr_conc, type="bl_component")
    constr = add_constraint(constr, constr_major_loss, type="nl_pressure")

    if sltns is not None:
        sltns = add_sltn(sltns, xout.name, constr_conc)
        sltns = add_sltn(sltns, pout.name, constr_major_loss)
        return constr, sltns
    return constr