Commit 8574c2ee authored by cianciosa's avatar cianciosa
Browse files

Save work in progress.

parent 1e3d6113
Loading
Loading
Loading
Loading
+54 −8
Original line number Diff line number Diff line
@@ -4341,8 +4341,10 @@ namespace graph {
                }
            } else if (this->middle->is_all_variables()) {
                auto rdm = this->right/this->middle;
                if (rdm->get_complexity() < this->middle->get_complexity() +
                                            this->right->get_complexity()) {
                auto rdmc = constant_cast(rdm->get_power_exponent());
                if ((rdm->get_complexity() < this->middle->get_complexity() +
                                             this->right->get_complexity()) &&
                    !(rdmc.get() && rdmc->evaluate().is_negative())) {
                    return (this->left + rdm)*this->middle;
                }
            }
@@ -4365,6 +4367,19 @@ namespace graph {
                    return this->left/pow(mp->get_left(), -mp->get_right()) +
                           this->right;
                }

//  fma(2,a^2,a) -> a*fma(2,a,1)
//  Note this case is handled eailer. fma(2,a,a^2) -> a*fma(2,1,a)
                if (is_variable_combineable(this->middle,
                                            this->right)) {
                    auto temp = this->right/this->middle;
                    auto temp_exponent = constant_cast(temp->get_power_exponent());
                    if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) {
                        return this->right*fma(this->left,
                                               this->middle/this->right,
                                               1.0);
                    }
                }
            }

//  a^b*c^b + d -> (a*c)^b + d
@@ -4382,12 +4397,43 @@ namespace graph {
                if (mplm.get()) {
                    if (is_variable_combineable(mplm->get_left(),
                                                rm->get_left())) {
                        return pow(mplm->get_left(),
                                   mp->get_right()) *
                        auto temp = pow(mplm->get_left(),
                                        mp->get_right());
                        return temp*fma(this->left,
                                        this->middle/temp,
                                        this->right/temp);
                    } else if (is_variable_combineable(mplm->get_right(),
                                                       rm->get_left())) {
                        auto temp = pow(mplm->get_right(),
                                        mp->get_right());
                        return temp*fma(this->left,
                                        this->middle/temp,
                                        this->right/temp);
                    }
                }
            }
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
            if (rfma.get() && mp.get()) {
                auto mplm = multiply_cast(mp->get_left());
                if (mplm.get()) {
                    if (is_variable_combineable(mplm->get_left(),
                                                rfma->get_left())) {
                        auto temp = pow(mplm->get_left(),
                                        mp->get_right());
                        return fma(temp,
                                   fma(this->left,
                                   pow(mplm->get_right(),
                                       mp->get_right()),
                                   this->right/mplm->get_left());
                                       this->middle/temp,
                                       rfma->get_middle()),
                                   rfma->get_right());
                    } else if (is_variable_combineable(mplm->get_right(),
                                                       rfma->get_left())) {
                        auto temp = pow(mplm->get_right(),
                                        mp->get_right());
                        return fma(temp,
                                   fma(this->left,
                                       this->middle/temp,
                                       rfma->get_middle()),
                                   rfma->get_right());
                    }
                }
            }
+13 −0
Original line number Diff line number Diff line
@@ -660,6 +660,19 @@ namespace graph {
        return constant<T, SAFE_MATH> (static_cast<T> (1.0));
    }

//------------------------------------------------------------------------------
///  @brief Create a one constant.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A one constant.
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH>
    constexpr shared_leaf<T, SAFE_MATH> none() {
        return constant<T, SAFE_MATH> (static_cast<T> (-1.0));
    }

///  Convinece type for imaginary constant.
    template<jit::complex_scalar T>
    constexpr T i = T(0.0, 1.0);
+54 −5
Original line number Diff line number Diff line
@@ -2773,7 +2773,7 @@ template<jit::float_scalar T> void test_fma() {
//  Test reduction.
    auto var_a = graph::variable<T> (1, "a");
    auto var_b = graph::variable<T> (1, "b");
    auto var_c = graph::variable<T> (1, "");
    auto var_c = graph::variable<T> (1, "c");

//  fma(1,a,b) = a + b
    auto one_times_vara_plus_varb = graph::fma(one, var_a, var_b);
@@ -3573,13 +3573,62 @@ template<jit::float_scalar T> void test_fma() {
                                             var_b) + var_d) &&
           "Expected a power node.");

//  fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b)
//  fma(2,(ab)^2,a^2b) -> a^2*b*fma(2, b, 1)
    auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b);
    commom_power->to_latex();
    std::cout << std::endl << std::endl;
    auto commom_power_cast = graph::multiply_cast(commom_power);
    assert(commom_power_cast.get() && "Expected a multiply node.");
//  fma(2,(a*b)^2,fma())
    assert(commom_power_cast->get_right()->is_match(var_b) &&
           "Expced b");
    assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) *
                                                   graph::fma(2.0, var_b, 1.0)) &&
           "Expced a^2*fma(2, b, 1)");
//  fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b)
    auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b);
    auto commom_power2_cast = graph::multiply_cast(commom_power2);
    assert(commom_power2_cast.get() && "Expected a multiply node.");
    assert(commom_power2_cast->get_right()->is_match(var_b) &&
           "Expced b");
    assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) *
                                                    graph::fma(2.0, var_b, 1.0)) &&
           "Expced a^2*fma(2, b, 1)");
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
    auto commom_power3 = graph::fma(2.0,
                                    graph::pow(var_a*var_b, 2.0),
                                    graph::fma(graph::pow(var_a, 2.0),
                                               var_b,
                                               var_c));
    auto commom_power3_cast = graph::fma_cast(commom_power3);
    assert(commom_power3_cast.get() && "Expected a fma node.");
    assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) &&
           "Expected a^2");
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
    auto commom_power4 = graph::fma(2.0,
                                    graph::pow(var_b*var_a, 2.0),
                                    graph::fma(graph::pow(var_a, 2.0),
                                               var_b,
                                               var_c));
    auto commom_power4_cast = graph::fma_cast(commom_power4);
    assert(commom_power4_cast.get() && "Expected a fma node.");
    assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) &&
           "Expected a^2");

//  fma(2,a^2,a) -> a*fma(2,a,1)
    auto common_power5 = graph::fma(2.0,var_a*var_a,var_a);
    auto commom_power5_cast = graph::multiply_cast(common_power5);
    assert(commom_power5_cast.get() && "Expected a multiply node.");
    assert(commom_power5_cast->get_left()->is_match(graph::fma(2.0,var_a,1.0)) &&
           "Expected fma(2,a,1).");
    assert(commom_power5_cast->get_right()->is_match(var_a) &&
           "Expected a.");
//  fma(2,a,a^2) -> a*(2 + a)
    auto temp = var_a*var_a;
    auto common_power6 = graph::fma(2.0,var_a,temp);
    auto commom_power6_cast = graph::multiply_cast(common_power6);
    assert(commom_power6_cast.get() && "Expected a multiply node.");
    assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) &&
           "Expected (2 + a).");
    assert(commom_power6_cast->get_right()->is_match(var_a) &&
           "Expected a.");
}

//------------------------------------------------------------------------------
+2 −0
Original line number Diff line number Diff line
@@ -139,6 +139,8 @@ void run_test() {

    auto bvec = eq->get_magnetic_field(x, y, z);
    auto ne = eq->get_electron_density(x, y, z);
    ne->to_latex();
    std::cout << std::endl << std::endl;
    auto te = eq->get_electron_temperature(x, y, z);

    workflow::manager<T> work(0);