Commit 0675759d authored by Cianciosa, Mark's avatar Cianciosa, Mark
Browse files

Add support for mlx models.

parent 18413c8f
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -27,6 +27,25 @@ ml_model mlce_init_custom_pytorch(const char *weights_filename,
                                                                       std::string(config_filename)));
}

//------------------------------------------------------------------------------
///  @brief Initalize a mlx model.
///
///  @param[in] filename Path to the mlx model.
///  @param[in] in_shape Input dimension shapes.
///  @param[in] num_dims Number of input dimensions.
//------------------------------------------------------------------------------
    ml_model mlce_init_mlx(const char *filename,
                           const int *in_shape,
                           const size_t num_dims) {
        mlx::core::SmallVector<int> shape(num_dims);
        for (size_t i = 0; i < num_dims; i++) {
            shape[i] = in_shape[i];
        }

        return reinterpret_cast<ml_model> (new ml_embedder::mlx_model(std::string(filename),
                                                                      shape));
    }

//------------------------------------------------------------------------------
///  @brief Finalize a model.
///
+11 −0
Original line number Diff line number Diff line
@@ -33,6 +33,17 @@ extern "C" {
    ml_model mlce_init_custom_pytorch(const char *weights_filename,
                                      const char *config_filename);

//------------------------------------------------------------------------------
///  @brief Initalize a mlx model.
///
///  @param[in] filename Path to the mlx model.
///  @param[in] in_shape Input dimension shapes.
///  @param[in] num_dims Number of input dimensions.
//------------------------------------------------------------------------------
    ml_model mlce_init_mlx(const char *filename,
                           const int *in_shape,
                           const size_t num_dims);

//------------------------------------------------------------------------------
///  @brief Finalize a model.
///
+3.13 KiB

File added.

No diff preview for this file type.

+43 −1
Original line number Diff line number Diff line
@@ -28,7 +28,8 @@
!*******************************************************************************
      INTERFACE ml_context
         MODULE PROCEDURE ml_context_construct_keras,                          &
                          ml_context_construct_custom_pytorch
                          ml_context_construct_custom_pytorch,                 &
                          ml_context_construct_mlx
      END INTERFACE

!*******************************************************************************
@@ -62,6 +63,22 @@
         CHARACTER(kind=C_CHAR), DIMENSION(*) :: config_filename
         END FUNCTION

!-------------------------------------------------------------------------------
!>  @brief Initialize a custom mlx model.
!>
!>  @param[in] filename Path to the mlx model.
!>  @param[in] in_shape Input dimension shapes.
!>  @param[in] num_dims Number of input dimensions.
!-------------------------------------------------------------------------------
         TYPE(C_PTR) FUNCTION mlce_init_mlx(filename, in_shape, num_dims)      &
         BIND(C, NAME='mlce_init_mlx')
         USE, INTRINSIC :: ISO_C_BINDING
         IMPLICIT NONE
         CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
         INTEGER(C_INTPTR_T), VALUE           :: in_shape
         INTEGER(C_LONG), VALUE               :: num_dims
         END FUNCTION

!-------------------------------------------------------------------------------
!>  @brief Finalize a model.
!>
@@ -176,6 +193,31 @@

      END FUNCTION

!-------------------------------------------------------------------------------
!>  @brief Construct a @ref ml_f_embedder::ml_context object.
!>
!>  @param[in] filename Path to the mlx model.
!>  @param[in] in_shape Input dimension shapes.
!>  @returns A @ref ml_f_embedder::ml_context object instance.
!-------------------------------------------------------------------------------
      FUNCTION ml_context_construct_mlx(filename, in_shape)

      IMPLICIT NONE

!  Declare Arguments
      CLASS(ml_context), POINTER               ::                              &
         ml_context_construct_mlx
      CHARACTER(kind=C_CHAR,len=*), INTENT(IN) :: filename
      INTEGER(C_INT), DIMENSION(:), INTENT(IN) :: in_shape

!  Start of executable.
      ALLOCATE (ml_context_construct_mlx)
      ml_context_construct_mlx%m =                                             &
         mlce_init_mlx(filename, LOC(in_shape),                                &
                       SIZE(in_shape, KIND=C_LONG))

      END FUNCTION

!*******************************************************************************
!  DESTRUCTION SUBROUTINES
!*******************************************************************************
+3 −1
Original line number Diff line number Diff line
@@ -325,9 +325,10 @@
			buildSettings = {
				CODE_SIGN_STYLE = Automatic;
				GCC_PREPROCESSOR_DEFINITIONS = (
					"KERAS_MODEL=\\\"/Users/m4c/Projects/ml_model_embeder/example_models/EPED/saved_model.keras\\\"",
					"KERAS_MODEL=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/EPED/saved_model.keras\\\"",
					"TGLF_MODEL_CONFIG=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/TGLF/MyModel_manual.json\\\"",
					"TGLF_MODEL_WEIGHTS=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/TGLF/keras_model_weights.weights.h5\\\"",
					"MLX_MODEL=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/MLX/test.mlxfn\\\"",
					"DEBUG=1",
					"$(inherited)",
				);
@@ -365,6 +366,7 @@
				GCC_PREPROCESSOR_DEFINITIONS = (
					"TGLF_MODEL_WEIGHTS=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/TGLF/keras_model_weights.weights.h5\\\"",
					"TGLF_MODEL_CONFIG=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/TGLF/MyModel_manual.json\\\"",
					"MLX_MODEL=\\\"/Users/m4c/Projects/ml_model_embedder/example_models/MLX/test.mlxfn\\\"",
					"KERAS_MODEL=\\\"/Users/m4c/Projects/ml_model_embeder/example_models/saved_model.keras\\\"",
				);
				HEADER_SEARCH_PATHS = (
Loading