Commit f78a6fac authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add reductions for fma nodes and reorder add/substract division nodes.

parent 4a7f1f4d
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -79,15 +79,15 @@ int main(int argc, const char * argv[]) {
            ky->set(backend::base_cast<cpu> (0.0));
            kz->set(backend::base_cast<cpu> (0.0));
            
            //auto eq = equilibrium::make_slab_density<cpu> ();
            auto eq = equilibrium::make_no_magnetic_field<cpu> ();
            auto eq = equilibrium::make_slab_density<cpu> ();
            //auto eq = equilibrium::make_no_magnetic_field<cpu> ();

            //solver::split_simplextic<dispersion::bohm_gross<cpu>>
            //solver::rk4<dispersion::bohm_gross<cpu>>
            //solver::rk4<dispersion::simple<cpu>>
            //solver::rk4<dispersion::ordinary_wave<cpu>>
            //solver::rk4<dispersion::extra_ordinary_wave<cpu>>
            solver::rk4<dispersion::cold_plasma<cpu>>
            solver::rk4<dispersion::extra_ordinary_wave<cpu>>
            //solver::rk4<dispersion::cold_plasma<cpu>>
                solve(omega, kx, ky, kz, x, y, z, t, 30.0/num_times, eq);
            solve.init(kx);
            if (thread_number == 0) {
−4.59 KiB (169 KiB)

File changed.

No diff preview for this file type.

+6 −6
Original line number Diff line number Diff line
@@ -7,17 +7,17 @@
		<key>arithmetic_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>8</integer>
			<integer>7</integer>
		</dict>
		<key>backend_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>7</integer>
			<integer>4</integer>
		</dict>
		<key>dispersion_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>6</integer>
			<integer>9</integer>
		</dict>
		<key>graph_driver.xcscheme_^#shared#^_</key>
		<dict>
@@ -42,7 +42,7 @@
		<key>node_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>4</integer>
			<integer>3</integer>
		</dict>
		<key>physics_test.xcscheme_^#shared#^_</key>
		<dict>
@@ -52,12 +52,12 @@
		<key>solver_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>3</integer>
			<integer>6</integer>
		</dict>
		<key>vector_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>9</integer>
			<integer>8</integer>
		</dict>
	</dict>
	<key>SuppressBuildableAutocreation</key>
+47 −4
Original line number Diff line number Diff line
@@ -126,6 +126,22 @@ namespace graph {
                return (ld->get_left() + rd->get_left())/ld->get_right();
            }

//  Move cases like
//  (c1 + c2/x) + c3/y -> c1 + (c2/x + c3/y)
//  (c1 - c2/x) + c3/y -> c1 + (c3/y - c2/x)
//  in case of common denominators.
            if (rd.get()) {
                auto la = add_cast(this->left);
                if (la.get() && divide_cast(la->get_right()).get()) {
                    return la->get_left() + (la->get_right() + this->right);
                }
                            
                auto ls = subtract_cast(this->left);
                if (ls.get() && divide_cast(ls->get_right()).get()) {
                    return ls->get_left() + (this->right - ls->get_right());
                }
            }

//  Fused multiply add reductions.
            auto m = multiply_cast(this->left);

@@ -135,7 +151,7 @@ namespace graph {
                           this->right);
            }

//  Handel cases like:
//  Handle cases like:
//  (a/y)^e + b/y^e -> (a^2 + b)/(y^e)
//  b/y^e + (a/y)^e -> (b + a^2)/(y^e)
//  (a/y)^e + (b/y)^e -> (a^2 + b^2)/(y^e)
@@ -187,6 +203,12 @@ namespace graph {
                }
            }

            if (lfma.get()) {
                return fma(lfma->get_left(),
                           lfma->get_middle(),
                           lfma->get_right() + this->right);
            }
            
            return this->shared_from_this();
        }

@@ -397,7 +419,23 @@ namespace graph {
                return (ld->get_left() - rd->get_left())/ld->get_right();
            }

