Commit 79fcd293 authored by Robert Schütz's avatar Robert Schütz
Browse files
parent f780f1f7
Loading
Loading
Loading
Loading
+7 −16
Original line number Diff line number Diff line
@@ -4,16 +4,12 @@
  cryptography,
  defusedxml,
  fetchFromGitHub,
  fetchpatch,
  paste,
  poetry-core,
  pyasn1,
  pymongo,
  pyopenssl,
  pytestCheckHook,
  python-dateutil,
  pythonOlder,
  pytz,
  repoze-who,
  requests,
  responses,
@@ -26,14 +22,14 @@

buildPythonPackage rec {
  pname = "pysaml2";
  version = "7.5.2";
  version = "7.5.4";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "IdentityPython";
    repo = "pysaml2";
    tag = "v${version}";
    hash = "sha256-2mvAXTruZqoSBUgfT2VEAnWQXVdviG0e49y7LPK5x00=";
    hash = "sha256-DDs0jWONZ78995p7bbyIyZTWHnCI93SsbECqyeo0se8=";
  };

  patches = [
@@ -41,10 +37,8 @@ buildPythonPackage rec {
      inherit xmlsec;
    })
    # Replaces usages of deprecated/removed pyopenssl APIs
    (fetchpatch {
      url = "https://github.com/IdentityPython/pysaml2/pull/977/commits/930a652a240c8cd1489429a7d70cf5fa7ef1606a.patch";
      hash = "sha256-kBNvGk5pwVmpW1wsIWVH9wapu6kjFavaTt4e3Llaw2c=";
    })
    # https://github.com/IdentityPython/pysaml2/pull/977
    ./replace-pyopenssl-with-cryptography.patch
  ];

  postPatch = ''
@@ -54,18 +48,14 @@ buildPythonPackage rec {

  pythonRelaxDeps = [ "xmlschema" ];

  nativeBuildInputs = [
  build-system = [
    poetry-core
  ];

  propagatedBuildInputs = [
  dependencies = [
    cryptography
    defusedxml
    pyopenssl
    python-dateutil
    pytz
    requests
    setuptools
    xmlschema
  ];

@@ -81,6 +71,7 @@ buildPythonPackage rec {
    pyasn1
    pymongo
    pytestCheckHook
    python-dateutil
    responses
  ];

+334 −0
Original line number Diff line number Diff line
diff --git a/pyproject.toml b/pyproject.toml
index 87068b18..03f30491 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,8 +22,6 @@ requires-python = ">= 3.9"
 dependencies = [
   "cryptography >=3.1",
   "defusedxml",
-  "pyopenssl <24.3.0",
-  "python-dateutil",
   "requests >=2.0.0,<3.0.0",  # ^2 means compatible with 2.x
   "xmlschema >=2.0.0,<3.0.0"
 ]
diff --git a/src/saml2/cert.py b/src/saml2/cert.py
index e90651e4..926aec8a 100644
--- a/src/saml2/cert.py
+++ b/src/saml2/cert.py
@@ -3,11 +3,13 @@ __author__ = "haho0032"
 import base64
 from os import remove
 from os.path import join
-from datetime import datetime
-from datetime import timezone
+from datetime import datetime, timedelta, timezone
 
-from OpenSSL import crypto
-import dateutil.parser
+from cryptography import x509
+from cryptography.exceptions import InvalidSignature
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.x509.oid import NameOID
 
 import saml2.cryptography.pki
 
@@ -36,7 +38,6 @@ class OpenSSLWrapper:
         valid_to=315360000,
         sn=1,
         key_length=1024,
-        hash_alg="sha256",
         write_to_file=False,
         cert_dir="",
         cipher_passphrase=None,
@@ -87,8 +88,6 @@ class OpenSSLWrapper:
                                   is 1.
         :param key_length:        Length of the key to be generated. Defaults
                                   to 1024.
-        :param hash_alg:          Hash algorithm to use for the key. Default
-                                  is sha256.
         :param write_to_file:     True if you want to write the certificate
                                   to a file. The method will then return
                                   a tuple with path to certificate file and
@@ -131,49 +130,68 @@ class OpenSSLWrapper:
             k_f = join(cert_dir, key_file)
 
         # create a key pair
-        k = crypto.PKey()
-        k.generate_key(crypto.TYPE_RSA, key_length)
+        k = rsa.generate_private_key(
+            public_exponent=65537,
+            key_size=key_length,
+        )
 
         # create a self-signed cert
-        cert = crypto.X509()
+        builder = x509.CertificateBuilder()
 
         if request:
-            cert = crypto.X509Req()
+            builder = x509.CertificateSigningRequestBuilder()
 
         if len(cert_info["country_code"]) != 2:
             raise WrongInput("Country code must be two letters!")
-        cert.get_subject().C = cert_info["country_code"]
-        cert.get_subject().ST = cert_info["state"]
-        cert.get_subject().L = cert_info["city"]
-        cert.get_subject().O = cert_info["organization"]  # noqa: E741
-        cert.get_subject().OU = cert_info["organization_unit"]
-        cert.get_subject().CN = cn
+        subject_name = x509.Name([
+            x509.NameAttribute(NameOID.COUNTRY_NAME,
+                               cert_info["country_code"]),
+            x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME,
+                               cert_info["state"]),
+            x509.NameAttribute(NameOID.LOCALITY_NAME,
+                               cert_info["city"]),
+            x509.NameAttribute(NameOID.ORGANIZATION_NAME,
+                               cert_info["organization"]),
+            x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME,
+                               cert_info["organization_unit"]),
+            x509.NameAttribute(NameOID.COMMON_NAME, cn),
+        ])
+        builder = builder.subject_name(subject_name)
         if not request:
-            cert.set_serial_number(sn)
-            cert.gmtime_adj_notBefore(valid_from)  # Valid before present time
-            cert.gmtime_adj_notAfter(valid_to)  # 3 650 days
-            cert.set_issuer(cert.get_subject())
-        cert.set_pubkey(k)
-        cert.sign(k, hash_alg)
+            now = datetime.now(timezone.utc)
+            builder = builder.serial_number(
+                sn,
+            ).not_valid_before(
+                now + timedelta(seconds=valid_from),
+            ).not_valid_after(
+                now + timedelta(seconds=valid_to),
+            ).issuer_name(
+                subject_name,
+            ).public_key(
+                k.public_key(),
+            )
+        cert = builder.sign(k, hashes.SHA256())
 
         try:
-            if request:
-                tmp_cert = crypto.dump_certificate_request(crypto.FILETYPE_PEM, cert)
-            else:
-                tmp_cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
-            tmp_key = None
+            tmp_cert = cert.public_bytes(serialization.Encoding.PEM)
+            key_encryption = None
             if cipher_passphrase is not None:
                 passphrase = cipher_passphrase["passphrase"]
                 if isinstance(cipher_passphrase["passphrase"], str):
                     passphrase = passphrase.encode("utf-8")
-                tmp_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, k, cipher_passphrase["cipher"], passphrase)
+                key_encryption = serialization.BestAvailableEncryption(passphrase)
             else:
