Commit 93680dff authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Fix slow generation of the kernel source code and update the tests so they pass with cuda.

parent 600d1957
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -132,7 +132,10 @@ int main(int argc, const char * argv[]) {

            if (thread_number == 0 && false) {
                solve.print(sample);
            } else {
                solve.sync();
            }

        }, i, threads.size());
    }

+17 −17
Original line number Diff line number Diff line
@@ -283,10 +283,10 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::backend> compile(std::stringstream &stream,
                                                          jit::register_map<LN> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename LN::backend> l = this->left->compile(stream, registers);
                shared_leaf<typename RN::backend> r = this->right->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
@@ -594,10 +594,10 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::backend> compile(std::stringstream &stream,
                                                          jit::register_map<LN> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename LN::backend> l = this->left->compile(stream, registers);
                shared_leaf<typename RN::backend> r = this->right->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
@@ -970,10 +970,10 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::backend> compile(std::stringstream &stream,
                                                          jit::register_map<LN> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename LN::backend> l = this->left->compile(stream, registers);
                shared_leaf<typename RN::backend> r = this->right->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
@@ -1287,10 +1287,10 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::backend> compile(std::stringstream &stream,
                                                          jit::register_map<LN> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename LN::backend> l = this->left->compile(stream, registers);
                shared_leaf<typename RN::backend> r = this->right->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
@@ -1537,11 +1537,11 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::backend> compile(std::stringstream &stream,
                                                          jit::register_map<LN> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename LN::backend> l = this->left->compile(stream, registers);
                shared_leaf<typename MN::backend> m = this->middle->compile(stream, registers);
                shared_leaf<typename RN::backend> r = this->right->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
+6 −1
Original line number Diff line number Diff line
@@ -149,7 +149,12 @@ namespace jit {
            source_buffer << "    const unsigned int k = threadIdx.x%32;" << std::endl;
#endif
            source_buffer << "    if (i < " << input->size() << ") {" << std::endl;
            source_buffer << "        " << jit::type_to_string<typename BACKEND::base> () << " sub_max = input[i];" << std::endl;
            source_buffer << "        " << jit::type_to_string<typename BACKEND::base> () << " sub_max = ";
            if constexpr (jit::is_complex<typename BACKEND::base> ()) {
                source_buffer << "abs(input[i]);" << std::endl;
            } else {
                source_buffer << "input[i];" << std::endl;
            }
            source_buffer << "        for (size_t index = i + 1024; index < " << input->size() << "; index += 1024) {" << std::endl;
            if constexpr (jit::is_complex<typename BACKEND::base> ()) {
                source_buffer << "            sub_max = max(abs(sub_max), abs(input[index]));" << std::endl;
+9 −9
Original line number Diff line number Diff line
@@ -125,9 +125,9 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename N::backend> compile(std::stringstream &stream,
                                                         jit::register_map<N> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename N::backend> a = this->arg->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename N::backend> (stream);
@@ -265,9 +265,9 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename N::backend> compile(std::stringstream &stream,
                                                         jit::register_map<N> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename N::backend> a = this->arg->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename N::backend> (stream);
@@ -401,9 +401,9 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename N::backend> compile(std::stringstream &stream,
                                                         jit::register_map<N> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename N::backend> a = this->arg->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename N::backend> (stream);
@@ -599,10 +599,10 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename LN::backend> compile(std::stringstream &stream,
                                                          jit::register_map<LN> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename LN::backend> l = this->left->compile(stream, registers);
                shared_leaf<typename RN::backend> r = this->right->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename LN::backend> (stream);
+4 −4
Original line number Diff line number Diff line
@@ -82,9 +82,9 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename N::backend> compile(std::stringstream &stream,
                                                         jit::register_map<N> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename N::backend> a = this->arg->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename N::backend> (stream);
@@ -220,9 +220,9 @@ namespace graph {
//------------------------------------------------------------------------------
        virtual shared_leaf<typename N::backend> compile(std::stringstream &stream,
                                                         jit::register_map<N> &registers) final {
            if (registers.find(this) == registers.end()) {
                shared_leaf<typename N::backend> a = this->arg->compile(stream, registers);

            if (registers.find(this) == registers.end()) {
                registers[this] = jit::to_string('r', this);
                stream << "        const ";
                jit::add_type<typename N::backend> (stream);
Loading