]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/scripts/tahoe_cp.py
cp: error on target-filename collisions, rather than overwrite
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / scripts / tahoe_cp.py
index 850a1b37df0470bdd1538016305e1daa8cf6c1c1..351b6a331717db90bda955fed187998cd8f2a4e1 100644 (file)
@@ -20,6 +20,14 @@ class MissingSourceError(TahoeError):
     def __init__(self, name, quotefn=quote_output):
         TahoeError.__init__(self, "No such file or directory %s" % quotefn(name))
 
+class FilenameWithTrailingSlashError(TahoeError):
+    def __init__(self, name, quotefn=quote_output):
+        TahoeError.__init__(self, "source '%s' is not a directory, but ends with a slash" % quotefn(name))
+
+class WeirdSourceError(TahoeError):
+    def __init__(self, absname):
+        quoted = quote_local_unicode_path(absname)
+        TahoeError.__init__(self, "source '%s' is neither a file nor a directory, I can't handle it" % quoted)
 
 def GET_to_file(url):
     resp = do_http("GET", url)
@@ -116,7 +124,7 @@ class LocalDirectorySource:
                 child = LocalDirectorySource(self.progressfunc, pn, n)
                 self.children[n] = child
                 if recurse:
-                    child.populate(True)
+                    child.populate(recurse=True)
             elif os.path.isfile(pn):
                 self.children[n] = LocalFileSource(pn, n)
             else:
@@ -144,15 +152,16 @@ class LocalDirectoryTarget:
                 child = LocalDirectoryTarget(self.progressfunc, pn)
                 self.children[n] = child
                 if recurse:
-                    child.populate(True)
+                    child.populate(recurse=True)
             else:
                 assert os.path.isfile(pn)
                 self.children[n] = LocalFileTarget(pn)
 
     def get_child_target(self, name):
         precondition(isinstance(name, unicode), name)
+        precondition(len(name), name) # don't want ""
         if self.children is None:
-            self.populate(False)
+            self.populate(recurse=False)
         if name in self.children:
             return self.children[name]
         pathname = os.path.join(self.pathname, name)
@@ -280,7 +289,7 @@ class TahoeDirectorySource:
                     if readcap:
                         self.cache[readcap] = child
                     if recurse:
-                        child.populate(True)
+                        child.populate(recurse=True)
                 self.children[name] = child
             else:
                 # TODO: there should be an option to skip unknown nodes.
@@ -380,7 +389,7 @@ class TahoeDirectoryTarget:
                     if readcap:
                         self.cache[readcap] = child
                     if recurse:
-                        child.populate(True)
+                        child.populate(recurse=True)
                 self.children[name] = child
             else:
                 # TODO: there should be an option to skip unknown nodes.
@@ -392,7 +401,7 @@ class TahoeDirectoryTarget:
         # return a new target for a named subdirectory of this dir
         precondition(isinstance(name, unicode), name)
         if self.children is None:
-            self.populate(False)
+            self.populate(recurse=False)
         if name in self.children:
             return self.children[name]
         writecap = make_tahoe_subdirectory(self.nodeurl, self.writecap, name)
@@ -409,7 +418,7 @@ class TahoeDirectoryTarget:
             inf = inf.read()
 
         if self.children is None:
-            self.populate(False)
+            self.populate(recurse=False)
 
         # Check to see if we already have a mutable file by this name.
         # If so, overwrite that file in place.
@@ -487,8 +496,9 @@ class Copier:
 
     def try_copy(self):
         """
-        All usage errors are caught here, not in a subroutine. This bottoms
-        out in copy_file_to_file() or copy_things_to_directory().
+        All usage errors (except for target filename collisions) are caught
+        here, not in a subroutine. This bottoms out in copy_file_to_file() or
+        copy_things_to_directory().
         """
         source_specs = self.options.sources
         destination_spec = self.options.destination
@@ -500,7 +510,11 @@ class Copier:
 
         sources = [] # list of source objects
         for ss in source_specs:
-            si = self.get_source_info(ss)
+            try:
+                si = self.get_source_info(ss)
+            except FilenameWithTrailingSlashError as e:
+                self.to_stderr(str(e))
+                return 1
             precondition(isinstance(si, FileSources + DirectorySources), si)
             sources.append(si)
 
@@ -531,7 +545,7 @@ class Copier:
                 target_is_file = False
 
         if target_is_file and target_has_trailing_slash:
-            self.to_stderr("target is not a directory, but has a slash")
+            self.to_stderr("target is not a directory, but ends with a slash")
             return 1
 
         if len(sources) > 1 and target_is_file:
@@ -559,7 +573,7 @@ class Copier:
         _assert(isinstance(target, DirectoryTargets + MissingTargets), target)
 
         for source in sources:
-            if isinstance(source, FileSources) and not source.basename():
+            if isinstance(source, FileSources) and source.basename() is None:
                 self.to_stderr("when copying into a directory, all source files must have names, but %s is unnamed" % quote_output(source_specs[0]))
                 return 1
         return self.copy_things_to_directory(sources, target)
