From 7fc2b10ab53b60ab5a12151955e4c1c0287e7a80 Mon Sep 17 00:00:00 2001
From: John Chilton <jmchilton@gmail.com>
Date: Sat, 7 Mar 2015 00:01:46 -0500
Subject: [PATCH] Integration tests and fixes for rsync/scp transfer actions.

 - Rearrange things so ssh_key can be alternatively defined in file_actions.yml or the other properties ssh_user, ssh_host, ssh_port can be defined in destination_params. This allows these actions to be usable as the default file action and made it much easier to test.
 - Fixes when constructing scp command-line (my version of Python didn't like a tuple in middle of subprocess.check_call).
 - Increase integration test timeout (these are slower than localhost RESTful transfer).
 - Add integration test for each of these commands.
---
 pulsar/client/action_mapper.py | 45 ++++++++++++++++++++++------------
 pulsar/client/client.py        |  3 ++-
 pulsar/client/test/check.py    | 11 ++++++++-
 pulsar/client/transport/ssh.py | 42 ++++++++++++++++++++++++++++---
 test/integration_test.py       | 25 ++++++++++++++++++-
 test/test_utils.py             |  6 +++++
 6 files changed, 109 insertions(+), 23 deletions(-)

diff --git a/pulsar/client/action_mapper.py b/pulsar/client/action_mapper.py
index 11aa1b62..8b0374ae 100644
--- a/pulsar/client/action_mapper.py
+++ b/pulsar/client/action_mapper.py
@@ -138,6 +138,9 @@ class FileActionMapper(object):
             config = self.__client_to_config(client)
         self.default_action = config.get("default_action", "transfer")
         self.ssh_key = config.get("ssh_key", None)
+        self.ssh_user = config.get("ssh_user", None)
+        self.ssh_host = config.get("ssh_host", None)
+        self.ssh_port = config.get("ssh_port", None)
         self.mappers = mappers_from_dicts(config.get("paths", []))
         self.files_endpoint = config.get("files_endpoint", None)
 
@@ -164,6 +167,9 @@ class FileActionMapper(object):
             default_action=self.default_action,
             files_endpoint=self.files_endpoint,
             ssh_key=self.ssh_key,
+            ssh_user=self.ssh_user,
+            ssh_port=self.ssh_port,
+            ssh_host=self.ssh_host,
             paths=map(lambda m: m.to_dict(), self.mappers)
         )
 
@@ -175,8 +181,9 @@ class FileActionMapper(object):
             config = dict()
         config["default_action"] = client.default_file_action
         config["files_endpoint"] = client.files_endpoint
-        if hasattr(client, 'ssh_key'):
-            config["ssh_key"] = client.ssh_key
+        for attr in ['ssh_key', 'ssh_user', 'ssh_port', 'ssh_host']:
+            if hasattr(client, attr):
+                config[attr] = getattr(client, attr)
         return config
 
     def __load_action_config(self, path):
