Commit 3972c494 authored by cianciosa's avatar cianciosa
Browse files

Add reductions of sqrts and powers. Refactor equilibrium to splimify expressions.

parent 8574c2ee
Loading
Loading
Loading
Loading
+52 −0
Original line number Diff line number Diff line
@@ -3753,6 +3753,30 @@ namespace graph {
                    }
                }

//  fma(a,b*c,b*d) -> b*fma(a,c,d)
//  fma(a,c*b,b*d) -> b*fma(a,c,d)
//  fma(a,b*c,d*b) -> b*fma(a,c,d)
//  fma(a,c*b,d*b) -> b*fma(a,c,d)
                if (mm.get()) {
                    if (mm->get_left()->is_match(rm->get_left())) {
                        return mm->get_left()*fma(this->left,
                                                  mm->get_right(),
                                                  rm->get_right());
                    } else if (mm->get_left()->is_match(rm->get_right())) {
                        return mm->get_left()*fma(this->left,
                                                  mm->get_right(),
                                                  rm->get_left());
                    } else if (mm->get_right()->is_match(rm->get_left())) {
                        return mm->get_right()*fma(this->left,
                                                   mm->get_left(),
                                                   rm->get_right());
                    } else if (mm->get_right()->is_match(rm->get_right())) {
                        return mm->get_right()*fma(this->left,
                                                   mm->get_left(),
                                                   rm->get_left());
                    }
                }

//  Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c)
//  Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c)
                if ((lm.get() || mm.get()) &&
@@ -3850,6 +3874,19 @@ namespace graph {
                }
            }

//  fma(a,b*c,b) -> b*fma(a,c,1)
            if (mm.get()) {
                if (mm->get_left()->is_match(this->right)) {
                    return mm->get_left()*fma(this->left,
                                              mm->get_right(),
                                              1.0);
                } else if (mm->get_right()->is_match(this->right)) {
                    return mm->get_right()*fma(this->left,
                                              mm->get_left(),
                                              1.0);
                }
            }

//  fma(c1,a,c2/b) -> c1*(a + c3/b)
//  fma(a,c1,c2/b) -> c1*(a + c3/b)
            auto rd = divide_cast(this->right);
@@ -4435,6 +4472,21 @@ namespace graph {
                                       rfma->get_middle()),
                                   rfma->get_right());
                    }
                    
//  fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c))
                    auto rfmamm = multiply_cast(rfma->get_middle());
                    if (rfmamm.get()) {
                        if (is_variable_combineable(mplm->get_left(),
                                                    rfmamm->get_left())) {
                            auto temp = pow(mplm->get_left(),
                                            mp->get_right());
                            return temp*fma(this->left,
                                            this->middle/temp,
                                            fma(rfma->get_left(),
                                                rfma->get_middle()/temp,
                                                rfma->get_right()));
                        }
                    }
                }
            }

+12 −12
Original line number Diff line number Diff line
@@ -1073,18 +1073,18 @@ namespace equilibrium {
                   + c01_temp*z_norm
                   + c02_temp*(z_norm*z_norm)
                   + c03_temp*(z_norm*z_norm*z_norm)
                   + c10_temp*r_norm
                   + c11_temp*r_norm*z_norm
                   + c12_temp*r_norm*(z_norm*z_norm)
                   + c13_temp*r_norm*(z_norm*z_norm*z_norm)
                   + c20_temp*(r_norm*r_norm)
                   + c21_temp*(r_norm*r_norm)*z_norm
                   + c22_temp*(r_norm*r_norm)*(z_norm*z_norm)
                   + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm)
                   + c30_temp*(r_norm*r_norm*r_norm)
                   + c31_temp*(r_norm*r_norm*r_norm)*z_norm
                   + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm)
                   + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm);
                   + r_norm*(c10_temp +
                             c11_temp*z_norm +
                             c12_temp*(z_norm*z_norm) +
                               c13_temp*(z_norm*z_norm*z_norm))
                   + (r_norm*r_norm)*(c20_temp +
                                      c21_temp*z_norm +
                                      c22_temp*(z_norm*z_norm) +
                                      c23_temp*(z_norm*z_norm*z_norm))
                   + (r_norm*r_norm*r_norm)*(c30_temp +
                                             c31_temp*z_norm +
                                             c32_temp*(z_norm*z_norm) +
                                             c33_temp*(z_norm*z_norm*z_norm));
        }

