Commit f5f4b3a5 authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Merge branch 'sqrt_reduction' into 'main'

Sqrt reduction

See merge request !59
parents 77171741 196c46cc
Loading
Loading
Loading
Loading
+174 −0
Original line number Diff line number Diff line
@@ -1294,6 +1294,93 @@
				);
				LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
				MACOSX_DEPLOYMENT_TARGET = 14.5;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Debug;
@@ -1312,6 +1399,93 @@
				);
				LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
				MACOSX_DEPLOYMENT_TARGET = 14.5;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Release;
+172 −2
Original line number Diff line number Diff line
@@ -2024,6 +2024,45 @@ namespace graph {
                    return this->right/pow(lp->get_left(), -lp->get_right());
                }
            }
//  a^b*c^b -> (a*c)^b
            if (lp.get() && rp.get()) {
                if (lp->get_right()->is_match(rp->get_right())) {
                    return pow(lp->get_left()*rp->get_left(), lp->get_right());
                }
            }
// (a*b^c)*d^c -> a*(b*d)^c
// (a^c*b)*d^c -> b*(a*d)^c
// a^c*(b*d^c) -> b*(a*d)^c
// a^c*(b^c*d) -> d*(a*b)^c
            if (lm.get() && rp.get()) {
                auto lmlp = pow_cast(lm->get_left());
                auto lmrp = pow_cast(lm->get_right());
                if (lmrp.get()) {
                    if (lmrp->get_right()->is_match(rp->get_right())) {
                        return lm->get_left()*pow(lmrp->get_left()*rp->get_left(),
                                                  rp->get_right());
                    }
                } else if (lmlp.get()) {
                    if (lmlp->get_right()->is_match(rp->get_right())) {
                        return lm->get_right()*pow(lmlp->get_left()*rp->get_left(),
                                                   rp->get_right());
                    }
                }
            } else if (rm.get() && lp.get()) {
                auto rmlp = pow_cast(rm->get_left());
                auto rmrp = pow_cast(rm->get_right());
                if (rmrp.get()) {
                    if (rmrp->get_right()->is_match(lp->get_right())) {
                        return rm->get_left()*pow(lp->get_left()*rmrp->get_left(),
                                                  lp->get_right());
                    }
                } else if (rmlp.get()) {
                    if (rmlp->get_right()->is_match(lp->get_right())) {
                        return rm->get_right()*pow(lp->get_left()*rmlp->get_left(),
                                                   lp->get_right());
                    }
                }
            }

//  (b*a)^c*a^d -> b^c*a^(c + d)
//  (a*b)^c*a^d -> b^c*a^(c + d)
@@ -2931,6 +2970,14 @@ namespace graph {
                    }
                }
            }

//  a^b/c^b -> (a/c)^b
            if (lp.get() && rp.get()) {
                if (lp->get_right()->is_match(rp->get_right())) {
                    return pow(lp->get_left()/rp->get_left(), lp->get_right());
                }
            }

//  (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e
//  (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e
//  (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e
@@ -3706,6 +3753,30 @@ namespace graph {
                    }
                }

//  fma(a,b*c,b*d) -> b*fma(a,c,d)
//  fma(a,c*b,b*d) -> b*fma(a,c,d)
//  fma(a,b*c,d*b) -> b*fma(a,c,d)
//  fma(a,c*b,d*b) -> b*fma(a,c,d)
                if (mm.get()) {
                    if (mm->get_left()->is_match(rm->get_left())) {
                        return mm->get_left()*fma(this->left,
                                                  mm->get_right(),
                                                  rm->get_right());
                    } else if (mm->get_left()->is_match(rm->get_right())) {
                        return mm->get_left()*fma(this->left,
                                                  mm->get_right(),
                                                  rm->get_left());
                    } else if (mm->get_right()->is_match(rm->get_left())) {
                        return mm->get_right()*fma(this->left,
                                                   mm->get_left(),
                                                   rm->get_right());
                    } else if (mm->get_right()->is_match(rm->get_right())) {
                        return mm->get_right()*fma(this->left,
                                                   mm->get_left(),
                                                   rm->get_left());
                    }
                }

