From d2e7bef7d86582b974616b6eb0cff010332e0c51 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 2 Jan 2025 16:10:47 -0500 Subject: [PATCH 1/6] Add initial implimentation for graph korc. --- CMakeLists.txt | 1 + graph_framework/equilibrium.hpp | 260 ++++++++++++++++++++++++-------- graph_framework/newton.hpp | 10 +- graph_framework/workflow.hpp | 31 ++++ 4 files changed, 237 insertions(+), 65 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e2865f..6e6f3d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -328,6 +328,7 @@ endmacro () add_subdirectory (graph_driver) add_subdirectory (graph_benchmark) add_subdirectory (graph_playground) +add_subdirectory (graph_korc) #------------------------------------------------------------------------------- # Define macro function to register tests. diff --git a/graph_framework/equilibrium.hpp b/graph_framework/equilibrium.hpp index 4701f7b..f44369f 100644 --- a/graph_framework/equilibrium.hpp +++ b/graph_framework/equilibrium.hpp @@ -17,6 +17,7 @@ #include "piecewise.hpp" #include "math.hpp" #include "arithmetic.hpp" +#include "newton.hpp" namespace equilibrium { /// Lock to syncronize netcdf accross threads. @@ -156,6 +157,15 @@ namespace equilibrium { graph::shared_leaf y, graph::shared_leaf z) = 0; +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// The characteristic field is equilibrium dependent. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf get_characteristic_field() = 0; + //------------------------------------------------------------------------------ /// @brief Get the contravariant basis vector in the x1 direction. /// @@ -358,6 +368,18 @@ namespace equilibrium { auto zero = graph::zero (); return graph::vector(zero, zero, zero); } + +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// To avoid divide by zeros use the value of 1. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + return graph::one (); + } }; //------------------------------------------------------------------------------ @@ -469,6 +491,18 @@ namespace equilibrium { graph::shared_leaf z) final { return graph::vector(0.0, 0.0, 0.1*x + 1.0); } + +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// Use the value at the y intercept. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + return graph::one (); + } }; //------------------------------------------------------------------------------ @@ -585,6 +619,18 @@ namespace equilibrium { auto zero = graph::zero (); return graph::vector(zero, zero, graph::one ()); } + +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// Use the value at the y intercept. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + return graph::one (); + } }; //------------------------------------------------------------------------------ @@ -700,6 +746,18 @@ namespace equilibrium { graph::shared_leaf z) final { return graph::vector(0.0, 0.0, 0.01*x + 1.0); } + +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// Use the value at the y intercept. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + return graph::one (); + } }; //------------------------------------------------------------------------------ @@ -813,6 +871,18 @@ namespace equilibrium { auto zero = graph::zero (); return graph::vector(graph::one (), zero, zero); } + +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// Use the value at the y intercept. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + return graph::one (); + } }; //------------------------------------------------------------------------------ @@ -959,6 +1029,57 @@ namespace equilibrium { /// Cached magnetic field vector. graph::shared_vector b_cache; +/// Cached magnetic field vector. + graph::shared_leaf psi_norm_cache; + +//------------------------------------------------------------------------------ +/// @brief Build psi. +/// +/// @param[in] r_norm The normalized radial position. +/// @param[in] z_norm The normalized z position. +/// @returns The psi value. +//------------------------------------------------------------------------------ + graph::shared_leaf + build_psi(graph::shared_leaf r_norm, + graph::shared_leaf z_norm) { + auto c00_temp = graph::piecewise_2D(c00, num_cols, r_norm, z_norm); + auto c01_temp = graph::piecewise_2D(c01, num_cols, r_norm, z_norm); + auto c02_temp = graph::piecewise_2D(c02, num_cols, r_norm, z_norm); + auto c03_temp = graph::piecewise_2D(c03, num_cols, r_norm, z_norm); + + auto c10_temp = graph::piecewise_2D(c10, num_cols, r_norm, z_norm); + auto c11_temp = graph::piecewise_2D(c11, num_cols, r_norm, z_norm); + auto c12_temp = graph::piecewise_2D(c12, num_cols, r_norm, z_norm); + auto c13_temp = graph::piecewise_2D(c13, num_cols, r_norm, z_norm); + + auto c20_temp = graph::piecewise_2D(c20, num_cols, r_norm, z_norm); + auto c21_temp = graph::piecewise_2D(c21, num_cols, r_norm, z_norm); + auto c22_temp = graph::piecewise_2D(c22, num_cols, r_norm, z_norm); + auto c23_temp = graph::piecewise_2D(c23, num_cols, r_norm, z_norm); + + auto c30_temp = graph::piecewise_2D(c30, num_cols, r_norm, z_norm); + auto c31_temp = graph::piecewise_2D(c31, num_cols, r_norm, z_norm); + auto c32_temp = graph::piecewise_2D(c32, num_cols, r_norm, z_norm); + auto c33_temp = graph::piecewise_2D(c33, num_cols, r_norm, z_norm); + + return c00_temp + + c01_temp*z_norm + + c02_temp*(z_norm*z_norm) + + c03_temp*(z_norm*z_norm*z_norm) + + c10_temp*r_norm + + c11_temp*r_norm*z_norm + + c12_temp*r_norm*(z_norm*z_norm) + + c13_temp*r_norm*(z_norm*z_norm*z_norm) + + c20_temp*(r_norm*r_norm) + + c21_temp*(r_norm*r_norm)*z_norm + + c22_temp*(r_norm*r_norm)*(z_norm*z_norm) + + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm) + + c30_temp*(r_norm*r_norm*r_norm) + + c31_temp*(r_norm*r_norm*r_norm)*z_norm + + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm) + + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm); + } + //------------------------------------------------------------------------------ /// @brief Set cache values. /// @@ -979,78 +1100,41 @@ namespace equilibrium { z_cache = z; auto r = graph::sqrt(x*x + y*y); - auto r_norm = (r - rmin)/dr; auto z_norm = (z - zmin)/dz; - auto c00_temp = graph::piecewise_2D(c00, num_cols, r_norm, z_norm); - auto c01_temp = graph::piecewise_2D(c01, num_cols, r_norm, z_norm); - auto c02_temp = graph::piecewise_2D(c02, num_cols, r_norm, z_norm); - auto c03_temp = graph::piecewise_2D(c03, num_cols, r_norm, z_norm); - - auto c10_temp = graph::piecewise_2D(c10, num_cols, r_norm, z_norm); - auto c11_temp = graph::piecewise_2D(c11, num_cols, r_norm, z_norm); - auto c12_temp = graph::piecewise_2D(c12, num_cols, r_norm, z_norm); - auto c13_temp = graph::piecewise_2D(c13, num_cols, r_norm, z_norm); - - auto c20_temp = graph::piecewise_2D(c20, num_cols, r_norm, z_norm); - auto c21_temp = graph::piecewise_2D(c21, num_cols, r_norm, z_norm); - auto c22_temp = graph::piecewise_2D(c22, num_cols, r_norm, z_norm); - auto c23_temp = graph::piecewise_2D(c23, num_cols, r_norm, z_norm); - - auto c30_temp = graph::piecewise_2D(c30, num_cols, r_norm, z_norm); - auto c31_temp = graph::piecewise_2D(c31, num_cols, r_norm, z_norm); - auto c32_temp = graph::piecewise_2D(c32, num_cols, r_norm, z_norm); - auto c33_temp = graph::piecewise_2D(c33, num_cols, r_norm, z_norm); - - auto psi = c00_temp - + c01_temp*z_norm - + c02_temp*(z_norm*z_norm) - + c03_temp*(z_norm*z_norm*z_norm) - + c10_temp*r_norm - + c11_temp*r_norm*z_norm - + c12_temp*r_norm*(z_norm*z_norm) - + c13_temp*r_norm*(z_norm*z_norm*z_norm) - + c20_temp*(r_norm*r_norm) - + c21_temp*(r_norm*r_norm)*z_norm - + c22_temp*(r_norm*r_norm)*(z_norm*z_norm) - + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm) - + c30_temp*(r_norm*r_norm*r_norm) - + c31_temp*(r_norm*r_norm*r_norm)*z_norm - + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm) - + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm); - - auto psi_norm = (psi - psimin)/dpsi; - - auto n0_temp = graph::piecewise_1D(ne_c0, psi_norm); - auto n1_temp = graph::piecewise_1D(ne_c1, psi_norm); - auto n2_temp = graph::piecewise_1D(ne_c2, psi_norm); - auto n3_temp = graph::piecewise_1D(ne_c3, psi_norm); + auto psi = build_psi(r_norm, z_norm); + psi_norm_cache = (psi - psimin)/dpsi; + + auto n0_temp = graph::piecewise_1D(ne_c0, psi_norm_cache); + auto n1_temp = graph::piecewise_1D(ne_c1, psi_norm_cache); + auto n2_temp = graph::piecewise_1D(ne_c2, psi_norm_cache); + auto n3_temp = graph::piecewise_1D(ne_c3, psi_norm_cache); ne_cache = ne_scale*(n0_temp + - n1_temp*psi_norm + - n2_temp*psi_norm*psi_norm + - n3_temp*psi_norm*psi_norm*psi_norm); + n1_temp*psi_norm_cache + + n2_temp*psi_norm_cache*psi_norm_cache + + n3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache); - auto t0_temp = graph::piecewise_1D(te_c0, psi_norm); - auto t1_temp = graph::piecewise_1D(te_c1, psi_norm); - auto t2_temp = graph::piecewise_1D(te_c2, psi_norm); - auto t3_temp = graph::piecewise_1D(te_c3, psi_norm); + auto t0_temp = graph::piecewise_1D(te_c0, psi_norm_cache); + auto t1_temp = graph::piecewise_1D(te_c1, psi_norm_cache); + auto t2_temp = graph::piecewise_1D(te_c2, psi_norm_cache); + auto t3_temp = graph::piecewise_1D(te_c3, psi_norm_cache); te_cache = te_scale*(t0_temp + - t1_temp*psi_norm + - t2_temp*psi_norm*psi_norm + - t3_temp*psi_norm*psi_norm*psi_norm); + t1_temp*psi_norm_cache + + t2_temp*psi_norm_cache*psi_norm_cache + + t3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache); - auto p0_temp = graph::piecewise_1D(pres_c0, psi_norm); - auto p1_temp = graph::piecewise_1D(pres_c1, psi_norm); - auto p2_temp = graph::piecewise_1D(pres_c2, psi_norm); - auto p3_temp = graph::piecewise_1D(pres_c3, psi_norm); + auto p0_temp = graph::piecewise_1D(pres_c0, psi_norm_cache); + auto p1_temp = graph::piecewise_1D(pres_c1, psi_norm_cache); + auto p2_temp = graph::piecewise_1D(pres_c2, psi_norm_cache); + auto p3_temp = graph::piecewise_1D(pres_c3, psi_norm_cache); auto pressure = pres_scale*(p0_temp + - p1_temp*psi_norm + - p2_temp*psi_norm*psi_norm + - p3_temp*psi_norm*psi_norm*psi_norm); + p1_temp*psi_norm_cache + + p2_temp*psi_norm_cache*psi_norm_cache + + p3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache); auto q = graph::constant (static_cast (1.60218E-19)); @@ -1271,6 +1355,44 @@ namespace equilibrium { set_cache(x, y, z); return b_cache; } + +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// Use the value at the y intercept. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + auto x_axis = graph::variable (1, "x"); + auto y_axis = graph::variable (1, "y"); + auto z_axis = graph::variable (1, "z"); + x_axis->set(static_cast (1.7)); + y_axis->set(static_cast (0.0)); + z_axis->set(static_cast (0.0)); + auto b_vec = get_magnetic_field(x_axis, y_axis, z_axis); + auto b_mod = b_vec->length(); + + graph::input_nodes inputs { + graph::variable_cast(x_axis), + graph::variable_cast(y_axis), + graph::variable_cast(z_axis) + }; + + workflow::manager work(0); + solver::newton(work, { + x_axis, z_axis + }, inputs, psi_norm_cache, static_cast (1.0E-30), 1000, static_cast (0.1)); + work.add_item(inputs, {b_mod}, {}, "bmod_at_axis"); + work.compile(); + work.run(); + + T result; + work.copy_to_host(b_mod, &result); + + return graph::constant (result); + } }; //------------------------------------------------------------------------------ @@ -2014,6 +2136,22 @@ namespace equilibrium { return bvec_cache; } +//------------------------------------------------------------------------------ +/// @brief Get the characteristic field. +/// +/// Use the value at the y intercept. +/// +/// @returns The characteristic field. +//------------------------------------------------------------------------------ + virtual graph::shared_leaf + get_characteristic_field() final { + auto s_axis = graph::zero (); + auto u_axis = graph::zero (); + auto v_axis = graph::zero (); + auto b_vec = get_magnetic_field(s_axis, u_axis, v_axis); + return b_vec->length(); + } + //------------------------------------------------------------------------------ /// @brief Get the x position. /// diff --git a/graph_framework/newton.hpp b/graph_framework/newton.hpp index b801cf3..825c471 100644 --- a/graph_framework/newton.hpp +++ b/graph_framework/newton.hpp @@ -24,9 +24,10 @@ namespace solver { /// @param[in] inputs Inputs for jit compile. /// @param[in] func Function to find the root of. /// @param[in] tolarance Tolarance to solve the dispersion function -/// to. +/// to. /// @param[in] max_iterations Maximum number of iterations before giving -/// up. +/// up. +/// @param[in] step Newton step size. //------------------------------------------------------------------------------ template void newton(workflow::manager &work, @@ -34,10 +35,11 @@ namespace solver { graph::input_nodes inputs, graph::shared_leaf func, const T tolarance = 1.0E-30, - const size_t max_iterations = 1000) { + const size_t max_iterations = 1000, + const T step = 1.0) { graph::map_nodes setters; for (auto x : vars) { - setters.push_back({x - func/func->df(x), + setters.push_back({x - step*func/func->df(x), graph::variable_cast(x)}); } diff --git a/graph_framework/workflow.hpp b/graph_framework/workflow.hpp index 7c39633..370d0e0 100644 --- a/graph_framework/workflow.hpp +++ b/graph_framework/workflow.hpp @@ -160,6 +160,8 @@ namespace workflow { private: /// JIT context. jit::context context; +/// List of prework items. + std::vector>> preitems; /// List of work items. std::vector>> items; /// Use reduction. @@ -173,6 +175,23 @@ namespace workflow { //------------------------------------------------------------------------------ manager(const size_t index) : context(index), add_reduction(false) {} +//------------------------------------------------------------------------------ +/// @brief Add a pre workflow item. +/// +/// @param[in] in Input variables. +/// @param[in] out Output nodes. +/// @param[in] maps Setter maps. +/// @param[in] name Name of the workitem. +//------------------------------------------------------------------------------ + void add_preitem(graph::input_nodes in, + graph::output_nodes out, + graph::map_nodes maps, + const std::string name) { + preitems.push_back(std::make_unique> (in, out, + maps, name, + context)); + } + //------------------------------------------------------------------------------ /// @brief Add a workflow item. /// @@ -219,11 +238,23 @@ namespace workflow { void compile() { context.compile(add_reduction); + for (auto &item : preitems) { + item->create_kernel_call(context); + } for (auto &item : items) { item->create_kernel_call(context); } } +//------------------------------------------------------------------------------ +/// @brief Run prework items. +//------------------------------------------------------------------------------ + void pre_run() { + for (auto &item : preitems) { + item->run(); + } + } + //------------------------------------------------------------------------------ /// @brief Run work items. //------------------------------------------------------------------------------ -- GitLab From 36bd4ae966f0b50810174397256a82cb1eb76472 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Mon, 6 Jan 2025 14:33:09 -0500 Subject: [PATCH 2/6] Added multithreading. --- graph_framework/equilibrium.hpp | 27 ++++-- graph_korc/CMakeLists.txt | 6 ++ graph_korc/xkorc.cpp | 150 ++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 graph_korc/CMakeLists.txt create mode 100644 graph_korc/xkorc.cpp diff --git a/graph_framework/equilibrium.hpp b/graph_framework/equilibrium.hpp index f44369f..e97c5c5 100644 --- a/graph_framework/equilibrium.hpp +++ b/graph_framework/equilibrium.hpp @@ -162,9 +162,11 @@ namespace equilibrium { /// /// The characteristic field is equilibrium dependent. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ - virtual graph::shared_leaf get_characteristic_field() = 0; + virtual graph::shared_leaf + get_characteristic_field(const size_t device_number=0) = 0; //------------------------------------------------------------------------------ /// @brief Get the contravariant basis vector in the x1 direction. @@ -374,10 +376,11 @@ namespace equilibrium { /// /// To avoid divide by zeros use the value of 1. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { return graph::one (); } }; @@ -497,10 +500,11 @@ namespace equilibrium { /// /// Use the value at the y intercept. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { return graph::one (); } }; @@ -625,10 +629,11 @@ namespace equilibrium { /// /// Use the value at the y intercept. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { return graph::one (); } }; @@ -752,10 +757,11 @@ namespace equilibrium { /// /// Use the value at the y intercept. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { return graph::one (); } }; @@ -877,10 +883,11 @@ namespace equilibrium { /// /// Use the value at the y intercept. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { return graph::one (); } }; @@ -1361,10 +1368,11 @@ namespace equilibrium { /// /// Use the value at the y intercept. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { auto x_axis = graph::variable (1, "x"); auto y_axis = graph::variable (1, "y"); auto z_axis = graph::variable (1, "z"); @@ -1380,7 +1388,7 @@ namespace equilibrium { graph::variable_cast(z_axis) }; - workflow::manager work(0); + workflow::manager work(device_number); solver::newton(work, { x_axis, z_axis }, inputs, psi_norm_cache, static_cast (1.0E-30), 1000, static_cast (0.1)); @@ -2141,10 +2149,11 @@ namespace equilibrium { /// /// Use the value at the y intercept. /// +/// @params[in] device_number Device to use. /// @returns The characteristic field. //------------------------------------------------------------------------------ virtual graph::shared_leaf - get_characteristic_field() final { + get_characteristic_field(const size_t device_number=0) final { auto s_axis = graph::zero (); auto u_axis = graph::zero (); auto v_axis = graph::zero (); diff --git a/graph_korc/CMakeLists.txt b/graph_korc/CMakeLists.txt new file mode 100644 index 0000000..d9ab551 --- /dev/null +++ b/graph_korc/CMakeLists.txt @@ -0,0 +1,6 @@ +add_tool_target (xkorc) + +if (${USE_PCH}) + target_precompile_headers (xrays_bench REUSE_FROM xrays) +endif () + diff --git a/graph_korc/xkorc.cpp b/graph_korc/xkorc.cpp new file mode 100644 index 0000000..8869ef1 --- /dev/null +++ b/graph_korc/xkorc.cpp @@ -0,0 +1,150 @@ +#include "../graph_framework/equilibrium.hpp" +#include "../graph_framework/timing.hpp" + +//------------------------------------------------------------------------------ +/// @brief Main program of the driver. +/// +/// @param[in] argc Number of commandline arguments. +/// @param[in] argv Array of commandline arguments. +//------------------------------------------------------------------------------ +int main(int argc, const char * argv[]) { + START_GPU + (void)argc; + (void)argv; + + const timeing::measure_diagnostic t_total("Total Time"); + + const size_t num_particles = 1000000; + std::cout << "Num particles " << num_particles << std::endl; + std::vector threads(std::max(std::min(static_cast (jit::context::max_concurrency()), + static_cast (num_particles)), + static_cast (1))); + + const size_t batch = num_particles/threads.size(); + const size_t extra = num_particles%threads.size(); + + for (size_t i = 0, ie = threads.size(); i < ie; i++) { + threads[i] = std::thread([num_particles, batch, extra] (const size_t thread_number) -> void { + const size_t local_num_particles = batch + (extra > thread_number ? 1 : 0); + + const timeing::measure_diagnostic t_setup("Setup Time"); + + auto eq = equilibrium::make_efit (EFIT_FILE); + //auto eq = equilibrium::make_slab_density (); + auto b0 = eq->get_characteristic_field(thread_number); + const double q = 1.602176634E-19; + const double me = 9.1093837139E-31; + const double c = 299792458.0; + + auto gryo_period = me/(q*b0); + std::cout << "gryo_period " << gryo_period->evaluate().at(0) << std::endl; + auto larmor_radius = c*gryo_period; + std::cout << "larmor_radius " << larmor_radius->evaluate().at(0) << std::endl; + + std::cout << "Local num particles " << local_num_particles << std::endl; + + auto ux = graph::variable (local_num_particles, "u_{x}"); + auto uy = graph::variable (local_num_particles, "u_{y}"); + auto uz = graph::variable (local_num_particles, "u_{z}"); + + ux->set(0.99); + uy->set(0.0); + uz->set(0.0); + + auto x = graph::variable (local_num_particles, "x"); + auto y = graph::variable (local_num_particles, "y"); + auto z = graph::variable (local_num_particles, "z"); + auto pos = graph::vector(x, y, z); + + x->set(1.7); + y->set(0.0); + z->set(0.0); + + auto u_vec = graph::vector(ux, uy, uz); + + auto gamma = graph::variable (local_num_particles, "\\gamma"); + + auto dt = graph::constant (0.1); + + auto gamma_init = graph::sqrt(1.0 - ux*ux - uy*uy - uz*uz); + + auto u_init = gamma_init*u_vec; + + workflow::manager work(0); + work.add_preitem({ + graph::variable_cast(x), + graph::variable_cast(y), + graph::variable_cast(z), + graph::variable_cast(ux), + graph::variable_cast(uy), + graph::variable_cast(uz), + graph::variable_cast(gamma) + }, {}, { + {u_init->get_x(), graph::variable_cast(ux)}, + {u_init->get_y(), graph::variable_cast(uy)}, + {u_init->get_z(), graph::variable_cast(uz)}, + {gamma_init, graph::variable_cast(gamma)} + }, "initalize_gamma"); + + auto pos_next = pos + larmor_radius*dt*u_vec/gamma; + + auto b_vec = eq->get_magnetic_field(pos_next->get_x(), + pos_next->get_y(), + pos_next->get_z())/b0; + + auto u_prime = u_vec + dt*u_vec->cross(b_vec)/(2.0*gamma); + + auto tau = dt*0.5*b_vec; + auto tau_sq = tau->dot(tau); + auto speed_sq = u_vec->dot(u_vec); + auto sigma = 1.0 + speed_sq - tau_sq; + auto ustar = u_vec->dot(tau); + + auto gamma_next = graph::sqrt(0.5*(sigma + graph::sqrt(sigma*sigma + 4.0*(tau_sq + ustar*ustar)))); + auto t = tau/gamma_next; + + auto s = 1.0/(1.0 + t->dot(t)); + auto u_prime_dot_t = u_prime->dot(t); + + auto u_next = s*(u_prime + u_prime_dot_t*t + u_prime->cross(t)); + + work.add_item({ + graph::variable_cast(x), + graph::variable_cast(y), + graph::variable_cast(z), + graph::variable_cast(ux), + graph::variable_cast(uy), + graph::variable_cast(uz), + graph::variable_cast(gamma) + }, {}, { + {pos_next->get_x(), graph::variable_cast(x)}, + {pos_next->get_y(), graph::variable_cast(y)}, + {pos_next->get_z(), graph::variable_cast(z)}, + {u_next->get_x(), graph::variable_cast(ux)}, + {u_next->get_y(), graph::variable_cast(uy)}, + {u_next->get_z(), graph::variable_cast(uz)}, + {gamma_next, graph::variable_cast(gamma)} + }, "step"); + + work.compile(); + t_setup.print(); + + const timeing::measure_diagnostic t_run("Run Time"); + work.pre_run(); + for (size_t i = 0; i < 1000000; i++) { + work.run(); + } + work.wait(); + t_run.print(); + }, i); + } + + for (std::thread &t : threads) { + t.join(); + } + + std::cout << std::endl << "Timing:" << std::endl; + t_total.print(); + + END_GPU +} -- GitLab From 63a58c38d0918c706299564b61ea01250850c3a7 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 7 Jan 2025 15:39:24 -0500 Subject: [PATCH 3/6] Commit changes before merging in bug fix. --- graph_framework/vector.hpp | 18 ++++++++ graph_korc/CMakeLists.txt | 2 +- graph_korc/xkorc.cpp | 86 ++++++++++++++++++++++---------------- 3 files changed, 69 insertions(+), 37 deletions(-) diff --git a/graph_framework/vector.hpp b/graph_framework/vector.hpp index 87b4db7..fea3069 100644 --- a/graph_framework/vector.hpp +++ b/graph_framework/vector.hpp @@ -294,6 +294,24 @@ namespace graph { l->get_z() + r->get_z()); } +//------------------------------------------------------------------------------ +/// @brief Subtraction operator. +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @param[in] l Left vector. +/// @param[in] r Right vector. +/// @returns The vector vector addition. +//------------------------------------------------------------------------------ + template + shared_vector operator-(shared_vector l, + shared_vector r) { + return vector(l->get_x() - r->get_x(), + l->get_y() - r->get_y(), + l->get_z() - r->get_z()); + } + //------------------------------------------------------------------------------ /// @brief Multiplication operator. /// diff --git a/graph_korc/CMakeLists.txt b/graph_korc/CMakeLists.txt index d9ab551..460f9c5 100644 --- a/graph_korc/CMakeLists.txt +++ b/graph_korc/CMakeLists.txt @@ -1,6 +1,6 @@ add_tool_target (xkorc) if (${USE_PCH}) - target_precompile_headers (xrays_bench REUSE_FROM xrays) + target_precompile_headers (xkorc REUSE_FROM xrays) endif () diff --git a/graph_korc/xkorc.cpp b/graph_korc/xkorc.cpp index 8869ef1..9b83317 100644 --- a/graph_korc/xkorc.cpp +++ b/graph_korc/xkorc.cpp @@ -14,7 +14,7 @@ int main(int argc, const char * argv[]) { const timeing::measure_diagnostic t_total("Total Time"); - const size_t num_particles = 1000000; + const size_t num_particles = 1; std::cout << "Num particles " << num_particles << std::endl; std::vector threads(std::max(std::min(static_cast (jit::context::max_concurrency()), static_cast (num_particles)), @@ -28,49 +28,53 @@ int main(int argc, const char * argv[]) { const size_t local_num_particles = batch + (extra > thread_number ? 1 : 0); const timeing::measure_diagnostic t_setup("Setup Time"); - + auto eq = equilibrium::make_efit (EFIT_FILE); //auto eq = equilibrium::make_slab_density (); auto b0 = eq->get_characteristic_field(thread_number); const double q = 1.602176634E-19; const double me = 9.1093837139E-31; const double c = 299792458.0; - + auto gryo_period = me/(q*b0); std::cout << "gryo_period " << gryo_period->evaluate().at(0) << std::endl; auto larmor_radius = c*gryo_period; std::cout << "larmor_radius " << larmor_radius->evaluate().at(0) << std::endl; std::cout << "Local num particles " << local_num_particles << std::endl; - + auto ux = graph::variable (local_num_particles, "u_{x}"); auto uy = graph::variable (local_num_particles, "u_{y}"); auto uz = graph::variable (local_num_particles, "u_{z}"); - + ux->set(0.99); uy->set(0.0); uz->set(0.0); - + auto x = graph::variable (local_num_particles, "x"); auto y = graph::variable (local_num_particles, "y"); auto z = graph::variable (local_num_particles, "z"); auto pos = graph::vector(x, y, z); - + x->set(1.7); y->set(0.0); z->set(0.0); - + auto u_vec = graph::vector(ux, uy, uz); - + auto gamma = graph::variable (local_num_particles, "\\gamma"); - - auto dt = graph::constant (0.1); - + + auto dt = graph::constant (0.01); + auto gamma_init = graph::sqrt(1.0 - ux*ux - uy*uy - uz*uz); - + auto u_init = gamma_init*u_vec; - - workflow::manager work(0); + + auto b_vec = eq->get_magnetic_field(pos->get_x(), + pos->get_y(), + pos->get_z())/b0; + + workflow::manager work(thread_number); work.add_preitem({ graph::variable_cast(x), graph::variable_cast(y), @@ -85,29 +89,25 @@ int main(int argc, const char * argv[]) { {u_init->get_z(), graph::variable_cast(uz)}, {gamma_init, graph::variable_cast(gamma)} }, "initalize_gamma"); - - auto pos_next = pos + larmor_radius*dt*u_vec/gamma; - - auto b_vec = eq->get_magnetic_field(pos_next->get_x(), - pos_next->get_y(), - pos_next->get_z())/b0; - - auto u_prime = u_vec + dt*u_vec->cross(b_vec)/(2.0*gamma); - - auto tau = dt*0.5*b_vec; + + auto u_prime = u_vec - dt*u_vec->cross(b_vec)/(2.0*gamma); + + auto tau = -0.5*dt*b_vec; auto tau_sq = tau->dot(tau); - auto speed_sq = u_vec->dot(u_vec); + auto speed_sq = u_prime->dot(u_prime); auto sigma = 1.0 + speed_sq - tau_sq; - auto ustar = u_vec->dot(tau); - + auto ustar = u_prime->dot(tau); + auto gamma_next = graph::sqrt(0.5*(sigma + graph::sqrt(sigma*sigma + 4.0*(tau_sq + ustar*ustar)))); auto t = tau/gamma_next; - - auto s = 1.0/(1.0 + t->dot(t)); + + auto s = 1.0 + t->dot(t); auto u_prime_dot_t = u_prime->dot(t); - - auto u_next = s*(u_prime + u_prime_dot_t*t + u_prime->cross(t)); - + + auto u_next = (u_prime + u_prime_dot_t*t + u_prime->cross(t))/s; + + auto pos_next = pos + larmor_radius*dt*u_next/gamma; + work.add_item({ graph::variable_cast(x), graph::variable_cast(y), @@ -116,7 +116,11 @@ int main(int argc, const char * argv[]) { graph::variable_cast(uy), graph::variable_cast(uz), graph::variable_cast(gamma) - }, {}, { + }, { + tau_sq, + tau->get_x()*tau->get_x(), + tau->get_y()*tau->get_y(), + }, { {pos_next->get_x(), graph::variable_cast(x)}, {pos_next->get_y(), graph::variable_cast(y)}, {pos_next->get_z(), graph::variable_cast(z)}, @@ -125,14 +129,24 @@ int main(int argc, const char * argv[]) { {u_next->get_z(), graph::variable_cast(uz)}, {gamma_next, graph::variable_cast(gamma)} }, "step"); - + tau->get_x()->to_latex(); + std::cout << "\\\\" << std::endl; + tau->get_y()->to_latex(); + std::cout << "\\\\" << std::endl; + (tau->get_x()*tau->get_x())->to_latex(); + std::cout << "\\\\" << std::endl; + (tau->get_y()*tau->get_y())->to_latex(); + std::cout << "\\\\" << std::endl; + work.compile(); t_setup.print(); - + const timeing::measure_diagnostic t_run("Run Time"); work.pre_run(); + work.print(0, {x, y, z, ux, uy, uz, gamma, tau_sq, tau->get_x()*tau->get_x(), tau->get_y()*tau->get_y()}); for (size_t i = 0; i < 1000000; i++) { work.run(); + work.print(0, {x, y, z, ux, uy, uz, gamma, tau_sq, tau->get_x()*tau->get_x(), tau->get_y()*tau->get_y()}); } work.wait(); t_run.print(); -- GitLab From ce0cfde6b3f345c2a200830fb80625087e73c307 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 7 Jan 2025 16:29:04 -0500 Subject: [PATCH 4/6] Correct gamma rebase bug fix that was causing unphysical damping. --- graph_korc/xkorc.cpp | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/graph_korc/xkorc.cpp b/graph_korc/xkorc.cpp index 9b83317..7aaea2d 100644 --- a/graph_korc/xkorc.cpp +++ b/graph_korc/xkorc.cpp @@ -64,9 +64,9 @@ int main(int argc, const char * argv[]) { auto gamma = graph::variable (local_num_particles, "\\gamma"); - auto dt = graph::constant (0.01); + auto dt = graph::constant (0.25); - auto gamma_init = graph::sqrt(1.0 - ux*ux - uy*uy - uz*uz); + auto gamma_init = 1.0/graph::sqrt(1.0 - u_vec->dot(u_vec)); auto u_init = gamma_init*u_vec; @@ -116,11 +116,7 @@ int main(int argc, const char * argv[]) { graph::variable_cast(uy), graph::variable_cast(uz), graph::variable_cast(gamma) - }, { - tau_sq, - tau->get_x()*tau->get_x(), - tau->get_y()*tau->get_y(), - }, { + }, {}, { {pos_next->get_x(), graph::variable_cast(x)}, {pos_next->get_y(), graph::variable_cast(y)}, {pos_next->get_z(), graph::variable_cast(z)}, @@ -129,24 +125,14 @@ int main(int argc, const char * argv[]) { {u_next->get_z(), graph::variable_cast(uz)}, {gamma_next, graph::variable_cast(gamma)} }, "step"); - tau->get_x()->to_latex(); - std::cout << "\\\\" << std::endl; - tau->get_y()->to_latex(); - std::cout << "\\\\" << std::endl; - (tau->get_x()*tau->get_x())->to_latex(); - std::cout << "\\\\" << std::endl; - (tau->get_y()*tau->get_y())->to_latex(); - std::cout << "\\\\" << std::endl; work.compile(); t_setup.print(); const timeing::measure_diagnostic t_run("Run Time"); work.pre_run(); - work.print(0, {x, y, z, ux, uy, uz, gamma, tau_sq, tau->get_x()*tau->get_x(), tau->get_y()*tau->get_y()}); for (size_t i = 0; i < 1000000; i++) { work.run(); - work.print(0, {x, y, z, ux, uy, uz, gamma, tau_sq, tau->get_x()*tau->get_x(), tau->get_y()*tau->get_y()}); } work.wait(); t_run.print(); -- GitLab From cf7a3be8a990c94947f76ca5509d19b83ff4c588 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 8 Jan 2025 14:15:36 -0500 Subject: [PATCH 5/6] Remove unused inputs from init kernel. --- graph_korc/xkorc.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/graph_korc/xkorc.cpp b/graph_korc/xkorc.cpp index 7aaea2d..95ddfdb 100644 --- a/graph_korc/xkorc.cpp +++ b/graph_korc/xkorc.cpp @@ -14,7 +14,7 @@ int main(int argc, const char * argv[]) { const timeing::measure_diagnostic t_total("Total Time"); - const size_t num_particles = 1; + const size_t num_particles = 10000000; std::cout << "Num particles " << num_particles << std::endl; std::vector threads(std::max(std::min(static_cast (jit::context::max_concurrency()), static_cast (num_particles)), @@ -76,9 +76,6 @@ int main(int argc, const char * argv[]) { workflow::manager work(thread_number); work.add_preitem({ - graph::variable_cast(x), - graph::variable_cast(y), - graph::variable_cast(z), graph::variable_cast(ux), graph::variable_cast(uy), graph::variable_cast(uz), -- GitLab From 1e2e89bd3ef619b04c7ce1a1c308ac413242e105 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 9 Jan 2025 11:30:19 -0500 Subject: [PATCH 6/6] Reduce common factors in adds of powers. --- graph_framework.xcodeproj/project.pbxproj | 175 ++++++++++++++++++++++ graph_framework/arithmetic.hpp | 35 ++++- graph_framework/jit.hpp | 2 +- graph_framework/node.hpp | 30 ++++ graph_tests/arithmetic_test.cpp | 41 +++++ 5 files changed, 279 insertions(+), 4 deletions(-) diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index ec83e30..fc44df7 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -1374,6 +1374,93 @@ ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.6; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGenData", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lLLVMSandboxIR", + "-lLLVMFrontendAtomic", + "-lLLVMCGData", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; @@ -1388,6 +1475,93 @@ GCC_C_LANGUAGE_STANDARD = gnu17; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.6; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGenData", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lLLVMSandboxIR", + "-lLLVMFrontendAtomic", + "-lLLVMCGData", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; @@ -1427,6 +1601,7 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "DEBUG=1", "$(inherited)", + USE_INPUT_CACHE, ); MACOSX_DEPLOYMENT_TARGET = 13.3; OTHER_LDFLAGS = ( diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 37d06b7..01d2ab3 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -490,12 +490,41 @@ namespace graph { } } -// Handle cases like: + auto pl = pow_cast(this->left); + auto pr = pow_cast(this->right); + +// (a*b)^c + (a*d)^c -> a^c*(b^c + d^c) +// (b*a)^c + (a*d)^c -> a^c*(b^c + d^c) +// (a*b)^c + (d*a)^c -> a^c*(b^c + d^c) +// (b*a)^c + (d*a)^c -> a^c*(b^c + d^c) + if (pl.get() && pr.get() && + pl->get_right()->is_match(pr->get_right())) { + auto plm = multiply_cast(pl->get_left()); + auto prm = multiply_cast(pr->get_left()); + if (plm.get() && prm.get()) { + if (plm->get_left()->is_match(prm->get_left())) { + return pow(plm->get_left(), pl->get_right())* + (pow(plm->get_right(), pl->get_right()) + + pow(prm->get_right(), pl->get_right())); + } else if (plm->get_left()->is_match(prm->get_right())) { + return pow(plm->get_left(), pl->get_right())* + (pow(plm->get_right(), pl->get_right()) + + pow(prm->get_left(), pl->get_right())); + } else if (plm->get_right()->is_match(prm->get_left())) { + return pow(plm->get_right(), pl->get_right())* + (pow(plm->get_left(), pl->get_right()) + + pow(prm->get_right(), pl->get_right())); + } else if (plm->get_right()->is_match(prm->get_right())) { + return pow(plm->get_right(), pl->get_right())* + (pow(plm->get_left(), pl->get_right()) + + pow(prm->get_left(), pl->get_right())); + } + } + } + // (a/y)^e + b/y^e -> (a^2 + b)/(y^e) // b/y^e + (a/y)^e -> (b + a^2)/(y^e) // (a/y)^e + (b/y)^e -> (a^2 + b^2)/(y^e) - auto pl = pow_cast(this->left); - auto pr = pow_cast(this->right); if (pl.get() && rd.get()) { auto rdp = pow_cast(rd->get_right()); if (rdp.get() && pl->get_right()->is_match(rdp->get_right())) { diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index 97cfcfd..83cd20c 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -133,7 +133,7 @@ namespace jit { for (auto &in : inputs) { if (usage.find(in.get()) == usage.end()) { - usage[in.get()] == 0; + usage[in.get()] = 0; } } diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 9e14963..299be69 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -1112,6 +1112,36 @@ namespace graph { return constant (static_cast (this->is_match(x))); } +//------------------------------------------------------------------------------ +/// @brief Compile preamble. +/// +/// Some nodes require additions to the preamble however most don't so define a +/// generic method that does nothing. +/// +/// @param[in,out] stream String buffer stream. +/// @param[in,out] registers List of defined registers. +/// @param[in,out] visited List of visited nodes. +/// @param[in,out] usage List of register usage count. +/// @param[in,out] textures1d List of 1D textures. +/// @param[in,out] textures2d List of 2D textures. +/// @param[in,out] avail_const_mem Available constant memory. +//------------------------------------------------------------------------------ + virtual void compile_preamble(std::ostringstream &stream, + jit::register_map ®isters, + jit::visiter_map &visited, + jit::register_usage &usage, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d, + int &avail_const_mem) { + if (usage.find(this) == usage.end()) { + usage[this] = 1; +#ifdef SHOW_USE_COUNT + } else { + ++usage[this]; +#endif + } + } + //------------------------------------------------------------------------------ /// @brief Compile the node. /// diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 90915b9..3417578 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -392,6 +392,47 @@ template void test_add() { "Expected var_a"); assert(common_var5_cast->get_left()->is_match(2.0/var_b + 3.0/var_c) && "Expected 2/b + 3/c"); + +// (a*b)^c + (a*d)^c -> a^c*(b^c + d^c) + auto common_power_factor = graph::pow(var_a*var_b, 2.0) + + graph::pow(var_a*var_c, 2.0); + auto common_power_factor_cast = multiply_cast(common_power_factor); + assert(common_power_factor_cast.get() && "Expected a multiply node."); + assert(common_power_factor_cast->get_right()->is_match(var_a*var_a) && + "Expected a^2 on the right."); + assert(common_power_factor_cast->get_left()->is_match(var_b*var_b + + var_c*var_c) && + "Expected b^2 + c^2 on the left."); +// (a*b)^c + (d*a)^c -> a^c*(b^c + d^c) + auto common_power_factor2 = graph::pow(var_a*var_b, 2.0) + + graph::pow(var_c*var_a, 2.0); + auto common_power_factor2_cast = multiply_cast(common_power_factor2); + assert(common_power_factor2_cast.get() && "Expected a multiply node."); + assert(common_power_factor2_cast->get_right()->is_match(var_a*var_a) && + "Expected a^2 on the right."); + assert(common_power_factor2_cast->get_left()->is_match(var_b*var_b + + var_c*var_c) && + "Expected b^2 + c^2 on the left."); +// (b*a)^c + (a*d)^c -> a^c*(b^c + d^c) + auto common_power_factor3 = graph::pow(var_b*var_a, 2.0) + + graph::pow(var_a*var_c, 2.0); + auto common_power_factor3_cast = multiply_cast(common_power_factor3); + assert(common_power_factor3_cast.get() && "Expected a multiply node."); + assert(common_power_factor3_cast->get_right()->is_match(var_a*var_a) && + "Expected a^2 on the right."); + assert(common_power_factor3_cast->get_left()->is_match(var_b*var_b + + var_c*var_c) && + "Expected b^2 + c^2 on the left."); +// (b*a)^c + (d*a)^c -> a^c*(b^c + d^c) + auto common_power_factor4 = graph::pow(var_b*var_a, 2.0) + + graph::pow(var_c*var_a, 2.0); + auto common_power_factor4_cast = multiply_cast(common_power_factor4); + assert(common_power_factor4_cast.get() && "Expected a multiply node."); + assert(common_power_factor4_cast->get_right()->is_match(var_a*var_a) && + "Expected a^2 on the right."); + assert(common_power_factor4_cast->get_left()->is_match(var_b*var_b + + var_c*var_c) && + "Expected b^2 + c^2 on the left."); } //------------------------------------------------------------------------------ -- GitLab