From a6ca98ac533621bc094abafaf33665a3f7c4b3b2 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Mon, 14 Jan 2008 21:22:55 -0700
Subject: [PATCH] upload: add Encoder.abort(), to abandon the upload in
 progress. Add some debug hooks to enable unit tests.

---
 src/allmydata/encode.py            | 34 +++++++++++++++++++++++++++---
 src/allmydata/interfaces.py        |  5 +++++
 src/allmydata/storage.py           | 32 ++++++++++++++++++++++++++--
 src/allmydata/test/test_encode.py  |  5 ++++-
 src/allmydata/test/test_storage.py | 23 ++++++++++++++------
 src/allmydata/upload.py            | 32 ++++++++++++++++++++++++++--
 6 files changed, 116 insertions(+), 15 deletions(-)

diff --git a/src/allmydata/encode.py b/src/allmydata/encode.py
index 2467aaf6..fb35bd72 100644
--- a/src/allmydata/encode.py
+++ b/src/allmydata/encode.py
@@ -90,6 +90,7 @@ class Encoder(object):
         self._parent = parent
         if self._parent:
             self._log_number = self._parent.log("creating Encoder %s" % self)
+        self._aborted = False
 
     def __repr__(self):
         if hasattr(self, "_storage_index"):
@@ -263,6 +264,15 @@ class Encoder(object):
         d.addCallbacks(lambda res: self.done(), self.err)
         return d
 
+    def abort(self):
+        self.log("aborting upload")
+        assert self._codec, "don't call abort before start"
+        self._aborted = True
+        # the next segment read (in _gather_data inside _encode_segment) will
+        # raise UploadAborted(), which will bypass the rest of the upload
+        # chain. If we've sent the final segment's shares, it's too late to
+        # abort. TODO: allow abort any time up to close_all_shareholders.
+
     def _turn_barrier(self, res):
         # putting this method in a Deferred chain imposes a guaranteed
         # reactor turn between the pre- and post- portions of that chain.
@@ -341,11 +351,16 @@ class Encoder(object):
         with the combination of any 'previous_chunks' and the new chunks
         which were gathered."""
 
+        if self._aborted:
+            raise UploadAborted()
+
         if not num_chunks:
             return defer.succeed(previous_chunks)
 
         d = self._uploadable.read_encrypted(input_chunk_size)
         def _got(data):
+            if self._aborted:
+                raise UploadAborted()
             encrypted_pieces = []
             length = 0
             while data:
@@ -595,6 +610,19 @@ class Encoder(object):
 
     def err(self, f):
         self.log("UNUSUAL: %s: upload failed: %s" % (self, f))
-        if f.check(defer.FirstError):
-            return f.value.subFailure
-        return f
+        # we need to abort any remaining shareholders, so they'll delete the
+        # partial share, allowing someone else to upload it again.
+        self.log("aborting shareholders")
+        dl = []
+        for shareid in list(self.landlords.keys()):
+            d = self.landlords[shareid].abort()
+            d.addErrback(self._remove_shareholder, shareid, "abort")
+            dl.append(d)
+        d = self._gather_responses(dl)
+        def _done(res):
+            self.log("shareholders aborted")
+            if f.check(defer.FirstError):
+                return f.value.subFailure
+            return f
+        d.addCallback(_done)
+        return d
diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py
index 1ffad5c5..bbc1c42c 100644
--- a/src/allmydata/interfaces.py
+++ b/src/allmydata/interfaces.py
@@ -78,6 +78,11 @@ class RIBucketWriter(RemoteInterface):
         """
         return None
 
+    def abort():
+        """Abandon all the data that has been written.
+        """
+        return None
+
 class RIBucketReader(RemoteInterface):
     def read(offset=int, length=int):
         return ShareData
diff --git a/src/allmydata/storage.py b/src/allmydata/storage.py
index 758d8bee..86c58864 100644
--- a/src/allmydata/storage.py
+++ b/src/allmydata/storage.py
@@ -164,11 +164,13 @@ class ShareFile:
 class BucketWriter(Referenceable):
     implements(RIBucketWriter)
 
-    def __init__(self, ss, incominghome, finalhome, size, lease_info):
+    def __init__(self, ss, incominghome, finalhome, size, lease_info, canary):
         self.ss = ss
         self.incominghome = incominghome
         self.finalhome = finalhome
         self._size = size
+        self._canary = canary
+        self._disconnect_marker = canary.notifyOnDisconnect(self._disconnected)
         self.closed = False
         self.throw_out_all_data = False
         # touch the file, so later callers will see that we're working on it.
@@ -196,6 +198,7 @@ class BucketWriter(Referenceable):
         fileutil.rename(self.incominghome, self.finalhome)
         self._sharefile = None
         self.closed = True