//  Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c)
//  Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c)
                if ((lm.get() || mm.get()) &&
@@ -3803,6 +3874,19 @@ namespace graph {
                }
            }

//  fma(a,b*c,b) -> b*fma(a,c,1)
            if (mm.get()) {
                if (mm->get_left()->is_match(this->right)) {
                    return mm->get_left()*fma(this->left,
                                              mm->get_right(),
                                              1.0);
                } else if (mm->get_right()->is_match(this->right)) {
                    return mm->get_right()*fma(this->left,
                                              mm->get_left(),
                                              1.0);
                }
            }

//  fma(c1,a,c2/b) -> c1*(a + c3/b)
//  fma(a,c1,c2/b) -> c1*(a + c3/b)
            auto rd = divide_cast(this->right);
@@ -4294,8 +4378,10 @@ namespace graph {
                }
            } else if (this->middle->is_all_variables()) {
                auto rdm = this->right/this->middle;
                if (rdm->get_complexity() < this->middle->get_complexity() +
                                            this->right->get_complexity()) {
                auto rdmc = constant_cast(rdm->get_power_exponent());
                if ((rdm->get_complexity() < this->middle->get_complexity() +
                                             this->right->get_complexity()) &&
                    !(rdmc.get() && rdmc->evaluate().is_negative())) {
                    return (this->left + rdm)*this->middle;
                }
            }
@@ -4318,6 +4404,90 @@ namespace graph {
                    return this->left/pow(mp->get_left(), -mp->get_right()) +
                           this->right;
                }

//  fma(2,a^2,a) -> a*fma(2,a,1)
//  Note this case is handled eailer. fma(2,a,a^2) -> a*fma(2,1,a)
                if (is_variable_combineable(this->middle,
                                            this->right)) {
                    auto temp = this->right/this->middle;
                    auto temp_exponent = constant_cast(temp->get_power_exponent());
                    if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) {
                        return this->right*fma(this->left,
                                               this->middle/this->right,
                                               1.0);
                    }
                }
            }

//  a^b*c^b + d -> (a*c)^b + d
            if (lp.get() && mp.get()) {
                if (lp->get_right()->is_match(mp->get_right())) {
                    return pow(lp->get_left()*mp->get_left(),
                               lp->get_right()) +
                           this->right;
                }
            }

//  fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b)
            if (rm.get() && mp.get()) {
                auto mplm = multiply_cast(mp->get_left());
                if (mplm.get()) {
                    if (is_variable_combineable(mplm->get_left(),
                                                rm->get_left())) {
                        auto temp = pow(mplm->get_left(),
                                        mp->get_right());
                        return temp*fma(this->left,
                                        this->middle/temp,
                                        this->right/temp);
                    } else if (is_variable_combineable(mplm->get_right(),
                                                       rm->get_left())) {
                        auto temp = pow(mplm->get_right(),
                                        mp->get_right());
                        return temp*fma(this->left,
                                        this->middle/temp,
                                        this->right/temp);
                    }
                }
            }
//  fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
            if (rfma.get() && mp.get()) {
                auto mplm = multiply_cast(mp->get_left());
                if (mplm.get()) {
                    if (is_variable_combineable(mplm->get_left(),
                                                rfma->get_left())) {
                        auto temp = pow(mplm->get_left(),
                                        mp->get_right());
                        return fma(temp,
                                   fma(this->left,
                                       this->middle/temp,
                                       rfma->get_middle()),
                                   rfma->get_right());
                    } else if (is_variable_combineable(mplm->get_right(),
                                                       rfma->get_left())) {
                        auto temp = pow(mplm->get_right(),
                                        mp->get_right());
                        return fma(temp,
                                   fma(this->left,
                                       this->middle/temp,
                                       rfma->get_middle()),
                                   rfma->get_right());
                    }
                    
//  fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c))
                    auto rfmamm = multiply_cast(rfma->get_middle());
                    if (rfmamm.get()) {
                        if (is_variable_combineable(mplm->get_left(),
                                                    rfmamm->get_left())) {
                            auto temp = pow(mplm->get_left(),
                                            mp->get_right());
                            return temp*fma(this->left,
                                            this->middle/temp,
                                            fma(rfma->get_left(),
                                                rfma->get_middle()/temp,
                                                rfma->get_right()));
                        }
                    }
                }
            }

