Commit 1e2e89bd authored by cianciosa's avatar cianciosa
Browse files

Reduce common factors in adds of powers.

parent cf7a3be8
Loading
Loading
Loading
Loading
+175 −0
Original line number Diff line number Diff line
@@ -1374,6 +1374,93 @@
				);
				LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
				MACOSX_DEPLOYMENT_TARGET = 14.6;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lLLVMCGData",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Debug;
@@ -1388,6 +1475,93 @@
				GCC_C_LANGUAGE_STANDARD = gnu17;
				LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
				MACOSX_DEPLOYMENT_TARGET = 14.6;
				OTHER_LDFLAGS = (
					"-lnetcdf",
					"-ld_classic",
					"-L/Users/m4c/Projects/graph_framework/build/_deps/llvm-build/lib",
					"-lz",
					"-lLLVMCoverage",
					"-lLLVMSupport",
					"-lLLVMDebugInfoCodeView",
					"-lLLVMRemarks",
					"-lLLVMJITLink",
					"-lLLVMLinker",
					"-lLLVMTextAPI",
					"-lLLVMRuntimeDyld",
					"-lLLVMOrcShared",
					"-lLLVMOrcDebugging",
					"-lLLVMOrcTargetProcess",
					"-lLLVMOrcJIT",
					"-lLLVMHipStdPar",
					"-lLLVMAggressiveInstCombine",
					"-lLLVMVectorize",
					"-lLLVMAsmParser",
					"-lLLVMOption",
					"-lLLVMLTO",
					"-lLLVMObject",
					"-lLLVMWindowsDriver",
					"-lLLVMDemangle",
					"-lLLVMIRReader",
					"-lLLVMIRPrinter",
					"-lLLVMInstCombine",
					"-lLLVMBinaryFormat",
					"-lLLVMCoroutines",
					"-lLLVMBitstreamReader",
					"-lLLVMBitReader",
					"-lLLVMBitWriter",
					"-lLLVMDebugInfoDWARF",
					"-lLLVMInstrumentation",
					"-lLLVMCFGuard",
					"-lLLVMObjCARCOpts",
					"-lLLVMipo",
					"-lLLVMGlobalISel",
					"-lLLVMExecutionEngine",
					"-lLLVMFrontendDriver",
					"-lLLVMFrontendHLSL",
					"-lLLVMFrontendOpenMP",
					"-lLLVMFrontendOffloading",
					"-lLLVMSelectionDAG",
					"-lLLVMProfileData",
					"-lLLVMAnalysis",
					"-lLLVMScalarOpts",
					"-lLLVMCodeGenTypes",
					"-lLLVMCodeGenData",
					"-lLLVMCodeGen",
					"-lLLVMTargetParser",
					"-lLLVMScalarOpts",
					"-lLLVMTarget",
					"-lLLVMTransformUtils",
					"-lLLVMPasses",
					"-lLLVMSupport",
					"-lLLVMMCParser",
					"-lLLVMMC",
					"-lLLVMCore",
					"-lLLVMAsmPrinter",
					"-lLLVMAArch64Utils",
					"-lLLVMAArch64Info",
					"-lLLVMAArch64Desc",
					"-lLLVMAArch64AsmParser",
					"-lLLVMAArch64CodeGen",
					"-lLLVMSandboxIR",
					"-lLLVMFrontendAtomic",
					"-lLLVMCGData",
					"-lclangFrontend",
					"-lclangBasic",
					"-lclangEdit",
					"-lclangLex",
					"-lclangDriver",
					"-lclangSerialization",
					"-lclangAST",
					"-lclangSema",
					"-lclangAnalysis",
					"-lclangASTMatchers",
					"-lclangSupport",
					"-lclangParse",
					"-lclangAPINotes",
					"-lclangCodeGen",
					"-rpath",
					/usr/local/lib,
				);
				PRODUCT_NAME = "$(TARGET_NAME)";
			};
			name = Release;
@@ -1427,6 +1601,7 @@
				GCC_PREPROCESSOR_DEFINITIONS = (
					"DEBUG=1",
					"$(inherited)",
					USE_INPUT_CACHE,
				);
				MACOSX_DEPLOYMENT_TARGET = 13.3;
				OTHER_LDFLAGS = (
+32 −3
Original line number Diff line number Diff line
@@ -490,12 +490,41 @@ namespace graph {
                }
            }

//  Handle cases like:
            auto pl = pow_cast(this->left);
            auto pr = pow_cast(this->right);

//  (a*b)^c + (a*d)^c -> a^c*(b^c + d^c)
//  (b*a)^c + (a*d)^c -> a^c*(b^c + d^c)
//  (a*b)^c + (d*a)^c -> a^c*(b^c + d^c)
//  (b*a)^c + (d*a)^c -> a^c*(b^c + d^c)
            if (pl.get() && pr.get() &&
                pl->get_right()->is_match(pr->get_right())) {
                auto plm = multiply_cast(pl->get_left());
                auto prm = multiply_cast(pr->get_left());
                if (plm.get() && prm.get()) {
                    if (plm->get_left()->is_match(prm->get_left())) {
                        return pow(plm->get_left(), pl->get_right())*
                               (pow(plm->get_right(), pl->get_right()) +
                                pow(prm->get_right(), pl->get_right()));
                    } else if (plm->get_left()->is_match(prm->get_right())) {
                        return pow(plm->get_left(), pl->get_right())*
                               (pow(plm->get_right(), pl->get_right()) +
                                pow(prm->get_left(), pl->get_right()));
                    } else if (plm->get_right()->is_match(prm->get_left())) {
                        return pow(plm->get_right(), pl->get_right())*
                               (pow(plm->get_left(), pl->get_right()) +
                                pow(prm->get_right(), pl->get_right()));
                    } else if (plm->get_right()->is_match(prm->get_right())) {
                        return pow(plm->get_right(), pl->get_right())*
                               (pow(plm->get_left(), pl->get_right()) +
                                pow(prm->get_left(), pl->get_right()));
                    }
                }
            }

