Skip to content
Snippets Groups Projects
AlgorithmMPITest.h 6.69 KiB
Newer Older
#ifndef ALGORITHMMPITEST_H_
#define ALGORITHMMPITEST_H_

#include <cxxtest/TestSuite.h>

#include "MantidAPI/Algorithm.h"
#include "MantidAPI/HistogramValidator.h"
#include "MantidKernel/CompositeValidator.h"
#include "MantidAPI/AnalysisDataService.h"
#include "MantidKernel/Strings.h"


#include "MantidKernel/Property.h"
#include "MantidAPI/AlgorithmFactory.h"
#include "MantidTestHelpers/FakeObjects.h"
#include "MantidAPI/WorkspaceProperty.h"
#include "MantidAPI/FrameworkManager.h"
#ifdef MPI_EXPERIMENTAL
#include "MantidParallel/ParallelRunner.h"
#endif

using namespace Mantid::Kernel;
using namespace Mantid::API;

#ifdef MPI_EXPERIMENTAL
class NoParallelismAlgorithm : public Algorithm {
public:
  const std::string name() const override { return "NoParallelismAlgorithm"; }
  int version() const override { return 1; }
  const std::string category() const override { return ""; }
  const std::string summary() const override { return ""; }

  void init() override {
    declareProperty(make_unique<WorkspaceProperty<>>("InputWorkspace", "",
                                                     Direction::Input));
  }

  void exec() override {}
};
DECLARE_ALGORITHM(NoParallelismAlgorithm)

class AlgorithmWithBad_getParallelExecutionMode : public Algorithm {
public:
  const std::string name() const override {
    return "AlgorithmWithBad_getParallelExecutionMode ";
  }
  int version() const override { return 1; }
  const std::string category() const override { return ""; }
  const std::string summary() const override { return ""; }
  void init() override {}
  void exec() override {}

protected:
  Parallel::ExecutionMode getParallelExecutionMode(
      const std::map<std::string, Parallel::StorageMode> &storageModes)
      const override {
        static_cast<void>(storageModes);
        return Parallel::ExecutionMode::Serial;
      }
};
DECLARE_ALGORITHM(AlgorithmWithBad_getParallelExecutionMode)

class ParallelAlgorithm : public Algorithm {
public:
  const std::string name() const override {
    return "ParallelAlgorithm";
  }
  int version() const override { return 1; }
  const std::string category() const override { return ""; }
  const std::string summary() const override { return ""; }
  void init() override {
    auto wsValidator = boost::make_shared<Kernel::CompositeValidator>();
    wsValidator->add<HistogramValidator>();
    //wsValidator->add<RawCountValidator>();
    declareProperty(Kernel::make_unique<WorkspaceProperty<>>(
        "InputWorkspace", "", Kernel::Direction::Input, wsValidator));
    declareProperty(make_unique<WorkspaceProperty<Workspace>>(
        "OutputWorkspace", "", Direction::Output));
  }
  void exec() override {
    boost::shared_ptr<MatrixWorkspace> ws = getProperty("InputWorkspace");
    fprintf(stderr, "exec %d %s\n", communicator().rank(), Parallel::toString(ws->storageMode()).c_str());
    setProperty("OutputWorkspace", ws->clone());
  }

protected:
  Parallel::ExecutionMode getParallelExecutionMode(
      const std::map<std::string, Parallel::StorageMode> &storageModes)
      const override {
    return getCorrespondingExecutionMode(storageModes.at("InputWorkspace"));
  }
};
DECLARE_ALGORITHM(ParallelAlgorithm)

