diff --git a/django_remote_submission/admin.py b/django_remote_submission/admin.py index d80d7fc16b0c69493a42db2edc34a81876e9ab06..07897f400dc130887fd572cf8367229814e1ea2e 100644 --- a/django_remote_submission/admin.py +++ b/django_remote_submission/admin.py @@ -7,9 +7,10 @@ from django.utils.translation import ngettext_lazy as _ from django.shortcuts import render from django.http.response import HttpResponseRedirect -from .models import Server, Job, Log, Interpreter, Result +from .models import Server, Job, Log, Interpreter, Result, IdentityFileModel from .tasks import submit_job_to_server + @admin.register(Interpreter) class InterpreterAdmin(admin.ModelAdmin): """Manage interpreters with default admin interface.""" @@ -138,3 +139,9 @@ class LogAdmin(admin.ModelAdmin): """Manage logs with the default admin interface.""" pass + + +@admin.register(IdentityFileModel) +class IdentityFileModelAdmin(admin.ModelAdmin): + """Manage interpreters with default admin interface.""" + pass diff --git a/django_remote_submission/models.py b/django_remote_submission/models.py index 3a4b8a6e90fa6b0064b8b934f397016d3f66c3a7..d05a05910f76a3bd8ad7d891af9c29509a633b9e 100644 --- a/django_remote_submission/models.py +++ b/django_remote_submission/models.py @@ -34,6 +34,9 @@ from model_utils import Choices from model_utils.fields import StatusField, AutoCreatedField from model_utils.models import TimeStampedModel +# standard imports +from pathlib import Path + # Thanks http://stackoverflow.com/a/7394475 class ListField(models.TextField): # noqa: D101 @@ -517,3 +520,9 @@ class IdentityFileModel(TimeStampedModel): r"""Override the save method to prevent updating the record, if extant""" kwargs["force_insert"] = True return super().save(*args, **kwargs) + + def delete(self, *args, **kwargs): + r"""Overwrite base delete method by first deleting the SSH key files""" + Path(self.public).unlink() + Path(self.private).unlink() + super().delete(*args, **kwargs) diff --git a/django_remote_submission/tasks.py b/django_remote_submission/tasks.py index 3cceb4b4b2492fc65ada04744247a039877a5cad..0b83da621feaf2e778d2a07b14edab6a301b9233 100644 --- a/django_remote_submission/tasks.py +++ b/django_remote_submission/tasks.py @@ -210,7 +210,7 @@ class LogContainer(object): @shared_task -def submit_job_to_server(job_pk, password=None, public_key_filename=None, username=None, +def submit_job_to_server(job_pk, password=None, key_filename=None, username=None, timeout=None, log_policy=LogPolicy.LOG_LIVE, store_results=None, remote=True): """Submit a job to the remote server. @@ -219,7 +219,7 @@ def submit_job_to_server(job_pk, password=None, public_key_filename=None, userna :param int job_pk: the primary key of the :class:`models.Job` to submit :param str password: the password of the user submitting the job - :param public_key_filename: the path where it is. + :param key_filename: the path to the private key file :param str username: the username of the user submitting, if it is different from the owner of the job :param datetime.timedelta timeout: the timeout for running the job @@ -228,9 +228,6 @@ def submit_job_to_server(job_pk, password=None, public_key_filename=None, userna :param bool remote: Either runs this task locally on the host or in a remote server. """ - - logger.debug("submit_job_to_server: %s", locals().keys()) - wrapper_cls = RemoteWrapper if remote else LocalWrapper job = Job.objects.get(pk=job_pk) @@ -249,7 +246,7 @@ def submit_job_to_server(job_pk, password=None, public_key_filename=None, userna log_policy=log_policy, ) - with wrapper.connect(password, public_key_filename): + with wrapper.connect(password, key_filename): wrapper.chdir(job.remote_directory) with wrapper.open(job.remote_filename, 'wt') as f: @@ -285,21 +282,14 @@ def submit_job_to_server(job_pk, password=None, public_key_filename=None, userna results = [] for attr in file_attrs: - # logger.debug('Listing directory: {!r}'.format(attr)) - if attr is script_attr: continue - if attr.st_mtime < script_mtime: continue - if not is_matching(attr.filename, store_results): - # logger.debug('Listing directory: is_matching: {}'.format(attr.filename)) continue else: - # logger.debug('Listing directory: not is_matching: {}'.format(attr.filename)) pass - result = Result.objects.create( remote_filename=attr.filename, job=job, @@ -314,7 +304,7 @@ def submit_job_to_server(job_pk, password=None, public_key_filename=None, userna @shared_task -def copy_job_to_server(job_pk, password=None, public_key_filename=None, username=None, +def copy_job_to_server(job_pk, password=None, key_filename=None, username=None, timeout=None, log_policy=LogPolicy.LOG_LIVE, store_results=None, remote=True): """Copy a job file to the remote server. @@ -323,7 +313,7 @@ def copy_job_to_server(job_pk, password=None, public_key_filename=None, username :param int job_pk: the primary key of the :class:`models.Job` to submit :param str password: the password of the user submitting the job - :param public_key_filename: the path where it is. + :param key_filename: the path to the private key file :param str username: the username of the user submitting, if it is different from the owner of the job :param datetime.timedelta timeout: the timeout for running the job @@ -348,7 +338,7 @@ def copy_job_to_server(job_pk, password=None, public_key_filename=None, username port=job.server.port, ) - with wrapper.connect(password, public_key_filename): + with wrapper.connect(password, key_filename): wrapper.chdir(job.remote_directory) with wrapper.open(job.remote_filename, 'wt') as f: @@ -405,7 +395,8 @@ def copy_key_to_server(public_key_filename: Optional[Union[PosixPath, str]], @shared_task def delete_key_from_server(public_key_filename: Optional[Union[PosixPath, str]], username: str, - password: str, + password: Optional[str], + key_filename: Optional[str], hostname: str, port: int = 22, remote: bool = True @@ -415,15 +406,16 @@ def delete_key_from_server(public_key_filename: Optional[Union[PosixPath, str]], This can be used as a Celery task, if the library is installed and running. - :param public_key_filename: the path where it is. + :param public_key_filename: the path to the public key file :param username: the username of the user submitting - :param password: the password of the user submitting the job + :param key_filename: the password of the user submitting the job + :param password: the path to the private key file :param hostname: The hostname used to connect to the server :param port: The port to connect to for SSH (usually 22) :param remote: Either runs this task locally on the host or in a remote server. """ wrapper_cls = RemoteWrapper if remote else LocalWrapper wrapper = wrapper_cls(hostname=hostname, username=username, port=port) - with wrapper.connect(password): + with wrapper.connect(password=password, key_filename=key_filename): wrapper.delete_key(public_key_filename) return None diff --git a/django_remote_submission/wrapper/remote.py b/django_remote_submission/wrapper/remote.py index cc2bb532d82b6a3e13d1dd4d5b092cdb096edb23..61a5337529e36dba8954668b3f0f35bfd645c748 100644 --- a/django_remote_submission/wrapper/remote.py +++ b/django_remote_submission/wrapper/remote.py @@ -19,7 +19,6 @@ except ImportError: from pipes import quote as cmd_quote # standard imports -import datetime import logging import os from pathlib import Path, PosixPath @@ -35,7 +34,7 @@ logger = logging.getLogger(__name__) class IdentityFile: r"""Provide temporary RSA SSH keys""" def __init__(self, - sshdir: PosixPath = Path.home() / ".ssh", + sshdir: Union[str, PosixPath] = Path.home() / ".ssh", persistent: bool = False ) -> None: r""" @@ -43,8 +42,9 @@ class IdentityFile: :param persistent: unless set to `True, RSA files will be automatically deleted when the object is about to be garbage collected. """ - sshdir.mkdir(mode=700, parents=True, exist_ok=True) - _, name = tempfile.mkstemp(prefix="id_rsa_", dir=sshdir) + dirpath = Path(sshdir) if isinstance(sshdir, str) else sshdir + dirpath.mkdir(mode=700, parents=True, exist_ok=True) + _, name = tempfile.mkstemp(prefix="id_rsa_", dir=dirpath) self._private = Path(name) self._public = Path(name + ".pub") self._persistent = persistent @@ -256,7 +256,6 @@ class RemoteWrapper(object): : raise ValueError: when no password and no stored public key are supplied :return SSHClient: an instace of the SSH tunnel """ - client = SSHClient() client.set_missing_host_key_policy(AutoAddPolicy()) @@ -283,8 +282,7 @@ class RemoteWrapper(object): logger.debug("Trying to connect with the public key") if self._key_filename is not None: try: - logger.info("Connecting to %s with public key.", - server_hostname) + logger.info("Connecting to %s with public key.", server_hostname) client.connect( server_hostname, port=server_port, @@ -302,7 +300,6 @@ class RemoteWrapper(object): def _make_command(self, args, timeout): command = ' '.join(cmd_quote(arg) for arg in args) - if timeout is not None: command = 'timeout {}s {}'.format(timeout.total_seconds(), command) return command @@ -330,6 +327,7 @@ class RemoteWrapper(object): key = f.read().strip() self._client.exec_command('mkdir -p ~/.ssh/') + self._client.exec_command('chmod 700 ~/.ssh/') self._client.exec_command('chmod 644 ~/.ssh/authorized_keys') diff --git a/tests/test_tasks.py b/tests/test_tasks.py index f67b3cbb5386aa35addc11f15e21075b6366cacf..5d6861f6990a8ab457cb27c9e5eacde45af8c6b2 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -526,7 +526,7 @@ def test_delete_key_old_way(env): @pytest.mark.remote_required @pytest.mark.django_db def test_deploy_and_delete_key(env): - r"""This is the new way of deploying and deleting the private key""" + r"""This is the new way of deploying and deleting the public key""" id_file = IdentityFile(sshdir=Path("/tmp")) # create temporary SSH identity file copy_key_to_server( @@ -549,10 +549,27 @@ def test_deploy_and_delete_key(env): with wrapper.connect(key_filename=id_file.private): pass + # delete the key from server passing the password for credentials delete_key_from_server( public_key_filename=id_file.public, username=env.remote_user, password=env.remote_password, + key_filename=None, + hostname=env.server_hostname, + port=env.server_port, + remote=runs_remotely, + ) + + # delete the key from server passing the private key file for credentials + copy_key_to_server(public_key_filename=id_file.public, username=env.remote_user, password=env.remote_password, + hostname=env.server_hostname, port=env.server_port, remote=runs_remotely) + with wrapper.connect(key_filename=id_file.private): + pass + delete_key_from_server( + public_key_filename=id_file.public, + username=env.remote_user, + password=None, + key_filename=id_file.private, hostname=env.server_hostname, port=env.server_port, remote=runs_remotely, diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 7f4b08c3892b262c124623405ecb0cf892a77c79..4919dbbd5a88393923fd032624371b747d8805c8 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -8,8 +8,7 @@ test_django-remote-submission Tests for `django-remote-submission` models module. """ # package imports -from django_remote_submission.models import IdentityFileModel -from django_remote_submission.models import Interpreter, Log, Job, Result, Server +from django_remote_submission.models import IdentityFileModel, Interpreter, Log, Job, Result, Server from django_remote_submission.wrapper.remote import IdentityFile # third party imports @@ -122,6 +121,9 @@ class TestIdentityFileModel: assert record.recipient == user assert os.path.exists(record.private) assert os.path.exists(record.public) + record.delete() + assert os.path.exists(record.private) is False + assert os.path.exists(record.public) is False @pytest.mark.django_db(transaction=True) def test_create_from_username(self, user): @@ -134,9 +136,6 @@ class TestIdentityFileModel: record = IdentityFileModel.objects.create_from_username(user.username) assert record.recipient == user assert Path(record.private).parent == Path.home() / ".ssh" - for file in [record.private, record.public]: - assert os.path.exists(file) - os.remove(file) # return the newly created IdentityFileModel record, do not instantiate another one record_2 = IdentityFileModel.objects.create_from_username(user.username) @@ -156,3 +155,14 @@ class TestIdentityFileModel: record.save() assert str(e.value) == "UNIQUE constraint failed: django_remote_submission_identityfilemodel.recipient_id" [os.remove(file) for file in [id_file.private, id_file.public]] + + @pytest.mark.django_db(transaction=True) + def test_delete(self, user): + r"""delete the record, along with the SSH key files""" + username = user.username + record = IdentityFileModel.objects.create_from_username(user.username) + record.delete() + assert os.path.exists(record.private) is False + assert os.path.exists(record.public) is False + # the IdentityFileModel instance is deleted, but not its associated recipient user + assert user.username == username