Because Twisted in an event based framework I’ve found myself needing to test whether an event has occurred. For example, let’s say I wanted to confirm a reconnecting factory was indeed reconnecting. To confirm this I need to know if connectionMade() is called on the protocol after a connection has been dropped. I have detailed how I solved this problem below.
I will be implementing a reconnecting factory for txAMQP. Twisted provides a ReconnectingFactory class which we’ll be subclassing. The concept behind the reconnecting factory is simple.
To test for a successful reconnect we need to do two things:
To force a lose of connection we can do the following from within connectionMade() on the protocol:
1 | def connectionMade(self): self.transport.loseConnection() |
To detect a reconnection we can count how many times retry() is called on the factory or how many times connectionMade() is called on the protocol. I’m going to be working with connectionMade() since we’ll need to manipulate this event for the lose of connection anyway.
Additional constraint: I want to test the actual classes rather than creating test classes that are subclasses with overridden methods.
To solve this problem I’ve created a Hijack class that will relay calls from a particular event to a supplied method. A user instantiates a Hijack object with a class, a method name, and the method to call. Here is the code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | class Hijack: def __init__(self, cls, method, new): self.cls = cls self.method = method self.old = old = getattr(cls, method) def new_method(obj, *args, **kwargs): old(obj, *args, **kwargs) new(obj, *args, **kwargs) setattr(cls, method, new_method) def release(self, ignore=None): # ignore is useful when this is used as a callback setattr(self.cls, self.method, self.old) return ignore |
We can write our reconnecting test as follows assuming our protocol and factory class names are AmqpFactory and AmqpProtocol. You’ll also notice a TimedDeferred() which I use to verify completion of the test before a set timeout (3 seconds by default).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | from twisted.trial import unittest from twisted.internet import reactor, defer, protocol from txamqp.protocol import AMQClient from txamqp.client import TwistedDelegate import txamqp class Hijack: def __init__(self, cls, method, new): self.cls = cls self.method = method self.old = old = getattr(cls, method) def new_method(obj, *args, **kwargs): old(obj, *args, **kwargs) new(obj, *args, **kwargs) setattr(cls, method, new_method) def release(self, ignore=None): # ignore is useful when this is used as a callback setattr(self.cls, self.method, self.old) return ignore class TimeoutException(Exception): pass class TimedDeferred(defer.Deferred): def __init__(self, timeout=3, msg=None): defer.Deferred.__init__(self) if not msg: msg = 'Deferred timed out after %s seconds' % timeout def onTimeout(deferred): deferred.errback(TimeoutException(msg)) def callback(ignore, delayedCall): if delayedCall.active(): delayedCall.cancel() def errback(failure, delayedCall): if delayedCall.active(): delayedCall.cancel() return failure try: delayedCall = reactor.callLater(timeout, onTimeout, self) self.addCallback(callback, delayedCall) self.addErrback(errback, delayedCall) except AssertionError, e: self.errback(e) class TestReconnectingFactory(unittest.TestCase): def test_reconnection(self): dq = defer.DeferredQueue() def new(obj, *args, **kwargs): dq.put(1) obj.transport.loseConnection() h = Hijack(AmqpProtocol, 'connectionMade', new) d = TimedDeferred() def count_reconnections(increment, count, dq): count += increment if count > 3: d.callback(None) else: dq.get().addCallback(count_reconnections, count, dq) dq.get().addCallback(count_reconnections, 0, dq) f = AmqpFactory() f.factor = 1 f.delay = .1 def clean_up(ignore, factory, hijacked): factory.stopTrying() hijacked.release() return ignore # continues chain of failure if this is an errback d.addBoth(clean_up, f, h) reactor.connectTCP('localhost', 5672, f) return d |
Now, we can throw in our reconnecting factory and protocol:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | from twisted.trial import unittest from twisted.internet import reactor, defer, protocol from txamqp.protocol import AMQClient from txamqp.client import TwistedDelegate import txamqp class Hijack: def __init__(self, cls, method, new): self.cls = cls self.method = method self.old = old = getattr(cls, method) def new_method(obj, *args, **kwargs): old(obj, *args, **kwargs) new(obj, *args, **kwargs) setattr(cls, method, new_method) def release(self, ignore=None): # ignore is useful when this is used as a callback setattr(self.cls, self.method, self.old) return ignore class TimeoutException(Exception): pass class TimedDeferred(defer.Deferred): def __init__(self, timeout=3, msg=None): defer.Deferred.__init__(self) if not msg: msg = 'Deferred timed out after %s seconds' % timeout def onTimeout(deferred): deferred.errback(TimeoutException(msg)) def callback(ignore, delayedCall): if delayedCall.active(): delayedCall.cancel() def errback(failure, delayedCall): if delayedCall.active(): delayedCall.cancel() return failure try: delayedCall = reactor.callLater(timeout, onTimeout, self) self.addCallback(callback, delayedCall) self.addErrback(errback, delayedCall) except AssertionError, e: self.errback(e) class TestReconnectingFactory(unittest.TestCase): def test_reconnection(self): dq = defer.DeferredQueue() def new(obj, *args, **kwargs): dq.put(1) obj.transport.loseConnection() h = Hijack(AmqpProtocol, 'connectionMade', new) d = TimedDeferred() def count_reconnections(increment, count, dq): count += increment if count > 3: d.callback(None) else: dq.get().addCallback(count_reconnections, count, dq) dq.get().addCallback(count_reconnections, 0, dq) f = AmqpFactory() f.factor = 1 f.delay = .1 def clean_up(ignore, factory, hijacked): factory.stopTrying() hijacked.release() return ignore # continues chain of failure if this is an errback d.addBoth(clean_up, f, h) reactor.connectTCP('localhost', 5672, f) return d class AmqpProtocol(AMQClient): def connectionMade(self): """Called when a connection has been made.""" AMQClient.connectionMade(self) deferred = self.start({"LOGIN": self.factory.user, "PASSWORD": self.factory.password}) deferred.addCallback(self._authenticated) def _authenticated(self, ignore): """Called when the connection has been authenticated.""" pass class AmqpFactory(protocol.ReconnectingClientFactory): protocol = AmqpProtocol def __init__(self, spec_file=None, vhost=None, user=None, password=None): spec_file = spec_file or '../amqp0-8.xml' self.spec = txamqp.spec.load(spec_file) self.user = user or 'guest' self.password = password or 'guest' self.vhost = '/' or vhost self.delegate = TwistedDelegate() def buildProtocol(self, addr): """ buildProtocol is the method that is required by protocol.ClientFactory which sets the protocol instance on a given factory. :arguements: addr *unknown* """ p = self.protocol(self.delegate, self.vhost, self.spec) p.factory = self return p |
To run the test we type the following:
1 | trial reconnecting.py |