From c68b9b3c00b57808d975b4274b7ef9a63eec1073 Mon Sep 17 00:00:00 2001
From: John Chilton <jmchilton@gmail.com>
Date: Mon, 13 May 2019 09:05:01 -0400
Subject: [PATCH] Push path->source abstraction up and into staging/up.py.

---
 pulsar/client/staging/up.py | 54 +++++++++++++++++++++++--------------
 1 file changed, 34 insertions(+), 20 deletions(-)

diff --git a/pulsar/client/staging/up.py b/pulsar/client/staging/up.py
index 119fe719..2207cd02 100644
--- a/pulsar/client/staging/up.py
+++ b/pulsar/client/staging/up.py
@@ -203,11 +203,11 @@ class FileStager(object):
 
     def __upload_tool_files(self):
         for referenced_tool_file in self.referenced_tool_files:
-            self.transfer_tracker.handle_transfer(referenced_tool_file, path_type.TOOL)
+            self.transfer_tracker.handle_transfer_path(referenced_tool_file, path_type.TOOL)
 
     def __upload_arbitrary_files(self):
         for path, name in self.arbitrary_files.items():
-            self.transfer_tracker.handle_transfer(path, path_type.UNSTRUCTURED, name=name)
+            self.transfer_tracker.handle_transfer_path(path, path_type.UNSTRUCTURED, name=name)
 
     def __upload_input_files(self):
         handled_inputs = set()
@@ -231,7 +231,7 @@ class FileStager(object):
     def __upload_input_file(self, input_file):
         if self.__stage_input(input_file):
             if exists(input_file):
-                self.transfer_tracker.handle_transfer(input_file, path_type.INPUT)
+                self.transfer_tracker.handle_transfer_path(input_file, path_type.INPUT)
             else:
                 message = "Pulsar: __upload_input_file called on empty or missing dataset." + \
                           " No such file: [%s]" % input_file
@@ -242,13 +242,13 @@ class FileStager(object):
             for extra_file_name in directory_files(files_path):
                 extra_file_path = join(files_path, extra_file_name)
                 remote_name = self.path_helper.remote_name(relpath(extra_file_path, dirname(files_path)))
-                self.transfer_tracker.handle_transfer(extra_file_path, path_type.INPUT, name=remote_name)
+                self.transfer_tracker.handle_transfer_path(extra_file_path, path_type.INPUT, name=remote_name)
 
     def __upload_input_metadata_file(self, path):
         if self.__stage_input(path):
             # Name must match what is generated in remote_input_path_rewrite in path_mapper.
             remote_name = "metadata_%s" % basename(path)
-            self.transfer_tracker.handle_transfer(path, path_type.INPUT, name=remote_name)
+            self.transfer_tracker.handle_transfer_path(path, path_type.INPUT, name=remote_name)
 
     def __upload_working_directory_files(self):
         # Task manager stages files into working directory, these need to be
@@ -256,13 +256,13 @@ class FileStager(object):
         working_directory_files = self.__working_directory_files()
         for working_directory_file in working_directory_files:
             path = join(self.working_directory, working_directory_file)
-            self.transfer_tracker.handle_transfer(path, path_type.WORKDIR)
+            self.transfer_tracker.handle_transfer_path(path, path_type.WORKDIR)
 
     def __upload_metadata_directory_files(self):
         metadata_directory_files = self.__metadata_directory_files()
         for metadata_directory_file in metadata_directory_files:
             path = join(self.metadata_directory, metadata_directory_file)
-            self.transfer_tracker.handle_transfer(path, path_type.METADATA)
+            self.transfer_tracker.handle_transfer_path(path, path_type.METADATA)
 
     def __working_directory_files(self):
         return self.__list_files(self.working_directory)
@@ -308,7 +308,7 @@ class FileStager(object):
 
     def __upload_rewritten_config_files(self):
         for config_file, new_config_contents in self.job_inputs.config_files.items():
-            self.transfer_tracker.handle_transfer(config_file, type=path_type.CONFIG, contents=new_config_contents)
+            self.transfer_tracker.handle_transfer_path(config_file, type=path_type.CONFIG, contents=new_config_contents)
 
     def get_command_line(self):
         """
@@ -437,20 +437,28 @@ class TransferTracker(object):
         self.file_renames = {}
         self.remote_staging_actions = []
 
-    def handle_transfer(self, path, type, name=None, contents=None):
-        action = self.__action_for_transfer(path, type, contents)
+    def handle_transfer_path(self, path, type, name=None, contents=None):
+        source = {"path": path}
+        return self.handle_transfer_source(source, type, name=name, contents=contents)
+
+    def handle_transfer_source(self, source, type, name=None, contents=None):
+        action = self.__action_for_transfer(source, type, contents)
 
         if action.staging_needed:
             local_action = action.staging_action_local
             if local_action:
+                path = source['path']
                 response = self.client.put_file(path, type, name=name, contents=contents, action_type=action.action_type)
 
                 def get_path():
                     return response['path']
             else:
+                path = source['path']
                 job_directory = self.job_directory
                 assert job_directory, "job directory required for action %s" % action
                 if not name:
+                    # TODO: consider fetching this from source so an actual input path
+                    # isn't needed. At least it isn't used though.
                     name = basename(path)
                 self.__add_remote_staging_input(action, name, type)
 
@@ -458,11 +466,11 @@ class TransferTracker(object):
                     return job_directory.calculate_path(name, type)
             register = self.rewrite_paths or type == 'tool'  # Even if inputs not rewritten, tool must be.
             if register:
-                self.register_rewrite(path, get_path(), type, force=True)
+                self.register_rewrite_action(action, get_path(), force=True)
         elif self.rewrite_paths:
             path_rewrite = action.path_rewrite(self.path_helper)
             if path_rewrite:
-                self.register_rewrite(path, path_rewrite, type, force=True)
+                self.register_rewrite_action(action, path_rewrite, force=True)
 
         # else: # No action for this file
 
@@ -474,23 +482,29 @@ class TransferTracker(object):
         )
         self.remote_staging_actions.append(input_dict)
 
-    def __action_for_transfer(self, path, type, contents):
+    def __action_for_transfer(self, source, type, contents):
         if contents:
             # If contents loaded in memory, no need to write out file and copy,
             # just transfer.
             action = MessageAction(contents=contents, client=self.client)
         else:
-            if not exists(path):
-                message = "handle_transfer called on non-existent file - [%s]" % path
+            path = source.get("path")
+            if path is not None and not exists(path):
+                message = "handle_transfer_path called on non-existent file - [%s]" % path
                 log.warn(message)
                 raise Exception(message)
-            action = self.__action(path, type)
+            action = self.__action(source, type)
         return action
 
     def register_rewrite(self, local_path, remote_path, type, force=False):
-        action = self.__action(local_path, type)
+        action = self.__action({"path": local_path}, type)
+        self.register_rewrite_action(action, remote_path, force=force)
+
+    def register_rewrite_action(self, action, remote_path, force=False):
         if action.staging_needed or force:
-            self.file_renames[local_path] = remote_path
+            path = getattr(action, 'path', None)
+            if path:
+                self.file_renames[path] = remote_path
 
     def rewrite_input_paths(self):
         """
@@ -500,8 +514,8 @@ class TransferTracker(object):
         for local_path, remote_path in self.file_renames.items():
             self.job_inputs.rewrite_paths(local_path, remote_path)
 
-    def __action(self, path, type):
-        return self.action_mapper.action({"path": path}, type)
+    def __action(self, source, type):
+        return self.action_mapper.action(source, type)
 
 
 def _read(path):
-- 
GitLab