Commit 1e3d6113 authored by cianciosa's avatar cianciosa
Browse files

Commit changes to power reductions. This is WIP becuase there are several regessions.

parent 77171741
Loading
Loading
Loading
Loading
+174 −0
Original line number Diff line number Diff line
@@ -1294,6 +1294,93 @@
				);
				LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
				MACOSX_DEPLOYMENT_TARGET = 14.5;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Debug;
@@ -1312,6 +1399,93 @@
				);
				LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
				MACOSX_DEPLOYMENT_TARGET = 14.5;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMCGData",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Release;
+72 −0
Original line number Diff line number Diff line
@@ -2024,6 +2024,45 @@ namespace graph {
                    return this->right/pow(lp->get_left(), -lp->get_right());
                }
            }
//  a^b*c^b -> (a*c)^b
            if (lp.get() && rp.get()) {
                if (lp->get_right()->is_match(rp->get_right())) {
                    return pow(lp->get_left()*rp->get_left(), lp->get_right());
                }
            }
// (a*b^c)*d^c -> a*(b*d)^c
// (a^c*b)*d^c -> b*(a*d)^c
// a^c*(b*d^c) -> b*(a*d)^c
// a^c*(b^c*d) -> d*(a*b)^c
            if (lm.get() && rp.get()) {
                auto lmlp = pow_cast(lm->get_left());
                auto lmrp = pow_cast(lm->get_right());
                if (lmrp.get()) {
                    if (lmrp->get_right()->is_match(rp->get_right())) {
                        return lm->get_left()*pow(lmrp->get_left()*rp->get_left(),
                                                  rp->get_right());
                    }
                } else if (lmlp.get()) {
                    if (lmlp->get_right()->is_match(rp->get_right())) {
                        return lm->get_right()*pow(lmlp->get_left()*rp->get_left(),
                                                   rp->get_right());
                    }
                }
            } else if (rm.get() && lp.get()) {
                auto rmlp = pow_cast(rm->get_left());
                auto rmrp = pow_cast(rm->get_right());
                if (rmrp.get()) {
                    if (rmrp->get_right()->is_match(lp->get_right())) {
                        return rm->get_left()*pow(lp->get_left()*rmrp->get_left(),
                                                  lp->get_right());
                    }
                } else if (rmlp.get()) {
                    if (rmlp->get_right()->is_match(lp->get_right())) {
                        return rm->get_right()*pow(lp->get_left()*rmlp->get_left(),
                                                   lp->get_right());
                    }
                }
            }

//  (b*a)^c*a^d -> b^c*a^(c + d)
//  (a*b)^c*a^d -> b^c*a^(c + d)
@@ -2931,6 +2970,14 @@ namespace graph {
                    }
                }
            }

//  a^b/c^b -> (a/c)^b
            if (lp.get() && rp.get()) {
                if (lp->get_right()->is_match(rp->get_right())) {
                    return pow(lp->get_left()/rp->get_left(), lp->get_right());
                }
            }

//  (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e
//  (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e
//  (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e
@@ -4320,6 +4367,31 @@ namespace graph {
                }
            }

//  a^b*c^b + d -> (a*c)^b + d
            if (lp.get() && mp.get()) {
                if (lp->get_right()->is_match(mp->get_right())) {
                    return pow(lp->get_left()*mp->get_left(),
                               lp->get_right()) +
                           this->right;
                }
            }

//  fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b)
            if (rm.get() && mp.get()) {
                auto mplm = multiply_cast(mp->get_left());
                if (mplm.get()) {
                    if (is_variable_combineable(mplm->get_left(),
                                                rm->get_left())) {
                        return pow(mplm->get_left(),
                                   mp->get_right()) *
                               fma(this->left,
                                   pow(mplm->get_right(),
                                       mp->get_right()),
                                   this->right/mplm->get_left());
                    }
                }
            }

