Commit 4098469d authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Reactor JIT so CPU and GPU is more unified. Change precision of serialized...

Reactor JIT so CPU and GPU is more unified. Change precision of serialized constants. Fix solver for split_simplextic to produce correct kernel. Fix tests to account for differences in pow function between cpu and gpu.
parent 519d3ac7
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -65,6 +65,11 @@ target_link_libraries (gpu_lib
                       $<$<BOOL:${USE_METAL}>:metal_lib>
                       $<$<BOOL:${USE_CUDA}>:cuda_lib>
)
target_compile_definitions (gpu_lib
                            INTERFACE
                            $<$<BOOL:${USE_METAL}>:USE_GPU>
                            $<$<BOOL:${USE_CUDA}>:USE_GPU>
)

#-------------------------------------------------------------------------------
#  Sanitizer options
+18 −58
Original line number Diff line number Diff line
@@ -43,10 +43,18 @@ int main(int argc, const char * argv[]) {
    //const size_t num_rays = 1;
    const size_t num_rays = 1000000;

    std::vector<std::thread> threads(1);
    //std::vector<std::thread> threads(std::max(std::min(std::thread::hardware_concurrency(),
    //                                                   static_cast<unsigned int> (num_rays)),
    //                                          static_cast<unsigned int> (1)));
    std::vector<std::thread> threads(0);
#if USE_GPU
    if constexpr (jit::can_jit<cpu> ()) {
        threads.resize(1);
    } else {
#endif
        threads.resize(std::max(std::min(std::thread::hardware_concurrency(),
                                         static_cast<unsigned int> (num_rays)),
                                static_cast<unsigned int> (1)));
#if USE_GPU
    }
#endif

    for (size_t i = 0, ie = threads.size(); i < ie; i++) {
        threads[i] = std::thread([num_times, num_rays] (const size_t thread_number,
@@ -92,7 +100,7 @@ int main(int argc, const char * argv[]) {
            //solver::rk4<dispersion::cold_plasma<cpu>>
                solve(omega, kx, ky, kz, x, y, z, t, 60.0/num_times, eq);
            solve.init(kx);
            solve.compile(num_times, num_rays);
            solve.compile(num_rays);
            if (thread_number == 0) {
                solve.print_dispersion();
                std::cout << std::endl;
@@ -109,45 +117,21 @@ int main(int argc, const char * argv[]) {
                solve.print_dzdt();
            }

            auto residule = solve.residule();

            const size_t sample = int_dist(engine);

            if (thread_number == 0) {
            if (thread_number == 0 && false) {
                std::cout << "Omega " << omega->evaluate().at(sample) << std::endl;
                std::cout << "t = " << 0.0 << " ";
                std::cout << solve.state.back().x.at(sample) << std::endl;
            }

            const timeing::measure_diagnostic cpu_time("CPU Time");
            for (size_t j = 0; j < num_times; j++) {
                if (thread_number == 0 && false) {
                    std::cout << "Time Step " << j << " Sample " << sample << " "
                              << solve.state.back().t.at(sample) << " "
                              << solve.state.back().x.at(sample) << " "
                              << solve.state.back().y.at(sample) << " "
                              << solve.state.back().z.at(sample) << " "
                              << solve.state.back().kx.at(sample) << " "
                              << solve.state.back().ky.at(sample) << " "
                              << solve.state.back().kz.at(sample) << " "
                              << residule->evaluate().at(sample)
                              << std::endl;
                    solve.print(sample);
                }
                //solve.step();
                solve.step();
            }
            cpu_time.stop();

            if (thread_number == 0 && false) {
                std::cout << "Time Step " << num_times << " Sample " << sample << " "
                          << solve.state.back().t.at(sample) << " "
                          << solve.state.back().x.at(sample) << " "
                          << solve.state.back().y.at(sample) << " "
                          << solve.state.back().z.at(sample) << " "
                          << solve.state.back().kx.at(sample) << " "
                          << solve.state.back().ky.at(sample) << " "
                          << solve.state.back().kz.at(sample) << " "
                          << residule->evaluate().at(sample)
                          << std::endl;
                solve.print(sample);
            }
        }, i, threads.size());
    }
@@ -159,27 +143,3 @@ int main(int argc, const char * argv[]) {
    std::cout << std::endl << "Timing:" << std::endl;
    total.stop();
}

/*
//------------------------------------------------------------------------------
///  @brief Print out timings.
///
///  @param[in] name Discription of the times.
///  @param[in] time Elapsed time in nanoseconds.
//------------------------------------------------------------------------------
void write_time(const std::string &name, const std::chrono::nanoseconds time) {
    if (time.count() < 1000) {
        std::cout << name << time.count()               << " ns" << std::endl;
    } else if (time.count() < 1000000) {
        std::cout << name << time.count()/1000.0        << " μs" << std::endl;
    } else if (time.count() < 1000000000) {
        std::cout << name << time.count()/1000000.0     << " ms" << std::endl;
    } else if (time.count() < 60000000000) {
        std::cout << name << time.count()/1000000000.0  << " s" << std::endl;
    } else if (time.count() < 3600000000000) {
        std::cout << name << time.count()/60000000000.0 << " min" << std::endl;
    } else {
        std::cout << name << time.count()/3600000000000 << " h" << std::endl;
    }
}
*/
+7 −0
Original line number Diff line number Diff line
@@ -98,6 +98,13 @@ namespace backend {
//------------------------------------------------------------------------------
        virtual void cos() = 0;

//------------------------------------------------------------------------------
///  @brief Get a pointer to the basic memory buffer.
///
///  @returns The pointer to the buffer memory.
//------------------------------------------------------------------------------
        virtual BASE *data() = 0;

///  Type def to retrieve the backend base type.
        typedef BASE base;
    };
+37 −23
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ namespace backend {
    class cpu final : public buffer<BASE> {
    protected:
///  The data buffer to hold the data.
        std::vector<BASE> data;
        std::vector<BASE> buffer;

    public:
//------------------------------------------------------------------------------
@@ -35,7 +35,7 @@ namespace backend {
///  @param[in] s Size of he data buffer.
//------------------------------------------------------------------------------
        cpu(const size_t s) :
        data(s) {}
        buffer(s) {}

//------------------------------------------------------------------------------
///  @brief Construct a cpu backend with a size.
@@ -44,7 +44,7 @@ namespace backend {
///  @param[in] d Scalar data to initalize.
//------------------------------------------------------------------------------
        cpu(const size_t s, const BASE d) :
        data(s, d) {}
        buffer(s, d) {}

//------------------------------------------------------------------------------
///  @brief Construct a cpu backend from a vector.
@@ -52,7 +52,7 @@ namespace backend {
///  @param[in] d Array buffer.
//------------------------------------------------------------------------------
        cpu(const std::vector<BASE> &d) :
        data(d) {}
        buffer(d) {}

//------------------------------------------------------------------------------
///  @brief Construct a cpu backend from a cpu backend.
@@ -60,27 +60,27 @@ namespace backend {
///  @param[in] d Backend buffer.
//------------------------------------------------------------------------------
        cpu(const cpu &d) :
        data(d.data) {}
        buffer(d.buffer) {}

//------------------------------------------------------------------------------
///  @brief Index operator.
//------------------------------------------------------------------------------
        virtual BASE &operator[] (const size_t index) final {
            return data[index];
            return buffer[index];
        }

//------------------------------------------------------------------------------
///  @brief Const index operator.
//------------------------------------------------------------------------------
        virtual const BASE &operator[] (const size_t index) const final {
            return data[index];
            return buffer[index];
        }

//------------------------------------------------------------------------------
///  @brief Get value at.
//------------------------------------------------------------------------------
        virtual const BASE at(const size_t index) const final {
            return data.at(index);
            return buffer.at(index);
        }

//------------------------------------------------------------------------------
@@ -89,7 +89,7 @@ namespace backend {
///  @param[in] d Scalar data to set.
//------------------------------------------------------------------------------
        virtual void set(const BASE d) final {
            data.assign(data.size(), d);
            buffer.assign(buffer.size(), d);
        }

//------------------------------------------------------------------------------
@@ -98,14 +98,14 @@ namespace backend {
///  @param[in] d Vector data to set.
//------------------------------------------------------------------------------
        virtual void set(const std::vector<BASE> &d) final {
            data.assign(d.cbegin(), d.cend());
            buffer.assign(d.cbegin(), d.cend());
        }

//------------------------------------------------------------------------------
///  @brief Get size of the buffer.
//------------------------------------------------------------------------------
        virtual size_t size() const final {
            return data.size();
            return buffer.size();
        }

//------------------------------------------------------------------------------
@@ -116,12 +116,12 @@ namespace backend {
        virtual BASE max() const final {
            if constexpr (std::is_same<BASE, std::complex<float>>::value ||
                          std::is_same<BASE, std::complex<double>>::value) {
                return *std::max_element(data.cbegin(), data.cend(),
                return *std::max_element(buffer.cbegin(), buffer.cend(),
                                         [] (const BASE a, const BASE b) {
                    return std::abs(a) < std::abs(b);
                });
            } else {
                return *std::max_element(data.cbegin(), data.cend());
                return *std::max_element(buffer.cbegin(), buffer.cend());
            }
        }

@@ -131,9 +131,9 @@ namespace backend {
///  @returns Returns true if every element is the same.
//------------------------------------------------------------------------------
        virtual bool is_same() const final {
            const BASE same = data.at(0);
            for (size_t i = 1, ie = data.size(); i < ie; i++) {
                if (data.at(i) != same) {
            const BASE same = buffer.at(0);
            for (size_t i = 1, ie = buffer.size(); i < ie; i++) {
                if (buffer.at(i) != same) {
                    return false;
                }
            }
@@ -147,7 +147,7 @@ namespace backend {
///  @returns Returns true if every element is zero.
//------------------------------------------------------------------------------
        virtual bool is_zero() const final {
            for (BASE d : data) {
            for (BASE d : buffer) {
                if (d != static_cast<BASE> (0.0)) {
                    return false;
                }
@@ -162,7 +162,7 @@ namespace backend {
///  @returns Returns true if every element is negative.
//------------------------------------------------------------------------------
        virtual bool is_negative() const final {
            for (BASE d : data) {
            for (BASE d : buffer) {
                if (std::real(d) > std::real(static_cast<BASE> (0.0))) {
                    return false;
                }
@@ -175,7 +175,7 @@ namespace backend {
///  @brief Take sqrt.
//------------------------------------------------------------------------------
        virtual void sqrt() final {
            for (BASE &d : data) {
            for (BASE &d : buffer) {
                d = std::sqrt(d);
            }
        }
@@ -184,7 +184,7 @@ namespace backend {
///  @brief Take exp.
//------------------------------------------------------------------------------
        virtual void exp() final {
            for (BASE &d : data) {
            for (BASE &d : buffer) {
                d = std::exp(d);
            }
        }
@@ -193,7 +193,7 @@ namespace backend {
///  @brief Take the natural log.
//------------------------------------------------------------------------------
        virtual void log() final {
            for (BASE &d : data) {
            for (BASE &d : buffer) {
                d = std::log(d);
            }
        }
@@ -202,7 +202,7 @@ namespace backend {
///  @brief Take sine.
//------------------------------------------------------------------------------
        virtual void sin() final {
            for (BASE &d : data) {
            for (BASE &d : buffer) {
                d = std::sin(d);
            }
        }
@@ -211,10 +211,19 @@ namespace backend {
///  @brief Take cosine.
//------------------------------------------------------------------------------
        virtual void cos() final {
            for (BASE &d : data) {
            for (BASE &d : buffer) {
                d = std::cos(d);
            }
        }

//------------------------------------------------------------------------------
///  @brief Get a pointer to the basic memory buffer.
///
///  @returns The pointer to the buffer memory.
//------------------------------------------------------------------------------
        virtual BASE *data() final {
            return buffer.data();
        }
    };

//------------------------------------------------------------------------------
@@ -517,6 +526,11 @@ namespace backend {
                    }
                    return base;
                }
            } else {
                for (size_t i = 0, ie = base.size(); i < ie; i++) {
                    base[i] = std::pow(base.at(i), right);
                }
                return base;
            }
        } else if (base.size() == 1) {
            const BASE left = base.at(0);
+50 −9
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@

#include "vector.hpp"
#include "equilibrium.hpp"
#include "jit.hpp"

namespace dispersion {
//******************************************************************************
@@ -117,10 +118,14 @@ namespace dispersion {
///  This uses newtons methods to solver for D(x) = 0.
///
///  @param[in,out] x              The unknown to solver for.
///  @param[in]     inputs         Inputs for jit compile.
///  @param[in]     tolarance      Tolarance to solve the dispersion function to.
///  @param[in]     max_iterations Maximum number of iterations before giving up.
///  @returns The residule graph.
//------------------------------------------------------------------------------
        void solve(graph::shared_leaf<typename DISPERSION_FUNCTION::backend> x,
        graph::shared_leaf<typename DISPERSION_FUNCTION::backend>
        solve(graph::shared_leaf<typename DISPERSION_FUNCTION::backend> x,
              graph::input_nodes<typename DISPERSION_FUNCTION::backend> inputs,
              const typename DISPERSION_FUNCTION::base tolarance=1.0E-30,
              const size_t max_iterations = 1000) {
            auto loss = D*D;
@@ -128,15 +133,49 @@ namespace dispersion {
                        - loss/(loss->df(x) +
                                graph::constant<typename DISPERSION_FUNCTION::backend> (tolarance));

            typename DISPERSION_FUNCTION::base max_residule =
                loss->evaluate().max();
            typename DISPERSION_FUNCTION::base max_residule;
            size_t iterations = 0;
            std::unique_ptr<jit::kernel<typename DISPERSION_FUNCTION::backend>> source;
            if constexpr (jit::can_jit<typename DISPERSION_FUNCTION::backend> ()) {
                auto x_var = graph::variable_cast(x);
                inputs.push_back(x_var);
                
                graph::output_nodes<typename DISPERSION_FUNCTION::backend> outputs = {
                    loss
                };

                graph::map_nodes<typename DISPERSION_FUNCTION::backend> setters = {
                    {x_next, x_var}
                };
                
                source = std::make_unique<jit::kernel<typename DISPERSION_FUNCTION::backend>> ("loss_kernel",
                                                                                               inputs,
                                                                                               outputs,
                                                                                               setters);
                source->add_max_reduction(x_var);

                source->compile("loss_kernel", inputs, outputs, x_var->size());
                source->compile_max();

                max_residule = source->max_reduction();
            } else {
                max_residule = loss->evaluate().max();
            }

            while (std::abs(max_residule) > std::abs(tolarance) &&
                   iterations++ < max_iterations) {
                if constexpr (jit::can_jit<typename DISPERSION_FUNCTION::backend> ()) {
                    max_residule = source->max_reduction();
               } else {
                    x->set(x_next->evaluate());
                    max_residule = loss->evaluate().max();
                }
            }

            if constexpr (jit::can_jit<typename DISPERSION_FUNCTION::backend> ()) {
                source->copy_buffer(inputs.size() - 1,
                                    inputs.back()->data());
            }

//  In release mode asserts are diaables so write error to standard err. Need to
//  flip the comparison operator because we want to assert to trip if false.
@@ -148,6 +187,8 @@ namespace dispersion {
                std::cerr << "Minimum residule reached: " << max_residule
                          << std::endl;
            }

            return loss;
        }

//------------------------------------------------------------------------------
Loading