diff --git a/pulsar/client/action_mapper.py b/pulsar/client/action_mapper.py index 11aa1b62de356666c0387de44687229f1877237b..8b0374aead9957125ceab68921345e35a388447e 100644 --- a/pulsar/client/action_mapper.py +++ b/pulsar/client/action_mapper.py @@ -138,6 +138,9 @@ class FileActionMapper(object): config = self.__client_to_config(client) self.default_action = config.get("default_action", "transfer") self.ssh_key = config.get("ssh_key", None) + self.ssh_user = config.get("ssh_user", None) + self.ssh_host = config.get("ssh_host", None) + self.ssh_port = config.get("ssh_port", None) self.mappers = mappers_from_dicts(config.get("paths", [])) self.files_endpoint = config.get("files_endpoint", None) @@ -164,6 +167,9 @@ class FileActionMapper(object): default_action=self.default_action, files_endpoint=self.files_endpoint, ssh_key=self.ssh_key, + ssh_user=self.ssh_user, + ssh_port=self.ssh_port, + ssh_host=self.ssh_host, paths=map(lambda m: m.to_dict(), self.mappers) ) @@ -175,8 +181,9 @@ class FileActionMapper(object): config = dict() config["default_action"] = client.default_file_action config["files_endpoint"] = client.files_endpoint - if hasattr(client, 'ssh_key'): - config["ssh_key"] = client.ssh_key + for attr in ['ssh_key', 'ssh_user', 'ssh_port', 'ssh_host']: + if hasattr(client, attr): + config[attr] = getattr(client, attr) return config def __load_action_config(self, path): @@ -213,8 +220,8 @@ class FileActionMapper(object): """ if getattr(action, "inject_url", False): self.__inject_url(action, file_type) - if getattr(action, "inject_ssh_key", False): - self.__inject_ssh_key(action) + if getattr(action, "inject_ssh_properties", False): + self.__inject_ssh_properties(action) def __inject_url(self, action, file_type): url_base = self.files_endpoint @@ -226,15 +233,19 @@ class FileActionMapper(object): url = "%s&path=%s&file_type=%s" % (url_base, action.path, file_type) action.url = url - def __inject_ssh_key(self, action): - # Required, so no check for presence - ssh_key = self.ssh_key - if ssh_key is None: + def __inject_ssh_properties(self, action): + for attr in ["ssh_key", "ssh_host", "ssh_port", "ssh_user"]: + action_attr = getattr(action, attr) + if action_attr == UNSET_ACTION_KWD: + client_default_attr = getattr(self, attr, None) + setattr(action, attr, client_default_attr) + + if action.ssh_key is None: raise Exception(MISSING_SSH_KEY_ERROR) - action.ssh_key = ssh_key REQUIRED_ACTION_KWD = object() +UNSET_ACTION_KWD = "__UNSET__" class BaseAction(object): @@ -403,17 +414,17 @@ class RemoteTransferAction(BaseAction): class PubkeyAuthenticatedTransferAction(BaseAction): """Base class for file transfers requiring an SSH public/private key """ - inject_ssh_key = True + inject_ssh_properties = True action_spec = dict( - ssh_user=REQUIRED_ACTION_KWD, - ssh_host=REQUIRED_ACTION_KWD, - ssh_port=REQUIRED_ACTION_KWD, + ssh_key=UNSET_ACTION_KWD, + ssh_user=UNSET_ACTION_KWD, + ssh_host=UNSET_ACTION_KWD, + ssh_port=UNSET_ACTION_KWD, ) staging = STAGING_ACTION_REMOTE - ssh_key = None - def __init__(self, path, file_lister=None, url=None, ssh_user=None, - ssh_host=None, ssh_port=None, ssh_key=None): + def __init__(self, path, file_lister=None, url=None, ssh_user=UNSET_ACTION_KWD, + ssh_host=UNSET_ACTION_KWD, ssh_port=UNSET_ACTION_KWD, ssh_key=UNSET_ACTION_KWD): super(PubkeyAuthenticatedTransferAction, self).__init__(path, file_lister=file_lister) self.url = url self.ssh_user = ssh_user @@ -550,6 +561,8 @@ class BasePathMapper(object): message_template = "action_type %s requires key word argument %s" message = message_template % (action_type, key) raise Exception(message) + else: + action_kwds[key] = value self.action_type = action_type self.action_kwds = action_kwds path_types_str = config.get('path_types', "*defaults*") diff --git a/pulsar/client/client.py b/pulsar/client/client.py index 909095919aae89b5f0e9c411702706adff5f48e3..e41d89f00bf3c17522d76b7148513075bc81af9b 100644 --- a/pulsar/client/client.py +++ b/pulsar/client/client.py @@ -43,7 +43,8 @@ class BaseJobClient(object): else: job_directory = None - self.ssh_key = destination_params.get("ssh_key", None) + for attr in ["ssh_key", "ssh_user", "ssh_host", "ssh_port"]: + setattr(self, attr, destination_params.get(attr, None)) self.env = destination_params.get("env", []) self.files_endpoint = destination_params.get("files_endpoint", None) self.job_directory = job_directory diff --git a/pulsar/client/test/check.py b/pulsar/client/test/check.py index 9f9ed701d755ea9c35883d202dfe3c67c2438d7f..5816f8d7d7ccbb43b0c82dd3cbd3b9b3798a1ce6 100644 --- a/pulsar/client/test/check.py +++ b/pulsar/client/test/check.py @@ -218,7 +218,7 @@ class Waiter(object): self.client_manager.ensure_has_status_update_callback(on_update) - def wait(self, seconds=5): + def wait(self, seconds=15): final_status = None if not self.async: i = 0 @@ -294,6 +294,15 @@ def __client(temp_directory, options): client_options["jobs_directory"] = getattr(options, "jobs_directory") if hasattr(options, "files_endpoint"): client_options["files_endpoint"] = getattr(options, "files_endpoint") + if default_file_action in ["remote_scp_transfer", "remote_rsync_transfer"]: + test_key = os.environ["PULSAR_TEST_KEY"] + if not test_key.startswith("----"): + test_key = open(test_key, "rb").read() + client_options["ssh_key"] = test_key + client_options["ssh_user"] = os.environ.get("USER") + client_options["ssh_port"] = 22 + client_options["ssh_host"] = "localhost" + user = getattr(options, 'user', None) if user: client_options["submit_user"] = user diff --git a/pulsar/client/transport/ssh.py b/pulsar/client/transport/ssh.py index 0dbf4cd144e683e1ee1a9ae5a16866c79eb3b31b..ab480cd7a1d1ce09c669b8fc785968147a384cd0 100644 --- a/pulsar/client/transport/ssh.py +++ b/pulsar/client/transport/ssh.py @@ -1,5 +1,7 @@ import subprocess -SSH_OPTIONS = ('-o', 'StrictHostKeyChecking=no', '-o', 'PreferredAuthentications=publickey', '-o', 'PubkeyAuthentication=yes') +import os + +SSH_OPTIONS = ['-o', 'StrictHostKeyChecking=no', '-o', 'PreferredAuthentications=publickey', '-o', 'PubkeyAuthentication=yes'] def rsync_get_file(uri_from, uri_to, user, host, port, key): @@ -16,6 +18,22 @@ def rsync_get_file(uri_from, uri_to, user, host, port, key): def rsync_post_file(uri_from, uri_to, user, host, port, key): + directory = os.path.dirname(uri_to) + cmd = [ + 'ssh', + '-i', + key, + '-p', + str(port), + ] + SSH_OPTIONS + [ + '%s@%s' % (user, host), + 'mkdir', + '-p', + directory, + ] + exit_code = subprocess.check_call(cmd) + if exit_code != 0: + raise Exception("ssh exited with code %s" % exit_code) cmd = [ 'rsync', '-e', @@ -32,8 +50,8 @@ def scp_get_file(uri_from, uri_to, user, host, port, key): cmd = [ 'scp', '-P', str(port), - '-i', key, - SSH_OPTIONS, + '-i', key + ] + SSH_OPTIONS + [ '%s@%s:%s' % (user, host, uri_from), uri_to, ] @@ -43,11 +61,27 @@ def scp_get_file(uri_from, uri_to, user, host, port, key): def scp_post_file(uri_from, uri_to, user, host, port, key): + directory = os.path.dirname(uri_to) + cmd = [ + 'ssh', + '-i', + key, + '-p', + str(port), + ] + SSH_OPTIONS + [ + '%s@%s' % (user, host), + 'mkdir', + '-p', + directory, + ] + exit_code = subprocess.check_call(cmd) + if exit_code != 0: + raise Exception("ssh exited with code %s" % exit_code) cmd = [ 'scp', '-P', str(port), '-i', key, - SSH_OPTIONS, + ] + SSH_OPTIONS + [ uri_from, '%s@%s:%s' % (user, host, uri_to), ] diff --git a/test/integration_test.py b/test/integration_test.py index 170c82b5b31750fab322f473b4758a41734d2237..c8845c56e348c2327760d4324c92b2273c73768b 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -6,7 +6,8 @@ from .test_utils import ( TempDirectoryTestCase, skip_unless_executable, skip_unless_module, - skip_unless_any_module + skip_unless_any_module, + skip_unless_environ, ) from .test_utils import test_pulsar_app @@ -116,6 +117,28 @@ class IntegrationTests(BaseIntegrationTest): **self.default_kwargs ) + @skip_unless_environ("PULSAR_TEST_KEY") + def test_integration_scp(self): + self._run( + app_conf=dict(message_queue_url="memory://test1"), + private_token=None, + default_file_action="remote_scp_transfer", + local_setup=True, + manager_url="memory://test1", + **self.default_kwargs + ) + + @skip_unless_environ("PULSAR_TEST_KEY") + def test_integration_rsync(self): + self._run( + app_conf=dict(message_queue_url="memory://test1"), + private_token=None, + default_file_action="remote_rsync_transfer", + local_setup=True, + manager_url="memory://test1", + **self.default_kwargs + ) + def test_integration_copy(self): self._run(private_token=None, default_file_action="copy", **self.default_kwargs) diff --git a/test/test_utils.py b/test/test_utils.py index 63cf493aee03e80836c5c0ee240d05989b7dcff4..3edf53bcfb27455511e8b6bed0a16de3cd7693a9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -250,6 +250,12 @@ def test_pulsar_app(global_conf={}, app_conf={}, test_conf={}): pass +def skip_unless_environ(var): + if var in environ: + return lambda func: func + return skip("Environment variable %s not found, dependent test skipped." % var) + + def skip_unless_executable(executable): if _which(executable): return lambda func: func