-                tmp_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
+                key_encryption = serialization.NoEncryption()
+            tmp_key = k.private_bytes(
+                encoding=serialization.Encoding.PEM,
+                format=serialization.PrivateFormat.TraditionalOpenSSL,
+                encryption_algorithm=key_encryption,
+            )
             if write_to_file:
-                with open(c_f, "w") as fc:
-                    fc.write(tmp_cert.decode("utf-8"))
-                with open(k_f, "w") as fk:
-                    fk.write(tmp_key.decode("utf-8"))
+                with open(c_f, "wb") as fc:
+                    fc.write(tmp_cert)
+                with open(k_f, "wb") as fk:
+                    fk.write(tmp_key)
                 return c_f, k_f
             return tmp_cert, tmp_key
         except Exception as ex:
@@ -198,7 +216,6 @@ class OpenSSLWrapper:
         sign_cert_str,
         sign_key_str,
         request_cert_str,
-        hash_alg="sha256",
         valid_from=0,
         valid_to=315360000,
         sn=1,
@@ -222,8 +239,6 @@ class OpenSSLWrapper:
                                   the requested certificate. If you only have
                                   a file use the method read_str_from_file
                                   to get a string representation.
-        :param hash_alg:          Hash algorithm to use for the key. Default
-                                  is sha256.
         :param valid_from:        When the certificate starts to be valid.
                                   Amount of seconds from when the
                                   certificate is generated.
@@ -237,27 +252,29 @@ class OpenSSLWrapper:
         :return:                  String representation of the signed
                                   certificate.
         """
-        ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, sign_cert_str)
-        ca_key = None
-        if passphrase is not None:
-            ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, sign_key_str, passphrase)
-        else:
-            ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, sign_key_str)
-        req_cert = crypto.load_certificate_request(crypto.FILETYPE_PEM, request_cert_str)
-
-        cert = crypto.X509()
-        cert.set_subject(req_cert.get_subject())
-        cert.set_serial_number(sn)
-        cert.gmtime_adj_notBefore(valid_from)
-        cert.gmtime_adj_notAfter(valid_to)
-        cert.set_issuer(ca_cert.get_subject())
-        cert.set_pubkey(req_cert.get_pubkey())
-        cert.sign(ca_key, hash_alg)
-
-        cert_dump = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
-        if isinstance(cert_dump, str):
-            return cert_dump
-        return cert_dump.decode("utf-8")
+        if isinstance(sign_cert_str, str):
+            sign_cert_str = sign_cert_str.encode("utf-8")
+        ca_cert = x509.load_pem_x509_certificate(sign_cert_str)
+        ca_key = serialization.load_pem_private_key(
+            sign_key_str, password=passphrase)
+        req_cert = x509.load_pem_x509_csr(request_cert_str)
+
+        now = datetime.now(timezone.utc)
+        cert = x509.CertificateBuilder().subject_name(
+            req_cert.subject,
+        ).serial_number(
+            sn,
+        ).not_valid_before(
+            now + timedelta(seconds=valid_from),
+        ).not_valid_after(
+            now + timedelta(seconds=valid_to),
+        ).issuer_name(
+            ca_cert.subject,
+        ).public_key(
+            req_cert.public_key(),
+        ).sign(ca_key, hashes.SHA256())
+
+        return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
 
     def verify_chain(self, cert_chain_str_list, cert_str):
         """
