From d20627d2b73c96299d4a7366c371fef059b47bcc Mon Sep 17 00:00:00 2001 From: cianciosa Date: Wed, 12 Feb 2025 19:27:25 -0500 Subject: [PATCH 1/3] Refactor chained fma reductions and expansion to handle arbitrary length chains. --- graph_framework.xcodeproj/project.pbxproj | 2 + graph_framework/arithmetic.hpp | 290 +++++++++++++--------- graph_framework/node.hpp | 12 + graph_tests/arithmetic_test.cpp | 57 +++++ 4 files changed, 237 insertions(+), 124 deletions(-) diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index 84ebb0b..79e3aa3 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -2530,6 +2530,7 @@ "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", + "-lLLVMObjectYAML", "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", @@ -2630,6 +2631,7 @@ "-lLLVMCGData", "-lLLVMSandboxIR", "-lLLVMFrontendAtomic", + "-lLLVMObjectYAML", "-lLLVMAArch64CodeGen", "-lclangFrontend", "-lclangBasic", diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index e6efea1..6ba0725 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -1685,6 +1685,112 @@ namespace graph { class multiply_node final : public branch_node { private: //------------------------------------------------------------------------------ +/// @brief Try to reduce paterns of constant times nested fma nodes. +/// +/// c1*fma(...,x,c2) -> fma(...,x,c3) +/// +/// @param[in] trial The fma node to try to reduce. +/// @returns The reduced node or null if it could not reduce the node. +//------------------------------------------------------------------------------ + shared_leaf + reduce_nested_fma_times_constant(shared_leaf trial) { + auto temp = fma_cast(trial); + if (temp.get()) { + if (is_constant_combineable(this->left, temp->get_left()) && + is_constant_combineable(this->left, temp->get_right())) { + return fma(this->left*temp->get_left(), + temp->get_middle(), + this->left*temp->get_right()); + } else { + auto temp2 = reduce_nested_fma_times_constant(temp->get_left()); + if (temp2.get()) { + return fma(temp2, + temp->get_middle(), + this->left*temp->get_right()); + } + } + } + return null_leaf (); + } + +//------------------------------------------------------------------------------ +/// @brief Try to expand nested fma node. +/// +/// fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4) +/// +/// @param[in] trial The fma node to try to expand. +/// @param[in] add The add node to try to expand. +/// @returns The expanded node or null if it could not expanded the node. +//------------------------------------------------------------------------------ + shared_leaf + expand_nested_fma_times_add(shared_leaf trial, + shared_add add) { + auto temp = fma_cast(trial); + if (temp.get()) { + if (add->get_right()->is_match(temp->get_middle()) && + is_constant_combineable(add->get_left(), temp->get_right())) { + auto temp2 = expand_nested_fma_times_add2(temp->get_left(), + temp, add); + if (temp2.get()) { + return fma(temp2, + add->get_right(), + temp->get_right()*add->get_left()); + } else if (is_constant_combineable(add->get_left(), temp->get_left())) { + return fma(fma(temp->get_left(), + add->get_right(), + add->get_left()*temp->get_left() + temp->get_right()), + add->get_right(), + temp->get_right()*add->get_left()); + } + } + } + return null_leaf (); + } + +//------------------------------------------------------------------------------ +/// @brief Try to expand nested fma node. +/// +/// fma(...,x,c2)*(c3 + x)* -> fma(...,x,c4) +/// +/// @param[in] trial The fma node to try to reduce. +/// @param[in] last The last fma node. +/// @param[in] add The add node to try to expand. +/// @returns The expanded node or null if it could not expanded the node. +//------------------------------------------------------------------------------ + shared_leaf + expand_nested_fma_times_add2(shared_leaf trial, + shared_leaf last, + shared_add add) { + auto temp = fma_cast(trial); + auto temp2 = fma_cast(last); + assert(temp2.get() && "Assumed a fma node."); + if (temp.get()) { + if (add->get_right()->is_match(temp->get_middle()) && + is_constant_combineable(add->get_left(), temp->get_left()) && + is_constant_combineable(add->get_left(), temp->get_right())) { + + return fma(fma(temp->get_left(), + add->get_right(), + add->get_left()*temp->get_left() + + temp->get_right()), + add->get_right(), + add->get_left()*temp->get_right() + + temp2->get_right()); + } else { + auto temp3 = expand_nested_fma_times_add2(temp->get_left(), + temp, add); + if (temp3.get()) { + return fma(temp3, + add->get_right(), + add->get_left()*temp->get_right() + + temp2->get_right()); + } + } + } + return null_leaf (); + } + +//------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// /// @param[in] l Left node pointer. @@ -1904,31 +2010,13 @@ namespace graph { return (this->left*rm->get_left())*rm->get_right(); } - auto rmlfma = fma_cast(rm->get_left()); - if (rmlfma.get()) { - if (is_constant_combineable(this->left, - rmlfma->get_left()) && - is_constant_combineable(this->left, - rmlfma->get_right())) { - return fma(this->left*rmlfma->get_left(), - rmlfma->get_middle(), - this->left*rmlfma->get_right())*rm->get_right(); - } - - auto rmlfmalfma = fma_cast(rmlfma->get_left()); - if (rmlfmalfma.get()) { - if (is_constant_combineable(this->left, - rmlfmalfma->get_left()) && - is_constant_combineable(this->left, - rmlfmalfma->get_right()) && - is_constant_combineable(this->left, rmlfma->get_right())) { - return fma(fma(this->left*rmlfmalfma->get_left(), - rmlfmalfma->get_middle(), - this->left*rmlfmalfma->get_right()), - rmlfma->get_middle(), - this->left*rmlfma->get_right())*rm->get_right(); - } - } +// c1*(fma(c2,x,c3)*y)-> fma(c4,x,c5)*y +// c1*(fma(fma(c2,x,c3),x,c4)*y)-> fma(fma(c5,x,c6),x,c7)*y +// c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y)-> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y +// etc... + auto temp = this->reduce_nested_fma_times_constant(rm->get_left()); + if (temp.get()) { + return temp*rm->get_right(); } } @@ -2326,76 +2414,24 @@ namespace graph { } } -// c3*fma(c1,a,c2) -> fma(c4,a,c5) - auto rfma = fma_cast(this->right); - if (rfma.get()) { - if (is_constant_combineable(this->left, rfma->get_left()) && - is_constant_combineable(this->left, rfma->get_right())) { - return fma(this->left*rfma->get_left(), - rfma->get_middle(), - this->left*rfma->get_right()); - } - - auto rfmalfma = fma_cast(rfma->get_left()); - if (rfmalfma.get()) { - if (is_constant_combineable(this->left, rfmalfma->get_left()) && - is_constant_combineable(this->left, rfmalfma->get_right()) && - is_constant_combineable(this->left, rfma->get_right())) { - return fma(fma(this->left*rfmalfma->get_left(), - rfmalfma->get_middle(), - this->left*rfmalfma->get_right()), - rfma->get_middle(), - this->left*rfma->get_right()); - } - - auto rfmalfmalfma = fma_cast(rfmalfma->get_left()); - if (rfmalfmalfma.get()) { - if (is_constant_combineable(this->left, rfmalfmalfma->get_left()) && - is_constant_combineable(this->left, rfmalfmalfma->get_right()) && - is_constant_combineable(this->left, rfmalfma->get_right()) && - is_constant_combineable(this->left, rfma->get_right())) { - return fma(fma(fma(this->left*rfmalfmalfma->get_left(), - rfmalfmalfma->get_middle(), - this->left*rfmalfmalfma->get_right()), - rfmalfma->get_middle(), - this->left*rfmalfma->get_right()), - rfma->get_middle(), - this->left*rfma->get_right()); - } - } - } +// c1*fma(c2,x,c3) -> fma(c4,x,c5) +// c1*fma(fma(c2,x,c3),x,c4) -> fma(fma(c5,x,c6),x,c7) +// c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9) +// etc... + auto fma_reduce = this->reduce_nested_fma_times_constant(this->right); + if (fma_reduce.get()) { + return fma_reduce; } // fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5) - auto lfma = fma_cast(this->left); - auto ra = add_cast(this->right); - if (lfma.get() && ra.get()) { - if (ra->get_right()->is_match(lfma->get_middle()) && - is_constant_combineable(ra->get_left(), lfma->get_left()) && - is_constant_combineable(ra->get_left(), lfma->get_right())) { - return fma(fma(lfma->get_left(), - ra->get_right(), - ra->get_left()*lfma->get_left() + lfma->get_right()), - ra->get_right(), - lfma->get_right()*ra->get_left()); - } - // fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7) - auto lfmalfma = fma_cast(lfma->get_left()); - if (ra->get_right()->is_match(lfma->get_middle()) && - ra->get_right()->is_match(lfmalfma->get_middle()) && - is_constant_combineable(ra->get_left(), lfma->get_right()) && - is_constant_combineable(ra->get_left(), lfmalfma->get_right()) && - is_constant_combineable(ra->get_left(), lfmalfma->get_left())) { - return fma(fma(fma(lfmalfma->get_left(), - ra->get_right(), - ra->get_left()*lfmalfma->get_left() + - lfmalfma->get_right()), - ra->get_right(), - ra->get_left()*lfmalfma->get_right() + - lfma->get_right()), - ra->get_right(), - ra->get_left()*lfma->get_right()); +// etc... + auto ra = add_cast(this->right); + if (ra.get()) { + auto fma_expand = this->expand_nested_fma_times_add(this->left, + ra); + if (fma_expand.get()) { + return fma_expand; } } @@ -3661,6 +3697,42 @@ namespace graph { class fma_node final : public triple_node { private: //------------------------------------------------------------------------------ +/// @brief Reduced nested fma nodes. +/// +/// fma(...,a - c1,c2) -> fma(...,a,c3) +/// +/// @param[in] sub The sub node to try to expand. +/// @returns The reduced node or null if it could not reduce the node. +//------------------------------------------------------------------------------ + shared_leaf + reduce_nested_fma(shared_subtract sub) { + auto temp = fma_cast(this->left); + if (temp.get()) { + if (is_constant_combineable(sub->get_right(), temp->get_left()) && + is_constant_combineable(sub->get_right(), temp->get_right()) && + is_constant_combineable(this->right, temp->get_right()) && + temp->get_middle()->is_match(sub->get_left())) { + return fma(fma(temp->get_left(), + sub->get_left(), + temp->get_right() - temp->get_left()*sub->get_right()), + sub->get_left(), + this->right - temp->get_right()*sub->get_right()); + } else { + if (temp->get_middle()->is_match(sub->get_left()) && + is_constant_combineable(sub->get_right(), this->right)) { + auto temp2 = temp->reduce_nested_fma(sub); + if (temp2.get()) { + return fma(temp2, + sub->get_left(), + this->right - temp->get_right()*sub->get_right()); + } + } + } + } + return this->shared_from_this(); + } + +//------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// /// @param[in] l Left node pointer. @@ -3787,39 +3859,9 @@ namespace graph { this->right - this->left*ms->get_right()); } - auto lfma = fma_cast(this->left); - if (lfma.get()) { - if (is_constant_combineable(ms->get_right(), lfma->get_left()) && - is_constant_combineable(ms->get_right(), lfma->get_right()) && - is_constant_combineable(this->right, lfma->get_right()) && - lfma->get_middle()->is_match(ms->get_left())) { - return fma(fma(lfma->get_left(), - ms->get_left(), - lfma->get_right() - lfma->get_left()*ms->get_right()), - ms->get_left(), - this->right - lfma->get_right()*ms->get_right()); - } - - auto lfmalfma = fma_cast(lfma->get_left()); - if (lfmalfma.get()) { - if (lfma->get_middle()->is_match(ms->get_left()) && - lfmalfma->get_middle()->is_match(ms->get_left()) && - is_constant_combineable(ms->get_right(), lfmalfma->get_left()) && - is_constant_combineable(ms->get_right(), lfmalfma->get_right()) && - is_constant_combineable(ms->get_right(), lfma->get_right()) && - is_constant_combineable(ms->get_right(), this->right)) { - return fma(fma(fma(lfmalfma->get_left(), - ms->get_left(), - lfmalfma->get_right() - - lfmalfma->get_left()*ms->get_right()), - ms->get_left(), - lfma->get_right() - - lfmalfma->get_right()*ms->get_right()), - ms->get_left(), - this->right - - lfma->get_right()*ms->get_right()); - } - } + auto temp = this->reduce_nested_fma(ms); + if (temp.get() != this) { + return temp; } } diff --git a/graph_framework/node.hpp b/graph_framework/node.hpp index 87b9d85..7d6c9f6 100644 --- a/graph_framework/node.hpp +++ b/graph_framework/node.hpp @@ -335,6 +335,18 @@ namespace graph { /// Convenience type alias for shared leaf nodes. template using shared_leaf = std::shared_ptr>; +//------------------------------------------------------------------------------ +/// @brief Create a null leaf. +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @returns A null leaf. +//------------------------------------------------------------------------------ + template + constexpr shared_leaf null_leaf() { + return shared_leaf (); + } /// Convenience type alias for a vector of output nodes. template using output_nodes = std::vector>; diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index feb0ff2..24155a5 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -1967,6 +1967,17 @@ template void test_multiply() { v1, 1.0)) && "Expected fma(fma(fma(2,x,23),x,30,x,1))"); + +// c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y + auto consume2 = 10.0*(graph::fma(graph::fma(graph::fma(5.0,v1,0.4),v1,0.3),v1,0.3)*v2); + assert(consume2->is_match(graph::fma(graph::fma(graph::fma(50.0, + v1, + 4.0), + v1, + 3.0), + v1, + 3.0)*v2) && + "Expected fma(fma(fma(50,x,4),x,3),x,3)*y"); } //------------------------------------------------------------------------------ @@ -3777,6 +3788,52 @@ template void test_fma() { var_a, -10.0)) && "Expected fma(fma(fma(2,x,16),x,-10),x,-10)"); +/* +// fma(fma(c1,a,c2),b - c3,fma(c4,a,c5) -> fma(fma(c6,a,c8),b,fma(c9,a,c10)) + auto gather3 = graph::fma(graph::fma(2.0, + var_a, + 20.0), + var_b - 2.0, + graph::fma(2.0, + var_a, + 21.0)); + assert(gather3->is_match(graph::fma(graph::fma(2.0,var_a,20.0),var_b,graph::fma(2.0,var_a,-19.0))) && + "Expected fma(fma(2,x,20),y,fma(2,x,-19))"); + +// fma(fma(fma(fma(c1,a,c2),a,c3),a,c4),b - c5,fma(fma(fma(c6,a,c7),a,c8),a,c9)) -> +// fma(fma(fma(fma(c10,a,c11),a,c12),a,c13),b,fma(fma(fma(c14,a,c15),a,c16),a,c17)) + auto gather4 = graph::fma(graph::fma(graph::fma(graph::fma(2.0, + var_a, + 20.0), + var_a, + 30.0), + var_a, + 50.0), + var_b - 2.0, + graph::fma(graph::fma(graph::fma(2.0, + var_a, + 21.0), + var_a, + 31.0), + var_a, + 51.0)); + assert(gather3->is_match(graph::fma(graph::fma(graph::fma(graph::fma(2.0, + var_a, + 20.0), + var_a, + 30.0), + var_a, + 50.0), + var_b , + graph::fma(graph::fma(graph::fma(2.0, + var_a, + -19.0), + var_a, + -29.0), + var_a, + -49.0))) && + "Expected fma(fma(fma(fma(2,x,20),x,30),x,50),b,fma(fma(fma(2,x,-19),-29),-49)"); + */ } //------------------------------------------------------------------------------ -- GitLab From 596cf102f06e8d9d453fffe08cf817cf8d1a5d0c Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 14 Feb 2025 13:22:24 -0500 Subject: [PATCH 2/3] Avoid creating new piecewise constants by prefactoring the scale and offsets out. --- graph_framework/arithmetic.hpp | 120 +++++++------- graph_framework/equilibrium.hpp | 262 +++++++++++++++++-------------- graph_framework/math.hpp | 62 +++++--- graph_framework/piecewise.hpp | 226 ++++++++++++++++++++------ graph_framework/trigonometry.hpp | 46 +++--- graph_tests/arithmetic_test.cpp | 28 ++-- graph_tests/efit_test.cpp | 10 +- graph_tests/math_test.cpp | 6 +- graph_tests/piecewise_test.cpp | 30 ++-- 9 files changed, 490 insertions(+), 300 deletions(-) diff --git a/graph_framework/arithmetic.hpp b/graph_framework/arithmetic.hpp index 6ba0725..4b72eaf 100644 --- a/graph_framework/arithmetic.hpp +++ b/graph_framework/arithmetic.hpp @@ -193,9 +193,11 @@ namespace graph { auto pr1 = piecewise_1D_cast(this->right); if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { - return piecewise_1D(this->evaluate(), pl1->get_arg()); + return piecewise_1D(this->evaluate(), pl1->get_arg(), + pl1->get_scale(), pl1->get_offset()); } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { - return piecewise_1D(this->evaluate(), pr1->get_arg()); + return piecewise_1D(this->evaluate(), pr1->get_arg(), + pr1->get_scale(), pr1->get_offset()); } auto pl2 = piecewise_2D_cast(this->left); @@ -204,13 +206,13 @@ namespace graph { if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { return piecewise_2D(this->evaluate(), pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { return piecewise_2D(this->evaluate(), pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } // Combine 2D and 1D piecewise constants if a row or column matches. @@ -219,29 +221,29 @@ namespace graph { result.add_row(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pr2.get() && pr2->is_col_match(this->left)) { backend::buffer result = pl1->evaluate(); result.add_col(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pl2.get() && pl2->is_row_match(this->right)) { backend::buffer result = pl2->evaluate(); result.add_row(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pl2.get() && pl2->is_col_match(this->right)) { backend::buffer result = pl2->evaluate(); result.add_col(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } // Idenity reductions. @@ -916,9 +918,11 @@ namespace graph { auto pr1 = piecewise_1D_cast(this->right); if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { - return piecewise_1D(this->evaluate(), pl1->get_arg()); + return piecewise_1D(this->evaluate(), pl1->get_arg(), + pl1->get_scale(), pl1->get_offset()); } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { - return piecewise_1D(this->evaluate(), pr1->get_arg()); + return piecewise_1D(this->evaluate(), pr1->get_arg(), + pr1->get_scale(), pr1->get_offset()); } auto pl2 = piecewise_2D_cast(this->left); @@ -927,13 +931,13 @@ namespace graph { if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { return piecewise_2D(this->evaluate(), pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { return piecewise_2D(this->evaluate(), pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } // Combine 2D and 1D piecewise constants if a row or column matches. @@ -942,29 +946,29 @@ namespace graph { result.subtract_row(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pr2.get() && pr2->is_col_match(this->left)) { backend::buffer result = pl1->evaluate(); result.subtract_col(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pl2.get() && pl2->is_row_match(this->right)) { backend::buffer result = pl2->evaluate(); result.subtract_row(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pl2.get() && pl2->is_col_match(this->right)) { backend::buffer result = pl2->evaluate(); result.subtract_col(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } // (c1 + a) - c2 -> c3 + a // c1 - (c2 + a) -> c3 + a @@ -1866,9 +1870,11 @@ namespace graph { auto pr1 = piecewise_1D_cast(this->right); if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { - return piecewise_1D(this->evaluate(), pl1->get_arg()); + return piecewise_1D(this->evaluate(), pl1->get_arg(), + pl1->get_scale(), pl1->get_offset()); } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { - return piecewise_1D(this->evaluate(), pr1->get_arg()); + return piecewise_1D(this->evaluate(), pr1->get_arg(), + pr1->get_scale(), pr1->get_offset()); } auto pl2 = piecewise_2D_cast(this->left); @@ -1877,13 +1883,13 @@ namespace graph { if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { return piecewise_2D(this->evaluate(), pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { return piecewise_2D(this->evaluate(), pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } // Combine 2D and 1D piecewise constants if a row or column matches. @@ -1892,29 +1898,29 @@ namespace graph { result.multiply_row(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pr2.get() && pr2->is_col_match(this->left)) { backend::buffer result = pl1->evaluate(); result.multiply_col(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pl2.get() && pl2->is_row_match(this->right)) { backend::buffer result = pl2->evaluate(); result.multiply_row(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pl2.get() && pl2->is_col_match(this->right)) { backend::buffer result = pl2->evaluate(); result.multiply_col(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } // Move constants to the left. @@ -2798,9 +2804,11 @@ namespace graph { auto pr1 = piecewise_1D_cast(this->right); if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { - return piecewise_1D(this->evaluate(), pl1->get_arg()); + return piecewise_1D(this->evaluate(), pl1->get_arg(), + pl1->get_scale(), pl1->get_offset()); } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { - return piecewise_1D(this->evaluate(), pr1->get_arg()); + return piecewise_1D(this->evaluate(), pr1->get_arg(), + pr1->get_scale(), pr1->get_offset()); } auto pl2 = piecewise_2D_cast(this->left); @@ -2809,13 +2817,13 @@ namespace graph { if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { return piecewise_2D(this->evaluate(), pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { return piecewise_2D(this->evaluate(), pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } // Combine 2D and 1D piecewise constants if a row or column matches. @@ -2824,29 +2832,29 @@ namespace graph { result.divide_row(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pr2.get() && pr2->is_col_match(this->left)) { backend::buffer result = pl1->evaluate(); result.divide_col(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pl2.get() && pl2->is_row_match(this->right)) { backend::buffer result = pl2->evaluate(); result.divide_row(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pl2.get() && pl2->is_col_match(this->right)) { backend::buffer result = pl2->evaluate(); result.divide_col(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } if (this->left->is_match(this->right)) { diff --git a/graph_framework/equilibrium.hpp b/graph_framework/equilibrium.hpp index 4fd4621..6361f3b 100644 --- a/graph_framework/equilibrium.hpp +++ b/graph_framework/equilibrium.hpp @@ -905,6 +905,31 @@ namespace equilibrium { return std::make_shared> (); } +//------------------------------------------------------------------------------ +/// @brief Build a 1D spline. +/// +/// @tparam T Base type of the calculation. +/// @tparam SAFE_MATH Use safe math operations. +/// +/// @param[in] c Array of spline coeffiecents. +/// @param[in] x Spline argument. +/// @param[in] scale Scale factor for argument. +/// @param[in] offset Offset value for argument. +/// @returns The graph expression for a 1D spline. +//------------------------------------------------------------------------------ + template + graph::shared_leaf build_1D_spline(graph::output_nodes c, + graph::shared_leaf x, + const T scale, + const T offset) { + auto c3 = c[3]/(scale*scale*scale); + auto c2 = c[2]/(scale*scale) - static_cast (3.0)*offset*c[3]/(scale*scale*scale); + auto c1 = c[1]/scale - static_cast (2.0)*offset*c[2]/(scale*scale) + static_cast (3.0)*offset*offset*c[3]/(scale*scale*scale); + auto c0 = c[0] - offset*c[1]/scale + offset*offset*c[2]/(scale*scale) - offset*offset*offset*c[3]/(scale*scale*scale); + + return graph::fma(graph::fma(graph::fma(c3, x, c2), x, c1), x, c0); + } + //****************************************************************************** // 2D EFIT equilibrium. //****************************************************************************** @@ -921,9 +946,9 @@ namespace equilibrium { class efit final : public generic { private: /// Minimum psi. - graph::shared_leaf psimin; + const T psimin; /// Psi grid spacing. - graph::shared_leaf dpsi; + const T dpsi; // Temperature spline coefficients. /// Temperature c0. @@ -962,13 +987,13 @@ namespace equilibrium { graph::shared_leaf pres_scale; /// Minimum R. - graph::shared_leaf rmin; + const T rmin; /// R grid spacing. - graph::shared_leaf dr; + const T dr; /// Minimum Z. - graph::shared_leaf zmin; + const T zmin; /// Z grid spacing. - graph::shared_leaf dz; + const T dz; // Fpol spline coefficients. /// Fpol c0. @@ -1036,43 +1061,53 @@ namespace equilibrium { /// Cached magnetic field vector. graph::shared_vector b_cache; -/// Cached magnetic field vector. - graph::shared_leaf psi_norm_cache; +/// Cached magnetic flux. + graph::shared_leaf psi_cache; //------------------------------------------------------------------------------ /// @brief Build psi. /// -/// @param[in] r_norm The normalized radial position. -/// @param[in] z_norm The normalized z position. +/// @param[in] r The normalized radial position. +/// @param[in] r_scale Scale factor for r. +/// @param[in] r_offset Offset factor for r. +/// @param[in] z The normalized z position. +/// @param[in] z_scale Scale factor for z. +/// @param[in] z_offset Offset factor for z. /// @returns The psi value. //------------------------------------------------------------------------------ graph::shared_leaf - build_psi(graph::shared_leaf r_norm, - graph::shared_leaf z_norm) { - auto c00_temp = graph::piecewise_2D(c00, num_cols, r_norm, z_norm); - auto c01_temp = graph::piecewise_2D(c01, num_cols, r_norm, z_norm); - auto c02_temp = graph::piecewise_2D(c02, num_cols, r_norm, z_norm); - auto c03_temp = graph::piecewise_2D(c03, num_cols, r_norm, z_norm); - - auto c10_temp = graph::piecewise_2D(c10, num_cols, r_norm, z_norm); - auto c11_temp = graph::piecewise_2D(c11, num_cols, r_norm, z_norm); - auto c12_temp = graph::piecewise_2D(c12, num_cols, r_norm, z_norm); - auto c13_temp = graph::piecewise_2D(c13, num_cols, r_norm, z_norm); - - auto c20_temp = graph::piecewise_2D(c20, num_cols, r_norm, z_norm); - auto c21_temp = graph::piecewise_2D(c21, num_cols, r_norm, z_norm); - auto c22_temp = graph::piecewise_2D(c22, num_cols, r_norm, z_norm); - auto c23_temp = graph::piecewise_2D(c23, num_cols, r_norm, z_norm); - - auto c30_temp = graph::piecewise_2D(c30, num_cols, r_norm, z_norm); - auto c31_temp = graph::piecewise_2D(c31, num_cols, r_norm, z_norm); - auto c32_temp = graph::piecewise_2D(c32, num_cols, r_norm, z_norm); - auto c33_temp = graph::piecewise_2D(c33, num_cols, r_norm, z_norm); - - auto c0 = ((c03_temp*z_norm + c02_temp)*z_norm + c01_temp)*z_norm + c00_temp; - auto c1 = ((c13_temp*z_norm + c12_temp)*z_norm + c11_temp)*z_norm + c10_temp; - auto c2 = ((c23_temp*z_norm + c22_temp)*z_norm + c21_temp)*z_norm + c20_temp; - auto c3 = ((c33_temp*z_norm + c32_temp)*z_norm + c31_temp)*z_norm + c30_temp; + build_psi(graph::shared_leaf r, + const T r_scale, + const T r_offset, + graph::shared_leaf z, + const T z_scale, + const T z_offset) { + auto c00_temp = graph::piecewise_2D(c00, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c01_temp = graph::piecewise_2D(c01, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c02_temp = graph::piecewise_2D(c02, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c03_temp = graph::piecewise_2D(c03, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + + auto c10_temp = graph::piecewise_2D(c10, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c11_temp = graph::piecewise_2D(c11, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c12_temp = graph::piecewise_2D(c12, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c13_temp = graph::piecewise_2D(c13, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + + auto c20_temp = graph::piecewise_2D(c20, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c21_temp = graph::piecewise_2D(c21, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c22_temp = graph::piecewise_2D(c22, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c23_temp = graph::piecewise_2D(c23, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + + auto c30_temp = graph::piecewise_2D(c30, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c31_temp = graph::piecewise_2D(c31, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c32_temp = graph::piecewise_2D(c32, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + auto c33_temp = graph::piecewise_2D(c33, num_cols, r, r_scale, r_offset, z, z_scale, z_offset); + + auto r_norm = (r - r_offset)/r_scale; + + auto c0 = build_1D_spline({c00_temp, c01_temp, c02_temp, c03_temp}, z, z_scale, z_offset); + auto c1 = build_1D_spline({c10_temp, c11_temp, c12_temp, c13_temp}, z, z_scale, z_offset); + auto c2 = build_1D_spline({c20_temp, c21_temp, c22_temp, c23_temp}, z, z_scale, z_offset); + auto c3 = build_1D_spline({c30_temp, c31_temp, c32_temp, c33_temp}, z, z_scale, z_offset); return ((c3*r_norm + c2)*r_norm + c1)*r_norm + c0; } @@ -1097,38 +1132,29 @@ namespace equilibrium { z_cache = z; auto r = graph::sqrt(x*x + y*y); - auto r_norm = (r - rmin)/dr; - auto z_norm = (z - zmin)/dz; - auto psi = build_psi(r_norm, z_norm); - psi_norm_cache = (psi - psimin)/dpsi; + psi_cache = build_psi(r, dr, rmin, z, dz, zmin); - auto n0_temp = graph::piecewise_1D(ne_c0, psi_norm_cache); - auto n1_temp = graph::piecewise_1D(ne_c1, psi_norm_cache); - auto n2_temp = graph::piecewise_1D(ne_c2, psi_norm_cache); - auto n3_temp = graph::piecewise_1D(ne_c3, psi_norm_cache); + auto n0_temp = graph::piecewise_1D(ne_c0, psi_cache, dpsi, psimin); + auto n1_temp = graph::piecewise_1D(ne_c1, psi_cache, dpsi, psimin); + auto n2_temp = graph::piecewise_1D(ne_c2, psi_cache, dpsi, psimin); + auto n3_temp = graph::piecewise_1D(ne_c3, psi_cache, dpsi, psimin); - ne_cache = ne_scale - * (((n3_temp*psi_norm_cache + n2_temp) * - psi_norm_cache + n1_temp)*psi_norm_cache + n0_temp); + ne_cache = ne_scale*build_1D_spline({n0_temp, n1_temp, n2_temp, n3_temp}, psi_cache, dpsi, psimin); - auto t0_temp = graph::piecewise_1D(te_c0, psi_norm_cache); - auto t1_temp = graph::piecewise_1D(te_c1, psi_norm_cache); - auto t2_temp = graph::piecewise_1D(te_c2, psi_norm_cache); - auto t3_temp = graph::piecewise_1D(te_c3, psi_norm_cache); + auto t0_temp = graph::piecewise_1D(te_c0, psi_cache, dpsi, psimin); + auto t1_temp = graph::piecewise_1D(te_c1, psi_cache, dpsi, psimin); + auto t2_temp = graph::piecewise_1D(te_c2, psi_cache, dpsi, psimin); + auto t3_temp = graph::piecewise_1D(te_c3, psi_cache, dpsi, psimin); - te_cache = te_scale - * (((t3_temp*psi_norm_cache + t2_temp) * - psi_norm_cache + t1_temp)*psi_norm_cache + t0_temp); + te_cache = te_scale*build_1D_spline({t0_temp, t1_temp, t2_temp, t3_temp}, psi_cache, dpsi, psimin); - auto p0_temp = graph::piecewise_1D(pres_c0, psi_norm_cache); - auto p1_temp = graph::piecewise_1D(pres_c1, psi_norm_cache); - auto p2_temp = graph::piecewise_1D(pres_c2, psi_norm_cache); - auto p3_temp = graph::piecewise_1D(pres_c3, psi_norm_cache); + auto p0_temp = graph::piecewise_1D(pres_c0, psi_cache, dpsi, psimin) - psimin; + auto p1_temp = graph::piecewise_1D(pres_c1, psi_cache, dpsi, psimin); + auto p2_temp = graph::piecewise_1D(pres_c2, psi_cache, dpsi, psimin); + auto p3_temp = graph::piecewise_1D(pres_c3, psi_cache, dpsi, psimin); - auto pressure = pres_scale - * (((p3_temp*psi_norm_cache + p2_temp) * - psi_norm_cache + p1_temp)*psi_norm_cache + p0_temp); + auto pressure = pres_scale*build_1D_spline({p0_temp, p1_temp, p2_temp, p3_temp}, psi_cache, dpsi, psimin); auto q = graph::constant (static_cast (1.60218E-19)); @@ -1137,17 +1163,16 @@ namespace equilibrium { auto phi = graph::atan(x, y); - auto br = psi->df(z)/r; + auto br = psi_cache->df(z)/r; - auto b0_temp = graph::piecewise_1D(fpol_c0, r_norm); - auto b1_temp = graph::piecewise_1D(fpol_c1, r_norm); - auto b2_temp = graph::piecewise_1D(fpol_c2, r_norm); - auto b3_temp = graph::piecewise_1D(fpol_c3, r_norm); + auto b0_temp = graph::piecewise_1D(fpol_c0, r, dr, rmin); + auto b1_temp = graph::piecewise_1D(fpol_c1, r, dr, rmin); + auto b2_temp = graph::piecewise_1D(fpol_c2, r, dr, rmin); + auto b3_temp = graph::piecewise_1D(fpol_c3, r, dr, rmin); - auto bp = (((b3_temp*r_norm + b2_temp) * - r_norm + b1_temp)*r_norm + b0_temp)/r; + auto bp = build_1D_spline({b0_temp, b1_temp, b2_temp, b3_temp}, r, dr, rmin)/r; - auto bz = -psi->df(r)/r; + auto bz = -psi_cache->df(r)/r; auto cos = graph::cos(phi); auto sin = graph::sin(phi); @@ -1205,8 +1230,8 @@ namespace equilibrium { /// @param[in] c32 Psi c32 spline coefficient. /// @param[in] c33 Psi c33 spline coefficient. //------------------------------------------------------------------------------ - efit(graph::shared_leaf psimin, - graph::shared_leaf dpsi, + efit(const T psimin, + const T dpsi, const backend::buffer te_c0, const backend::buffer te_c1, const backend::buffer te_c2, @@ -1222,10 +1247,10 @@ namespace equilibrium { const backend::buffer pres_c2, const backend::buffer pres_c3, graph::shared_leaf pres_scale, - graph::shared_leaf rmin, - graph::shared_leaf dr, - graph::shared_leaf zmin, - graph::shared_leaf dz, + const T rmin, + const T dr, + const T zmin, + const T dz, const backend::buffer fpol_c0, const backend::buffer fpol_c1, const backend::buffer fpol_c2, @@ -1376,7 +1401,7 @@ namespace equilibrium { workflow::manager work(device_number); solver::newton(work, { x_axis, z_axis - }, inputs, psi_norm_cache, static_cast (1.0E-30), 1000, static_cast (0.1)); + }, inputs, (psi_cache - psimin)/dpsi, static_cast (1.0E-30), 1000, static_cast (0.1)); work.add_item(inputs, {b_mod}, {}, "bmod_at_axis"); work.compile(); work.run(); @@ -1567,12 +1592,12 @@ namespace equilibrium { nc_close(ncid); sync.unlock(); - auto rmin = graph::constant (static_cast (rmin_value)); - auto dr = graph::constant (static_cast (dr_value)); - auto zmin = graph::constant (static_cast (zmin_value)); - auto dz = graph::constant (static_cast (dz_value)); - auto psimin = graph::constant (static_cast (psimin_value)); - auto dpsi = graph::constant (static_cast (dpsi_value)); + auto rmin = static_cast (rmin_value); + auto dr = static_cast (dr_value); + auto zmin = static_cast (zmin_value); + auto dz = static_cast (dz_value); + auto psimin = static_cast (psimin_value); + auto dpsi = static_cast (dpsi_value); auto pres_scale = graph::constant (static_cast (pres_scale_value)); auto ne_scale = graph::constant (static_cast (ne_scale_value)); auto te_scale = graph::constant (static_cast (te_scale_value)); @@ -1641,11 +1666,11 @@ namespace equilibrium { class vmec final : public generic { private: /// Minimum s on the half grid. - graph::shared_leaf sminh; + const T sminh; /// Minimum s on the full grid. - graph::shared_leaf sminf; + const T sminf; /// Change in s grid. - graph::shared_leaf ds; + const T ds; /// Sign of the jacobian. graph::shared_leaf signj; @@ -1814,13 +1839,13 @@ namespace equilibrium { /// @returns χ(s,u,v) //------------------------------------------------------------------------------ graph::shared_leaf - get_chi(graph::shared_leaf s_norm) { - auto c0_temp = graph::piecewise_1D(chi_c0, s_norm); - auto c1_temp = graph::piecewise_1D(chi_c1, s_norm); - auto c2_temp = graph::piecewise_1D(chi_c2, s_norm); - auto c3_temp = graph::piecewise_1D(chi_c3, s_norm); + get_chi(graph::shared_leaf s) { + auto c0_temp = graph::piecewise_1D(chi_c0, s, ds, sminf); + auto c1_temp = graph::piecewise_1D(chi_c1, s, ds, sminf); + auto c2_temp = graph::piecewise_1D(chi_c2, s, ds, sminf); + auto c3_temp = graph::piecewise_1D(chi_c3, s, ds, sminf); - return ((c3_temp*s_norm + c2_temp)*s_norm + c1_temp)*s_norm + c0_temp; + return build_1D_spline({c0_temp, c1_temp, c2_temp, c3_temp}, s, ds, sminf); } //------------------------------------------------------------------------------ @@ -1854,7 +1879,6 @@ namespace equilibrium { v_cache = v; auto s_norm_f = (s - sminf)/ds; - auto s_norm_h = (s - sminh)/ds; auto zero = graph::zero (); auto r = zero; @@ -1862,27 +1886,27 @@ namespace equilibrium { auto l = zero; for (size_t i = 0, ie = xm.size(); i < ie; i++) { - auto rmnc_c0_temp = graph::piecewise_1D(rmnc_c0[i], s_norm_f); - auto rmnc_c1_temp = graph::piecewise_1D(rmnc_c1[i], s_norm_f); - auto rmnc_c2_temp = graph::piecewise_1D(rmnc_c2[i], s_norm_f); - auto rmnc_c3_temp = graph::piecewise_1D(rmnc_c3[i], s_norm_f); - - auto zmns_c0_temp = graph::piecewise_1D(zmns_c0[i], s_norm_f); - auto zmns_c1_temp = graph::piecewise_1D(zmns_c1[i], s_norm_f); - auto zmns_c2_temp = graph::piecewise_1D(zmns_c2[i], s_norm_f); - auto zmns_c3_temp = graph::piecewise_1D(zmns_c3[i], s_norm_f); - - auto lmns_c0_temp = graph::piecewise_1D(lmns_c0[i], s_norm_h); - auto lmns_c1_temp = graph::piecewise_1D(lmns_c1[i], s_norm_h); - auto lmns_c2_temp = graph::piecewise_1D(lmns_c2[i], s_norm_h); - auto lmns_c3_temp = graph::piecewise_1D(lmns_c3[i], s_norm_h); - - auto rmnc = ((rmnc_c3_temp*s_norm_f + rmnc_c2_temp)*s_norm_f + - rmnc_c1_temp)*s_norm_f + rmnc_c0_temp; - auto zmns = ((zmns_c3_temp*s_norm_f + zmns_c2_temp)*s_norm_f + - zmns_c1_temp)*s_norm_f + zmns_c0_temp; - auto lmns = ((lmns_c3_temp*s_norm_h + lmns_c2_temp)*s_norm_h + - lmns_c1_temp)*s_norm_h + lmns_c0_temp; + auto rmnc_c0_temp = graph::piecewise_1D(rmnc_c0[i], s, ds, sminf); + auto rmnc_c1_temp = graph::piecewise_1D(rmnc_c1[i], s, ds, sminf); + auto rmnc_c2_temp = graph::piecewise_1D(rmnc_c2[i], s, ds, sminf); + auto rmnc_c3_temp = graph::piecewise_1D(rmnc_c3[i], s, ds, sminf); + + auto zmns_c0_temp = graph::piecewise_1D(zmns_c0[i], s, ds, sminf); + auto zmns_c1_temp = graph::piecewise_1D(zmns_c1[i], s, ds, sminf); + auto zmns_c2_temp = graph::piecewise_1D(zmns_c2[i], s, ds, sminf); + auto zmns_c3_temp = graph::piecewise_1D(zmns_c3[i], s, ds, sminf); + + auto lmns_c0_temp = graph::piecewise_1D(lmns_c0[i], s, ds, sminh); + auto lmns_c1_temp = graph::piecewise_1D(lmns_c1[i], s, ds, sminh); + auto lmns_c2_temp = graph::piecewise_1D(lmns_c2[i], s, ds, sminh); + auto lmns_c3_temp = graph::piecewise_1D(lmns_c3[i], s, ds, sminh); + + auto rmnc = build_1D_spline({rmnc_c0_temp, rmnc_c1_temp, rmnc_c2_temp, rmnc_c3_temp}, + s, ds, sminf); + auto zmns = build_1D_spline({zmns_c0_temp, zmns_c1_temp, zmns_c2_temp, zmns_c3_temp}, + s, ds, sminf); + auto lmns = build_1D_spline({lmns_c0_temp, lmns_c1_temp, lmns_c2_temp, lmns_c3_temp}, + s, ds, sminh); auto m = graph::constant (xm[i]); auto n = graph::constant (xn[i]); @@ -1954,9 +1978,9 @@ namespace equilibrium { /// @param[in] xm Poloidal mode numbers. /// @param[in] xn Toroidal mode numbers. //------------------------------------------------------------------------------ - vmec(graph::shared_leaf sminh, - graph::shared_leaf sminf, - graph::shared_leaf ds, + vmec(const T sminh, + const T sminf, + const T ds, graph::shared_leaf dphi, graph::shared_leaf signj, const backend::buffer chi_c0, @@ -2367,9 +2391,9 @@ namespace equilibrium { nc_close(ncid); sync.unlock(); - auto sminf = graph::constant (static_cast (sminf_value)); - auto sminh = graph::constant (static_cast (sminh_value)); - auto ds = graph::constant (static_cast (ds_value)); + auto sminf = static_cast (sminf_value); + auto sminh = static_cast (sminh_value); + auto ds = static_cast (ds_value); auto dphi = graph::constant (static_cast (dphi_value)); auto signj = graph::constant (static_cast (signj_value)); diff --git a/graph_framework/math.hpp b/graph_framework/math.hpp index bda00d4..1b2d982 100644 --- a/graph_framework/math.hpp +++ b/graph_framework/math.hpp @@ -75,15 +75,17 @@ namespace graph { auto ap1 = piecewise_1D_cast(this->arg); if (ap1.get()) { return piecewise_1D(this->evaluate(), - ap1->get_arg()); + ap1->get_arg(), + ap1->get_scale(), + ap1->get_offset()); } auto ap2 = piecewise_2D_cast(this->arg); if (ap2.get()) { return piecewise_2D(this->evaluate(), ap2->get_num_columns(), - ap2->get_left(), - ap2->get_right()); + ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(), + ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset()); } // Handle casses like sqrt(c*x) where c is constant or cases like @@ -371,15 +373,17 @@ namespace graph { auto ap1 = piecewise_1D_cast(this->arg); if (ap1.get()) { return piecewise_1D(this->evaluate(), - ap1->get_arg()); + ap1->get_arg(), + ap1->get_scale(), + ap1->get_offset()); } auto ap2 = piecewise_2D_cast(this->arg); if (ap2.get()) { return piecewise_2D(this->evaluate(), ap2->get_num_columns(), - ap2->get_left(), - ap2->get_right()); + ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(), + ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset()); } // Reduce exp(log(x)) -> x @@ -638,15 +642,17 @@ namespace graph { auto ap1 = piecewise_1D_cast(this->arg); if (ap1.get()) { return piecewise_1D(this->evaluate(), - ap1->get_arg()); + ap1->get_arg(), + ap1->get_scale(), + ap1->get_offset()); } auto ap2 = piecewise_2D_cast(this->arg); if (ap2.get()) { return piecewise_2D(this->evaluate(), ap2->get_num_columns(), - ap2->get_left(), - ap2->get_right()); + ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(), + ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset()); } // Reduce log(exp(x)) -> x @@ -900,9 +906,11 @@ namespace graph { auto pl1 = piecewise_1D_cast(this->left); auto pr1 = piecewise_1D_cast(this->right); if (pl1.get() && (rc.get() || pl1->is_arg_match(this->right))) { - return piecewise_1D(this->evaluate(), pl1->get_arg()); + return piecewise_1D(this->evaluate(), pl1->get_arg(), + pl1->get_scale(), pl1->get_offset()); } else if (pr1.get() && (lc.get() || pr1->is_arg_match(this->left))) { - return piecewise_1D(this->evaluate(), pr1->get_arg()); + return piecewise_1D(this->evaluate(), pr1->get_arg(), + pr1->get_scale(), pr1->get_offset()); } auto pl2 = piecewise_2D_cast(this->left); @@ -910,13 +918,13 @@ namespace graph { if (pl2.get() && (rc.get() || pl2->is_arg_match(this->right))) { return piecewise_2D(this->evaluate(), pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pr2.get() && (lc.get() || pr2->is_arg_match(this->left))) { return piecewise_2D(this->evaluate(), pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } // Combine 2D and 1D piecewise constants if a row or column matches. @@ -925,29 +933,29 @@ namespace graph { result.pow_row(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pr2.get() && pr2->is_col_match(this->left)) { backend::buffer result = pl1->evaluate(); result.pow_col(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pl2.get() && pl2->is_row_match(this->right)) { backend::buffer result = pl2->evaluate(); result.pow_row(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pl2.get() && pl2->is_col_match(this->right)) { backend::buffer result = pl2->evaluate(); result.pow_col(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } auto lp = pow_cast(this->left); @@ -1472,15 +1480,17 @@ namespace graph { auto ap1 = piecewise_1D_cast(this->arg); if (ap1.get()) { return piecewise_1D(this->evaluate(), - ap1->get_arg()); + ap1->get_arg(), + ap1->get_scale(), + ap1->get_offset()); } auto ap2 = piecewise_2D_cast(this->arg); if (ap2.get()) { return piecewise_2D(this->evaluate(), ap2->get_num_columns(), - ap2->get_left(), - ap2->get_right()); + ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(), + ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset()); } return this->shared_from_this(); diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index d956f6f..651c6d9 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -19,11 +19,15 @@ namespace graph { /// @param[in,out] stream String buffer stream. /// @param[in] register_name Reister for the argument. /// @param[in] length Dimension length of argument. +/// @param[in] scale Argument scale factor. +/// @param[in] offset Argument offset factor. //------------------------------------------------------------------------------ template void compile_index(std::ostringstream &stream, const std::string ®ister_name, - const size_t length) { + const size_t length, + const T scale, + const T offset) { const std::string type = jit::smallest_int_type (length); stream << "min(max((" << type @@ -31,7 +35,15 @@ void compile_index(std::ostringstream &stream, if constexpr (jit::is_complex ()) { stream << "real("; } - stream << register_name; + stream << "((" << register_name << " - "; + if constexpr (jit::is_complex ()) { + stream << jit::get_type_string (); + } + stream << offset << ")/"; + if constexpr (jit::is_complex ()) { + stream << jit::get_type_string (); + } + stream << scale << ")"; if constexpr (jit::is_complex ()) { stream << ")"; } @@ -78,6 +90,11 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ template class piecewise_1D_node final : public straight_node { +/// Scale factor for the argument. + const T scale; +/// Offset factor for the argument. + const T offset; + private: //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. @@ -138,13 +155,18 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Construct 1D a piecewise constant node. /// -/// @param[in] d Data to initalize the piecewise constant. -/// @param[in] x Argument. +/// @param[in] d Data to initalize the piecewise constant. +/// @param[in] x Argument. +/// @param[in] scale Scale factor for the argument. +/// @param[in] offset Offset factor for the argument. //------------------------------------------------------------------------------ piecewise_1D_node(const backend::buffer &d, - shared_leaf x) : + shared_leaf x, + const T scale, + const T offset) : straight_node (x, piecewise_1D_node::to_string(d, x)), - data_hash(piecewise_1D_node::hash_data(d)) {} + data_hash(piecewise_1D_node::hash_data(d)), scale(scale), + offset(offset) {} //------------------------------------------------------------------------------ /// @brief Evaluate the results of the piecewise constant. @@ -306,7 +328,8 @@ void compile_index(std::ostringstream &stream, stream << " const " << jit::smallest_int_type (length) << " " << indices[a.get()] << " = "; - compile_index (stream, registers[a.get()], length); + compile_index (stream, registers[a.get()], length, + scale, offset); a->endline(stream, usage); } #endif @@ -336,7 +359,8 @@ void compile_index(std::ostringstream &stream, << ").r"; #else stream << ".read("; - compile_index (stream, registers[a.get()], length); + compile_index (stream, registers[a.get()], length, + scale, offset); stream << ").r"; #endif #ifdef USE_CUDA_TEXTURES @@ -346,7 +370,8 @@ void compile_index(std::ostringstream &stream, << indices[this->arg.get()]; #else stream << ", "; - compile_index (stream, registers[a.get()], length); + compile_index (stream, registers[a.get()], length, + scale, offset); #endif if constexpr (jit::is_complex () || jit::is_double ()) { stream << ")"; @@ -360,7 +385,8 @@ void compile_index(std::ostringstream &stream, << "]"; #else stream << "["; - compile_index (stream, registers[a.get()], length); + compile_index (stream, registers[a.get()], length, + scale, offset); stream << "]"; #endif } @@ -482,9 +508,29 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && - this->arg->is_match(temp->get_arg()) && - (temp->get_size() == this->get_size()); + return temp.get() && + this->arg->is_match(temp->get_arg()) && + (temp->get_size() == this->get_size()) && + (temp->get_scale() == this->scale) && + (temp->get_offset() == this->offset); + } + +//------------------------------------------------------------------------------ +/// @brief Get x argument scale. +/// +/// @returns The scale factor for x. +//------------------------------------------------------------------------------ + T get_scale() const { + return scale; + } + +//------------------------------------------------------------------------------ +/// @brief Get x argument offset. +/// +/// @returns The offset factor for x. +//------------------------------------------------------------------------------ + T get_offset() const { + return offset; } //------------------------------------------------------------------------------ @@ -509,8 +555,12 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ template shared_leaf piecewise_1D(const backend::buffer &d, - shared_leaf x) { - auto temp = std::make_shared> (d, x)->reduce(); + shared_leaf x, + const T scale, + const T offset) { + auto temp = std::make_shared> (d, x, + scale, + offset)->reduce(); // Test for hash collisions. for (size_t i = temp->get_hash(); i < std::numeric_limits::max(); i++) { if (leaf_node::caches.nodes.find(i) == @@ -593,6 +643,15 @@ void compile_index(std::ostringstream &stream, template class piecewise_2D_node final : public branch_node { private: +/// Scale factor for the x argument. + const T x_scale; +/// Offset factor for the x argument. + const T x_offset; +/// Scale factor for the y argument. + const T y_scale; +/// Offset factor for the y argument. + const T y_offset; + //------------------------------------------------------------------------------ /// @brief Convert node pointer to a string. /// @@ -657,18 +716,27 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ /// @brief Construct 2D a piecewise constant node. /// -/// @param[in] d Data to initalize the piecewise constant. -/// @param[in] n Number of columns. -/// @param[in] x X Argument. -/// @param[in] y Y Argument. +/// @param[in] d Data to initalize the piecewise constant. +/// @param[in] n Number of columns. +/// @param[in] x X Argument. +/// @param[in] x_scale Scale factor for the xargument. +/// @param[in] x_offset Offset factor for the x argument. +/// @param[in] y Y Argument. +/// @param[in] y_scale Scale factor for the y argument. +/// @param[in] y_offset Offset factor for the y argument. //------------------------------------------------------------------------------ piecewise_2D_node(const backend::buffer &d, const size_t n, shared_leaf x, - shared_leaf y) : + const T x_scale, + const T x_offset, + shared_leaf y, + const T y_scale, + const T y_offset) : branch_node (x, y, piecewise_2D_node::to_string(d, x, y)), data_hash(piecewise_2D_node::hash_data(d)), - num_columns(n) { + num_columns(n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale), + y_offset(y_offset) { assert(d.size()%n == 0 && "Expected the data buffer to be a multiple of the number of columns."); } @@ -692,6 +760,42 @@ void compile_index(std::ostringstream &stream, num_columns; } +//------------------------------------------------------------------------------ +/// @brief Get x argument scale. +/// +/// @returns The scale factor for x. +//------------------------------------------------------------------------------ + T get_x_scale() const { + return x_scale; + } + +//------------------------------------------------------------------------------ +/// @brief Get x argument offset. +/// +/// @returns The offset factor for x. +//------------------------------------------------------------------------------ + T get_x_offset() const { + return x_offset; + } + +//------------------------------------------------------------------------------ +/// @brief Get y argument scale. +/// +/// @returns The scale factor for y. +//------------------------------------------------------------------------------ + T get_y_scale() const { + return y_scale; + } + +//------------------------------------------------------------------------------ +/// @brief Get y argument offset. +/// +/// @returns The offset factor for x. +//------------------------------------------------------------------------------ + T get_y_offset() const { + return y_offset; + } + //------------------------------------------------------------------------------ /// @brief Evaluate the results of the piecewise constant. /// @@ -872,7 +976,8 @@ void compile_index(std::ostringstream &stream, stream << " const " << jit::smallest_int_type (num_rows) << " " << indices[x.get()] << " = "; - compile_index (stream, registers[x.get()], num_rows); + compile_index (stream, registers[x.get()], num_rows, + x_scale, x_offset); x->endline(stream, usage); } if (indices.find(this->right.get()) == indices.end()) { @@ -886,7 +991,8 @@ void compile_index(std::ostringstream &stream, stream << " const " << jit::smallest_int_type (num_columns) << " " << indices[y.get()] << " = "; - compile_index (stream, registers[y.get()], num_columns); + compile_index (stream, registers[y.get()], num_columns, + y_scale, y_offset); y->endline(stream, usage); } @@ -940,9 +1046,11 @@ void compile_index(std::ostringstream &stream, << ")).r"; #else stream << ".read(uint2("; - compile_index (stream, registers[y.get()], num_columns); + compile_index (stream, registers[y.get()], num_columns, + y_scale, y_offset); stream << ","; - compile_index (stream, registers[x.get()], num_rows); + compile_index (stream, registers[x.get()], num_rows, + x_scale, x_offset); stream << ")).r"; #endif #ifdef USE_CUDA_TEXTURES @@ -954,9 +1062,11 @@ void compile_index(std::ostringstream &stream, << indices[this->left.get()]; #else stream << ", "; - compile_index (stream, registers[y.get()], num_columns); + compile_index (stream, registers[y.get()], num_columns, + y_scale, y_offset); stream << ", "; - compile_index (stream, registers[x.get()], num_rows); + compile_index (stream, registers[x.get()], num_rows, + x_scale, x_offset); #endif if constexpr (jit::is_complex () || jit::is_double ()) { stream << ")"; @@ -970,9 +1080,11 @@ void compile_index(std::ostringstream &stream, << "]"; #else stream << "["; - compile_index (stream, registers[x.get()], num_rows); + compile_index (stream, registers[x.get()], num_rows, + x_scale, x_offset); stream << "*" << num_columns << " + "; - compile_index (stream, registers[y.get()], num_columns); + compile_index (stream, registers[y.get()], num_columns, + y_scale, y_offset); stream << "]"; #endif } @@ -1098,11 +1210,15 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_arg_match(shared_leaf x) { auto temp = piecewise_2D_cast(x); - return temp.get() && - this->left->is_match(temp->get_left()) && - this->right->is_match(temp->get_right()) && - (temp->get_num_rows() == this->get_num_rows()) && - (temp->get_num_columns() == this->get_num_columns()); + return temp.get() && + this->left->is_match(temp->get_left()) && + this->right->is_match(temp->get_right()) && + (temp->get_num_rows() == this->get_num_rows()) && + (temp->get_num_columns() == this->get_num_columns()) && + (temp->get_x_scale() == this->x_scale) && + (temp->get_x_offset() == this->x_offset) && + (temp->get_y_scale() == this->y_scale) && + (temp->get_y_offset() == this->y_offset); } //------------------------------------------------------------------------------ @@ -1113,9 +1229,11 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_row_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && - this->left->is_match(temp->get_arg()) && - (temp->get_size() == this->get_num_rows()); + return temp.get() && + this->left->is_match(temp->get_arg()) && + (temp->get_size() == this->get_num_rows()) && + (temp->get_scale() == this->x_scale) && + (temp->get_offset() == this->x_offset); } //------------------------------------------------------------------------------ @@ -1128,9 +1246,11 @@ void compile_index(std::ostringstream &stream, //------------------------------------------------------------------------------ bool is_col_match(shared_leaf x) { auto temp = piecewise_1D_cast(x); - return temp.get() && - this->right->is_match(temp->get_arg()) && - (temp->get_size() == this->get_num_columns()); + return temp.get() && + this->right->is_match(temp->get_arg()) && + (temp->get_size() == this->get_num_columns()) && + (temp->get_scale() == this->y_scale) && + (temp->get_offset() == this->y_offset); } }; @@ -1140,18 +1260,28 @@ void compile_index(std::ostringstream &stream, /// @tparam T Base type of the calculation. /// @tparam SAFE_MATH Use safe math operations. /// -/// @param[in] d Data to initalize the piecewise constant. -/// @param[in] n Number of columns. -/// @param[in] x Argument. -/// @param[in] y Argument. +/// @param[in] d Data to initalize the piecewise constant. +/// @param[in] n Number of columns. +/// @param[in] x X argument. +/// @param[in] x_scale Scale for x argument. +/// @param[in] x_offset Offset for x argument. +/// @param[in] y Argument. +/// @param[in] y_scale Scale for y argument. +/// @param[in] y_offset Offset for y argument. /// @returns A reduced sqrt node. //------------------------------------------------------------------------------ template shared_leaf piecewise_2D(const backend::buffer &d, - const size_t n, - shared_leaf x, - shared_leaf y) { - auto temp = std::make_shared> (d, n, x, y)->reduce(); + const size_t n, + shared_leaf x, + const T x_scale, + const T x_offset, + shared_leaf y, + const T y_scale, + const T y_offset) { + auto temp = std::make_shared> (d, n, + x, x_scale, x_offset, + y, y_scale, y_offset)->reduce(); // Test for hash collisions. for (size_t i = temp->get_hash(); i < std::numeric_limits::max(); i++) { if (leaf_node::caches.nodes.find(i) == diff --git a/graph_framework/trigonometry.hpp b/graph_framework/trigonometry.hpp index d5fd80e..20cbb95 100644 --- a/graph_framework/trigonometry.hpp +++ b/graph_framework/trigonometry.hpp @@ -69,15 +69,17 @@ namespace graph { auto ap1 = piecewise_1D_cast(this->arg); if (ap1.get()) { return piecewise_1D(this->evaluate(), - ap1->get_arg()); + ap1->get_arg(), + ap1->get_scale(), + ap1->get_offset()); } auto ap2 = piecewise_2D_cast(this->arg); if (ap2.get()) { return piecewise_2D(this->evaluate(), ap2->get_num_columns(), - ap2->get_left(), - ap2->get_right()); + ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(), + ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset()); } // Sin(ArcTan(x, y)) -> y/Sqrt(x^2 + y^2) @@ -318,15 +320,17 @@ namespace graph { auto ap1 = piecewise_1D_cast(this->arg); if (ap1.get()) { return piecewise_1D(this->evaluate(), - ap1->get_arg()); + ap1->get_arg(), + ap1->get_scale(), + ap1->get_offset()); } auto ap2 = piecewise_2D_cast(this->arg); if (ap2.get()) { return piecewise_2D(this->evaluate(), ap2->get_num_columns(), - ap2->get_left(), - ap2->get_right()); + ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(), + ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset()); } // Cos(ArcTan(x, y)) -> x/Sqrt(x^2 + y^2) @@ -595,9 +599,11 @@ namespace graph { auto pr1 = piecewise_1D_cast(this->right); if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) { - return piecewise_1D(this->evaluate(), pl1->get_arg()); + return piecewise_1D(this->evaluate(), pl1->get_arg(), + pl1->get_scale(), pl1->get_offset()); } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) { - return piecewise_1D(this->evaluate(), pr1->get_arg()); + return piecewise_1D(this->evaluate(), pr1->get_arg(), + pr1->get_scale(), pr1->get_offset()); } auto pl2 = piecewise_2D_cast(this->left); @@ -606,13 +612,13 @@ namespace graph { if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) { return piecewise_2D(this->evaluate(), pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) { return piecewise_2D(this->evaluate(), pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } // Combine 2D and 1D piecewise constants if a row or column matches. @@ -621,29 +627,29 @@ namespace graph { result.atan_row(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pr2.get() && pr2->is_col_match(this->left)) { backend::buffer result = pl1->evaluate(); result.atan_col(pr2->evaluate()); return piecewise_2D(result, pr2->get_num_columns(), - pr2->get_left(), - pr2->get_right()); + pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(), + pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset()); } else if (pl2.get() && pl2->is_row_match(this->right)) { backend::buffer result = pl2->evaluate(); result.atan_row(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } else if (pl2.get() && pl2->is_col_match(this->right)) { backend::buffer result = pl2->evaluate(); result.atan_col(pr1->evaluate()); return piecewise_2D(result, pl2->get_num_columns(), - pl2->get_left(), - pl2->get_right()); + pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(), + pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset()); } return this->shared_from_this(); diff --git a/graph_tests/arithmetic_test.cpp b/graph_tests/arithmetic_test.cpp index 24155a5..07278e7 100644 --- a/graph_tests/arithmetic_test.cpp +++ b/graph_tests/arithmetic_test.cpp @@ -262,7 +262,8 @@ template void test_add() { assert(!three->is_all_variables() && "Did not expect a variable."); assert(three->is_power_like() && "Expected a power like."); auto constant_add = three + graph::piecewise_1D (std::vector ({static_cast (1.0), - static_cast (2.0)}), var_a); + static_cast (2.0)}), + var_a, 1.0, 0.0); assert(constant_add->is_constant() && "Expected a constant."); assert(!constant_add->is_all_variables() && "Did not expect a variable."); assert(constant_add->is_power_like() && "Expected a power like."); @@ -736,7 +737,8 @@ template void test_subtract() { assert(!zero->is_all_variables() && "Did not expect a variable."); assert(zero->is_power_like() && "Expected a power like."); auto constant_sub = one - graph::piecewise_1D (std::vector ({static_cast (1.0), - static_cast (2.0)}), var_a); + static_cast (2.0)}), var_a, + 1.0, 0.0); assert(constant_sub->is_constant() && "Expected a constant."); assert(!constant_sub->is_all_variables() && "Did not expect a variable."); assert(constant_sub->is_power_like() && "Expected a power like."); @@ -1405,7 +1407,8 @@ template void test_multiply() { assert(!two_times_three->is_all_variables() && "Did not expect a variable."); assert(two_times_three->is_power_like() && "Expected a power like."); auto constant_mul = three*graph::piecewise_1D (std::vector ({static_cast (1.0), - static_cast (2.0)}), variable); + static_cast (2.0)}), + variable, 1.0, 0.0); assert(constant_mul->is_constant() && "Expected a constant."); assert(!constant_mul->is_all_variables() && "Did not expect a variable."); assert(constant_mul->is_power_like() && "Expected a power like."); @@ -2382,7 +2385,8 @@ template void test_divide() { assert(!two_divided_three->is_all_variables() && "Did not expect a variable."); assert(two_divided_three->is_power_like() && "Expected a power like."); auto constant_div = two_divided_three/graph::piecewise_1D (std::vector ({static_cast (1.0), - static_cast (2.0)}), variable); + static_cast (2.0)}), + variable, 1.0, 0.0); assert(constant_div->is_constant() && "Expected a constant."); assert(!constant_div->is_all_variables() && "Did not expect a variable."); assert(constant_div->is_power_like() && "Expected a power like."); @@ -3231,7 +3235,7 @@ template void test_fma() { auto constant_fma = graph::fma(one_two_three, graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), - var_a), + var_a, 1.0, 0.0), one); assert(!constant_fma->is_all_variables() && "Did not expect a variable."); assert(constant_fma->is_power_like() && "Expected a power like."); @@ -3261,7 +3265,7 @@ template void test_fma() { auto piecewise1 = graph::fma (2.0, graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), - var_a)*var_a, + var_a, 1.0, 0.0)*var_a, var_b); auto piecewise1_cast = graph::fma_cast(piecewise1); assert(piecewise1_cast.get() && "Expected a fma node."); @@ -3269,8 +3273,9 @@ template void test_fma() { "Expected a piecewise_1D node."); auto piecewise2 = graph::fma (2.0, graph::piecewise_2D (std::vector ({static_cast (1.0), - static_cast (2.0)}), - 1, var_a, var_b)*var_a, + static_cast (2.0)}), 1, + var_a, 1.0, 0.0, + var_b, 1.0, 0.0)*var_a, var_b); auto piecewise2_cast = graph::fma_cast(piecewise2); assert(piecewise2_cast.get() && "Expected a fma node."); @@ -3517,12 +3522,13 @@ template void test_fma() { // fma(p2,p1,a) -> fma(p1,p2,a) auto p1 = graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0)}), - var_a); + var_a, 1.0, 0.0); auto p2 = graph::piecewise_2D (std::vector ({static_cast (1.0), static_cast (2.0), static_cast (3.0), - static_cast (4.0)}), - 2, var_b, var_c); + static_cast (4.0)}), 2, + var_b, 1.0, 0.0, + var_c, 1.0, 0.0); auto fma_promote = graph::fma(p2, p1, var_a); auto fma_promote_cast = graph::fma_cast(fma_promote); assert(fma_promote_cast.get() && "Expected a fma node."); diff --git a/graph_tests/efit_test.cpp b/graph_tests/efit_test.cpp index c565e9b..f0a11f2 100644 --- a/graph_tests/efit_test.cpp +++ b/graph_tests/efit_test.cpp @@ -153,15 +153,15 @@ void run_test() { work.run(); for (size_t i = 0, ie = gold.r_grid.size()*gold.z_grid.size(); i < ie; i++) { - check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 10.0E-11, + check_error(work.check_value(i, bvec->get_x()), gold.bx_grid[i], 4.0E-12, "Expected a match in bx."); - check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 1.0E-20, + check_error(work.check_value(i, bvec->get_y()), gold.by_grid[i], 2.0E-30, "Expected a match in by."); - check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 5.0E-12, + check_error(work.check_value(i, bvec->get_z()), gold.bz_grid[i], 3.0E-13, "Expected a match in bz."); - check_error(work.check_value(i, ne), gold.ne_grid[i], 2.1E-12, + check_error(work.check_value(i, ne), gold.ne_grid[i], 3.0E-13, "Expected a match in ne."); - check_error(work.check_value(i, te), gold.te_grid[i], 2.1E-12, + check_error(work.check_value(i, te), gold.te_grid[i], 3.0E-13, "Expected a match in te."); } } diff --git a/graph_tests/math_test.cpp b/graph_tests/math_test.cpp index 972ff51..4f4c783 100644 --- a/graph_tests/math_test.cpp +++ b/graph_tests/math_test.cpp @@ -104,7 +104,8 @@ void test_sqrt() { #endif // Test node properties. auto sqrt_const = graph::sqrt(graph::piecewise_1D (std::vector ({static_cast (1.0), - static_cast (2.0)}), var)); + static_cast (2.0)}), + var, 1.0, 0.0)); assert(sqrt_const->is_constant() && "Expected a constant."); assert(!sqrt_const->is_all_variables() && "Did not expect a variable."); assert(sqrt_const->is_power_like() && "Expected a power like."); @@ -265,7 +266,8 @@ void test_pow() { // Test node properties. auto var_a = graph::variable (1, ""); auto pow_const = graph::pow(3.0, graph::piecewise_1D (std::vector ({static_cast (1.0), - static_cast (2.0)}), var_a)); + static_cast (2.0)}), + var_a, 1.0, 0.0)); assert(pow_const->is_constant() && "Expected a constant."); assert(!pow_const->is_all_variables() && "Did not expect a variable."); assert(pow_const->is_power_like() && "Expected a power like."); diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index b2e7b6c..c6cb5aa 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -80,13 +80,16 @@ template void piecewise_1D() { auto b = graph::variable (1, ""); auto p1 = graph::piecewise_1D (std::vector ({static_cast (1.0), static_cast (2.0), - static_cast (3.0)}), a); + static_cast (3.0)}), + a, 1.0, 0.0); auto p2 = graph::piecewise_1D (std::vector ({static_cast (2.0), static_cast (4.0), - static_cast (6.0)}), b); + static_cast (6.0)}), + b, 1.0, 0.0); auto p3 = graph::piecewise_1D (std::vector ({static_cast (2.0), static_cast (4.0), - static_cast (6.0)}), a); + static_cast (6.0)}), + a, 1.0, 0.0); assert(graph::constant_cast(p1*0.0).get() && "Expected a constant node."); @@ -225,7 +228,8 @@ template void piecewise_1D() { auto pc = graph::piecewise_1D (std::vector ({static_cast (10.0), static_cast (10.0), - static_cast (10.0)}), a); + static_cast (10.0)}), + a, 1.0, 0.0); assert(graph::constant_cast(pc).get() && "Expected a constant."); @@ -282,21 +286,21 @@ template void piecewise_2D() { auto p1 = graph::piecewise_2D (std::vector ({ static_cast (1.0), static_cast (2.0), static_cast (3.0), static_cast (4.0) - }), 2, ax, ay); + }), 2, ax, 1.0, 0.0, ay, 1.0, 0.0); auto p2 = graph::piecewise_2D (std::vector ({ static_cast (2.0), static_cast (4.0), static_cast (6.0), static_cast (10.0) - }), 2, bx, by); + }), 2, bx, 1.0, 0.0, by, 1.0, 0.0); auto p3 = graph::piecewise_2D (std::vector ({ static_cast (2.0), static_cast (4.0), static_cast (6.0), static_cast (10.0) - }), 2, ax, ay); + }), 2, ax, 1.0, 0.0, ay, 1.0, 0.0); auto p4 = graph::piecewise_1D (std::vector ({ static_cast (2.0), static_cast (4.0) - }), ax); + }), ax, 1.0, 0.0); auto p5 = graph::piecewise_1D (std::vector ({ static_cast (2.0), static_cast (4.0) - }), ay); + }), ay, 1.0, 0.0); assert(graph::constant_cast(p1*0.0).get() && "Expected a constant node."); @@ -592,7 +596,7 @@ template void piecewise_2D() { static_cast (10.0), static_cast (10.0), static_cast (10.0)}), - 2, ax, bx); + 2, ax, 1.0, 0.0, bx, 1.0, 0.0); assert(graph::constant_cast(pc).get() && "Expected a constant."); @@ -600,17 +604,17 @@ template void piecewise_2D() { static_cast (1.0), static_cast (2.0), static_cast (3.0) - }), ax); + }), ax, 1.0, 0.0); auto pcc = graph::piecewise_1D (std::vector ({ static_cast (1.0), static_cast (2.0), static_cast (3.0) - }), ay); + }), ay, 1.0, 0.0); auto p2Dc = graph::piecewise_2D (std::vector ({ static_cast (1.0), static_cast (2.0), static_cast (3.0), static_cast (4.0), static_cast (5.0), static_cast (6.0) - }), 2, ax, ay); + }), 2, ax, 1.0, 0.0, ay, 1.0, 0.0); auto row_test = prc + p2Dc; auto row_test_cast = graph::piecewise_2D_cast(row_test); -- GitLab From 531550d80d12105ba27d78fffe029ba7e20d5443 Mon Sep 17 00:00:00 2001 From: cianciosa Date: Fri, 14 Feb 2025 14:08:50 -0500 Subject: [PATCH 3/3] Add test for magnetic field divergence. --- graph_tests/efit_test.cpp | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/graph_tests/efit_test.cpp b/graph_tests/efit_test.cpp index f0a11f2..bf71592 100644 --- a/graph_tests/efit_test.cpp +++ b/graph_tests/efit_test.cpp @@ -110,6 +110,19 @@ void check_error(const T test, const T expected, const T tolarance, assert(error*error <= tolarance && name); } +//------------------------------------------------------------------------------ +/// @brief Check error. +/// +/// @param[in] test Test value. +/// @param[in] expected Expected result. +/// @param[in] name Name of the test. +//------------------------------------------------------------------------------ +template +void check_error(const T test, const T tolarance, + const char *name) { + assert(test*test <= tolarance && name); +} + //------------------------------------------------------------------------------ /// @brief Run tests. /// @@ -141,13 +154,18 @@ void run_test() { auto ne = eq->get_electron_density(x, y, z); auto te = eq->get_electron_temperature(x, y, z); +// Test the divergence. + auto div = bvec->get_x()->df(x) + + bvec->get_y()->df(y) + + bvec->get_z()->df(z); + workflow::manager work(0); work.add_item({ graph::variable_cast(x), graph::variable_cast(y), graph::variable_cast(z) }, { - bvec->get_x(), bvec->get_y(), bvec->get_z(), ne, te + bvec->get_x(), bvec->get_y(), bvec->get_z(), ne, te, div }, {}, "test_kernel"); work.compile(); work.run(); @@ -163,6 +181,8 @@ void run_test() { "Expected a match in ne."); check_error(work.check_value(i, te), gold.te_grid[i], 3.0E-13, "Expected a match in te."); + check_error(work.check_value(i, div), 1.0E-20, + "Expected div(B)=0."); } } -- GitLab