@@ -213,8 +220,8 @@ class FileActionMapper(object):
         """
         if getattr(action, "inject_url", False):
             self.__inject_url(action, file_type)
-        if getattr(action, "inject_ssh_key", False):
-            self.__inject_ssh_key(action)
+        if getattr(action, "inject_ssh_properties", False):
+            self.__inject_ssh_properties(action)
 
     def __inject_url(self, action, file_type):
         url_base = self.files_endpoint
@@ -226,15 +233,19 @@ class FileActionMapper(object):
         url = "%s&path=%s&file_type=%s" % (url_base, action.path, file_type)
         action.url = url
 
-    def __inject_ssh_key(self, action):
-        # Required, so no check for presence
-        ssh_key = self.ssh_key
-        if ssh_key is None:
+    def __inject_ssh_properties(self, action):
+        for attr in ["ssh_key", "ssh_host", "ssh_port", "ssh_user"]:
+            action_attr = getattr(action, attr)
+            if action_attr == UNSET_ACTION_KWD:
+                client_default_attr = getattr(self, attr, None)
+                setattr(action, attr, client_default_attr)
+
+        if action.ssh_key is None:
             raise Exception(MISSING_SSH_KEY_ERROR)
-        action.ssh_key = ssh_key
 
 
 REQUIRED_ACTION_KWD = object()
+UNSET_ACTION_KWD = "__UNSET__"
 
 
 class BaseAction(object):
@@ -403,17 +414,17 @@ class RemoteTransferAction(BaseAction):
 class PubkeyAuthenticatedTransferAction(BaseAction):
     """Base class for file transfers requiring an SSH public/private key
     """
-    inject_ssh_key = True
+    inject_ssh_properties = True
     action_spec = dict(
-        ssh_user=REQUIRED_ACTION_KWD,
-        ssh_host=REQUIRED_ACTION_KWD,
-        ssh_port=REQUIRED_ACTION_KWD,
+        ssh_key=UNSET_ACTION_KWD,
+        ssh_user=UNSET_ACTION_KWD,
+        ssh_host=UNSET_ACTION_KWD,
+        ssh_port=UNSET_ACTION_KWD,
     )
     staging = STAGING_ACTION_REMOTE
-    ssh_key = None
 
-    def __init__(self, path, file_lister=None, url=None, ssh_user=None,
-                 ssh_host=None, ssh_port=None, ssh_key=None):
+    def __init__(self, path, file_lister=None, url=None, ssh_user=UNSET_ACTION_KWD,
+                 ssh_host=UNSET_ACTION_KWD, ssh_port=UNSET_ACTION_KWD, ssh_key=UNSET_ACTION_KWD):
         super(PubkeyAuthenticatedTransferAction, self).__init__(path, file_lister=file_lister)
         self.url = url
         self.ssh_user = ssh_user
@@ -550,6 +561,8 @@ class BasePathMapper(object):
                 message_template = "action_type %s requires key word argument %s"
                 message = message_template % (action_type, key)
                 raise Exception(message)
+            else:
+                action_kwds[key] = value
         self.action_type = action_type
         self.action_kwds = action_kwds
         path_types_str = config.get('path_types', "*defaults*")
diff --git a/pulsar/client/client.py b/pulsar/client/client.py
index 90909591..e41d89f0 100644
--- a/pulsar/client/client.py
+++ b/pulsar/client/client.py
@@ -43,7 +43,8 @@ class BaseJobClient(object):
         else:
             job_directory = None
 
-        self.ssh_key = destination_params.get("ssh_key", None)
+        for attr in ["ssh_key", "ssh_user", "ssh_host", "ssh_port"]:
+            setattr(self, attr, destination_params.get(attr, None))
         self.env = destination_params.get("env", [])
         self.files_endpoint = destination_params.get("files_endpoint", None)
         self.job_directory = job_directory
diff --git a/pulsar/client/test/check.py b/pulsar/client/test/check.py
index 9f9ed701..5816f8d7 100644
--- a/pulsar/client/test/check.py
+++ b/pulsar/client/test/check.py
@@ -218,7 +218,7 @@ class Waiter(object):
 
             self.client_manager.ensure_has_status_update_callback(on_update)
 
-    def wait(self, seconds=5):
+    def wait(self, seconds=15):
         final_status = None
         if not self.async:
             i = 0
@@ -294,6 +294,15 @@ def __client(temp_directory, options):
         client_options["jobs_directory"] = getattr(options, "jobs_directory")
     if hasattr(options, "files_endpoint"):
         client_options["files_endpoint"] = getattr(options, "files_endpoint")
+    if default_file_action in ["remote_scp_transfer", "remote_rsync_transfer"]:
+        test_key = os.environ["PULSAR_TEST_KEY"]
+        if not test_key.startswith("----"):
+            test_key = open(test_key, "rb").read()
+        client_options["ssh_key"] = test_key
+        client_options["ssh_user"] = os.environ.get("USER")
+        client_options["ssh_port"] = 22
+        client_options["ssh_host"] = "localhost"
+
     user = getattr(options, 'user', None)
     if user:
         client_options["submit_user"] = user
diff --git a/pulsar/client/transport/ssh.py b/pulsar/client/transport/ssh.py
index 0dbf4cd1..ab480cd7 100644
--- a/pulsar/client/transport/ssh.py
+++ b/pulsar/client/transport/ssh.py
@@ -1,5 +1,7 @@
 import subprocess
-SSH_OPTIONS = ('-o', 'StrictHostKeyChecking=no', '-o', 'PreferredAuthentications=publickey', '-o', 'PubkeyAuthentication=yes')
+import os
+
+SSH_OPTIONS = ['-o', 'StrictHostKeyChecking=no', '-o', 'PreferredAuthentications=publickey', '-o', 'PubkeyAuthentication=yes']
 
 
 def rsync_get_file(uri_from, uri_to, user, host, port, key):
@@ -16,6 +18,22 @@ def rsync_get_file(uri_from, uri_to, user, host, port, key):
 
 
 def rsync_post_file(uri_from, uri_to, user, host, port, key):
+    directory = os.path.dirname(uri_to)
+    cmd = [
+        'ssh',
+        '-i',
+        key,
+        '-p',
+        str(port),
+    ] + SSH_OPTIONS + [
+        '%s@%s' % (user, host),
+        'mkdir',
+        '-p',
+        directory,
+    ]
+    exit_code = subprocess.check_call(cmd)
+    if exit_code != 0:
+        raise Exception("ssh exited with code %s" % exit_code)
     cmd = [
         'rsync',
         '-e',
@@ -32,8 +50,8 @@ def scp_get_file(uri_from, uri_to, user, host, port, key):
     cmd = [
         'scp',
         '-P', str(port),
-        '-i', key,
-        SSH_OPTIONS,
+        '-i', key
+    ] + SSH_OPTIONS + [
         '%s@%s:%s' % (user, host, uri_from),
         uri_to,
     ]
@@ -43,11 +61,27 @@ def scp_get_file(uri_from, uri_to, user, host, port, key):
 
 
 def scp_post_file(uri_from, uri_to, user, host, port, key):
+    directory = os.path.dirname(uri_to)
+    cmd = [
+        'ssh',
+        '-i',
+        key,
+        '-p',
+        str(port),
+    ] + SSH_OPTIONS + [
+        '%s@%s' % (user, host),
+        'mkdir',
+        '-p',
+        directory,
+    ]
+    exit_code = subprocess.check_call(cmd)
+    if exit_code != 0:
+        raise Exception("ssh exited with code %s" % exit_code)
     cmd = [
         'scp',
         '-P', str(port),
         '-i', key,
-        SSH_OPTIONS,
+    ] + SSH_OPTIONS + [
         uri_from,
         '%s@%s:%s' % (user, host, uri_to),
     ]
diff --git a/test/integration_test.py b/test/integration_test.py
index 170c82b5..c8845c56 100644
--- a/test/integration_test.py
+++ b/test/integration_test.py
@@ -6,7 +6,8 @@ from .test_utils import (
     TempDirectoryTestCase,
     skip_unless_executable,
     skip_unless_module,
-    skip_unless_any_module
+    skip_unless_any_module,
+    skip_unless_environ,
 )
 
 from .test_utils import test_pulsar_app
@@ -116,6 +117,28 @@ class IntegrationTests(BaseIntegrationTest):
             **self.default_kwargs
         )
 
+    @skip_unless_environ("PULSAR_TEST_KEY")
+    def test_integration_scp(self):
+        self._run(
+            app_conf=dict(message_queue_url="memory://test1"),
+            private_token=None,
+            default_file_action="remote_scp_transfer",
+            local_setup=True,
+            manager_url="memory://test1",
+            **self.default_kwargs
+        )
+
+    @skip_unless_environ("PULSAR_TEST_KEY")
+    def test_integration_rsync(self):
+        self._run(
+            app_conf=dict(message_queue_url="memory://test1"),
+            private_token=None,
+            default_file_action="remote_rsync_transfer",
+            local_setup=True,
+            manager_url="memory://test1",
+            **self.default_kwargs
+        )
+
     def test_integration_copy(self):
         self._run(private_token=None, default_file_action="copy", **self.default_kwargs)
 
diff --git a/test/test_utils.py b/test/test_utils.py
index 63cf493a..3edf53bc 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -250,6 +250,12 @@ def test_pulsar_app(global_conf={}, app_conf={}, test_conf={}):
                 pass
 
 
+def skip_unless_environ(var):
+    if var in environ:
+        return lambda func: func
+    return skip("Environment variable %s not found, dependent test skipped." % var)
+
+
 def skip_unless_executable(executable):
     if _which(executable):
         return lambda func: func
-- 
GitLab