@@ -621,6 +635,9 @@ class Copier:
         precondition(isinstance(source_spec, unicode), source_spec)
         rootcap, path_utf8 = get_alias(self.aliases, source_spec, None)
         path = path_utf8.decode("utf-8")
+        # any trailing slash is removed in abspath_expanduser_unicode(), so
+        # make a note of it here, to throw an error later
+        had_trailing_slash = path.endswith("/")
         if rootcap == DefaultAliasMarker:
             # no alias, so this is a local file
             pathname = abspath_expanduser_unicode(path)
@@ -630,13 +647,19 @@ class Copier:
             if os.path.isdir(pathname):
                 t = LocalDirectorySource(self.progress, pathname, name)
             else:
-                assert os.path.isfile(pathname)
+                if had_trailing_slash:
+                    raise FilenameWithTrailingSlashError(source_spec,
+                                                         quotefn=quote_local_unicode_path)
+                if not os.path.isfile(pathname):
+                    raise WeirdSourceError(pathname)
                 t = LocalFileSource(pathname, name) # non-empty
         else:
             # this is a tahoe object
             url = self.nodeurl + "uri/%s" % urllib.quote(rootcap)
             name = None
             if path:
+                if path.endswith("/"):
+                    path = path[:-1]
                 url += "/" + escape_path(path)
                 last_slash = path.rfind(u"/")
                 name = path
@@ -656,16 +679,11 @@ class Copier:
                                          self.progress, name)
                 t.init_from_parsed(parsed)
             else:
+                if had_trailing_slash:
+                    raise FilenameWithTrailingSlashError(source_spec)
                 writecap = to_str(d.get("rw_uri"))
                 readcap = to_str(d.get("ro_uri"))
                 mutable = d.get("mutable", False) # older nodes don't provide it
-
-                last_slash = source_spec.rfind(u"/")
-                if last_slash != -1:
-                    # TODO: this looks funny and redundant with the 'name'
-                    # assignment above. cf #2329
-                    name = source_spec[last_slash+1:]
-
                 t = TahoeFileSource(self.nodeurl, mutable, writecap, readcap, name)
         return t
 
@@ -702,12 +720,12 @@ class Copier:
     def copy_things_to_directory(self, sources, target):
         # step one: if the target is missing, we should mkdir it
         target = self.maybe_create_target(target)
-        target.populate(False)
+        target.populate(recurse=False)
 
         # step two: scan any source dirs, recursively, to find children
         for s in sources:
             if isinstance(s, DirectorySources):
-                s.populate(True)
+                s.populate(recurse=True)
             if isinstance(s, FileSources):
                 # each source must have a name, or be a directory
                 _assert(s.basename() is not None, s)
@@ -718,6 +736,23 @@ class Copier:
         # sourceobject) dicts for all the files that need to wind up there.
         targetmap = self.build_targetmap(sources, target)
 
+        # target name collisions are an error
+        collisions = []
+        for target, sources in targetmap.items():
+            target_names = {}
+            for source in sources:
+                name = source.basename()
+                if name in target_names:
+                    collisions.append((target, source, target_names[name]))
+                else:
+                    target_names[name] = source
+        if collisions:
+            self.to_stderr("cannot copy multiple files with the same name into the same target directory")
+            # I'm not sure how to show where the collisions are coming from
+            #for (target, source1, source2) in collisions:
+            #    self.to_stderr(source1.basename())
+            return 1
+
         # step four: walk through the list of targets. For each one, copy all
         # the files. If the target is a TahoeDirectory, upload and create
         # read-caps, then do a set_children to the target directory.
@@ -778,7 +813,7 @@ class Copier:
                 subtarget = target.get_child_target(name)
                 self.assign_targets(targetmap, child, subtarget)
             else:
-                precondition(isinstance(child, FileSources), child)
+                _assert(isinstance(child, FileSources), child)
                 targetmap[target].append(child)
 
     def copy_to_targetmap(self, targetmap):
@@ -789,9 +824,9 @@ class Copier:
         targets_finished = 0
 
         for target, sources in targetmap.items():
-            precondition(isinstance(target, DirectoryTargets), target)
+            _assert(isinstance(target, DirectoryTargets), target)
             for source in sources:
-                precondition(isinstance(source, FileSources), source)
+                _assert(isinstance(source, FileSources), source)
                 self.copy_file_into_dir(source, source.basename(), target)
                 files_copied += 1
                 self.progress("%d/%d files, %d/%d directories" %
@@ -803,8 +838,7 @@ class Copier:
                           (targets_finished, len(targetmap)))
 
     def count_files_to_copy(self, targetmap):
-        files_to_copy = sum([len(sources) for sources in targetmap.values()])
-        return files_to_copy
+        return sum([len(sources) for sources in targetmap.values()])
 
     def copy_file_into_dir(self, source, name, target):
         precondition(isinstance(source, FileSources), source)