Commit 22bae737 authored by cianciosa's avatar cianciosa
Browse files

Add a basic commandline parser.

parent 867f2f8c
Loading
Loading
Loading
Loading
+24 −20
Original line number Diff line number Diff line
@@ -11,11 +11,7 @@
#include "../graph_framework/timing.hpp"
#include "../graph_framework/output.hpp"
#include "../graph_framework/absorption.hpp"

const bool print = false;
const bool write_step = true;
const bool print_expressions = false;
const bool verbose = true;
#include "../graph_framework/commandline_parser.hpp"

//------------------------------------------------------------------------------
///  @brief Initalize random rays for efit.
@@ -113,12 +109,14 @@ void init_vmec(graph::shared_leaf<T, SAFE_MATH> omega,
///  @tparam T         Base type of the calculation.
///  @tparam SAFE_MATH Use safe math operations.
///
///  @params[in] cl        Parsed commandline.
///  @params[in] num_times Total number of time steps.
///  @params[in] sub_steps Number of substeps to push the rays.
///  @params[in] num_rays  Number of rays to trace.
//------------------------------------------------------------------------------
template<std::floating_point T, bool SAFE_MATH=false>
void trace_ray(const size_t num_times,
void trace_ray(const commandline::parser &cl,
               const size_t num_times,
               const size_t sub_steps,
               const size_t num_rays) {
    const timeing::measure_diagnostic total("Total Ray Time");
@@ -131,7 +129,7 @@ void trace_ray(const size_t num_times,
    const size_t extra = num_rays%threads.size();

    for (size_t i = 0, ie = threads.size(); i < ie; i++) {
        threads[i] = std::thread([num_times, sub_steps, batch, extra] (const size_t thread_number) -> void {
        threads[i] = std::thread([&cl, num_times, sub_steps, batch, extra] (const size_t thread_number) -> void {

            const size_t num_steps = num_times/sub_steps;
            const size_t local_num_rays = batch
@@ -214,7 +212,7 @@ void trace_ray(const size_t num_times,
                      stream.str(), local_num_rays, thread_number);
            solve.init(kx);
            solve.compile();
            if (thread_number == 0 && print_expressions) {
            if (thread_number == 0 && cl.is_option_set("print_expressions")) {
                solve.print_dispersion();
                std::cout << std::endl;
                solve.print_dkxdt();
@@ -251,13 +249,12 @@ void trace_ray(const size_t num_times,
                std::cout << "Omega " << omega->evaluate().at(sample) << std::endl;
            }

            const bool print = cl.is_option_set("print");
            for (size_t j = 0; j < num_steps; j++) {
                if (thread_number == 0 && print) {
                    solve.print(sample);
                }
                if (write_step) {
                solve.write_step();
                }
                for (size_t k = 0; k < sub_steps; k++) {
                    solve.step();
                }
@@ -265,10 +262,8 @@ void trace_ray(const size_t num_times,

            if (thread_number == 0 && print) {
                solve.print(sample);
            } else if (write_step) {
                solve.write_step();
            } else {
                solve.sync_host();
                solve.write_step();
            }

        }, i);
@@ -491,12 +486,21 @@ int main(int argc, const char * argv[]) {
    (void)argv;
    const timeing::measure_diagnostic total("Total Time");

    jit::verbose = verbose;
    commandline::parser cl(argv[0]);
    cl.add_option("verbose",           false, "Show verbose output.");
    cl.add_option("num_times",         true,  "Number of times.");
    cl.add_option("sub_steps",         true,  "Number of substeps.");
    cl.add_option("num_rays",          true,  "Number of rays.");
    cl.add_option("print_expressions", false, "Print out rays expressions.");
    cl.add_option("print",             false, "Print sample rays to screen.");
    cl.parse(argc, argv);

    jit::verbose = cl.is_option_set("verbose");

    const size_t num_times = 100000;
    const size_t sub_steps = 100;
    const size_t num_times = cl.get_option_value<size_t> ("num_times");
    const size_t sub_steps = cl.get_option_value<size_t> ("sub_steps");
#ifndef STATIC
    const size_t num_rays = 100000;
    const size_t num_rays = cl.get_option_value<size_t> ("num_rays");
#else
    const size_t num_rays = 1;
#endif
@@ -505,7 +509,7 @@ int main(int argc, const char * argv[]) {

    typedef double base;

    trace_ray<base> (num_times, sub_steps, num_rays);
    trace_ray<base> (cl, num_times, sub_steps, num_rays);
    calculate_power<std::complex<base>, use_safe_math> (num_times,
                                                        sub_steps,
                                                        num_rays);
+2 −0
Original line number Diff line number Diff line
@@ -287,6 +287,7 @@
		C71C1FF827F621E7006997C2 /* vector.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = vector.hpp; sourceTree = "<group>"; };
		C721EA992833FF7800EAFB2D /* equilibrium.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = equilibrium.hpp; sourceTree = "<group>"; };
		C723210222DC0D0A006BBF13 /* arithmetic.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = arithmetic.hpp; sourceTree = "<group>"; };
		C72358F52C4027A10084A489 /* commandline_parser.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = commandline_parser.hpp; sourceTree = "<group>"; };
		C725CD792840088000D0EDE2 /* physics_test.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; path = physics_test.cpp; sourceTree = "<group>"; };
		C736902B2A38AC0E001733B0 /* erfi_test.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.objcpp; path = erfi_test.cpp; sourceTree = "<group>"; };
		C73690312A38C498001733B0 /* erfi_test */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = erfi_test; sourceTree = BUILT_PRODUCTS_DIR; };
@@ -526,6 +527,7 @@
			children = (
				C7931E7028073BE70033B488 /* CMakeLists.txt */,
				C79141AE22DA9C3000E0BA0D /* node.hpp */,
				C72358F52C4027A10084A489 /* commandline_parser.hpp */,
				C70B705629F4F86A00098AA0 /* piecewise.hpp */,
				C7922EEB29E0ABDF000BB9C7 /* workflow.hpp */,
				C71C1FF727F61DFA006997C2 /* math.hpp */,
+14 −0
Original line number Diff line number Diff line
@@ -59,6 +59,20 @@
            ReferencedContainer = "container:graph_framework.xcodeproj">
         </BuildableReference>
      </BuildableProductRunnable>
      <CommandLineArguments>
         <CommandLineArgument
            argument = "--num_rays=100000"
            isEnabled = "YES">
         </CommandLineArgument>
         <CommandLineArgument
            argument = "--num_times=100000"
            isEnabled = "YES">
         </CommandLineArgument>
         <CommandLineArgument
            argument = "--sub_steps=100"
            isEnabled = "YES">
         </CommandLineArgument>
      </CommandLineArguments>
   </LaunchAction>
   <ProfileAction
      buildConfiguration = "Release"
+1 −0
Original line number Diff line number Diff line
@@ -58,6 +58,7 @@ target_precompile_headers (rays
                           $<$<BOOL:${USE_PCH}>:$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/newton.hpp>>
                           $<$<BOOL:${USE_PCH}>:$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/special_functions.hpp>>
                           $<$<BOOL:${USE_PCH}>:$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/output.hpp>>
                           $<$<BOOL:${USE_PCH}>:$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/commandline_parser.hpp>>
                           $<$<BOOL:${USE_PCH}>:$<$<BOOL:${USE_METAL}>:$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/metal_context.hpp>>>
                           $<$<BOOL:${USE_PCH}>:$<$<BOOL:${USE_CUDA}>:$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/cuda_context.hpp>>>
)
+155 −0
Original line number Diff line number Diff line
//------------------------------------------------------------------------------
///  @file commandline\_parser.hpp
///  @brief Parsing routines for the command line.
//------------------------------------------------------------------------------

#ifndef commandline_parser_h
#define commandline_parser_h

#include <map>
#include <type_traits>
#include <utility>

namespace commandline {
//------------------------------------------------------------------------------
///  @brief Parser class
//------------------------------------------------------------------------------
    class parser {
    private:
///  Command line options with values.
        std::map<std::string, std::pair<bool, std::string>> options;
///  Parsed commands.
        std::map<std::string, std::string> parsed_options;
///  Command name.
        const std::string command;

//------------------------------------------------------------------------------
///  @brief Take end string.
///
///  @params[in] string    String to take.
///  @params[in] character Character to split at.
//------------------------------------------------------------------------------
        static std::string_view take_end(const char *string,
                                         const char character) {
            std::string_view view(string);
            return view.substr(view.find_last_of('/') + 1);
        }

    public:
//------------------------------------------------------------------------------
///  @brief Default constructor
//------------------------------------------------------------------------------
        parser(const char *name) : 
        command(take_end(name, '/')) {
            options.emplace(std::make_pair("help",
                                           std::make_pair(false,
                                                          "Show this help.")));
        }

//------------------------------------------------------------------------------
///  @brief Add commandline option.
///
///  Defines an option. If the option has no default, assume no
///
///  @params[in] option      The command option.
///  @params[in] takes_value Flag to indicate the option takes a value.
///  @params[in] help_text   The help text of the option.
//------------------------------------------------------------------------------
        void add_option(const std::string &option,
                        const bool takes_value,
                        const std::string &help_text) {
            options.try_emplace(option, takes_value, help_text);
        }

//------------------------------------------------------------------------------
///  @brief Display help.
///
///  @params[in] command Name of the program.
//------------------------------------------------------------------------------
        void show_help(const std::string &command) const {
            size_t longest = 0;
            for (auto &[option, value] : options) {
                longest = std::max(longest, option.size());
            }
            std::cout << "USAGE: " << command << " [--options] [--options=with_value]" << std::endl << std::endl;
            std::cout << "OPTIONS:" << std::endl;
            for (auto &[option, value] : options) {
                std::cout << "  --" << option
                          << (std::get<bool> (value) ? "= " : "  ");
                for (size_t i = option.size(); i < longest; i++) {
                    std::cout << " ";
                }
                std::cout << std::get<std::string> (value) << std::endl;
            }
            std::cout << std::endl;
            exit(0);
        }

//------------------------------------------------------------------------------
///  @brief Parse the command line.
///
///  @params[in] argc Number of commandline arguments.
///  @params[in] argv Array of commandline arguments.
//------------------------------------------------------------------------------
        void parse(const int argc, const char * argv[]) {
            for (int i = 1; i < argc; i++) {
                std::string_view view(argv[i]);
                const size_t option_end = view.find_first_of('=');
                std::string option(view.substr(2, option_end - 2));
                if (is_option_set(option)) {
                    std::cout << "Warning --" << option << " set more than once." << std::endl;
                    std::cout << "  Overwriting --" << option << std::endl;
                }
                if (options.find(option) == options.cend()) {
                    std::cout << "UNKNOWN OPTION: " << view << std::endl << std::endl;
                    show_help(std::string(argv[0]));
                }
                if (option_end != view.size()) {
                    parsed_options[option] = std::string(view.substr(option_end + 1));
                } else {
                    parsed_options[option] = "";
                }
            }
        }

//------------------------------------------------------------------------------
///  @brief Check if option is set.
///
///  @params[in] option The option to check.
//------------------------------------------------------------------------------
        bool is_option_set(const std::string &option) const {
            return parsed_options.find(option) != parsed_options.cend();
        }

//------------------------------------------------------------------------------
///  @brief Get the option value.
///
///  @tparam T Type of the value.
//------------------------------------------------------------------------------
        template<typename T>
        T get_option_value(const std::string &option) const {
            if (is_option_set(option)) {
                std::string value = parsed_options.at(option);
                if constexpr (std::is_same<T, std::string> ()) {
                    return value;
                } else if constexpr (std::is_same<T, float> ()) {
                    return std::stof(value);
                } else if constexpr (std::is_same<T, double> ()) {
                    return std::stod(value);
                } else if constexpr (std::is_same<T, int> ()) {
                    return std::stoi(value);
                } else if constexpr (std::is_same<T, long> ()) {
                    return std::stol(value);
                } else if constexpr (std::is_same<T, unsigned long> ()) {
                    return std::stoul(value);
                }
            } else {
                std::cout << "Expected option : --" << option << std::endl << std::endl;
                show_help(command);
            }
            return NULL;
        }
    };
}

#endif /* commandline_parser_h */