Commit 4a7f1f4d authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add tests for addition and subtraction of fma nodes.

parent 63c3cba6
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -79,8 +79,8 @@ int main(int argc, const char * argv[]) {
            ky->set(backend::base_cast<cpu> (0.0));
            kz->set(backend::base_cast<cpu> (0.0));
            
            auto eq = equilibrium::make_slab_density<cpu> ();
            //auto eq = equilibrium::make_no_magnetic_field<cpu> ();
            //auto eq = equilibrium::make_slab_density<cpu> ();
            auto eq = equilibrium::make_no_magnetic_field<cpu> ();

            //solver::split_simplextic<dispersion::bohm_gross<cpu>>
            //solver::rk4<dispersion::bohm_gross<cpu>>
−1.15 KiB (174 KiB)

File changed.

No diff preview for this file type.

+31 −15
Original line number Diff line number Diff line
@@ -130,9 +130,7 @@ namespace graph {
            auto m = multiply_cast(this->left);

            if (m.get()) {
                return fma<leaf_node<typename LN::backend>,
                           leaf_node<typename LN::backend>,
                           leaf_node<typename RN::backend>> (m->get_left(),
                return fma(m->get_left(),
                           m->get_right(),
                           this->right);
            }
@@ -178,6 +176,17 @@ namespace graph {
                }
            }

            auto lfma = fma_cast(this->left);
            auto rfma = fma_cast(this->right);
            
            if (lfma.get() && rfma.get()) {
                if (lfma->get_middle()->is_match(rfma->get_middle())) {
                    return fma(lfma->get_left() + rfma->get_left(),
                               lfma->get_middle(),
                               lfma->get_right() + rfma->get_right());
                }
            }
            
            return this->shared_from_this();
        }

@@ -429,6 +438,17 @@ namespace graph {
                }
            }
            
            auto lfma = fma_cast(this->left);
            auto rfma = fma_cast(this->right);
            
            if (lfma.get() && rfma.get()) {
                if (lfma->get_middle()->is_match(rfma->get_middle())) {
                    return fma(lfma->get_left() - rfma->get_left(),
                               lfma->get_middle(),
                               lfma->get_right() - rfma->get_right());
                }
            }
            
            return this->shared_from_this();
        }

@@ -1267,15 +1287,11 @@ namespace graph {
                return constant<typename LN::backend> (1);
            }

            auto temp_right = fma<LN,
                                  leaf_node<typename MN::backend>,
                                  leaf_node<typename RN::backend>> (this->left,
            auto temp_right = fma(this->left,
                                  this->middle->df(x),
                                  this->right->df(x));

            return fma<leaf_node<typename LN::backend>,
                       MN,
                       leaf_node<typename RN::backend>> (this->left->df(x),
            return fma(this->left->df(x),
                       this->middle,
                       temp_right);
        }
+13 −0
Original line number Diff line number Diff line
@@ -182,6 +182,12 @@ template<typename BACKEND> void test_add() {
    auto negate2_cast = subtract_cast(negate2);
    assert(negate2_cast.get() && "Expected subtract node.");
    assert(negate2_cast->get_left()->is_match(var_b) && "Expected var_b.");

//  (c1*v1 + c2) + (c3*v1 + c4) -> c5*v1 + c6
    auto addfma = graph::fma(three, var_a, one)
                + graph::fma(three, var_a, one);
    auto addfma_cast = graph::fma_cast(addfma);
    assert(addfma_cast.get() && "Expected fused multiply add node.");
}

//------------------------------------------------------------------------------
@@ -352,6 +358,13 @@ template<typename BACKEND> void test_subtract() {
    auto negate = var_a - graph::constant<BACKEND> (-2)*var_b;
    auto negate_cast = add_cast(negate);
    assert(negate_cast.get() && "Expected addition node.");

//  (c1*v1 + c2) - (c3*v1 + c4) -> c5*v1 - c6
    auto three = graph::constant<BACKEND> (3);
    auto subfma = graph::fma(three, var_a, two)
                + graph::fma(two, var_a, three);
    auto subfma_cast = graph::fma_cast(subfma);
        assert(subfma_cast.get() && "Expected fused multiply add node.");
}

//------------------------------------------------------------------------------