void run_NoParallelismAlgorithm(const Parallel::Communicator &comm) {
  for (auto storageMode :
       {Parallel::StorageMode::Cloned, Parallel::StorageMode::Distributed,
        Parallel::StorageMode::MasterOnly}) {
    NoParallelismAlgorithm alg;
    alg.setCommunicator(comm);
    auto in = boost::make_shared<WorkspaceTester>(storageMode);
    alg.initialize();
    alg.setProperty("InputWorkspace", in);
    if (comm.size() == 1) {
      TS_ASSERT_THROWS_NOTHING(alg.execute());
    } else {
      TS_ASSERT_THROWS_EQUALS(
          alg.execute(), const std::runtime_error &e, std::string(e.what()),
          "Algorithm does not support execution with input workspaces of the "
          "following storage types: \nInputWorkspace " +
              Parallel::toString(storageMode) + "\n.");
    }
  }
}

void run_AlgorithmWithBad_getParallelExecutionMode(const Parallel::Communicator &comm) {
  AlgorithmWithBad_getParallelExecutionMode alg;
  alg.setCommunicator(comm);
  alg.initialize();
  if(comm.size() == 1) {
    TS_ASSERT_THROWS_NOTHING(alg.execute());
  } else {
    TS_ASSERT_THROWS_EQUALS(alg.execute(), const std::runtime_error &e,
                            std::string(e.what()),
                            "Parallel::ExecutionMode::Serial is not a valid "
                            "*parallel* execution mode.");
  }
}

void run_ParallelAlgorithm(const Parallel::Communicator &comm) {
  for (auto storageMode :
       {//Parallel::StorageMode::Cloned, Parallel::StorageMode::Distributed,
        Parallel::StorageMode::MasterOnly}) {
    ParallelAlgorithm alg;
    alg.setCommunicator(comm);
    alg.initialize();
    auto in = boost::make_shared<WorkspaceTester>(storageMode);
    if (storageMode != Parallel::StorageMode::MasterOnly || comm.rank() == 0) {
      alg.setProperty("InputWorkspace", in);
    } else {
      alg.setProperty("InputWorkspace", in->cloneEmpty());
    }
    std::string outName("out" + Strings::toString(comm.rank()));
    alg.setProperty("OutputWorkspace", outName);
    TS_ASSERT_THROWS_NOTHING(alg.execute());
    auto out = AnalysisDataService::Instance().retrieve(outName);
    TS_ASSERT_EQUALS(out->storageMode(), storageMode);
  }
}
#endif

class AlgorithmMPITest : public CxxTest::TestSuite {
public:
  // This pair of boilerplate methods prevent the suite being created statically
  // This means the constructor isn't called when running other tests
  static AlgorithmMPITest *createSuite() { return new AlgorithmMPITest(); }
  static void destroySuite(AlgorithmMPITest *suite) { delete suite; }

  AlgorithmMPITest() {
    Mantid::API::FrameworkManager::Instance();
    AnalysisDataService::Instance();
  }

  void testNoParallelismAlgorithm() {
#ifdef MPI_EXPERIMENTAL
    Parallel::ParallelRunner serial(1);
    serial.run(run_NoParallelismAlgorithm);
    runParallel(run_NoParallelismAlgorithm);
    runParallel(run_NoParallelismAlgorithm);
    runParallel(run_NoParallelismAlgorithm);
    Mantid::API::AnalysisDataService::Instance().clear();
#endif
  }

  void testAlgorithmWithBad_getParallelExecutionMode() {
#ifdef MPI_EXPERIMENTAL
    Parallel::ParallelRunner serial(1);
    serial.run(run_AlgorithmWithBad_getParallelExecutionMode);
    runParallel(run_AlgorithmWithBad_getParallelExecutionMode);
    Mantid::API::AnalysisDataService::Instance().clear();
#endif
  }

  void testParallelAlgorithm() {
#ifdef MPI_EXPERIMENTAL
    Parallel::ParallelRunner serial(1);
    serial.run(run_ParallelAlgorithm);
    runParallel(run_ParallelAlgorithm);
    Mantid::API::AnalysisDataService::Instance().clear();
#endif
  }
};

#endif /*ALGORITHMMPITEST_H_*/