//  fma(a,b/c,b/d) -> b*(a/c + 1/d)
//  fma(a,c/b,d/b) -> (a*c + d)/b
            if (md.get() && rd.get()) {
+64 −2
Original line number Diff line number Diff line
@@ -957,8 +957,9 @@ namespace graph {
            }

//  Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2.
//  These reductions only make sense if the power is constant.
            auto lm = multiply_cast(this->left);
            if (lm.get()) {
            if (lm.get() && rc.get()) {
                if (lm->get_left()->is_constant()    ||
                    lm->get_right()->is_constant()   ||
                    sqrt_cast(lm->get_left()).get()  ||
@@ -986,8 +987,9 @@ namespace graph {
                }
            }

//  These reductions only make sense if the power is constant.
            auto ld = divide_cast(this->left);
            if (ld.get()) {
            if (ld.get() && rc.get()) {
//  For even exponents e.
//  (-a/b)^e -> (a/b)^e
                auto ldlm = multiply_cast(ld->get_left());
@@ -1010,6 +1012,36 @@ namespace graph {
                               pow(ldlm->get_right(), this->right)/
                               pow(ld->get_right(), this->right);
                    }

                    auto ldlmlm = multiply_cast(ldlm->get_left());
                    if (ldlmlm.get()) {
                        if (ldlmlm->get_left()->is_constant()    ||
                            ldlmlm->get_right()->is_constant()   ||
                            sqrt_cast(ldlmlm->get_left()).get()  ||
                            sqrt_cast(ldlmlm->get_right()).get() ||
                            pow_cast(ldlmlm->get_left()).get()   ||
                            pow_cast(ldlmlm->get_right()).get()) {
                            return (pow(ldlmlm->get_left(), this->right)  *
                                    pow(ldlmlm->get_right(), this->right) *
                                    pow(ldlm->get_right(), this->right)) /
                                   pow(ld->get_right(), this->right);
                        }
                    }

                    auto ldlmrm = multiply_cast(ldlm->get_right());
                    if (ldlmrm.get()) {
                        if (ldlmrm->get_left()->is_constant()    ||
                            ldlmrm->get_right()->is_constant()   ||
                            sqrt_cast(ldlmrm->get_left()).get()  ||
                            sqrt_cast(ldlmrm->get_right()).get() ||
                            pow_cast(ldlmrm->get_left()).get()   ||
                            pow_cast(ldlmrm->get_right()).get()) {
                            return (pow(ldlmrm->get_left(), this->right)  *
                                    pow(ldlmrm->get_right(), this->right) *
                                    pow(ldlm->get_left(), this->right)) /
                                   pow(ld->get_right(), this->right);
                        }
                    }
                }
                
//  Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2.
@@ -1036,6 +1068,36 @@ namespace graph {
                               (pow(ldrm->get_left(), this->right) *
                                pow(ldrm->get_right(), this->right));
                    }

                    auto ldrmlm = multiply_cast(ldrm->get_left());
                    if (ldrmlm.get()) {
                        if (ldrmlm->get_left()->is_constant()    ||
                            ldrmlm->get_right()->is_constant()   ||
                            sqrt_cast(ldrmlm->get_left()).get()  ||
                            sqrt_cast(ldrmlm->get_right()).get() ||
                            pow_cast(ldrmlm->get_left()).get()   ||
                            pow_cast(ldrmlm->get_right()).get()) {
                            return pow(ld->get_left(), this->right) /
                                   (pow(ldrmlm->get_left(), this->right)  *
                                    pow(ldrmlm->get_right(), this->right) *
                                    pow(ldrm->get_right(), this->right));
                        }
                    }

                    auto ldrmrm = multiply_cast(ldrm->get_right());
                    if (ldrmrm.get()) {
                        if (ldrmrm->get_left()->is_constant()    ||
                            ldrmrm->get_right()->is_constant()   ||
                            sqrt_cast(ldrmrm->get_left()).get()  ||
                            sqrt_cast(ldrmrm->get_right()).get() ||
                            pow_cast(ldrmrm->get_left()).get()   ||
                            pow_cast(ldrmrm->get_right()).get()) {
                            return pow(ld->get_left(), this->right) /
                                   (pow(ldrmrm->get_left(), this->right)  *
                                    pow(ldrmrm->get_right(), this->right) *
                                    pow(ldrm->get_left(), this->right));
                        }
                    }
                }

                if (is_variable_combineable(ld->get_left(),
+83 −12
Original line number Diff line number Diff line
@@ -537,8 +537,8 @@ template<jit::float_scalar T> void test_subtract() {
           "Expected to reduce to a constant minus one.");

//  Test common factors.
    auto var_a = graph::variable<T> (1, "");
    auto var_b = graph::variable<T> (1, "");
    auto var_a = graph::variable<T> (1, "a");
    auto var_b = graph::variable<T> (1, "b");
    auto var_c = graph::variable<T> (1, "");
    auto common_a = var_a*var_b - var_a*var_c;
    assert(graph::add_cast(common_a).get() == nullptr &&
@@ -1850,34 +1850,81 @@ template<jit::float_scalar T> void test_multiply() {
    auto common_pow2 = graph::pow(var_b*var_a, 2.0)*graph::pow(var_a, 2.0);
    auto common_pow2_cast = graph::multiply_cast(common_pow2);
    assert(common_pow2_cast.get() && "Expected a multiply node.");
    assert(common_pow2_cast->get_left()->is_match(graph::pow(var_b, 2.0)) &&
    assert(common_pow2_cast->get_right()->is_match(graph::pow(var_b, 2.0)) &&
           "Expected b^2.");
    assert(common_pow2_cast->get_right()->is_match(graph::pow(var_a, 4.0)) &&
    assert(common_pow2_cast->get_left()->is_match(graph::pow(var_a, 4.0)) &&
           "Expected a^4.");
//  (a*b)^2*a^2 -> b^2*a^4
    auto common_pow3 = graph::pow(var_a*var_b, 2.0)*graph::pow(var_a, 2.0);
    auto common_pow3_cast = graph::multiply_cast(common_pow3);
    assert(common_pow3_cast.get() && "Expected a multiply node.");
    assert(common_pow3_cast->get_left()->is_match(graph::pow(var_b, 2.0)) &&
    assert(common_pow3_cast->get_right()->is_match(graph::pow(var_b, 2.0)) &&
           "Expected b^2.");
    assert(common_pow3_cast->get_right()->is_match(graph::pow(var_a, 4.0)) &&
    assert(common_pow3_cast->get_left()->is_match(graph::pow(var_a, 4.0)) &&
           "Expected a^4.");
//  a^2*(b*a)^2 -> b^2*a^4
    auto common_pow4 = graph::pow(var_a, 2.0)*graph::pow(var_b*var_a, 2.0);
    auto common_pow4_cast = graph::multiply_cast(common_pow4);
    assert(common_pow4_cast.get() && "Expected a multiply node.");
    assert(common_pow4_cast->get_left()->is_match(graph::pow(var_b, 2.0)) &&
    assert(common_pow4_cast->get_right()->is_match(graph::pow(var_b, 2.0)) &&
           "Expected b^2.");
    assert(common_pow4_cast->get_right()->is_match(graph::pow(var_a, 4.0)) &&
    assert(common_pow4_cast->get_left()->is_match(graph::pow(var_a, 4.0)) &&
           "Expected a^4.");
//  a^2*(b*a)^2 -> b^2*a^4
    auto common_pow5 = graph::pow(var_a, 2.0)*graph::pow(var_a*var_b, 2.0);
    auto common_pow5_cast = graph::multiply_cast(common_pow5);
    assert(common_pow5_cast.get() && "Expected a multiply node.");
    assert(common_pow5_cast->get_left()->is_match(graph::pow(var_b, 2.0)) &&
    assert(common_pow5_cast->get_right()->is_match(graph::pow(var_b, 2.0)) &&
           "Expected b^2.");
    assert(common_pow5_cast->get_right()->is_match(graph::pow(var_a, 4.0)) &&
    assert(common_pow5_cast->get_left()->is_match(graph::pow(var_a, 4.0)) &&
           "Expected a^4.");

//  a^b*c^b -> (a*c)^b
    auto gather_power = graph::pow(var_a, var_b)*graph::pow(var_c, var_b);
    auto gather_power_cast = graph::pow_cast(gather_power);
    assert(gather_power_cast.get() && "Expected a power node.");
    assert(gather_power_cast->get_left()->is_match(var_a*var_c) &&
           "Expected a*c.");
    assert(gather_power_cast->get_right()->is_match(var_b) &&
           "Expected b.");

// (a*b^c)*d^c -> a*(b*d)^c
    auto var_d = (3.0 + graph::variable<T> (1, ""));
    auto gather_power2 = (var_a*graph::pow(var_b, var_c))*graph::pow(var_d, var_c);
    auto gather_power2_cast = graph::multiply_cast(gather_power2);
    assert(gather_power2_cast.get() && "Expected a mutliply node.");
    assert(gather_power2_cast->get_left()->is_match(var_a) &&
           "Expected a.");
    assert(gather_power2_cast->get_right()->is_match(graph::pow(var_b*var_d,
                                                                var_c)) &&
           "Expected (b*d)^c.");
// (a^c*b)*d^c -> b*(a*d)^c
    auto gather_power3 = (graph::pow(var_a, var_c)*var_b)*graph::pow(var_d, var_c);
    auto gather_power3_cast = graph::multiply_cast(gather_power3);
    assert(gather_power3_cast.get() && "Expected a mutliply node.");
    assert(gather_power3_cast->get_left()->is_match(var_b) &&
           "Expected b.");
    assert(gather_power3_cast->get_right()->is_match(graph::pow(var_a*var_d,
                                                                var_c)) &&
           "Expected (a*d)^c.");
// a^c*(b*d^c) -> b*(a*d)^c
    auto gather_power4 = graph::pow(var_a, var_c)*(var_b*graph::pow(var_d, var_c));
    auto gather_power4_cast = graph::multiply_cast(gather_power4);
    assert(gather_power4_cast.get() && "Expected a mutliply node.");
    assert(gather_power4_cast->get_left()->is_match(var_b) &&
           "Expected b.");
    assert(gather_power4_cast->get_right()->is_match(graph::pow(var_a*var_d,
                                                                var_c)) &&
           "Expected (a*d)^c.");
// a^c*(b^c*d) -> d*(a*b)^c
    auto gather_power5 = graph::pow(var_a, var_c)*(graph::pow(var_b, var_c)*var_d);
    auto gather_power5_cast = graph::multiply_cast(gather_power5);
    assert(gather_power5_cast.get() && "Expected a mutliply node.");
    assert(gather_power5_cast->get_left()->is_match(var_d) &&
           "Expected d.");
    assert(gather_power5_cast->get_right()->is_match(graph::pow(var_a*var_b,
                                                                var_c)) &&
           "Expected (a*b)^c.");
}

//------------------------------------------------------------------------------
@@ -2600,6 +2647,14 @@ template<jit::float_scalar T> void test_divide() {
//  a/(b/c) -> a*c/b
    assert((a/(b/c))->is_match(a*c/b) && "Expected a*b/c");

//  a^b/c^b -> (a/c)^b
    auto gather_power = graph::pow(a, b)/graph::pow(c, b);
    auto gather_power_cast = graph::pow_cast(gather_power);
    assert(gather_power_cast.get() && "Expected a power node.");
    assert(gather_power_cast->get_left()->is_match(a/c) &&
           "Expected a*c.");
    assert(gather_power_cast->get_right()->is_match(b) &&
           "Expected b.");
//  (a*b*c)^2/a^2 -> (b*c)^2
//  (a*b*c)^2/(a^2*d) -> (b*c)^2/d
//  (e*(a*b*c)^2)/(a^2*d) -> e*(b*c)^2/d
@@ -2716,8 +2771,8 @@ template<jit::float_scalar T> void test_fma() {
           "Expected a value of one.");

//  Test reduction.
    auto var_a = graph::variable<T> (1, "");
    auto var_b = graph::variable<T> (1, "");
    auto var_a = graph::variable<T> (1, "a");
    auto var_b = graph::variable<T> (1, "b");
    auto var_c = graph::variable<T> (1, "");

//  fma(1,a,b) = a + b
@@ -3509,6 +3564,22 @@ template<jit::float_scalar T> void test_fma() {
                                                               var_d*var_b,
                                                               var_e)) &&
           "Expected fma(a,d*b,e) as numerator.");

//  a^b*c^b + d -> (a*c)^b + d
    auto gather_power = graph::fma(graph::pow(var_a, var_b),
                                   graph::pow(var_c, var_b),
                                   var_d);
    assert(gather_power->is_match(graph::pow(var_a*var_c,
                                             var_b) + var_d) &&
           "Expected a power node.");

//  fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b)
    auto commom_power = graph::fma(2.0, graph::pow(var_a*var_b, 2.0), graph::pow(var_a, 2.0)*var_b);
    commom_power->to_latex();
    std::cout << std::endl << std::endl;
    auto commom_power_cast = graph::multiply_cast(commom_power);
    assert(commom_power_cast.get() && "Expected a multiply node.");
//  fma(2,(a*b)^2,fma())
}

//------------------------------------------------------------------------------
+124 −45

File changed.

Preview size limit exceeded, changes collapsed.