]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/scripts/tahoe_cp.py
cp: trailing slash on source filename is an error, just like on targets
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / scripts / tahoe_cp.py
index a5cfe16a1a25b7b65e9ca9d9856256457b7da9a4..7c3026218563b7ca085bcd8cb8c3f8e883048ec4 100644 (file)
@@ -20,6 +20,10 @@ class MissingSourceError(TahoeError):
     def __init__(self, name, quotefn=quote_output):
         TahoeError.__init__(self, "No such file or directory %s" % quotefn(name))
 
+class FilenameWithTrailingSlashError(TahoeError):
+    def __init__(self, name, quotefn=quote_output):
+        TahoeError.__init__(self, "source '%s' is not a directory, but ends with a slash" % quotefn(name))
+
 
 def GET_to_file(url):
     resp = do_http("GET", url)
@@ -501,7 +505,11 @@ class Copier:
 
         sources = [] # list of source objects
         for ss in source_specs:
-            si = self.get_source_info(ss)
+            try:
+                si = self.get_source_info(ss)
+            except FilenameWithTrailingSlashError:
+                self.to_stderr("source is not a directory, but ends with a slash")
+                return 1
             precondition(isinstance(si, FileSources + DirectorySources), si)
             sources.append(si)
 
@@ -622,6 +630,9 @@ class Copier:
         precondition(isinstance(source_spec, unicode), source_spec)
         rootcap, path_utf8 = get_alias(self.aliases, source_spec, None)
         path = path_utf8.decode("utf-8")
+        # any trailing slash is removed in abspath_expanduser_unicode(), so
+        # make a note of it here, to throw an error later
+        had_trailing_slash = path.endswith("/")
         if rootcap == DefaultAliasMarker:
             # no alias, so this is a local file
             pathname = abspath_expanduser_unicode(path)
@@ -631,6 +642,8 @@ class Copier:
             if os.path.isdir(pathname):
                 t = LocalDirectorySource(self.progress, pathname, name)
             else:
+                if had_trailing_slash:
+                    raise FilenameWithTrailingSlashError(source_spec)
                 assert os.path.isfile(pathname)
                 t = LocalFileSource(pathname, name) # non-empty
         else:
@@ -659,6 +672,8 @@ class Copier:
                                          self.progress, name)
                 t.init_from_parsed(parsed)
             else:
+                if had_trailing_slash:
+                    raise FilenameWithTrailingSlashError(source_spec)
                 writecap = to_str(d.get("rw_uri"))
                 readcap = to_str(d.get("ro_uri"))
                 mutable = d.get("mutable", False) # older nodes don't provide it