//  Handel cases like:
//  Move cases like
//  (c1 + c2/x) - c3/y -> c1 + (c2/x - c3/y)
//  (c1 - c2/x) - c3/y -> c1 - (c2/x + c3/y)
//  in case of common denominators.
            if (rd.get()) {
                auto la = add_cast(this->left);
                if (la.get() && divide_cast(la->get_right()).get()) {
                    return la->get_left() + (la->get_right() - this->right);
                }
                
                auto ls = subtract_cast(this->left);
                if (ls.get() && divide_cast(ls->get_right()).get()) {
                    return ls->get_left() - (this->right + ls->get_right());
                }
            }

//  Handle cases like:
//  (a/y)^e - b/y^e -> (a^2 - b)/(y^e)
//  b/y^e - (a/y)^e -> (b - a^2)/(y^e)
//  (a/y)^e - (b/y)^e -> (a^2 - b^2)/(y^e)
@@ -1270,6 +1308,11 @@ namespace graph {
                }
            }

//  Promote constants out to the left.
            if (l.get() && r.get()) {
                return this->left*(this->middle + this->right/this->left);
            }
            
            return this->shared_from_this();
        }

+38 −6
Original line number Diff line number Diff line
@@ -184,10 +184,25 @@ template<typename BACKEND> void test_add() {
    assert(negate2_cast->get_left()->is_match(var_b) && "Expected var_b.");

//  (c1*v1 + c2) + (c3*v1 + c4) -> c5*v1 + c6
    auto addfma = graph::fma(three, var_a, one)
                + graph::fma(three, var_a, one);
    auto addfma_cast = graph::fma_cast(addfma);
    assert(addfma_cast.get() && "Expected fused multiply add node.");
    auto addfma = graph::fma(var_b, var_a, var_d)
                + graph::fma(var_c, var_a, var_d);
    assert(graph::fma_cast(addfma).get() &&
           "Expected fused multiply add node.");

//  Test cases like
//  (c1 + c2/x) + c3/x -> c1 + c4/x
//  (c1 - c2/x) + c3/x -> c1 + c4/x
    common_d = (one + three/var_a) + (one/var_a);
    auto common_d_acast = graph::add_cast(common_d);
    assert(common_d_acast.get() && "Expected add node.");
    assert(graph::constant_cast(common_d_acast->get_left()).get() &&
           "Expected constant on the left.");
    
    common_d = (one - three/var_a) + (one/var_a);
    common_d_acast = graph::add_cast(common_d);
    assert(common_d_acast.get() && "Expected add node.");
    assert(graph::constant_cast(common_d_acast->get_left()).get() &&
           "Expected constant on the left.");
}

//------------------------------------------------------------------------------
@@ -363,8 +378,22 @@ template<typename BACKEND> void test_subtract() {
    auto three = graph::constant<BACKEND> (3);
    auto subfma = graph::fma(three, var_a, two)
                + graph::fma(two, var_a, three);
    auto subfma_cast = graph::fma_cast(subfma);
        assert(subfma_cast.get() && "Expected fused multiply add node.");
    assert(graph::fma_cast(subfma).get() && "Expected fused multiply add node.");

//  Test cases like
//  (c1 + c2/x) - c3/x -> c1 + c4/x
//  (c1 - c2/x) - c3/x -> c1 - c4/x
    common_d = (one + three/var_a) - (one/var_a);
    auto common_d_acast = graph::add_cast(common_d);
    assert(common_d_acast.get() && "Expected add node.");
    assert(graph::constant_cast(common_d_acast->get_left()).get() &&
           "Expected constant on the left.");
        
    common_d = (one - three/var_a) - (one/var_a);
    auto common_d_scast = graph::subtract_cast(common_d);
    assert(common_d_scast.get() && "Expected subtract node.");
    assert(graph::constant_cast(common_d_scast->get_left()).get() &&
           "Expected constant on the left.");
}

//------------------------------------------------------------------------------
@@ -1272,6 +1301,9 @@ template<typename BACKEND> void test_fma() {
    assert(reduce4_cast.get() && "Expected multiply node.");
    assert(reduce4_cast->get_right()->is_match(var_b) &&
           "Expected common var_b");

    assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() &&
           "Expected multiply node.");
}

//------------------------------------------------------------------------------
Loading