Commit c7cc3db7 authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

add an option to use multiple IDPs

parent 9d5e1963
Loading
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -41,8 +41,8 @@ build-packages:
        - docker build -f server_side/dockerfiles/Dockerfile.opensuse -t opensuse --target package ./server_side
        - docker run --rm -v `pwd`:/tmp opensuse bash -c "cp /src/c/build/*.rpm /tmp"
        - fname=`ls *.rpm | head -n 1`
        - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file $fname "${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/rpm/0.2.1/$fname"'
        - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file $fname "${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/rpm/0.3.0/$fname"'
        - fname=`ls *.deb | head -n 1`
        - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file $fname "${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/deb/0.2.1/$fname"'
        - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file $fname "${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/deb/0.3.0/$fname"'
    tags:
        - rse-multi-builder
+2 −2
Original line number Diff line number Diff line
@@ -30,8 +30,8 @@ SET(CPACK_DEBIAN_PACKAGE_MAINTAINER "ORNL")
SET(CPACK_RPM_PACKAGE_MAINTAINER "ORNL")

set(CPACK_PACKAGE_VERSION_MAJOR "0")
set(CPACK_PACKAGE_VERSION_MINOR "2")
set(CPACK_PACKAGE_VERSION_PATCH "1")
set(CPACK_PACKAGE_VERSION_MINOR "3")
set(CPACK_PACKAGE_VERSION_PATCH "0")

set(CPACK_DEBIAN_PACKAGE_DEPENDS "curl")
set(CPACK_RPM_PACKAGE_DEPENDS "curl")
+58 −75
Original line number Diff line number Diff line
#include "auth.h"

#include <stdio.h>
#include <stdlib.h>
#include <memory.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/pem.h>
@@ -13,6 +11,7 @@
#include <curl/curl.h>
#include "cjwt/cJSON.h"
#include "cjwt/cjwt.h"
#include "cjwt/base64.h"
#include "log.h"

static size_t
@@ -35,9 +34,37 @@ WriteMemoryCallback(void *contents, size_t size, size_t nmemb, void *userp) {
    return realsize;
}

