Commit e2c11097 authored by cianciosa's avatar cianciosa
Browse files

Reduced piecewise cases of p1*(c1 ± a) ± p2.

parent 64cb6892
Loading
Loading
Loading
Loading
+177 −1
Original line number Diff line number Diff line
@@ -1429,6 +1429,94 @@
					"$(inherited)",
				);
				MACOSX_DEPLOYMENT_TARGET = 13.3;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMCodeGenData",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Debug;
@@ -1440,6 +1528,94 @@
				"CODE_SIGN_IDENTITY[sdk=macosx*]" = "-";
				CODE_SIGN_STYLE = Automatic;
				MACOSX_DEPLOYMENT_TARGET = 13.3;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMCodeGenData",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Release;
@@ -1526,7 +1702,7 @@
					"EFIT_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/efit.nc\\\"",
					"VMEC_FILE=\\\"/Users/m4c/Projects/graph_framework/graph_tests/vmec.nc\\\"",
					USE_METAL,
					"\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"",
					"\"CXX_ARGS=\\\"-I/Users/m4c/Projects/graph_framework/graph_framework -I/usr/local/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1 -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include -I/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks -fgnuc-version=4.2.1 -std=gnu++2a\\\"\"",
					STATIC,
					"DEBUG=1",
					"$(inherited)",
