Loading graph_driver/xrays.cpp +8 −8 Original line number Diff line number Diff line Loading @@ -30,8 +30,8 @@ static base solution(const base t) { /// @param[in] argv Array of commandline arguments. //------------------------------------------------------------------------------ int main(int argc, const char * argv[]) { typedef std::complex<double> base; //typedef double base; //typedef std::complex<double> base; typedef double base; //typedef float base; //typedef std::complex<float> base; typedef backend::cpu<base> cpu; Loading @@ -39,8 +39,8 @@ int main(int argc, const char * argv[]) { const std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now(); const size_t num_times = 10000; const size_t num_rays = 1; //const size_t num_rays = 10000; //const size_t num_rays = 1; const size_t num_rays = 10000; std::vector<std::thread> threads(std::max(std::min(std::thread::hardware_concurrency(), static_cast<unsigned int> (num_rays)), Loading Loading @@ -69,7 +69,7 @@ int main(int argc, const char * argv[]) { // Inital conditions. for (size_t j = 0; j < local_num_rays; j++) { omega->set(j, 600.0); omega->set(j, 500.0); } x->set(backend::base_cast<cpu> (0.0)); Loading @@ -85,10 +85,10 @@ int main(int argc, const char * argv[]) { //solver::split_simplextic<dispersion::bohm_gross<cpu>> //solver::rk4<dispersion::bohm_gross<cpu>> //solver::rk4<dispersion::simple<cpu>> //solver::rk4<dispersion::ordinary_wave<cpu>> solver::rk4<dispersion::extra_ordinary_wave<cpu>> solver::rk4<dispersion::ordinary_wave<cpu>> //solver::rk4<dispersion::extra_ordinary_wave<cpu>> //solver::rk4<dispersion::cold_plasma<cpu>> solve(omega, kx, ky, kz, x, y, z, t, 30.0/num_times, eq); solve(omega, kx, ky, kz, x, y, z, t, 60.0/num_times, eq); solve.init(kx); if (thread_number == 0) { solve.print_dispersion(); Loading graph_framework.xcodeproj/project.xcworkspace/xcuserdata/m4c.xcuserdatad/UserInterfaceState.xcuserstate +2.39 KiB (172 KiB) File changed.No diff preview for this file type. View original file View changed file graph_framework.xcodeproj/xcuserdata/m4c.xcuserdatad/xcschemes/xcschememanagement.plist +7 −7 Original line number Diff line number Diff line Loading @@ -7,17 +7,17 @@ <key>arithmetic_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>7</integer> <integer>8</integer> </dict> <key>backend_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>4</integer> <integer>9</integer> </dict> <key>dispersion_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>9</integer> <integer>3</integer> </dict> <key>graph_driver.xcscheme_^#shared#^_</key> <dict> Loading @@ -27,7 +27,7 @@ <key>graph_framework.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>5</integer> <integer>6</integer> </dict> <key>graph_tests.xcscheme_^#shared#^_</key> <dict> Loading @@ -42,7 +42,7 @@ <key>node_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>3</integer> <integer>4</integer> </dict> <key>physics_test.xcscheme_^#shared#^_</key> <dict> Loading @@ -52,12 +52,12 @@ <key>solver_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>6</integer> <integer>7</integer> </dict> <key>vector_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>8</integer> <integer>5</integer> </dict> </dict> <key>SuppressBuildableAutocreation</key> Loading graph_framework/arithmetic.hpp +122 −22 Original line number Diff line number Diff line Loading @@ -26,6 +26,43 @@ namespace graph { (pow_cast(a).get() && variable_cast(pow_cast(a)->get_left()).get()); } //------------------------------------------------------------------------------ /// @brief Get the argument of a variable like object. /// /// @param[in] a Expression to check. /// @returns The agument of a. //------------------------------------------------------------------------------ template<typename N> std::shared_ptr<N> get_argument(std::shared_ptr<N> a) { if (variable_cast(a).get()) { return a; } else if (sqrt_cast(a).get() && variable_cast(sqrt_cast(a)->get_arg()).get()) { return sqrt_cast(a)->get_arg(); } else if (pow_cast(a).get() && variable_cast(pow_cast(a)->get_left()).get()) { return pow_cast(a)->get_left(); } assert(false && "Should never reach this point."); return nullptr; } //------------------------------------------------------------------------------ /// @brief Check variable like objects are the same. /// /// @param[in] a Expression to check. /// @param[in] b Expression to check. /// @returns True if a is variable like. //------------------------------------------------------------------------------ template<typename N> bool is_same_variable_like(std::shared_ptr<N> a, std::shared_ptr<N> b) { return is_variable_like(a) && is_variable_like(b) && get_argument(a)->is_match(get_argument(b)); } //****************************************************************************** // Add node. //****************************************************************************** Loading Loading @@ -114,6 +151,14 @@ namespace graph { } else if (lm->get_right()->is_match(rm->get_right())) { return lm->get_right()*(lm->get_left() + rm->get_left()); } // Change cases like c1*a + c2*b -> c1*(a + c2*b) auto lmc = constant_cast(lm->get_left()); auto rmc = constant_cast(rm->get_left()); if (lmc.get() && rmc.get()) { return lm->get_left()*(lm->get_right() + (rm->get_left()/lm->get_left())*rm->get_right()); } } // Common denominator reduction. If the left and right are both divide nodes Loading Loading @@ -407,6 +452,14 @@ namespace graph { } else if (lm->get_right()->is_match(rm->get_right())) { return lm->get_right()*(lm->get_left() - rm->get_left()); } // Change cases like c1*a - c2*b -> c1*(a - c2*b) auto lmc = constant_cast(lm->get_left()); auto rmc = constant_cast(rm->get_left()); if (lmc.get() && rmc.get()) { return lm->get_left()*(lm->get_right() - (rm->get_left()/lm->get_left())*rm->get_right()); } } // Common denominator reduction. If the left and right are both divide nodes Loading Loading @@ -670,20 +723,6 @@ namespace graph { return this->right*this->left; } // Reduce constants multiplied by fused multiply add nodes. auto rfma = fma_cast(this->right); if (l.get() && rfma.get()) { return fma(this->left*rfma->get_left(), rfma->get_middle(), this->left*rfma->get_right()); } auto lfma = fma_cast(this->left); if (r.get() && lfma.get()) { return fma(this->right*lfma->get_left(), lfma->get_middle(), this->right*lfma->get_right()); } // Reduce x*x to x^2 if (this->left->is_match(this->right)) { return pow(this->left, constant<typename LN::backend> (2.0)); Loading @@ -702,6 +741,12 @@ namespace graph { return (this->right*lm->get_right())*lm->get_left(); } // Promote constants before variables. // (c*v1)*v2 -> c*(v1*v2) if (constant_cast(lm->get_left()).get()) { return lm->get_left()*(lm->get_right()*this->right); } // Assume variables, sqrt of variables, and powers of variables are on the // right. // (a*v)*b -> a*(v*b) Loading @@ -727,11 +772,8 @@ namespace graph { } // v1*(c*v2) -> c*(v1*v2) // (c*v1)*v2 -> c*(v1*v2) if (rm.get() && constant_cast(rm->get_left()).get()) { return rm->get_left()*(this->left*rm->get_right()); } else if (lm.get() && constant_cast(lm->get_left()).get()) { return lm->get_left()*(lm->get_right()*this->right); } // Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on Loading Loading @@ -1012,6 +1054,12 @@ namespace graph { return constant(this->evaluate()); } // Reduce cases of a/c1 -> c2*a if (r.get()) { return (constant<typename LN::backend> (1)/this->right) * this->left; } // Reduce fused multiply divided by constant nodes. auto lfma = fma_cast(this->left); if (r.get() && lfma.get()) { Loading @@ -1026,14 +1074,14 @@ namespace graph { // Assume constants are always on the left. // c1/(c2*v) -> c3/v // (c1*v)/c2 -> v/c3 // (c1*v)/c2 -> c3*v if (rm.get() && l.get()) { if (constant_cast(rm->get_left()).get()) { return (this->left/rm->get_left())/rm->get_right(); } } else if (lm.get() && r.get()) { if (constant_cast(lm->get_left()).get()) { return lm->get_right()/(this->right/lm->get_left()); return (lm->get_left()/this->right)*lm->get_right(); } } Loading Loading @@ -1295,7 +1343,6 @@ namespace graph { // Common factor reduction. If the left and right are both multiply nodes check // for a common factor. So you can change a*b + (a*c) -> a*(b + c). auto rm = multiply_cast(this->right); if (rm.get()) { if (rm->get_left()->is_match(this->left)) { return this->left*(this->middle + rm->get_right()); Loading @@ -1308,6 +1355,45 @@ namespace graph { } } // Handle cases like. // fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d) auto lm = multiply_cast(this->left); if (lm.get() && rm.get()) { auto rmc = constant_cast(rm->get_left()); if (rmc.get()) { return lm->get_left()*fma(lm->get_right(), this->middle, (rm->get_left()/lm->get_left())*rm->get_right()); } } // fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d)) // fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2)) auto rd = divide_cast(this->right); if (lm.get() && rd.get()) { if (constant_cast(rd->get_left()).get() || constant_cast(rd->get_right()).get()) { return lm->get_left()*fma(lm->get_right(), this->middle, rd->get_left()/(lm->get_left()*rd->get_right())); } } // Handle cases like. // fma(a,v1,b*v2) -> (a + b*v1/v2)*v1 // fma(a,v1,c*b*v2) -> (a + c*b*v1/v2)*v1 if (rm.get()) { if (is_same_variable_like(this->middle, rm->get_right())) { return (this->left + rm->get_left()*this->middle/rm->get_right()) * this->middle*rm->get_right(); } auto rmm = multiply_cast(rm->get_right()); if (rmm.get() && is_same_variable_like(this->middle, rmm->get_right())) { return (this->left + rm->get_left()*rmm->get_left()*this->middle/rmm->get_right()) * this->middle; } } // Promote constants out to the left. if (l.get() && r.get()) { return this->left*(this->middle + this->right/this->left); Loading Loading @@ -1365,9 +1451,23 @@ namespace graph { //------------------------------------------------------------------------------ virtual void to_latex() const final { std::cout << "\\left("; if (add_cast(this->left).get() || subtract_cast(this->left).get()) { std::cout << "\\left("; this->left->to_latex(); std::cout << "\\right)"; } else { this->left->to_latex(); } std::cout << " "; if (add_cast(this->right).get() || subtract_cast(this->right).get()) { std::cout << "\\left("; this->middle->to_latex(); std::cout << "\\right)"; } else { this->middle->to_latex(); } std::cout << "+"; this->right->to_latex(); std::cout << "\\right)"; Loading graph_tests/arithmetic_test.cpp +114 −51 Original line number Diff line number Diff line Loading @@ -203,6 +203,10 @@ template<typename BACKEND> void test_add() { assert(common_d_acast.get() && "Expected add node."); assert(graph::constant_cast(common_d_acast->get_left()).get() && "Expected constant on the left."); // c1*a + c2*b -> c1*(a + c2*b) assert(graph::multiply_cast(three*variable + (one + one)*var_b).get() && "Expected multilpy node."); } //------------------------------------------------------------------------------ Loading Loading @@ -374,11 +378,11 @@ template<typename BACKEND> void test_subtract() { auto negate_cast = add_cast(negate); assert(negate_cast.get() && "Expected addition node."); // (c1*v1 + c2) - (c3*v1 + c4) -> c5*v1 - c6 // (c1*v1 + c2) - (c3*v1 + c4) -> c5*(v1 - c6) auto three = graph::constant<BACKEND> (3); auto subfma = graph::fma(three, var_a, two) + graph::fma(two, var_a, three); assert(graph::fma_cast(subfma).get() && "Expected fused multiply add node."); assert(graph::multiply_cast(subfma).get() && "Expected a multiply node."); // Test cases like // (c1 + c2/x) - c3/x -> c1 + c4/x Loading @@ -394,6 +398,10 @@ template<typename BACKEND> void test_subtract() { assert(common_d_scast.get() && "Expected subtract node."); assert(graph::constant_cast(common_d_scast->get_left()).get() && "Expected constant on the left."); // c1*a - c2*b -> c1*(a - c2*b) assert(graph::multiply_cast(three*var_a - (one + one)*var_b).get() && "Expected multilpy node."); } //------------------------------------------------------------------------------ Loading Loading @@ -626,21 +634,6 @@ template<typename BACKEND> void test_multiply() { assert(graph::constant_cast(gather_v5_cast->get_right())->is(2) && "Expected power of 2."); // Test reduction of a constant*fma node. auto c = graph::constant<BACKEND> (3); auto fma = graph::fma(graph::variable<BACKEND> (1, ""), graph::variable<BACKEND> (1, ""), graph::variable<BACKEND> (1, "")); auto cfma = graph::fma_cast(c*fma); assert(cfma.get() && "Expected fma node."); assert(graph::multiply_cast(cfma->get_right()).get() && "Expected multiply node in add branch."); auto fmac = graph::fma_cast(fma*c); assert(fmac.get() && "Expected fma node."); assert(graph::multiply_cast(fmac->get_right()).get() && "Expected multiply node in add branch."); // Test gather of terms. This test is setup to trigger an infinite recursive // loop if a critical check is not in place no need to check the values. auto a = graph::variable<BACKEND> (1, ""); Loading Loading @@ -684,30 +677,30 @@ template<typename BACKEND> void test_multiply() { // Test c1*(v/c2) -> c6*v auto c6 = two*(a/three); auto c6_cast = graph::divide_cast(c6); assert(c6_cast.get() && "Expected divide node."); assert(graph::constant_cast(c6_cast->get_right()) && "Expected constant in the denominator."); assert(graph::variable_cast(c6_cast->get_left()) && "Expected variable in the numerator."); // Test (c2/v)*c1 -> c7*v auto c6_cast = graph::multiply_cast(c6); assert(c6_cast.get() && "Expected multiply node."); assert(graph::constant_cast(c6_cast->get_left()) && "Expected constant for the left."); assert(graph::variable_cast(c6_cast->get_right()) && "Expected variable for the right."); // Test (c2/v)*c1 -> c7/v auto c7 = (three/a)*two; auto c7_cast = graph::divide_cast(c7); assert(c7_cast.get() && "Expected divide node."); assert(graph::constant_cast(c7_cast->get_left()) && "Expected constant in the numerator."); "Expected constant for the numerator."); assert(graph::variable_cast(c7_cast->get_right()) && "Expected variable in the denominator."); "Expected variable for the denominator."); // Test (v/c2)*c1 -> c8*v auto c8 = two*(a/three); auto c8_cast = graph::divide_cast(c8); auto c8_cast = graph::multiply_cast(c8); assert(c8_cast.get() && "Expected divide node."); assert(graph::constant_cast(c8_cast->get_right()) && "Expected constant in the denominator."); assert(graph::variable_cast(c8_cast->get_left()) && "Expected variable in the numerator."); assert(graph::constant_cast(c8_cast->get_left()) && "Expected constant for the left."); assert(graph::variable_cast(c8_cast->get_right()) && "Expected variable for the right."); // Test v1*(c*v2) -> c*(v1*v2) auto c9 = a*(three*variable); Loading Loading @@ -891,9 +884,10 @@ template<typename BACKEND> void test_divide() { backend::base_cast<BACKEND> (3.0) && "Expected 2/3 for result."); // v/c1 -> (1/c1)*v -> c2*v auto var_divided_two = variable/two; assert(graph::divide_cast(var_divided_two).get() && "Expected divide node."); assert(graph::multiply_cast(var_divided_two).get() && "Expected multiply node."); const BACKEND var_divided_two_result = var_divided_two->evaluate(); assert(var_divided_two_result.size() == 1 && "Expected single value."); assert(var_divided_two_result.at(0) == backend::base_cast<BACKEND> (3.0) / Loading Loading @@ -922,8 +916,8 @@ template<typename BACKEND> void test_divide() { "Expected to recover numerator."); auto varvec_divided_two = varvec/two; assert(graph::divide_cast(varvec_divided_two).get() && "Expect divide node."); assert(graph::multiply_cast(varvec_divided_two).get() && "Expect mutliply node."); const BACKEND varvec_divided_two_result = varvec_divided_two->evaluate(); assert(varvec_divided_two_result.size() == 2 && "Size mismatch in result."); assert(varvec_divided_two_result.at(0) == backend::base_cast<BACKEND> (1.0) && Loading Loading @@ -1071,23 +1065,23 @@ template<typename BACKEND> void test_divide() { assert(graph::variable_cast(c4_cast->get_right()).get() && "Expected a variable in the denominator"); // (c1*v)/c2 -> v/c5 // (c1*v)/c2 -> c5*v auto c5 = (two*variable)/three; auto c5_cast = graph::divide_cast(c5); assert(c5_cast.get() && "Expected divide node"); assert(graph::variable_cast(c5_cast->get_left()).get() && "Expected a variable in numerator."); assert(graph::constant_cast(c5_cast->get_right()).get() && "Expected a constant in the denominator"); // (v*c1)/c2 -> v/c5 auto c5_cast = graph::multiply_cast(c5); assert(c5_cast.get() && "Expected multiply node"); assert(graph::constant_cast(c5_cast->get_left()).get() && "Expected a constant in the numerator"); assert(graph::variable_cast(c5_cast->get_right()).get() && "Expected a variable in the denominator."); // (v*c1)/c2 -> c5*v auto c6 = (variable*two)/three; auto c6_cast = graph::divide_cast(c6); assert(c6_cast.get() && "Expected divide node"); assert(graph::variable_cast(c6_cast->get_left()).get() && "Expected a variable in numerator."); assert(graph::constant_cast(c6_cast->get_right()).get() && "Expected a constant in the denominator"); auto c6_cast = graph::multiply_cast(c6); assert(c6_cast.get() && "Expected multiply node"); assert(graph::constant_cast(c6_cast->get_left()).get() && "Expected a constant in the numerator"); assert(graph::variable_cast(c6_cast->get_right()).get() && "Expected a variable in the denominator."); // (c*v1)/v2 -> c*(v1/v2) auto a = graph::variable<BACKEND> (1, ""); Loading Loading @@ -1304,6 +1298,34 @@ template<typename BACKEND> void test_fma() { assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() && "Expected multiply node."); // fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d) assert(graph::multiply_cast(graph::fma(two*var_b, var_a, two*two*var_b)).get() && "Expected multiply node."); // fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d)) // fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2)) assert(graph::multiply_cast(graph::fma(two*var_b, var_a, two*two/var_b)).get() && "Expected multiply node."); assert(graph::multiply_cast(graph::fma(two*var_b, var_a, var_b/(two*two))).get() && "Expected multiply node."); // fma(a,v1,b*v2) -> (a + b*v1/v2)*v1 // fma(a,v1,c*b*v2) -> (a + c*b*v1/v2)*v1 assert(graph::multiply_cast(graph::fma(two, var_a, two*sqrt(var_a))).get() && "Expected multiply node."); assert(graph::multiply_cast(graph::fma(two, var_a, two*(var_b*sqrt(var_a)))).get() && "Expected multiply node."); } //------------------------------------------------------------------------------ Loading @@ -1325,6 +1347,47 @@ template<typename BACKEND> void test_variable_like() { "Expected sqrt(c) to not be variable like."); assert(!graph::is_variable_like(graph::pow(c, a)) && "Expected c^a to not be variable like."); assert(graph::get_argument(a)->is_match(a) && "Expected argument of a."); assert(graph::get_argument(graph::sqrt(a))->is_match(a) && "Expected argument of a."); assert(graph::get_argument(graph::pow(a, c))->is_match(a) && "Expected argument of a."); assert(graph::is_same_variable_like(a, graph::sqrt(a)) && "Expected same."); assert(graph::is_same_variable_like(graph::sqrt(a), a) && "Expected same."); assert(graph::is_same_variable_like(a, graph::pow(a, c)) && "Expected same."); assert(graph::is_same_variable_like(graph::pow(a, c), a) && "Expected same."); assert(graph::is_same_variable_like(graph::sqrt(a), graph::pow(a, c)) && "Expected same."); assert(graph::is_same_variable_like(graph::pow(a, c), graph::sqrt(a)) && "Expected same."); assert(!graph::is_same_variable_like(graph::pow(c, a), graph::sqrt(a)) && "Expected different."); auto b = graph::variable<BACKEND> (1, ""); assert(!graph::is_same_variable_like(a, graph::sqrt(b)) && "Expected different."); assert(!graph::is_same_variable_like(graph::sqrt(a), b) && "Expected different."); assert(!graph::is_same_variable_like(a, graph::pow(b, c)) && "Expected different."); assert(!graph::is_same_variable_like(graph::pow(a, c), b) && "Expected different."); assert(!graph::is_same_variable_like(graph::sqrt(a), graph::pow(b, c)) && "Expected different."); assert(!graph::is_same_variable_like(graph::pow(a, c), graph::sqrt(b)) && "Expected different."); } //------------------------------------------------------------------------------ Loading Loading
graph_driver/xrays.cpp +8 −8 Original line number Diff line number Diff line Loading @@ -30,8 +30,8 @@ static base solution(const base t) { /// @param[in] argv Array of commandline arguments. //------------------------------------------------------------------------------ int main(int argc, const char * argv[]) { typedef std::complex<double> base; //typedef double base; //typedef std::complex<double> base; typedef double base; //typedef float base; //typedef std::complex<float> base; typedef backend::cpu<base> cpu; Loading @@ -39,8 +39,8 @@ int main(int argc, const char * argv[]) { const std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now(); const size_t num_times = 10000; const size_t num_rays = 1; //const size_t num_rays = 10000; //const size_t num_rays = 1; const size_t num_rays = 10000; std::vector<std::thread> threads(std::max(std::min(std::thread::hardware_concurrency(), static_cast<unsigned int> (num_rays)), Loading Loading @@ -69,7 +69,7 @@ int main(int argc, const char * argv[]) { // Inital conditions. for (size_t j = 0; j < local_num_rays; j++) { omega->set(j, 600.0); omega->set(j, 500.0); } x->set(backend::base_cast<cpu> (0.0)); Loading @@ -85,10 +85,10 @@ int main(int argc, const char * argv[]) { //solver::split_simplextic<dispersion::bohm_gross<cpu>> //solver::rk4<dispersion::bohm_gross<cpu>> //solver::rk4<dispersion::simple<cpu>> //solver::rk4<dispersion::ordinary_wave<cpu>> solver::rk4<dispersion::extra_ordinary_wave<cpu>> solver::rk4<dispersion::ordinary_wave<cpu>> //solver::rk4<dispersion::extra_ordinary_wave<cpu>> //solver::rk4<dispersion::cold_plasma<cpu>> solve(omega, kx, ky, kz, x, y, z, t, 30.0/num_times, eq); solve(omega, kx, ky, kz, x, y, z, t, 60.0/num_times, eq); solve.init(kx); if (thread_number == 0) { solve.print_dispersion(); Loading
graph_framework.xcodeproj/project.xcworkspace/xcuserdata/m4c.xcuserdatad/UserInterfaceState.xcuserstate +2.39 KiB (172 KiB) File changed.No diff preview for this file type. View original file View changed file
graph_framework.xcodeproj/xcuserdata/m4c.xcuserdatad/xcschemes/xcschememanagement.plist +7 −7 Original line number Diff line number Diff line Loading @@ -7,17 +7,17 @@ <key>arithmetic_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>7</integer> <integer>8</integer> </dict> <key>backend_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>4</integer> <integer>9</integer> </dict> <key>dispersion_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>9</integer> <integer>3</integer> </dict> <key>graph_driver.xcscheme_^#shared#^_</key> <dict> Loading @@ -27,7 +27,7 @@ <key>graph_framework.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>5</integer> <integer>6</integer> </dict> <key>graph_tests.xcscheme_^#shared#^_</key> <dict> Loading @@ -42,7 +42,7 @@ <key>node_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>3</integer> <integer>4</integer> </dict> <key>physics_test.xcscheme_^#shared#^_</key> <dict> Loading @@ -52,12 +52,12 @@ <key>solver_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>6</integer> <integer>7</integer> </dict> <key>vector_test.xcscheme_^#shared#^_</key> <dict> <key>orderHint</key> <integer>8</integer> <integer>5</integer> </dict> </dict> <key>SuppressBuildableAutocreation</key> Loading
graph_framework/arithmetic.hpp +122 −22 Original line number Diff line number Diff line Loading @@ -26,6 +26,43 @@ namespace graph { (pow_cast(a).get() && variable_cast(pow_cast(a)->get_left()).get()); } //------------------------------------------------------------------------------ /// @brief Get the argument of a variable like object. /// /// @param[in] a Expression to check. /// @returns The agument of a. //------------------------------------------------------------------------------ template<typename N> std::shared_ptr<N> get_argument(std::shared_ptr<N> a) { if (variable_cast(a).get()) { return a; } else if (sqrt_cast(a).get() && variable_cast(sqrt_cast(a)->get_arg()).get()) { return sqrt_cast(a)->get_arg(); } else if (pow_cast(a).get() && variable_cast(pow_cast(a)->get_left()).get()) { return pow_cast(a)->get_left(); } assert(false && "Should never reach this point."); return nullptr; } //------------------------------------------------------------------------------ /// @brief Check variable like objects are the same. /// /// @param[in] a Expression to check. /// @param[in] b Expression to check. /// @returns True if a is variable like. //------------------------------------------------------------------------------ template<typename N> bool is_same_variable_like(std::shared_ptr<N> a, std::shared_ptr<N> b) { return is_variable_like(a) && is_variable_like(b) && get_argument(a)->is_match(get_argument(b)); } //****************************************************************************** // Add node. //****************************************************************************** Loading Loading @@ -114,6 +151,14 @@ namespace graph { } else if (lm->get_right()->is_match(rm->get_right())) { return lm->get_right()*(lm->get_left() + rm->get_left()); } // Change cases like c1*a + c2*b -> c1*(a + c2*b) auto lmc = constant_cast(lm->get_left()); auto rmc = constant_cast(rm->get_left()); if (lmc.get() && rmc.get()) { return lm->get_left()*(lm->get_right() + (rm->get_left()/lm->get_left())*rm->get_right()); } } // Common denominator reduction. If the left and right are both divide nodes Loading Loading @@ -407,6 +452,14 @@ namespace graph { } else if (lm->get_right()->is_match(rm->get_right())) { return lm->get_right()*(lm->get_left() - rm->get_left()); } // Change cases like c1*a - c2*b -> c1*(a - c2*b) auto lmc = constant_cast(lm->get_left()); auto rmc = constant_cast(rm->get_left()); if (lmc.get() && rmc.get()) { return lm->get_left()*(lm->get_right() - (rm->get_left()/lm->get_left())*rm->get_right()); } } // Common denominator reduction. If the left and right are both divide nodes Loading Loading @@ -670,20 +723,6 @@ namespace graph { return this->right*this->left; } // Reduce constants multiplied by fused multiply add nodes. auto rfma = fma_cast(this->right); if (l.get() && rfma.get()) { return fma(this->left*rfma->get_left(), rfma->get_middle(), this->left*rfma->get_right()); } auto lfma = fma_cast(this->left); if (r.get() && lfma.get()) { return fma(this->right*lfma->get_left(), lfma->get_middle(), this->right*lfma->get_right()); } // Reduce x*x to x^2 if (this->left->is_match(this->right)) { return pow(this->left, constant<typename LN::backend> (2.0)); Loading @@ -702,6 +741,12 @@ namespace graph { return (this->right*lm->get_right())*lm->get_left(); } // Promote constants before variables. // (c*v1)*v2 -> c*(v1*v2) if (constant_cast(lm->get_left()).get()) { return lm->get_left()*(lm->get_right()*this->right); } // Assume variables, sqrt of variables, and powers of variables are on the // right. // (a*v)*b -> a*(v*b) Loading @@ -727,11 +772,8 @@ namespace graph { } // v1*(c*v2) -> c*(v1*v2) // (c*v1)*v2 -> c*(v1*v2) if (rm.get() && constant_cast(rm->get_left()).get()) { return rm->get_left()*(this->left*rm->get_right()); } else if (lm.get() && constant_cast(lm->get_left()).get()) { return lm->get_left()*(lm->get_right()*this->right); } // Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on Loading Loading @@ -1012,6 +1054,12 @@ namespace graph { return constant(this->evaluate()); } // Reduce cases of a/c1 -> c2*a if (r.get()) { return (constant<typename LN::backend> (1)/this->right) * this->left; } // Reduce fused multiply divided by constant nodes. auto lfma = fma_cast(this->left); if (r.get() && lfma.get()) { Loading @@ -1026,14 +1074,14 @@ namespace graph { // Assume constants are always on the left. // c1/(c2*v) -> c3/v // (c1*v)/c2 -> v/c3 // (c1*v)/c2 -> c3*v if (rm.get() && l.get()) { if (constant_cast(rm->get_left()).get()) { return (this->left/rm->get_left())/rm->get_right(); } } else if (lm.get() && r.get()) { if (constant_cast(lm->get_left()).get()) { return lm->get_right()/(this->right/lm->get_left()); return (lm->get_left()/this->right)*lm->get_right(); } } Loading Loading @@ -1295,7 +1343,6 @@ namespace graph { // Common factor reduction. If the left and right are both multiply nodes check // for a common factor. So you can change a*b + (a*c) -> a*(b + c). auto rm = multiply_cast(this->right); if (rm.get()) { if (rm->get_left()->is_match(this->left)) { return this->left*(this->middle + rm->get_right()); Loading @@ -1308,6 +1355,45 @@ namespace graph { } } // Handle cases like. // fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d) auto lm = multiply_cast(this->left); if (lm.get() && rm.get()) { auto rmc = constant_cast(rm->get_left()); if (rmc.get()) { return lm->get_left()*fma(lm->get_right(), this->middle, (rm->get_left()/lm->get_left())*rm->get_right()); } } // fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d)) // fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2)) auto rd = divide_cast(this->right); if (lm.get() && rd.get()) { if (constant_cast(rd->get_left()).get() || constant_cast(rd->get_right()).get()) { return lm->get_left()*fma(lm->get_right(), this->middle, rd->get_left()/(lm->get_left()*rd->get_right())); } } // Handle cases like. // fma(a,v1,b*v2) -> (a + b*v1/v2)*v1 // fma(a,v1,c*b*v2) -> (a + c*b*v1/v2)*v1 if (rm.get()) { if (is_same_variable_like(this->middle, rm->get_right())) { return (this->left + rm->get_left()*this->middle/rm->get_right()) * this->middle*rm->get_right(); } auto rmm = multiply_cast(rm->get_right()); if (rmm.get() && is_same_variable_like(this->middle, rmm->get_right())) { return (this->left + rm->get_left()*rmm->get_left()*this->middle/rmm->get_right()) * this->middle; } } // Promote constants out to the left. if (l.get() && r.get()) { return this->left*(this->middle + this->right/this->left); Loading Loading @@ -1365,9 +1451,23 @@ namespace graph { //------------------------------------------------------------------------------ virtual void to_latex() const final { std::cout << "\\left("; if (add_cast(this->left).get() || subtract_cast(this->left).get()) { std::cout << "\\left("; this->left->to_latex(); std::cout << "\\right)"; } else { this->left->to_latex(); } std::cout << " "; if (add_cast(this->right).get() || subtract_cast(this->right).get()) { std::cout << "\\left("; this->middle->to_latex(); std::cout << "\\right)"; } else { this->middle->to_latex(); } std::cout << "+"; this->right->to_latex(); std::cout << "\\right)"; Loading
graph_tests/arithmetic_test.cpp +114 −51 Original line number Diff line number Diff line Loading @@ -203,6 +203,10 @@ template<typename BACKEND> void test_add() { assert(common_d_acast.get() && "Expected add node."); assert(graph::constant_cast(common_d_acast->get_left()).get() && "Expected constant on the left."); // c1*a + c2*b -> c1*(a + c2*b) assert(graph::multiply_cast(three*variable + (one + one)*var_b).get() && "Expected multilpy node."); } //------------------------------------------------------------------------------ Loading Loading @@ -374,11 +378,11 @@ template<typename BACKEND> void test_subtract() { auto negate_cast = add_cast(negate); assert(negate_cast.get() && "Expected addition node."); // (c1*v1 + c2) - (c3*v1 + c4) -> c5*v1 - c6 // (c1*v1 + c2) - (c3*v1 + c4) -> c5*(v1 - c6) auto three = graph::constant<BACKEND> (3); auto subfma = graph::fma(three, var_a, two) + graph::fma(two, var_a, three); assert(graph::fma_cast(subfma).get() && "Expected fused multiply add node."); assert(graph::multiply_cast(subfma).get() && "Expected a multiply node."); // Test cases like // (c1 + c2/x) - c3/x -> c1 + c4/x Loading @@ -394,6 +398,10 @@ template<typename BACKEND> void test_subtract() { assert(common_d_scast.get() && "Expected subtract node."); assert(graph::constant_cast(common_d_scast->get_left()).get() && "Expected constant on the left."); // c1*a - c2*b -> c1*(a - c2*b) assert(graph::multiply_cast(three*var_a - (one + one)*var_b).get() && "Expected multilpy node."); } //------------------------------------------------------------------------------ Loading Loading @@ -626,21 +634,6 @@ template<typename BACKEND> void test_multiply() { assert(graph::constant_cast(gather_v5_cast->get_right())->is(2) && "Expected power of 2."); // Test reduction of a constant*fma node. auto c = graph::constant<BACKEND> (3); auto fma = graph::fma(graph::variable<BACKEND> (1, ""), graph::variable<BACKEND> (1, ""), graph::variable<BACKEND> (1, "")); auto cfma = graph::fma_cast(c*fma); assert(cfma.get() && "Expected fma node."); assert(graph::multiply_cast(cfma->get_right()).get() && "Expected multiply node in add branch."); auto fmac = graph::fma_cast(fma*c); assert(fmac.get() && "Expected fma node."); assert(graph::multiply_cast(fmac->get_right()).get() && "Expected multiply node in add branch."); // Test gather of terms. This test is setup to trigger an infinite recursive // loop if a critical check is not in place no need to check the values. auto a = graph::variable<BACKEND> (1, ""); Loading Loading @@ -684,30 +677,30 @@ template<typename BACKEND> void test_multiply() { // Test c1*(v/c2) -> c6*v auto c6 = two*(a/three); auto c6_cast = graph::divide_cast(c6); assert(c6_cast.get() && "Expected divide node."); assert(graph::constant_cast(c6_cast->get_right()) && "Expected constant in the denominator."); assert(graph::variable_cast(c6_cast->get_left()) && "Expected variable in the numerator."); // Test (c2/v)*c1 -> c7*v auto c6_cast = graph::multiply_cast(c6); assert(c6_cast.get() && "Expected multiply node."); assert(graph::constant_cast(c6_cast->get_left()) && "Expected constant for the left."); assert(graph::variable_cast(c6_cast->get_right()) && "Expected variable for the right."); // Test (c2/v)*c1 -> c7/v auto c7 = (three/a)*two; auto c7_cast = graph::divide_cast(c7); assert(c7_cast.get() && "Expected divide node."); assert(graph::constant_cast(c7_cast->get_left()) && "Expected constant in the numerator."); "Expected constant for the numerator."); assert(graph::variable_cast(c7_cast->get_right()) && "Expected variable in the denominator."); "Expected variable for the denominator."); // Test (v/c2)*c1 -> c8*v auto c8 = two*(a/three); auto c8_cast = graph::divide_cast(c8); auto c8_cast = graph::multiply_cast(c8); assert(c8_cast.get() && "Expected divide node."); assert(graph::constant_cast(c8_cast->get_right()) && "Expected constant in the denominator."); assert(graph::variable_cast(c8_cast->get_left()) && "Expected variable in the numerator."); assert(graph::constant_cast(c8_cast->get_left()) && "Expected constant for the left."); assert(graph::variable_cast(c8_cast->get_right()) && "Expected variable for the right."); // Test v1*(c*v2) -> c*(v1*v2) auto c9 = a*(three*variable); Loading Loading @@ -891,9 +884,10 @@ template<typename BACKEND> void test_divide() { backend::base_cast<BACKEND> (3.0) && "Expected 2/3 for result."); // v/c1 -> (1/c1)*v -> c2*v auto var_divided_two = variable/two; assert(graph::divide_cast(var_divided_two).get() && "Expected divide node."); assert(graph::multiply_cast(var_divided_two).get() && "Expected multiply node."); const BACKEND var_divided_two_result = var_divided_two->evaluate(); assert(var_divided_two_result.size() == 1 && "Expected single value."); assert(var_divided_two_result.at(0) == backend::base_cast<BACKEND> (3.0) / Loading Loading @@ -922,8 +916,8 @@ template<typename BACKEND> void test_divide() { "Expected to recover numerator."); auto varvec_divided_two = varvec/two; assert(graph::divide_cast(varvec_divided_two).get() && "Expect divide node."); assert(graph::multiply_cast(varvec_divided_two).get() && "Expect mutliply node."); const BACKEND varvec_divided_two_result = varvec_divided_two->evaluate(); assert(varvec_divided_two_result.size() == 2 && "Size mismatch in result."); assert(varvec_divided_two_result.at(0) == backend::base_cast<BACKEND> (1.0) && Loading Loading @@ -1071,23 +1065,23 @@ template<typename BACKEND> void test_divide() { assert(graph::variable_cast(c4_cast->get_right()).get() && "Expected a variable in the denominator"); // (c1*v)/c2 -> v/c5 // (c1*v)/c2 -> c5*v auto c5 = (two*variable)/three; auto c5_cast = graph::divide_cast(c5); assert(c5_cast.get() && "Expected divide node"); assert(graph::variable_cast(c5_cast->get_left()).get() && "Expected a variable in numerator."); assert(graph::constant_cast(c5_cast->get_right()).get() && "Expected a constant in the denominator"); // (v*c1)/c2 -> v/c5 auto c5_cast = graph::multiply_cast(c5); assert(c5_cast.get() && "Expected multiply node"); assert(graph::constant_cast(c5_cast->get_left()).get() && "Expected a constant in the numerator"); assert(graph::variable_cast(c5_cast->get_right()).get() && "Expected a variable in the denominator."); // (v*c1)/c2 -> c5*v auto c6 = (variable*two)/three; auto c6_cast = graph::divide_cast(c6); assert(c6_cast.get() && "Expected divide node"); assert(graph::variable_cast(c6_cast->get_left()).get() && "Expected a variable in numerator."); assert(graph::constant_cast(c6_cast->get_right()).get() && "Expected a constant in the denominator"); auto c6_cast = graph::multiply_cast(c6); assert(c6_cast.get() && "Expected multiply node"); assert(graph::constant_cast(c6_cast->get_left()).get() && "Expected a constant in the numerator"); assert(graph::variable_cast(c6_cast->get_right()).get() && "Expected a variable in the denominator."); // (c*v1)/v2 -> c*(v1/v2) auto a = graph::variable<BACKEND> (1, ""); Loading Loading @@ -1304,6 +1298,34 @@ template<typename BACKEND> void test_fma() { assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() && "Expected multiply node."); // fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d) assert(graph::multiply_cast(graph::fma(two*var_b, var_a, two*two*var_b)).get() && "Expected multiply node."); // fma(c1*a,b,c2/d) -> c1*(a*b + c1/(c2*d)) // fma(c1*a,b,d/c2) -> c1*(a*b + d/(c1*c2)) assert(graph::multiply_cast(graph::fma(two*var_b, var_a, two*two/var_b)).get() && "Expected multiply node."); assert(graph::multiply_cast(graph::fma(two*var_b, var_a, var_b/(two*two))).get() && "Expected multiply node."); // fma(a,v1,b*v2) -> (a + b*v1/v2)*v1 // fma(a,v1,c*b*v2) -> (a + c*b*v1/v2)*v1 assert(graph::multiply_cast(graph::fma(two, var_a, two*sqrt(var_a))).get() && "Expected multiply node."); assert(graph::multiply_cast(graph::fma(two, var_a, two*(var_b*sqrt(var_a)))).get() && "Expected multiply node."); } //------------------------------------------------------------------------------ Loading @@ -1325,6 +1347,47 @@ template<typename BACKEND> void test_variable_like() { "Expected sqrt(c) to not be variable like."); assert(!graph::is_variable_like(graph::pow(c, a)) && "Expected c^a to not be variable like."); assert(graph::get_argument(a)->is_match(a) && "Expected argument of a."); assert(graph::get_argument(graph::sqrt(a))->is_match(a) && "Expected argument of a."); assert(graph::get_argument(graph::pow(a, c))->is_match(a) && "Expected argument of a."); assert(graph::is_same_variable_like(a, graph::sqrt(a)) && "Expected same."); assert(graph::is_same_variable_like(graph::sqrt(a), a) && "Expected same."); assert(graph::is_same_variable_like(a, graph::pow(a, c)) && "Expected same."); assert(graph::is_same_variable_like(graph::pow(a, c), a) && "Expected same."); assert(graph::is_same_variable_like(graph::sqrt(a), graph::pow(a, c)) && "Expected same."); assert(graph::is_same_variable_like(graph::pow(a, c), graph::sqrt(a)) && "Expected same."); assert(!graph::is_same_variable_like(graph::pow(c, a), graph::sqrt(a)) && "Expected different."); auto b = graph::variable<BACKEND> (1, ""); assert(!graph::is_same_variable_like(a, graph::sqrt(b)) && "Expected different."); assert(!graph::is_same_variable_like(graph::sqrt(a), b) && "Expected different."); assert(!graph::is_same_variable_like(a, graph::pow(b, c)) && "Expected different."); assert(!graph::is_same_variable_like(graph::pow(a, c), b) && "Expected different."); assert(!graph::is_same_variable_like(graph::sqrt(a), graph::pow(b, c)) && "Expected different."); assert(!graph::is_same_variable_like(graph::pow(a, c), graph::sqrt(b)) && "Expected different."); } //------------------------------------------------------------------------------ Loading