+        self._canary.dontNotifyOnDisconnect(self._disconnect_marker)
 
         filelen = os.stat(self.finalhome)[stat.ST_SIZE]
         self.ss.bucket_writer_closed(self, filelen)
@@ -206,6 +209,28 @@ class BucketWriter(Referenceable):
         if not os.listdir(parentdir):
             os.rmdir(parentdir)
 
+    def _disconnected(self):
+        if not self.closed:
+            self._abort()
+
+    def remote_abort(self):
+        log.msg("storage: aborting sharefile %s" % self.incominghome,
+                facility="tahoe.storage", level=log.UNUSUAL)
+        if not self.closed:
+            self._canary.dontNotifyOnDisconnect(self._disconnect_marker)
+        self._abort()
+
+    def _abort(self):
+        if self.closed:
+            return
+        os.remove(self.incominghome)
+        # if we were the last share to be moved, remove the incoming/
+        # directory that was our parent
+        parentdir = os.path.split(self.incominghome)[0]
+        if not os.listdir(parentdir):
+            os.rmdir(parentdir)
+
+
 
 class BucketReader(Referenceable):
     implements(RIBucketReader)
@@ -721,7 +746,7 @@ class StorageServer(service.MultiService, Referenceable):
                 # ok! we need to create the new share file.
                 fileutil.make_dirs(os.path.join(self.incomingdir, si_s))
                 bw = BucketWriter(self, incominghome, finalhome,
-                                  space_per_bucket, lease_info)
+                                  space_per_bucket, lease_info, canary)
                 if self.no_storage:
                     bw.throw_out_all_data = True
                 bucketwriters[shnum] = bw
@@ -1110,6 +1135,9 @@ class WriteBucketProxy:
     def close(self):
         return self._rref.callRemote("close")
 
+    def abort(self):
+        return self._rref.callRemote("abort")
+
 class ReadBucketProxy:
     implements(IStorageBucketReader)
     def __init__(self, rref):
diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py
index 9562f7c7..db514d2c 100644
--- a/src/allmydata/test/test_encode.py
+++ b/src/allmydata/test/test_encode.py
@@ -84,6 +84,9 @@ class FakeBucketWriterProxy:
             self.closed = True
         return defer.maybeDeferred(_try)
 
+    def abort(self):
+        return defer.succeed(None)
+
     def get_block(self, blocknum):
         def _try():
             assert isinstance(blocknum, (int, long))
@@ -621,7 +624,7 @@ class Roundtrip(unittest.TestCase):
         d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         def _done(res):
             self.failUnless(isinstance(res, Failure))
-            self.failUnless(res.check(encode.NotEnoughPeersError))
+            self.failUnless(res.check(encode.NotEnoughPeersError), res)
         d.addBoth(_done)
         return d
 
diff --git a/src/allmydata/test/test_storage.py b/src/allmydata/test/test_storage.py
index 4538261c..97762003 100644
--- a/src/allmydata/test/test_storage.py
+++ b/src/allmydata/test/test_storage.py
@@ -12,6 +12,12 @@ from allmydata.storage import BucketWriter, BucketReader, \
 from allmydata.interfaces import BadWriteEnablerError
 from allmydata.test.common import LoggingServiceParent
 
+class FakeCanary:
+    def notifyOnDisconnect(self, *args, **kwargs):
+        pass
+    def dontNotifyOnDisconnect(self, marker):
+        pass
+
 class Bucket(unittest.TestCase):
     def make_workdir(self, name):
         basedir = os.path.join("storage", "Bucket", name)
@@ -33,7 +39,8 @@ class Bucket(unittest.TestCase):
 
     def test_create(self):
         incoming, final = self.make_workdir("test_create")
-        bw = BucketWriter(self, incoming, final, 200, self.make_lease())
+        bw = BucketWriter(self, incoming, final, 200, self.make_lease(),
+                          FakeCanary())
         bw.remote_write(0, "a"*25)
         bw.remote_write(25, "b"*25)
         bw.remote_write(50, "c"*25)
@@ -42,7 +49,8 @@ class Bucket(unittest.TestCase):
 
     def test_readwrite(self):
         incoming, final = self.make_workdir("test_readwrite")
-        bw = BucketWriter(self, incoming, final, 200, self.make_lease())
+        bw = BucketWriter(self, incoming, final, 200, self.make_lease(),
+                          FakeCanary())
         bw.remote_write(0, "a"*25)
         bw.remote_write(25, "b"*25)
         bw.remote_write(50, "c"*7) # last block may be short
@@ -69,7 +77,8 @@ class BucketProxy(unittest.TestCase):
         final = os.path.join(basedir, "bucket")
         fileutil.make_dirs(basedir)
         fileutil.make_dirs(os.path.join(basedir, "tmp"))