+47 −9
Original line number Diff line number Diff line
@@ -925,6 +925,32 @@ namespace graph {
            auto lm = multiply_cast(this->left);
            auto rm = multiply_cast(this->right);

//  c1*(c2 + a) - c3 -> fma(c1,a,c4)
            if (lm.get()) {
                auto lmra = add_cast(lm->get_right());
                if (lmra.get()) {
                    if (is_constant_combineable(lm->get_left(),
                                                lmra->get_left()) &&
                        is_constant_combineable(lm->get_left(),
                                                this->right)) {
                        return fma(lm->get_left(),
                                   lmra->get_right(),
                                   lm->get_left()*lmra->get_left() - this->right);
                    }
                }

                auto lmrs = subtract_cast(lm->get_right());
                if (lmrs.get()) {
                    if (is_constant_combineable(lm->get_left(),
                                                lmrs->get_left()) &&
                        is_constant_combineable(lm->get_left(),
                                                this->right)) {
                        return lm->get_left()*lmrs->get_left() - this->right -
                               lm->get_left()*lmrs->get_right();
                    }
                }
            }

//  Assume constants are on the left.
//  v1 - -c*v2 -> v1 + c*v2
            if (rm.get()                      &&
@@ -3491,6 +3517,27 @@ namespace graph {
                return this->middle*(1.0 + this->left);
            }

//  fma(c1,c2 + a,c3) -> fma(c4,a,c5)
            auto ma = add_cast(this->middle);
            if (ma.get()) {
                if (is_constant_combineable(this->left, ma->get_left()) &&
                    is_constant_combineable(this->left, this->right)) {
                    return fma(this->left,
                               ma->get_right(),
                               fma(this->left, ma->get_left(), this->right));
                }
            }

//  fma(c1,c2 - a,c3) -> c4 - c5*a
            auto ms = subtract_cast(this->middle);
            if (ms.get()) {
                if (is_constant_combineable(this->left, ms->get_left()) &&
                    is_constant_combineable(this->left, this->right)) {
                    return fma(this->left, ms->get_left(), this->right) -
                           this->left*ms->get_right();
                }
            }

//  Common factor reduction. If the left and right are both multiply nodes check
//  for a common factor. So you can change a*b + (a*c) -> a*(b + c).
            auto lm = multiply_cast(this->left);
@@ -4149,15 +4196,6 @@ namespace graph {
                }
            }

//  Promote constants out to the left.
            if (is_constant_combineable(this->left, this->right) &&
                !this->left->has_constant_zero()) {
                auto temp = this->right/this->left;
                if (temp->is_normal()) {
                    return this->left*(this->middle + temp);
                }
            }

//  Change negative exponents to divide so that can be factored out.
//  fma(a,b^-c,d) = a/b^c + d
//  fma(b^-c,a,d) = a/b^c + d
+2 −4
Original line number Diff line number Diff line
@@ -538,7 +538,8 @@ template<jit::float_scalar T> void test_subtract() {
//  (c1*v1 + c2) - (c3*v1 + c4) -> c5*(v1 - c6)
    auto subfma = graph::fma(3.0, var_a, 2.0)
                - graph::fma(2.0, var_a, 3.0);
    assert(graph::multiply_cast(subfma).get() && "Expected a multiply node.");
//  -1 + a
    assert(graph::add_cast(subfma).get() && "Expected an add node.");

//  Test cases like
//  (c1 + c2/x) - c3/x -> c1 + c4/x
@@ -2665,9 +2666,6 @@ template<jit::float_scalar T> void test_fma() {
    assert(reduce4_cast->get_right()->is_match(var_b) &&
           "Expected common var_b");

    assert(graph::multiply_cast(graph::fma(two, var_a, one)).get() &&
           "Expected multiply node.");

//  fma(a, b, fma(c, b, d)) -> fma(b, a + c, d)
    auto var_d = graph::variable<T> (1, "");
    auto match1 = graph::fma(var_b, var_a + var_c, var_d);
+74 −5
Original line number Diff line number Diff line
@@ -129,9 +129,6 @@ template<jit::float_scalar T> void piecewise_1D() {
           "Expected a piecewise_1D node.");
    assert(graph::add_cast(graph::fma(p1, 2.0, p2)).get() &&
           "Expected an add node.");
    auto temp = graph::fma(p1, p2, 2.0);
    assert(graph::multiply_cast(graph::fma(p1, p2, 2.0)).get() &&
           "Expected a multiply node.");
    assert(graph::add_cast(graph::fma(p1, p3, p2)).get() &&
           "Expected an add node.");
    assert(graph::piecewise_1D_cast(graph::fma(p1, p3, 2.0)).get() &&
@@ -231,6 +228,43 @@ template<jit::float_scalar T> void piecewise_1D() {
                                                       static_cast<T> (10.0)}), a);
    assert(graph::constant_cast(pc).get() &&
           "Expected a constant.");

//  fma(p1,c1 + a,p2) -> fma(p1,a,p3)
    auto fma_combine = fma(p1,1.0 + a,p3);
    auto fma_combine_cast = graph::fma_cast(fma_combine);
    assert(fma_combine_cast.get() && "Expected an fma node.");
    assert(fma_combine_cast->get_middle()->is_match(a) &&
           "Expected a in the middle.");
    assert(fma_combine_cast->get_left()->is_match(p1) &&
           "Expected p1 on the left.");
    assert(fma_combine_cast->get_right()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the right.");
//  fma(p1,c1 - a,p2) -> p3 - p1*a
    auto fma_combine2 = fma(p1,1.0 - a,p3);
    auto fma_combine2_cast = graph::subtract_cast(fma_combine2);
    assert(fma_combine2_cast.get() && "Expected an subtract node.");
    assert(fma_combine2_cast->get_right()->is_match(p1*a) &&
           "Expected p1*a on the right.");
    assert(fma_combine2_cast->get_left()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the left.");
//  p1*(c1 + a) - p2 -> fma(p1,a,p3)
    auto fma_combine3 = p1*(1.0 + a) - p3;
    auto fma_combine3_cast = graph::fma_cast(fma_combine3);
    assert(fma_combine3_cast.get() && "Expected a fma node.");
    assert(fma_combine3_cast->get_middle()->is_match(a) &&
           "Expected a in the middle.");
    assert(fma_combine3_cast->get_left()->is_match(p1) &&
           "Expected p1 on the left.");
    assert(fma_combine3_cast->get_right()->is_match(p1 - p3) &&
           "Expected p1 - p3 on the right.");
//  p1*(c1 - a) - p2 -> p3 - p1*a
    auto fma_combine4 = p1*(1.0 - a) - p3;
    auto fma_combine4_cast = graph::subtract_cast(fma_combine4);
    assert(fma_combine4_cast.get() && "Expected an subtract node.");
    assert(fma_combine4_cast->get_right()->is_match(p1*a) &&
           "Expected p1*a on the right.");
    assert(fma_combine4_cast->get_left()->is_match(p1 - p3) &&
           "Expected p1 - p3 on the left.");
}

//------------------------------------------------------------------------------
@@ -319,8 +353,6 @@ template<jit::float_scalar T> void piecewise_2D() {
           "Expected a piecewise_2D node.");
    assert(graph::add_cast(graph::fma(p1, 2.0, p2)).get() &&
           "Expected an add node.");
    assert(graph::multiply_cast(graph::fma(p1, p2, 2.0)).get() &&
           "Expected a multiply node.");
    assert(graph::add_cast(graph::fma(p1, p3, p2)).get() &&
           "Expected an add node.");
    assert(graph::piecewise_2D_cast(graph::fma(p1, p3, 2.0)).get() &&
@@ -606,6 +638,43 @@ template<jit::float_scalar T> void piecewise_2D() {
                 graph::variable_cast(ay)},
                {col_test}, {},
                static_cast<T> (8.0), 0.0);

//  fma(p1,c1 + a,p2) -> fma(p1,a,p3)
    auto fma_combine = fma(p1,1.0 + ax,p3);
    auto fma_combine_cast = graph::fma_cast(fma_combine);
    assert(fma_combine_cast.get() && "Expected an fma node.");
    assert(fma_combine_cast->get_middle()->is_match(ax) &&
           "Expected a in the middle.");
    assert(fma_combine_cast->get_left()->is_match(p1) &&
           "Expected p1 on the left.");
    assert(fma_combine_cast->get_right()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the right.");
//  fma(p1,c1 - a,p2) -> p3 - p1*a
    auto fma_combine2 = fma(p1,1.0 - ax,p3);
    auto fma_combine2_cast = graph::subtract_cast(fma_combine2);
    assert(fma_combine2_cast.get() && "Expected an subtract node.");
    assert(fma_combine2_cast->get_right()->is_match(p1*ax) &&
           "Expected p1*a on the right.");
    assert(fma_combine2_cast->get_left()->is_match(p1 + p3) &&
           "Expected p1 + p3 on the left.");
//  p1*(c1 + a) - p2 -> fma(p1,a,p3)
    auto fma_combine3 = p1*(1.0 + ax) - p3;
    auto fma_combine3_cast = graph::fma_cast(fma_combine3);
    assert(fma_combine3_cast.get() && "Expected a fma node.");
    assert(fma_combine3_cast->get_middle()->is_match(ax) &&
           "Expected a in the middle.");
    assert(fma_combine3_cast->get_left()->is_match(p1) &&
           "Expected p1 on the left.");
    assert(fma_combine3_cast->get_right()->is_match(p1 - p3) &&
           "Expected p1 - p3 on the right.");
//  p1*(c1 - a) - p2 -> p3 - p1*a
    auto fma_combine4 = p1*(1.0 - ax) - p3;
    auto fma_combine4_cast = graph::subtract_cast(fma_combine4);
    assert(fma_combine4_cast.get() && "Expected an subtract node.");
    assert(fma_combine4_cast->get_right()->is_match(p1*ax) &&
           "Expected p1*a on the right.");
    assert(fma_combine4_cast->get_left()->is_match(p1 - p3) &&
           "Expected p1 - p3 on the left.");
}

//------------------------------------------------------------------------------