//------------------------------------------------------------------------------
+3 −2
Original line number Diff line number Diff line
@@ -143,11 +143,11 @@ void run_korc() {
            const timeing::measure_diagnostic t_run("Run Time");
            work.pre_run();
            for (size_t i = 0; i < 1000000; i++) {
                sync.join();
/*                sync.join();
                work.wait();
                sync = std::thread([&file, &dataset] () -> void {
                    dataset.write(file);
                });
                });*/
                
                work.run();
            }
@@ -181,6 +181,7 @@ int main(int argc, const char * argv[]) {
    (void)argv;

    run_korc<double> ();
//    run_korc<float> ();

    END_GPU
}
+84 −8
Original line number Diff line number Diff line
@@ -2822,7 +2822,7 @@ template<jit::float_scalar T> void test_fma() {
           "Expected common var_b");

//  fma(a, b, fma(c, b, d)) -> fma(b, a + c, d)
    auto var_d = graph::variable<T> (1, "");
    auto var_d = graph::variable<T> (1, "d");
    auto match1 = graph::fma(var_b, var_a + var_c, var_d);
    auto nested_fma1 = graph::fma(var_a, var_b, 
                                  graph::fma(var_c, var_b, var_d));
@@ -3578,11 +3578,11 @@ template<jit::float_scalar T> void test_fma() {
    auto commom_power_cast = graph::multiply_cast(commom_power);
    assert(commom_power_cast.get() && "Expected a multiply node.");
    assert(commom_power_cast->get_right()->is_match(var_b) &&
           "Expced b");
           "Expeced b");
    assert(commom_power_cast->get_left()->is_match(graph::pow(var_a, 2.0) *
                                                   graph::fma(2.0, var_b, 1.0)) &&
           "Expced a^2*fma(2, b, 1)");
//  fma(2,(ba)^2,a^2b) -> a^2*fma(2, b^2, b)
           "Expeced a^2*fma(2, b, 1)");
//  fma(2,(ba)^2,a^2b) -> a^2*b*fma(2, b, 1)
    auto commom_power2 = graph::fma(2.0, graph::pow(var_b*var_a, 2.0), graph::pow(var_a, 2.0)*var_b);
    auto commom_power2_cast = graph::multiply_cast(commom_power2);
    assert(commom_power2_cast.get() && "Expected a multiply node.");