-        bw = BucketWriter(self, incoming, final, size, self.make_lease())
+        bw = BucketWriter(self, incoming, final, size, self.make_lease(),
+                          FakeCanary())
         rb = RemoteBucket()
         rb.target = bw
         return bw, rb, final
@@ -201,7 +210,7 @@ class Server(unittest.TestCase):
         cancel_secret = hashutil.tagged_hash("blah", "%d" % self._lease_secret.next())
         return ss.remote_allocate_buckets(storage_index,
                                           renew_secret, cancel_secret,
-                                          sharenums, size, Referenceable())
+                                          sharenums, size, FakeCanary())
 
     def test_remove_incoming(self):
         ss = self.create("test_remove_incoming")
@@ -219,7 +228,7 @@ class Server(unittest.TestCase):
 
         self.failUnlessEqual(ss.remote_get_buckets("vid"), {})
 
-        canary = Referenceable()
+        canary = FakeCanary()
         already,writers = self.allocate(ss, "vid", [0,1,2], 75)
         self.failUnlessEqual(already, set())
         self.failUnlessEqual(set(writers.keys()), set([0,1,2]))
@@ -253,7 +262,7 @@ class Server(unittest.TestCase):
 
     def test_sizelimits(self):
         ss = self.create("test_sizelimits", 5000)
-        canary = Referenceable()
+        canary = FakeCanary()
         # a newly created and filled share incurs this much overhead, beyond
         # the size we request.
         OVERHEAD = 3*4
@@ -336,7 +345,7 @@ class Server(unittest.TestCase):
 
     def test_leases(self):
         ss = self.create("test_leases")
-        canary = Referenceable()
+        canary = FakeCanary()
         sharenums = range(5)
         size = 100
 
diff --git a/src/allmydata/upload.py b/src/allmydata/upload.py
index 3036d04b..71d20e09 100644
--- a/src/allmydata/upload.py
+++ b/src/allmydata/upload.py
@@ -436,6 +436,7 @@ class CHKUploader:
         self._client = client
         self._options = options
         self._log_number = self._client.log("CHKUploader starting")
+        self._encoder = None
 
     def set_params(self, encoding_parameters):
         self._encoding_parameters = encoding_parameters
@@ -465,10 +466,19 @@ class CHKUploader:
         d.addCallback(_uploaded)
         return d
 
+    def abort(self):
+        """Call this is the upload must be abandoned before it completes.
+        This will tell the shareholders to delete their partial shares. I
+        return a Deferred that fires when these messages have been acked."""
+        if not self._encoder:
+            # how did you call abort() before calling start() ?
+            return defer.succeed(None)
+        return self._encoder.abort()
+
     def start_encrypted(self, encrypted):
         eu = IEncryptedUploadable(encrypted)
 
-        e = encode.Encoder(self._options, self)
+        self._encoder = e = encode.Encoder(self._options, self)
         e.set_params(self._encoding_parameters)
         d = e.set_encrypted_uploadable(eu)
         d.addCallback(self.locate_all_shareholders)
@@ -562,6 +572,9 @@ class RemoteEncryptedUploabable(Referenceable):
     def __init__(self, encrypted_uploadable):
         self._eu = IEncryptedUploadable(encrypted_uploadable)
         self._offset = 0
+        self._bytes_read = 0
+        self._cutoff = None # set by debug options
+        self._cutoff_cb = None
 
     def remote_get_size(self):
         return self._eu.get_size()
@@ -570,9 +583,13 @@ class RemoteEncryptedUploabable(Referenceable):
     def remote_read_encrypted(self, offset, length):
         # we don't yet implement seek
         assert offset == self._offset, "%d != %d" % (offset, self._offset)
+        if self._cutoff is not None and offset+length > self._cutoff:
+            self._cutoff_cb()
         d = self._eu.read_encrypted(length)
         def _read(strings):
-            self._offset += sum([len(data) for data in strings])
+            size = sum([len(data) for data in strings])
+            self._bytes_read += size
+            self._offset += size
             return strings
         d.addCallback(_read)
         return d
@@ -636,6 +653,17 @@ class AssistedUploader:
             self.log("helper says we need to upload")
             # we need to upload the file
             reu = RemoteEncryptedUploabable(self._encuploadable)
+            if "debug_stash_RemoteEncryptedUploadable" in self._options:
+                self._options["RemoteEncryptedUploabable"] = reu
+            if "debug_interrupt" in self._options:
+                reu._cutoff = self._options["debug_interrupt"]
+                def _cutoff():
+                    # simulate the loss of the connection to the helper
+                    self.log("debug_interrupt killing connection to helper",
+                             level=log.WEIRD)
+                    upload_helper.tracker.broker.transport.loseConnection()
+                    return
+                reu._cutoff_cb = _cutoff
             d = upload_helper.callRemote("upload", reu)
             # this Deferred will fire with the upload results
             return d
-- 
2.45.2