From 3b1cc40a11b7e0cc33db35bf5758eec516905f49 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 24 May 2024 16:17:23 -0400 Subject: [PATCH 01/63] Initial setup to support multi buffers in piecewise nodes. This should reduce memory accesses by only indexing these buffers once. WIP --- graph_framework/backend.hpp | 37 ++++++++++++++++++++++++++++++++++- graph_framework/node.hpp | 4 ++-- graph_framework/piecewise.hpp | 20 +++++++++---------- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index e90dcda..68c861a 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -15,6 +15,37 @@ #include "register.hpp" namespace backend { +//****************************************************************************** +// Multi buffer. +//****************************************************************************** +//------------------------------------------------------------------------------ +/// @brief Struct containing multiple piecewise buffers. +/// +/// 1D and 2D splines come in multiples of 4. To reduce the impack of memory +/// reads we can index these buffers once for four buffers at the same time. +/// +/// @tparam T Base type of the buffer. +//------------------------------------------------------------------------------ + template + struct multi { +/// A coefficient. + const T a; +/// B coefficient. + const T b; +/// C coefficient. + const T c; +/// D coefficient. + const T d; + }; + +/// Multi scalar concept. + template + concept multi_scalar = jit::float_scalar || + std::same_as> || + std::same_as> || + std::same_as>> || + std::same_as>>; + //****************************************************************************** // Data buffer. //****************************************************************************** @@ -23,7 +54,7 @@ namespace backend { /// /// @tparam T Base type of the calculation. //------------------------------------------------------------------------------ - template + template class buffer { private: /// The data buffer to hold the data. @@ -646,6 +677,10 @@ namespace backend { } return x; } + +/// Convenience type alias for multi buffers. + template + using multi_buffer = buffer>; } #endif /* backend_h */ diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 3609865..0496db1 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -257,8 +257,8 @@ namespace graph { std::shared_ptr>> cache; /// Cache for the backend buffers. inline thread_local static std::map> backend_cache; - + backend::multi_buffer> backend_cache; + /// Type def to retrieve the backend type. typedef T base; }; diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index f8a6069..e7e694b 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -57,7 +57,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::buffer &d) { + static std::string to_string(const backend::multi_buffer &d) { std::string temp; for (size_t i = 0, ie = d.size(); i < ie; i++) { temp += jit::format_to_string(d[i]); @@ -73,7 +73,7 @@ namespace graph { /// @params[in] x Argument. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::buffer &d, + static std::string to_string(const backend::multi_buffer &d, shared_leaf x) { return piecewise_1D_node::to_string(d) + jit::format_to_string(x->get_hash()); @@ -85,7 +85,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @returns The hash the node is stored in. //------------------------------------------------------------------------------ - static size_t hash_data(const backend::buffer &d) { + static size_t hash_data(const backend::multi_buffer &d) { const size_t h = std::hash{} (piecewise_1D_node::to_string(d)); for (size_t i = h; i < std::numeric_limits::max(); i++) { if (leaf_node::backend_cache.find(i) == @@ -109,7 +109,7 @@ namespace graph { /// @params[in] d Data to initalize the piecewise constant. /// @params[in] x Argument. //------------------------------------------------------------------------------ - piecewise_1D_node(const backend::buffer &d, + piecewise_1D_node(const backend::multi_buffer &d, shared_leaf x) : straight_node (x, piecewise_1D_node::to_string(d, x)), data_hash(piecewise_1D_node::hash_data(d)) {} @@ -193,8 +193,8 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile the node. /// -/// This node first evaluates the value of the argument then chooses the correct -/// piecewise index. This assumes that the argument is +/// This node first evaluates the value of the argument then chooses the +/// correct piecewise index. This assumes that the argument is /// /// x' = (x - xmin)/dx (1) /// @@ -440,7 +440,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::buffer &d) { + static std::string to_string(const backend::multi_buffer &d) { std::string temp; for (size_t i = 0, ie = d.size(); i < ie; i++) { temp += jit::format_to_string(d[i]); @@ -457,7 +457,7 @@ namespace graph { /// @params[in] y Y argument. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::buffer &d, + static std::string to_string(const backend::multi_buffer &d, shared_leaf x, shared_leaf y) { return piecewise_2D_node::to_string(d) + @@ -471,7 +471,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @returns The hash the node is stored in. //------------------------------------------------------------------------------ - static size_t hash_data(const backend::buffer &d) { + static size_t hash_data(const backend::multi_buffer &d) { const size_t h = std::hash{} (piecewise_2D_node::to_string(d)); for (size_t i = h; i < std::numeric_limits::max(); i++) { if (leaf_node::backend_cache.find(i) == @@ -499,7 +499,7 @@ namespace graph { /// @params[in] x X Argument. /// @params[in] y Y Argument. //------------------------------------------------------------------------------ - piecewise_2D_node(const backend::buffer &d, + piecewise_2D_node(const backend::multi_buffer &d, const size_t n, shared_leaf x, shared_leaf y) : -- GitLab From aff1b4e9f45e13aa54fb6bc1cd5f3b16d6c8f3d4 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 28 May 2024 12:15:31 -0400 Subject: [PATCH 02/63] Revert "Initial setup to support multi buffers in piecewise nodes. This should reduce memory accesses by only indexing these buffers once. WIP" This reverts commit 82902060358e12b6eb8943514e0d3f411696fb22. --- graph_framework/backend.hpp | 37 +---------------------------------- graph_framework/node.hpp | 4 ++-- graph_framework/piecewise.hpp | 20 +++++++++---------- 3 files changed, 13 insertions(+), 48 deletions(-) diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index 68c861a..e90dcda 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -15,37 +15,6 @@ #include "register.hpp" namespace backend { -//****************************************************************************** -// Multi buffer. -//****************************************************************************** -//------------------------------------------------------------------------------ -/// @brief Struct containing multiple piecewise buffers. -/// -/// 1D and 2D splines come in multiples of 4. To reduce the impack of memory -/// reads we can index these buffers once for four buffers at the same time. -/// -/// @tparam T Base type of the buffer. -//------------------------------------------------------------------------------ - template - struct multi { -/// A coefficient. - const T a; -/// B coefficient. - const T b; -/// C coefficient. - const T c; -/// D coefficient. - const T d; - }; - -/// Multi scalar concept. - template - concept multi_scalar = jit::float_scalar || - std::same_as> || - std::same_as> || - std::same_as>> || - std::same_as>>; - //****************************************************************************** // Data buffer. //****************************************************************************** @@ -54,7 +23,7 @@ namespace backend { /// /// @tparam T Base type of the calculation. //------------------------------------------------------------------------------ - template + template class buffer { private: /// The data buffer to hold the data. @@ -677,10 +646,6 @@ namespace backend { } return x; } - -/// Convenience type alias for multi buffers. - template - using multi_buffer = buffer>; } #endif /* backend_h */ diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 0496db1..3609865 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -257,8 +257,8 @@ namespace graph { std::shared_ptr>> cache; /// Cache for the backend buffers. inline thread_local static std::map> backend_cache; - + backend::buffer> backend_cache; + /// Type def to retrieve the backend type. typedef T base; }; diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index e7e694b..f8a6069 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -57,7 +57,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::multi_buffer &d) { + static std::string to_string(const backend::buffer &d) { std::string temp; for (size_t i = 0, ie = d.size(); i < ie; i++) { temp += jit::format_to_string(d[i]); @@ -73,7 +73,7 @@ namespace graph { /// @params[in] x Argument. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::multi_buffer &d, + static std::string to_string(const backend::buffer &d, shared_leaf x) { return piecewise_1D_node::to_string(d) + jit::format_to_string(x->get_hash()); @@ -85,7 +85,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @returns The hash the node is stored in. //------------------------------------------------------------------------------ - static size_t hash_data(const backend::multi_buffer &d) { + static size_t hash_data(const backend::buffer &d) { const size_t h = std::hash{} (piecewise_1D_node::to_string(d)); for (size_t i = h; i < std::numeric_limits::max(); i++) { if (leaf_node::backend_cache.find(i) == @@ -109,7 +109,7 @@ namespace graph { /// @params[in] d Data to initalize the piecewise constant. /// @params[in] x Argument. //------------------------------------------------------------------------------ - piecewise_1D_node(const backend::multi_buffer &d, + piecewise_1D_node(const backend::buffer &d, shared_leaf x) : straight_node (x, piecewise_1D_node::to_string(d, x)), data_hash(piecewise_1D_node::hash_data(d)) {} @@ -193,8 +193,8 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile the node. /// -/// This node first evaluates the value of the argument then chooses the -/// correct piecewise index. This assumes that the argument is +/// This node first evaluates the value of the argument then chooses the correct +/// piecewise index. This assumes that the argument is /// /// x' = (x - xmin)/dx (1) /// @@ -440,7 +440,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::multi_buffer &d) { + static std::string to_string(const backend::buffer &d) { std::string temp; for (size_t i = 0, ie = d.size(); i < ie; i++) { temp += jit::format_to_string(d[i]); @@ -457,7 +457,7 @@ namespace graph { /// @params[in] y Y argument. /// @return A string rep of the node. //------------------------------------------------------------------------------ - static std::string to_string(const backend::multi_buffer &d, + static std::string to_string(const backend::buffer &d, shared_leaf x, shared_leaf y) { return piecewise_2D_node::to_string(d) + @@ -471,7 +471,7 @@ namespace graph { /// @params[in] d Backend buffer. /// @returns The hash the node is stored in. //------------------------------------------------------------------------------ - static size_t hash_data(const backend::multi_buffer &d) { + static size_t hash_data(const backend::buffer &d) { const size_t h = std::hash{} (piecewise_2D_node::to_string(d)); for (size_t i = h; i < std::numeric_limits::max(); i++) { if (leaf_node::backend_cache.find(i) == @@ -499,7 +499,7 @@ namespace graph { /// @params[in] x X Argument. /// @params[in] y Y Argument. //------------------------------------------------------------------------------ - piecewise_2D_node(const backend::multi_buffer &d, + piecewise_2D_node(const backend::buffer &d, const size_t n, shared_leaf x, shared_leaf y) : -- GitLab From da863e1ef6531ee98cbd60225b805acb5438428c Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 30 May 2024 09:26:44 -0400 Subject: [PATCH 03/63] Increase theoretical occupancy by setting a maxthreads value. --- graph_framework/cuda_context.hpp | 6 +++--- graph_framework/metal_context.hpp | 20 +++++++++++++++++--- graph_framework/node.hpp | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index ea0ce7f..e6a7a86 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -448,8 +448,8 @@ namespace gpu { const size_t size, jit::register_map ®isters) { source_buffer << std::endl; - source_buffer << "extern \"C\" __global__ void " << name << "(" - << std::endl; + source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void " + << name << "(" << std::endl; source_buffer << " "; jit::add_type (source_buffer); @@ -556,7 +556,7 @@ namespace gpu { void create_reduction(std::ostringstream &source_buffer, const size_t size) { source_buffer << std::endl; - source_buffer << "extern \"C\" __global__ void max_reduction(" << std::endl; + source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void max_reduction(" << std::endl; source_buffer << " "; jit::add_type (source_buffer); source_buffer << " *input," << std::endl; diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 73d649e..9c25039 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -106,6 +106,7 @@ namespace gpu { MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new]; compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; compute.computeFunction = function; + compute.maxTotalThreadsPerThreadgroup = 1024; id state = [device newComputePipelineStateWithDescriptor:compute options:MTLPipelineOptionNone @@ -140,13 +141,15 @@ namespace gpu { NSRange range = NSMakeRange(0, buffers.size()); NSUInteger threads_per_group = state.maxTotalThreadsPerThreadgroup; + NSUInteger thread_width = state.threadExecutionWidth; NSUInteger thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0); if (jit::verbose) { std::cout << " Kernel name : " << kernel_name << std::endl; - std::cout << " Threads per group : " << threads_per_group << std::endl; - std::cout << " Number of groups : " << thread_groups << std::endl; - std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; + std::cout << " Thread execution width : " << thread_width << std::endl; + std::cout << " Threads per group : " << threads_per_group << std::endl; + std::cout << " Number of groups : " << thread_groups << std::endl; + std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; } return [this, state, buffers, offsets, range, thread_groups, threads_per_group] () mutable { @@ -178,6 +181,7 @@ namespace gpu { MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new]; compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; compute.computeFunction = [library newFunctionWithName:@"max_reduction"]; + compute.maxTotalThreadsPerThreadgroup = 1024; NSError *error; id max_state = [device newComputePipelineStateWithDescriptor:compute @@ -194,6 +198,16 @@ namespace gpu { id buffer = kernel_arguments[argument.get()]; + NSUInteger threads_per_group = max_state.maxTotalThreadsPerThreadgroup; + NSUInteger thread_width = max_state.threadExecutionWidth; + if (jit::verbose) { + std::cout << " Kernel name : max_reduction" << std::endl; + std::cout << " Thread execution width : " << thread_width << std::endl; + std::cout << " Threads per group : " << threads_per_group << std::endl; + std::cout << " Number of groups : " << 1 << std::endl; + std::cout << " Total problem size : " << threads_per_group*1 << std::endl; + } + return [this, run, buffer, result, max_state] () mutable { run(); command_buffer = [queue commandBuffer]; diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 3609865..5317399 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -258,7 +258,7 @@ namespace graph { /// Cache for the backend buffers. inline thread_local static std::map> backend_cache; - + /// Type def to retrieve the backend type. typedef T base; }; -- GitLab From c35d2ff72db41dfe614c6a8483a87b5069fb6875 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 31 May 2024 16:04:06 -0400 Subject: [PATCH 04/63] Mark kernel arguments that don't change was const. --- graph_framework/cpu_context.hpp | 13 +++++++++---- graph_framework/cuda_context.hpp | 10 +++++++++- graph_framework/jit.hpp | 13 ++++++++++++- graph_framework/metal_context.hpp | 7 +++++-- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index ca5a92b..264e976 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -353,13 +353,15 @@ namespace gpu { /// @params[in] inputs Input variables of the kernel. /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] size Size of the input buffer. +/// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, - const size_t size, + const size_t size, + const std::vector &is_constant, jit::register_map ®isters) { source_buffer << std::endl; source_buffer << "extern \"C\" void " << name << "(" << std::endl; @@ -368,11 +370,14 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " *> &args) {" << std::endl; - for (auto &input : inputs) { + for (size_t i = 0, ie = inputs.size(); i < ie; i++) { source_buffer << " "; + if (is_constant[i]) { + source_buffer << "const "; + } jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', input.get()) - << " = args[" << reinterpret_cast (input.get()) + source_buffer << " *" << jit::to_string('v', inputs[i].get()) + << " = args[" << reinterpret_cast (inputs[i].get()) << "];" << std::endl; } for (auto &output : outputs) { diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index e6a7a86..c7c6f40 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -439,24 +439,32 @@ namespace gpu { /// @params[in] inputs Input variables of the kernel. /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] size Size of the input buffer. +/// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, - const size_t size, + const size_t size, + const std::vector &is_constant, jit::register_map ®isters) { source_buffer << std::endl; source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void " << name << "(" << std::endl; source_buffer << " "; + if (is_constant[0]) { + source_buffer << "const " + } jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('v', inputs[0].get()); for (size_t i = 1, ie = inputs.size(); i < ie; i++) { source_buffer << "," << std::endl; source_buffer << " "; + if (is_constant[0]) { + source_buffer << "const " + } jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('v', inputs[i].get()); } diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index cfb3e29..56a9991 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -8,6 +8,9 @@ #ifndef jit_h #define jit_h +#include +#include + #ifdef USE_METAL #include "metal_context.hpp" #elif defined(USE_CUDA) @@ -97,8 +100,15 @@ namespace jit { const size_t size = inputs[0]->size(); + std::vector is_constant(inputs.size(), true); visiter_map visited; for (auto &[out, in] : setters) { + auto found = std::distance(inputs.begin(), + std::find(inputs.begin(), + inputs.end(), in)); + if (found < is_constant.size()) { + is_constant[found] = false; + } out->compile_preamble(source_buffer, registers, visited); } for (auto &out : outputs) { @@ -106,7 +116,8 @@ namespace jit { } gpu_context.create_kernel_prefix(source_buffer, - name, inputs, outputs, size, + name, inputs, outputs, + size, is_constant, registers); for (auto &[out, in] : setters) { diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 9c25039..1530510 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -338,19 +338,22 @@ namespace gpu { /// @params[in] inputs Input variables of the kernel. /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] size Size of the input buffer. +/// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, - const size_t size, + const size_t size, + const std::vector &is_constant, jit::register_map ®isters) { source_buffer << std::endl; source_buffer << "kernel void " << name << "(" << std::endl; for (size_t i = 0, ie = inputs.size(); i < ie; i++) { - source_buffer << " device float *" + source_buffer << " " << (is_constant[i] ? "constant" : "device") + << " float *" << jit::to_string('v', inputs[i].get()) << " [[buffer(" << i << ")]]," << std::endl; } -- GitLab From dc34341791dd260a94de801cb334f70c4fb5d306 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 31 May 2024 16:25:22 -0400 Subject: [PATCH 05/63] Add missing ; --- graph_framework/cuda_context.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index c7c6f40..b619cb7 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -455,7 +455,7 @@ namespace gpu { source_buffer << " "; if (is_constant[0]) { - source_buffer << "const " + source_buffer << "const "; } jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('v', inputs[0].get()); @@ -463,7 +463,7 @@ namespace gpu { source_buffer << "," << std::endl; source_buffer << " "; if (is_constant[0]) { - source_buffer << "const " + source_buffer << "const "; } jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('v', inputs[i].get()); -- GitLab From 43aff470e7ea4e1f60cb05c9570328d632ea53bf Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 31 May 2024 16:42:13 -0400 Subject: [PATCH 06/63] Mark input to reduction kernel as const and fix is_constant index. --- graph_framework/cuda_context.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index b619cb7..5e03a2e 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -425,7 +425,7 @@ namespace gpu { void create_header(std::ostringstream &source_buffer) { if constexpr (jit::is_complex ()) { source_buffer << "#define CUDA_DEVICE_CODE" << std::endl; - source_buffer << "#define M_PI " << M_PI << std::endl; + source_buffer << "#define M_PI " << M_PI << std::endl; source_buffer << "#include " << std::endl; source_buffer << "#include " << std::endl; } @@ -462,7 +462,7 @@ namespace gpu { for (size_t i = 1, ie = inputs.size(); i < ie; i++) { source_buffer << "," << std::endl; source_buffer << " "; - if (is_constant[0]) { + if (is_constant[i]) { source_buffer << "const "; } jit::add_type (source_buffer); @@ -565,7 +565,7 @@ namespace gpu { const size_t size) { source_buffer << std::endl; source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void max_reduction(" << std::endl; - source_buffer << " "; + source_buffer << " const "; jit::add_type (source_buffer); source_buffer << " *input," << std::endl; source_buffer << " "; -- GitLab From 5af93888de42d85427ffb577c95d9aca24d4ad72 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 4 Jun 2024 13:53:55 -0400 Subject: [PATCH 07/63] Swhich piecewise arrays to use textures. --- graph_framework/cpu_context.hpp | 20 +++-- graph_framework/cuda_context.hpp | 34 +++++-- graph_framework/jit.hpp | 25 ++++-- graph_framework/metal_context.hpp | 81 ++++++++++++++--- graph_framework/node.hpp | 77 +++++++++++----- graph_framework/piecewise.hpp | 145 ++++++++++++++++++------------ graph_framework/register.hpp | 8 +- graph_tests/piecewise_test.cpp | 2 +- 8 files changed, 278 insertions(+), 114 deletions(-) diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index 264e976..572e246 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -187,16 +187,20 @@ namespace gpu { //------------------------------------------------------------------------------ /// @brief Create a kernel calling function. /// -/// @params[in] kernel_name Name of the kernel for later reference. -/// @params[in] inputs Input nodes of the kernel. -/// @params[in] outputs Output nodes of the kernel. -/// @params[in] num_rays Number of rays to trace. +/// @params[in] kernel_name Name of the kernel for later reference. +/// @params[in] inputs Input nodes of the kernel. +/// @params[in] outputs Output nodes of the kernel. +/// @params[in] num_rays Number of rays to trace. +/// @params[in] tex1d_list List of 1D textures. +/// @params[in] tex2d_list List of 1D textures. /// @returns A lambda function to run the kernel. //------------------------------------------------------------------------------ std::function create_kernel_call(const std::string kernel_name, graph::input_nodes inputs, graph::output_nodes outputs, - const size_t num_rays) { + const size_t num_rays, + const jit::texture1d_list &tex1d_list, + const jit::texture2d_list &tex2d_list) { auto entry = std::move(jit->lookup(kernel_name)).get(); auto kernel = entry.toPtr &)> (); @@ -355,6 +359,8 @@ namespace gpu { /// @params[in] size Size of the input buffer. /// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] textures1d List of 1D kernel textures. +/// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, @@ -362,7 +368,9 @@ namespace gpu { graph::output_nodes &outputs, const size_t size, const std::vector &is_constant, - jit::register_map ®isters) { + jit::register_map ®isters, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { source_buffer << std::endl; source_buffer << "extern \"C\" void " << name << "(" << std::endl; diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 5e03a2e..a458cc9 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -253,13 +253,17 @@ namespace gpu { /// @params[in] kernel_name Name of the kernel for later reference. /// @params[in] inputs Input nodes of the kernel. /// @params[in] outputs Output nodes of the kernel. -/// @params[in] num_rays Number of rays to trace. +/// @params[in] num_rays Number of rays to trace.' +/// @params[in] tex1d_list List of 1D textures. +/// @params[in] tex2d_list List of 1D textures. /// @returns A lambda function to run the kernel. //------------------------------------------------------------------------------ std::function create_kernel_call(const std::string kernel_name, graph::input_nodes inputs, graph::output_nodes outputs, - const size_t num_rays) { + const size_t num_rays, + const jit::texture1d_list &tex1d_list, + const jit::texture2d_list &tex2d_list) { CUfunction function; check_error(cuModuleGetFunction(&function, module, kernel_name.c_str()), "cuModuleGetFunction"); @@ -297,11 +301,18 @@ namespace gpu { function), "cuFuncGetAttribute"); unsigned int threads_per_group = value; unsigned int thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0); + + int min_grid; + check_error(cuOccupancyMaxPotentialBlockSize(&min_grid, &value, function, 0, 0, 0), + "cuOccupancyMaxPotentialBlockSize"); + if (jit::verbose) { std::cout << " Kernel name : " << kernel_name << std::endl; - std::cout << " Threads per group : " << threads_per_group << std::endl; - std::cout << " Number of groups : " << thread_groups << std::endl; - std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; + std::cout << " Threads per group : " << threads_per_group << std::endl; + std::cout << " Number of groups : " << thread_groups << std::endl; + std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; + std::cout << " Min grid size : " << min_grid << std::endl; + std::cout << " Suggested Block size : " << value << std::endl; } return [this, function, thread_groups, threads_per_group, buffers] () mutable { @@ -334,8 +345,15 @@ namespace gpu { check_error(cuModuleGetFunction(&function, module, "max_reduction"), "cuModuleGetFunction"); + int value; + int min_grid; + check_error(cuOccupancyMaxPotentialBlockSize(&min_grid, &value, function, 0, 0, 0), + "cuOccupancyMaxPotentialBlockSize"); + if (jit::verbose) { std::cout << " Kernel name : max_reduction" << std::endl; + std::cout << " Min grid size : " << min_grid << std::endl; + std::cout << " Suggested Block size : " << value << std::endl; } return [this, function, run, buffers] () mutable { @@ -441,6 +459,8 @@ namespace gpu { /// @params[in] size Size of the input buffer. /// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] textures1d List of 1D kernel textures. +/// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, @@ -448,7 +468,9 @@ namespace gpu { graph::output_nodes &outputs, const size_t size, const std::vector &is_constant, - jit::register_map ®isters) { + jit::register_map ®isters, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { source_buffer << std::endl; source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void " << name << "(" << std::endl; diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index 56a9991..cf1487d 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -42,6 +42,10 @@ namespace jit { register_map registers; /// Kernel names. std::vector kernel_names; +/// Kernel textures. + std::map kernel_1dtextures; +/// Kernel textures. + std::map kernel_2dtextures; /// Type for the GPU context. using gpu_context_type = typename std::conditional (), @@ -102,6 +106,8 @@ namespace jit { std::vector is_constant(inputs.size(), true); visiter_map visited; + kernel_1dtextures[name] = texture1d_list(); + kernel_2dtextures[name] = texture2d_list(); for (auto &[out, in] : setters) { auto found = std::distance(inputs.begin(), std::find(inputs.begin(), @@ -109,16 +115,22 @@ namespace jit { if (found < is_constant.size()) { is_constant[found] = false; } - out->compile_preamble(source_buffer, registers, visited); + out->compile_preamble(source_buffer, registers, visited, + kernel_1dtextures[name], + kernel_2dtextures[name]); } for (auto &out : outputs) { - out->compile_preamble(source_buffer, registers, visited); + out->compile_preamble(source_buffer, registers, visited, + kernel_1dtextures[name], + kernel_2dtextures[name]); } gpu_context.create_kernel_prefix(source_buffer, name, inputs, outputs, size, is_constant, - registers); + registers, + kernel_1dtextures[name], + kernel_2dtextures[name]); for (auto &[out, in] : setters) { out->compile(source_buffer, registers); @@ -130,7 +142,7 @@ namespace jit { gpu_context.create_kernel_postfix(source_buffer, outputs, setters, registers); -// Delete the registers so that can be used again in other kernels. +// Delete the registers so that they can be used again in other kernels. std::vector removed_elements; for (auto &[key, value] : registers) { if (value[0] == 'r') { @@ -184,8 +196,9 @@ namespace jit { graph::input_nodes inputs, graph::output_nodes outputs, const size_t num_rays) { - return gpu_context.create_kernel_call(kernel_name, inputs, outputs, - num_rays); + return gpu_context.create_kernel_call(kernel_name, inputs, outputs, num_rays, + kernel_1dtextures[kernel_name], + kernel_2dtextures[kernel_name]); } //------------------------------------------------------------------------------ diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 1530510..824429b 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -27,6 +27,8 @@ namespace gpu { id queue; /// Argument map. std::map *, id> kernel_arguments; +/// Textures. + std::map> texture_arguments; /// Max Buffer. id result; /// Metal command buffer. @@ -75,7 +77,7 @@ namespace gpu { encoding:NSUTF8StringEncoding] options:compile_options() error:&error]; - + if (error) { NSLog(@"%@", error); } @@ -88,16 +90,20 @@ namespace gpu { //------------------------------------------------------------------------------ /// @brief Create a kernel calling function. /// -/// @params[in] kernel_name Name of the kernel for later reference. -/// @params[in] inputs Input nodes of the kernel. -/// @params[in] outputs Output nodes of the kernel. -/// @params[in] num_rays Number of rays to trace. +/// @params[in] kernel_name Name of the kernel for later reference. +/// @params[in] inputs Input nodes of the kernel. +/// @params[in] outputs Output nodes of the kernel. +/// @params[in] num_rays Number of rays to trace. +/// @params[in] tex1d_list List of 1D textures. +/// @params[in] tex2d_list List of 1D textures. /// @returns A lambda function to run the kernel. //------------------------------------------------------------------------------ std::function create_kernel_call(const std::string kernel_name, graph::input_nodes inputs, graph::output_nodes outputs, - const size_t num_rays) { + const size_t num_rays, + const jit::texture1d_list &tex1d_list, + const jit::texture2d_list &tex2d_list) { NSError *error; id function = [library newFunctionWithName:[NSString stringWithCString:kernel_name.c_str() @@ -137,8 +143,44 @@ namespace gpu { buffers.push_back(kernel_arguments[output.get()]); } + std::vector> textures; + for (auto &[data, size] : tex1d_list) { + if (!texture_arguments.contains(data)) { + MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new]; + discriptor.textureType = MTLTextureType1D; + discriptor.pixelFormat = MTLPixelFormatR32Float; + discriptor.width = size; + discriptor.resourceOptions = MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeManaged; + discriptor.usage = MTLTextureUsageShaderRead; + texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; + [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size) + mipmapLevel:0 + withBytes:reinterpret_cast (data) + bytesPerRow:4*size]; + } + textures.push_back(texture_arguments[data]); + } + for (auto &[data, size] : tex2d_list) { + if (!texture_arguments.contains(data)) { + MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new]; + discriptor.textureType = MTLTextureType2D; + discriptor.pixelFormat = MTLPixelFormatR32Float; + discriptor.width = size[0]; + discriptor.height = size[1]; + discriptor.resourceOptions = MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeManaged; + discriptor.usage = MTLTextureUsageShaderRead; + texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; + [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[0], size[1]) + mipmapLevel:0 + withBytes:reinterpret_cast (data) + bytesPerRow:4*size[0]]; + } + textures.push_back(texture_arguments[data]); + } + std::vector offsets(buffers.size(), 0); NSRange range = NSMakeRange(0, buffers.size()); + NSRange tex_range = NSMakeRange(0, textures.size()); NSUInteger threads_per_group = state.maxTotalThreadsPerThreadgroup; NSUInteger thread_width = state.threadExecutionWidth; @@ -152,7 +194,7 @@ namespace gpu { std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl; } - return [this, state, buffers, offsets, range, thread_groups, threads_per_group] () mutable { + return [this, state, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] () mutable { command_buffer = [queue commandBuffer]; id encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial]; @@ -160,6 +202,8 @@ namespace gpu { [encoder setBuffers:buffers.data() offsets:offsets.data() withRange:range]; + [encoder setTextures:textures.data() + withRange:tex_range]; [encoder dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)]; @@ -340,6 +384,8 @@ namespace gpu { /// @params[in] size Size of the input buffer. /// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] textures1d List of 1D kernel textures. +/// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, @@ -347,24 +393,37 @@ namespace gpu { graph::output_nodes &outputs, const size_t size, const std::vector &is_constant, - jit::register_map ®isters) { + jit::register_map ®isters, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { source_buffer << std::endl; source_buffer << "kernel void " << name << "(" << std::endl; - + for (size_t i = 0, ie = inputs.size(); i < ie; i++) { source_buffer << " " << (is_constant[i] ? "constant" : "device") << " float *" << jit::to_string('v', inputs[i].get()) << " [[buffer(" << i << ")]]," << std::endl; } - for (size_t i = 0, ie = outputs.size(); i < ie; i++) { source_buffer << " device float *" << jit::to_string('o', outputs[i].get()) << " [[buffer(" << i + inputs.size() << ")]]," << std::endl; } - + for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { + source_buffer << " const texture1d " + << jit::to_string('a', textures1d[i].first) + << " [[texture(" << i << ")]]," + << std::endl; + } + for (size_t i = 0, ie = textures2d.size(); i < ie; i++) { + source_buffer << " const texture2d " + << jit::to_string('a', textures2d[i].first) + << " [[texture(" << i + textures1d.size() << ")]]," + << std::endl; + } + source_buffer << " uint index [[thread_position_in_grid]]) {" << std::endl; source_buffer << " if (index < " << size << ") {" << std::endl; diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 5317399..b1ecd5a 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -88,13 +88,17 @@ namespace graph { /// Some nodes require additions to the preamble however most don't so define a /// generic method that does nothing. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) {} + jit::visiter_map &visited, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) {} //------------------------------------------------------------------------------ /// @brief Compile the node. @@ -700,15 +704,22 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { - this->arg->compile_preamble(stream, registers, visited); - visited[this] = 0; + this->arg->compile_preamble(stream, registers, + visited, textures1d, + textures2d); + visited.insert(this); } } @@ -816,17 +827,25 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { - this->left->compile_preamble(stream, registers, visited); - this->right->compile_preamble(stream, registers, visited); - visited[this] = 0; + this->left->compile_preamble(stream, registers, + visited, textures1d, + textures2d); + this->right->compile_preamble(stream, registers, + visited, textures1d, + textures2d); + visited.insert(this); } } @@ -919,18 +938,28 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { - this->left->compile_preamble(stream, registers, visited); - this->middle->compile_preamble(stream, registers, visited); - this->right->compile_preamble(stream, registers, visited); - visited[this] = 0; + this->left->compile_preamble(stream, registers, + visited, textures1d, + textures2d); + this->middle->compile_preamble(stream, registers, + visited, textures1d, + textures2d); + this->right->compile_preamble(stream, registers, + visited, textures1d, + textures2d); + visited.insert(this); } } diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index f8a6069..5338d6b 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -158,34 +158,40 @@ namespace graph { /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. /// @params[in,out] visited List of visited nodes. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); + const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - stream << "constant "; - } - stream << "const "; - jit::add_type (stream); - stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; - if constexpr (jit::is_complex ()) { + textures1d.emplace_back(leaf_node::backend_cache[data_hash].data(), + length); + } else { + stream << "const "; jit::add_type (stream); - } - stream << leaf_node::backend_cache[data_hash][0]; - for (size_t i = 1, ie = leaf_node::backend_cache[data_hash].size(); - i < ie; i++) { - stream << ", "; + stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; if constexpr (jit::is_complex ()) { jit::add_type (stream); } - stream << leaf_node::backend_cache[data_hash][i]; + stream << leaf_node::backend_cache[data_hash][0]; + for (size_t i = 1; i < length; i++) { + stream << ", "; + if constexpr (jit::is_complex ()) { + jit::add_type (stream); + } + stream << leaf_node::backend_cache[data_hash][i]; + } + stream << "};" << std::endl; } - stream << "};" << std::endl; - visited[this] = 0; + visited.insert(this); } } } @@ -219,17 +225,22 @@ namespace graph { jit::add_type (stream); stream << " " << registers[this] << " = " << registers[leaf_node::backend_cache[data_hash].data()]; - stream << "[max(min((int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; - } - stream << registers[a.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + stream << ".read(min(max((uint)" << registers[a.get()] + << ",0u)," << length - 1 << "u)).r;"; + } else { + stream << "[min(max((int)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << registers[a.get()]; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << ",0)," << length - 1 << ")];"; } - stream << ", " - << leaf_node::backend_cache[data_hash].size() - 1 << "), 0)];" - << std::endl; + stream << std::endl; } return this->shared_from_this(); @@ -561,35 +572,43 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, - jit::visiter_map &visited) { + jit::visiter_map &visited, + jit::texture1d_list &textures1d, + jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); + const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - stream << "constant "; - } - stream << "const "; - jit::add_type (stream); - stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; - if constexpr (jit::is_complex ()) { + textures2d.emplace_back(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); + } else { + stream << "const "; jit::add_type (stream); - } - stream << leaf_node::backend_cache[data_hash][0]; - for (size_t i = 1, ie = leaf_node::backend_cache[data_hash].size(); i < ie; i++) { - stream << ", "; + stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; if constexpr (jit::is_complex ()) { jit::add_type (stream); } - stream << leaf_node::backend_cache[data_hash][i]; + stream << leaf_node::backend_cache[data_hash][0]; + for (size_t i = 1; i < length; i++) { + stream << ", "; + if constexpr (jit::is_complex ()) { + jit::add_type (stream); + } + stream << leaf_node::backend_cache[data_hash][i]; + } + stream << "};" << std::endl; } - stream << "};" << std::endl; + visited.insert(this); } } } @@ -637,25 +656,33 @@ namespace graph { jit::add_type (stream); stream << " " << registers[this] << " = " << registers[leaf_node::backend_cache[data_hash].data()]; - stream << "[max(min((int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; - } - stream << registers[x.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; - } - stream << "*" << num_columns << " + (int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; - } - stream << registers[y.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + const size_t num_rows = length/num_columns; + stream << ".read(uint2(min(max((uint)" + << registers[x.get()] << ", 0u)," << num_rows + << "u),min(max((uint)" << registers[y.get()] + << ",0u)," << num_columns << "u)).yx).r;"; + } else { + stream << "[min(max((int)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << registers[x.get()]; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << "*" << num_columns << " + (int)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << registers[y.get()]; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << ",0), " << length - 1 << ")];"; } - stream << ", " - << leaf_node::backend_cache[data_hash].size() - 1 << "), 0)];" - << std::endl; + stream << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index 8a3feb6..f6f6a3b 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -12,12 +12,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include namespace jit { /// Complex scalar concept. @@ -243,7 +245,11 @@ namespace jit { /// Type alias for mapping node pointers to register names. typedef std::map register_map; /// Type alias for listing visited nodes. - typedef std::map visiter_map; + typedef std::set visiter_map; +/// Type alias for indexing 1D textures. + typedef std::vector> texture1d_list; +/// Type alias for indexing 2D textures. + typedef std::vector>> texture2d_list; //------------------------------------------------------------------------------ /// @brief Define a custom comparitor class. diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index fa8b4da..8f56135 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -27,7 +27,7 @@ /// @params[in] tolarance Test tolarance. //------------------------------------------------------------------------------ template void check(const T test, - const T tolarance) { + const T tolarance) { if constexpr (jit::is_complex ()) { assert(std::real(test) <= std::real(tolarance) && "Real GPU and CPU values differ."); -- GitLab From fd1e963ae86bc3975ad272da3968c9c26ccdeae1 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 5 Jun 2024 18:07:20 -0400 Subject: [PATCH 08/63] Add more buffer hint attributes to metal backend. --- graph_framework/metal_context.hpp | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 824429b..652c8b6 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -35,6 +35,8 @@ namespace gpu { id command_buffer; /// Metal library. id library; +/// Buffer mutability discriptor. + std::map> bufferMutability; public: //------------------------------------------------------------------------------ @@ -113,6 +115,9 @@ namespace gpu { compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; compute.computeFunction = function; compute.maxTotalThreadsPerThreadgroup = 1024; + for (size_t i = 0, ie = bufferMutability[kernel_name].size(); i < ie; i++) { + compute.buffers[i].mutability = bufferMutability[kernel_name][i]; + } id state = [device newComputePipelineStateWithDescriptor:compute options:MTLPipelineOptionNone @@ -144,19 +149,25 @@ namespace gpu { } std::vector> textures; + command_buffer = [queue commandBuffer]; + id encoder = [command_buffer blitCommandEncoder]; for (auto &[data, size] : tex1d_list) { if (!texture_arguments.contains(data)) { MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new]; discriptor.textureType = MTLTextureType1D; discriptor.pixelFormat = MTLPixelFormatR32Float; discriptor.width = size; - discriptor.resourceOptions = MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeManaged; + discriptor.storageMode = MTLStorageModeManaged; + discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined; + discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked; discriptor.usage = MTLTextureUsageShaderRead; texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size) mipmapLevel:0 withBytes:reinterpret_cast (data) bytesPerRow:4*size]; + + [encoder optimizeContentsForGPUAccess:texture_arguments[data]]; } textures.push_back(texture_arguments[data]); } @@ -167,16 +178,22 @@ namespace gpu { discriptor.pixelFormat = MTLPixelFormatR32Float; discriptor.width = size[0]; discriptor.height = size[1]; - discriptor.resourceOptions = MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeManaged; + discriptor.storageMode = MTLStorageModeManaged; + discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined; + discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked; discriptor.usage = MTLTextureUsageShaderRead; texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[0], size[1]) mipmapLevel:0 withBytes:reinterpret_cast (data) bytesPerRow:4*size[0]]; + + [encoder optimizeContentsForGPUAccess:texture_arguments[data]]; } textures.push_back(texture_arguments[data]); } + [encoder endEncoding]; + [command_buffer commit]; std::vector offsets(buffers.size(), 0); NSRange range = NSMakeRange(0, buffers.size()); @@ -226,13 +243,13 @@ namespace gpu { compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; compute.computeFunction = [library newFunctionWithName:@"max_reduction"]; compute.maxTotalThreadsPerThreadgroup = 1024; + compute.buffers[0].mutability = MTLMutabilityImmutable; NSError *error; id max_state = [device newComputePipelineStateWithDescriptor:compute options:MTLPipelineOptionNone reflection:NULL error:&error]; - if (error) { NSLog(@"%@", error); } @@ -399,13 +416,17 @@ namespace gpu { source_buffer << std::endl; source_buffer << "kernel void " << name << "(" << std::endl; + bufferMutability[name] = std::vector (); + for (size_t i = 0, ie = inputs.size(); i < ie; i++) { + bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable); source_buffer << " " << (is_constant[i] ? "constant" : "device") << " float *" << jit::to_string('v', inputs[i].get()) << " [[buffer(" << i << ")]]," << std::endl; } for (size_t i = 0, ie = outputs.size(); i < ie; i++) { + bufferMutability[name].push_back(MTLMutabilityMutable); source_buffer << " device float *" << jit::to_string('o', outputs[i].get()) << " [[buffer(" << i + inputs.size() << ")]]," -- GitLab From 2b5cf4e80f9aa7e6782cbbb6d97b5cd0db4298c4 Mon Sep 17 00:00:00 2001 From: m4c Date: Tue, 11 Jun 2024 15:55:44 -0400 Subject: [PATCH 09/63] Do not force a maximum number of threads. This slows wall time. --- graph_framework/cuda_context.hpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index a458cc9..6bf8824 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -120,6 +120,7 @@ namespace gpu { check_error(cuDeviceGet(&device, index), "cuDeviceGet"); check_error(cuDevicePrimaryCtxRetain(&context, device), "cuDevicePrimaryCtxRetain"); check_error(cuCtxSetCurrent(context), "cuCtxSetCurrent"); + check_error(cuCtxSetCacheConfig(CU_FUNC_CACHE_PREFER_L1), "cuCtxSetCacheConfig"); check_error(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "cuStreamCreate"); } @@ -203,13 +204,16 @@ namespace gpu { } const std::string temp = arch.str(); - std::array options({ + std::array options({ temp.c_str(), "--std=c++17", + "--relocatable-device-code=false", "--include-path=" CUDA_INCLUDE, "--include-path=" HEADER_DIR, "--extra-device-vectorization", - "--device-as-default-execution-space" + "--device-as-default-execution-space", + "--ptxas-options", + "-dlcm=cg" }); if (nvrtcCompileProgram(kernel_program, options.size(), options.data())) { @@ -466,13 +470,13 @@ namespace gpu { const std::string name, graph::input_nodes &inputs, graph::output_nodes &outputs, - const size_t size, + const size_t size, const std::vector &is_constant, jit::register_map ®isters, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { source_buffer << std::endl; - source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void " + source_buffer << "extern \"C\" __global__ void " << name << "(" << std::endl; source_buffer << " "; @@ -586,7 +590,7 @@ namespace gpu { void create_reduction(std::ostringstream &source_buffer, const size_t size) { source_buffer << std::endl; - source_buffer << "extern \"C\" __global__ __launch_bounds__(1024) void max_reduction(" << std::endl; + source_buffer << "extern \"C\" __global__ void max_reduction(" << std::endl; source_buffer << " const "; jit::add_type (source_buffer); source_buffer << " *input," << std::endl; -- GitLab From 893cf17c6b1f4d9ed93cad7d4c13fb4f7a02f5e5 Mon Sep 17 00:00:00 2001 From: m4c Date: Wed, 12 Jun 2024 12:27:46 -0400 Subject: [PATCH 10/63] Restrict registers to fastest float wall time. --- graph_framework/cuda_context.hpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 6bf8824..5d1f61e 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -246,7 +246,16 @@ namespace gpu { check_nvrtc_error(nvrtcDestroyProgram(&kernel_program), "nvrtcDestroyProgram"); - check_error(cuModuleLoadDataEx(&module, ptx, 0, NULL, NULL), "cuModuleLoadDataEx"); + std::array module_options = { + CU_JIT_MAX_REGISTERS + }; + std::array module_values = { + reinterpret_cast (168) + }; + + check_error(cuModuleLoadDataEx(&module, ptx, 1, + module_options.data(), + module_values.data()), "cuModuleLoadDataEx"); free(ptx); } -- GitLab From d989c7989ddb763694d4db878931f88d95f0bde8 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 11 Jun 2024 16:10:49 -0400 Subject: [PATCH 11/63] Add dependencies to avoid building before the pull. --- CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index aca1c3c..6e31b81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,9 @@ add_dependencies (cuda-resource-headers pull_llvm) add_dependencies (scan-build-py pull_llvm) add_dependencies (x86-resource-headers pull_llvm) add_dependencies (obj.clangSupport pull_llvm) +add_dependencies (arm-common-resource-headers pull_llvm) +add_dependencies (arm-resource-headers pull_llvm) +add_dependencies (aarch64-resource-headers pull_llvm) add_library (llvm_dep INTERFACE) target_include_directories (llvm_dep -- GitLab From 07c0997debeab23abf106eb9931d406d2840adc0 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 14:57:11 -0400 Subject: [PATCH 12/63] Use texture objects in cuda. Not yet tested. --- graph_framework/cuda_context.hpp | 123 ++++++++++++++++++++++++++++++- graph_framework/piecewise.hpp | 96 ++++++++++++++++++------ 2 files changed, 195 insertions(+), 24 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 5d1f61e..334ca4b 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -72,6 +72,8 @@ namespace gpu { CUmodule module; /// Argument map. std::map *, CUdeviceptr> kernel_arguments; +/// Textures. + std::map texture_arguments; /// Result buffer. CUdeviceptr result_buffer; /// Cuda stream. @@ -309,6 +311,95 @@ namespace gpu { buffers.push_back(reinterpret_cast (&kernel_arguments[output.get()])); } + for (auto &[data, size] : tex1d_list) { + if (!texture_arguments.contains(data)) { + struct CUDA_RESOURCE_DESC resource_desc; + struct CUDA_TEXTURE_DESC texture_desc; + struct CUDA_RESOURCE_VIEW_DESC view_desc; + resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; + texture_desc.flags = CU_TRSF_READ_AS_INTEGER; + view_desc.format = CU_RES_VIEW_FORMAT_NONE; + view_desc.width = size; + if constexpr (jit::is_float ()) { + resource_desc.format = CU_AD_FORMAT_FLOAT; + if constexpr (jit::is_complex ()) { + resource_desc.numChannel = 2; + resource_desc.sizeInBytes = 2*size*sizeof(float); + } else { + resourec_desc.numChannel = 1; + resource_desc.sizeInBytes = size*sizeof(float); + } + } else { + resource_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; + if constexpr (jit::is_complex ()) { + resource_desc.numChannel = 4; + resource_desc.sizeInBytes = 2*size*sizeof(double); + } else { + resource_desc.numChannel = 2; + resource_desc.sizeInBytes = size*sizeof(double); + } + } + check_error(cuMemAllocManaged(&resource_desc.devPtr, + resource_desc.sizeInBytes, + CU_MEM_ATTACH_GLOBAL), + "cuMemAllocManaged"); + check_error(cuMemcpyHtoD(resource_desc.devPtr, + data, + resource_desc.sizeInBytes), + "cuMemcpyHtoD"); + + check_error(cuTexObjectCreate(&texture_arguments[data], + resource_desc,), + "cuTexObjectCreate"); + } + buffers.push_back(reinterpret_cast (&texture_arguments[data])); + } + for (auto &[data, size] : tex2d_list) { + if (!texture_arguments.contains(data)) { + struct CUDA_RESOURCE_DESC resource_desc; + struct CUDA_TEXTURE_DESC texture_desc; + struct CUDA_RESOURCE_VIEW_DESC view_desc; + resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; + texture_desc.flags = CU_TRSF_READ_AS_INTEGER; + view_desc.format = CU_RES_VIEW_FORMAT_NONE; + view_desc.width = size[0]; + view_desc.height = size[1]; + const size_t total = size[0]*size[1]; + if constexpr (jit::is_float ()) { + resource_desc.format = CU_AD_FORMAT_FLOAT; + if constexpr (jit::is_complex ()) { + resource_desc.numChannel = 2; + resource_desc.sizeInBytes = 2*total*sizeof(float); + } else { + resourec_desc.numChannel = 1; + resource_desc.sizeInBytes = total*sizeof(float); + } + } else { + resource_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; + if constexpr (jit::is_complex ()) { + resource_desc.numChannel = 4; + resource_desc.sizeInBytes = 2*total*sizeof(double); + } else { + resource_desc.numChannel = 2; + resource_desc.sizeInBytes = total*sizeof(double); + } + } + check_error(cuMemAllocManaged(&resource_desc.devPtr, + resource_desc.sizeInBytes, + CU_MEM_ATTACH_GLOBAL), + "cuMemAllocManaged"); + check_error(cuMemcpyHtoD(resource_desc.devPtr, + data, + resource_desc.sizeInBytes), + "cuMemcpyHtoD"); + + check_error(cuTexObjectCreate(&texture_arguments[data], + resource_desc,), + "cuTexObjectCreate"); + } + buffers.push_back(reinterpret_cast (&texture_arguments[data])); + } + int value; check_error(cuFuncGetAttribute(&value, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, function), "cuFuncGetAttribute"); @@ -328,7 +419,7 @@ namespace gpu { std::cout << " Suggested Block size : " << value << std::endl; } - return [this, function, thread_groups, threads_per_group, buffers] () mutable { + return [this, function, thread_groups, threads_per_group, buffers, textures] () mutable { check_error_async(cuLaunchKernel(function, thread_groups, 1, 1, threads_per_group, 1, 1, 0, stream, buffers.data(), NULL), @@ -459,6 +550,23 @@ namespace gpu { source_buffer << "#define M_PI " << M_PI << std::endl; source_buffer << "#include " << std::endl; source_buffer << "#include " << std::endl; + if constexpr (jit::is_float ()) { + source_buffer << "static __inline__ __device__ complex to_cmp_float(float2 p) {" + << " return "; + jit::add_type (stream); + source_buffer << " (p.x, p.y);" + << "}"; + } else { + source_buffer << "static __inline__ __device__ complex to_cmp_double(uint4 p) {" + << " return "; + jit::add_type (stream); + source_buffer << " (__hiloint2double(p.y, p.x), __hiloint2double(p.w, p.z));" + << "}"; + } + } else if constexpr (jit::is_double ()) { + source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" + << " return __hiloint2double(p.y, p.x);" + << "}"; } } @@ -503,13 +611,24 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('v', inputs[i].get()); } - for (size_t i = 0, ie = outputs.size(); i < ie; i++) { source_buffer << "," << std::endl; source_buffer << " "; jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('o', outputs[i].get()); } + for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { + source_buffer << "," << std::endl; + source_buffer << " "; + jit::add_type (source_buffer); + source_buffer << " *" << jit::to_string('a', textures1d[i].get()); + } + for (size_t i = 0, ie = textures2d.size(); i < ie; i++) { + source_buffer << "," << std::endl; + source_buffer << " "; + jit::add_type (source_buffer); + source_buffer << " *" << jit::to_string('a', textures2d[i].get()); + } source_buffer << ") {" << std::endl; source_buffer << " const int index = blockIdx.x*blockDim.x + threadIdx.x;" diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 5338d6b..b104d81 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -11,6 +11,30 @@ #include "node.hpp" namespace graph { +//------------------------------------------------------------------------------ +/// @brief Compile an index. +/// +/// @tparam T Base type of the calculation. +/// +/// @params[in,out] stream String buffer stream. +/// @params[in] register_name Reister for the argument. +/// @params[in] length Dimension length of argument. +//------------------------------------------------------------------------------ +template +void compile_index(std::ostringstream &stream, + const std::string ®ister_name, + const size_t length) { + stream << "min(max((uint)"; + if constexpr (jit::is_complex ()) { + stream << "real("; + } + stream << register_name; + if constexpr (jit::is_complex ()) { + stream << ")"; + } + stream << ",0u)," << length - 1 << "u)"; +} + //****************************************************************************** // 1D Piecewise node. //****************************************************************************** @@ -171,7 +195,7 @@ namespace graph { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); const size_t length = leaf_node::backend_cache[data_hash].size(); - if constexpr (jit::use_metal ()) { + if constexpr (jit::use_metal () || jit::use_cuda()) { textures1d.emplace_back(leaf_node::backend_cache[data_hash].data(), length); } else { @@ -223,22 +247,32 @@ namespace graph { registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); - stream << " " << registers[this] << " = " - << registers[leaf_node::backend_cache[data_hash].data()]; + stream << " " << registers[this] << " = "; + if constexpr (jit::use_cuda()) { + if constexpr (jit::is_float ()) { + stream << "tex1D ("; + } else if constexpr (jit::is_double ()) { + stream << "to_double(tex1D ("; + } else if constexpr (jit::is_complex () && jit::is_float ()) { + stream << "to_cmp_float(tex1D ("; + } else { + stream << "to_cmp_double(tex1D ("; + } + } + stream << registers[leaf_node::backend_cache[data_hash].data()]; const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - stream << ".read(min(max((uint)" << registers[a.get()] - << ",0u)," << length - 1 << "u)).r;"; + stream << ".read("; + compile_index (stream, registers[a.get()], length); + stream << ").r;"; + } else if constexpr (jit::use_cuda()) { + stream << ", "; + compile_index (stream, registers[a.get()], length); + stream << ");"; } else { - stream << "[min(max((int)"; - if constexpr (jit::is_complex ()) { - stream << "real("; - } - stream << registers[a.get()]; - if constexpr (jit::is_complex ()) { - stream << ")"; - } - stream << ",0)," << length - 1 << ")];"; + stream << "["; + compile_index (stream, registers[a.get()], length); + stream << "];"; } stream << std::endl; } @@ -588,7 +622,7 @@ namespace graph { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); const size_t length = leaf_node::backend_cache[data_hash].size(); - if constexpr (jit::use_metal ()) { + if constexpr (jit::use_metal () || jit::use_cuda()) { textures2d.emplace_back(leaf_node::backend_cache[data_hash].data(), std::array ({length/num_columns, num_columns})); } else { @@ -654,15 +688,33 @@ namespace graph { registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); - stream << " " << registers[this] << " = " - << registers[leaf_node::backend_cache[data_hash].data()]; + stream << " " << registers[this] << " = "; + if constexpr (jit::use_cuda()) { + if constexpr (jit::is_float ()) { + stream << "tex1D ("; + } else if constexpr (jit::is_double ()) { + stream << "to_double(tex1D ("; + } else if constexpr (jit::is_complex () && jit::is_float ()) { + stream << "to_cmp_float(tex1D ("; + } else { + stream << "to_cmp_double(tex1D ("; + } + } + stream << registers[leaf_node::backend_cache[data_hash].data()]; const size_t length = leaf_node::backend_cache[data_hash].size(); + const size_t num_rows = length/num_columns; if constexpr (jit::use_metal ()) { - const size_t num_rows = length/num_columns; - stream << ".read(uint2(min(max((uint)" - << registers[x.get()] << ", 0u)," << num_rows - << "u),min(max((uint)" << registers[y.get()] - << ",0u)," << num_columns << "u)).yx).r;"; + stream << ".read(uint2("; + compile_index (stream, registers[x.get()], num_rows); + stream << ","; + compile_index (stream, registers[y.get()], num_columns); + stream << ").yx).r;"; + } else if constexpr (jit::use_cuda()) { + stream << ", "; + compile_index (stream, registers[x.get()], num_rows); + stream << ", "; + compile_index (stream, registers[y.get()], num_columns); + stream << ");"; } else { stream << "[min(max((int)"; if constexpr (jit::is_complex ()) { -- GitLab From 2c78e7931661eaf60dba06de974eac8fb0120cb4 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 15:04:49 -0400 Subject: [PATCH 13/63] Fix compile error by removing struct keyword. --- graph_framework/cuda_context.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 334ca4b..7fe4346 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -313,9 +313,9 @@ namespace gpu { for (auto &[data, size] : tex1d_list) { if (!texture_arguments.contains(data)) { - struct CUDA_RESOURCE_DESC resource_desc; - struct CUDA_TEXTURE_DESC texture_desc; - struct CUDA_RESOURCE_VIEW_DESC view_desc; + CUDA_RESOURCE_DESC resource_desc; + CUDA_TEXTURE_DESC texture_desc; + CUDA_RESOURCE_VIEW_DESC view_desc; resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; texture_desc.flags = CU_TRSF_READ_AS_INTEGER; view_desc.format = CU_RES_VIEW_FORMAT_NONE; @@ -356,9 +356,9 @@ namespace gpu { } for (auto &[data, size] : tex2d_list) { if (!texture_arguments.contains(data)) { - struct CUDA_RESOURCE_DESC resource_desc; - struct CUDA_TEXTURE_DESC texture_desc; - struct CUDA_RESOURCE_VIEW_DESC view_desc; + CUDA_RESOURCE_DESC resource_desc; + CUDA_TEXTURE_DESC texture_desc; + CUDA_RESOURCE_VIEW_DESC view_desc; resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; texture_desc.flags = CU_TRSF_READ_AS_INTEGER; view_desc.format = CU_RES_VIEW_FORMAT_NONE; -- GitLab From 3d4916fe987e780da36d10205e4eae8ad852fed9 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 15:07:17 -0400 Subject: [PATCH 14/63] Reference correct section of struct. --- graph_framework/cuda_context.hpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 7fe4346..9a6be15 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -366,31 +366,31 @@ namespace gpu { view_desc.height = size[1]; const size_t total = size[0]*size[1]; if constexpr (jit::is_float ()) { - resource_desc.format = CU_AD_FORMAT_FLOAT; + resource_desc.res.format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - resource_desc.numChannel = 2; - resource_desc.sizeInBytes = 2*total*sizeof(float); + resource_desc.res.numChannel = 2; + resource_desc.res.sizeInBytes = 2*total*sizeof(float); } else { - resourec_desc.numChannel = 1; - resource_desc.sizeInBytes = total*sizeof(float); + resourec_desc.res.numChannel = 1; + resource_desc.res.sizeInBytes = total*sizeof(float); } } else { - resource_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; + resource_desc.res.format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - resource_desc.numChannel = 4; - resource_desc.sizeInBytes = 2*total*sizeof(double); + resource_desc.res.numChannel = 4; + resource_desc.res.sizeInBytes = 2*total*sizeof(double); } else { - resource_desc.numChannel = 2; - resource_desc.sizeInBytes = total*sizeof(double); + resource_desc.res.numChannel = 2; + resource_desc.res.sizeInBytes = total*sizeof(double); } } - check_error(cuMemAllocManaged(&resource_desc.devPtr, - resource_desc.sizeInBytes, + check_error(cuMemAllocManaged(&resource_desc.res.devPtr, + resource_desc.res.sizeInBytes, CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged"); - check_error(cuMemcpyHtoD(resource_desc.devPtr, + check_error(cuMemcpyHtoD(resource_desc.res.devPtr, data, - resource_desc.sizeInBytes), + resource_desc.res.sizeInBytes), "cuMemcpyHtoD"); check_error(cuTexObjectCreate(&texture_arguments[data], -- GitLab From 2900f6de4c18f5a47ccaf61754c89ae49fb712b6 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 15:08:23 -0400 Subject: [PATCH 15/63] Reference correct section of struct. --- graph_framework/cuda_context.hpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 9a6be15..241aebb 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -321,31 +321,31 @@ namespace gpu { view_desc.format = CU_RES_VIEW_FORMAT_NONE; view_desc.width = size; if constexpr (jit::is_float ()) { - resource_desc.format = CU_AD_FORMAT_FLOAT; + resource_desc.res.format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - resource_desc.numChannel = 2; - resource_desc.sizeInBytes = 2*size*sizeof(float); + resource_desc.res.numChannel = 2; + resource_desc.res.sizeInBytes = 2*size*sizeof(float); } else { - resourec_desc.numChannel = 1; - resource_desc.sizeInBytes = size*sizeof(float); + resourec_desc.res.numChannel = 1; + resource_desc.res.sizeInBytes = size*sizeof(float); } } else { - resource_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; + resource_desc.res.format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - resource_desc.numChannel = 4; - resource_desc.sizeInBytes = 2*size*sizeof(double); + resource_desc.res.numChannel = 4; + resource_desc.res.sizeInBytes = 2*size*sizeof(double); } else { - resource_desc.numChannel = 2; - resource_desc.sizeInBytes = size*sizeof(double); + resource_desc.res.numChannel = 2; + resource_desc.res.sizeInBytes = size*sizeof(double); } } - check_error(cuMemAllocManaged(&resource_desc.devPtr, - resource_desc.sizeInBytes, + check_error(cuMemAllocManaged(&resource_desc.res.devPtr, + resource_desc.res.sizeInBytes, CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged"); - check_error(cuMemcpyHtoD(resource_desc.devPtr, + check_error(cuMemcpyHtoD(resource_desc.res.devPtr, data, - resource_desc.sizeInBytes), + resource_desc.res.sizeInBytes), "cuMemcpyHtoD"); check_error(cuTexObjectCreate(&texture_arguments[data], -- GitLab From 57548839409a61de007960d2cef10e9d35e59bd1 Mon Sep 17 00:00:00 2001 From: m4c Date: Wed, 12 Jun 2024 15:35:16 -0400 Subject: [PATCH 16/63] Fix compile errors. --- graph_framework/cuda_context.hpp | 78 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 241aebb..7e5760a 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -321,35 +321,37 @@ namespace gpu { view_desc.format = CU_RES_VIEW_FORMAT_NONE; view_desc.width = size; if constexpr (jit::is_float ()) { - resource_desc.res.format = CU_AD_FORMAT_FLOAT; + resource_desc.res.linear.format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - resource_desc.res.numChannel = 2; - resource_desc.res.sizeInBytes = 2*size*sizeof(float); + resource_desc.res.linear.numChannels = 2; + resource_desc.res.linear.sizeInBytes = 2*size*sizeof(float); } else { - resourec_desc.res.numChannel = 1; - resource_desc.res.sizeInBytes = size*sizeof(float); + resource_desc.res.linear.numChannels = 1; + resource_desc.res.linear.sizeInBytes = size*sizeof(float); } } else { - resource_desc.res.format = CU_AD_FORMAT_UNSIGNED_INT32; + resource_desc.res.linear.format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - resource_desc.res.numChannel = 4; - resource_desc.res.sizeInBytes = 2*size*sizeof(double); + resource_desc.res.linear.numChannels = 4; + resource_desc.res.linear.sizeInBytes = 2*size*sizeof(double); } else { - resource_desc.res.numChannel = 2; - resource_desc.res.sizeInBytes = size*sizeof(double); + resource_desc.res.linear.numChannels = 2; + resource_desc.res.linear.sizeInBytes = size*sizeof(double); } } - check_error(cuMemAllocManaged(&resource_desc.res.devPtr, - resource_desc.res.sizeInBytes, + check_error(cuMemAllocManaged(&resource_desc.res.linear.devPtr, + resource_desc.res.linear.sizeInBytes, CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged"); - check_error(cuMemcpyHtoD(resource_desc.res.devPtr, + check_error(cuMemcpyHtoD(resource_desc.res.linear.devPtr, data, - resource_desc.res.sizeInBytes), + resource_desc.res.linear.sizeInBytes), "cuMemcpyHtoD"); - + check_error(cuTexObjectCreate(&texture_arguments[data], - resource_desc,), + &resource_desc, + &texture_desc, + &view_desc), "cuTexObjectCreate"); } buffers.push_back(reinterpret_cast (&texture_arguments[data])); @@ -366,35 +368,37 @@ namespace gpu { view_desc.height = size[1]; const size_t total = size[0]*size[1]; if constexpr (jit::is_float ()) { - resource_desc.res.format = CU_AD_FORMAT_FLOAT; + resource_desc.res.linear.format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - resource_desc.res.numChannel = 2; - resource_desc.res.sizeInBytes = 2*total*sizeof(float); + resource_desc.res.linear.numChannels = 2; + resource_desc.res.linear.sizeInBytes = 2*total*sizeof(float); } else { - resourec_desc.res.numChannel = 1; - resource_desc.res.sizeInBytes = total*sizeof(float); + resource_desc.res.linear.numChannels = 1; + resource_desc.res.linear.sizeInBytes = total*sizeof(float); } } else { - resource_desc.res.format = CU_AD_FORMAT_UNSIGNED_INT32; + resource_desc.res.linear.format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - resource_desc.res.numChannel = 4; - resource_desc.res.sizeInBytes = 2*total*sizeof(double); + resource_desc.res.linear.numChannels = 4; + resource_desc.res.linear.sizeInBytes = 2*total*sizeof(double); } else { - resource_desc.res.numChannel = 2; - resource_desc.res.sizeInBytes = total*sizeof(double); + resource_desc.res.linear.numChannels = 2; + resource_desc.res.linear.sizeInBytes = total*sizeof(double); } } - check_error(cuMemAllocManaged(&resource_desc.res.devPtr, - resource_desc.res.sizeInBytes, + check_error(cuMemAllocManaged(&resource_desc.res.linear.devPtr, + resource_desc.res.linear.sizeInBytes, CU_MEM_ATTACH_GLOBAL), "cuMemAllocManaged"); - check_error(cuMemcpyHtoD(resource_desc.res.devPtr, + check_error(cuMemcpyHtoD(resource_desc.res.linear.devPtr, data, - resource_desc.res.sizeInBytes), + resource_desc.res.linear.sizeInBytes), "cuMemcpyHtoD"); - + check_error(cuTexObjectCreate(&texture_arguments[data], - resource_desc,), + &resource_desc, + &texture_desc, + &view_desc), "cuTexObjectCreate"); } buffers.push_back(reinterpret_cast (&texture_arguments[data])); @@ -419,7 +423,7 @@ namespace gpu { std::cout << " Suggested Block size : " << value << std::endl; } - return [this, function, thread_groups, threads_per_group, buffers, textures] () mutable { + return [this, function, thread_groups, threads_per_group, buffers] () mutable { check_error_async(cuLaunchKernel(function, thread_groups, 1, 1, threads_per_group, 1, 1, 0, stream, buffers.data(), NULL), @@ -553,13 +557,13 @@ namespace gpu { if constexpr (jit::is_float ()) { source_buffer << "static __inline__ __device__ complex to_cmp_float(float2 p) {" << " return "; - jit::add_type (stream); + jit::add_type (source_buffer); source_buffer << " (p.x, p.y);" << "}"; } else { source_buffer << "static __inline__ __device__ complex to_cmp_double(uint4 p) {" << " return "; - jit::add_type (stream); + jit::add_type (source_buffer); source_buffer << " (__hiloint2double(p.y, p.x), __hiloint2double(p.w, p.z));" << "}"; } @@ -621,13 +625,13 @@ namespace gpu { source_buffer << "," << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('a', textures1d[i].get()); + source_buffer << " *" << jit::to_string('a', textures1d[i].first); } for (size_t i = 0, ie = textures2d.size(); i < ie; i++) { source_buffer << "," << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('a', textures2d[i].get()); + source_buffer << " *" << jit::to_string('a', textures2d[i].first); } source_buffer << ") {" << std::endl; -- GitLab From b187f002dd4a0548d695432599d484246855ed95 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 15:43:31 -0400 Subject: [PATCH 17/63] Repleace uint with unsigned int since uint is not defined for cuda kernels. --- graph_framework/piecewise.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index b104d81..b1cae98 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -24,7 +24,7 @@ template void compile_index(std::ostringstream &stream, const std::string ®ister_name, const size_t length) { - stream << "min(max((uint)"; + stream << "min(max((unsigned int)"; if constexpr (jit::is_complex ()) { stream << "real("; } -- GitLab From bffca1a0b13f5a7663d00666266d613ab94bd175 Mon Sep 17 00:00:00 2001 From: m4c Date: Wed, 12 Jun 2024 16:00:08 -0400 Subject: [PATCH 18/63] Set the correct types for cuda texture objects. --- graph_framework/cuda_context.hpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 7e5760a..a20b2a6 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -623,15 +623,13 @@ namespace gpu { } for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { source_buffer << "," << std::endl; - source_buffer << " "; - jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('a', textures1d[i].first); + source_buffer << " cudaTextureObject_t " + << jit::to_string('a', textures1d[i].first); } for (size_t i = 0, ie = textures2d.size(); i < ie; i++) { source_buffer << "," << std::endl; - source_buffer << " "; - jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('a', textures2d[i].first); + source_buffer << " cudaTextureObject_t " + << jit::to_string('a', textures2d[i].first); } source_buffer << ") {" << std::endl; -- GitLab From 77a29e3c60a0be12a263839ced0873e580b4bb0c Mon Sep 17 00:00:00 2001 From: m4c Date: Wed, 12 Jun 2024 16:34:07 -0400 Subject: [PATCH 19/63] Zero out discriptor structures. --- graph_framework/cuda_context.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index a20b2a6..e104c50 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -9,6 +9,7 @@ #define cuda_context_h #include +#include #include #include @@ -316,6 +317,11 @@ namespace gpu { CUDA_RESOURCE_DESC resource_desc; CUDA_TEXTURE_DESC texture_desc; CUDA_RESOURCE_VIEW_DESC view_desc; + + memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); + memset(&resource_desc, 0, sizeof(CUDA_TEXTURE_DESC)); + memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_VIEW_DESC)); + resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; texture_desc.flags = CU_TRSF_READ_AS_INTEGER; view_desc.format = CU_RES_VIEW_FORMAT_NONE; @@ -361,6 +367,11 @@ namespace gpu { CUDA_RESOURCE_DESC resource_desc; CUDA_TEXTURE_DESC texture_desc; CUDA_RESOURCE_VIEW_DESC view_desc; + + memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); + memset(&resource_desc, 0, sizeof(CUDA_TEXTURE_DESC)); + memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_VIEW_DESC)); + resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; texture_desc.flags = CU_TRSF_READ_AS_INTEGER; view_desc.format = CU_RES_VIEW_FORMAT_NONE; -- GitLab From ba35f194b6fecc26cac5bd1594cae49032b88007 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:17:44 -0400 Subject: [PATCH 20/63] Use CUArray to create texture. NULL the testure discriptor. --- graph_framework/cuda_context.hpp | 106 +++++++++++++------------------ 1 file changed, 45 insertions(+), 61 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index e104c50..de92e97 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -140,6 +140,15 @@ namespace gpu { check_error(cuMemFree(value), "cuMemFree"); } + for (auto &[key, value] : texture_arguments) { + CUDA_RESOURCE_DESC resource; + check_error(cuTexObjectGetResourceDesc(&resource, value), + "cuTexObjectGetResourceDesc"); + + check_error(cuMemFree(resource.res.linear.devPtr), "cuMemFree"); + check_error(cuTexObjectDestroy(value), "cuTexObjectDestroy"); + } + if (result_buffer) { check_error(cuMemFree(result_buffer), "cuMemFree"); result_buffer = 0; @@ -315,101 +324,76 @@ namespace gpu { for (auto &[data, size] : tex1d_list) { if (!texture_arguments.contains(data)) { CUDA_RESOURCE_DESC resource_desc; - CUDA_TEXTURE_DESC texture_desc; - CUDA_RESOURCE_VIEW_DESC view_desc; + CUDA_ARRAY_DESCRIPTOR array_desc; + + array_desc.width = size; + array_desc.height = 1; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); - memset(&resource_desc, 0, sizeof(CUDA_TEXTURE_DESC)); - memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_VIEW_DESC)); - resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; - texture_desc.flags = CU_TRSF_READ_AS_INTEGER; - view_desc.format = CU_RES_VIEW_FORMAT_NONE; - view_desc.width = size; + array_desc.resType = CU_RESOURCE_TYPE_ARRAY; if constexpr (jit::is_float ()) { - resource_desc.res.linear.format = CU_AD_FORMAT_FLOAT; + array_desc.format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - resource_desc.res.linear.numChannels = 2; - resource_desc.res.linear.sizeInBytes = 2*size*sizeof(float); + array_desc.numChannels = 2; } else { - resource_desc.res.linear.numChannels = 1; - resource_desc.res.linear.sizeInBytes = size*sizeof(float); + array_desc.numChannels = 1; } } else { - resource_desc.res.linear.format = CU_AD_FORMAT_UNSIGNED_INT32; + array_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - resource_desc.res.linear.numChannels = 4; - resource_desc.res.linear.sizeInBytes = 2*size*sizeof(double); + array_desc.numChannels = 4; } else { - resource_desc.res.linear.numChannels = 2; - resource_desc.res.linear.sizeInBytes = size*sizeof(double); + array_desc.numChannels = 2; } } - check_error(cuMemAllocManaged(&resource_desc.res.linear.devPtr, - resource_desc.res.linear.sizeInBytes, - CU_MEM_ATTACH_GLOBAL), - "cuMemAllocManaged"); - check_error(cuMemcpyHtoD(resource_desc.res.linear.devPtr, - data, - resource_desc.res.linear.sizeInBytes), - "cuMemcpyHtoD"); + check_error(cuArrayCreate(&resource_desc.array, &array_desc), + "cuArrayCreate"); + check_error(cuMemcpyHtoA(resource_desc.array, 0, data, + size*sizeof(float)*array_desc.numChannels), + "cuMemcpyHtoA"); check_error(cuTexObjectCreate(&texture_arguments[data], - &resource_desc, - &texture_desc, - &view_desc), + &resource_desc, NULL, NULL), "cuTexObjectCreate"); } buffers.push_back(reinterpret_cast (&texture_arguments[data])); } for (auto &[data, size] : tex2d_list) { if (!texture_arguments.contains(data)) { - CUDA_RESOURCE_DESC resource_desc; - CUDA_TEXTURE_DESC texture_desc; - CUDA_RESOURCE_VIEW_DESC view_desc; + CUDA_RESOURCE_DESC *resource_desc; + CUDA_ARRAY_DESCRIPTOR array_desc; + + array_desc.width = size; + array_desc.height = 1; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); - memset(&resource_desc, 0, sizeof(CUDA_TEXTURE_DESC)); - memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_VIEW_DESC)); - - resource_desc.resType = CU_RESOURCE_TYPE_LINEAR; - texture_desc.flags = CU_TRSF_READ_AS_INTEGER; - view_desc.format = CU_RES_VIEW_FORMAT_NONE; - view_desc.width = size[0]; - view_desc.height = size[1]; + + resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; const size_t total = size[0]*size[1]; if constexpr (jit::is_float ()) { - resource_desc.res.linear.format = CU_AD_FORMAT_FLOAT; + array_desc.format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - resource_desc.res.linear.numChannels = 2; - resource_desc.res.linear.sizeInBytes = 2*total*sizeof(float); + array_desc.numChannels = 2; } else { - resource_desc.res.linear.numChannels = 1; - resource_desc.res.linear.sizeInBytes = total*sizeof(float); + array_desc.numChannels = 1; } } else { - resource_desc.res.linear.format = CU_AD_FORMAT_UNSIGNED_INT32; + array_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - resource_desc.res.linear.numChannels = 4; - resource_desc.res.linear.sizeInBytes = 2*total*sizeof(double); + array_desc.numChannels = 4; } else { - resource_desc.res.linear.numChannels = 2; - resource_desc.res.linear.sizeInBytes = total*sizeof(double); + array_desc.numChannels = 2; } } - check_error(cuMemAllocManaged(&resource_desc.res.linear.devPtr, - resource_desc.res.linear.sizeInBytes, - CU_MEM_ATTACH_GLOBAL), - "cuMemAllocManaged"); - check_error(cuMemcpyHtoD(resource_desc.res.linear.devPtr, - data, - resource_desc.res.linear.sizeInBytes), - "cuMemcpyHtoD"); + check_error(cuArrayCreate(&resource_desc.array, &array_desc), + "cuArrayCreate"); + check_error(cuMemcpyHtoA(resource_desc.array, 0, data, + size[0]*size[1]*sizeof(float)*array_desc.numChannels), + "cuMemcpyHtoA"); check_error(cuTexObjectCreate(&texture_arguments[data], - &resource_desc, - &texture_desc, - &view_desc), + &resource_desc, NULL, NULL), "cuTexObjectCreate"); } buffers.push_back(reinterpret_cast (&texture_arguments[data])); -- GitLab From ca7e3003eb10acff3650fe3045287d89c4664ec4 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:22:59 -0400 Subject: [PATCH 21/63] This shouldn't be a pointer and fix case on members. --- graph_framework/cuda_context.hpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index de92e97..8039433 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -326,25 +326,25 @@ namespace gpu { CUDA_RESOURCE_DESC resource_desc; CUDA_ARRAY_DESCRIPTOR array_desc; - array_desc.width = size; - array_desc.height = 1; + array_desc.Width = size; + array_desc.Height = 1; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); array_desc.resType = CU_RESOURCE_TYPE_ARRAY; if constexpr (jit::is_float ()) { - array_desc.format = CU_AD_FORMAT_FLOAT; + array_desc.Format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - array_desc.numChannels = 2; + array_desc.NumChannels = 2; } else { - array_desc.numChannels = 1; + array_desc.NumChannels = 1; } } else { - array_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; + array_desc.Format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - array_desc.numChannels = 4; + array_desc.NumChannels = 4; } else { - array_desc.numChannels = 2; + array_desc.NumChannels = 2; } } check_error(cuArrayCreate(&resource_desc.array, &array_desc), @@ -361,7 +361,7 @@ namespace gpu { } for (auto &[data, size] : tex2d_list) { if (!texture_arguments.contains(data)) { - CUDA_RESOURCE_DESC *resource_desc; + CUDA_RESOURCE_DESC resource_desc; CUDA_ARRAY_DESCRIPTOR array_desc; array_desc.width = size; -- GitLab From 51811ed163a81bf2e594f48f87f5b4d765522a63 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:25:27 -0400 Subject: [PATCH 22/63] Fix case. --- graph_framework/cuda_context.hpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 8039433..1a131f9 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -331,7 +331,7 @@ namespace gpu { memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); - array_desc.resType = CU_RESOURCE_TYPE_ARRAY; + resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; if constexpr (jit::is_float ()) { array_desc.Format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { @@ -347,10 +347,10 @@ namespace gpu { array_desc.NumChannels = 2; } } - check_error(cuArrayCreate(&resource_desc.array, &array_desc), + check_error(cuArrayCreate(&resource_desc.res.array, &array_desc), "cuArrayCreate"); - check_error(cuMemcpyHtoA(resource_desc.array, 0, data, - size*sizeof(float)*array_desc.numChannels), + check_error(cuMemcpyHtoA(resource_desc.res.array, 0, data, + size*sizeof(float)*array_desc.NumChannels), "cuMemcpyHtoA"); check_error(cuTexObjectCreate(&texture_arguments[data], @@ -364,32 +364,32 @@ namespace gpu { CUDA_RESOURCE_DESC resource_desc; CUDA_ARRAY_DESCRIPTOR array_desc; - array_desc.width = size; - array_desc.height = 1; + array_desc.Width = size; + array_desc.Height = 1; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; const size_t total = size[0]*size[1]; if constexpr (jit::is_float ()) { - array_desc.format = CU_AD_FORMAT_FLOAT; + array_desc.Format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { - array_desc.numChannels = 2; + array_desc.NumChannels = 2; } else { - array_desc.numChannels = 1; + array_desc.NumChannels = 1; } } else { - array_desc.format = CU_AD_FORMAT_UNSIGNED_INT32; + array_desc.Format = CU_AD_FORMAT_UNSIGNED_INT32; if constexpr (jit::is_complex ()) { - array_desc.numChannels = 4; + array_desc.NumChannels = 4; } else { - array_desc.numChannels = 2; + array_desc.NumChannels = 2; } } check_error(cuArrayCreate(&resource_desc.array, &array_desc), "cuArrayCreate"); - check_error(cuMemcpyHtoA(resource_desc.array, 0, data, - size[0]*size[1]*sizeof(float)*array_desc.numChannels), + check_error(cuMemcpyHtoA(resource_desc.res.array, 0, data, + size[0]*size[1]*sizeof(float)*array_desc.NumChannels), "cuMemcpyHtoA"); check_error(cuTexObjectCreate(&texture_arguments[data], -- GitLab From b42393465c9c4d043e3f1fc98b71eadc798a7713 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:28:59 -0400 Subject: [PATCH 23/63] Do not dereference the cuarray. --- graph_framework/cuda_context.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 1a131f9..1b12a80 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -347,7 +347,7 @@ namespace gpu { array_desc.NumChannels = 2; } } - check_error(cuArrayCreate(&resource_desc.res.array, &array_desc), + check_error(cuArrayCreate(resource_desc.res.array, &array_desc), "cuArrayCreate"); check_error(cuMemcpyHtoA(resource_desc.res.array, 0, data, size*sizeof(float)*array_desc.NumChannels), @@ -386,7 +386,7 @@ namespace gpu { array_desc.NumChannels = 2; } } - check_error(cuArrayCreate(&resource_desc.array, &array_desc), + check_error(cuArrayCreate(resource_desc.res.array, &array_desc), "cuArrayCreate"); check_error(cuMemcpyHtoA(resource_desc.res.array, 0, data, size[0]*size[1]*sizeof(float)*array_desc.NumChannels), -- GitLab From 1dedb02ecddf42b14164a8909bd8c6eaa094f71c Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:33:30 -0400 Subject: [PATCH 24/63] Fix Array access name. --- graph_framework/cuda_context.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 1b12a80..26f7be5 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -347,9 +347,9 @@ namespace gpu { array_desc.NumChannels = 2; } } - check_error(cuArrayCreate(resource_desc.res.array, &array_desc), + check_error(cuArrayCreate(&resource_desc.res.array.hArray, &array_desc), "cuArrayCreate"); - check_error(cuMemcpyHtoA(resource_desc.res.array, 0, data, + check_error(cuMemcpyHtoA(resource_desc.res.array.hArray, 0, data, size*sizeof(float)*array_desc.NumChannels), "cuMemcpyHtoA"); @@ -386,9 +386,9 @@ namespace gpu { array_desc.NumChannels = 2; } } - check_error(cuArrayCreate(resource_desc.res.array, &array_desc), + check_error(cuArrayCreate(&resource_desc.res.array.hArray, &array_desc), "cuArrayCreate"); - check_error(cuMemcpyHtoA(resource_desc.res.array, 0, data, + check_error(cuMemcpyHtoA(resource_desc.res.array.hArray, 0, data, size[0]*size[1]*sizeof(float)*array_desc.NumChannels), "cuMemcpyHtoA"); -- GitLab From 6dfa1069250485ca0960ef96a45c8e59d9ae0053 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:36:33 -0400 Subject: [PATCH 25/63] Fis destructor and sizes in 2D array. --- graph_framework/cuda_context.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 26f7be5..8723269 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -145,7 +145,7 @@ namespace gpu { check_error(cuTexObjectGetResourceDesc(&resource, value), "cuTexObjectGetResourceDesc"); - check_error(cuMemFree(resource.res.linear.devPtr), "cuMemFree"); + check_error(cuArrayDestroy(resource.res.array.harray), "cuArrayDestroy"); check_error(cuTexObjectDestroy(value), "cuTexObjectDestroy"); } @@ -364,8 +364,8 @@ namespace gpu { CUDA_RESOURCE_DESC resource_desc; CUDA_ARRAY_DESCRIPTOR array_desc; - array_desc.Width = size; - array_desc.Height = 1; + array_desc.Width = size[0]; + array_desc.Height = size[1]; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); -- GitLab From 108d9532ebf4cde1ebd03a254a18d76938897a5b Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 17:38:19 -0400 Subject: [PATCH 26/63] Fix case. --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 8723269..92eaeec 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -145,7 +145,7 @@ namespace gpu { check_error(cuTexObjectGetResourceDesc(&resource, value), "cuTexObjectGetResourceDesc"); - check_error(cuArrayDestroy(resource.res.array.harray), "cuArrayDestroy"); + check_error(cuArrayDestroy(resource.res.array.hArray), "cuArrayDestroy"); check_error(cuTexObjectDestroy(value), "cuTexObjectDestroy"); } -- GitLab From 4a1e69cdcf0dbdf4b097a571714f509edc11430b Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Jun 2024 18:32:48 -0400 Subject: [PATCH 27/63] Add back texture discriptor --- graph_framework/cuda_context.hpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 92eaeec..66e54f6 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -324,14 +324,19 @@ namespace gpu { for (auto &[data, size] : tex1d_list) { if (!texture_arguments.contains(data)) { CUDA_RESOURCE_DESC resource_desc; + CUDA_TEXTURE_DESC texture_desc; CUDA_ARRAY_DESCRIPTOR array_desc; array_desc.Width = size; array_desc.Height = 1; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); + memset(&texture_desc, 0, sizeof(CUDA_TEXTURE_DESC)); resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; + texture_desc.addressMode[0] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[1] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[2] = CU_TR_ADDRESS_MODE_BORDER; if constexpr (jit::is_float ()) { array_desc.Format = CU_AD_FORMAT_FLOAT; if constexpr (jit::is_complex ()) { @@ -354,7 +359,8 @@ namespace gpu { "cuMemcpyHtoA"); check_error(cuTexObjectCreate(&texture_arguments[data], - &resource_desc, NULL, NULL), + &resource_desc, &texture_desc, + NULL), "cuTexObjectCreate"); } buffers.push_back(reinterpret_cast (&texture_arguments[data])); @@ -362,14 +368,19 @@ namespace gpu { for (auto &[data, size] : tex2d_list) { if (!texture_arguments.contains(data)) { CUDA_RESOURCE_DESC resource_desc; + CUDA_TEXTURE_DESC texture_desc; CUDA_ARRAY_DESCRIPTOR array_desc; array_desc.Width = size[0]; array_desc.Height = size[1]; memset(&resource_desc, 0, sizeof(CUDA_RESOURCE_DESC)); + memset(&texture_desc, 0, sizeof(CUDA_TEXTURE_DESC)); resource_desc.resType = CU_RESOURCE_TYPE_ARRAY; + texture_desc.addressMode[0] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[1] = CU_TR_ADDRESS_MODE_BORDER; + texture_desc.addressMode[2] = CU_TR_ADDRESS_MODE_BORDER; const size_t total = size[0]*size[1]; if constexpr (jit::is_float ()) { array_desc.Format = CU_AD_FORMAT_FLOAT; @@ -393,7 +404,8 @@ namespace gpu { "cuMemcpyHtoA"); check_error(cuTexObjectCreate(&texture_arguments[data], - &resource_desc, NULL, NULL), + &resource_desc, &texture_desc, + NULL), "cuTexObjectCreate"); } buffers.push_back(reinterpret_cast (&texture_arguments[data])); -- GitLab From 65d25fb6d637d7d608f87ddc414be2bbc15f7c09 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 10:47:30 -0400 Subject: [PATCH 28/63] Make sure the object is created in the map before constructing the texture object. Add a error message to write out cuda errors. --- graph_framework/cuda_context.hpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 66e54f6..56b130b 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -41,6 +41,9 @@ namespace gpu { #ifndef NDEBUG const char *error; cuGetErrorString(result, &error); + if (error !== CUDA_SUCCESS) { + std::cerr << name << " " << std::string(error) << std::endl; + } assert(result == CUDA_SUCCESS && error); #endif } @@ -323,6 +326,7 @@ namespace gpu { for (auto &[data, size] : tex1d_list) { if (!texture_arguments.contains(data)) { + texture_arguments.try_emplace(data); CUDA_RESOURCE_DESC resource_desc; CUDA_TEXTURE_DESC texture_desc; CUDA_ARRAY_DESCRIPTOR array_desc; @@ -367,6 +371,7 @@ namespace gpu { } for (auto &[data, size] : tex2d_list) { if (!texture_arguments.contains(data)) { + texture_arguments.try_emplace(data); CUDA_RESOURCE_DESC resource_desc; CUDA_TEXTURE_DESC texture_desc; CUDA_ARRAY_DESCRIPTOR array_desc; -- GitLab From 248fd0002b97d232fe55cf5608375f5fad8c2a88 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 10:57:53 -0400 Subject: [PATCH 29/63] Fix typo inn not equals. --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 56b130b..51ca60d 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -41,7 +41,7 @@ namespace gpu { #ifndef NDEBUG const char *error; cuGetErrorString(result, &error); - if (error !== CUDA_SUCCESS) { + if (error != CUDA_SUCCESS) { std::cerr << name << " " << std::string(error) << std::endl; } assert(result == CUDA_SUCCESS && error); -- GitLab From 411f8f634070c17070799c4873b743f26b422432 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 11:31:10 -0400 Subject: [PATCH 30/63] Result is the error number. --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 51ca60d..a690b80 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -41,7 +41,7 @@ namespace gpu { #ifndef NDEBUG const char *error; cuGetErrorString(result, &error); - if (error != CUDA_SUCCESS) { + if (result != CUDA_SUCCESS) { std::cerr << name << " " << std::string(error) << std::endl; } assert(result == CUDA_SUCCESS && error); -- GitLab From 026c2545a9766f5ebbc8bacdfe5e989f2f4244d1 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 11:40:24 -0400 Subject: [PATCH 31/63] Use tex2D for 2D textures. --- graph_framework/piecewise.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index b1cae98..f979d9f 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -691,13 +691,13 @@ void compile_index(std::ostringstream &stream, stream << " " << registers[this] << " = "; if constexpr (jit::use_cuda()) { if constexpr (jit::is_float ()) { - stream << "tex1D ("; + stream << "tex2D ("; } else if constexpr (jit::is_double ()) { - stream << "to_double(tex1D ("; + stream << "to_double(tex2D ("; } else if constexpr (jit::is_complex () && jit::is_float ()) { - stream << "to_cmp_float(tex1D ("; + stream << "to_cmp_float(tex2D ("; } else { - stream << "to_cmp_double(tex1D ("; + stream << "to_cmp_double(tex2D ("; } } stream << registers[leaf_node::backend_cache[data_hash].data()]; -- GitLab From 0d0d7fdff3647e1667e7bf9f5ca670c748ef5504 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 12:19:48 -0400 Subject: [PATCH 32/63] Use 2D memcopy. --- graph_framework/cuda_context.hpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index a690b80..57ebcbd 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -404,9 +404,19 @@ namespace gpu { } check_error(cuArrayCreate(&resource_desc.res.array.hArray, &array_desc), "cuArrayCreate"); - check_error(cuMemcpyHtoA(resource_desc.res.array.hArray, 0, data, - size[0]*size[1]*sizeof(float)*array_desc.NumChannels), - "cuMemcpyHtoA"); + + CUDA_MEMCPY2D copy_desc; + copy_desc.srcPitch = size[0]*sizeof(float)*array_desc.NumChannels + copy_desc.srcMemoryType = CU_MEMORYTYPE_HOST; + copy_desc.srcHost = data; + + copy_desc.dstMemoryType = CU_MEMORYTYPE_HOST; + copy_desc.dstArray = resource_desc.res.array.hArray; + + copy_desc.WidthInBytes = copyParam.srcPitch; + copy_desc.Height = size[0] + + check_error(cuMemcpy2D(©_desc), "cuMemcpy2D"); check_error(cuTexObjectCreate(&texture_arguments[data], &resource_desc, &texture_desc, -- GitLab From 0c5953ce494a5744ab3956b29f51d8d26be09bd4 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 12:22:14 -0400 Subject: [PATCH 33/63] Add missing semicolons. --- graph_framework/cuda_context.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 57ebcbd..763bdc4 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -406,7 +406,7 @@ namespace gpu { "cuArrayCreate"); CUDA_MEMCPY2D copy_desc; - copy_desc.srcPitch = size[0]*sizeof(float)*array_desc.NumChannels + copy_desc.srcPitch = size[0]*sizeof(float)*array_desc.NumChannels; copy_desc.srcMemoryType = CU_MEMORYTYPE_HOST; copy_desc.srcHost = data; @@ -414,7 +414,7 @@ namespace gpu { copy_desc.dstArray = resource_desc.res.array.hArray; copy_desc.WidthInBytes = copyParam.srcPitch; - copy_desc.Height = size[0] + copy_desc.Height = size[0]; check_error(cuMemcpy2D(©_desc), "cuMemcpy2D"); -- GitLab From 97bab97a2536c2a239ea92a3681abcce50bb7ecf Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 12:24:44 -0400 Subject: [PATCH 34/63] Fix copy and past error. --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 763bdc4..9815684 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -413,7 +413,7 @@ namespace gpu { copy_desc.dstMemoryType = CU_MEMORYTYPE_HOST; copy_desc.dstArray = resource_desc.res.array.hArray; - copy_desc.WidthInBytes = copyParam.srcPitch; + copy_desc.WidthInBytes = copy_desc.srcPitch; copy_desc.Height = size[0]; check_error(cuMemcpy2D(©_desc), "cuMemcpy2D"); -- GitLab From 36e9f0424d4b4d8dbd21736f7be1c3d627915973 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 12:30:18 -0400 Subject: [PATCH 35/63] Fix dest memory type. --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 9815684..f89b3f7 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -410,7 +410,7 @@ namespace gpu { copy_desc.srcMemoryType = CU_MEMORYTYPE_HOST; copy_desc.srcHost = data; - copy_desc.dstMemoryType = CU_MEMORYTYPE_HOST; + copy_desc.dstMemoryType = CU_MEMORYTYPE_ARRAY; copy_desc.dstArray = resource_desc.res.array.hArray; copy_desc.WidthInBytes = copy_desc.srcPitch; -- GitLab From 4675d6b151cb94fcc2f1a7df8b843a59ec848248 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 12:42:05 -0400 Subject: [PATCH 36/63] Zero inital parameters. --- graph_framework/cuda_context.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index f89b3f7..27ac75b 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -406,6 +406,8 @@ namespace gpu { "cuArrayCreate"); CUDA_MEMCPY2D copy_desc; + memset(©_desc, 0, sizeof(copy_desc)); + copy_desc.srcPitch = size[0]*sizeof(float)*array_desc.NumChannels; copy_desc.srcMemoryType = CU_MEMORYTYPE_HOST; copy_desc.srcHost = data; -- GitLab From af0d4f5bd058ed967705bfbb1b7e8b932432edda Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 13:07:13 -0400 Subject: [PATCH 37/63] Transpose the 2D Array indicies. --- graph_framework/piecewise.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index f979d9f..cfbe1e3 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -705,15 +705,15 @@ void compile_index(std::ostringstream &stream, const size_t num_rows = length/num_columns; if constexpr (jit::use_metal ()) { stream << ".read(uint2("; - compile_index (stream, registers[x.get()], num_rows); - stream << ","; compile_index (stream, registers[y.get()], num_columns); - stream << ").yx).r;"; - } else if constexpr (jit::use_cuda()) { - stream << ", "; + stream << ","; compile_index (stream, registers[x.get()], num_rows); + stream << ")).r;"; + } else if constexpr (jit::use_cuda()) { stream << ", "; compile_index (stream, registers[y.get()], num_columns); + stream << ", "; + compile_index (stream, registers[x.get()], num_rows); stream << ");"; } else { stream << "[min(max((int)"; -- GitLab From cd50c62fc7d0c4255f088be3c00fc383143b3b13 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 13:14:26 -0400 Subject: [PATCH 38/63] Ensure brackets are closed for complex and double. --- graph_framework/piecewise.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index cfbe1e3..071726e 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -714,6 +714,9 @@ void compile_index(std::ostringstream &stream, compile_index (stream, registers[y.get()], num_columns); stream << ", "; compile_index (stream, registers[x.get()], num_rows); + if constexpr (jit::is_complex () || jit::is_double ()) { + stream << ")"; + } stream << ");"; } else { stream << "[min(max((int)"; -- GitLab From 30fab2e9eb23eb433de9bd9901deaf790854eab7 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 13:20:49 -0400 Subject: [PATCH 39/63] Ensure brackets are closed for complex and double on 1D arrays. --- graph_framework/piecewise.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 071726e..cc10002 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -268,6 +268,9 @@ void compile_index(std::ostringstream &stream, } else if constexpr (jit::use_cuda()) { stream << ", "; compile_index (stream, registers[a.get()], length); + if constexpr (jit::is_complex () || jit::is_double ()) { + stream << ")"; + } stream << ");"; } else { stream << "["; -- GitLab From 046b058841241891e1678ef7aed98b481af938ab Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 13:31:12 -0400 Subject: [PATCH 40/63] Use the correct branches for complex. Fix generated code formatting line brakes. --- graph_framework/cuda_context.hpp | 13 ++++++++----- graph_framework/piecewise.hpp | 12 ++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 27ac75b..6b2c465 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -580,21 +580,24 @@ namespace gpu { source_buffer << "#include " << std::endl; if constexpr (jit::is_float ()) { source_buffer << "static __inline__ __device__ complex to_cmp_float(float2 p) {" + << std::endl << " return "; jit::add_type (source_buffer); - source_buffer << " (p.x, p.y);" - << "}"; + source_buffer << " (p.x, p.y);" << std::endl + << "}" << std::endl; } else { source_buffer << "static __inline__ __device__ complex to_cmp_double(uint4 p) {" + << std::endl << " return "; jit::add_type (source_buffer); source_buffer << " (__hiloint2double(p.y, p.x), __hiloint2double(p.w, p.z));" - << "}"; + << "}" << std::endl; } } else if constexpr (jit::is_double ()) { - source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" + source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" + << std::endl << " return __hiloint2double(p.y, p.x);" - << "}"; + << "}" << std::endl; } } diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index cc10002..079bea6 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -249,11 +249,11 @@ void compile_index(std::ostringstream &stream, jit::add_type (stream); stream << " " << registers[this] << " = "; if constexpr (jit::use_cuda()) { - if constexpr (jit::is_float ()) { + if constexpr (jit::is_float () && !jit::is_complex ()) { stream << "tex1D ("; - } else if constexpr (jit::is_double ()) { + } else if constexpr (jit::is_double () && !jit::is_complex ()) { stream << "to_double(tex1D ("; - } else if constexpr (jit::is_complex () && jit::is_float ()) { + } else if constexpr (jit::is_float ()) { stream << "to_cmp_float(tex1D ("; } else { stream << "to_cmp_double(tex1D ("; @@ -693,11 +693,11 @@ void compile_index(std::ostringstream &stream, jit::add_type (stream); stream << " " << registers[this] << " = "; if constexpr (jit::use_cuda()) { - if constexpr (jit::is_float ()) { + if constexpr (jit::is_float () && !jit::is_complex ()) { stream << "tex2D ("; - } else if constexpr (jit::is_double ()) { + } else if constexpr (jit::is_double () && !jit::is_complex ()) { stream << "to_double(tex2D ("; - } else if constexpr (jit::is_complex () && jit::is_float ()) { + } else if constexpr (jit::is_float ()) { stream << "to_cmp_float(tex2D ("; } else { stream << "to_cmp_double(tex2D ("; -- GitLab From 61a318b6d9153092ae89201bbf91aec48fa63b2d Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 13:37:43 -0400 Subject: [PATCH 41/63] Fix newlines in generated kernel source. --- graph_framework/cuda_context.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 6b2c465..6f34104 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -591,12 +591,14 @@ namespace gpu { << " return "; jit::add_type (source_buffer); source_buffer << " (__hiloint2double(p.y, p.x), __hiloint2double(p.w, p.z));" + << std::endl << "}" << std::endl; } } else if constexpr (jit::is_double ()) { source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" << std::endl << " return __hiloint2double(p.y, p.x);" + << std::endl; << "}" << std::endl; } } -- GitLab From d37e77c918f83c4a31e6bdcbc65a116878496e9c Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 13 Jun 2024 13:42:11 -0400 Subject: [PATCH 42/63] Fix misplaced semicolon. --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 6f34104..38b4d9c 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -598,7 +598,7 @@ namespace gpu { source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" << std::endl << " return __hiloint2double(p.y, p.x);" - << std::endl; + << std::endl << "}" << std::endl; } } -- GitLab From 88b10c86d8daf7b5dc6165681e4758654782b74a Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Fri, 14 Jun 2024 15:16:15 -0400 Subject: [PATCH 43/63] Make cuda textures optional. --- CMakeLists.txt | 3 +++ graph_framework/cuda_context.hpp | 24 +++++++++++++++++------- graph_framework/piecewise.hpp | 22 ++++++++++++++++++++-- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e31b81..6d751fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,9 +60,12 @@ else () find_package (CUDAToolkit REQUIRED) + option (USE_CUDA_TEXTURES "Enable the use of cuda textures" OFF) + target_compile_definitions (cuda_lib INTERFACE USE_CUDA + $<$:USE_CUDA_TEXTURES> CUDA_INCLUDE="${CUDAToolkit_INCLUDE_DIRS}" ) target_link_libraries (cuda_lib diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 38b4d9c..fd2e011 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -76,8 +76,10 @@ namespace gpu { CUmodule module; /// Argument map. std::map *, CUdeviceptr> kernel_arguments; +#ifdef USE_CUDA_TEXTURES /// Textures. std::map texture_arguments; +#endif /// Result buffer. CUdeviceptr result_buffer; /// Cuda stream. @@ -324,6 +326,7 @@ namespace gpu { buffers.push_back(reinterpret_cast (&kernel_arguments[output.get()])); } +#ifdef USE_CUDA_TEXTURES for (auto &[data, size] : tex1d_list) { if (!texture_arguments.contains(data)) { texture_arguments.try_emplace(data); @@ -427,6 +430,7 @@ namespace gpu { } buffers.push_back(reinterpret_cast (&texture_arguments[data])); } +#endif int value; check_error(cuFuncGetAttribute(&value, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, @@ -578,16 +582,17 @@ namespace gpu { source_buffer << "#define M_PI " << M_PI << std::endl; source_buffer << "#include " << std::endl; source_buffer << "#include " << std::endl; +#ifdef USE_CUDA_TEXTURES if constexpr (jit::is_float ()) { source_buffer << "static __inline__ __device__ complex to_cmp_float(float2 p) {" - << std::endl + << std::endl << " return "; jit::add_type (source_buffer); source_buffer << " (p.x, p.y);" << std::endl << "}" << std::endl; } else { source_buffer << "static __inline__ __device__ complex to_cmp_double(uint4 p) {" - << std::endl + << std::endl << " return "; jit::add_type (source_buffer); source_buffer << " (__hiloint2double(p.y, p.x), __hiloint2double(p.w, p.z));" @@ -595,12 +600,15 @@ namespace gpu { << "}" << std::endl; } } else if constexpr (jit::is_double ()) { - source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" - << std::endl - << " return __hiloint2double(p.y, p.x);" - << std::endl - << "}" << std::endl; + source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" + << std::endl + << " return __hiloint2double(p.y, p.x);" + << std::endl + << "}" << std::endl; } +#else + } +#endif } //------------------------------------------------------------------------------ @@ -650,6 +658,7 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('o', outputs[i].get()); } +#ifdef USE_CUDA_TEXTURES for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { source_buffer << "," << std::endl; source_buffer << " cudaTextureObject_t " @@ -660,6 +669,7 @@ namespace gpu { source_buffer << " cudaTextureObject_t " << jit::to_string('a', textures2d[i].first); } +#endif source_buffer << ") {" << std::endl; source_buffer << " const int index = blockIdx.x*blockDim.x + threadIdx.x;" diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 079bea6..37bf060 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -195,9 +195,14 @@ void compile_index(std::ostringstream &stream, registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); const size_t length = leaf_node::backend_cache[data_hash].size(); - if constexpr (jit::use_metal () || jit::use_cuda()) { + if constexpr (jit::use_metal ()) { textures1d.emplace_back(leaf_node::backend_cache[data_hash].data(), length); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures1d.emplace_back(leaf_node::backend_cache[data_hash].data(), + length); +#endif } else { stream << "const "; jit::add_type (stream); @@ -248,6 +253,7 @@ void compile_index(std::ostringstream &stream, stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = "; +#ifdef USE_CUDA_TEXTURES if constexpr (jit::use_cuda()) { if constexpr (jit::is_float () && !jit::is_complex ()) { stream << "tex1D ("; @@ -259,12 +265,14 @@ void compile_index(std::ostringstream &stream, stream << "to_cmp_double(tex1D ("; } } +#endif stream << registers[leaf_node::backend_cache[data_hash].data()]; const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { stream << ".read("; compile_index (stream, registers[a.get()], length); stream << ").r;"; +#ifdef USE_CUDA_TEXTURES } else if constexpr (jit::use_cuda()) { stream << ", "; compile_index (stream, registers[a.get()], length); @@ -272,6 +280,7 @@ void compile_index(std::ostringstream &stream, stream << ")"; } stream << ");"; +#endif } else { stream << "["; compile_index (stream, registers[a.get()], length); @@ -625,9 +634,14 @@ void compile_index(std::ostringstream &stream, registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); const size_t length = leaf_node::backend_cache[data_hash].size(); - if constexpr (jit::use_metal () || jit::use_cuda()) { + if constexpr (jit::use_metal ()) { + textures2d.emplace_back(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { textures2d.emplace_back(leaf_node::backend_cache[data_hash].data(), std::array ({length/num_columns, num_columns})); +#endif } else { stream << "const "; jit::add_type (stream); @@ -692,6 +706,7 @@ void compile_index(std::ostringstream &stream, stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = "; +#ifdef USE_CUDA_TEXTURES if constexpr (jit::use_cuda()) { if constexpr (jit::is_float () && !jit::is_complex ()) { stream << "tex2D ("; @@ -703,6 +718,7 @@ void compile_index(std::ostringstream &stream, stream << "to_cmp_double(tex2D ("; } } +#endif stream << registers[leaf_node::backend_cache[data_hash].data()]; const size_t length = leaf_node::backend_cache[data_hash].size(); const size_t num_rows = length/num_columns; @@ -712,6 +728,7 @@ void compile_index(std::ostringstream &stream, stream << ","; compile_index (stream, registers[x.get()], num_rows); stream << ")).r;"; +#ifdef USE_CUDA_TEXTURES } else if constexpr (jit::use_cuda()) { stream << ", "; compile_index (stream, registers[y.get()], num_columns); @@ -721,6 +738,7 @@ void compile_index(std::ostringstream &stream, stream << ")"; } stream << ");"; +#endif } else { stream << "[min(max((int)"; if constexpr (jit::is_complex ()) { -- GitLab From 8ba5686c5ddc34b91ec78e515e995df9e5956553 Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Fri, 14 Jun 2024 16:01:14 -0400 Subject: [PATCH 44/63] Fix generator expression to use the correct variable. --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d751fc..9b1bc84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,7 +65,7 @@ else () target_compile_definitions (cuda_lib INTERFACE USE_CUDA - $<$:USE_CUDA_TEXTURES> + $<$:USE_CUDA_TEXTURES> CUDA_INCLUDE="${CUDAToolkit_INCLUDE_DIRS}" ) target_link_libraries (cuda_lib -- GitLab From e9a22dafccb786d80b72c4f091d4167098c87b76 Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Fri, 14 Jun 2024 16:03:20 -0400 Subject: [PATCH 45/63] Wrap don't destruct the textures if they arent used. --- graph_framework/cuda_context.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index fd2e011..0a1bc21 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -145,6 +145,7 @@ namespace gpu { check_error(cuMemFree(value), "cuMemFree"); } +#ifdef USE_CUDA_TEXTURES for (auto &[key, value] : texture_arguments) { CUDA_RESOURCE_DESC resource; check_error(cuTexObjectGetResourceDesc(&resource, value), @@ -153,6 +154,7 @@ namespace gpu { check_error(cuArrayDestroy(resource.res.array.hArray), "cuArrayDestroy"); check_error(cuTexObjectDestroy(value), "cuTexObjectDestroy"); } +#endif if (result_buffer) { check_error(cuMemFree(result_buffer), "cuMemFree"); -- GitLab From 163e7acd2c2462f735849af7b14734f7b59aa00d Mon Sep 17 00:00:00 2001 From: cianciosa Date: Mon, 17 Jun 2024 16:31:11 -0400 Subject: [PATCH 46/63] Add useage counters. --- graph_framework/arithmetic.hpp | 87 +++++++++++++++++++++---------- graph_framework/cpu_context.hpp | 17 ++++-- graph_framework/cuda_context.hpp | 38 +++++++++----- graph_framework/jit.hpp | 25 ++++++--- graph_framework/math.hpp | 57 +++++++++++++------- graph_framework/metal_context.hpp | 19 +++++-- graph_framework/node.hpp | 70 ++++++++++++++++++------- graph_framework/piecewise.hpp | 53 ++++++++++++++----- graph_framework/register.hpp | 2 + graph_framework/trigonometry.hpp | 38 +++++++++----- 10 files changed, 285 insertions(+), 121 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index a7d5322..91682c7 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -277,22 +277,28 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = " << registers[l.get()] << " + " - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -820,22 +826,28 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = " << registers[l.get()] << " - " - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -1543,14 +1555,20 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -1581,8 +1599,8 @@ namespace graph { stream << " : "; } stream << registers[l.get()] << "*" - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -2093,14 +2111,20 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -2124,8 +2148,8 @@ namespace graph { stream << " : "; } stream << registers[l.get()] << "/" - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << "; // usage " + << usage.at(this) << std::endl; } return this->shared_from_this(); } @@ -2920,15 +2944,23 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf m = this->middle->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf m = this->middle->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -2954,15 +2986,14 @@ namespace graph { if constexpr (jit::is_complex ()) { stream << registers[l.get()] << "*" << registers[m.get()] << " + " - << registers[r.get()] << ";" - << std::endl; + << registers[r.get()] << ";"; } else { stream << "fma(" << registers[l.get()] << ", " << registers[m.get()] << ", " - << registers[r.get()] << ");" - << std::endl; + << registers[r.get()] << ");"; } + stream << " // used " << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index 572e246..a70ffea 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -359,6 +359,7 @@ namespace gpu { /// @params[in] size Size of the input buffer. /// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. /// @params[in] textures1d List of 1D kernel textures. /// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ @@ -369,6 +370,7 @@ namespace gpu { const size_t size, const std::vector &is_constant, jit::register_map ®isters, + const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { source_buffer << std::endl; @@ -404,7 +406,8 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " " << registers[input.get()] << " = " << jit::to_string('v', input.get()) - << "[i]; //" << input->get_symbol() << std::endl; + << "[i]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; } } @@ -415,13 +418,17 @@ namespace gpu { /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] setters Map outputs back to input values. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. //------------------------------------------------------------------------------ void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes &outputs, graph::map_nodes &setters, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('v', in.get()); source_buffer << "[i] = "; if constexpr (SAFE_MATH) { @@ -444,7 +451,9 @@ namespace gpu { } } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('o', out.get()); source_buffer << "[i] = "; if constexpr (SAFE_MATH) { diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 0a1bc21..41df697 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -223,16 +223,14 @@ namespace gpu { } const std::string temp = arch.str(); - std::array options({ + std::array options({ temp.c_str(), "--std=c++17", "--relocatable-device-code=false", "--include-path=" CUDA_INCLUDE, "--include-path=" HEADER_DIR, "--extra-device-vectorization", - "--device-as-default-execution-space", - "--ptxas-options", - "-dlcm=cg" + "--device-as-default-execution-space" }); if (nvrtcCompileProgram(kernel_program, options.size(), options.data())) { @@ -265,11 +263,15 @@ namespace gpu { check_nvrtc_error(nvrtcDestroyProgram(&kernel_program), "nvrtcDestroyProgram"); - std::array module_options = { - CU_JIT_MAX_REGISTERS + std::array module_options = { + CU_JIT_MAX_REGISTERS, + CU_JIT_LTO, + CU_JIT_POSITION_INDEPENDENT_CODE }; - std::array module_values = { - reinterpret_cast (168) + std::array module_values = { + reinterpret_cast (168), + reinterpret_cast (1), + reinterpret_cast (0) }; check_error(cuModuleLoadDataEx(&module, ptx, 1, @@ -623,6 +625,7 @@ namespace gpu { /// @params[in] size Size of the input buffer. /// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. /// @params[in] textures1d List of 1D kernel textures. /// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ @@ -633,6 +636,7 @@ namespace gpu { const size_t size, const std::vector &is_constant, jit::register_map ®isters, + const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { source_buffer << std::endl; @@ -683,8 +687,9 @@ namespace gpu { source_buffer << " const "; jit::add_type (source_buffer); source_buffer << " " << registers[input.get()] << " = " - << jit::to_string('v', input.get()) << "[index];" - << std::endl; + << jit::to_string('v', input.get()) + << "[index]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; } } @@ -695,14 +700,17 @@ namespace gpu { /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] setters Map outputs back to input values. /// @params[in,out] registers Map of used registers. - +/// @params[in] usage List of register usage count. //------------------------------------------------------------------------------ void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes &outputs, graph::map_nodes &setters, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('v', in.get()) << "[index] = "; if constexpr (SAFE_MATH) { @@ -726,7 +734,9 @@ namespace gpu { } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('o', out.get()) << "[index] = "; if constexpr (SAFE_MATH) { diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index cf1487d..2235542 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -101,11 +101,12 @@ namespace jit { graph::output_nodes outputs, graph::map_nodes setters) { kernel_names.push_back(name); - + const size_t size = inputs[0]->size(); std::vector is_constant(inputs.size(), true); visiter_map visited; + register_usage usage; kernel_1dtextures[name] = texture1d_list(); kernel_2dtextures[name] = texture2d_list(); for (auto &[out, in] : setters) { @@ -115,32 +116,40 @@ namespace jit { if (found < is_constant.size()) { is_constant[found] = false; } - out->compile_preamble(source_buffer, registers, visited, + out->compile_preamble(source_buffer, registers, + visited, usage, kernel_1dtextures[name], kernel_2dtextures[name]); } for (auto &out : outputs) { - out->compile_preamble(source_buffer, registers, visited, + out->compile_preamble(source_buffer, registers, + visited, usage, kernel_1dtextures[name], kernel_2dtextures[name]); } + for (auto &in : inputs) { + if (usage.find(in.get()) == usage.end()) { + usage[in.get()] == 0; + } + } + gpu_context.create_kernel_prefix(source_buffer, name, inputs, outputs, size, is_constant, - registers, + registers, usage, kernel_1dtextures[name], kernel_2dtextures[name]); for (auto &[out, in] : setters) { - out->compile(source_buffer, registers); + out->compile(source_buffer, registers, usage); } for (auto &out : outputs) { - out->compile(source_buffer, registers); + out->compile(source_buffer, registers, usage); } gpu_context.create_kernel_postfix(source_buffer, outputs, - setters, registers); + setters, registers, usage); // Delete the registers so that they can be used again in other kernels. std::vector removed_elements; @@ -149,7 +158,7 @@ namespace jit { removed_elements.push_back(key); } } - + for (auto &key : removed_elements) { registers.erase(key); } diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index 60456ec..7fc26b3 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -158,21 +158,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { shared_leaf a = this->arg->compile(stream, - registers); + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = sqrt(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -416,13 +419,17 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -454,7 +461,7 @@ namespace graph { stream << ")"; } } - stream << ";" << std::endl; + stream << "; // used " << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -671,20 +678,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = log(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -975,17 +986,21 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); shared_leaf r; auto temp = constant_cast(this->right); if (!temp.get() || !temp->is_integer()) { - r = this->right->compile(stream, registers); + r = this->right->compile(stream, registers, usage); } registers[this] = jit::to_string('r', this); @@ -1004,7 +1019,7 @@ namespace graph { << registers[l.get()] << ", " << registers[r.get()] << ");"; } - stream << std::endl; + stream << " // used " << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -1266,20 +1281,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = special::erfi(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // usage " + << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 652c8b6..4b8c586 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -401,6 +401,7 @@ namespace gpu { /// @params[in] size Size of the input buffer. /// @params[in] is_constant Flags if the input is read only. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. /// @params[in] textures1d List of 1D kernel textures. /// @params[in] textures2d List of 2D kernel textures. //------------------------------------------------------------------------------ @@ -411,6 +412,7 @@ namespace gpu { const size_t size, const std::vector &is_constant, jit::register_map ®isters, + const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { source_buffer << std::endl; @@ -453,8 +455,9 @@ namespace gpu { source_buffer << " const "; jit::add_type (source_buffer); source_buffer << " " << registers[input.get()] << " = " - << jit::to_string('v', input.get()) << "[index];" - << std::endl; + << jit::to_string('v', input.get()) + << "[index]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; } } @@ -465,13 +468,17 @@ namespace gpu { /// @params[in] outputs Output nodes of the graph to compute. /// @params[in] setters Map outputs back to input values. /// @params[in,out] registers Map of used registers. +/// @params[in] usage List of register usage count. //------------------------------------------------------------------------------ void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes &outputs, graph::map_nodes &setters, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { for (auto &[out, in] : setters) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('v', in.get()) << "[index] = "; if constexpr (SAFE_MATH) { @@ -482,7 +489,9 @@ namespace gpu { } for (auto &out : outputs) { - graph::shared_leaf a = out->compile(source_buffer, registers); + graph::shared_leaf a = out->compile(source_buffer, + registers, + usage); source_buffer << " " << jit::to_string('o', out.get()) << "[index] = "; if constexpr (SAFE_MATH) { diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index b1ecd5a..3176968 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -91,25 +91,35 @@ namespace graph { /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. /// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. /// @params[in,out] textures1d List of 1D textures. /// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, + jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) {} + jit::texture2d_list &textures2d) { + if (usage.find(this) == usage.end()) { + usage[this] = 0; + } else { + ++usage[this]; + } + } //------------------------------------------------------------------------------ /// @brief Compile the node. /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual std::shared_ptr> compile(std::ostringstream &stream, - jit::register_map ®isters) = 0; + jit::register_map ®isters, + const jit::register_usage &usage) = 0; //------------------------------------------------------------------------------ /// @brief Querey if the nodes match. @@ -371,11 +381,13 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { registers[this] = jit::to_string('r', this); stream << " const "; @@ -386,7 +398,8 @@ namespace graph { if constexpr (jit::is_complex ()) { jit::add_type (stream); } - stream << temp << ";" << std::endl; + stream << temp << "; // used " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -707,19 +720,24 @@ namespace graph { /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. /// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. /// @params[in,out] textures1d List of 1D textures. /// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, + jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { this->arg->compile_preamble(stream, registers, - visited, textures1d, - textures2d); + visited, usage, + textures1d, textures2d); visited.insert(this); + usage[this] = 0; + } else { + ++usage[this]; } } @@ -728,12 +746,14 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { - return this->arg->compile(stream, registers); + jit::register_map ®isters, + const jit::register_usage &usage) { + return this->arg->compile(stream, registers, usage); } //------------------------------------------------------------------------------ @@ -830,22 +850,27 @@ namespace graph { /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. /// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. /// @params[in,out] textures1d List of 1D textures. /// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, + jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { this->left->compile_preamble(stream, registers, - visited, textures1d, - textures2d); + visited, usage, + textures1d, textures2d); this->right->compile_preamble(stream, registers, - visited, textures1d, - textures2d); + visited, usage, + textures1d, textures2d); visited.insert(this); + usage[this] = 0; + } else { + ++usage[this]; } } @@ -941,25 +966,30 @@ namespace graph { /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. /// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. /// @params[in,out] textures1d List of 1D textures. /// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, + jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { this->left->compile_preamble(stream, registers, - visited, textures1d, - textures2d); + visited, usage, + textures1d, textures2d); this->middle->compile_preamble(stream, registers, - visited, textures1d, - textures2d); + visited, usage, + textures1d, textures2d); this->right->compile_preamble(stream, registers, - visited, textures1d, - textures2d); + visited, usage, + textures1d, textures2d); visited.insert(this); + usage[this] = 0; + } else { + ++usage[this]; } } @@ -1101,11 +1131,13 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { return this->shared_from_this(); } diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 37bf060..02546a5 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -179,18 +179,23 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. /// @params[in,out] textures1d List of 1D textures. /// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, + jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { + this->arg->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d); if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); @@ -220,8 +225,11 @@ void compile_index(std::ostringstream &stream, } stream << "};" << std::endl; } - visited.insert(this); } + visited.insert(this); + usage[this] = 0; + } else { + ++usage[this]; } } @@ -242,13 +250,17 @@ void compile_index(std::ostringstream &stream, /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); @@ -286,7 +298,7 @@ void compile_index(std::ostringstream &stream, compile_index (stream, registers[a.get()], length); stream << "];"; } - stream << std::endl; + stream << " // used " << usage.at(this) <shared_from_this(); @@ -621,15 +633,23 @@ void compile_index(std::ostringstream &stream, /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. /// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. /// @params[in,out] textures1d List of 1D textures. /// @params[in,out] textures2d List of 2D textures. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, + jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { if (visited.find(this) == visited.end()) { + this->left->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d); + this->right->compile_preamble(stream, registers, + visited, usage, + textures1d, textures2d); if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); @@ -659,8 +679,11 @@ void compile_index(std::ostringstream &stream, } stream << "};" << std::endl; } - visited.insert(this); } + visited.insert(this); + usage[this] = 0; + } else { + ++usage[this]; } } @@ -694,14 +717,20 @@ void compile_index(std::ostringstream &stream, /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf x = this->left->compile(stream, registers); - shared_leaf y = this->right->compile(stream, registers); + shared_leaf x = this->left->compile(stream, + registers, + usage); + shared_leaf y = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); @@ -758,7 +787,7 @@ void compile_index(std::ostringstream &stream, } stream << ",0), " << length - 1 << ")];"; } - stream << std::endl; + stream << " // used " << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index f6f6a3b..87cea84 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -244,6 +244,8 @@ namespace jit { /// Type alias for mapping node pointers to register names. typedef std::map register_map; +/// Type alias for counting register usage. + typedef std::map register_usage; /// Type alias for listing visited nodes. typedef std::set visiter_map; /// Type alias for indexing 1D textures. diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index eefd020..7de3cc5 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -126,19 +126,23 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = sin(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // usage " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -364,20 +368,24 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf a = this->arg->compile(stream, registers); + shared_leaf a = this->arg->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = cos(" - << registers[a.get()] << ");" - << std::endl; + << registers[a.get()] << "); // usage " + << usage.at(this) << std::endl; } return this->shared_from_this(); @@ -600,14 +608,20 @@ namespace graph { /// /// @params[in,out] stream String buffer stream. /// @params[in,out] registers List of defined registers. +/// @params[in] usage List of register usage count. /// @returns The current node. //------------------------------------------------------------------------------ virtual shared_leaf compile(std::ostringstream &stream, - jit::register_map ®isters) { + jit::register_map ®isters, + const jit::register_usage &usage) { if (registers.find(this) == registers.end()) { - shared_leaf l = this->left->compile(stream, registers); - shared_leaf r = this->right->compile(stream, registers); + shared_leaf l = this->left->compile(stream, + registers, + usage); + shared_leaf r = this->right->compile(stream, + registers, + usage); registers[this] = jit::to_string('r', this); stream << " const "; @@ -621,7 +635,7 @@ namespace graph { << registers[r.get()] << "," << registers[l.get()]; } - stream << ");" << std::endl; + stream << "); // used " << usage.at(this) << std::endl; } return this->shared_from_this(); -- GitLab From 9b43e8008832ab2f64955f394a2a7901b6486cfc Mon Sep 17 00:00:00 2001 From: cianciosa Date: Mon, 17 Jun 2024 17:24:07 -0400 Subject: [PATCH 47/63] Fix generated code comment for consistency. --- graph_framework/arithmetic.hpp | 2 +- graph_framework/math.hpp | 2 +- graph_framework/trigonometry.hpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 91682c7..a1af546 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -2148,7 +2148,7 @@ namespace graph { stream << " : "; } stream << registers[l.get()] << "/" - << registers[r.get()] << "; // usage " + << registers[r.get()] << "; // used " << usage.at(this) << std::endl; } return this->shared_from_this(); diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index 7fc26b3..d988cde 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -1297,7 +1297,7 @@ namespace graph { stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = special::erfi(" - << registers[a.get()] << "); // usage " + << registers[a.get()] << "); // used " << usage.at(this) << std::endl; } diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index 7de3cc5..91ff128 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -141,7 +141,7 @@ namespace graph { stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = sin(" - << registers[a.get()] << "); // usage " + << registers[a.get()] << "); // used " << usage.at(this) << std::endl; } @@ -384,7 +384,7 @@ namespace graph { stream << " const "; jit::add_type (stream); stream << " " << registers[this] << " = cos(" - << registers[a.get()] << "); // usage " + << registers[a.get()] << "); // used " << usage.at(this) << std::endl; } -- GitLab From 952683ca38b4c253c1b07e74d6693d8144a2283f Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 18 Jun 2024 10:49:06 -0400 Subject: [PATCH 48/63] Add option to save the kernel code. --- CMakeLists.txt | 1 + graph_framework/CMakeLists.txt | 1 + graph_framework/jit.hpp | 24 ++++++++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b1bc84..cb68a96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ project (rays CXX) #------------------------------------------------------------------------------- option (USE_PCH "Enable the use of precompiled headers" ON) option (USE_STATIC "Limits the dyamics for testing." OFF) +option (SAVE_KERNEL_SOURCE "Writes the kernel source code to a file." OFF) #------------------------------------------------------------------------------- # Set the cmake module path. diff --git a/graph_framework/CMakeLists.txt b/graph_framework/CMakeLists.txt index 3c06dac..16af63c 100644 --- a/graph_framework/CMakeLists.txt +++ b/graph_framework/CMakeLists.txt @@ -22,6 +22,7 @@ target_compile_definitions (rays VMEC_FILE="${CMAKE_CURRENT_SOURCE_DIR}/../graph_tests/vmec.nc" $<$:HEADER_DIR="$"> $<$:STATIC> + $<$:SAVE_KERNEL_SOURCE> ) target_include_directories (rays diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index 2235542..f94581d 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -10,6 +10,7 @@ #include #include +#include #ifdef USE_METAL #include "metal_context.hpp" @@ -180,6 +181,26 @@ namespace jit { std::cout << std::endl << source_buffer.str() << std::endl; } +//------------------------------------------------------------------------------ +/// @brief Save the kernel source code. +//------------------------------------------------------------------------------ + void save_source() { + std::string source = source_buffer.str(); + std::ostringstream filename; + filename << std::hash {} (source) + << std::hash{}(std::this_thread::get_id()); + if constexpr (use_cuda()) { + filename << ".cu"; + } else if constexpr (use_metal ()) { + filename << ".metal"; + } else { + filename << ".cpp"; + } + + std::ofstream outFile(filename.str()); + outFile << source; + } + //------------------------------------------------------------------------------ /// @brief Compile the kernel. /// @@ -187,6 +208,9 @@ namespace jit { /// kernel. //------------------------------------------------------------------------------ void compile(const bool add_reduction=false) { +#ifdef SAVE_KERNEL_SOURCE + save_source(); +#endif gpu_context.compile(source_buffer.str(), kernel_names, add_reduction); -- GitLab From 9cb46f428a45aa80780743f13d85f39bf255e707 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 18 Jun 2024 11:48:18 -0400 Subject: [PATCH 49/63] Simplify cases of fma(1,a,b) and fma(a,1,b). --- graph_framework/arithmetic.hpp | 8 ++++++-- graph_tests/arithmetic_test.cpp | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index a1af546..4b5de5a 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -2368,10 +2368,14 @@ namespace graph { return constant (this->evaluate()); } else if (l.get() && m.get()) { return this->left*this->middle + this->right; - } else if (l.get() && l->evaluate().is_none()) { + } else if (l.get() && l->is(-1)) { return this->right - this->middle; - } else if (m.get() && m->evaluate().is_none()) { + } else if (m.get() && m->is(-1)) { return this->right - this->left; + } else if (l.get() && l->is(1)) { + return this->middle + this->right; + } else if (m.get() && m->is(1)) { + return this->left + this->right; } auto pl1 = piecewise_1D_cast(this->left); diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 124df71..7a8d443 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -2030,6 +2030,18 @@ template void test_fma() { auto var_b = graph::variable (1, ""); auto var_c = graph::variable (1, ""); +// fma(1,a,b) = a + b + auto one_times_vara_plus_varb = graph::fma(one, var_a, var_b); + auto one_times_vara_plus_varb_cast = + graph::add_cast(one_times_vara_plus_varb); + assert(one_times_vara_plus_varb_cast.get() && "Expected an add node."); + +// fma(a,1,b) = a + b + auto vara_times_one_plus_varb = graph::fma(var_a, one, var_b); + auto vara_times_one_plus_varb_cast = + graph::add_cast(vara_times_one_plus_varb); + assert(vara_times_one_plus_varb_cast.get() && "Expected an add node."); + auto reduce1 = graph::fma(var_a, var_b, var_a*var_c); auto reduce1_cast = graph::multiply_cast(reduce1); assert(reduce1_cast.get() && "Expected multiply node."); -- GitLab From 43e9c4852987b66360e4461e8ec51a841115bcbc Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 18 Jun 2024 13:28:13 -0400 Subject: [PATCH 50/63] Simplify (-a) - b -> -(a + b). --- graph_framework/arithmetic.hpp | 11 +++++++++-- graph_framework/node.hpp | 8 ++++---- graph_framework/piecewise.hpp | 4 ++-- graph_tests/arithmetic_test.cpp | 5 +++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 4b5de5a..cc3f9a9 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -574,16 +574,23 @@ namespace graph { // v1 - -c*v2 -> v1 + c*v2 if (rm.get()) { auto rmc = constant_cast(rm->get_left()); - if (rmc.get() && rmc->evaluate().is_none()) { + if (rmc.get() && rmc->is(-1)) { return this->left + rm->get_right(); } else if (rmc.get() && rmc->evaluate().is_negative()) { return this->left + (none ()*rm->get_left())*rm->get_right(); } } + if (lm.get()) { +// Assume constants are on the left. +// -a - b -> -(a + b) + auto lmc = constant_cast(lm->get_left()); + if (lmc.get() && lmc->is(-1)) { + return lm->get_left()*(lm->get_right() + this->right); + } + // a*v - v = (a - 1)*v // v*a - v = (a - 1)*v - if (lm.get()) { if (this->right->is_match(lm->get_right())) { return (lm->get_left() - one ())*this->right; } else if (this->right->is_match(lm->get_left())) { diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 3176968..e63b49e 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -102,7 +102,7 @@ namespace graph { jit::texture1d_list &textures1d, jit::texture2d_list &textures2d) { if (usage.find(this) == usage.end()) { - usage[this] = 0; + usage[this] = 1; } else { ++usage[this]; } @@ -735,7 +735,7 @@ namespace graph { visited, usage, textures1d, textures2d); visited.insert(this); - usage[this] = 0; + usage[this] = 1; } else { ++usage[this]; } @@ -868,7 +868,7 @@ namespace graph { visited, usage, textures1d, textures2d); visited.insert(this); - usage[this] = 0; + usage[this] = 1; } else { ++usage[this]; } @@ -987,7 +987,7 @@ namespace graph { visited, usage, textures1d, textures2d); visited.insert(this); - usage[this] = 0; + usage[this] = 1; } else { ++usage[this]; } diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 02546a5..2bd0aa7 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -227,7 +227,7 @@ void compile_index(std::ostringstream &stream, } } visited.insert(this); - usage[this] = 0; + usage[this] = 1; } else { ++usage[this]; } @@ -681,7 +681,7 @@ void compile_index(std::ostringstream &stream, } } visited.insert(this); - usage[this] = 0; + usage[this] = 1; } else { ++usage[this]; } diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 7a8d443..8819649 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -600,6 +600,11 @@ template void test_subtract() { auto factor4 = var_b - (var_b*var_a); assert(graph::multiply_cast(factor4).get() && "Expected a multiply node."); + +// -1*a - b -> -1*(a + b) + auto neg_vara_minus_varb = (graph::none ()*var_a) - var_b; + assert(graph::multiply_cast(neg_vara_minus_varb).get() && + "Expected a multiply node."); } //------------------------------------------------------------------------------ -- GitLab From ea79db20fba06196c615af502bc5ae4afaa2fcc7 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 19 Jun 2024 11:23:36 -0400 Subject: [PATCH 51/63] Enable an option to disable caching of kernel inputs. --- CMakeLists.txt | 6 ++++ graph_framework/cuda_context.hpp | 46 +++++++++++++++++++++++-------- graph_framework/metal_context.hpp | 27 ++++++++++++------ 3 files changed, 59 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cb68a96..fe4adb1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,12 +77,18 @@ else () endif () endif () +option (USE_INPUT_CACHE "Cache the values kernel input values." OFF) + add_library (gpu_lib INTERFACE) target_link_libraries (gpu_lib INTERFACE $<$:metal_lib> $<$:cuda_lib> ) +target_compile_definitions (gpu_lib + INTERFACE + $<$:USE_INPUT_CACHE> +) #------------------------------------------------------------------------------- # Sanitizer options diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 41df697..db9c663 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -16,6 +16,8 @@ #include "node.hpp" +#define MAX_REG 256 + namespace gpu { //------------------------------------------------------------------------------ /// @brief Check results of realtime compile. @@ -269,7 +271,7 @@ namespace gpu { CU_JIT_POSITION_INDEPENDENT_CODE }; std::array module_values = { - reinterpret_cast (168), + reinterpret_cast (MAX_REG), reinterpret_cast (1), reinterpret_cast (0) }; @@ -337,7 +339,7 @@ namespace gpu { CUDA_RESOURCE_DESC resource_desc; CUDA_TEXTURE_DESC texture_desc; CUDA_ARRAY_DESCRIPTOR array_desc; - + array_desc.Width = size; array_desc.Height = 1; @@ -411,7 +413,7 @@ namespace gpu { } check_error(cuArrayCreate(&resource_desc.res.array.hArray, &array_desc), "cuArrayCreate"); - + CUDA_MEMCPY2D copy_desc; memset(©_desc, 0, sizeof(copy_desc)); @@ -650,7 +652,11 @@ namespace gpu { jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('v', inputs[0].get()); for (size_t i = 1, ie = inputs.size(); i < ie; i++) { - source_buffer << "," << std::endl; + source_buffer << ", // " << inputs[i - 1]->get_symbol() +#ifndef USE_INPUT_CACHE + << " used " << usage.at(inputs[i - 1].get()) +#endif + << std::endl; source_buffer << " "; if (is_constant[i]) { source_buffer << "const "; @@ -659,7 +665,17 @@ namespace gpu { source_buffer << " *" << jit::to_string('v', inputs[i].get()); } for (size_t i = 0, ie = outputs.size(); i < ie; i++) { - source_buffer << "," << std::endl; + source_buffer << ","; + if (i == 0) { + source_buffer << " // " + << inputs[inputs.size() - 1]->get_symbol(); +#ifndef USE_INPUT_CACHE + source_buffer << " used " + << usage.at(inputs[inputs.size() - 1].get()); +#endif + } + + source_buffer << std::endl; source_buffer << " "; jit::add_type (source_buffer); source_buffer << " *" << jit::to_string('o', outputs[i].get()); @@ -683,13 +699,19 @@ namespace gpu { source_buffer << " if (index < " << size << ") {" << std::endl; for (auto &input : inputs) { - registers[input.get()] = jit::to_string('r', input.get()); - source_buffer << " const "; - jit::add_type (source_buffer); - source_buffer << " " << registers[input.get()] << " = " - << jit::to_string('v', input.get()) - << "[index]; // " << input->get_symbol() - << " used " << usage.at(input.get()) << std::endl; +#ifdef USE_INPUT_CACHE + if (usage.at(input.get())) { + registers[input.get()] = jit::to_string('r', input.get()); + source_buffer << " const "; + jit::add_type (source_buffer); + source_buffer << " " << registers[input.get()] << " = " + << jit::to_string('v', input.get()) + << "[index]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; + } +#else + registers[input.get()] = jit::to_string('v', input.get()) + "[index]"; +#endif } } diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 4b8c586..fc5cc4b 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -425,7 +425,12 @@ namespace gpu { source_buffer << " " << (is_constant[i] ? "constant" : "device") << " float *" << jit::to_string('v', inputs[i].get()) - << " [[buffer(" << i << ")]]," << std::endl; + << " [[buffer(" << i << ")]], // " + << inputs[i]->get_symbol() +#ifndef USE_INPUT_CACHE + << " used " << usage.at(inputs[i].get()) +#endif + << std::endl; } for (size_t i = 0, ie = outputs.size(); i < ie; i++) { bufferMutability[name].push_back(MTLMutabilityMutable); @@ -451,13 +456,19 @@ namespace gpu { source_buffer << " if (index < " << size << ") {" << std::endl; for (auto &input : inputs) { - registers[input.get()] = jit::to_string('r', input.get()); - source_buffer << " const "; - jit::add_type (source_buffer); - source_buffer << " " << registers[input.get()] << " = " - << jit::to_string('v', input.get()) - << "[index]; // " << input->get_symbol() - << " used " << usage.at(input.get()) << std::endl; +#ifdef USE_INPUT_CACHE + if (usage.at(input.get())) { + registers[input.get()] = jit::to_string('r', input.get()); + source_buffer << " const "; + jit::add_type (source_buffer); + source_buffer << " " << registers[input.get()] << " = " + << jit::to_string('v', input.get()) + << "[index]; // " << input->get_symbol() + << " used " << usage.at(input.get()) << std::endl; + } +#else + registers[input.get()] = jit::to_string('v', input.get()) + "[index]"; +#endif } } -- GitLab From 228dd9d30af358a1ec5fb77a2f146db0b43b5dd9 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 19 Jun 2024 15:57:13 -0400 Subject: [PATCH 52/63] Add a counter for constant memory to measure usage only used for cuda kernels. --- graph_framework/cpu_context.hpp | 3 ++ graph_framework/cuda_context.hpp | 7 +++ graph_framework/jit.hpp | 6 ++- graph_framework/metal_context.hpp | 3 ++ graph_framework/node.hpp | 82 ++++++++++++++++++------------- graph_framework/piecewise.hpp | 55 ++++++++++++++------- 6 files changed, 103 insertions(+), 53 deletions(-) diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index a70ffea..8dcd6cd 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -75,6 +75,9 @@ namespace gpu { std::map *, size_t> arg_index; public: +/// Remaining constant memory in bytes. NOT USED. + int remaining_const_memory; + //------------------------------------------------------------------------------ /// @brief Get the maximum number of concurrent instances. /// diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index db9c663..6cf5115 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -17,6 +17,7 @@ #include "node.hpp" #define MAX_REG 256 +#define MAX_CONSTANT_MEMORY namespace gpu { //------------------------------------------------------------------------------ @@ -103,6 +104,9 @@ namespace gpu { } public: +/// Remaining constant memory in bytes. + int remaining_const_memory; + //------------------------------------------------------------------------------ /// @brief Get the maximum number of concurrent instances. /// @@ -132,6 +136,9 @@ namespace gpu { check_error(cuCtxSetCurrent(context), "cuCtxSetCurrent"); check_error(cuCtxSetCacheConfig(CU_FUNC_CACHE_PREFER_L1), "cuCtxSetCacheConfig"); check_error(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "cuStreamCreate"); + check_error(cuDeviceGetAttribute(&remaining_const_memory, + CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY, + device), "cuDeviceGetAttribute"); } //------------------------------------------------------------------------------ diff --git a/graph_framework/jit.hpp b/graph_framework/jit.hpp index f94581d..ba092c5 100644 --- a/graph_framework/jit.hpp +++ b/graph_framework/jit.hpp @@ -120,13 +120,15 @@ namespace jit { out->compile_preamble(source_buffer, registers, visited, usage, kernel_1dtextures[name], - kernel_2dtextures[name]); + kernel_2dtextures[name], + gpu_context.remaining_const_memory); } for (auto &out : outputs) { out->compile_preamble(source_buffer, registers, visited, usage, kernel_1dtextures[name], - kernel_2dtextures[name]); + kernel_2dtextures[name], + gpu_context.remaining_const_memory); } for (auto &in : inputs) { diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index fc5cc4b..326147f 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -39,6 +39,9 @@ namespace gpu { std::map> bufferMutability; public: +/// Remaining constant memory in bytes. NOT USED. + int remaining_const_memory; + //------------------------------------------------------------------------------ /// @brief Get the maximum number of concurrent instances. /// diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index e63b49e..35aa445 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -88,19 +88,21 @@ namespace graph { /// Some nodes require additions to the preamble however most don't so define a /// generic method that does nothing. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. -/// @params[in,out] usage List of register usage count. -/// @params[in,out] textures1d List of 1D textures. -/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) { + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (usage.find(this) == usage.end()) { usage[this] = 1; } else { @@ -717,23 +719,26 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. -/// @params[in,out] usage List of register usage count. -/// @params[in,out] textures1d List of 1D textures. -/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) { + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { this->arg->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); visited.insert(this); usage[this] = 1; } else { @@ -847,26 +852,30 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. -/// @params[in,out] usage List of register usage count. -/// @params[in,out] textures1d List of 1D textures. -/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) { + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { this->left->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); this->right->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); visited.insert(this); usage[this] = 1; } else { @@ -963,29 +972,34 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. -/// @params[in,out] usage List of register usage count. -/// @params[in,out] textures1d List of 1D textures. -/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) { + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { this->left->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); this->middle->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); this->right->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); visited.insert(this); usage[this] = 1; } else { diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 2bd0aa7..72e8294 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -179,23 +179,26 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. -/// @params[in,out] usage List of register usage count. -/// @params[in,out] textures1d List of 1D textures. -/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) { + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { this->arg->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); @@ -209,6 +212,13 @@ void compile_index(std::ostringstream &stream, length); #endif } else { + if constexpr (jit::use_cuda()) { + const int buffer_size = length*sizeof(T); + if (avail_const_mem - buffer_size > 0) { + avail_const_mem -= buffer_size; + stream << "__constant__ "; + } + } stream << "const "; jit::add_type (stream); stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; @@ -630,26 +640,30 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Compile preamble. /// -/// @params[in,out] stream String buffer stream. -/// @params[in,out] registers List of defined registers. -/// @params[in,out] visited List of visited nodes. -/// @params[in,out] usage List of register usage count. -/// @params[in,out] textures1d List of 1D textures. -/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] stream String buffer stream. +/// @params[in,out] registers List of defined registers. +/// @params[in,out] visited List of visited nodes. +/// @params[in,out] usage List of register usage count. +/// @params[in,out] textures1d List of 1D textures. +/// @params[in,out] textures2d List of 2D textures. +/// @params[in,out] avail_const_mem Available constant memory. //------------------------------------------------------------------------------ virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, - jit::texture2d_list &textures2d) { + jit::texture2d_list &textures2d, + int &avail_const_mem) { if (visited.find(this) == visited.end()) { this->left->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); this->right->compile_preamble(stream, registers, visited, usage, - textures1d, textures2d); + textures1d, textures2d, + avail_const_mem); if (registers.find(leaf_node::backend_cache[data_hash].data()) == registers.end()) { registers[leaf_node::backend_cache[data_hash].data()] = jit::to_string('a', leaf_node::backend_cache[data_hash].data()); @@ -663,6 +677,13 @@ void compile_index(std::ostringstream &stream, std::array ({length/num_columns, num_columns})); #endif } else { + if constexpr (jit::use_cuda()) { + const int buffer_size = length*sizeof(T); + if (avail_const_mem - buffer_size > 0) { + avail_const_mem -= buffer_size; + stream << "__constant__ "; + } + } stream << "const "; jit::add_type (stream); stream << " " << registers[leaf_node::backend_cache[data_hash].data()] << "[] = {"; -- GitLab From dd4c50351b1ca0b7de645ce1401712a90cb99278 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 20 Jun 2024 13:33:22 -0400 Subject: [PATCH 53/63] Add restrict keyword. --- graph_framework/cuda_context.hpp | 56 +++++--------------------------- 1 file changed, 9 insertions(+), 47 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 6cf5115..ea520df 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -16,7 +16,7 @@ #include "node.hpp" -#define MAX_REG 256 +#define MAX_REG 128 #define MAX_CONSTANT_MEMORY namespace gpu { @@ -657,7 +657,8 @@ namespace gpu { source_buffer << "const "; } jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', inputs[0].get()); + source_buffer << " * __restrict__ " + << jit::to_string('v', inputs[0].get()); for (size_t i = 1, ie = inputs.size(); i < ie; i++) { source_buffer << ", // " << inputs[i - 1]->get_symbol() #ifndef USE_INPUT_CACHE @@ -669,7 +670,8 @@ namespace gpu { source_buffer << "const "; } jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('v', inputs[i].get()); + source_buffer << " * __restrict__ " + << jit::to_string('v', inputs[i].get()); } for (size_t i = 0, ie = outputs.size(); i < ie; i++) { source_buffer << ","; @@ -685,7 +687,8 @@ namespace gpu { source_buffer << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " *" << jit::to_string('o', outputs[i].get()); + source_buffer << " * __restrict__ " + << jit::to_string('o', outputs[i].get()); } #ifdef USE_CUDA_TEXTURES for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { @@ -803,10 +806,10 @@ namespace gpu { source_buffer << "extern \"C\" __global__ void max_reduction(" << std::endl; source_buffer << " const "; jit::add_type (source_buffer); - source_buffer << " *input," << std::endl; + source_buffer << " * __restruct__ input," << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " *result) {" << std::endl; + source_buffer << " * __result__ result) {" << std::endl; source_buffer << " const unsigned int i = threadIdx.x;" << std::endl; source_buffer << " const unsigned int j = threadIdx.x/32;" << std::endl; source_buffer << " const unsigned int k = threadIdx.x%32;" << std::endl; @@ -840,47 +843,6 @@ namespace gpu { source_buffer << "}" << std::endl << std::endl; } -//------------------------------------------------------------------------------ -/// @brief Create a preamble. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_preamble(std::ostringstream &source_buffer) { - source_buffer << "extern \"C\" __global__ "; - } - -//------------------------------------------------------------------------------ -/// @brief Create arg prefix. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_argument_prefix(std::ostringstream &source_buffer) {} - -//------------------------------------------------------------------------------ -/// @brief Create arg postfix. -/// -/// @params[in,out] source_buffer Source buffer stream. -/// @params[in] index Argument index. -//------------------------------------------------------------------------------ - void create_argument_postfix(std::ostringstream &source_buffer, - const size_t index) {} - -//------------------------------------------------------------------------------ -/// @brief Create index argument. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_index_argument(std::ostringstream &source_buffer) {} - -//------------------------------------------------------------------------------ -/// @brief Create index. -/// -/// @params[in,out] source_buffer Source buffer stream. -//------------------------------------------------------------------------------ - void create_index(std::ostringstream &source_buffer) { - source_buffer << "blockIdx.x*blockDim.x + threadIdx.x;"; - } - //------------------------------------------------------------------------------ /// @brief Get the buffer for a node. /// -- GitLab From 8e0eb301e875826d088e96b8a5654b975bc48ce3 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 20 Jun 2024 13:37:22 -0400 Subject: [PATCH 54/63] __result__ -> __restruct__ --- graph_framework/cuda_context.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index ea520df..704951e 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -809,7 +809,7 @@ namespace gpu { source_buffer << " * __restruct__ input," << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " * __result__ result) {" << std::endl; + source_buffer << " * __restruct__ result) {" << std::endl; source_buffer << " const unsigned int i = threadIdx.x;" << std::endl; source_buffer << " const unsigned int j = threadIdx.x/32;" << std::endl; source_buffer << " const unsigned int k = threadIdx.x%32;" << std::endl; -- GitLab From e0d06521d2d5e92e61a4895dc0d16f1746d602f4 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Thu, 20 Jun 2024 13:44:56 -0400 Subject: [PATCH 55/63] __restruct__ -> __restrict__ --- graph_framework/cuda_context.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 704951e..e5eb9a9 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -806,10 +806,10 @@ namespace gpu { source_buffer << "extern \"C\" __global__ void max_reduction(" << std::endl; source_buffer << " const "; jit::add_type (source_buffer); - source_buffer << " * __restruct__ input," << std::endl; + source_buffer << " * __restrict__ input," << std::endl; source_buffer << " "; jit::add_type (source_buffer); - source_buffer << " * __restruct__ result) {" << std::endl; + source_buffer << " * __restrict__ result) {" << std::endl; source_buffer << " const unsigned int i = threadIdx.x;" << std::endl; source_buffer << " const unsigned int j = threadIdx.x/32;" << std::endl; source_buffer << " const unsigned int k = threadIdx.x%32;" << std::endl; -- GitLab From a76ab26090b02fe7e76aefc0c00ffafdb716639d Mon Sep 17 00:00:00 2001 From: cianciosa Date: Mon, 24 Jun 2024 18:04:12 -0400 Subject: [PATCH 56/63] Fix error where textures would be missing when multiple kernals were called in a workflow. --- graph_benchmark/xrays_bench.cpp | 6 ++--- graph_framework/cuda_context.hpp | 8 +++--- graph_framework/metal_context.hpp | 13 +++++----- graph_framework/piecewise.hpp | 42 +++++++++++++++++++++++++------ graph_framework/register.hpp | 4 +-- 5 files changed, 50 insertions(+), 23 deletions(-) diff --git a/graph_benchmark/xrays_bench.cpp b/graph_benchmark/xrays_bench.cpp index 9496977..bd447fc 100644 --- a/graph_benchmark/xrays_bench.cpp +++ b/graph_benchmark/xrays_bench.cpp @@ -114,9 +114,9 @@ int main(int argc, const char * argv[]) { (void)argv; bench_runner (); - bench_runner (); - bench_runner, 1000, 10, 100000> (); - bench_runner, 1000, 10, 100000> (); +// bench_runner (); +// bench_runner, 1000, 10, 100000> (); +// bench_runner, 1000, 10, 100000> (); END_GPU } diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index e5eb9a9..e0912b0 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -691,15 +691,15 @@ namespace gpu { << jit::to_string('o', outputs[i].get()); } #ifdef USE_CUDA_TEXTURES - for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { + for (auto &[key, value] : textures1d) { source_buffer << "," << std::endl; source_buffer << " cudaTextureObject_t " - << jit::to_string('a', textures1d[i].first); + << jit::to_string('a', key); } - for (size_t i = 0, ie = textures2d.size(); i < ie; i++) { + for (auto &[key, value] : textures2d) { source_buffer << "," << std::endl; source_buffer << " cudaTextureObject_t " - << jit::to_string('a', textures2d[i].first); + << jit::to_string('a', key); } #endif source_buffer << ") {" << std::endl; diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 326147f..03bde6b 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -442,16 +442,17 @@ namespace gpu { << " [[buffer(" << i + inputs.size() << ")]]," << std::endl; } - for (size_t i = 0, ie = textures1d.size(); i < ie; i++) { + size_t index = 0; + for (auto &[key, value] : textures1d) { source_buffer << " const texture1d " - << jit::to_string('a', textures1d[i].first) - << " [[texture(" << i << ")]]," + << jit::to_string('a', key) + << " [[texture(" << index++ << ")]]," << std::endl; } - for (size_t i = 0, ie = textures2d.size(); i < ie; i++) { + for (auto &[key, value] : textures2d) { source_buffer << " const texture2d " - << jit::to_string('a', textures2d[i].first) - << " [[texture(" << i + textures1d.size() << ")]]," + << jit::to_string('a', key) + << " [[texture(" << index++ << ")]]," << std::endl; } diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 72e8294..93fd335 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -204,12 +204,12 @@ void compile_index(std::ostringstream &stream, jit::to_string('a', leaf_node::backend_cache[data_hash].data()); const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - textures1d.emplace_back(leaf_node::backend_cache[data_hash].data(), - length); + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); #ifdef USE_CUDA_TEXTURES } else if constexpr (jit::use_cuda()) { - textures1d.emplace_back(leaf_node::backend_cache[data_hash].data(), - length); + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); #endif } else { if constexpr (jit::use_cuda()) { @@ -235,6 +235,19 @@ void compile_index(std::ostringstream &stream, } stream << "};" << std::endl; } + } else { +// When using textures, the register can be defined in a previous kernel. We +// need to add the textures again. + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures1d.try_emplace(leaf_node::backend_cache[data_hash].data(), + length); +#endif + } } visited.insert(this); usage[this] = 1; @@ -669,12 +682,12 @@ void compile_index(std::ostringstream &stream, jit::to_string('a', leaf_node::backend_cache[data_hash].data()); const size_t length = leaf_node::backend_cache[data_hash].size(); if constexpr (jit::use_metal ()) { - textures2d.emplace_back(leaf_node::backend_cache[data_hash].data(), - std::array ({length/num_columns, num_columns})); + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); #ifdef USE_CUDA_TEXTURES } else if constexpr (jit::use_cuda()) { - textures2d.emplace_back(leaf_node::backend_cache[data_hash].data(), - std::array ({length/num_columns, num_columns})); + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); #endif } else { if constexpr (jit::use_cuda()) { @@ -700,6 +713,19 @@ void compile_index(std::ostringstream &stream, } stream << "};" << std::endl; } + } else { +// When using textures, the register can be defined in a previous kernel. We +// need to add the textures again. + const size_t length = leaf_node::backend_cache[data_hash].size(); + if constexpr (jit::use_metal ()) { + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#ifdef USE_CUDA_TEXTURES + } else if constexpr (jit::use_cuda()) { + textures2d.try_emplace(leaf_node::backend_cache[data_hash].data(), + std::array ({length/num_columns, num_columns})); +#endif + } } visited.insert(this); usage[this] = 1; diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index 87cea84..b1f18cc 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -249,9 +249,9 @@ namespace jit { /// Type alias for listing visited nodes. typedef std::set visiter_map; /// Type alias for indexing 1D textures. - typedef std::vector> texture1d_list; + typedef std::map texture1d_list; /// Type alias for indexing 2D textures. - typedef std::vector>> texture2d_list; + typedef std::map> texture2d_list; //------------------------------------------------------------------------------ /// @brief Define a custom comparitor class. -- GitLab From 91106450854003019273fc0eb56fcdbbca8d77ef Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 25 Jun 2024 09:41:52 -0400 Subject: [PATCH 57/63] Reset benchmark sizes. --- graph_benchmark/xrays_bench.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graph_benchmark/xrays_bench.cpp b/graph_benchmark/xrays_bench.cpp index bd447fc..9496977 100644 --- a/graph_benchmark/xrays_bench.cpp +++ b/graph_benchmark/xrays_bench.cpp @@ -114,9 +114,9 @@ int main(int argc, const char * argv[]) { (void)argv; bench_runner (); -// bench_runner (); -// bench_runner, 1000, 10, 100000> (); -// bench_runner, 1000, 10, 100000> (); + bench_runner (); + bench_runner, 1000, 10, 100000> (); + bench_runner, 1000, 10, 100000> (); END_GPU } -- GitLab From df332d73b7f10fcd3df394c7c4a297b664a0a21b Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 5 Jul 2024 17:40:13 -0400 Subject: [PATCH 58/63] Add combinations for piecewise block and row reductions. Add reductions for nested fma nodes. Clean up code that checks for compatable constants and variables. --- CMakeLists.txt | 2 + graph_framework.xcodeproj/project.pbxproj | 313 ++++++- .../xcschemes/arithmetic_test.xcscheme | 2 +- .../xcschemes/graph_driver.xcscheme | 2 +- .../xcshareddata/xcschemes/jit_test.xcscheme | 2 +- .../xcshareddata/xcschemes/math_test.xcscheme | 2 +- .../xcschemes/physics_test.xcscheme | 2 +- graph_framework/arithmetic.hpp | 886 +++++++++++++++--- graph_framework/backend.hpp | 485 ++++++++++ graph_framework/cpu_context.hpp | 24 +- graph_framework/math.hpp | 146 +-- graph_framework/node.hpp | 95 +- graph_framework/piecewise.hpp | 59 +- graph_framework/trigonometry.hpp | 55 ++ graph_tests/arithmetic_test.cpp | 343 ++++++- graph_tests/backend_test.cpp | 24 + graph_tests/math_test.cpp | 22 +- graph_tests/node_test.cpp | 6 +- graph_tests/physics_test.cpp | 1 + graph_tests/piecewise_test.cpp | 283 +++++- 20 files changed, 2435 insertions(+), 319 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fe4adb1..ddad9ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -272,6 +272,8 @@ target_link_libraries (llvm_dep clangCodeGen LLVM${LLVM_NATIVE_ARCH}CodeGen LLVMOrcJIT + LLVMOrcDebugging + LLVMOrcTargetProcess ) #------------------------------------------------------------------------------- diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index 4817ca9..b95c2df 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -886,7 +886,7 @@ isa = PBXProject; attributes = { BuildIndependentTargetsInParallel = YES; - LastUpgradeCheck = 1530; + LastUpgradeCheck = 1540; ORGANIZATIONNAME = "Cianciosa, Mark R."; TargetAttributes = { C73690302A38C498001733B0 = { @@ -1338,8 +1338,7 @@ "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", USE_METAL, - "CXX_FLAGS=\\\"-g\\\"", - "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -std=gnu++2a\\\"\"", + "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", STATIC, "DEBUG=1", "$(inherited)", @@ -1366,9 +1365,69 @@ OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", - "-rpath", - /usr/local/lib, - "-lLLVM", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", "-lclangEdit", @@ -1383,6 +1442,8 @@ "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", + "-rpath", + /usr/local/lib, ); SDKROOT = macosx; SYSTEM_HEADER_SEARCH_PATHS = ""; @@ -1441,7 +1502,7 @@ "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", USE_METAL, - "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -std=gnu++2a\\\"\"", + "\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"", "$(inherited)", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; @@ -1466,9 +1527,69 @@ OTHER_LDFLAGS = ( "-lnetcdf", "-ld_classic", - "-rpath", - /usr/local/lib, - "-lLLVM", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", "-lclangEdit", @@ -1483,6 +1604,8 @@ "-lclangParse", "-lclangAPINotes", "-lclangCodeGen", + "-rpath", + /usr/local/lib, ); SDKROOT = macosx; SYSTEM_HEADER_SEARCH_PATHS = ""; @@ -1747,11 +1870,94 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", USE_METAL, - "CXX=\\\"c++\\\"", "DEBUG=1", "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Debug; @@ -1766,10 +1972,93 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", USE_METAL, - "CXX=\\\"c++\\\"", "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; + OTHER_LDFLAGS = ( + "-lnetcdf", + "-ld_classic", + "-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib", + "-lz", + "-lLLVMCoverage", + "-lLLVMSupport", + "-lLLVMDebugInfoCodeView", + "-lLLVMRemarks", + "-lLLVMJITLink", + "-lLLVMLinker", + "-lLLVMTextAPI", + "-lLLVMRuntimeDyld", + "-lLLVMOrcShared", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcDebugging", + "-lLLVMOrcTargetProcess", + "-lLLVMOrcJIT", + "-lLLVMHipStdPar", + "-lLLVMAggressiveInstCombine", + "-lLLVMVectorize", + "-lLLVMAsmParser", + "-lLLVMOption", + "-lLLVMLTO", + "-lLLVMObject", + "-lLLVMWindowsDriver", + "-lLLVMDemangle", + "-lLLVMIRReader", + "-lLLVMIRPrinter", + "-lLLVMInstCombine", + "-lLLVMBinaryFormat", + "-lLLVMCoroutines", + "-lLLVMBitstreamReader", + "-lLLVMBitReader", + "-lLLVMBitWriter", + "-lLLVMDebugInfoDWARF", + "-lLLVMInstrumentation", + "-lLLVMCFGuard", + "-lLLVMObjCARCOpts", + "-lLLVMipo", + "-lLLVMGlobalISel", + "-lLLVMExecutionEngine", + "-lLLVMFrontendDriver", + "-lLLVMFrontendHLSL", + "-lLLVMFrontendOpenMP", + "-lLLVMFrontendOffloading", + "-lLLVMSelectionDAG", + "-lLLVMProfileData", + "-lLLVMAnalysis", + "-lLLVMScalarOpts", + "-lLLVMCodeGenTypes", + "-lLLVMCodeGen", + "-lLLVMTargetParser", + "-lLLVMScalarOpts", + "-lLLVMTarget", + "-lLLVMTransformUtils", + "-lLLVMPasses", + "-lLLVMSupport", + "-lLLVMMCParser", + "-lLLVMMC", + "-lLLVMCore", + "-lLLVMAsmPrinter", + "-lLLVMAArch64Utils", + "-lLLVMAArch64Info", + "-lLLVMAArch64Desc", + "-lLLVMAArch64AsmParser", + "-lLLVMAArch64CodeGen", + "-lclangFrontend", + "-lclangBasic", + "-lclangEdit", + "-lclangLex", + "-lclangDriver", + "-lclangSerialization", + "-lclangAST", + "-lclangSema", + "-lclangAnalysis", + "-lclangASTMatchers", + "-lclangSupport", + "-lclangParse", + "-lclangAPINotes", + "-lclangCodeGen", + "-rpath", + /usr/local/lib, + ); PRODUCT_NAME = "$(TARGET_NAME)"; }; name = Release; diff --git a/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme b/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme index 6ff6a04..42cdf87 100644 --- a/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme +++ b/graph_framework.xcodeproj/xcshareddata/xcschemes/arithmetic_test.xcscheme @@ -1,6 +1,6 @@ + bool is_constant_combineable(shared_leaf a, + shared_leaf b) { + if (a->is_constant() && b->is_constant()) { + auto a1 = piecewise_1D_cast(a); + auto a2 = piecewise_2D_cast(a); + auto b2 = piecewise_2D_cast(b); + + return constant_cast(a).get() || + constant_cast(b).get() || + (a1.get() && a1->is_arg_match(b)) || + (a2.get() && a2->is_arg_match(b)) || + (a2.get() && (a2->is_row_match(b) || a2->is_col_match(b))) || + (b2.get() && (b2->is_row_match(a) || b2->is_col_match(a))); + } + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Check if the constants are promotable. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a is promoteable over b. +//------------------------------------------------------------------------------ + template + bool is_constant_promotable(shared_leaf a, + shared_leaf b) { + + auto b1 = piecewise_1D_cast(b); + auto b2 = piecewise_2D_cast(b); + + return a->is_constant() && + (!b->is_constant() || + (constant_cast(a).get() && (b1.get() || b2.get())) || + (piecewise_1D_cast(a).get() && b2.get())); + } + +//------------------------------------------------------------------------------ +/// @brief Check if the variable is combinable. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a and b are combinable. +//------------------------------------------------------------------------------ + template + bool is_variable_combinable(shared_leaf a, + shared_leaf b) { + return a->get_power_base()->is_match(b->get_power_base()); + } + +//------------------------------------------------------------------------------ +/// @brief Check if the exponent is greater than the other. +/// +/// @tparam T Base type of the nodes. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @params[in] a Opperand A +/// @params[in] b Opperand B +/// @returns True if a and b are combinable. +//------------------------------------------------------------------------------ + template + bool is_greater_exponent(shared_leaf a, + shared_leaf b) { + auto ae = constant_cast(a->get_power_exponent()); + auto be = constant_cast(b->get_power_exponent()); + + return ae.get() && be.get() && + std::abs(ae->evaluate().at(0)) > std::abs(be->evaluate().at(0)); + } + //****************************************************************************** // Add node. //****************************************************************************** @@ -107,6 +194,37 @@ namespace graph { pr2->get_right()); } +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.add_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.add_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.add_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.add_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + // Idenity reductions. if (this->left->is_match(this->right)) { return two ()*this->left; @@ -199,14 +317,40 @@ namespace graph { auto rfma = fma_cast(this->right); if (lfma.get()) { // fma(c,d,e) + a -> fma(c,d,e + a) - return fma(lfma->get_left(), lfma->get_middle(), + return fma(lfma->get_left(), + lfma->get_middle(), lfma->get_right() + this->right); } else if (rfma.get()) { // a + fma(c,d,e) -> fma(c,d,a + e) - return fma(rfma->get_left(), rfma->get_middle(), + return fma(rfma->get_left(), + rfma->get_middle(), this->left + rfma->get_right()); } +// fma(b,a,d) + fma(c,a,e) -> fma(a,b + c, d + e) +// fma(a,b,d) + fma(c,a,e) -> fma(a,b + c, d + e) +// fma(b,a,d) + fma(a,c,e) -> fma(a,b + c, d + e) +// fma(a,b,d) + fma(a,c,e) -> fma(a,b + c, d + e) + if (lfma.get() && rfma.get()) { + if (lfma->get_middle()->is_match(rfma->get_middle())) { + return fma(lfma->get_middle(), + lfma->get_left() + rfma->get_left(), + lfma->get_right() + rfma->get_right()); + } else if (lfma->get_left()->is_match(rfma->get_middle())) { + return fma(lfma->get_left(), + lfma->get_middle() + rfma->get_left(), + lfma->get_right() + rfma->get_right()); + } else if (lfma->get_middle()->is_match(rfma->get_left())) { + return fma(lfma->get_middle(), + lfma->get_left() + rfma->get_middle(), + lfma->get_right() + rfma->get_right()); + } else if (lfma->get_left()->is_match(rfma->get_left())) { + return fma(lfma->get_left(), + lfma->get_middle() + rfma->get_middle(), + lfma->get_right() + rfma->get_right()); + } + } + // Handle cases like: // (a/y)^e + b/y^e -> (a^2 + b)/(y^e) // b/y^e + (a/y)^e -> (b + a^2)/(y^e) @@ -564,6 +708,37 @@ namespace graph { pr2->get_right()); } +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.subtract_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.subtract_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.subtract_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.subtract_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). auto lm = multiply_cast(this->left); @@ -622,10 +797,11 @@ namespace graph { return lm->get_right()*(lm->get_left() - rm->get_left()); } -// Change cases like c1*a - c2*b -> c1*(a - c2*b) - auto lmc = constant_cast(lm->get_left()); - auto rmc = constant_cast(rm->get_left()); - if (lmc.get() && rmc.get()) { +// Change cases like c1*a - c2*b -> c1*(a - c2/c1*b) +// Note need to make sure c1 doesn't contain any zeros. + if (lm->get_left()->is_constant() && + rm->get_left()->is_constant() && + !lm->has_constant_zero()) { return lm->get_left()*(lm->get_right() - (rm->get_left()/lm->get_left())*rm->get_right()); } @@ -680,16 +856,20 @@ namespace graph { auto rmrd = divide_cast(rm->get_right()); if (lmld.get() && rmld.get() && lmld->get_right()->is_match(rmld->get_right())) { - return (lmld->get_left()*lm->get_right() - rmld->get_left()*rm->get_right())/lmld->get_right(); + return (lmld->get_left()*lm->get_right() - + rmld->get_left()*rm->get_right())/lmld->get_right(); } else if (lmld.get() && rmrd.get() && lmld->get_right()->is_match(rmrd->get_right())) { - return (lmld->get_left()*lm->get_right() - rmrd->get_left()*rm->get_left())/lmld->get_right(); + return (lmld->get_left()*lm->get_right() - + rmrd->get_left()*rm->get_left())/lmld->get_right(); } else if (lmrd.get() && rmld.get() && lmrd->get_right()->is_match(rmld->get_right())) { - return (lmrd->get_left()*lm->get_left() - rmld->get_left()*rm->get_right())/lmrd->get_right(); + return (lmrd->get_left()*lm->get_left() - + rmld->get_left()*rm->get_right())/lmrd->get_right(); } else if (lmrd.get() && rmrd.get() && lmrd->get_right()->is_match(rmrd->get_right())) { - return (lmrd->get_left()*lm->get_left() - rmrd->get_left()*rm->get_left())/lmrd->get_right(); + return (lmrd->get_left()*lm->get_left() - + rmrd->get_left()*rm->get_left())/lmrd->get_right(); } } @@ -701,19 +881,23 @@ namespace graph { if (lrm->get_left()->is_match(rm->get_left())) { // (a - c*b) - c*d -> a - (b + d)*c return ls->get_left() - - (lrm->get_right() + rm->get_right())*rm->get_left(); + (lrm->get_right() + + rm->get_right())*rm->get_left(); } else if (lrm->get_left()->is_match(rm->get_right())) { // (a - c*b) - d*c -> a - (b + d)*c return ls->get_left() - - (lrm->get_right() + rm->get_left())*rm->get_right(); + (lrm->get_right() + + rm->get_left())*rm->get_right(); } else if (lrm->get_right()->is_match(rm->get_left())) { // (a - c*b) - c*d -> a - (b + d)*c return ls->get_left() - - (lrm->get_left() + rm->get_right())*rm->get_left(); + (lrm->get_left() + + rm->get_right())*rm->get_left(); } else if (lrm->get_right()->is_match(rm->get_right())) { // (a - c*b) - d*c -> a - (b + d)*c return ls->get_left() - - (lrm->get_left() + rm->get_left())*rm->get_right(); + (lrm->get_left() + + rm->get_left())*rm->get_right(); } } } @@ -793,6 +977,13 @@ namespace graph { } } +// fma(c,d,e) - a -> fma(c,d,e - a) + if (lfma.get() && !this->right->is_all_variables()) { + return fma(lfma->get_left(), + lfma->get_middle(), + lfma->get_right() - this->right); + } + // Reduce cases chained subtract multiply divide. if (ls.get()) { // (a - b*c) - d*e -> a - (b*c + d*e) @@ -1111,20 +1302,44 @@ namespace graph { pr2->get_right()); } -// Move constants to the left. - if (r.get() && !l.get()) { - return this->right*this->left; +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.multiply_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.multiply_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.multiply_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.multiply_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); } -// Move piecewise constants to the left. - if ((pr1.get() || pr2.get()) && - (!pl1.get() && !pl2.get() && !l.get())) { +// Move constants to the left. + if (is_constant_promotable(this->right, this->left)) { return this->right*this->left; } // Move constant like to the left. - if (this->right->is_constant_like() && - !this->left->is_constant_like()) { + if (is_constant_promotable(this->right, this->left)) { return this->right*this->left; } @@ -1132,7 +1347,7 @@ namespace graph { // Disable if the left is a constant like to avoid an infinite loop. if (this->left->is_power_like() && !this->right->is_power_like() && - !this->left->is_constant_like()) { + !this->left->is_constant()) { return this->right*this->left; } @@ -1175,7 +1390,8 @@ namespace graph { // Promote constants before variables. // (c*v1)*v2 -> c*(v1*v2) - if (lm->get_left()->is_constant_like()) { + if (is_constant_promotable(lm->get_left(), + lm->get_right())) { return lm->get_left()*(lm->get_right()*this->right); } @@ -1198,7 +1414,8 @@ namespace graph { if (rm.get()) { // Assume constants are on the left. // c1*(c2*v) -> c3*v - if (constant_cast(rm->get_left()).get() && l.get()) { + if (is_constant_combineable(this->left, + rm->get_left())) { return (this->left*rm->get_left())*rm->get_right(); } @@ -1210,7 +1427,8 @@ namespace graph { } // v1*(c*v2) -> c*(v1*v2) - if (rm.get() && constant_cast(rm->get_left()).get()) { + if (rm.get() && + is_constant_promotable(rm->get_left(), this->left)) { return rm->get_left()*(this->left*rm->get_right()); } @@ -1228,27 +1446,27 @@ namespace graph { } else if (rm.get() && (sin_cast(rm->get_right()).get() || cos_cast(rm->get_right()).get()) && - !this->left->is_constant_like()) { + !this->left->is_constant()) { return (this->left*rm->get_left())*rm->get_right(); } // Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on // the second pass. if (lm.get() && rm.get()) { - if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_left()).get()) { + if (is_constant_combineable(lm->get_left(), + rm->get_left())) { return (lm->get_left()*rm->get_left()) * (lm->get_right()*rm->get_right()); - } else if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_right()).get()) { + } else if (is_constant_combineable(lm->get_left(), + rm->get_right())) { return (lm->get_left()*rm->get_right()) * (lm->get_right()*rm->get_left()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_left()).get()) { + } else if (is_constant_combineable(lm->get_right(), + rm->get_left())) { return (lm->get_right()*rm->get_left()) * (lm->get_left()*rm->get_right()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_right()).get()) { + } else if (is_constant_combineable(lm->get_right(), + rm->get_right())) { return (lm->get_right()*rm->get_right()) * (lm->get_left()*rm->get_left()); } @@ -1275,26 +1493,22 @@ namespace graph { if (ld.get()) { // (c/v1)*v2 -> c*(v2/v1) - if (constant_cast(ld->get_left()).get() || - piecewise_1D_cast(ld->get_left()).get() || - piecewise_2D_cast(ld->get_left()).get()) { + if (ld->get_left()->is_constant()) { return ld->get_left()*(this->right/ld->get_right()); } } // c1*(c2/v) -> c3/v -// c1*(v/c2) -> v/c3 - if (rd.get() && l.get()) { - if (constant_cast(rd->get_left()).get()) { - return (this->left*rd->get_left())/rd->get_right(); - } else if (constant_cast(rd->get_right()).get()) { - return rd->get_left()/(this->left*rd->get_right()); - } + if (rd.get() && this->left->is_constant() && + rd->get_left()->is_constant()) { + return (this->left*rd->get_left())/rd->get_right(); } +// (a/b)*(c/a) -> c/b +// (b/a)*(a/c) -> c/b if (ld.get() && rd.get()) { if (ld->get_left()->is_match(rd->get_right())) { - return ld->get_right()/rd->get_left(); + return rd->get_left()/ld->get_right(); } else if (ld->get_right()->is_match(rd->get_left())) { return ld->get_left()/rd->get_right(); } @@ -1860,18 +2074,44 @@ namespace graph { pr2->get_right()); } - if (this->left->is_match(this->right)) { - if (l.get() && l->is(1)) { - return this->left; - } +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.divide_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.divide_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.divide_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.divide_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + if (this->left->is_match(this->right)) { return one (); } // Reduce cases of a/c1 -> c2*a - if (r.get()) { - return (one ()/this->right) * - this->left; + if (this->right->is_constant()) { + return (one ()/this->right)*this->left; } // fma(a,d,c*d)/d -> a + c @@ -1902,35 +2142,46 @@ namespace graph { auto lm = multiply_cast(this->left); auto rm = multiply_cast(this->right); -// Assume constants are always on the left. // c1/(c2*v) -> c3/v -// (c1*v)/c2 -> c3*v - if (rm.get() && l.get()) { - if (constant_cast(rm->get_left()).get()) { +// c1/(c2*c3) -> c4/c3 + if (rm.get()) { + if (is_constant_combineable(rm->get_left(), + this->left)) { return (this->left/rm->get_left())/rm->get_right(); - } - } else if (lm.get() && r.get()) { - if (constant_cast(lm->get_left()).get()) { - return (lm->get_left()/this->right)*lm->get_right(); + } else if (is_constant_combineable(rm->get_left(), + this->left)) { + return (this->left/rm->get_right())/rm->get_left(); } } if (lm.get() && rm.get()) { // Test for constants that can be reduced out. - if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_left()).get()) { - return (lm->get_left()/rm->get_left())*(lm->get_right()/rm->get_right()); - } else if (constant_cast(lm->get_left()).get() && - constant_cast(rm->get_right()).get()) { - return (lm->get_left()/rm->get_right())*(lm->get_right()/rm->get_left()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_left()).get()) { - return (lm->get_right()/rm->get_left())*(lm->get_left()/rm->get_right()); - } else if (constant_cast(lm->get_right()).get() && - constant_cast(rm->get_right()).get()) { - return (lm->get_right()/rm->get_right())*(lm->get_left()/rm->get_left()); - } - +// (c1*a)/(c2*b) -> c3*a/b +// (a*c1)/(c2*b) -> c3*a/b +// (c1*a)/(b*c2) -> c3*a/b +// (a*c1)/(b*c2) -> c3*a/b + if (is_constant_combineable(lm->get_left(), + rm->get_left())) { + return (lm->get_left()/rm->get_left()) * + (lm->get_right()/rm->get_right()); + } else if (is_constant_combineable(lm->get_left(), + rm->get_right())) { + return (lm->get_left()/rm->get_right()) * + (lm->get_right()/rm->get_left()); + } else if (is_constant_combineable(lm->get_right(), + rm->get_left())) { + return (lm->get_right()/rm->get_left()) * + (lm->get_left()/rm->get_right()); + } else if (is_constant_combineable(lm->get_right(), + rm->get_right())) { + return (lm->get_right()/rm->get_right()) * + (lm->get_left()/rm->get_left()); + } + +// (a*b)/(a*c) -> b/c +// (b*a)/(a*c) -> b/c +// (a*b)/(c*a) -> b/c +// (b*a)/(c*a) -> b/c if (lm->get_left()->is_match(rm->get_left())) { return lm->get_right()/rm->get_right(); } else if (lm->get_left()->is_match(rm->get_right())) { @@ -1975,11 +2226,8 @@ namespace graph { } // (c*v1)/v2 -> c*(v1/v2) - if (lm.get() && constant_cast(lm->get_left()).get()) { - return lm->get_left()*(lm->get_right()/this->right); - } - - if (lm.get() && lm->get_left()->is_constant_like()) { + if (lm.get() && lm->get_left()->is_constant() && + !lm->get_right()->is_constant()) { return lm->get_left()*(lm->get_right()/this->right); } @@ -2385,16 +2633,25 @@ namespace graph { return this->left + this->right; } - auto pl1 = piecewise_1D_cast(this->left); - auto pm1 = piecewise_1D_cast(this->middle); - auto pl2 = piecewise_2D_cast(this->left); - auto pm2 = piecewise_2D_cast(this->middle); +// Check if the left and middle are combinable. This will be constant merged in +// multiply reduction. + if (is_constant_combineable(this->left, this->middle) || + is_variable_combinable(this->left, this->middle)) { + return (this->left*this->middle) + this->right; + } - if ((pl1.get() && (m.get() || pl1->is_arg_match(this->middle))) || - (pm1.get() && (l.get() || pm1->is_arg_match(this->left))) || - (pl2.get() && (m.get() || pl2->is_arg_match(this->middle))) || - (pm2.get() && (l.get() || pm2->is_arg_match(this->left)))) { - return (this->left*this->middle) + this->right; +// fma(c2,c1,a) -> fma(c1,c2,a) + if (is_constant_promotable(this->middle, + this->left)) { + return fma(this->middle, this->left, this->right); + } + +// fma(a,b,a) -> a*(1 + b) +// fma(b,a,a) -> a*(1 + b) + if (this->left->is_match(this->right)) { + return this->left*(one () + this->middle); + } else if (this->middle->is_match(this->right)) { + return this->middle*(one () + this->left); } // Common factor reduction. If the left and right are both multiply nodes check @@ -2413,12 +2670,35 @@ namespace graph { return this->middle*(this->left + rm->get_left()); } -// Change cases like c1*a + c2*b -> c1*(c3*b + a) - auto rmc = constant_cast(rm->get_left()); - if (rmc.get() && l.get()) { +// Change cases like +// fma(c1,a,c2*b) -> c1*fma(c3,b,a) +// fma(a,c1,c2*b) -> c1*fma(c3,b,a) +// fma(c1,a,b*c2) -> c1*fma(c3,b,a) +// fma(a,c1,b*c2) -> c1*fma(c3,b,a) + if (is_constant_combineable(this->left, + rm->get_left()) && + !this->left->has_constant_zero()) { return this->left*fma(rm->get_left()/this->left, rm->get_right(), this->middle); + } else if (is_constant_combineable(this->middle, + rm->get_left()) && + !this->middle->has_constant_zero()) { + return this->middle*fma(rm->get_left()/this->middle, + rm->get_right(), + this->left); + } else if (is_constant_combineable(this->left, + rm->get_right()) && + !this->left->has_constant_zero()) { + return this->left*fma(rm->get_right()/this->left, + rm->get_left(), + this->middle); + } else if (is_constant_combineable(this->middle, + rm->get_right()) && + !this->middle->has_constant_zero()) { + return this->middle*fma(rm->get_right()/this->middle, + rm->get_left(), + this->left); } // Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c) @@ -2433,48 +2713,82 @@ namespace graph { // Handle cases like. // fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d) +// fma(a*c1,b,c2*d) -> c1*(a*b + c2/c1*d) +// fma(c1*a,b,d*c2*d) -> c1*(a*b + c2/c1*d) +// fma(a*c1,b,d*c2*d) -> c1*(a*b + c2/c1*d) if (lm.get() && rm.get()) { - auto rmc = constant_cast(rm->get_left()); - if (rmc.get()) { + if (is_constant_combineable(rm->get_left(), + lm->get_left()) && + !lm->get_left()->has_constant_zero()) { return lm->get_left()*fma(lm->get_right(), this->middle, (rm->get_left()/lm->get_left())*rm->get_right()); + } else if (is_constant_combineable(rm->get_left(), + lm->get_right()) && + !lm->get_right()->has_constant_zero()) { + return lm->get_right()*fma(lm->get_left(), + this->middle, + (rm->get_left()/lm->get_right())*rm->get_right()); + } else if (is_constant_combineable(rm->get_right(), + lm->get_left()) && + !lm->get_left()->has_constant_zero()) { + return lm->get_left()*fma(lm->get_right(), + this->middle, + (rm->get_right()/lm->get_left())*rm->get_left()); + } else if (is_constant_combineable(rm->get_right(), + lm->get_right()) && + !lm->get_right()->has_constant_zero()) { + return lm->get_right()*fma(lm->get_left(), + this->middle, + (rm->get_right()/lm->get_right())*rm->get_left()); } } // Move constant multiplies to the left. if (lm.get()) { - auto lmc = constant_cast(lm->get_left()); - if (lmc.get()) { +// fma(c1*a,b,c) -> fma(c1,a*b,c) + if (is_constant_promotable(lm->get_left(), + lm->get_right())) { return fma(lm->get_left(), lm->get_right()*this->middle, this->right); } } else if (mm.get()) { - auto mmc = constant_cast(mm->get_left()); - auto mmpw1c = piecewise_1D_cast(mm->get_left()); - auto mmpw2c = piecewise_2D_cast(mm->get_left()); - if (mmc.get() || mmpw1c.get() || mmpw2c.get()) { - if (l.get() || pl1.get() || pl2.get()) { - return fma(this->left*mm->get_left(), - mm->get_right(), - this->right); - } else { - return fma(mm->get_left(), - this->left*mm->get_right(), - this->right); - } +// fma(c1,c2*a,b) -> fma(c3,a,b) +// fma(c1,a*c2,b) -> fma(c3,a,b) +// fma(a,c1*b,c) -> fma(c1,a*b,c) + if (is_constant_combineable(this->left, + mm->get_left())) { + return fma(this->left*mm->get_left(), + mm->get_right(), + this->right); + } else if (is_constant_combineable(this->left, + mm->get_right())) { + return fma(this->left*mm->get_right(), + mm->get_left(), + this->right); + } else if (is_constant_promotable(mm->get_left(), + this->left)) { + return fma(mm->get_left(), + this->left*mm->get_right(), + this->right); } } -// fma(c1,a,c2/b) -> c1*(a + c1/(c2*b)) -// fma(c1,a,b/c2) -> c1*(a + b/(c1*c2)) +// fma(c1,a,c2/b) -> c1*(a + c3/b) +// fma(a,c1,c2/b) -> c1*(a + c3/b) auto rd = divide_cast(this->right); - if (l.get() && rd.get()) { - if (constant_cast(rd->get_left()).get() || - constant_cast(rd->get_right()).get()) { + if (rd.get()) { + if (is_constant_combineable(this->left, + rd->get_left()) && + !this->left->has_constant_zero()) { return this->left*(this->middle + rd->get_left()/(this->left*rd->get_right())); + } else if (is_constant_combineable(this->middle, + rd->get_left()) && + !this->middle->has_constant_zero()) { + return this->middle*(this->left + + rd->get_left()/(this->middle*rd->get_right())); } } @@ -2511,6 +2825,342 @@ namespace graph { // Chained fma reductions. auto rfma = fma_cast(this->right); if (rfma.get()) { +// fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) +// fma(b, a, fma(c, b, d)) -> fma(b, a + c, d) +// fma(a, b, fma(b, c, d)) -> fma(b, a + c, d) +// fma(b, a, fma(b, c, d)) -> fma(b, a + c, d) + if (this->middle->is_match(rfma->get_middle())) { + return fma(this->middle, + this->left + rfma->get_left(), + rfma->get_right()); + } else if (this->left->is_match(rfma->get_middle())) { + return fma(this->left, + this->middle + rfma->get_left(), + rfma->get_right()); + } else if (this->middle->is_match(rfma->get_left())) { + return fma(this->middle, + this->left + rfma->get_middle(), + rfma->get_right()); + } else if (this->left->is_match(rfma->get_left())) { + return fma(this->left, + this->middle + rfma->get_middle(), + rfma->get_right()); + } + + if (mm.get()) { +// fma(a, e*b, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(a, e*b, fma(b, c, d)) -> fma(b, fma(a, e, c), d) +// fma(a, b*e, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + if (mm->get_right()->is_match(rfma->get_middle())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_left()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfma->get_middle())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_left()), + rfma->get_right()); + } else if (mm->get_right()->is_match(rfma->get_left())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_middle()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfma->get_left())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_middle()), + rfma->get_right()); + } + } else if (lm.get()) { +// fma(e*b, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(b*e, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) +// fma(e*b, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) +// fma(e*d, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + if (lm->get_right()->is_match(rfma->get_middle())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_left()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfma->get_middle())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_left()), + rfma->get_right()); + } else if (lm->get_right()->is_match(rfma->get_left())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_middle()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfma->get_left())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_middle()), + rfma->get_right()); + } + } + + auto rfmamm = multiply_cast(rfma->get_middle()); + auto rfmalm = multiply_cast(rfma->get_left()); + if (rfmamm.get()) { +// fma(a, b, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) +// fma(a, b, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) + if (rfmamm->get_right()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_left(), + rfmamm->get_left(), + this->left), + rfma->get_right()); + } else if (rfmamm->get_right()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_left(), + rfmamm->get_left(), + this->middle), + rfma->get_right()); + } else if (rfmamm->get_left()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_left(), + rfmamm->get_right(), + this->left), + rfma->get_right()); + } else if (rfmamm->get_left()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_left(), + rfmamm->get_right(), + this->middle), + rfma->get_right()); + } + } else if (rfmalm.get()) { +// fma(a, b, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) +// fma(a, b, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) +// fma(b, a, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) + if (rfmalm->get_right()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_middle(), + rfmalm->get_left(), + this->left), + rfma->get_right()); + } else if (rfmalm->get_right()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_middle(), + rfmalm->get_left(), + this->middle), + rfma->get_right()); + } else if (rfmalm->get_left()->is_match(this->middle)) { + return fma(this->middle, + fma(rfma->get_middle(), + rfmalm->get_right(), + this->left), + rfma->get_right()); + } else if (rfmalm->get_left()->is_match(this->left)) { + return fma(this->left, + fma(rfma->get_middle(), + rfmalm->get_right(), + this->middle), + rfma->get_right()); + } + } + + if (mm.get() && rfmamm.get()) { +// fma(a, f*b, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, f*b, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + if (mm->get_right()->is_match(rfmamm->get_right())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmamm->get_right())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (mm->get_right()->is_match(rfmamm->get_left())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmamm->get_left())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } + } else if (lm.get() && rfmamm.get()) { +// fma(f*b, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) +// fma(f*b, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + if (lm->get_right()->is_match(rfmamm->get_right())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmamm->get_right())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_left()*rfmamm->get_left()), + rfma->get_right()); + } else if (lm->get_right()->is_match(rfmamm->get_left())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmamm->get_left())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_left()*rfmamm->get_right()), + rfma->get_right()); + } + } else if (mm.get() && rfmalm.get()) { +// fma(a, f*b, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, f*b, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(a, b*f, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + if (mm->get_right()->is_match(rfmalm->get_right())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmalm->get_right())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (mm->get_right()->is_match(rfmalm->get_left())) { + return fma(mm->get_right(), + fma(this->left, + mm->get_left(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } else if (mm->get_left()->is_match(rfmalm->get_left())) { + return fma(mm->get_left(), + fma(this->left, + mm->get_right(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } + } else if (lm.get() && rfmalm.get()) { +// fma(f*b, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(f*b, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) +// fma(b*f, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + if (lm->get_right()->is_match(rfmalm->get_right())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmalm->get_right())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_middle()*rfmalm->get_left()), + rfma->get_right()); + } else if (lm->get_right()->is_match(rfmalm->get_left())) { + return fma(lm->get_right(), + fma(this->middle, + lm->get_left(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } else if (lm->get_left()->is_match(rfmalm->get_left())) { + return fma(lm->get_left(), + fma(this->middle, + lm->get_right(), + rfma->get_middle()*rfmalm->get_right()), + rfma->get_right()); + } + } + + if (is_variable_combinable(this->middle, rfma->get_middle())) { + if (is_greater_exponent(this->middle, rfma->get_middle())) { +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_middle(), + fma(this->middle/rfma->get_middle(), + this->left, + rfma->get_left()), + rfma->get_right()); + } else { +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->middle, + fma(rfma->get_middle()/this->middle, + rfma->get_left(), + this->left), + rfma->get_right()); + } + } else if (is_variable_combinable(this->left, rfma->get_middle())) { + if (is_greater_exponent(this->left, rfma->get_middle())) { +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_middle(), + fma(this->left/rfma->get_middle(), + this->middle, + rfma->get_left()), + rfma->get_right()); + } else { +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->left, + fma(rfma->get_middle()/this->left, + rfma->get_left(), + this->middle), + rfma->get_right()); + } + } else if (is_variable_combinable(this->middle, rfma->get_left())) { + if (is_greater_exponent(this->middle, rfma->get_left())) { +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_left(), + fma(this->middle/rfma->get_left(), + this->left, + rfma->get_middle()), + rfma->get_right()); + } else { +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->middle, + fma(rfma->get_left()/this->middle, + rfma->get_middle(), + this->left), + rfma->get_right()); + } + } else if (is_variable_combinable(this->left, rfma->get_left())) { + if (is_greater_exponent(this->left, rfma->get_left())) { +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + return fma(rfma->get_left(), + fma(this->left/rfma->get_left(), + this->middle, + rfma->get_middle()), + rfma->get_right()); + } else { +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + return fma(this->left, + fma(rfma->get_left()/this->left, + rfma->get_middle(), + this->middle), + rfma->get_right()); + } + } + // fma(a,b,fma(a,b,c)) -> fma(2*a,b,c) // fma(a,b,fma(b,a,c)) -> fma(2*a,b,c) if (this->left->is_match(rfma->get_left()) && @@ -2553,7 +3203,7 @@ namespace graph { } // Check to see if it is worth moving nodes out of a fma nodes. These should be -// restricted to variable like nodes. Only do this reduction is the complexity +// restricted to variable like nodes. Only do this reduction if the complexity // reduces. if (this->left->is_all_variables()) { auto rdl = this->right/this->left; @@ -2570,12 +3220,12 @@ namespace graph { } // Promote constants out to the left. - if (l.get() && r.get()) { + if (is_constant_combineable(this->left, this->right) && + !this->left->has_constant_zero()) { return this->left*(this->middle + this->right/this->left); } - -// Change negative eponents to divide so that can be factored out. +// Change negative exponents to divide so that can be factored out. // fma(a,b^-c,d) = a/b^c + d // fma(b^-c,a,d) = a/b^c + d auto lp = pow_cast(this->left); diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index e90dcda..e3fff57 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -10,6 +10,7 @@ #include #include +#include #include "special_functions.hpp" #include "register.hpp" @@ -161,7 +162,22 @@ namespace backend { return true; } + +//------------------------------------------------------------------------------ +/// @brief Is every element zero. +/// +/// @returns Returns true if every element is zero. +//------------------------------------------------------------------------------ + bool has_zero() const { + for (T d : memory) { + if (d == static_cast (0.0)) { + return true; + } + } + return false; + } + //------------------------------------------------------------------------------ /// @brief Is every element negative. /// @@ -256,6 +272,475 @@ namespace backend { return memory.data(); } +//------------------------------------------------------------------------------ +/// @brief Check for normal values. +/// +/// @returns False if any NaN or Inf is found. +//------------------------------------------------------------------------------ + bool is_normal() const { + for (T x : memory) { + if constexpr (jit::is_complex ()) { + if (std::isnan(std::real(x)) || std::isinf(std::real(x)) || + std::isnan(std::imag(x)) || std::isinf(std::imag(x))) { + return false; + } + } else { + if (std::isnan(x) || std::isinf(x)) { + return false; + } + } + } + return true; + } + +//------------------------------------------------------------------------------ +/// @brief Add row operation. +/// +/// Adds m_ij + v_i or v_i + m_ij. This will resize the buffer if it needs to +/// be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void add_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] += x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[j] + x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Add col operation. +/// +/// Adds m_ij + v_j or v_j + m_ij. This will resize the buffer if it needs to +/// be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void add_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] += x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[i] + x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Subtract row operation. +/// +/// Sunbtracts m_ij - v_i or v_i - m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void subtract_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] -= x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[j] - x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Subtract col operation. +/// +/// Sunbtracts m_ij - v_j or v_j - m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void subtract_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] -= x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[i] - x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Multiply row operation. +/// +/// Multiplies m_ij * v_i or v_i * m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void multiply_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] *= x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[j]*x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Multiply col operation. +/// +/// Multiplies m_ij * v_j or v_j * m_ij. This will resize the buffer if it +/// needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void multiply_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] *= x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[i]*x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Divide row operation. +/// +/// Divides m_ij / v_i or v_i / m_ij. This will resize the buffer if it needs +/// to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void divide_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] /= x[j]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[j]/x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Divide col operation. +/// +/// Divides m_ij / v_j or v_j / m_ij. This will resize the buffer if it needs +/// to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void divide_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] /= x[i]; + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = memory[i]/x[i*num_rows + j]; + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Atan row operation. +/// +/// Computes atan(m_ij, v_i) or atan(v_i, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void atan_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + if constexpr (jit::is_complex ()) { + memory[i*num_rows + j] = std::atan(x[j]/memory[i*num_rows + j]); + } else { + memory[i*num_rows + j] = std::atan2(x[j], memory[i*num_rows + j]); + } + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + if constexpr (jit::is_complex ()) { + m[i*num_rows + j] = std::atan(x[i*num_rows + j]/memory[j]); + } else { + m[i*num_rows + j] = std::atan2(x[i*num_rows + j], memory[j]); + } + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Atan col operation. +/// +/// Computes atan(m_ij, v_j) or atan(v_j, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void atan_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + if constexpr (jit::is_complex ()) { + memory[i*num_rows + j] = std::atan(x[i]/memory[i*num_rows + j]); + } else { + memory[i*num_rows + j] = std::atan2(x[i], memory[i*num_rows + j]); + } + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + if constexpr (jit::is_complex ()) { + m[i*num_rows + j] = std::atan(x[i*num_rows + j]/memory[i]); + } else { + m[i*num_rows + j] = std::atan2(x[i*num_rows + j], memory[i]); + } + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Pow row operation. +/// +/// Computes pow(m_ij, v_i) or pow(v_i, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The right operand. +//------------------------------------------------------------------------------ + void pow_row(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] = std::pow(memory[i*num_rows + j], x[j]); + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = std::pow(memory[j], x[i*num_rows + j]); + } + } + memory = m; + } + } + +//------------------------------------------------------------------------------ +/// @brief Pow col operation. +/// +/// Computes pow(m_ij, v_j) or pow(v_j, m_ij). This will resize the buffer if +/// it needs to be. +/// +/// @params[in] x The other operand. +//------------------------------------------------------------------------------ + void pow_col(const buffer &x) { + if (size() > x.size()) { + assert(size()%x.size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + const size_t num_colmns = size()/x.size(); + const size_t num_rows = x.size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + memory[i*num_rows + j] = std::pow(memory[i*num_rows + j], x[i]); + } + } + } else { + assert(x.size()%size() == 0 && + "Vector operand size is not a multiple of matrix operand size"); + + std::vector m(x.size()); + const size_t num_colmns = x.size()/size(); + const size_t num_rows = size(); + for (size_t i = 0; i < num_colmns; i++) { + for (size_t j = 0; j < num_rows; j++) { + m[i*num_rows + j] = std::pow(memory[i], x[i*num_rows + j]); + } + } + memory = m; + } + } + /// Type def to retrieve the backend T type. typedef T base; }; diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index 8dcd6cd..a41d484 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -24,6 +24,10 @@ #include "clang/Lex/PreprocessorOptions.h" #include "llvm/Support/TargetSelect.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" +#ifndef NDEBUG +#include "llvm/ExecutionEngine/Orc/Debugging/DebuggerSupport.h" +#include "llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h" +#endif #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/IntrusiveRefCntPtr.h" #include "llvm/ADT/SmallVector.h" @@ -32,6 +36,16 @@ #include "node.hpp" +#ifndef NDEBUG +//------------------------------------------------------------------------------ +/// @brief This just exposes the functions so the debugger links. +//------------------------------------------------------------------------------ +LLVM_ATTRIBUTE_USED void linkComponents() { + llvm::errs() << (void *)&llvm_orc_registerJITLoaderGDBWrapper + << (void *)&llvm_orc_registerJITLoaderGDBAllocAction; +} +#endif + namespace gpu { //------------------------------------------------------------------------------ /// @brief Split a string by the space delimiter. @@ -135,6 +149,8 @@ namespace gpu { args.push_back(filename.c_str()); #ifdef NDEBUG args.push_back("-O3"); +#else + args.push_back("-debug-info-kind=standalone"); #endif if (jit::verbose) { for (auto &arg : args) { @@ -176,7 +192,13 @@ namespace gpu { auto ir_module = action.takeModule(); auto context = std::unique_ptr (action.takeLLVMContext()); - auto jit_try = llvm::orc::LLJITBuilder().create(); + auto jit_try = llvm::orc::LLJITBuilder() +#ifndef NDEBUG + .setPrePlatformSetup([](llvm::orc::LLJIT &J) { + return llvm::orc::enableDebuggerSupport(J); + }) +#endif + .create(); if (auto jiterror = jit_try.takeError()) { std::cerr << "Failed to build JIT : " << toString(std::move(jiterror)) << std::endl; exit(-1); diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index d988cde..a8373b3 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -99,14 +99,10 @@ namespace graph { // sqrt((x^a)*y). auto am = multiply_cast(this->arg); if (am.get()) { - if (pow_cast(am->get_left()).get() || - constant_cast(am->get_left()).get() || - piecewise_1D_cast(am->get_left()).get() || - piecewise_2D_cast(am->get_left()).get() || - pow_cast(am->get_right()).get() || - constant_cast(am->get_right()).get() || - piecewise_1D_cast(am->get_right()).get() || - piecewise_2D_cast(am->get_right()).get()) { + if (pow_cast(am->get_left()).get() || + am->get_left()->is_constant() || + pow_cast(am->get_right()).get() || + am->get_right()->is_constant()) { return sqrt(am->get_left()) * sqrt(am->get_right()); } @@ -116,14 +112,10 @@ namespace graph { // where c is a constant. auto ad = divide_cast(this->arg); if (ad.get()) { - if (pow_cast(ad->get_left()).get() || - constant_cast(ad->get_left()).get() || - piecewise_1D_cast(ad->get_left()).get() || - piecewise_2D_cast(ad->get_left()).get() || - pow_cast(ad->get_right()).get() || - constant_cast(ad->get_right()).get() || - piecewise_1D_cast(ad->get_right()).get() || - piecewise_2D_cast(ad->get_right()).get()) { + if (pow_cast(ad->get_left()).get() || + ad->get_left()->is_constant() || + pow_cast(ad->get_right()).get() || + ad->get_right()->is_constant()) { return sqrt(ad->get_left()) / sqrt(ad->get_right()); } @@ -865,39 +857,77 @@ namespace graph { /// @returns A reduced power node. //------------------------------------------------------------------------------ virtual shared_leaf reduce() { + auto lc = constant_cast(this->left); auto rc = constant_cast(this->right); - if (rc.get()) { - if (rc->is(0)) { - return one (); - } else if (rc->is(1)) { - return this->left; - } else if (rc->is(0.5)) { - return sqrt(this->left); - } else if (rc->is(2)){ - auto sq = sqrt_cast(this->left); - if (sq.get()) { - return sq->get_arg(); - } - } - - if (constant_cast(this->left).get()) { - return constant (this->evaluate()); + if (rc.get() && rc->is(0)) { + return one (); + } else if (rc.get() && rc->is(1)) { + return this->left; + } else if (rc.get() && rc->is(0.5)) { + return sqrt(this->left); + } else if (rc.get() && rc->is(2)){ + auto sq = sqrt_cast(this->left); + if (sq.get()) { + return sq->get_arg(); } + } - auto pl1 = piecewise_1D_cast(this->left); - if (pl1.get()) { - return piecewise_1D(this->evaluate(), - pl1->get_arg()); - } + if (lc.get() && rc.get()) { + return constant (this->evaluate()); + } - auto pl2 = piecewise_2D_cast(this->left); - if (pl2.get()) { - return piecewise_2D(this->evaluate(), - pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); - } + auto pl1 = piecewise_1D_cast(this->left); + auto pr1 = piecewise_1D_cast(this->right); + if (pl1.get() && (rc.get() || pl1->is_arg_match(this->right))) { + return piecewise_1D(this->evaluate(), pl1->get_arg()); + } else if (pr1.get() && (lc.get() || pr1->is_arg_match(this->left))) { + return piecewise_1D(this->evaluate(), pr1->get_arg()); + } + + auto pl2 = piecewise_2D_cast(this->left); + auto pr2 = piecewise_2D_cast(this->right); + if (pl2.get() && (rc.get() || pl2->is_arg_match(this->right))) { + return piecewise_2D(this->evaluate(), + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pr2.get() && (lc.get() || pr2->is_arg_match(this->left))) { + return piecewise_2D(this->evaluate(), + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } + +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.pow_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.pow_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.pow_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.pow_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); } auto lp = pow_cast(this->left); @@ -909,15 +939,11 @@ namespace graph { // Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2. auto lm = multiply_cast(this->left); if (lm.get()) { - if (constant_cast(lm->get_left()).get() || - constant_cast(lm->get_right()).get() || - piecewise_1D_cast(lm->get_left()).get() || - piecewise_1D_cast(lm->get_right()).get() || - piecewise_2D_cast(lm->get_left()).get() || - piecewise_2D_cast(lm->get_right()).get() || - sqrt_cast(lm->get_left()).get() || - sqrt_cast(lm->get_right()).get() || - pow_cast(lm->get_left()).get() || + if (lm->get_left()->is_constant() || + lm->get_right()->is_constant() || + sqrt_cast(lm->get_left()).get() || + sqrt_cast(lm->get_right()).get() || + pow_cast(lm->get_left()).get() || pow_cast(lm->get_right()).get()) { return pow(lm->get_left(), this->right) * pow(lm->get_right(), this->right); @@ -927,15 +953,11 @@ namespace graph { // Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2. auto ld = divide_cast(this->left); if (ld.get()) { - if (constant_cast(ld->get_left()).get() || - constant_cast(ld->get_right()).get() || - piecewise_1D_cast(ld->get_left()).get() || - piecewise_1D_cast(ld->get_right()).get() || - piecewise_2D_cast(ld->get_left()).get() || - piecewise_2D_cast(ld->get_right()).get() || - sqrt_cast(ld->get_left()).get() || - sqrt_cast(ld->get_right()).get() || - pow_cast(ld->get_left()).get() || + if (ld->get_left()->is_constant() || + ld->get_right()->is_constant() || + sqrt_cast(ld->get_left()).get() || + sqrt_cast(ld->get_right()).get() || + pow_cast(ld->get_left()).get() || pow_cast(ld->get_right()).get()) { return pow(ld->get_left(), this->right) / pow(ld->get_right(), this->right); diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 35aa445..ccaae18 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -187,11 +187,22 @@ namespace graph { jit::register_map ®isters) = 0; //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the node is like a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const = 0; + virtual bool is_constant() const { + return false; + } + +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return false; + } //------------------------------------------------------------------------------ /// @brief Test if all the subnodes terminate in variables. @@ -203,7 +214,7 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Test if the node acts like a power of variable. /// -/// Most notes are not so default to false. +/// Most nodes are not so default to false. /// /// @returns True the node is power like and false otherwise. //------------------------------------------------------------------------------ @@ -346,6 +357,7 @@ namespace graph { constant_node(const backend::buffer &d) : leaf_node (constant_node::to_string(d.at(0)), 1, false), data(d) { assert(d.size() == 1 && "Constants need to be scalar functions."); + assert(d.is_normal() && "NaN or Inf value."); } //------------------------------------------------------------------------------ @@ -470,14 +482,23 @@ namespace graph { } //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the is a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const { + virtual bool is_constant() const { return true; } +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return data.has_zero(); + } + //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -768,15 +789,6 @@ namespace graph { return this->arg; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return this->arg->is_constant_like(); - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -897,16 +909,6 @@ namespace graph { return this->right; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return this->left->is_constant_like() && - this->right->is_constant_like(); - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -1014,17 +1016,6 @@ namespace graph { return this->middle; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return this->left->is_constant_like() && - this->middle->is_constant_like() && - this->right->is_constant_like(); - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -1086,7 +1077,9 @@ namespace graph { variable_node(const size_t s, const T d, const std::string &symbol) : leaf_node (variable_node::to_string(this), 1, false), - buffer(s, d), symbol(symbol) {} + buffer(s, d), symbol(symbol) { + assert(buffer.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Construct a variable node from a vector. @@ -1097,7 +1090,9 @@ namespace graph { variable_node(const std::vector &d, const std::string &symbol) : leaf_node (variable_node::to_string(this), 1, false), - buffer(d), symbol(symbol) {} + buffer(d), symbol(symbol) { + assert(buffer.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Construct a variable node from backend buffer. @@ -1108,7 +1103,9 @@ namespace graph { variable_node(const backend::buffer &d, const std::string &symbol) : leaf_node (variable_node::to_string(this), 1, false), - buffer(d), symbol(symbol) {} + buffer(d), symbol(symbol) { + assert(buffer.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Evaluate method. @@ -1251,15 +1248,6 @@ namespace graph { T *data() { return buffer.data(); } - -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return false; - } //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. @@ -1459,15 +1447,6 @@ namespace graph { std::cout << "\\right)"; } -//------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. -/// -/// @returns True if the node acts like a constant. -//------------------------------------------------------------------------------ - virtual bool is_constant_like() const { - return false; - } - //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 93fd335..73c30f7 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -136,7 +136,9 @@ void compile_index(std::ostringstream &stream, piecewise_1D_node(const backend::buffer &d, shared_leaf x) : straight_node (x, piecewise_1D_node::to_string(d, x)), - data_hash(piecewise_1D_node::hash_data(d)) {} + data_hash(piecewise_1D_node::hash_data(d)) { + assert(d.is_normal() && "NaN or Inf value."); + } //------------------------------------------------------------------------------ /// @brief Evaluate the results of the piecewise constant. @@ -378,14 +380,23 @@ void compile_index(std::ostringstream &stream, } //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the node is a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const { + virtual bool is_constant() const { return true; } +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return leaf_node::backend_cache[data_hash].has_zero(); + } + //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -600,6 +611,7 @@ void compile_index(std::ostringstream &stream, num_columns(n) { assert(d.size()/n && "Expected the data buffer to be a multiple of the number of columns."); + assert(d.is_normal() && "NaN or Inf value."); } //------------------------------------------------------------------------------ @@ -895,14 +907,23 @@ void compile_index(std::ostringstream &stream, } //------------------------------------------------------------------------------ -/// @brief Test if node acts like a constant. +/// @brief Test if node is a constant. /// -/// @returns True if the node acts like a constant. +/// @returns True if the node is a constant. //------------------------------------------------------------------------------ - virtual bool is_constant_like() const { + virtual bool is_constant() const { return true; } +//------------------------------------------------------------------------------ +/// @brief Test the constant node has a zero. +/// +/// @returns True the node has a zero constant value. +//------------------------------------------------------------------------------ + virtual bool has_constant_zero() const { + return leaf_node::backend_cache[data_hash].has_zero(); + } + //------------------------------------------------------------------------------ /// @brief Test if node acts like a variable. /// @@ -942,7 +963,7 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Check if the args match. /// -/// @param[in] x Node to match. +/// @params[in] x Node to match. /// @returns True if the arguments match. //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { @@ -952,6 +973,28 @@ void compile_index(std::ostringstream &stream, this->right->is_match(temp->get_right()) && (num_columns == this->get_num_columns()); } + +//------------------------------------------------------------------------------ +/// @brief Do the rows match. +/// +/// @params[in] x Node to match. +/// @returns True if the row arguments match. +//------------------------------------------------------------------------------ + bool is_row_match(shared_leaf x) { + auto temp = piecewise_1D_cast(x); + return temp.get() && this->left->is_match(temp->get_arg()); + } + +//------------------------------------------------------------------------------ +/// @brief Do the columns match. +/// +/// @params[in] x Node to match. +/// @returns True if the column arguments match. +//------------------------------------------------------------------------------ + bool is_col_match(shared_leaf x) { + auto temp = piecewise_1D_cast(x); + return temp.get() && this->right->is_match(temp->get_arg()); + } }; //------------------------------------------------------------------------------ diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index 91ff128..8a5672d 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -577,6 +577,61 @@ namespace graph { return constant (this->evaluate()); } + auto pl1 = piecewise_1D_cast(this->left); + auto pr1 = piecewise_1D_cast(this->right); + + if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { + return piecewise_1D(this->evaluate(), pl1->get_arg()); + } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { + return piecewise_1D(this->evaluate(), pr1->get_arg()); + } + + auto pl2 = piecewise_2D_cast(this->left); + auto pr2 = piecewise_2D_cast(this->right); + + if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { + return piecewise_2D(this->evaluate(), + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { + return piecewise_2D(this->evaluate(), + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } + +// Combine 2D and 1D piecewise constants if a row or column matches. + if (pr2.get() && pr2->is_row_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.atan_row(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pr2.get() && pr2->is_col_match(this->left)) { + backend::buffer result = pl1->evaluate(); + result.atan_col(pr2->evaluate()); + return piecewise_2D(result, + pr2->get_num_columns(), + pr2->get_left(), + pr2->get_right()); + } else if (pl2.get() && pl2->is_row_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.atan_row(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } else if (pl2.get() && pl2->is_col_match(this->right)) { + backend::buffer result = pl2->evaluate(); + result.atan_col(pr1->evaluate()); + return piecewise_2D(result, + pl2->get_num_columns(), + pl2->get_left(), + pl2->get_right()); + } + return this->shared_from_this(); } diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 8819649..b4e5e78 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -150,7 +150,7 @@ template void test_add() { // (c1*v1 + c2) + (c3*v1 + c4) -> c5*v1 + c6 auto var_e = graph::variable (1, ""); auto addfma1 = graph::fma(var_b, var_a, var_d) - + graph::fma(var_c, var_a, var_e); + + graph::fma(var_c, var_a, var_e); assert(graph::fma_cast(addfma1).get() && "Expected fused multiply add node."); // (v1*c1 + c2) + (v1*c3 + c4) -> c5*v1 + c6 @@ -250,20 +250,20 @@ template void test_add() { assert(muliply_divide_factor_cast4.get() && "Expected divide node."); // Test node properties. - assert(three->is_constant_like() && "Expected a constant."); + assert(three->is_constant() && "Expected a constant."); assert(!three->is_all_variables() && "Did not expect a variable."); assert(three->is_power_like() && "Expected a power like."); auto constant_add = three + graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var_a); - assert(constant_add->is_constant_like() && "Expected a constant."); + assert(constant_add->is_constant() && "Expected a constant."); assert(!constant_add->is_all_variables() && "Did not expect a variable."); assert(constant_add->is_power_like() && "Expected a power like."); auto constant_var_add = three + var_a; - assert(!constant_var_add->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_add->is_constant() && "Did not expect a constant."); assert(!constant_var_add->is_all_variables() && "Did not expect a variable."); assert(!constant_var_add->is_power_like() && "Did not expect a power like."); auto var_var_add = var_a + variable; - assert(!var_var_add->is_constant_like() && "Did not expect a constant."); + assert(!var_var_add->is_constant() && "Did not expect a constant."); assert(var_var_add->is_all_variables() && "Expected a variable."); assert(!var_var_add->is_power_like() && "Did not expect a power like."); } @@ -553,20 +553,20 @@ template void test_subtract() { "Expected a fused multiply add node on the left."); // Test node properties. - assert(zero->is_constant_like() && "Expected a constant."); + assert(zero->is_constant() && "Expected a constant."); assert(!zero->is_all_variables() && "Did not expect a variable."); assert(zero->is_power_like() && "Expected a power like."); auto constant_sub = one - graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var_a); - assert(constant_sub->is_constant_like() && "Expected a constant."); + assert(constant_sub->is_constant() && "Expected a constant."); assert(!constant_sub->is_all_variables() && "Did not expect a variable."); assert(constant_sub->is_power_like() && "Expected a power like."); auto constant_var_sub = one - var_a; - assert(!constant_var_sub->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_sub->is_constant() && "Did not expect a constant."); assert(!constant_var_sub->is_all_variables() && "Did not expect a variable."); assert(!constant_var_sub->is_power_like() && "Did not expect a power like."); auto var_var_sub = var_a - var_b; - assert(!var_var_sub->is_constant_like() && "Did not expect a constant."); + assert(!var_var_sub->is_constant() && "Did not expect a constant."); assert(var_var_sub->is_all_variables() && "Expected a variable."); assert(!var_var_sub->is_power_like() && "Did not expect a power like."); @@ -1034,20 +1034,20 @@ template void test_multiply() { "Expected a divide node."); // Test node properties. - assert(two_times_three->is_constant_like() && "Expected a constant."); + assert(two_times_three->is_constant() && "Expected a constant."); assert(!two_times_three->is_all_variables() && "Did not expect a variable."); assert(two_times_three->is_power_like() && "Expected a power like."); auto constant_mul = three*graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), variable); - assert(constant_mul->is_constant_like() && "Expected a constant."); + assert(constant_mul->is_constant() && "Expected a constant."); assert(!constant_mul->is_all_variables() && "Did not expect a variable."); assert(constant_mul->is_power_like() && "Expected a power like."); auto constant_var_mul = three*variable; - assert(!constant_var_mul->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_mul->is_constant() && "Did not expect a constant."); assert(!constant_var_mul->is_all_variables() && "Did not expect a variable."); assert(!constant_var_mul->is_power_like() && "Did not expect a power like."); auto var_var_mul = variable*a; - assert(!var_var_mul->is_constant_like() && "Did not expect a constant."); + assert(!var_var_mul->is_constant() && "Did not expect a constant."); assert(var_var_mul->is_all_variables() && "Expected a variable."); assert(!var_var_mul->is_power_like() && "Did not expect a power like."); @@ -1786,8 +1786,8 @@ template void test_divide() { assert(fma_divide_cast2.get() && "Expected an fma node."); // fma(d,a,c*d)/d -> a + c auto fma_divide3 = graph::fma(a, - graph::variable (1, ""), - graph::variable (1, "")*a)/a; + graph::variable (1, ""), + graph::variable (1, "")*a)/a; auto fma_divide_cast3 = graph::add_cast(fma_divide3); assert(fma_divide_cast3.get() && "Expected an add node."); // fma(d,a,c*d)/d -> a + c @@ -1797,6 +1797,15 @@ template void test_divide() { auto fma_divide_cast4 = graph::add_cast(fma_divide4); assert(fma_divide_cast4.get() && "Expected an add node."); +// fma(a,b,a)/a -> 1 + b + auto fma_divide5 = graph::fma(a, graph::variable (1, ""), a)/a; + auto fma_divide5_cast = graph::add_cast(fma_divide5); + assert(fma_divide5_cast.get() && "Expected an add node."); +// fma(b,a,a)/a -> 1 + b + auto fma_divide6 = graph::fma(graph::variable (1, ""), a, a)/a; + auto fma_divide6_cast = graph::add_cast(fma_divide6); + assert(fma_divide6_cast.get() && "Expected an add node."); + // (a*b^c)/b^d -> a*b^(c - d) auto common_power = (variable*graph::pow(a, three))/graph::pow(a, two); assert(graph::multiply_cast(common_power).get() && @@ -1807,20 +1816,20 @@ template void test_divide() { "Expected a multiply node."); // Test node properties. - assert(two_divided_three->is_constant_like() && "Expected a constant."); + assert(two_divided_three->is_constant() && "Expected a constant."); assert(!two_divided_three->is_all_variables() && "Did not expect a variable."); assert(two_divided_three->is_power_like() && "Expected a power like."); auto constant_div = two_divided_three/graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), variable); - assert(constant_div->is_constant_like() && "Expected a constant."); + assert(constant_div->is_constant() && "Expected a constant."); assert(!constant_div->is_all_variables() && "Did not expect a variable."); assert(constant_div->is_power_like() && "Expected a power like."); auto constant_var_div = two_divided_three/variable; - assert(!constant_var_div->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_div->is_constant() && "Did not expect a constant."); assert(!constant_var_div->is_all_variables() && "Did not expect a variable."); assert(!constant_var_div->is_power_like() && "Did not expect a power like."); auto var_var_div = variable/a; - assert(!var_var_div->is_constant_like() && "Did not expect a constant."); + assert(!var_var_div->is_constant() && "Did not expect a constant."); assert(var_var_div->is_all_variables() && "Expected a variable."); assert(!var_var_div->is_power_like() && "Did not expect a power like."); @@ -2046,7 +2055,17 @@ template void test_fma() { auto vara_times_one_plus_varb_cast = graph::add_cast(vara_times_one_plus_varb); assert(vara_times_one_plus_varb_cast.get() && "Expected an add node."); - + +// fma(b,a,a) = a*(1 + b) + auto common1 = graph::fma(var_a, var_b, var_a); + auto common1_cast = graph::multiply_cast(common1); + assert(common1_cast.get() && "Expected multiply node."); +// fma(b,a,a) = a*(1 + b) + auto common2 = graph::fma(var_b, var_a, var_a); + auto common2_cast = graph::multiply_cast(common2); + assert(common2_cast.get() && "Expected multiply node."); + assert(common1->is_match(common2) && "Expected same graph"); + auto reduce1 = graph::fma(var_a, var_b, var_a*var_c); auto reduce1_cast = graph::multiply_cast(reduce1); assert(reduce1_cast.get() && "Expected multiply node."); @@ -2074,6 +2093,264 @@ template void test_fma() { assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() && "Expected multiply node."); +// fma(a, b, fma(c, b, d)) -> fma(b, a + c, d) + auto var_d = graph::variable (1, ""); + auto match1 = graph::fma(var_b, var_a + var_c, var_d); + auto nested_fma1 = graph::fma(var_a, var_b, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma1->is_match(match1) && "Expected match."); +// fma(b, a, fma(c, b, d)) -> fma(b, a + c, d) + auto nested_fma2 = graph::fma(var_b, var_a, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma2->is_match(match1) && "Expected match."); +// fma(a, b, fma(b, c, d)) -> fma(b, a + c, d) + auto nested_fma3 = graph::fma(var_a, var_b, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma3->is_match(match1) && "Expected match."); +// fma(b, a, fma(b, c, d)) -> fma(b, a + c, d) + auto nested_fma4 = graph::fma(var_b, var_a, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma4->is_match(match1) && "Expected match."); + +// fma(a, e*b, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto var_e = graph::variable (1, ""); + auto match2 = graph::fma(var_b, graph::fma(var_a, var_e, var_c), var_d); + auto nested_fma5 = graph::fma(var_a, + var_e*var_b, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma5->is_match(match2) && "Expected match."); +// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma6 = graph::fma(var_a, + var_b*var_e, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma6->is_match(match2) && "Expected match."); + // fma(a, e*b, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma7 = graph::fma(var_a, + var_e*var_b, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma7->is_match(match2) && "Expected match."); +// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma8 = graph::fma(var_a, + var_b*var_e, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma8->is_match(match2) && "Expected match."); + +// fma(e*b, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma9 = graph::fma(var_e*var_b, + var_a, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma9->is_match(match2) && "Expected match."); +// fma(b*e, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma10 = graph::fma(var_b*var_e, + var_a, + graph::fma(var_c, var_b, var_d)); + assert(nested_fma10->is_match(match2) && "Expected match."); +// fma(e*b, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma11 = graph::fma(var_e*var_b, + var_a, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma11->is_match(match2) && "Expected match."); +// fma(e*d, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d) + auto nested_fma12 = graph::fma(var_a, + var_b*var_e, + graph::fma(var_b, var_c, var_d)); + assert(nested_fma12->is_match(match2) && "Expected match."); + +// fma(a, b, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) + auto match3 = graph::fma(var_b, graph::fma(var_c, var_e, var_a), var_d); + auto nested_fma13 = graph::fma(var_a, + var_b, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma13->is_match(match3) && "Expected match."); +// fma(b, a, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma14 = graph::fma(var_b, + var_a, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma14->is_match(match3) && "Expected match."); +// fma(a, b, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma15 = graph::fma(var_a, + var_b, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma15->is_match(match3) && "Expected match."); +// fma(b, a, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma16 = graph::fma(var_b, + var_a, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma16->is_match(match3) && "Expected match."); +// fma(a, b, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma17 = graph::fma(var_a, + var_b, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma17->is_match(match3) && "Expected match."); +// fma(b, a, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma18 = graph::fma(var_b, + var_a, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma18->is_match(match3) && "Expected match."); +// fma(a, b, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma19 = graph::fma(var_a, + var_b, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma19->is_match(match3) && "Expected match."); +// fma(b, a, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d) + auto nested_fma20 = graph::fma(var_b, + var_a, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma20->is_match(match3) && "Expected match."); + +// fma(a, f*b, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto var_f = graph::variable (1, ""); + auto match4 = graph::fma(var_b, graph::fma(var_a, var_f, var_c*var_e), var_d); + auto nested_fma21 = graph::fma(var_a, + var_f*var_b, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma21->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma22 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma22->is_match(match4) && "Expected match."); +// fma(a, f*b, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma23 = graph::fma(var_a, + var_f*var_b, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma23->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma24 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma24->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma25 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma25->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma26 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_c, var_e*var_b, var_d)); + assert(nested_fma26->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma27 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma27->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma28 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_c, var_b*var_e, var_d)); + assert(nested_fma28->is_match(match4) && "Expected match."); +// fma(a, f*b, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma29 = graph::fma(var_a, + var_f*var_b, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma29->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma30 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma30->is_match(match4) && "Expected match."); +// fma(a, f*b, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma31= graph::fma(var_a, + var_f*var_b, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma31->is_match(match4) && "Expected match."); +// fma(a, b*f, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma32 = graph::fma(var_a, + var_b*var_f, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma32->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma33 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma33->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma34 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_e*var_b, var_c, var_d)); + assert(nested_fma34->is_match(match4) && "Expected match."); +// fma(f*b, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma35 = graph::fma(var_f*var_b, + var_a, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma35->is_match(match4) && "Expected match."); +// fma(b*f, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d) + auto nested_fma36 = graph::fma(var_b*var_f, + var_a, + graph::fma(var_b*var_e, var_c, var_d)); + assert(nested_fma36->is_match(match4) && "Expected match."); + +// fma(a^b,a^c,d) -> a^(b+c) +d + assert(graph::fma(graph::pow(var_a, var_b), + graph::pow(var_a, var_c), + var_d)->is_match(graph::pow(var_a, + var_b + var_c) + var_d) && + "Expected match"); + +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto matchv1 = graph::fma(graph::pow(var_b, two), + fma(var_b, var_a, var_c), + var_d); + auto matchv2 = graph::fma(graph::pow(var_b, two), + fma(var_b, var_c, var_a), + var_d); + auto nested_fmav1 = graph::fma(var_a, + graph::pow(var_b, three), + fma(var_c, + graph::pow(var_b, two), + var_d)); + assert(nested_fmav1->is_match(matchv1) && "Expected match"); +// fma(a,x^b,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav2 = graph::fma(var_a, + graph::pow(var_b, two), + fma(var_c, + graph::pow(var_b, three), + var_d)); + assert(nested_fmav2->is_match(matchv2) && "Expected match"); +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto nested_fmav3 = graph::fma(graph::pow(var_b, three), + var_a, + fma(var_c, + graph::pow(var_b, two), + var_d)); + assert(nested_fmav3->is_match(matchv1) && "Expected match"); +// fma(x^b,a,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav4 = graph::fma(graph::pow(var_b, two), + var_a, + fma(var_c, + graph::pow(var_b, three), + var_d)); + assert(nested_fmav4->is_match(matchv2) && "Expected match"); +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto nested_fmav5 = graph::fma(var_a, + graph::pow(var_b, three), + fma(graph::pow(var_b, two), + var_c, + var_d)); + assert(nested_fmav5->is_match(matchv1) && "Expected match"); +// fma(a,x^b,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav6 = graph::fma(var_a, + graph::pow(var_b, two), + fma(graph::pow(var_b, three), + var_c, + var_d)); + assert(nested_fmav6->is_match(matchv2) && "Expected match"); +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d + auto nested_fmav7 = graph::fma(graph::pow(var_b, three), + var_a, + fma(graph::pow(var_b, two), + var_c, + var_d)); + assert(nested_fmav7->is_match(matchv1) && "Expected match"); +// fma(x^b,a,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b + auto nested_fmav8 = graph::fma(graph::pow(var_b, two), + var_a, + fma(graph::pow(var_b, three), + var_c, + var_d)); + assert(nested_fmav8->is_match(matchv2) && "Expected match"); + // fma(a, b, a*b) -> 2*a*b // fma(b, a, a*b) -> 2*a*b // fma(a, b, b*a) -> 2*a*b @@ -2126,8 +2403,6 @@ template void test_fma() { "Expected constant node."); // fma(a,b/c,fma(d,e/c,g)) -> (a*b + d*e)/c + g - auto var_d = graph::variable (1, ""); - auto var_e = graph::variable (1, ""); auto chained_fma3 = fma(var_a, var_b/var_c, fma(var_d, var_e/var_c, var)); assert(add_cast(chained_fma3).get() && "expected add node."); // fma(a,b/c,fma(e/c,f,g)) -> (a*b + e*f)/c + g @@ -2170,7 +2445,7 @@ template void test_fma() { "Expetced a divide node."); // Test node properties. - assert(one_two_three->is_constant_like() && "Expected a constant."); + assert(one_two_three->is_constant() && "Expected a constant."); assert(!one_two_three->is_all_variables() && "Did not expect a variable."); assert(one_two_three->is_power_like() && "Expected a power like."); auto constant_fma = graph::fma(one_two_three, @@ -2181,11 +2456,11 @@ template void test_fma() { assert(!constant_fma->is_all_variables() && "Did not expect a variable."); assert(constant_fma->is_power_like() && "Expected a power like."); auto constant_var_fma = graph::fma(var_a, var_b, one); - assert(!constant_var_fma->is_constant_like() && "Did not expect a constant."); + assert(!constant_var_fma->is_constant() && "Did not expect a constant."); assert(!constant_var_fma->is_all_variables() && "Did not expect a variable."); assert(!constant_var_fma->is_power_like() && "Did not expect a power like."); auto var_var_fma = graph::fma(var_a, var_b, var_c); - assert(!var_var_fma->is_constant_like() && "Did not expect a constant."); + assert(!var_var_fma->is_constant() && "Did not expect a constant."); assert(var_var_fma->is_all_variables() && "Expected a variable."); assert(!var_var_fma->is_power_like() && "Did not expect a power like."); @@ -2271,7 +2546,6 @@ template void test_fma() { // fma(a/c, b, d*((f/c)*e)) -> fma(a, b, f*e*d)/c // fma(a, b/c, d*(e*(f/c))) -> fma(a, b, f*e*d)/c // fma(a/c, b, d*(e*(f/c))) -> fma(a, b, f*e*d)/c - auto var_f = graph::variable (1, ""); auto exp_a = (one + var_a); auto exp_b = (one + var_b); auto exp_c = (one + var_c); @@ -2459,6 +2733,23 @@ template void test_fma() { assert(fmaexp21_cast.get() && "Expected an add node."); assert(graph::divide_cast(fmaexp21_cast->get_left()).get() && "Expected a dive node on the left."); + +// fma(p2,p1,a) -> fma(p1,p2,a) + auto p1 = graph::piecewise_1D (std::vector ({static_cast (1.0), + static_cast (2.0)}), + var_a); + auto p2 = graph::piecewise_2D (std::vector ({static_cast (1.0), + static_cast (2.0), + static_cast (3.0), + static_cast (4.0)}), + 2, var_b, var_c); + auto fma_promote = graph::fma(p2, p1, var_a); + auto fma_promote_cast = graph::fma_cast(fma_promote); + assert(fma_promote_cast.get() && "Expected a fma node."); + assert(graph::piecewise_1D_cast(fma_promote_cast->get_left()).get() && + "Expected a piecewise 1d node on the left."); + assert(graph::piecewise_2D_cast(fma_promote_cast->get_middle()).get() && + "Expected a piecewise 2d node in the middle."); } //------------------------------------------------------------------------------ diff --git a/graph_tests/backend_test.cpp b/graph_tests/backend_test.cpp index 79224c2..b4e303d 100644 --- a/graph_tests/backend_test.cpp +++ b/graph_tests/backend_test.cpp @@ -543,6 +543,30 @@ template void test_backend() { static_cast (2.0) })); assert(!base_vec.is_negative() && "Expected false."); + + backend::buffer has_zero_vec(std::vector ({ + static_cast (3.0), + static_cast (0.0) + })); + assert(has_zero_vec.has_zero() && "Expected zero."); + backend::buffer has_zero_vec2(std::vector ({ + static_cast (3.0), + static_cast (1.0) + })); + assert(!has_zero_vec2.has_zero() && "Expected zero."); + assert(has_zero_vec2.is_normal() && "Expected normal."); + + backend::buffer inf_vec(std::vector ({ + static_cast (3.0), + static_cast (INFINITY) + })); + assert(!inf_vec.is_normal() && "Expected a inf."); + + backend::buffer nan_vec(std::vector ({ + static_cast (3.0), + static_cast (NAN) + })); + assert(!nan_vec.is_normal() && "Expected a NaN."); } //------------------------------------------------------------------------------ diff --git a/graph_tests/math_test.cpp b/graph_tests/math_test.cpp index 0c5440b..6bca69f 100644 --- a/graph_tests/math_test.cpp +++ b/graph_tests/math_test.cpp @@ -110,10 +110,10 @@ void test_sqrt() { // Test node properties. auto sqrt_const = graph::sqrt(graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var)); - assert(sqrt_const->is_constant_like() && "Expected a constant."); + assert(sqrt_const->is_constant() && "Expected a constant."); assert(!sqrt_const->is_all_variables() && "Did not expect a variable."); assert(sqrt_const->is_power_like() && "Expected a power like."); - assert(!sqrt_var->is_constant_like() && "Did not expect a constant."); + assert(!sqrt_var->is_constant() && "Did not expect a constant."); assert(sqrt_var->is_all_variables() && "Expected a variable."); assert(sqrt_var->is_power_like() && "Expected a power like."); } @@ -152,7 +152,7 @@ void test_exp() { assert(dexp_var->evaluate().at(0) == std::exp(static_cast (3.0))); // Test node properties. - assert(!exp_var->is_constant_like() && "Did not expect a constant."); + assert(!exp_var->is_constant() && "Did not expect a constant."); assert(exp_var->is_all_variables() && "Expected a variable."); assert(!exp_var->is_power_like() && "Did not expect a power like."); } @@ -239,7 +239,7 @@ void test_pow() { assert(sqrd_neg->evaluate().at(0) == static_cast (non_int_neg*non_int_neg) && "Expected x*x"); - auto three = graph::two (); + auto three = graph::constant (static_cast (3)); auto pow_pow1 = graph::pow(graph::pow(ten, three), two); auto pow_pow2 = graph::pow(ten, three*two); assert(pow_pow1->is_match(pow_pow2) && @@ -273,16 +273,16 @@ void test_pow() { auto var_a = graph::variable (1, ""); auto pow_const = graph::pow(three, graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), var_a)); - assert(pow_const->is_constant_like() && "Expected a constant."); + assert(pow_const->is_constant() && "Expected a constant."); assert(!pow_const->is_all_variables() && "Did not expect a variable."); assert(pow_const->is_power_like() && "Expected a power like."); auto pow_var = graph::pow(var_a, three); - assert(!pow_var->is_constant_like() && "Did not expect a constant."); + assert(!pow_var->is_constant() && "Did not expect a constant."); assert(pow_var->is_all_variables() && "Expected a variable."); assert(pow_var->is_power_like() && "Expected a power like."); auto var_b = graph::variable (1, ""); auto pow_var_var = graph::pow(var_a, var_b); - assert(!pow_var->is_constant_like() && "Did not expect a constant."); + assert(!pow_var->is_constant() && "Did not expect a constant."); assert(pow_var->is_all_variables() && "Expected a variable."); assert(pow_var->is_power_like() && "Expected a power like."); @@ -320,6 +320,10 @@ void test_pow() { auto powexp_float_cast = graph::pow_cast(powexp_float); assert(powexp_float_cast.get() && "Expected power cast."); + +// c1^c2 + assert(graph::constant_cast(graph::pow(two, three)).get() && + "Expected a constant node."); } //------------------------------------------------------------------------------ @@ -340,7 +344,7 @@ void test_log() { auto dlogy = logy->df(y); assert(graph::divide_cast(dlogy) && "Expected divide node."); - assert(!logy->is_constant_like() && "Did not expect a constant."); + assert(!logy->is_constant() && "Did not expect a constant."); assert(logy->is_all_variables() && "Expected a variable."); assert(!logy->is_power_like() && "Did not expect a power like."); } @@ -367,7 +371,7 @@ void test_erfi() { "Expected a constant node."); // Test node properties. - assert(!erfi->is_constant_like() && "Did not expect a constant."); + assert(!erfi->is_constant() && "Did not expect a constant."); assert(erfi->is_all_variables() && "Expected a variable."); assert(!erfi->is_power_like() && "Did not expect a power like."); } diff --git a/graph_tests/node_test.cpp b/graph_tests/node_test.cpp index 0eb9a68..a0bdc78 100644 --- a/graph_tests/node_test.cpp +++ b/graph_tests/node_test.cpp @@ -60,7 +60,7 @@ void test_constant() { assert(c1->is_match(c2) && "Expected match."); // Test node properties. - assert(c1->is_constant_like() && "Expected a constant."); + assert(c1->is_constant() && "Expected a constant."); assert(!c1->is_all_variables() && "Did not expect a variable."); assert(c1->is_power_like() && "Expected a power like."); } @@ -124,7 +124,7 @@ void test_variable() { assert(!v1->is_match(v2) && "Expected no match."); // Test node properties. - assert(!v1->is_constant_like() && "Did not expect a constant."); + assert(!v1->is_constant() && "Did not expect a constant."); assert(v1->is_all_variables() && "Expected a variable."); assert(v1->is_power_like() && "Expected a power like."); } @@ -157,7 +157,7 @@ void test_pseudo_variable() { "Expected constant node."); // Test node properties. - assert(!c->is_constant_like() && "Did not expect a constant."); + assert(!c->is_constant() && "Did not expect a constant."); assert(c->is_all_variables() && "Expected a variable."); assert(c->is_power_like() && "Expected a power like."); } diff --git a/graph_tests/physics_test.cpp b/graph_tests/physics_test.cpp index c9611e0..c46c1ba 100644 --- a/graph_tests/physics_test.cpp +++ b/graph_tests/physics_test.cpp @@ -606,6 +606,7 @@ template void test_efit() { for (size_t i = 0; i < 10000; i++) { solve.step(); + solve.sync_host(); } solve.sync_host(); diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index 8f56135..c092726 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -8,6 +8,7 @@ #undef NDEBUG #endif +#include #include #include "../graph_framework/arithmetic.hpp" @@ -84,6 +85,10 @@ template void piecewise_1D() { auto p2 = graph::piecewise_1D (std::vector ({static_cast (2.0), static_cast (4.0), static_cast (6.0)}), b); + auto p3 = graph::piecewise_1D (std::vector ({static_cast (2.0), + static_cast (4.0), + static_cast (6.0)}), a); + auto zero = graph::zero (); assert(graph::constant_cast(p1*zero).get() && @@ -95,6 +100,8 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::multiply_cast(p1*p2).get() && "Expected a multiply node."); + assert(graph::piecewise_1D_cast(p1*p3).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(p1 + zero).get() && "Expected a piecewise_1D node."); @@ -102,6 +109,8 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::add_cast(p1 + p2).get() && "Expected an add node."); + assert(graph::piecewise_1D_cast(p1 + p3).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(p1 - zero).get() && "Expected a piecewise_1D node."); @@ -109,20 +118,31 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::subtract_cast(p1 - p2).get() && "Expected a subtract node."); + assert(graph::piecewise_1D_cast(p1 - p3).get() && + "Expected a piecewise_1D node."); assert(graph::constant_cast(zero/p1).get() && "Expected a constant node."); assert(graph::piecewise_1D_cast(p1/two).get() && "Expected a piecewise_1D node."); - assert(graph::divide_cast(p1/p2).get() && - "Expected a divide node."); + assert(graph::multiply_cast(p1/p2).get() && + "Expected a multiply node."); + assert(graph::constant_cast(p1/p3).get() && + "Expected a constant node."); assert(graph::piecewise_1D_cast(graph::fma(p1, two, zero)).get() && "Expected a piecewise_1D node."); assert(graph::add_cast(graph::fma(p1, two, p2)).get() && "Expected an add node."); - assert(graph::fma_cast(graph::fma(p1, p2, two)).get() && - "Expected a fma node."); + auto temp = graph::fma(p1, p2, two); + assert(graph::multiply_cast(graph::fma(p1, p2, two)).get() && + "Expected a multiply node."); + assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && + "Expected an add node."); + assert(graph::piecewise_1D_cast(graph::fma(p1, p3, two)).get() && + "Expected a piecewise_1D node."); + assert(graph::piecewise_1D_cast(graph::fma(p1, p3, p1)).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(graph::sqrt(p1)).get() && "Expected a piecewise_1D node."); @@ -137,6 +157,8 @@ template void piecewise_1D() { "Expected a piecewise_1D node."); assert(graph::pow_cast(graph::pow(p1, p2)).get() && "Expected a pow constant."); + assert(graph::piecewise_1D_cast(graph::pow(p1, p3)).get() && + "Expected a piecewise_1D node."); assert(graph::piecewise_1D_cast(graph::sin(p1)).get() && "Expected a piecewise_1D node."); @@ -147,10 +169,12 @@ template void piecewise_1D() { assert(graph::piecewise_1D_cast(graph::tan(p1)).get() && "Expected a piecewise_1D node."); - assert(graph::atan_cast(graph::atan(p1, two)).get() && - "Expected an atan node."); + assert(graph::piecewise_1D_cast(graph::atan(p1, two)).get() && + "Expected a piecewise_1D node."); assert(graph::atan_cast(graph::atan(p1, p2)).get() && - "Expected a atan constant."); + "Expected an atan node."); + assert(graph::constant_cast(graph::atan(p1, p3)).get() && + "Expected a constant node."); a->set(static_cast (1.5)); compile ({graph::variable_cast(a)}, @@ -166,7 +190,42 @@ template void piecewise_1D() { compile ({graph::variable_cast(a)}, {p1}, {}, static_cast (3.0), 0.0); - + + a->set(static_cast (1.5)); + compile ({graph::variable_cast(a)}, + {p1 + p3}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(a)}, + {p1 - p3}, {}, + static_cast (-2.0), 0.0); + compile ({graph::variable_cast(a)}, + {p1*p3}, {}, + static_cast (8.0), 0.0); + compile ({graph::variable_cast(a)}, + {p1/p3}, {}, + static_cast (0.5), 0.0); + compile ({graph::variable_cast(a), + graph::variable_cast(b)}, + {graph::fma(p1, p3, p2)}, {}, + static_cast (10.0), 0.0); + compile ({graph::variable_cast(a)}, + {graph::pow(p1, p3)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (4.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(a)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan(static_cast (4.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(a)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan2(static_cast (4.0), + static_cast (2.0))), + 0.0); + } + auto pc = graph::piecewise_1D (std::vector ({static_cast (10.0), static_cast (10.0), static_cast (10.0)}), a); @@ -194,6 +253,17 @@ template void piecewise_2D() { static_cast (6.0), static_cast (10.0)}), 2, bx, by); + auto p3 = graph::piecewise_2D (std::vector ({static_cast (2.0), + static_cast (4.0), + static_cast (6.0), + static_cast (10.0)}), + 2, ax, ay); + auto p4 = graph::piecewise_1D (std::vector ({static_cast (2.0), + static_cast (4.0)}), + ax); + auto p5 = graph::piecewise_1D (std::vector ({static_cast (2.0), + static_cast (4.0)}), + ay); auto zero = graph::zero (); @@ -206,6 +276,12 @@ template void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::multiply_cast(p1*p2).get() && "Expected a multiply node."); + assert(graph::piecewise_2D_cast(p1*p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1*p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1*p5).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(p1 + zero).get() && "Expected a piecewise_2D node."); @@ -213,6 +289,12 @@ template void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::add_cast(p1 + p2).get() && "Expected an add node."); + assert(graph::piecewise_2D_cast(p1 + p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 + p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 + p5).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(p1 - zero).get() && "Expected a piecewise_2D node."); @@ -220,20 +302,42 @@ template void piecewise_2D() { "Expected a piecewise_2D node."); assert(graph::subtract_cast(p1 - p2).get() && "Expected a subtract node."); + assert(graph::piecewise_2D_cast(p1 - p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 - p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1 - p5).get() && + "Expected a piecewise_2D node."); assert(graph::constant_cast(zero/p1).get() && "Expected a constant node."); assert(graph::piecewise_2D_cast(p1/two).get() && "Expected a piecewise_2D node."); - assert(graph::divide_cast(p1/p2).get() && - "Expected a divide node."); + assert(graph::multiply_cast(p1/p2).get() && + "Expected a multiply node."); + assert(graph::piecewise_2D_cast(p1/p3).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1/p4).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(p1/p5).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(graph::fma(p1, two, zero)).get() && "Expected a piecewise_2D node."); assert(graph::add_cast(graph::fma(p1, two, p2)).get() && "Expected an add node."); - assert(graph::fma_cast(graph::fma(p1, p2, two)).get() && - "Expected a fma node."); + assert(graph::multiply_cast(graph::fma(p1, p2, two)).get() && + "Expected a multiply node."); + assert(graph::add_cast(graph::fma(p1, p3, p2)).get() && + "Expected an add node."); + assert(graph::piecewise_2D_cast(graph::fma(p1, p3, two)).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(graph::fma(p1, p3, p1)).get() && + "Expected a piecewise_2D node."); + assert(graph::add_cast(graph::fma(p1, p4, p2)).get() && + "Expected an add node."); + assert(graph::add_cast(graph::fma(p1, p5, p2)).get() && + "Expected an add node."); assert(graph::piecewise_2D_cast(graph::sqrt(p1)).get() && "Expected a piecewise_2D node."); @@ -247,7 +351,13 @@ template void piecewise_2D() { assert(graph::piecewise_2D_cast(graph::pow(p1, two)).get() && "Expected a piecewise_2D node."); assert(graph::pow_cast(graph::pow(p1, p2)).get() && - "Expected a pow constant."); + "Expected a pow node."); + assert(graph::piecewise_2D_cast(graph::pow(p1, p3)).get() && + "Expected a pow node."); + assert(graph::piecewise_2D_cast(graph::pow(p1, p4)).get() && + "Expected a piecewise_2D node."); + assert(graph::piecewise_2D_cast(graph::pow(p1, p5)).get() && + "Expected a piecewise_2D node."); assert(graph::piecewise_2D_cast(graph::sin(p1)).get() && "Expected a piecewise_2D node."); @@ -258,10 +368,16 @@ template void piecewise_2D() { assert(graph::piecewise_2D_cast(graph::tan(p1)).get() && "Expected a piecewise_2D node."); - assert(graph::atan_cast(graph::atan(p1, two)).get() && - "Expected an atan node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, two)).get() && + "Expected a piecewise_2d node."); assert(graph::atan_cast(graph::atan(p1, p2)).get() && - "Expected a atan constant."); + "Expected an atan node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, p3)).get() && + "Expected a piecewise_2d node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, p4)).get() && + "Expected a piecewise_2d node."); + assert(graph::piecewise_2D_cast(graph::atan(p1, p5)).get() && + "Expected a piecewise_2d node."); ax->set(static_cast (1.5)); ay->set(static_cast (1.5)); @@ -290,7 +406,140 @@ template void piecewise_2D() { graph::variable_cast(ay)}, {p1}, {}, static_cast (3.0), 0.0); - + + ax->set(static_cast (0.5)); + ay->set(static_cast (1.5)); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 + p3}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 - p3}, {}, + static_cast (-2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1*p3}, {}, + static_cast (8.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1/p3}, {}, + static_cast (0.5), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay), + graph::variable_cast(bx), + graph::variable_cast(by)}, + {graph::fma(p1, p3, p2)}, {}, + static_cast (10.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::pow(p1, p3)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (4.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan(static_cast (4.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p3)}, {}, + static_cast (std::atan2(static_cast (4.0), + static_cast (2.0))), + 0.0); + } + +// Test row combines. + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 + p4}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 - p4}, {}, + static_cast (-2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1*p4}, {}, + static_cast (8.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1/p4}, {}, + static_cast (0.5), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay), + graph::variable_cast(bx), + graph::variable_cast(by)}, + {graph::fma(p1, p4, p2)}, {}, + static_cast (10.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::pow(p1, p4)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (4.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p4)}, {}, + static_cast (std::atan(static_cast (4.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p4)}, {}, + static_cast (std::atan2(static_cast (4.0), + static_cast (2.0))), + 0.0); + } + +// Test column combines. + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 + p5}, {}, + static_cast (4.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1 - p5}, {}, + static_cast (0.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1*p5}, {}, + static_cast (4.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1/p5}, {}, + static_cast (1.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay), + graph::variable_cast(bx), + graph::variable_cast(by)}, + {graph::fma(p1, p5, p2)}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::pow(p1, p5)}, {}, + static_cast (std::pow(static_cast (2.0), + static_cast (2.0))), 0.0); + if constexpr (jit::is_complex ()) { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p5)}, {}, + static_cast (std::atan(static_cast (2.0) / + static_cast (2.0))), + 0.0); + } else { + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {graph::atan(p1, p5)}, {}, + static_cast (std::atan2(static_cast (2.0), + static_cast (2.0))), + 0.0); + } + auto pc = graph::piecewise_2D (std::vector ({static_cast (10.0), static_cast (10.0), static_cast (10.0), -- GitLab From b166360fa1f26334c8ed76ef8e94abe868eaefbf Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Mon, 8 Jul 2024 23:23:56 -0400 Subject: [PATCH 59/63] Avoid NaN and Inf by checking if constant combine causes these before doing the reduction. Add timing for setup, init, compile, and step. --- graph_benchmark/xrays_bench.cpp | 22 +- graph_framework.xcodeproj/project.pbxproj | 1 - graph_framework/arithmetic.hpp | 249 ++++++++++++++-------- graph_framework/backend.hpp | 10 +- graph_framework/node.hpp | 10 +- graph_framework/piecewise.hpp | 7 +- 6 files changed, 196 insertions(+), 103 deletions(-) diff --git a/graph_benchmark/xrays_bench.cpp b/graph_benchmark/xrays_bench.cpp index 9496977..0af790d 100644 --- a/graph_benchmark/xrays_bench.cpp +++ b/graph_benchmark/xrays_bench.cpp @@ -38,10 +38,14 @@ void bench_runner() { const size_t batch = NUM_RAYS/threads.size(); const size_t extra = NUM_RAYS%threads.size(); - timeing::measure_diagnostic_threaded timing; + timeing::measure_diagnostic_threaded time_setup("Setup Time"); + timeing::measure_diagnostic_threaded time_init("Init Time"); + timeing::measure_diagnostic_threaded time_compile("Compile Time"); + timeing::measure_diagnostic_threaded time_steps("Time Steps"); for (size_t i = 0, ie = threads.size(); i < ie; i++) { - threads[i] = std::thread([&timing, batch, extra] (const size_t thread_number) -> void { + threads[i] = std::thread([&time_setup, &time_init, &time_compile, &time_steps, batch, extra] (const size_t thread_number) -> void { + time_setup.start_time(thread_number); const size_t local_num_rays = batch + (extra > thread_number ? 1 : 0); @@ -78,25 +82,33 @@ void bench_runner() { eq, "", local_num_rays, thread_number); + time_setup.end_time(thread_number); + time_init.start_time(thread_number); solve.init(kx); + time_init.end_time(thread_number); + time_compile.start_time(thread_number); solve.compile(); + time_compile.end_time(thread_number); - timing.start_time(thread_number); + time_steps.start_time(thread_number); for (size_t j = 0; j < num_steps; j++) { for (size_t k = 0; k < SUB_STEPS; k++) { solve.step(); } } solve.sync_host(); - timing.end_time(thread_number); + time_steps.end_time(thread_number); }, i); } for (std::thread &t : threads) { t.join(); } - timing.print(); + time_setup.print(); + time_init.print(); + time_compile.print(); + time_steps.print(); std::cout << "--------------------------------------------------------------------------------" << std::endl << std::endl; diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index b95c2df..855cfac 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -1282,7 +1282,6 @@ "VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"", "EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"", USE_METAL, - "CXX=\\\"c++\\ -I/Users/m4c/Projects/graph_framework/graph_framework\\ -std=gnu++2a\\\"", "$(inherited)", ); MACOSX_DEPLOYMENT_TARGET = 13.3; diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index d2ab2e0..ad30bd8 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -801,7 +801,7 @@ namespace graph { // Note need to make sure c1 doesn't contain any zeros. if (lm->get_left()->is_constant() && rm->get_left()->is_constant() && - !lm->has_constant_zero()) { + !lm->get_left()->has_constant_zero()) { return lm->get_left()*(lm->get_right() - (rm->get_left()/lm->get_left())*rm->get_right()); } @@ -1416,7 +1416,10 @@ namespace graph { // c1*(c2*v) -> c3*v if (is_constant_combineable(this->left, rm->get_left())) { - return (this->left*rm->get_left())*rm->get_right(); + auto temp = this->left*rm->get_left(); + if (temp->is_normal()) { + return temp*rm->get_right(); + } } if (this->left->is_match(rm->get_left())) { @@ -1455,20 +1458,28 @@ namespace graph { if (lm.get() && rm.get()) { if (is_constant_combineable(lm->get_left(), rm->get_left())) { - return (lm->get_left()*rm->get_left()) * - (lm->get_right()*rm->get_right()); + auto temp = lm->get_left()*rm->get_left(); + if (temp->is_normal()) { + return temp*(lm->get_right()*rm->get_right()); + } } else if (is_constant_combineable(lm->get_left(), rm->get_right())) { - return (lm->get_left()*rm->get_right()) * - (lm->get_right()*rm->get_left()); + auto temp = lm->get_left()*rm->get_right(); + if (temp->is_normal()) { + return temp*(lm->get_right()*rm->get_left()); + } } else if (is_constant_combineable(lm->get_right(), rm->get_left())) { - return (lm->get_right()*rm->get_left()) * - (lm->get_left()*rm->get_right()); + auto temp = lm->get_right()*rm->get_left(); + if (temp->is_normal()) { + return temp*(lm->get_left()*rm->get_right()); + } } else if (is_constant_combineable(lm->get_right(), rm->get_right())) { - return (lm->get_right()*rm->get_right()) * - (lm->get_left()*rm->get_left()); + auto temp = lm->get_right()*rm->get_right(); + if (temp->is_normal()) { + return temp*(lm->get_left()*rm->get_left()); + } } // Gather common terms. This will help reduce sqrt(a)*sqrt(a). @@ -2147,10 +2158,17 @@ namespace graph { if (rm.get()) { if (is_constant_combineable(rm->get_left(), this->left)) { - return (this->left/rm->get_left())/rm->get_right(); - } else if (is_constant_combineable(rm->get_left(), - this->left)) { - return (this->left/rm->get_right())/rm->get_left(); + auto temp = this->left/rm->get_left(); + if (temp->is_normal()) { + return temp/rm->get_right(); + } + } + if (is_constant_combineable(rm->get_right(), + this->left)) { + auto temp = this->left/rm->get_right(); + if (temp->is_normal()) { + return temp/rm->get_left(); + } } } @@ -2162,20 +2180,31 @@ namespace graph { // (a*c1)/(b*c2) -> c3*a/b if (is_constant_combineable(lm->get_left(), rm->get_left())) { - return (lm->get_left()/rm->get_left()) * - (lm->get_right()/rm->get_right()); - } else if (is_constant_combineable(lm->get_left(), - rm->get_right())) { - return (lm->get_left()/rm->get_right()) * - (lm->get_right()/rm->get_left()); - } else if (is_constant_combineable(lm->get_right(), - rm->get_left())) { - return (lm->get_right()/rm->get_left()) * - (lm->get_left()/rm->get_right()); - } else if (is_constant_combineable(lm->get_right(), - rm->get_right())) { - return (lm->get_right()/rm->get_right()) * - (lm->get_left()/rm->get_left()); + auto temp = lm->get_left()/rm->get_left(); + if (temp->is_normal()) { + return temp*lm->get_right()/rm->get_right(); + } + } + if (is_constant_combineable(lm->get_left(), + rm->get_right())) { + auto temp = lm->get_left()/rm->get_right(); + if (temp->is_normal()) { + return temp*lm->get_right()/rm->get_left(); + } + } + if (is_constant_combineable(lm->get_right(), + rm->get_left())) { + auto temp = lm->get_right()/rm->get_left(); + if (temp->is_normal()) { + return temp*lm->get_left()/rm->get_right(); + } + } + if (is_constant_combineable(lm->get_right(), + rm->get_right())) { + auto temp = lm->get_right()/rm->get_right(); + if (temp->is_normal()) { + return temp*lm->get_left()/rm->get_left(); + } } // (a*b)/(a*c) -> b/c @@ -2678,27 +2707,42 @@ namespace graph { if (is_constant_combineable(this->left, rm->get_left()) && !this->left->has_constant_zero()) { - return this->left*fma(rm->get_left()/this->left, - rm->get_right(), - this->middle); - } else if (is_constant_combineable(this->middle, - rm->get_left()) && - !this->middle->has_constant_zero()) { - return this->middle*fma(rm->get_left()/this->middle, - rm->get_right(), - this->left); - } else if (is_constant_combineable(this->left, - rm->get_right()) && - !this->left->has_constant_zero()) { - return this->left*fma(rm->get_right()/this->left, - rm->get_left(), - this->middle); - } else if (is_constant_combineable(this->middle, - rm->get_right()) && - !this->middle->has_constant_zero()) { - return this->middle*fma(rm->get_right()/this->middle, - rm->get_left(), - this->left); + auto temp = rm->get_left()/this->left; + if (temp->is_normal()) { + return this->left*fma(temp, + rm->get_right(), + this->middle); + } + } + if (is_constant_combineable(this->middle, + rm->get_left()) && + !this->middle->has_constant_zero()) { + auto temp = rm->get_left()/this->middle; + if (temp->is_normal()) { + return this->middle*fma(temp, + rm->get_right(), + this->left); + } + } + if (is_constant_combineable(this->left, + rm->get_right()) && + !this->left->has_constant_zero()) { + auto temp = rm->get_right()/this->left; + if (temp->is_normal()) { + return this->left*fma(temp, + rm->get_left(), + this->middle); + } + } + if (is_constant_combineable(this->middle, + rm->get_right()) && + !this->middle->has_constant_zero()) { + auto temp = rm->get_right()/this->middle; + if (temp->is_normal()) { + return this->middle*fma(temp, + rm->get_left(), + this->left); + } } // Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c) @@ -2720,27 +2764,42 @@ namespace graph { if (is_constant_combineable(rm->get_left(), lm->get_left()) && !lm->get_left()->has_constant_zero()) { - return lm->get_left()*fma(lm->get_right(), - this->middle, - (rm->get_left()/lm->get_left())*rm->get_right()); - } else if (is_constant_combineable(rm->get_left(), - lm->get_right()) && - !lm->get_right()->has_constant_zero()) { - return lm->get_right()*fma(lm->get_left(), - this->middle, - (rm->get_left()/lm->get_right())*rm->get_right()); - } else if (is_constant_combineable(rm->get_right(), - lm->get_left()) && - !lm->get_left()->has_constant_zero()) { - return lm->get_left()*fma(lm->get_right(), - this->middle, - (rm->get_right()/lm->get_left())*rm->get_left()); - } else if (is_constant_combineable(rm->get_right(), - lm->get_right()) && - !lm->get_right()->has_constant_zero()) { - return lm->get_right()*fma(lm->get_left(), - this->middle, - (rm->get_right()/lm->get_right())*rm->get_left()); + auto temp = rm->get_left()/lm->get_left(); + if (temp->is_normal()){ + return lm->get_left()*fma(lm->get_right(), + this->middle, + temp*rm->get_right()); + } + } + if (is_constant_combineable(rm->get_left(), + lm->get_right()) && + !lm->get_right()->has_constant_zero()) { + auto temp = rm->get_left()/lm->get_right(); + if (temp->is_normal()){ + return lm->get_right()*fma(lm->get_left(), + this->middle, + temp*rm->get_right()); + } + } + if (is_constant_combineable(rm->get_right(), + lm->get_left()) && + !lm->get_left()->has_constant_zero()) { + auto temp = rm->get_right()/lm->get_left(); + if (temp->is_normal()) { + return lm->get_left()*fma(lm->get_right(), + this->middle, + temp*rm->get_left()); + } + } + if (is_constant_combineable(rm->get_right(), + lm->get_right()) && + !lm->get_right()->has_constant_zero()) { + auto temp = rm->get_right()/lm->get_right(); + if (temp->is_normal()) { + return lm->get_right()*fma(lm->get_left(), + this->middle, + temp*rm->get_left()); + } } } @@ -2759,16 +2818,24 @@ namespace graph { // fma(a,c1*b,c) -> fma(c1,a*b,c) if (is_constant_combineable(this->left, mm->get_left())) { - return fma(this->left*mm->get_left(), - mm->get_right(), - this->right); - } else if (is_constant_combineable(this->left, - mm->get_right())) { - return fma(this->left*mm->get_right(), - mm->get_left(), - this->right); - } else if (is_constant_promotable(mm->get_left(), - this->left)) { + auto temp = this->left*mm->get_left(); + if (temp->is_normal()) { + return fma(temp, + mm->get_right(), + this->right); + } + } + if (is_constant_combineable(this->left, + mm->get_right())) { + auto temp = this->left*mm->get_right(); + if (temp->is_normal()) { + return fma(temp, + mm->get_left(), + this->right); + } + } + if (is_constant_promotable(mm->get_left(), + this->left)) { return fma(mm->get_left(), this->left*mm->get_right(), this->right); @@ -2782,13 +2849,20 @@ namespace graph { if (is_constant_combineable(this->left, rd->get_left()) && !this->left->has_constant_zero()) { - return this->left*(this->middle + - rd->get_left()/(this->left*rd->get_right())); - } else if (is_constant_combineable(this->middle, + auto temp = rd->get_left()/this->left; + if (temp->is_normal()) { + return this->left*(this->middle + + temp/rd->get_right()); + } + } + if (is_constant_combineable(this->middle, rd->get_left()) && !this->middle->has_constant_zero()) { - return this->middle*(this->left + - rd->get_left()/(this->middle*rd->get_right())); + auto temp = rd->get_left()/this->middle; + if (temp->is_normal()) { + return this->middle*(this->left + + temp/rd->get_right()); + } } } @@ -3222,7 +3296,10 @@ namespace graph { // Promote constants out to the left. if (is_constant_combineable(this->left, this->right) && !this->left->has_constant_zero()) { - return this->left*(this->middle + this->right/this->left); + auto temp = this->right/this->left; + if (temp->is_normal()) { + return this->left*(this->middle + temp); + } } // Change negative exponents to divide so that can be factored out. diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index e3fff57..a5aa120 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -154,7 +154,7 @@ namespace backend { /// @returns Returns true if every element is zero. //------------------------------------------------------------------------------ bool is_zero() const { - for (T d : memory) { + for (const T &d : memory) { if (d != static_cast (0.0)) { return false; } @@ -169,7 +169,7 @@ namespace backend { /// @returns Returns true if every element is zero. //------------------------------------------------------------------------------ bool has_zero() const { - for (T d : memory) { + for (const T &d : memory) { if (d == static_cast (0.0)) { return true; } @@ -184,7 +184,7 @@ namespace backend { /// @returns Returns true if every element is negative. //------------------------------------------------------------------------------ bool is_negative() const { - for (T d : memory) { + for (const T &d : memory) { if (std::real(d) > std::real(static_cast (0.0))) { return false; } @@ -199,7 +199,7 @@ namespace backend { /// @returns Returns true if every element is negative one. //------------------------------------------------------------------------------ bool is_none() const { - for (T d : memory) { + for (const T &d : memory) { if (d != static_cast (-1.0)) { return false; } @@ -278,7 +278,7 @@ namespace backend { /// @returns False if any NaN or Inf is found. //------------------------------------------------------------------------------ bool is_normal() const { - for (T x : memory) { + for (const T &x : memory) { if constexpr (jit::is_complex ()) { if (std::isnan(std::real(x)) || std::isinf(std::real(x)) || std::isnan(std::imag(x)) || std::isinf(std::imag(x))) { diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index ccaae18..ef5c140 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -204,6 +204,15 @@ namespace graph { return false; } +//------------------------------------------------------------------------------ +/// @brief Test if the result is normal. +/// +/// @returns True if the node is normal. +//------------------------------------------------------------------------------ + bool is_normal() { + return this->evaluate().is_normal(); + } + //------------------------------------------------------------------------------ /// @brief Test if all the subnodes terminate in variables. /// @@ -357,7 +366,6 @@ namespace graph { constant_node(const backend::buffer &d) : leaf_node (constant_node::to_string(d.at(0)), 1, false), data(d) { assert(d.size() == 1 && "Constants need to be scalar functions."); - assert(d.is_normal() && "NaN or Inf value."); } //------------------------------------------------------------------------------ diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 73c30f7..ad06663 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -136,9 +136,7 @@ void compile_index(std::ostringstream &stream, piecewise_1D_node(const backend::buffer &d, shared_leaf x) : straight_node (x, piecewise_1D_node::to_string(d, x)), - data_hash(piecewise_1D_node::hash_data(d)) { - assert(d.is_normal() && "NaN or Inf value."); - } + data_hash(piecewise_1D_node::hash_data(d)) {} //------------------------------------------------------------------------------ /// @brief Evaluate the results of the piecewise constant. @@ -609,9 +607,8 @@ void compile_index(std::ostringstream &stream, branch_node (x, y, piecewise_2D_node::to_string(d, x, y)), data_hash(piecewise_2D_node::hash_data(d)), num_columns(n) { - assert(d.size()/n && + assert(d.size()%n == 0 && "Expected the data buffer to be a multiple of the number of columns."); - assert(d.is_normal() && "NaN or Inf value."); } //------------------------------------------------------------------------------ -- GitLab From f786bb923d858f582210c716c1335e85b5e6b135 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Tue, 9 Jul 2024 16:30:34 -0400 Subject: [PATCH 60/63] Fix file headers. --- graph_framework/math.hpp | 1 - graph_framework/register.hpp | 11 ++++------- graph_framework/vector.hpp | 7 ++----- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index a8373b3..dca53cb 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -3,7 +3,6 @@ /// @brief Defined basic math functions. //------------------------------------------------------------------------------ - #ifndef math_h #define math_h diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index b1f18cc..4139551 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -1,10 +1,7 @@ -// -// register.hpp -// graph_framework -// -// Created by Cianciosa, Mark on 12/8/22. -// Copyright © 2022 Cianciosa, Mark R. All rights reserved. -// +//------------------------------------------------------------------------------ +/// @file register.hpp +/// @brief Utilities for writting jit source code. +//------------------------------------------------------------------------------ #ifndef register_h #define register_h diff --git a/graph_framework/vector.hpp b/graph_framework/vector.hpp index edafc7e..893ab00 100644 --- a/graph_framework/vector.hpp +++ b/graph_framework/vector.hpp @@ -1,9 +1,6 @@ //------------------------------------------------------------------------------ -/// vector.hpp -/// graph_framework -/// -/// Created by Cianciosa, Mark R. on 3/31/22. -/// Copyright © 2022 Cianciosa, Mark R. All rights reserved. +/// @file vector.hpp +/// @brief Defines vectors of graphs. //------------------------------------------------------------------------------ #ifndef vector_h -- GitLab From 08e7e108f389e8c8dc805c83830ebcf77aaedd34 Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Wed, 10 Jul 2024 12:46:24 -0400 Subject: [PATCH 61/63] Fix row and column constant reductions. Do not combine if the dimensions are not the same. Add unit tests. --- graph_framework/arithmetic.hpp | 8 ++ graph_framework/backend.hpp | 140 +++++++++++++++--------------- graph_framework/math.hpp | 2 + graph_framework/metal_context.hpp | 8 +- graph_framework/piecewise.hpp | 42 +++++++-- graph_framework/trigonometry.hpp | 2 + graph_tests/piecewise_test.cpp | 125 ++++++++++++++++++-------- 7 files changed, 209 insertions(+), 118 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index ad30bd8..3724701 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -194,6 +194,7 @@ namespace graph { pr2->get_right()); } +#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -224,6 +225,7 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +#endif // Idenity reductions. if (this->left->is_match(this->right)) { @@ -708,6 +710,7 @@ namespace graph { pr2->get_right()); } +#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -738,6 +741,7 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +#endif // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). @@ -1302,6 +1306,7 @@ namespace graph { pr2->get_right()); } +#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -1332,6 +1337,7 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +#endif // Move constants to the left. if (is_constant_promotable(this->right, this->left)) { @@ -2085,6 +2091,7 @@ namespace graph { pr2->get_right()); } +#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -2115,6 +2122,7 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +#endif if (this->left->is_match(this->right)) { return one (); diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index a5aa120..254f73a 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -308,9 +308,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] += x[j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_rows + j] += x[i]; } } } else { @@ -320,9 +320,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[j] + x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[i] + x[i*num_colmns + j]; } } memory = m; @@ -344,9 +344,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] += x[i]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] += x[j]; } } } else { @@ -356,9 +356,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[i] + x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j] + x[i*num_colmns + j]; } } memory = m; @@ -380,9 +380,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] -= x[j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] -= x[i]; } } } else { @@ -394,7 +394,7 @@ namespace backend { const size_t num_rows = size(); for (size_t i = 0; i < num_colmns; i++) { for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[j] - x[i*num_rows + j]; + m[i*num_colmns + j] = memory[i] - x[i*num_colmns + j]; } } memory = m; @@ -416,9 +416,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] -= x[i]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] -= x[j]; } } } else { @@ -428,9 +428,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[i] - x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j] - x[i*num_colmns + j]; } } memory = m; @@ -452,9 +452,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] *= x[j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] *= x[i]; } } } else { @@ -464,9 +464,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[j]*x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[i]*x[i*num_colmns + j]; } } memory = m; @@ -488,9 +488,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] *= x[i]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] *= x[j]; } } } else { @@ -500,9 +500,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[i]*x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j]*x[i*num_colmns + j]; } } memory = m; @@ -524,9 +524,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] /= x[j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] /= x[i]; } } } else { @@ -536,9 +536,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[j]/x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[i]/x[i*num_colmns + j]; } } memory = m; @@ -560,9 +560,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] /= x[i]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] /= x[j]; } } } else { @@ -572,9 +572,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = memory[i]/x[i*num_rows + j]; + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = memory[j]/x[i*num_colmns + j]; } } memory = m; @@ -596,12 +596,12 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { if constexpr (jit::is_complex ()) { - memory[i*num_rows + j] = std::atan(x[j]/memory[i*num_rows + j]); + memory[i*num_colmns + j] = std::atan(x[i]/memory[i*num_colmns + j]); } else { - memory[i*num_rows + j] = std::atan2(x[j], memory[i*num_rows + j]); + memory[i*num_colmns + j] = std::atan2(x[i], memory[i*num_colmns + j]); } } } @@ -612,12 +612,12 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { if constexpr (jit::is_complex ()) { - m[i*num_rows + j] = std::atan(x[i*num_rows + j]/memory[j]); + m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[i]); } else { - m[i*num_rows + j] = std::atan2(x[i*num_rows + j], memory[j]); + m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[i]); } } } @@ -643,9 +643,9 @@ namespace backend { for (size_t i = 0; i < num_colmns; i++) { for (size_t j = 0; j < num_rows; j++) { if constexpr (jit::is_complex ()) { - memory[i*num_rows + j] = std::atan(x[i]/memory[i*num_rows + j]); + memory[i*num_colmns + j] = std::atan(x[j]/memory[i*num_colmns + j]); } else { - memory[i*num_rows + j] = std::atan2(x[i], memory[i*num_rows + j]); + memory[i*num_colmns + j] = std::atan2(x[j], memory[i*num_colmns + j]); } } } @@ -656,12 +656,12 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { if constexpr (jit::is_complex ()) { - m[i*num_rows + j] = std::atan(x[i*num_rows + j]/memory[i]); + m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[j]); } else { - m[i*num_rows + j] = std::atan2(x[i*num_rows + j], memory[i]); + m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[j]); } } } @@ -684,9 +684,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] = std::pow(memory[i*num_rows + j], x[j]); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] = std::pow(memory[i*num_colmns + j], x[i]); } } } else { @@ -698,7 +698,7 @@ namespace backend { const size_t num_rows = size(); for (size_t i = 0; i < num_colmns; i++) { for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = std::pow(memory[j], x[i*num_rows + j]); + m[i*num_colmns + j] = std::pow(memory[i], x[i*num_colmns + j]); } } memory = m; @@ -720,9 +720,9 @@ namespace backend { const size_t num_colmns = size()/x.size(); const size_t num_rows = x.size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - memory[i*num_rows + j] = std::pow(memory[i*num_rows + j], x[i]); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + memory[i*num_colmns + j] = std::pow(memory[i*num_colmns + j], x[j]); } } } else { @@ -732,9 +732,9 @@ namespace backend { std::vector m(x.size()); const size_t num_colmns = x.size()/size(); const size_t num_rows = size(); - for (size_t i = 0; i < num_colmns; i++) { - for (size_t j = 0; j < num_rows; j++) { - m[i*num_rows + j] = std::pow(memory[i], x[i*num_rows + j]); + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < num_colmns; j++) { + m[i*num_colmns + j] = std::pow(memory[j], x[i*num_colmns + j]); } } memory = m; diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index dca53cb..ed05de0 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -898,6 +898,7 @@ namespace graph { pr2->get_right()); } +#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -928,6 +929,7 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +#endif auto lp = pow_cast(this->left); // Only run this reduction if the right is an integer constant value. diff --git a/graph_framework/metal_context.hpp b/graph_framework/metal_context.hpp index 03bde6b..f46f2be 100644 --- a/graph_framework/metal_context.hpp +++ b/graph_framework/metal_context.hpp @@ -179,17 +179,17 @@ namespace gpu { MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new]; discriptor.textureType = MTLTextureType2D; discriptor.pixelFormat = MTLPixelFormatR32Float; - discriptor.width = size[0]; - discriptor.height = size[1]; + discriptor.width = size[1]; + discriptor.height = size[0]; discriptor.storageMode = MTLStorageModeManaged; discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined; discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked; discriptor.usage = MTLTextureUsageShaderRead; texture_arguments[data] = [device newTextureWithDescriptor:discriptor]; - [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[0], size[1]) + [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[1], size[0]) mipmapLevel:0 withBytes:reinterpret_cast (data) - bytesPerRow:4*size[0]]; + bytesPerRow:4*size[1]]; [encoder optimizeContentsForGPUAccess:texture_arguments[data]]; } diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index ad06663..ac610b9 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -439,7 +439,18 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && this->arg->is_match(temp->get_arg()); + return temp.get() && + this->arg->is_match(temp->get_arg()) && + (temp->get_size() == this->get_size()); + } + +//------------------------------------------------------------------------------ +/// @brief Get the size of the buffer. +/// +/// @returns The size of the buffer. +//------------------------------------------------------------------------------ + size_t get_size() const { + return leaf_node::backend_cache[data_hash].size(); } }; @@ -620,6 +631,16 @@ void compile_index(std::ostringstream &stream, return num_columns; } +//------------------------------------------------------------------------------ +/// @brief Get the number of columns. +/// +/// @returns The number of columns in the constant. +//------------------------------------------------------------------------------ + size_t get_num_rows() const { + return leaf_node::backend_cache[data_hash].size() / + num_columns; + } + //------------------------------------------------------------------------------ /// @brief Evaluate the results of the piecewise constant. /// @@ -965,10 +986,11 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { auto temp = piecewise_2D_cast(x); - return temp.get() && - this->left->is_match(temp->get_left()) && - this->right->is_match(temp->get_right()) && - (num_columns == this->get_num_columns()); + return temp.get() && + this->left->is_match(temp->get_left()) && + this->right->is_match(temp->get_right()) && + (temp->get_num_rows() == this->get_num_rows()) && + (temp->get_num_columns() == this->get_num_columns()); } //------------------------------------------------------------------------------ @@ -979,18 +1001,24 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_row_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && this->left->is_match(temp->get_arg()); + return temp.get() && + this->left->is_match(temp->get_arg()) && + (temp->get_size() == this->get_num_rows()); } //------------------------------------------------------------------------------ /// @brief Do the columns match. /// +/// The number of rows is the column dimension. +/// /// @params[in] x Node to match. /// @returns True if the column arguments match. //------------------------------------------------------------------------------ bool is_col_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && this->right->is_match(temp->get_arg()); + return temp.get() && + this->right->is_match(temp->get_arg()) && + (temp->get_size() == this->get_num_columns()); } }; diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index 8a5672d..3f44986 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -601,6 +601,7 @@ namespace graph { pr2->get_right()); } +#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -631,6 +632,7 @@ namespace graph { pl2->get_left(), pl2->get_right()); } +#endif return this->shared_from_this(); } diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index c092726..222286d 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -243,27 +243,24 @@ template void piecewise_2D() { auto ay = graph::variable (1, ""); auto bx = graph::variable (1, ""); auto by = graph::variable (1, ""); - auto p1 = graph::piecewise_2D (std::vector ({static_cast (1.0), - static_cast (2.0), - static_cast (3.0), - static_cast (4.0)}), - 2, ax, ay); - auto p2 = graph::piecewise_2D (std::vector ({static_cast (2.0), - static_cast (4.0), - static_cast (6.0), - static_cast (10.0)}), - 2, bx, by); - auto p3 = graph::piecewise_2D (std::vector ({static_cast (2.0), - static_cast (4.0), - static_cast (6.0), - static_cast (10.0)}), - 2, ax, ay); - auto p4 = graph::piecewise_1D (std::vector ({static_cast (2.0), - static_cast (4.0)}), - ax); - auto p5 = graph::piecewise_1D (std::vector ({static_cast (2.0), - static_cast (4.0)}), - ay); + auto p1 = graph::piecewise_2D (std::vector ({ + static_cast (1.0), static_cast (2.0), + static_cast (3.0), static_cast (4.0) + }), 2, ax, ay); + auto p2 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, bx, by); + auto p3 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, ax, ay); + auto p4 = graph::piecewise_1D (std::vector ({ + static_cast (2.0), static_cast (4.0) + }), ax); + auto p5 = graph::piecewise_1D (std::vector ({ + static_cast (2.0), static_cast (4.0) + }), ay); auto zero = graph::zero (); @@ -425,12 +422,14 @@ template void piecewise_2D() { graph::variable_cast(ay)}, {p1/p3}, {}, static_cast (0.5), 0.0); + bx->set(static_cast (1.5)); + by->set(static_cast (0.5)); compile ({graph::variable_cast(ax), graph::variable_cast(ay), graph::variable_cast(bx), graph::variable_cast(by)}, {graph::fma(p1, p3, p2)}, {}, - static_cast (10.0), 0.0); + static_cast (14.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::pow(p1, p3)}, {}, @@ -453,22 +452,29 @@ template void piecewise_2D() { } // Test row combines. + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1}, {}, + static_cast (2.0), 0.0); + compile ({graph::variable_cast(ax)}, + {p4}, {}, + static_cast (2.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1 + p4}, {}, - static_cast (6.0), 0.0); + static_cast (4.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1 - p4}, {}, - static_cast (-2.0), 0.0); + static_cast (0.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1*p4}, {}, - static_cast (8.0), 0.0); + static_cast (4.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1/p4}, {}, - static_cast (0.5), 0.0); + static_cast (1.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay), graph::variable_cast(bx), @@ -479,19 +485,19 @@ template void piecewise_2D() { graph::variable_cast(ay)}, {graph::pow(p1, p4)}, {}, static_cast (std::pow(static_cast (2.0), - static_cast (4.0))), 0.0); + static_cast (2.0))), 0.0); if constexpr (jit::is_complex ()) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p4)}, {}, - static_cast (std::atan(static_cast (4.0) / + static_cast (std::atan(static_cast (2.0) / static_cast (2.0))), 0.0); } else { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p4)}, {}, - static_cast (std::atan2(static_cast (4.0), + static_cast (std::atan2(static_cast (2.0), static_cast (2.0))), 0.0); } @@ -500,42 +506,42 @@ template void piecewise_2D() { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1 + p5}, {}, - static_cast (4.0), 0.0); + static_cast (6.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1 - p5}, {}, - static_cast (0.0), 0.0); + static_cast (-2.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1*p5}, {}, - static_cast (4.0), 0.0); + static_cast (8.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {p1/p5}, {}, - static_cast (1.0), 0.0); + static_cast (0.5), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay), graph::variable_cast(bx), graph::variable_cast(by)}, {graph::fma(p1, p5, p2)}, {}, - static_cast (6.0), 0.0); + static_cast (14.0), 0.0); compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::pow(p1, p5)}, {}, static_cast (std::pow(static_cast (2.0), - static_cast (2.0))), 0.0); + static_cast (4.0))), 0.0); if constexpr (jit::is_complex ()) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p5)}, {}, - static_cast (std::atan(static_cast (2.0) / + static_cast (std::atan(static_cast (4.0) / static_cast (2.0))), 0.0); } else { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p5)}, {}, - static_cast (std::atan2(static_cast (2.0), + static_cast (std::atan2(static_cast (4.0), static_cast (2.0))), 0.0); } @@ -547,6 +553,51 @@ template void piecewise_2D() { 2, ax, bx); assert(graph::constant_cast(pc).get() && "Expected a constant."); + + auto prc = graph::piecewise_1D (std::vector ({ + static_cast (1.0), + static_cast (2.0), + static_cast (3.0) + }), ax); + auto pcc = graph::piecewise_1D (std::vector ({ + static_cast (1.0), + static_cast (2.0), + static_cast (3.0) + }), ay); + auto p2Dc = graph::piecewise_2D (std::vector ({ + static_cast (1.0), static_cast (2.0), + static_cast (3.0), static_cast (4.0), + static_cast (5.0), static_cast (6.0) + }), 2, ax, ay); + + auto row_test = prc + p2Dc; + auto row_test_cast = graph::piecewise_2D_cast(row_test); + assert(row_test_cast.get() && "Expected a 2D piecewise node.."); + + auto col_test = pcc + p2Dc; + auto col_test_cast = graph::add_cast(col_test); + assert(col_test_cast.get() && "Expected an add node."); + + ax->set(static_cast (2.5)); + ay->set(static_cast (1.5)); + compile ({graph::variable_cast(ax)}, + {prc}, {}, + static_cast (3.0), 0.0); + compile ({graph::variable_cast(ay)}, + {pcc}, {}, + static_cast (2.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p2Dc}, {}, + static_cast (6.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {row_test}, {}, + static_cast (9.0), 0.0); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {col_test}, {}, + static_cast (8.0), 0.0); } //------------------------------------------------------------------------------ -- GitLab From 1d5ca448690b0a0f7ab5ada91673b590f0b794a3 Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Wed, 10 Jul 2024 12:57:55 -0400 Subject: [PATCH 62/63] Remove left over macros to diaable the row column constant reduction. --- graph_framework/arithmetic.hpp | 8 -------- graph_framework/math.hpp | 2 -- graph_framework/trigonometry.hpp | 2 -- 3 files changed, 12 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 3724701..ad30bd8 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -194,7 +194,6 @@ namespace graph { pr2->get_right()); } -#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -225,7 +224,6 @@ namespace graph { pl2->get_left(), pl2->get_right()); } -#endif // Idenity reductions. if (this->left->is_match(this->right)) { @@ -710,7 +708,6 @@ namespace graph { pr2->get_right()); } -#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -741,7 +738,6 @@ namespace graph { pl2->get_left(), pl2->get_right()); } -#endif // Common factor reduction. If the left and right are both muliply nodes check // for a common factor. So you can change a*b - a*c -> a*(b - c). @@ -1306,7 +1302,6 @@ namespace graph { pr2->get_right()); } -#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -1337,7 +1332,6 @@ namespace graph { pl2->get_left(), pl2->get_right()); } -#endif // Move constants to the left. if (is_constant_promotable(this->right, this->left)) { @@ -2091,7 +2085,6 @@ namespace graph { pr2->get_right()); } -#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -2122,7 +2115,6 @@ namespace graph { pl2->get_left(), pl2->get_right()); } -#endif if (this->left->is_match(this->right)) { return one (); diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index ed05de0..dca53cb 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -898,7 +898,6 @@ namespace graph { pr2->get_right()); } -#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -929,7 +928,6 @@ namespace graph { pl2->get_left(), pl2->get_right()); } -#endif auto lp = pow_cast(this->left); // Only run this reduction if the right is an integer constant value. diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index 3f44986..8a5672d 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -601,7 +601,6 @@ namespace graph { pr2->get_right()); } -#if 1 // Combine 2D and 1D piecewise constants if a row or column matches. if (pr2.get() && pr2->is_row_match(this->left)) { backend::buffer result = pl1->evaluate(); @@ -632,7 +631,6 @@ namespace graph { pl2->get_left(), pl2->get_right()); } -#endif return this->shared_from_this(); } -- GitLab From 92881d34c4b48c260566e95b7e41f59eb3c74c1f Mon Sep 17 00:00:00 2001 From: Mark Cianciosa Date: Wed, 10 Jul 2024 15:14:04 -0400 Subject: [PATCH 63/63] Remove host sync that is not needed. --- graph_tests/physics_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/graph_tests/physics_test.cpp b/graph_tests/physics_test.cpp index c46c1ba..c9611e0 100644 --- a/graph_tests/physics_test.cpp +++ b/graph_tests/physics_test.cpp @@ -606,7 +606,6 @@ template void test_efit() { for (size_t i = 0; i < 10000; i++) { solve.step(); - solve.sync_host(); } solve.sync_host(); -- GitLab