@@ -3591,7 +3591,7 @@ template<jit::float_scalar T> void test_fma() {
    assert(commom_power2_cast->get_left()->is_match(graph::pow(var_a, 2.0) *
                                                    graph::fma(2.0, var_b, 1.0)) &&
           "Expced a^2*fma(2, b, 1)");
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c)
    auto commom_power3 = graph::fma(2.0,
                                    graph::pow(var_a*var_b, 2.0),
                                    graph::fma(graph::pow(var_a, 2.0),
@@ -3601,7 +3601,12 @@ template<jit::float_scalar T> void test_fma() {
    assert(commom_power3_cast.get() && "Expected a fma node.");
    assert(commom_power3_cast->get_left()->is_match(graph::pow(var_a, 2.0)) &&
           "Expected a^2");
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
    assert(commom_power3_cast->get_middle()->is_match(var_b*graph::fma(2.0,
                                                                       var_b,
                                                                       1.0)) &&
           "Expected b*fma(2,b,1)");
    assert(commom_power3_cast->get_right()->is_match(var_c) && "Expected c");
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2*b,fma(2,b,1),c)
    auto commom_power4 = graph::fma(2.0,
                                    graph::pow(var_b*var_a, 2.0),
                                    graph::fma(graph::pow(var_a, 2.0),
@@ -3611,6 +3616,11 @@ template<jit::float_scalar T> void test_fma() {
    assert(commom_power4_cast.get() && "Expected a fma node.");
    assert(commom_power4_cast->get_left()->is_match(graph::pow(var_a, 2.0)) &&
           "Expected a^2");
    assert(commom_power4_cast->get_middle()->is_match(var_b*graph::fma(2.0,
                                                                       var_b,
                                                                       1.0)) &&
           "Expected b*fma(2,b,1)");
    assert(commom_power4_cast->get_right()->is_match(var_c) && "Expected c");

//  fma(2,a^2,a) -> a*fma(2,a,1)
    auto common_power5 = graph::fma(2.0,var_a*var_a,var_a);
@@ -3621,14 +3631,80 @@ template<jit::float_scalar T> void test_fma() {
    assert(commom_power5_cast->get_right()->is_match(var_a) &&
           "Expected a.");
//  fma(2,a,a^2) -> a*(2 + a)
    auto temp = var_a*var_a;
    auto common_power6 = graph::fma(2.0,var_a,temp);
    auto common_power6 = graph::fma(2.0,var_a,var_a*var_a);
    auto commom_power6_cast = graph::multiply_cast(common_power6);
    assert(commom_power6_cast.get() && "Expected a multiply node.");
    assert(commom_power6_cast->get_left()->is_match(2.0 + var_a) &&
           "Expected (2 + a).");
    assert(commom_power6_cast->get_right()->is_match(var_a) &&
           "Expected a.");

//  fma(2,(a*b)^2,fma(3,a^2*b,c)) -> fma(a^2*b,fma(2,b,3),c)
    auto common_power7 = graph::fma(2.0,
                                    graph::pow(var_a*var_b,
                                               2.0),
                                    graph::fma(3.0,
                                               var_a*var_a*var_b,
                                               var_c));
    auto common_power7_cast = graph::multiply_cast(common_power7);
    assert(common_power7_cast.get() && "Expected a multiply node.");
    assert(common_power7_cast->get_left()->is_match(graph::fma(var_b,
                                                               graph::fma(2.0,
                                                                          var_b,
                                                                          3.0),
                                                               var_c)) &&
           "Expected fma(b,fma(2,b,3),c)");
    assert(common_power7_cast->get_right()->is_match(var_a*var_a) &&
           "Expected a^2");

//  fma(a,b*c,b) -> b*fma(a,c,1)
    auto factorize = graph::fma(var_a,var_b*var_c,var_b);
    auto factorize_cast = multiply_cast(factorize);
    assert(factorize_cast.get() && "Expected a multiply node.");
    assert(factorize_cast->get_right()->is_match(var_b) &&
           "Expected b.");
    assert(factorize_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) &&
           "Expected a*c + 1.");
//  fma(a,c*b,b) -> b*fma(a,c,1)
    auto factorize2 = graph::fma(var_a,var_c*var_b,var_b);
    auto factorize2_cast = multiply_cast(factorize2);
    assert(factorize2_cast.get() && "Expected a multiply node.");
    assert(factorize2_cast->get_right()->is_match(var_b) &&
           "Expected b.");
    assert(factorize2_cast->get_left()->is_match(graph::fma(var_a,var_c,1.0)) &&
           "Expected a*c + 1.");
//  fma(a,b*c,b*d) -> b*fma(a,c,d)
    auto factorize3 = graph::fma(var_a,var_b*var_c,var_b*var_d);
    auto factorize3_cast = multiply_cast(factorize3);
    assert(factorize3_cast.get() && "Expected a multiply node.");
    assert(factorize3_cast->get_right()->is_match(var_b) &&
           "Expected b.");
    assert(factorize3_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) &&
           "Expected a*c + d.");
//  fma(a,c*b,b*d) -> b*fma(a,c,d)
    auto factorize4 = graph::fma(var_a,var_c*var_b,var_b*var_d);
    auto factorize4_cast = multiply_cast(factorize4);
    assert(factorize4_cast.get() && "Expected a multiply node.");
    assert(factorize4_cast->get_right()->is_match(var_b) &&
           "Expected b.");
    assert(factorize4_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) &&
           "Expected a*c + d.");
//  fma(a,b*c,d*b) -> b*fma(a,c,d)
    auto factorize5 = graph::fma(var_a,var_b*var_c,var_d*var_b);
    auto factorize5_cast = multiply_cast(factorize5);
    assert(factorize5_cast.get() && "Expected a multiply node.");
    assert(factorize5_cast->get_right()->is_match(var_b) &&
           "Expected b.");
    assert(factorize5_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) &&
           "Expected a*c + d.");
//  fma(a,c*b,d*b) -> b*fma(a,c,d)
    auto factorize6 = graph::fma(var_a,var_c*var_b,var_d*var_b);
    auto factorize6_cast = multiply_cast(factorize6);
    assert(factorize6_cast.get() && "Expected a multiply node.");
    assert(factorize6_cast->get_right()->is_match(var_b) &&
           "Expected b.");
    assert(factorize6_cast->get_left()->is_match(graph::fma(var_a,var_c,var_d)) &&
           "Expected a*c + d.");
}

//------------------------------------------------------------------------------
+4 −6
Original line number Diff line number Diff line
@@ -139,8 +139,6 @@ void run_test() {

    auto bvec = eq->get_magnetic_field(x, y, z);
    auto ne = eq->get_electron_density(x, y, z);
    ne->to_latex();
    std::cout << std::endl << std::endl;
    auto te = eq->get_electron_temperature(x, y, z);

    workflow::manager<T> work(0);
@@ -155,15 +153,15 @@ void run_test() {
    work.run();

    for (size_t i = 0, ie = gold.r_grid.size()*gold.z_grid.size(); i < ie; i++) {
        check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 4.0E-11,
        check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 9.0E-12,
                    "Expected a match in bx.");
        check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20,
                    "Expected a match in by.");
        check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 3.0E-12,
        check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 4.0E-12,
                    "Expected a match in bz.");
        check_error(work.check_value(i, ne), gold.ne_grid[i], 5.0E-13,
        check_error(work.check_value(i, ne), gold.ne_grid[i], 8.0E-13,
                    "Expected a match in ne.");
        check_error(work.check_value(i, te), gold.te_grid[i], 5.0E-13,
        check_error(work.check_value(i, te), gold.te_grid[i], 8.0E-13,
                    "Expected a match in te.");
    }
}