From 71ad61f28aaa04472bb77da2c8fc7d144bbaab55 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 28 May 2025 14:58:14 -0400 Subject: [PATCH] Clean commented out cuda random code. Restruct reduction of random numbers. Repleace query functions with concepts. --- graph_framework/arithmetic.hpp | 17 +++-- graph_framework/backend.hpp | 23 ++++--- graph_framework/cpu_context.hpp | 10 +-- graph_framework/cuda_context.hpp | 28 ++++----- graph_framework/math.hpp | 14 ++--- graph_framework/node.hpp | 4 +- graph_framework/output.hpp | 18 +++--- graph_framework/piecewise.hpp | 56 ++++++++++------- graph_framework/random.hpp | 20 ++---- graph_framework/register.hpp | 103 ++++++++++--------------------- graph_framework/trigonometry.hpp | 6 +- graph_tests/jit_test.cpp | 4 +- graph_tests/math_test.cpp | 2 +- graph_tests/physics_test.cpp | 4 +- graph_tests/piecewise_test.cpp | 16 ++--- graph_tests/random_test.cpp | 94 ++++++++++++++++++++++++++-- 16 files changed, 229 insertions(+), 190 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 4b72eaf..51f0d90 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -893,7 +893,6 @@ namespace graph { // Idenity reductions. auto l = constant_cast(this->left); if (this->left->is_match(this->right)) { - auto l = constant_cast(this->left); if (l.get() && l->is(0)) { return this->left; } @@ -2502,21 +2501,21 @@ namespace graph { stream << " " << registers[this] << " = "; if constexpr (SAFE_MATH) { stream << "(" << registers[l.get()] << " == "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { stream << "0"; } stream << " || " << registers[r.get()] << " == "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { stream << "0"; } stream << ") ? "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { @@ -3494,14 +3493,14 @@ namespace graph { stream << " " << registers[this] << " = "; if constexpr (SAFE_MATH) { stream << registers[l.get()] << " == "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { stream << "0"; } stream << " ? "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { @@ -5069,14 +5068,14 @@ namespace graph { stream << " " << registers[this] << " = "; if constexpr (SAFE_MATH) { stream << "(" << registers[l.get()] << " == "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { stream << "0"; } stream << " || " << registers[m.get()] << " == "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "(0, 0)"; } else { @@ -5084,7 +5083,7 @@ namespace graph { } stream << ") ? " << registers[r.get()] << " : "; } - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << registers[l.get()] << "*" << registers[m.get()] << " + " << registers[r.get()]; diff --git a/graph_framework/backend.hpp b/graph_framework/backend.hpp index aab71c5..0ad9222 100644 --- a/graph_framework/backend.hpp +++ b/graph_framework/backend.hpp @@ -255,9 +255,8 @@ namespace backend { //------------------------------------------------------------------------------ /// @brief Take erfi. //------------------------------------------------------------------------------ - template - typename std::enable_if (), void>::type erfi() { - for (D &d : memory) { + void erfi() requires(jit::complex_scalar) { + for (T &d : memory) { d = special::erfi(d); } } @@ -278,7 +277,7 @@ namespace backend { //------------------------------------------------------------------------------ bool is_normal() const { for (const T &x : memory) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { if (std::isnan(std::real(x)) || std::isinf(std::real(x)) || std::isnan(std::imag(x)) || std::isinf(std::imag(x))) { return false; @@ -597,7 +596,7 @@ namespace backend { const size_t num_rows = x.size(); for (size_t i = 0; i < num_rows; i++) { for (size_t j = 0; j < num_colmns; j++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { memory[i*num_colmns + j] = std::atan(x[i]/memory[i*num_colmns + j]); } else { memory[i*num_colmns + j] = std::atan2(x[i], memory[i*num_colmns + j]); @@ -613,7 +612,7 @@ namespace backend { const size_t num_rows = size(); for (size_t i = 0; i < num_rows; i++) { for (size_t j = 0; j < num_colmns; j++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[i]); } else { m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[i]); @@ -641,7 +640,7 @@ namespace backend { const size_t num_rows = x.size(); for (size_t i = 0; i < num_colmns; i++) { for (size_t j = 0; j < num_rows; j++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { memory[i*num_colmns + j] = std::atan(x[j]/memory[i*num_colmns + j]); } else { memory[i*num_colmns + j] = std::atan2(x[j], memory[i*num_colmns + j]); @@ -657,7 +656,7 @@ namespace backend { const size_t num_rows = size(); for (size_t i = 0; i < num_rows; i++) { for (size_t j = 0; j < num_colmns; j++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[j]); } else { m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[j]); @@ -918,7 +917,7 @@ namespace backend { inline buffer fma(buffer &a, buffer &b, buffer &c) { - constexpr bool use_fma = !jit::is_complex () && + constexpr bool use_fma = !jit::complex_scalar && #ifdef FP_FAST_FMA true; #else @@ -1100,7 +1099,7 @@ namespace backend { if (y.size() == 1) { const T right = y.at(0); for (size_t i = 0, ie = x.size(); i < ie; i++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { x[i] = std::atan(right/x[i]); } else { x[i] = std::atan2(right, x[i]); @@ -1110,7 +1109,7 @@ namespace backend { } else if (x.size() == 1) { const T left = x.at(0); for (size_t i = 0, ie = y.size(); i < ie; i++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { y[i] = std::atan(y[i]/left); } else { y[i] = std::atan2(y[i], left); @@ -1122,7 +1121,7 @@ namespace backend { assert(x.size() == y.size() && "Left and right sizes are incompatable."); for (size_t i = 0, ie = x.size(); i < ie; i++) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { x[i] = std::atan(y[i]/x[i]); } else { x[i] = std::atan2(y[i], x[i]); diff --git a/graph_framework/cpu_context.hpp b/graph_framework/cpu_context.hpp index c16d00e..66a3fd5 100644 --- a/graph_framework/cpu_context.hpp +++ b/graph_framework/cpu_context.hpp @@ -311,7 +311,7 @@ namespace gpu { return [run, begin, end] () mutable { run(); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { return *std::max_element(begin, end, [] (const T a, const T b) { return std::abs(a) < std::abs(b); @@ -346,7 +346,7 @@ namespace gpu { const graph::output_nodes &nodes) { for (auto &out : nodes) { const T temp = kernel_arguments[out.get()][index]; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { std::cout << std::real(temp) << " " << std::imag(temp) << " "; } else { std::cout << temp << " "; @@ -401,7 +401,7 @@ namespace gpu { void create_header(std::ostringstream &source_buffer) { source_buffer << "#include " << std::endl << "#include " << std::endl; - if (jit::is_complex ()) { + if (jit::complex_scalar) { source_buffer << "#include " << std::endl; source_buffer << "#include " << std::endl; } else { @@ -517,7 +517,7 @@ namespace gpu { source_buffer << " " << jit::to_string('v', in.get()); source_buffer << "[i] = "; if constexpr (SAFE_MATH) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (source_buffer); source_buffer << " ("; source_buffer << "isnan(real(" << registers[a.get()] @@ -543,7 +543,7 @@ namespace gpu { source_buffer << " " << jit::to_string('o', out.get()); source_buffer << "[i] = "; if constexpr (SAFE_MATH) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (source_buffer); source_buffer << " ("; source_buffer << "isnan(real(" << registers[a.get()] diff --git a/graph_framework/cuda_context.hpp b/graph_framework/cuda_context.hpp index 8b12d85..288b42b 100644 --- a/graph_framework/cuda_context.hpp +++ b/graph_framework/cuda_context.hpp @@ -387,16 +387,16 @@ namespace gpu { texture_desc.addressMode[0] = CU_TR_ADDRESS_MODE_BORDER; texture_desc.addressMode[1] = CU_TR_ADDRESS_MODE_BORDER; texture_desc.addressMode[2] = CU_TR_ADDRESS_MODE_BORDER; - if constexpr (jit::is_float ()) { + if constexpr (jit::float_base) { array_desc.Format = CU_AD_FORMAT_FLOAT; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { array_desc.NumChannels = 2; } else { array_desc.NumChannels = 1; } } else { array_desc.Format = CU_AD_FORMAT_UNSIGNED_INT32; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { array_desc.NumChannels = 4; } else { array_desc.NumChannels = 2; @@ -433,16 +433,16 @@ namespace gpu { texture_desc.addressMode[1] = CU_TR_ADDRESS_MODE_BORDER; texture_desc.addressMode[2] = CU_TR_ADDRESS_MODE_BORDER; const size_t total = size[0]*size[1]; - if constexpr (jit::is_float ()) { + if constexpr (jit::float_base) { array_desc.Format = CU_AD_FORMAT_FLOAT; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { array_desc.NumChannels = 2; } else { array_desc.NumChannels = 1; } } else { array_desc.Format = CU_AD_FORMAT_UNSIGNED_INT32; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { array_desc.NumChannels = 4; } else { array_desc.NumChannels = 2; @@ -582,7 +582,7 @@ namespace gpu { wait(); for (auto &out : nodes) { const T temp = reinterpret_cast (kernel_arguments[out.get()])[index]; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { std::cout << std::real(temp) << " " << std::imag(temp) << " "; } else { std::cout << temp << " "; @@ -651,13 +651,13 @@ namespace gpu { << " return _buffer[index];" << std::endl << " }" << std::endl << "};" << std::endl; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { source_buffer << "#define CUDA_DEVICE_CODE" << std::endl << "#define M_PI " << M_PI << std::endl << "#include " << std::endl << "#include " << std::endl; #ifdef USE_CUDA_TEXTURES - if constexpr (jit::is_float ()) { + if constexpr (jit::float_base) { source_buffer << "static __inline__ __device__ complex to_cmp_float(float2 p) {" << std::endl << " return "; @@ -673,7 +673,7 @@ namespace gpu { << std::endl << "}" << std::endl; } - } else if constexpr (jit::is_double ()) { + } else if constexpr (jit::double_base) { source_buffer << "static __inline__ __device__ double to_double(uint2 p) {" << std::endl << " return __hiloint2double(p.y, p.x);" @@ -858,7 +858,7 @@ namespace gpu { } source_buffer << "index] = "; if constexpr (SAFE_MATH) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (source_buffer); source_buffer << " ("; source_buffer << "isnan(real(" << registers[a.get()] @@ -889,7 +889,7 @@ namespace gpu { } source_buffer << "index] = "; if constexpr (SAFE_MATH) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (source_buffer); source_buffer << " ("; source_buffer << "isnan(real(" << registers[a.get()] @@ -932,13 +932,13 @@ namespace gpu { source_buffer << " const unsigned int k = threadIdx.x%32;" << std::endl; source_buffer << " if (i < " << size << ") {" << std::endl; source_buffer << " " << jit::type_to_string () << " sub_max = "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { source_buffer << "abs(input[i]);" << std::endl; } else { source_buffer << "input[i];" << std::endl; } source_buffer << " for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { source_buffer << " sub_max = max(sub_max, abs(input[index]));" << std::endl; } else { source_buffer << " sub_max = max(sub_max, input[index]);" << std::endl; diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index 1b2d982..9568ca5 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -440,11 +440,11 @@ namespace graph { jit::add_type (stream); stream << " " << registers[this] << " = "; if constexpr (SAFE_MATH) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << "real("; } stream << registers[a.get()]; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << ")"; } stream << " < 709.8 ? "; @@ -452,16 +452,12 @@ namespace graph { stream << "exp(" << registers[a.get()] << ")"; if constexpr (SAFE_MATH) { stream << " : "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); stream << "("; } - if constexpr (jit::is_float ()) { - stream << std::numeric_limits::max(); - } else { - stream << std::numeric_limits::max(); - } - if constexpr (jit::is_complex ()) { + stream << jit::max_base (); + if constexpr (jit::complex_scalar) { stream << ")"; } } diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 5ad0853..b6a1839 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -465,13 +465,13 @@ namespace graph { const T temp = this->evaluate().at(0); stream << " " << registers[this] << " = "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); } stream << temp; this->endline(stream, usage); #else - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { registers[this] = jit::get_type_string () + "(" + jit::format_to_string(this->evaluate().at(0)) + ")"; diff --git a/graph_framework/output.hpp b/graph_framework/output.hpp index ac93b81..a09cd9e 100644 --- a/graph_framework/output.hpp +++ b/graph_framework/output.hpp @@ -171,9 +171,9 @@ namespace output { /// Data sizes. std::array count; /// Get the ray dimension size. - const size_t ray_dim_size = 1 + jit::is_complex (); + static constexpr size_t ray_dim_size = 1 + jit::complex_scalar; /// The NetCDF type. - const nc_type type = jit::is_float () ? NC_FLOAT : NC_DOUBLE; + static constexpr nc_type type = jit::float_base ? NC_FLOAT : NC_DOUBLE; //------------------------------------------------------------------------------ /// @brief Struct to map variables to a gpu buffer. @@ -213,7 +213,7 @@ namespace output { //------------------------------------------------------------------------------ data_set(const result_file &result) { sync.lock(); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { if (NC_NOERR != nc_inq_dimid(result.get_ncid(), "ray_dim_cplx", &ray_dim)) { @@ -367,8 +367,8 @@ namespace output { for (variable &var : variables) { sync.lock(); - if constexpr (jit::is_float ()) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::float_base) { + if constexpr (jit::complex_scalar) { check_error(nc_put_vara_float(result.get_ncid(), var.id, start.data(), @@ -382,7 +382,7 @@ namespace output { var.buffer)); } } else { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { check_error(nc_put_vara_double(result.get_ncid(), var.id, start.data(), @@ -428,8 +428,8 @@ namespace output { }; sync.lock(); - if constexpr (jit::is_float ()) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::float_base) { + if constexpr (jit::complex_scalar) { check_error(nc_get_varm_float(result.get_ncid(), ref.id, ref_start.data(), @@ -447,7 +447,7 @@ namespace output { ref.buffer)); } } else { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { check_error(nc_get_varm_double(result.get_ncid(), ref.id, ref_start.data(), diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index 651c6d9..612e4b2 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -32,19 +32,19 @@ void compile_index(std::ostringstream &stream, stream << "min(max((" << type << ")"; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << "real("; } stream << "((" << register_name << " - "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << jit::get_type_string (); } stream << offset << ")/"; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << jit::get_type_string (); } stream << scale << ")"; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << ")"; } stream << ",(" << type << ")0),(" @@ -252,13 +252,13 @@ void compile_index(std::ostringstream &stream, stream << "const "; jit::add_type (stream); stream << " " << registers[leaf_node::caches.backends[data_hash].data()] << "[] = {"; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); } stream << leaf_node::caches.backends[data_hash][0]; for (size_t i = 1; i < length; i++) { stream << ", "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); } stream << leaf_node::caches.backends[data_hash][i]; @@ -340,14 +340,18 @@ void compile_index(std::ostringstream &stream, stream << " " << registers[this] << " = "; #ifdef USE_CUDA_TEXTURES if constexpr (jit::use_cuda()) { - if constexpr (jit::is_float () && !jit::is_complex ()) { - stream << "tex1D ("; - } else if constexpr (jit::is_double () && !jit::is_complex ()) { - stream << "to_double(tex1D ("; - } else if constexpr (jit::is_float ()) { - stream << "to_cmp_float(tex1D ("; + if constexpr (float_base) { + if constexpr (complex_scalar) { + stream << "to_cmp_float(tex1D ("; + } else { + stream << "tex1D ("; + } } else { - stream << "to_cmp_double(tex1D ("; + if constexpr (complex_scalar) { + stream << "to_cmp_double(tex1D ("; + } else { + stream << "to_double(tex1D ("; + } } } #endif @@ -373,7 +377,7 @@ void compile_index(std::ostringstream &stream, compile_index (stream, registers[a.get()], length, scale, offset); #endif - if constexpr (jit::is_complex () || jit::is_double ()) { + if constexpr (jit::complex_scalar || jit::double_base) { stream << ")"; } stream << ")"; @@ -885,13 +889,13 @@ void compile_index(std::ostringstream &stream, stream << "const "; jit::add_type (stream); stream << " " << registers[leaf_node::caches.backends[data_hash].data()] << "[] = {"; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); } stream << leaf_node::caches.backends[data_hash][0]; for (size_t i = 1; i < length; i++) { stream << ", "; - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { jit::add_type (stream); } stream << leaf_node::caches.backends[data_hash][i]; @@ -1022,14 +1026,18 @@ void compile_index(std::ostringstream &stream, stream << " " << registers[this] << " = "; #ifdef USE_CUDA_TEXTURES if constexpr (jit::use_cuda()) { - if constexpr (jit::is_float () && !jit::is_complex ()) { - stream << "tex2D ("; - } else if constexpr (jit::is_double () && !jit::is_complex ()) { - stream << "to_double(tex2D ("; - } else if constexpr (jit::is_float ()) { - stream << "to_cmp_float(tex2D ("; + if constexpr (float_base) { + if constexpr (complex_scalar) { + stream << "to_cmp_float(tex1D ("; + } else { + stream << "tex1D ("; + } } else { - stream << "to_cmp_double(tex2D ("; + if constexpr (complex_scalar) { + stream << "to_cmp_double(tex1D ("; + } else { + stream << "to_double(tex1D ("; + } } } #endif @@ -1068,7 +1076,7 @@ void compile_index(std::ostringstream &stream, compile_index (stream, registers[x.get()], num_rows, x_scale, x_offset); #endif - if constexpr (jit::is_complex () || jit::is_double ()) { + if constexpr (jit::complex_scalar || jit::double_base) { stream << ")"; } stream << ")"; diff --git a/graph_framework/random.hpp b/graph_framework/random.hpp index d405666..f049a93 100644 --- a/graph_framework/random.hpp +++ b/graph_framework/random.hpp @@ -28,11 +28,7 @@ namespace graph { //------------------------------------------------------------------------------ struct mt_state { /// State array. -//#ifdef USE_CUDA -// uint32_t array[624]; -//#else std::array array; -//#endif /// State index. uint16_t index; #ifdef USE_CUDA @@ -103,11 +99,7 @@ namespace graph { int &avail_const_mem) { if (visited.find(this) == visited.end()) { stream << "struct mt_state {" << std::endl -//#ifdef USE_CUDA -// << " uint32_t array[624];" << std::endl -//#else << " array array;" << std::endl -//#endif << " uint16_t index;" << std::endl #ifdef USE_CUDA << " uint16_t padding[3];" << std::endl @@ -243,11 +235,7 @@ namespace graph { mt_state initalize_state(const uint32_t seed) { mt_state state; state.array[0] = seed; -#ifdef USE_CUDA - for (uint16_t i = 1; i < 624; i++) { -#else for (uint16_t i = 1, ie = state.array.size(); i < ie; i++) { -#endif state.array[i] = 1812433253U*(state.array[i - 1]^(state.array[i - 1] >> 30)) + i; } state.index = 0; @@ -453,13 +441,15 @@ namespace graph { //------------------------------------------------------------------------------ /// @brief Querey if the nodes match. /// +/// Arithmetic and math operations on random number umber distributions have +/// the effect of changing the distribution. For instance rand1 + rand2 will +/// to a pyramid shaped distribution function. Assume random numbers never +/// match as a consequnce. +/// /// @param[in] x Other graph to check if it is a match. /// @returns True if the nodes are a match. //------------------------------------------------------------------------------ virtual bool is_match(shared_leaf x) { - if (this == x.get()) { - return true; - } return false; } diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index b3dcf5e..2201451 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -32,71 +32,18 @@ namespace jit { template concept scalar = float_scalar || std::integral; -/// Verbose output. - static bool verbose = USE_VERBOSE; - -//------------------------------------------------------------------------------ -/// @brief Test if a type is complex. -/// -/// @tparam BASE Base type. -/// @tparam T Type to check against. -/// -/// @returns A constant expression true or false type. -//------------------------------------------------------------------------------ - template - constexpr bool is_complex() { - return std::is_same>::value; - } - -//------------------------------------------------------------------------------ -/// @brief Test if the base type is float. -/// -/// @tparam BASE Base type. -/// @tparam T Type to check against. -/// -/// @returns A constant expression true or false type. -//------------------------------------------------------------------------------ - template - constexpr bool is_base() { - return is_complex () || std::is_same::value; - } - -//------------------------------------------------------------------------------ -/// @brief Test if the base type is float. -/// -/// @tparam T Base type of the calculation. -/// -/// @returns A constant expression true or false type. -//------------------------------------------------------------------------------ - template - constexpr bool is_float() { - return is_base (); - } +/// float base concept. + template + concept float_base = std::same_as || + std::same_as>; -//------------------------------------------------------------------------------ -/// @brief Test if the base type is double. -/// -/// @tparam T Base type of the calculation. -/// -/// @returns A constant expression true or false type. -//------------------------------------------------------------------------------ - template - constexpr bool is_double() { - return is_base (); - } +/// Double base concept. + template + concept double_base = std::same_as || + std::same_as>; -//------------------------------------------------------------------------------ -/// @brief Test if a type is complex. -/// -/// @tparam T Base type of the calculation. -/// -/// @returns A constant expression true or false type. -//------------------------------------------------------------------------------ - template - constexpr bool is_complex() { - return is_complex () || - is_complex (); - } +/// Verbose output. + static bool verbose = USE_VERBOSE; //------------------------------------------------------------------------------ /// @brief Convert a base type to a string. @@ -107,9 +54,9 @@ namespace jit { //------------------------------------------------------------------------------ template std::string type_to_string() { - if constexpr (is_float ()) { + if constexpr (float_base) { return "float"; - } else if constexpr (is_double ()) { + } else { return "double"; } } @@ -133,7 +80,7 @@ namespace jit { template constexpr bool use_metal() { #if USE_METAL - return is_float() && !is_complex (); + return float_base && !complex_scalar; #else return false; #endif @@ -195,7 +142,7 @@ namespace jit { //------------------------------------------------------------------------------ template std::string get_type_string() { - if constexpr (is_complex ()) { + if constexpr (complex_scalar) { if constexpr (use_cuda()) { return "cuda::std::complex<" + type_to_string () + ">"; } else { @@ -227,13 +174,29 @@ namespace jit { //------------------------------------------------------------------------------ template constexpr int max_digits10() { - if constexpr (is_float ()) { + if constexpr (float_base) { return std::numeric_limits::max_digits10; } else { return std::numeric_limits::max_digits10; } } +//------------------------------------------------------------------------------ +/// @brief The maximum value for a base type. +/// +/// @tparam T Base type of the calculation. +/// +/// @returns The maximum number of digits needed. +//------------------------------------------------------------------------------ + template + constexpr int max_base() { + if constexpr (float_base) { + return std::numeric_limits::max(); + } else { + return std::numeric_limits::max(); + } + } + //------------------------------------------------------------------------------ /// @brief Convert a value to a string while avoiding locale. /// @@ -252,7 +215,7 @@ namespace jit { end = std::to_chars(buffer.begin(), buffer.end(), value, 16).ptr; - } else if constexpr (is_complex ()) { + } else if constexpr (complex_scalar) { return format_to_string(std::real(value)) + "," + format_to_string(std::imag(value)); } else { @@ -314,7 +277,7 @@ namespace jit { /// @param[in] right Right hand side. //------------------------------------------------------------------------------ bool operator() (const T &left, const T &right) const { - if constexpr (is_complex ()) { + if constexpr (complex_scalar) { return std::abs(left) < std::abs(right); } else { return left < right; diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index 20cbb95..74153b6 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -704,7 +704,7 @@ namespace graph { registers[this] = jit::to_string('r', this); stream << " const "; jit::add_type (stream); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { stream << " " << registers[this] << " = atan(" << registers[r.get()] << "/" << registers[l.get()]; @@ -831,7 +831,7 @@ namespace graph { //------------------------------------------------------------------------------ template shared_leaf atan(const L l, - shared_leaf r) { + shared_leaf r) { return atan(constant (static_cast (l)), r); } @@ -847,7 +847,7 @@ namespace graph { //------------------------------------------------------------------------------ template shared_leaf atan(shared_leaf l, - const R r) { + const R r) { return atan(l, constant (static_cast (r))); } diff --git a/graph_tests/jit_test.cpp b/graph_tests/jit_test.cpp index ad82843..1f2eb2e 100644 --- a/graph_tests/jit_test.cpp +++ b/graph_tests/jit_test.cpp @@ -24,7 +24,7 @@ //------------------------------------------------------------------------------ template void check(const T test, const T tolarance) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { assert(std::real(test) <= std::real(tolarance) && "Real GPU and CPU values differ."); assert(std::imag(test) <= std::imag(tolarance) && @@ -394,7 +394,7 @@ void run_dispersion_tests() { run_dispersion_test> (slab_eq, 1.4E10); } else if constexpr (jit::use_metal ()) { run_dispersion_test> (slab_eq, 5.0E9); - } else if constexpr (jit::is_complex ()){ + } else if constexpr (jit::complex_scalar){ run_dispersion_test> (slab_eq, 1.5E11); } else { run_dispersion_test> (slab_eq, 5.1E10); diff --git a/graph_tests/math_test.cpp b/graph_tests/math_test.cpp index 4f4c783..c92272f 100644 --- a/graph_tests/math_test.cpp +++ b/graph_tests/math_test.cpp @@ -584,7 +584,7 @@ template void run_tests() { test_exp (); test_pow (); test_log (); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { test_erfi (); } } diff --git a/graph_tests/physics_test.cpp b/graph_tests/physics_test.cpp index 732cd6b..f8b6d55 100644 --- a/graph_tests/physics_test.cpp +++ b/graph_tests/physics_test.cpp @@ -470,7 +470,7 @@ void test_cold_plasma_cutoffs() { kx->set(1, static_cast (0.0)); t->set(0, static_cast (0.0)); t->set(1, static_cast (0.0)); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { solve.init(x, 2.8E-29); } else { solve.init(x, 5.0E-30); @@ -486,7 +486,7 @@ void test_cold_plasma_cutoffs() { // Solve for X-Mode and O-Mode wave numbers. kx->set(0, static_cast (500.0)); // O-Mode kx->set(1, static_cast (1500.0)); // X-Mode - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { solve.init(kx, 2.2E-30); } else { solve.init(kx); diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index db736ec..d213a7f 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -29,7 +29,7 @@ //------------------------------------------------------------------------------ template void check(const T test, const T tolarance) { - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { assert(std::real(test) <= std::real(tolarance) && "Real GPU and CPU values differ."); assert(std::imag(test) <= std::imag(tolarance) && @@ -205,7 +205,7 @@ template void piecewise_1D() { graph::variable_cast(b)}, {graph::fma(p1, p3, p2)}, {}, static_cast (10.0), 0.0); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(a)}, {graph::pow(p1, p3)}, {}, static_cast (16.0), 2.0E-15); @@ -214,7 +214,7 @@ template void piecewise_1D() { {graph::pow(p1, p3)}, {}, static_cast (16.0), 0.0); } - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(a)}, {graph::atan(p1, p3)}, {}, static_cast (std::atan(static_cast (4.0) / @@ -466,7 +466,7 @@ template void piecewise_2D() { graph::variable_cast(by)}, {graph::fma(p1, p3, p2)}, {}, static_cast (14.0), 0.0); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::pow(p1, p3)}, {}, @@ -477,7 +477,7 @@ template void piecewise_2D() { {graph::pow(p1, p3)}, {}, static_cast (16.0), 0.0); } - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p3)}, {}, @@ -528,7 +528,7 @@ template void piecewise_2D() { {graph::pow(p1, p4)}, {}, static_cast (std::pow(static_cast (2.0), static_cast (2.0))), 0.0); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p4)}, {}, @@ -567,7 +567,7 @@ template void piecewise_2D() { graph::variable_cast(by)}, {graph::fma(p1, p5, p2)}, {}, static_cast (14.0), 0.0); - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::pow(p1, p5)}, {}, @@ -578,7 +578,7 @@ template void piecewise_2D() { {graph::pow(p1, p5)}, {}, static_cast (16.0), 0.0); } - if constexpr (jit::is_complex ()) { + if constexpr (jit::complex_scalar) { compile ({graph::variable_cast(ax), graph::variable_cast(ay)}, {graph::atan(p1, p5)}, {}, diff --git a/graph_tests/random_test.cpp b/graph_tests/random_test.cpp index 2fb7f9a..346bc7b 100644 --- a/graph_tests/random_test.cpp +++ b/graph_tests/random_test.cpp @@ -39,8 +39,9 @@ T autocorrelation(const std::vector &sequence, /// @brief Run test with a specified backend. /// /// @tparam T Base type of the calculation. +/// @tparam N Number of random numbers to use. //------------------------------------------------------------------------------ -template void run_test() { +template void test_dist() { auto state = graph::random_state (jit::context::random_state_size, 0); auto random = graph::random (graph::random_state_cast(state)); const T max = 1.0; @@ -63,6 +64,89 @@ template void run_test() { } } +//------------------------------------------------------------------------------ +/// @brief Test graph properties of random numbers. +//------------------------------------------------------------------------------ +template void test_graph() { + auto state = graph::random_state (jit::context::random_state_size, 0); + auto random = graph::random (graph::random_state_cast(state)); + +// r + r -> r + r + assert(graph::add_cast(random + random).get() && "Expected add node."); +// r + 0.0 -> r + assert(graph::random_cast(random + 0.0).get() && "Expected random node."); +// r - r -> r - r + assert(graph::subtract_cast(random - random).get() && + "Expected subtract node."); +// r - 0.0 -> r + assert(graph::random_cast(random - 0.0).get() && "Expected random node."); +// r*r -> r*r + assert(graph::multiply_cast(random*random).get() && + "Expected multiply node."); +// 1*r -> r + assert(graph::random_cast(1.0*random).get() && "Expected random node."); +// r/r -> r/r + assert(graph::divide_cast(random/random).get() && "Expected divide node."); +// r/1 -> r + assert(graph::random_cast(random/1.0).get() && "Expected random node."); +// fma(r,r,1) -> fma(r,r,1) + assert(graph::fma_cast(graph::fma(random, random, 1.0)).get() && + "Expected fma node."); +// fma(r,2,r) -> fma(2,r,r) + assert(graph::fma_cast(graph::fma(random, 2.0, random)).get() && + "Expected fma node."); +// fma(2.0,r,r) -> fma(2.0,r,r) + assert(graph::fma_cast(graph::fma(2.0, random, random)).get() && + "Expected fma node."); +// fma(r,r,0.0) -> r*r + assert(graph::multiply_cast(graph::fma(random, random, 0.0)).get() && + "Expected multiply node."); +// fma(r,1.0,r) -> r + r + assert(graph::add_cast(graph::fma(random, 1.0, random)).get() && + "Expected add node."); +// sqrt(r) -> sqrt(r) + assert(graph::sqrt_cast(graph::sqrt(random)).get() && + "Expected sqrt node."); +// exp(r) -> exp(r) + assert(graph::exp_cast(graph::exp(random)).get() && + "Expected exp node."); +// ln(r) -> ln(r) + assert(graph::log_cast(graph::log(random)).get() && + "Expected log node."); +// pow(r,r) -> pow(r,r) + assert(graph::pow_cast(graph::pow(random, random)).get() && + "Expected pow node."); +// pow(r,1) -> r + assert(graph::random_cast(graph::pow(random, 1.0)).get() && + "Expected random node."); + + if constexpr(jit::complex_scalar) { +// efi(r) -> efi(r) + assert(graph::erfi_cast(graph::erfi(random)).get() && + "Expected erfi node."); + } +// sin(r) -> sin(r) + assert(graph::sin_cast(graph::sin(random)).get() && + "Expected sin node."); +// cos(r) -> cos(r) + assert(graph::cos_cast(graph::cos(random)).get() && + "Expected cos node."); +// atan(r,r) -> atan(r,r) + assert(graph::atan_cast(graph::atan(random, random)).get() && + "Expected atan node."); +} + +//------------------------------------------------------------------------------ +/// @brief Run tests. +/// +/// @tparam T Base type of the calculation. +/// @tparam N Number of random numbers to use. +//------------------------------------------------------------------------------ +template void run_tests() { + test_dist (); + test_graph (); +} + //------------------------------------------------------------------------------ /// @brief Main program of the test. /// @@ -74,10 +158,10 @@ int main(int argc, const char * argv[]) { (void)argc; (void)argv; - run_test (); - run_test (); - run_test, 1000000> (); - run_test, 1000000> (); + run_tests (); + run_tests (); + run_tests, 1000000> (); + run_tests, 1000000> (); END_GPU } -- GitLab