//  fma(a,b/c,b/d) -> b*(a/c + 1/d)
+24 −48
Original line number Diff line number Diff line
@@ -1069,22 +1069,12 @@ namespace equilibrium {
            auto c32_temp = graph::piecewise_2D(c32, num_cols, r_norm, z_norm);
            auto c33_temp = graph::piecewise_2D(c33, num_cols, r_norm, z_norm);

            return   c00_temp
                   + c01_temp*z_norm
                   + c02_temp*(z_norm*z_norm)
                   + c03_temp*(z_norm*z_norm*z_norm)
                   + c10_temp*r_norm
                   + c11_temp*r_norm*z_norm
                   + c12_temp*r_norm*(z_norm*z_norm)
                   + c13_temp*r_norm*(z_norm*z_norm*z_norm)
                   + c20_temp*(r_norm*r_norm)
                   + c21_temp*(r_norm*r_norm)*z_norm
                   + c22_temp*(r_norm*r_norm)*(z_norm*z_norm)
                   + c23_temp*(r_norm*r_norm)*(z_norm*z_norm*z_norm)
                   + c30_temp*(r_norm*r_norm*r_norm)
                   + c31_temp*(r_norm*r_norm*r_norm)*z_norm
                   + c32_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm)
                   + c33_temp*(r_norm*r_norm*r_norm)*(z_norm*z_norm*z_norm);
            auto c0 = ((c03_temp*z_norm + c02_temp)*z_norm + c01_temp)*z_norm + c00_temp;
            auto c1 = ((c13_temp*z_norm + c12_temp)*z_norm + c11_temp)*z_norm + c10_temp;
            auto c2 = ((c23_temp*z_norm + c22_temp)*z_norm + c21_temp)*z_norm + c20_temp;
            auto c3 = ((c33_temp*z_norm + c32_temp)*z_norm + c31_temp)*z_norm + c30_temp;

            return ((c3*r_norm + c2)*r_norm + c1)*r_norm + c0;
        }