@@ -276,13 +293,6 @@ class OpenSSLWrapper:
                 cert_str = tmp_cert_str
             return (True, "Signed certificate is valid and correctly signed by CA " "certificate.")
 
-    def certificate_not_valid_yet(self, cert):
-        starts_to_be_valid = dateutil.parser.parse(cert.get_notBefore())
-        now = datetime.now(timezone.utc)
-        if starts_to_be_valid < now:
-            return False
-        return True
-
     def verify(self, signing_cert_str, cert_str):
         """
         Verifies if a certificate is valid and signed by a given certificate.
@@ -303,34 +313,34 @@ class OpenSSLWrapper:
                                  Message = Why the validation failed.
         """
         try:
-            ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, signing_cert_str)
-            cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str)
-
-            if self.certificate_not_valid_yet(ca_cert):
+            if isinstance(signing_cert_str, str):
+                signing_cert_str = signing_cert_str.encode("utf-8")
+            if isinstance(cert_str, str):
+                cert_str = cert_str.encode("utf-8")
+            ca_cert = x509.load_pem_x509_certificate(signing_cert_str)
+            cert = x509.load_pem_x509_certificate(cert_str)
+            now = datetime.now(timezone.utc)
+
+            if ca_cert.not_valid_before_utc >= now:
                 return False, "CA certificate is not valid yet."
 
-            if ca_cert.has_expired() == 1:
+            if ca_cert.not_valid_after_utc < now:
                 return False, "CA certificate is expired."
 
-            if cert.has_expired() == 1:
+            if cert.not_valid_after_utc < now:
                 return False, "The signed certificate is expired."
 
-            if self.certificate_not_valid_yet(cert):
+            if cert.not_valid_before_utc >= now:
                 return False, "The signed certificate is not valid yet."
 
-            if ca_cert.get_subject().CN == cert.get_subject().CN:
+            if ca_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) == \
+               cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME):
                 return False, ("CN may not be equal for CA certificate and the " "signed certificate.")
 
-            cert_algorithm = cert.get_signature_algorithm()
-            cert_algorithm = cert_algorithm.decode("ascii")
-            cert_str = cert_str.encode("ascii")
-
-            cert_crypto = saml2.cryptography.pki.load_pem_x509_certificate(cert_str)
-
             try:
-                crypto.verify(ca_cert, cert_crypto.signature, cert_crypto.tbs_certificate_bytes, cert_algorithm)
+                cert.verify_directly_issued_by(ca_cert)
                 return True, "Signed certificate is valid and correctly signed by CA certificate."
-            except crypto.Error as e:
+            except (ValueError, TypeError, InvalidSignature) as e:
                 return False, f"Certificate is incorrectly signed: {str(e)}"
         except Exception as e:
             return False, f"Certificate is not valid for an unknown reason. {str(e)}"
@@ -352,8 +362,14 @@ def read_cert_from_file(cert_file, cert_type="pem"):
         data = fp.read()
 
     try:
-        cert = saml2.cryptography.pki.load_x509_certificate(data, cert_type)
-        pem_data = saml2.cryptography.pki.get_public_bytes_from_cert(cert)
+        cert = None
+        if cert_type == "pem":
+            cert = x509.load_pem_x509_certificate(data)
+        elif cert_type == "der":
+            cert = x509.load_der_x509_certificate(data)
+        else:
+            raise ValueError(f"cert-type {cert_type} not supported")
+        pem_data = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
     except Exception as e:
         raise CertificateError(e)
 
diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py
index 738ac04b..60c83718 100644
--- a/src/saml2/sigver.py
+++ b/src/saml2/sigver.py
@@ -18,8 +18,9 @@ from time import mktime
 from urllib import parse
 from uuid import uuid4 as gen_random_key
 
-from OpenSSL import crypto
-import dateutil
+from urllib import parse
+
+from cryptography import x509
 
 from saml2 import ExtensionElement
 from saml2 import SamlBase
@@ -373,14 +374,14 @@ def active_cert(key):
     """
     try:
         cert_str = pem_format(key)
-        cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str)
+        cert = x509.load_pem_x509_certificate(cert_str)
     except AttributeError:
         return False
 
-    now = datetime.now(timezone.utc)
-    valid_from = dateutil.parser.parse(cert.get_notBefore())
-    valid_to = dateutil.parser.parse(cert.get_notAfter())
-    active = not cert.has_expired() and valid_from <= now < valid_to
+    now = datetime.datetime.now(datetime.timezone.utc)
+    valid_from = cert.not_valid_before_utc
+    valid_to = cert.not_valid_after_utc
+    active = valid_from <= now < valid_to
     return active