int verify_token(const char *token, oidc_token_content_t *token_info) {
unsigned char* base64_urlsafe_decode(const char *input, int length, int *out_length) {
    // Convert base64 URL-safe to base64 standard
    char *converted = malloc(length + 2); // extra bytes for potential padding
    strcpy(converted, input);
    for (int i = 0; i < length; ++i) {
        if (converted[i] == '-') converted[i] = '+';
        if (converted[i] == '_') converted[i] = '/';
    }
    // Add necessary padding
    while (strlen(converted) % 4) {
        strcat(converted, "=");
    }

    // Decode base64 standard
    BIO *bio, *b64;
    unsigned char *buffer = malloc(length); // decoded length will be <= encoded length
    bio = BIO_new_mem_buf(converted, -1);
    b64 = BIO_new(BIO_f_base64());
    bio = BIO_push(b64, bio);
    BIO_set_flags(bio, BIO_FLAGS_BASE64_NO_NL); // Do not use newlines to flush buffer
    *out_length = BIO_read(bio, buffer, strlen(converted));
    BIO_free_all(bio);

    free(converted);
    return buffer;
}

    cJSON *keyset_json = fetch_jwks();

int verify_token(const char *token, oidc_token_content_t *token_info, int auth_number) {
    
    cJSON *keyset_json = fetch_jwks(auth_number);

    if (keyset_json == NULL) {
        const char *error_ptr = cJSON_GetErrorPtr();
@@ -63,36 +90,43 @@ int verify_token(const char *token, oidc_token_content_t *token_info) {
    // Search for correct key
    const cJSON *keyset = cJSON_GetObjectItemCaseSensitive(keyset_json, "keys");

    const unsigned char *key = NULL;
    const  char *n = NULL;
    const  char *e = NULL;

    const cJSON *key_itr = NULL;
    const cJSON *kid = NULL;    
    cJSON_ArrayForEach(key_itr, keyset)
    {
        kid = cJSON_GetObjectItemCaseSensitive(key_itr, "kid");


        if (strcmp(kid->valuestring, token_kid) == 0)
        {
            key = (cJSON_GetObjectItemCaseSensitive(key_itr, "x5c")->child->valuestring);
            n = (cJSON_GetObjectItemCaseSensitive(key_itr, "n")->valuestring); // Modulus
            e = (cJSON_GetObjectItemCaseSensitive(key_itr, "e")->valuestring); // Exponent
        }

    }
    // Free header since it's no longer being used

    cjwt_header_destroy(jwt_header);

    // Handle not finding the correct key
    if (!key) {
    if (!n || !e) {
        logit("Could not find correct key in keyset. Token: %s\n", token);
        return 1;
    }

    // Load Cert here
    char* loaded_cert = load_cert(key);
    cJSON_Delete(keyset_json);
    if (loaded_cert == NULL) {
        logit("Loaded cert is null. Token: %s\n", token);
        return 1;
    }
    int n_length, e_length;
    unsigned char *n_bytes = base64_urlsafe_decode(n, strlen(n), &n_length);
    unsigned char *e_bytes = base64_urlsafe_decode(e, strlen(e), &e_length);
    BIGNUM *n_bn = BN_bin2bn(n_bytes, n_length, NULL);
    BIGNUM *e_bn = BN_bin2bn(e_bytes, e_length, NULL);

    // Create RSA key
    RSA *rsa = RSA_new();
    RSA_set0_key(rsa, n_bn, e_bn, NULL);
    BIO *pem_bio = BIO_new(BIO_s_mem());
    PEM_write_bio_RSA_PUBKEY(pem_bio, rsa);

    char *loaded_cert;
    BIO_get_mem_data(pem_bio, &loaded_cert);

    // Actually validate token now
    cjwt_t *jwt = NULL;
@@ -104,10 +138,11 @@ int verify_token(const char *token, oidc_token_content_t *token_info) {
        return 1;
    }

    const cJSON *user = cJSON_GetObjectItemCaseSensitive(jwt->private_claims, "preferred_username");
    const cJSON *user = cJSON_GetObjectItemCaseSensitive(jwt->private_claims, config.name_field[auth_number]);
    token_info->exp = *(jwt->exp);

    if (!cJSON_IsString(user) || (user->valuestring == NULL)) {
        logit("Could not find 'preferred_username' claim.\n");
        logit("Could not find %s claim.\n",config.name_field[auth_number]);
        cjwt_destroy(jwt);
        return 1;
    }
@@ -115,12 +150,13 @@ int verify_token(const char *token, oidc_token_content_t *token_info) {
    token_info->user = malloc(strlen(user->valuestring));
    strcpy(token_info->user, user->valuestring);

    cJSON_Delete(keyset_json);
    cjwt_destroy(jwt);

    return 0;
}

cJSON* fetch_jwks() {
cJSON* fetch_jwks(int auth_number) {
    memory_struct mem;

    mem.memory = malloc(1);  /* will be grown as needed by realloc above */
@@ -133,7 +169,7 @@ cJSON* fetch_jwks() {
    curl = curl_easy_init();
    long http_code = 0;
    if (curl) {
        curl_easy_setopt(curl, CURLOPT_URL, config.jwks_url);
        curl_easy_setopt(curl, CURLOPT_URL, config.jwks_url[auth_number]);
        curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteMemoryCallback);
        curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *) &mem);
        curl_easy_setopt(curl, CURLOPT_USERAGENT, "libcurl-agent/1.0");
@@ -158,56 +194,3 @@ cJSON* fetch_jwks() {
    free(mem.memory);
    return keyset_json;
}
 No newline at end of file

char *load_cert(const char* x509) {

    /* ---------------------------------------------------------- *
    * These function calls initialize openssl for correct work.  *
    * ---------------------------------------------------------- */
    OpenSSL_add_all_algorithms();
    ERR_load_BIO_strings();
    ERR_load_crypto_strings();

    const char *header = "-----BEGIN CERTIFICATE-----\n";
    const char *footer = "\n-----END CERTIFICATE-----";

    char *x509_formatted = malloc(strlen(x509) + strlen(header) + strlen(footer));
    strcpy(x509_formatted, header);
    strcat(x509_formatted, x509);
    strcat(x509_formatted, footer);

    X509 *cert = NULL;
    BIO *certbio = BIO_new(BIO_s_mem());
    BIO_write(certbio, x509_formatted, strlen(x509_formatted) + 1);
    if (! (cert = PEM_read_bio_X509(certbio, NULL, 0, NULL))) {
        logit("Error reading cert into memory. x509: %s\n", x509_formatted);
        BIO_free_all(certbio);
        free(x509_formatted);
        return NULL;
    }
    BIO_free_all(certbio);

    free(x509_formatted);

    EVP_PKEY *pkey = NULL;
    if ((pkey = X509_get_pubkey(cert)) == NULL) {
        logit("Error getting public key.\n");
        X509_free(cert);
        return NULL;
    }
    X509_free(cert);
    BIO *keybio = BIO_new(BIO_s_mem());
    if(!PEM_write_bio_PUBKEY(keybio, pkey)) {
        logit("Error writing public key data in PEM format\n");
        EVP_PKEY_free(pkey);
        return NULL;
    }
    char* key_buf = (char*) malloc(EVP_PKEY_bits(pkey) + 1);
    memset(key_buf, 0, EVP_PKEY_bits(pkey) + 1);
    BIO_read(keybio, key_buf, EVP_PKEY_bits(pkey));

    EVP_PKEY_free(pkey);
    BIO_free_all(keybio);

    return key_buf;
}
 No newline at end of file
+4 −6
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
#define OIDC_PAM_AUTH_H

#include "config.h"
#include <inttypes.h>

typedef struct MemoryStruct {
    char *memory;
@@ -11,15 +12,12 @@ typedef struct MemoryStruct {
typedef struct oidc_token_content_t
{
    char *user;
    // const char *session_attribute;
    // int active;
    int64_t exp;
} oidc_token_content_t;

int verify_token(const char* token, oidc_token_content_t *token_info);
int verify_token(const char* token, oidc_token_content_t *token_info, int auth_number);

cJSON* fetch_jwks();

char *load_cert(const char* x509);
cJSON* fetch_jwks(int auth_number);


#endif //OIDC_PAM_AUTH_H
+30 −4
Original line number Diff line number Diff line
@@ -36,22 +36,48 @@ int parse_config(const char* fname, json_config_t* config) {
        return 1;
    }

    const cJSON *jwks_url = cJSON_GetObjectItemCaseSensitive(config_json, "jwks_url");
    const cJSON *auth = cJSON_GetObjectItemCaseSensitive(config_json, "auth");
    config->n_auth = cJSON_GetArraySize(auth);
    const cJSON *auth_itr = NULL;
    int i = 0;
    config->jwks_url = malloc(sizeof(char*)*config->n_auth);
    config->name_field = malloc(sizeof(char*)*config->n_auth);
    config->name_separator = malloc(sizeof(char*)*config->n_auth);

    cJSON_ArrayForEach(auth_itr, auth)
    {
        const cJSON *jwks_url = cJSON_GetObjectItemCaseSensitive(auth_itr, "jwks_url");
        const cJSON *name_field = (cJSON_GetObjectItemCaseSensitive(auth_itr, "name_field"));
        const cJSON *name_separator = (cJSON_GetObjectItemCaseSensitive(auth_itr, "name_separator"));
        if (!cJSON_IsString(jwks_url) || (jwks_url->valuestring == NULL) ||
            !cJSON_IsString(name_field) || (name_field->valuestring == NULL) ||
            !cJSON_IsString(name_separator) || (name_separator->valuestring == NULL))
        {
            free(buffer);
            return 1;
        }
        config->jwks_url[i] = jwks_url->valuestring;
        config->name_field[i] = name_field->valuestring;
        config->name_separator[i] = name_separator->valuestring;
        i++;
    }
    const cJSON *check_2fa = cJSON_GetObjectItemCaseSensitive(config_json, "check_2fa");
    const cJSON *enable_log = cJSON_GetObjectItemCaseSensitive(config_json, "enable_log");
    const cJSON *log_file = cJSON_GetObjectItemCaseSensitive(config_json, "log_file");
    const cJSON *cache_folder = cJSON_GetObjectItemCaseSensitive(config_json, "cache_folder");

    if (!cJSON_IsString(jwks_url) || (jwks_url->valuestring == NULL) ||
        !cJSON_IsBool(check_2fa) || !cJSON_IsBool(enable_log))
    if (!cJSON_IsBool(check_2fa) || !cJSON_IsBool(enable_log)
    || !cJSON_IsString(cache_folder) || (cache_folder->valuestring == NULL))
    {
        free(buffer);
        return 1;
    }

    config->jwks_url = jwks_url->valuestring;

    config->enable_2fa = cJSON_IsFalse(check_2fa)?0:1;
    config->enable_log = cJSON_IsFalse(enable_log)?0:1;
    config->log_file = log_file->valuestring;
    config->cache_folder = cache_folder->valuestring;

    config->parsed_object = config_json;
    free(buffer);
Loading