//------------------------------------------------------------------------------
@@ -1118,30 +1108,27 @@ namespace equilibrium {
                auto n2_temp = graph::piecewise_1D(ne_c2, psi_norm_cache);
                auto n3_temp = graph::piecewise_1D(ne_c3, psi_norm_cache);

                ne_cache = ne_scale*(n0_temp +
                                     n1_temp*psi_norm_cache +
                                     n2_temp*psi_norm_cache*psi_norm_cache +
                                     n3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache);
                ne_cache = ne_scale
                         * (((n3_temp*psi_norm_cache + n2_temp) *
                             psi_norm_cache + n1_temp)*psi_norm_cache + n0_temp);

                auto t0_temp = graph::piecewise_1D(te_c0, psi_norm_cache);
                auto t1_temp = graph::piecewise_1D(te_c1, psi_norm_cache);
                auto t2_temp = graph::piecewise_1D(te_c2, psi_norm_cache);
                auto t3_temp = graph::piecewise_1D(te_c3, psi_norm_cache);

                te_cache = te_scale*(t0_temp +
                                     t1_temp*psi_norm_cache +
                                     t2_temp*psi_norm_cache*psi_norm_cache +
                                     t3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache);
                te_cache = te_scale
                         * (((t3_temp*psi_norm_cache + t2_temp) *
                             psi_norm_cache + t1_temp)*psi_norm_cache + t0_temp);

                auto p0_temp = graph::piecewise_1D(pres_c0, psi_norm_cache);
                auto p1_temp = graph::piecewise_1D(pres_c1, psi_norm_cache);
                auto p2_temp = graph::piecewise_1D(pres_c2, psi_norm_cache);
                auto p3_temp = graph::piecewise_1D(pres_c3, psi_norm_cache);

                auto pressure = pres_scale*(p0_temp +
                                            p1_temp*psi_norm_cache +
                                            p2_temp*psi_norm_cache*psi_norm_cache +
                                            p3_temp*psi_norm_cache*psi_norm_cache*psi_norm_cache);
                auto pressure = pres_scale
                              * (((p3_temp*psi_norm_cache + p2_temp) *
                                  psi_norm_cache + p1_temp)*psi_norm_cache + p0_temp);

                auto q = graph::constant<T, SAFE_MATH> (static_cast<T> (1.60218E-19));

@@ -1157,10 +1144,8 @@ namespace equilibrium {
                auto b2_temp = graph::piecewise_1D(fpol_c2, r_norm);
                auto b3_temp = graph::piecewise_1D(fpol_c3, r_norm);

                auto bp = (b0_temp +
                           b1_temp*r_norm +
                           b2_temp*r_norm*r_norm +
                           b3_temp*r_norm*r_norm*r_norm)/r;
                auto bp = (((b3_temp*r_norm + b2_temp) *
                            r_norm + b1_temp)*r_norm + b0_temp)/r;

                auto bz = -psi->df(r)/r;

@@ -1835,10 +1820,7 @@ namespace equilibrium {
            auto c2_temp = graph::piecewise_1D(chi_c2, s_norm);
            auto c3_temp = graph::piecewise_1D(chi_c3, s_norm);

            return c0_temp +
                   c1_temp*s_norm +
                   c2_temp*s_norm*s_norm +
                   c3_temp*s_norm*s_norm*s_norm;
            return ((c3_temp*s_norm + c2_temp)*s_norm + c1_temp)*s_norm + c0_temp;
        }

//------------------------------------------------------------------------------
@@ -1895,18 +1877,12 @@ namespace equilibrium {
                    auto lmns_c2_temp = graph::piecewise_1D(lmns_c2[i], s_norm_h);
                    auto lmns_c3_temp = graph::piecewise_1D(lmns_c3[i], s_norm_h);

                    auto rmnc = rmnc_c0_temp
                              + rmnc_c1_temp*s_norm_f
                              + rmnc_c2_temp*s_norm_f*s_norm_f
                              + rmnc_c3_temp*s_norm_f*s_norm_f*s_norm_f;
                    auto zmns = zmns_c0_temp
                              + zmns_c1_temp*s_norm_f
                              + zmns_c2_temp*s_norm_f*s_norm_f
                              + zmns_c3_temp*s_norm_f*s_norm_f*s_norm_f;
                    auto lmns = lmns_c0_temp
                              + lmns_c1_temp*s_norm_h
                              + lmns_c2_temp*s_norm_h*s_norm_h
                              + lmns_c3_temp*s_norm_h*s_norm_h*s_norm_h;
                    auto rmnc = ((rmnc_c3_temp*s_norm_f + rmnc_c2_temp)*s_norm_f +
                                 rmnc_c1_temp)*s_norm_f + rmnc_c0_temp;
                    auto zmns = ((zmns_c3_temp*s_norm_f + zmns_c2_temp)*s_norm_f +
                                 zmns_c1_temp)*s_norm_f + zmns_c0_temp;
                    auto lmns = ((lmns_c3_temp*s_norm_h + lmns_c2_temp)*s_norm_h +
                                 lmns_c1_temp)*s_norm_h + lmns_c0_temp;

                    auto m = graph::constant<T, SAFE_MATH> (xm[i]);
                    auto n = graph::constant<T, SAFE_MATH> (xn[i]);
+64 −2

File changed.

Preview size limit exceeded, changes collapsed.

+13 −0
Original line number Diff line number Diff line
@@ -660,6 +660,19 @@ namespace graph {
        return constant<T, SAFE_MATH> (static_cast<T> (1.0));
    }

//------------------------------------------------------------------------------
///  @brief Create a one constant.
///
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @returns A one constant.
//------------------------------------------------------------------------------
    template<jit::float_scalar T, bool SAFE_MATH>
    constexpr shared_leaf<T, SAFE_MATH> none() {
        return constant<T, SAFE_MATH> (static_cast<T> (-1.0));
    }

///  Convinece type for imaginary constant.
    template<jit::complex_scalar T>
    constexpr T i = T(0.0, 1.0);
Loading