//  (a/y)^e + b/y^e -> (a^2 + b)/(y^e)
//  b/y^e + (a/y)^e -> (b + a^2)/(y^e)
//  (a/y)^e + (b/y)^e -> (a^2 + b^2)/(y^e)
            auto pl = pow_cast(this->left);
            auto pr = pow_cast(this->right);
            if (pl.get() && rd.get()) {
                auto rdp = pow_cast(rd->get_right());
                if (rdp.get() && pl->get_right()->is_match(rdp->get_right())) {
+1 −1
Original line number Diff line number Diff line
@@ -133,7 +133,7 @@ namespace jit {

            for (auto &in : inputs) {
                if (usage.find(in.get()) == usage.end()) {
                    usage[in.get()] == 0;
                    usage[in.get()] = 0;
                }
            }

+30 −0
Original line number Diff line number Diff line
@@ -1112,6 +1112,36 @@ namespace graph {
            return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
        }

//------------------------------------------------------------------------------
///  @brief Compile preamble.
///
///  Some nodes require additions to the preamble however most don't so define a
///  generic method that does nothing.
///
///  @param[in,out] stream          String buffer stream.
///  @param[in,out] registers       List of defined registers.
///  @param[in,out] visited         List of visited nodes.
///  @param[in,out] usage           List of register usage count.
///  @param[in,out] textures1d      List of 1D textures.
///  @param[in,out] textures2d      List of 2D textures.
///  @param[in,out] avail_const_mem Available constant memory.
//------------------------------------------------------------------------------
        virtual void compile_preamble(std::ostringstream &stream,
                                      jit::register_map &registers,
                                      jit::visiter_map &visited,
                                      jit::register_usage &usage,
                                      jit::texture1d_list &textures1d,
                                      jit::texture2d_list &textures2d,
                                      int &avail_const_mem) {
            if (usage.find(this) == usage.end()) {
                usage[this] = 1;
#ifdef SHOW_USE_COUNT
            } else {
                ++usage[this];
#endif
            }
        }

//------------------------------------------------------------------------------
///  @brief Compile the node.
///
+41 −0
Original line number Diff line number Diff line
@@ -392,6 +392,47 @@ template<jit::float_scalar T> void test_add() {
           "Expected var_a");
    assert(common_var5_cast->get_left()->is_match(2.0/var_b + 3.0/var_c) &&
           "Expected 2/b + 3/c");

//  (a*b)^c + (a*d)^c -> a^c*(b^c + d^c)
    auto common_power_factor = graph::pow(var_a*var_b, 2.0)
                             + graph::pow(var_a*var_c, 2.0);
    auto common_power_factor_cast = multiply_cast(common_power_factor);
    assert(common_power_factor_cast.get() && "Expected a multiply node.");
    assert(common_power_factor_cast->get_right()->is_match(var_a*var_a) &&
           "Expected a^2 on the right.");
    assert(common_power_factor_cast->get_left()->is_match(var_b*var_b +
                                                          var_c*var_c) &&
           "Expected b^2 + c^2 on the left.");
//  (a*b)^c + (d*a)^c -> a^c*(b^c + d^c)
    auto common_power_factor2 = graph::pow(var_a*var_b, 2.0)
                              + graph::pow(var_c*var_a, 2.0);
    auto common_power_factor2_cast = multiply_cast(common_power_factor2);
    assert(common_power_factor2_cast.get() && "Expected a multiply node.");
    assert(common_power_factor2_cast->get_right()->is_match(var_a*var_a) &&
           "Expected a^2 on the right.");
    assert(common_power_factor2_cast->get_left()->is_match(var_b*var_b +
                                                           var_c*var_c) &&
           "Expected b^2 + c^2 on the left.");
//  (b*a)^c + (a*d)^c -> a^c*(b^c + d^c)
    auto common_power_factor3 = graph::pow(var_b*var_a, 2.0)
                              + graph::pow(var_a*var_c, 2.0);
    auto common_power_factor3_cast = multiply_cast(common_power_factor3);
    assert(common_power_factor3_cast.get() && "Expected a multiply node.");
    assert(common_power_factor3_cast->get_right()->is_match(var_a*var_a) &&
           "Expected a^2 on the right.");
    assert(common_power_factor3_cast->get_left()->is_match(var_b*var_b +
                                                           var_c*var_c) &&
           "Expected b^2 + c^2 on the left.");
//  (b*a)^c + (d*a)^c -> a^c*(b^c + d^c)
    auto common_power_factor4 = graph::pow(var_b*var_a, 2.0)
                              + graph::pow(var_c*var_a, 2.0);
    auto common_power_factor4_cast = multiply_cast(common_power_factor4);
    assert(common_power_factor4_cast.get() && "Expected a multiply node.");
    assert(common_power_factor4_cast->get_right()->is_match(var_a*var_a) &&
           "Expected a^2 on the right.");
    assert(common_power_factor4_cast->get_left()->is_match(var_b*var_b +
                                                           var_c*var_c) &&
           "Expected b^2 + c^2 on the left.");
}

//------------------------------------------------------------------------------