Commit d0195c12 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add reductions that simplify the o-mode wave to simplist form.

parent f78a6fac
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -30,8 +30,8 @@ static base solution(const base t) {
///  @param[in] argv Array of commandline arguments.
//------------------------------------------------------------------------------
int main(int argc, const char * argv[]) {
    typedef std::complex<double> base;
    //typedef double base;
    //typedef std::complex<double> base;
    typedef double base;
    //typedef float base;
    //typedef std::complex<float> base;
    typedef backend::cpu<base> cpu;
@@ -39,8 +39,8 @@ int main(int argc, const char * argv[]) {
    const std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();

    const size_t num_times = 10000;
    const size_t num_rays = 1;
    //const size_t num_rays = 10000;
    //const size_t num_rays = 1;
    const size_t num_rays = 10000;
    
    std::vector<std::thread> threads(std::max(std::min(std::thread::hardware_concurrency(),
                                                       static_cast<unsigned int> (num_rays)),
@@ -69,7 +69,7 @@ int main(int argc, const char * argv[]) {

//  Inital conditions.
            for (size_t j = 0; j < local_num_rays; j++) {
                omega->set(j, 600.0);
                omega->set(j, 500.0);
            }

            x->set(backend::base_cast<cpu> (0.0));
@@ -85,10 +85,10 @@ int main(int argc, const char * argv[]) {
            //solver::split_simplextic<dispersion::bohm_gross<cpu>>
            //solver::rk4<dispersion::bohm_gross<cpu>>
            //solver::rk4<dispersion::simple<cpu>>
            //solver::rk4<dispersion::ordinary_wave<cpu>>
            solver::rk4<dispersion::extra_ordinary_wave<cpu>>
            solver::rk4<dispersion::ordinary_wave<cpu>>
            //solver::rk4<dispersion::extra_ordinary_wave<cpu>>
            //solver::rk4<dispersion::cold_plasma<cpu>>
                solve(omega, kx, ky, kz, x, y, z, t, 30.0/num_times, eq);
                solve(omega, kx, ky, kz, x, y, z, t, 60.0/num_times, eq);
            solve.init(kx);
            if (thread_number == 0) {
                solve.print_dispersion();
+2.39 KiB (172 KiB)

File changed.

No diff preview for this file type.

+7 −7
Original line number Diff line number Diff line
@@ -7,17 +7,17 @@
		<key>arithmetic_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>7</integer>
			<integer>8</integer>
		</dict>
		<key>backend_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>4</integer>
			<integer>9</integer>
		</dict>
		<key>dispersion_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>9</integer>
			<integer>3</integer>
		</dict>
		<key>graph_driver.xcscheme_^#shared#^_</key>
		<dict>
@@ -27,7 +27,7 @@
		<key>graph_framework.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>5</integer>
			<integer>6</integer>
		</dict>
		<key>graph_tests.xcscheme_^#shared#^_</key>
		<dict>
@@ -42,7 +42,7 @@
		<key>node_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>3</integer>
			<integer>4</integer>
		</dict>
		<key>physics_test.xcscheme_^#shared#^_</key>
		<dict>
@@ -52,12 +52,12 @@
		<key>solver_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>6</integer>
			<integer>7</integer>
		</dict>
		<key>vector_test.xcscheme_^#shared#^_</key>
		<dict>
			<key>orderHint</key>
			<integer>8</integer>
			<integer>5</integer>
		</dict>
	</dict>
	<key>SuppressBuildableAutocreation</key>
+122 −22
Original line number Diff line number Diff line
@@ -26,6 +26,43 @@ namespace graph {
               (pow_cast(a).get()  && variable_cast(pow_cast(a)->get_left()).get());
    }

//------------------------------------------------------------------------------
///  @brief Get the argument of a variable like object.
///
///  @param[in] a Expression to check.
///  @returns The agument of a.
//------------------------------------------------------------------------------
    template<typename N>
    std::shared_ptr<N> get_argument(std::shared_ptr<N> a) {
        if (variable_cast(a).get()) {
            return a;
        } else if (sqrt_cast(a).get() &&
                   variable_cast(sqrt_cast(a)->get_arg()).get()) {
            return sqrt_cast(a)->get_arg();
        } else if (pow_cast(a).get()  &&
                   variable_cast(pow_cast(a)->get_left()).get()) {
            return pow_cast(a)->get_left();
        }
        assert(false && "Should never reach this point.");
        return nullptr;
    }

//------------------------------------------------------------------------------
///  @brief Check variable like objects are the same.
///
///  @param[in] a Expression to check.
///  @param[in] b Expression to check.
///  @returns True if a is variable like.
//------------------------------------------------------------------------------
    template<typename N>
    bool is_same_variable_like(std::shared_ptr<N> a,
                               std::shared_ptr<N> b) {
        return is_variable_like(a) &&
               is_variable_like(b) &&
               get_argument(a)->is_match(get_argument(b));
        
    }

//******************************************************************************
//  Add node.
//******************************************************************************
@@ -114,6 +151,14 @@ namespace graph {
                } else if (lm->get_right()->is_match(rm->get_right())) {
                    return lm->get_right()*(lm->get_left() + rm->get_left());
                }

//  Change cases like c1*a + c2*b -> c1*(a + c2*b)
                auto lmc = constant_cast(lm->get_left());
                auto rmc = constant_cast(rm->get_left());
                if (lmc.get() && rmc.get()) {
                    return lm->get_left()*(lm->get_right() +
                                           (rm->get_left()/lm->get_left())*rm->get_right());
                }
            }

//  Common denominator reduction. If the left and right are both divide nodes
@@ -407,6 +452,14 @@ namespace graph {
                } else if (lm->get_right()->is_match(rm->get_right())) {
                    return lm->get_right()*(lm->get_left() - rm->get_left());
                }

//  Change cases like c1*a - c2*b -> c1*(a - c2*b)
                auto lmc = constant_cast(lm->get_left());
                auto rmc = constant_cast(rm->get_left());
                if (lmc.get() && rmc.get()) {
                    return lm->get_left()*(lm->get_right() -
                                           (rm->get_left()/lm->get_left())*rm->get_right());
                }
            }

//  Common denominator reduction. If the left and right are both divide nodes
@@ -670,20 +723,6 @@ namespace graph {
                return this->right*this->left;
            }

//  Reduce constants multiplied by fused multiply add nodes.
            auto rfma = fma_cast(this->right);
            if (l.get() && rfma.get()) {
                return fma(this->left*rfma->get_left(),
                           rfma->get_middle(),
                           this->left*rfma->get_right());
            }
            auto lfma = fma_cast(this->left);
            if (r.get() && lfma.get()) {
                return fma(this->right*lfma->get_left(),
                           lfma->get_middle(),
                           this->right*lfma->get_right());
            }

//  Reduce x*x to x^2
            if (this->left->is_match(this->right)) {
                return pow(this->left, constant<typename LN::backend> (2.0));
@@ -702,6 +741,12 @@ namespace graph {
                    return (this->right*lm->get_right())*lm->get_left();
                }

//  Promote constants before variables.
//  (c*v1)*v2 -> c*(v1*v2)
                if (constant_cast(lm->get_left()).get()) {
                    return lm->get_left()*(lm->get_right()*this->right);
                }

//  Assume variables, sqrt of variables, and powers of variables are on the
//  right.
//  (a*v)*b -> a*(v*b)
@@ -727,11 +772,8 @@ namespace graph {
            }

//  v1*(c*v2) -> c*(v1*v2)
//  (c*v1)*v2 -> c*(v1*v2)
            if (rm.get() && constant_cast(rm->get_left()).get()) {
                return rm->get_left()*(this->left*rm->get_right());
            } else if (lm.get() && constant_cast(lm->get_left()).get()) {
                return lm->get_left()*(lm->get_right()*this->right);
            }

//  Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on
@@ -1012,6 +1054,12 @@ namespace graph {
                return constant(this->evaluate());
            }

//  Reduce cases of a/c1 -> c2*a
            if (r.get()) {
                return (constant<typename LN::backend> (1)/this->right) *
                       this->left;
            }

//  Reduce fused multiply divided by constant nodes.
            auto lfma = fma_cast(this->left);
            if (r.get() && lfma.get()) {
@@ -1026,14 +1074,14 @@ namespace graph {

//  Assume constants are always on the left.
//  c1/(c2*v) -> c3/v
//  (c1*v)/c2 -> v/c3
//  (c1*v)/c2 -> c3*v
            if (rm.get() && l.get()) {
                if (constant_cast(rm->get_left()).get()) {
                    return (this->left/rm->get_left())/rm->get_right();
                }
            } else if (lm.get() && r.get()) {
                if (constant_cast(lm->get_left()).get()) {
                    return lm->get_right()/(this->right/lm->get_left());
                    return (lm->get_left()/this->right)*lm->get_right();
                }
            }

@@ -1295,7 +1343,6 @@ namespace graph {
//  Common factor reduction. If the left and right are both multiply nodes check
//  for a common factor. So you can change a*b + (a*c) -> a*(b + c).
            auto rm = multiply_cast(this->right);

            if (rm.get()) {
                if (rm->get_left()->is_match(this->left)) {
                    return this->left*(this->middle + rm->get_right());
@@ -1308,6 +1355,45 @@ namespace graph {
                }
            }

//  Handle cases like.
//  fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d)
            auto lm = multiply_cast(this->left);
            if (lm.get() && rm.get()) {
                auto rmc = constant_cast(rm->get_left());
                if (rmc.get()) {
                    return lm->get_left()*fma(lm->get_right(),
                                              this->middle,
                                              (rm->get_left()/lm->get_left())*rm->get_right());
                }
            }
//  fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d))
//  fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2))
            auto rd = divide_cast(this->right);
            if (lm.get() && rd.get()) {
                if (constant_cast(rd->get_left()).get() ||
                    constant_cast(rd->get_right()).get()) {
                    return lm->get_left()*fma(lm->get_right(),
                                              this->middle,
                                              rd->get_left()/(lm->get_left()*rd->get_right()));
                }
            }

//  Handle cases like.
//  fma(a,v1,b*v2) -> (a + b*v1/v2)*v1
//  fma(a,v1,c*b*v2) -> (a + c*b*v1/v2)*v1
            if (rm.get()) {
                if (is_same_variable_like(this->middle, rm->get_right())) {
                    return (this->left + rm->get_left()*this->middle/rm->get_right()) *
                           this->middle*rm->get_right();
                }
                auto rmm = multiply_cast(rm->get_right());
                if (rmm.get() &&
                    is_same_variable_like(this->middle, rmm->get_right())) {
                    return (this->left + rm->get_left()*rmm->get_left()*this->middle/rmm->get_right()) *
                           this->middle;
                }
            }

//  Promote constants out to the left.
            if (l.get() && r.get()) {
                return this->left*(this->middle + this->right/this->left);
@@ -1365,9 +1451,23 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual void to_latex() const final {
            std::cout << "\\left(";
            if (add_cast(this->left).get() ||
                subtract_cast(this->left).get()) {
                std::cout << "\\left(";
                this->left->to_latex();
                std::cout << "\\right)";
            } else {
                this->left->to_latex();
            }
            std::cout << " ";
            if (add_cast(this->right).get() ||
                subtract_cast(this->right).get()) {
                std::cout << "\\left(";
                this->middle->to_latex();
                std::cout << "\\right)";
            } else {
                this->middle->to_latex();
            }
            std::cout << "+";
            this->right->to_latex();
            std::cout << "\\right)";
+114 −51
Original line number Diff line number Diff line
@@ -203,6 +203,10 @@ template<typename BACKEND> void test_add() {
    assert(common_d_acast.get() && "Expected add node.");
    assert(graph::constant_cast(common_d_acast->get_left()).get() &&
           "Expected constant on the left.");

//  c1*a + c2*b -> c1*(a + c2*b)
    assert(graph::multiply_cast(three*variable + (one + one)*var_b).get() &&
           "Expected multilpy node.");
}

//------------------------------------------------------------------------------
@@ -374,11 +378,11 @@ template<typename BACKEND> void test_subtract() {
    auto negate_cast = add_cast(negate);
    assert(negate_cast.get() && "Expected addition node.");

//  (c1*v1 + c2) - (c3*v1 + c4) -> c5*v1 - c6
//  (c1*v1 + c2) - (c3*v1 + c4) -> c5*(v1 - c6)
    auto three = graph::constant<BACKEND> (3);
    auto subfma = graph::fma(three, var_a, two)
                + graph::fma(two, var_a, three);
    assert(graph::fma_cast(subfma).get() && "Expected fused multiply add node.");
    assert(graph::multiply_cast(subfma).get() && "Expected a multiply node.");

//  Test cases like
//  (c1 + c2/x) - c3/x -> c1 + c4/x
@@ -394,6 +398,10 @@ template<typename BACKEND> void test_subtract() {
    assert(common_d_scast.get() && "Expected subtract node.");
    assert(graph::constant_cast(common_d_scast->get_left()).get() &&
           "Expected constant on the left.");
    
//  c1*a - c2*b -> c1*(a - c2*b)
    assert(graph::multiply_cast(three*var_a - (one + one)*var_b).get() &&
           "Expected multilpy node.");
}

//------------------------------------------------------------------------------
@@ -626,21 +634,6 @@ template<typename BACKEND> void test_multiply() {
    assert(graph::constant_cast(gather_v5_cast->get_right())->is(2) &&
           "Expected power of 2.");

//  Test reduction of a constant*fma node.
    auto c = graph::constant<BACKEND> (3);
    auto fma = graph::fma(graph::variable<BACKEND> (1, ""),
                          graph::variable<BACKEND> (1, ""),
                          graph::variable<BACKEND> (1, ""));
    auto cfma = graph::fma_cast(c*fma);
    assert(cfma.get() && "Expected fma node.");
    assert(graph::multiply_cast(cfma->get_right()).get() &&
           "Expected multiply node in add branch.");

    auto fmac = graph::fma_cast(fma*c);
    assert(fmac.get() && "Expected fma node.");
    assert(graph::multiply_cast(fmac->get_right()).get() &&
           "Expected multiply node in add branch.");

//  Test gather of terms. This test is setup to trigger an infinite recursive
//  loop if a critical check is not in place no need to check the values.
    auto a = graph::variable<BACKEND> (1, "");
@@ -684,30 +677,30 @@ template<typename BACKEND> void test_multiply() {

//  Test c1*(v/c2) -> c6*v
    auto c6 = two*(a/three);
    auto c6_cast = graph::divide_cast(c6);
    assert(c6_cast.get() && "Expected divide node.");
    assert(graph::constant_cast(c6_cast->get_right()) &&
           "Expected constant in the denominator.");
    assert(graph::variable_cast(c6_cast->get_left()) &&
           "Expected variable in the numerator.");

//  Test (c2/v)*c1 -> c7*v
    auto c6_cast = graph::multiply_cast(c6);
    assert(c6_cast.get() && "Expected multiply node.");
    assert(graph::constant_cast(c6_cast->get_left()) &&
           "Expected constant for the left.");
    assert(graph::variable_cast(c6_cast->get_right()) &&
           "Expected variable for the right.");

//  Test (c2/v)*c1 -> c7/v
    auto c7 = (three/a)*two;
    auto c7_cast = graph::divide_cast(c7);
    assert(c7_cast.get() && "Expected divide node.");
    assert(graph::constant_cast(c7_cast->get_left()) &&
           "Expected constant in the numerator.");
           "Expected constant for the numerator.");
    assert(graph::variable_cast(c7_cast->get_right()) &&
           "Expected variable in the denominator.");
           "Expected variable for the denominator.");

//  Test (v/c2)*c1 -> c8*v
    auto c8 = two*(a/three);
    auto c8_cast = graph::divide_cast(c8);
    auto c8_cast = graph::multiply_cast(c8);
    assert(c8_cast.get() && "Expected divide node.");
    assert(graph::constant_cast(c8_cast->get_right()) &&
           "Expected constant in the denominator.");
    assert(graph::variable_cast(c8_cast->get_left()) &&
           "Expected variable in the numerator.");
    assert(graph::constant_cast(c8_cast->get_left()) &&
           "Expected constant for the left.");
    assert(graph::variable_cast(c8_cast->get_right()) &&
           "Expected variable for the right.");

//  Test v1*(c*v2) -> c*(v1*v2)
    auto c9 = a*(three*variable);
@@ -891,9 +884,10 @@ template<typename BACKEND> void test_divide() {
                                           backend::base_cast<BACKEND> (3.0) &&
           "Expected 2/3 for result.");

//  v/c1 -> (1/c1)*v -> c2*v
    auto var_divided_two = variable/two;
    assert(graph::divide_cast(var_divided_two).get() &&
           "Expected divide node.");
    assert(graph::multiply_cast(var_divided_two).get() &&
           "Expected multiply node.");
    const BACKEND var_divided_two_result = var_divided_two->evaluate();
    assert(var_divided_two_result.size() == 1 && "Expected single value.");
    assert(var_divided_two_result.at(0) == backend::base_cast<BACKEND> (3.0) /
@@ -922,8 +916,8 @@ template<typename BACKEND> void test_divide() {
           "Expected to recover numerator.");

    auto varvec_divided_two = varvec/two;
    assert(graph::divide_cast(varvec_divided_two).get() &&
           "Expect divide node.");
    assert(graph::multiply_cast(varvec_divided_two).get() &&
           "Expect mutliply node.");
    const BACKEND varvec_divided_two_result = varvec_divided_two->evaluate();
    assert(varvec_divided_two_result.size() == 2 && "Size mismatch in result.");
    assert(varvec_divided_two_result.at(0) == backend::base_cast<BACKEND> (1.0) &&
@@ -1071,23 +1065,23 @@ template<typename BACKEND> void test_divide() {
    assert(graph::variable_cast(c4_cast->get_right()).get() &&
           "Expected a variable in the denominator");

//  (c1*v)/c2 -> v/c5
//  (c1*v)/c2 -> c5*v
    auto c5 = (two*variable)/three;
    auto c5_cast = graph::divide_cast(c5);
    assert(c5_cast.get() && "Expected divide node");
    assert(graph::variable_cast(c5_cast->get_left()).get() &&
           "Expected a variable in numerator.");
    assert(graph::constant_cast(c5_cast->get_right()).get() &&
           "Expected a constant in the denominator");

//  (v*c1)/c2 -> v/c5
    auto c5_cast = graph::multiply_cast(c5);
    assert(c5_cast.get() && "Expected multiply node");
    assert(graph::constant_cast(c5_cast->get_left()).get() &&
           "Expected a constant in the numerator");
    assert(graph::variable_cast(c5_cast->get_right()).get() &&
           "Expected a variable in the denominator.");

//  (v*c1)/c2 -> c5*v
    auto c6 = (variable*two)/three;
    auto c6_cast = graph::divide_cast(c6);
    assert(c6_cast.get() && "Expected divide node");
    assert(graph::variable_cast(c6_cast->get_left()).get() &&
           "Expected a variable in numerator.");
    assert(graph::constant_cast(c6_cast->get_right()).get() &&
           "Expected a constant in the denominator");
    auto c6_cast = graph::multiply_cast(c6);
    assert(c6_cast.get() && "Expected multiply node");
    assert(graph::constant_cast(c6_cast->get_left()).get() &&
           "Expected a constant in the numerator");
    assert(graph::variable_cast(c6_cast->get_right()).get() &&
           "Expected a variable in the denominator.");

//  (c*v1)/v2 -> c*(v1/v2)
    auto a = graph::variable<BACKEND> (1, "");
@@ -1304,6 +1298,34 @@ template<typename BACKEND> void test_fma() {

    assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() &&
           "Expected multiply node.");
    
//  fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d)
    assert(graph::multiply_cast(graph::fma(two*var_b,
                                           var_a,
                                           two*two*var_b)).get() &&
           "Expected multiply node.");

//  fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d))
//  fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2))
    assert(graph::multiply_cast(graph::fma(two*var_b,
                                           var_a,
                                           two*two/var_b)).get() &&
           "Expected multiply node.");
    assert(graph::multiply_cast(graph::fma(two*var_b,
                                           var_a,
                                           var_b/(two*two))).get() &&
           "Expected multiply node.");

//  fma(a,v1,b*v2) -> (a + b*v1/v2)*v1
//  fma(a,v1,c*b*v2) -> (a + c*b*v1/v2)*v1
    assert(graph::multiply_cast(graph::fma(two,
                                           var_a,
                                           two*sqrt(var_a))).get() &&
           "Expected multiply node.");
    assert(graph::multiply_cast(graph::fma(two,
                                           var_a,
                                           two*(var_b*sqrt(var_a)))).get() &&
           "Expected multiply node.");
}

//------------------------------------------------------------------------------
@@ -1325,6 +1347,47 @@ template<typename BACKEND> void test_variable_like() {
           "Expected sqrt(c) to not be variable like.");
    assert(!graph::is_variable_like(graph::pow(c, a)) &&
           "Expected c^a to not be variable like.");
    
    assert(graph::get_argument(a)->is_match(a) &&
           "Expected argument of a.");
    assert(graph::get_argument(graph::sqrt(a))->is_match(a) &&
           "Expected argument of a.");
    assert(graph::get_argument(graph::pow(a, c))->is_match(a) &&
           "Expected argument of a.");
    
    assert(graph::is_same_variable_like(a, graph::sqrt(a)) &&
           "Expected same.");
    assert(graph::is_same_variable_like(graph::sqrt(a), a) &&
           "Expected same.");
    assert(graph::is_same_variable_like(a, graph::pow(a, c)) &&
           "Expected same.");
    assert(graph::is_same_variable_like(graph::pow(a, c), a) &&
           "Expected same.");
    assert(graph::is_same_variable_like(graph::sqrt(a),
                                        graph::pow(a, c)) &&
           "Expected same.");
    assert(graph::is_same_variable_like(graph::pow(a, c),
                                        graph::sqrt(a)) &&
           "Expected same.");
    assert(!graph::is_same_variable_like(graph::pow(c, a),
                                         graph::sqrt(a)) &&
           "Expected different.");
    
    auto b = graph::variable<BACKEND> (1, "");
    assert(!graph::is_same_variable_like(a, graph::sqrt(b)) &&
           "Expected different.");
    assert(!graph::is_same_variable_like(graph::sqrt(a), b) &&
           "Expected different.");
    assert(!graph::is_same_variable_like(a, graph::pow(b, c)) &&
           "Expected different.");
    assert(!graph::is_same_variable_like(graph::pow(a, c), b) &&
           "Expected different.");
    assert(!graph::is_same_variable_like(graph::sqrt(a),
                                         graph::pow(b, c)) &&
           "Expected different.");
    assert(!graph::is_same_variable_like(graph::pow(a, c),
                                         graph::sqrt(b)) &&
           "Expected different.");
}

//------------------------------------------------------------------------------
Loading