Make stuff to do things

Subscribe (?) Subscribe to RSS

Archive for June, 2009

Testing events in Twisted’s Trial

Published on June 29th, 2009 in Comments

Because Twisted in an event based framework I’ve found myself needing to test whether an event has occurred. For example, let’s say I wanted to confirm a reconnecting factory was indeed reconnecting. To confirm this I need to know if connectionMade() is called on the protocol after a connection has been dropped. I have detailed how I solved this problem below.

I will be implementing a reconnecting factory for txAMQP. Twisted provides a ReconnectingFactory class which we’ll be subclassing. The concept behind the reconnecting factory is simple.

  1. Connect to transport
  2. On a lost or failed connection the events clientConnectionLost() or clientConnectionFailed() are called on the factory
    * At the same time, the events connectionLost() or connectionFalied() are called on the protocol; for the reconnecting factory these methods are unimportant
  3. The ReconnectingFactory class overrides these events and issues a retry()
  4. The method retry() will attempt to re-establish the connection after a suitable delay

To test for a successful reconnect we need to do two things:

  1. Force a lose of connection
  2. Detect a reconnect attempt

To force a lose of connection we can do the following from within connectionMade() on the protocol:

1
def connectionMade(self): self.transport.loseConnection()

To detect a reconnection we can count how many times retry() is called on the factory or how many times connectionMade() is called on the protocol. I’m going to be working with connectionMade() since we’ll need to manipulate this event for the lose of connection anyway.

Additional constraint: I want to test the actual classes rather than creating test classes that are subclasses with overridden methods.

To solve this problem I’ve created a Hijack class that will relay calls from a particular event to a supplied method. A user instantiates a Hijack object with a class, a method name, and the method to call. Here is the code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Hijack:
    def __init__(self, cls, method, new):
        self.cls = cls
        self.method = method
        self.old = old = getattr(cls, method)
        def new_method(obj, *args, **kwargs):
            old(obj, *args, **kwargs)
            new(obj, *args, **kwargs)
        setattr(cls, method, new_method)
       
    def release(self, ignore=None):
        # ignore is useful when this is used as a callback
        setattr(self.cls, self.method, self.old)
        return ignore

We can write our reconnecting test as follows assuming our protocol and factory class names are AmqpFactory and AmqpProtocol. You’ll also notice a TimedDeferred() which I use to verify completion of the test before a set timeout (3 seconds by default).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from twisted.trial import unittest
from twisted.internet import reactor, defer, protocol
from txamqp.protocol import AMQClient
from txamqp.client import TwistedDelegate
import txamqp

class Hijack:
    def __init__(self, cls, method, new):
        self.cls = cls
        self.method = method
        self.old = old = getattr(cls, method)
        def new_method(obj, *args, **kwargs):
            old(obj, *args, **kwargs)
            new(obj, *args, **kwargs)
        setattr(cls, method, new_method)
       
    def release(self, ignore=None):
        # ignore is useful when this is used as a callback
        setattr(self.cls, self.method, self.old)
        return ignore


class TimeoutException(Exception):
    pass


class TimedDeferred(defer.Deferred):
    def __init__(self, timeout=3, msg=None):
        defer.Deferred.__init__(self)

        if not msg:
            msg = 'Deferred timed out after %s seconds' % timeout

        def onTimeout(deferred):
            deferred.errback(TimeoutException(msg))

        def callback(ignore, delayedCall):
            if delayedCall.active():
                delayedCall.cancel()

        def errback(failure, delayedCall):
            if delayedCall.active():
                delayedCall.cancel()
            return failure

        try:
            delayedCall = reactor.callLater(timeout, onTimeout, self)
            self.addCallback(callback, delayedCall)
            self.addErrback(errback, delayedCall)
        except AssertionError, e:
            self.errback(e)


class TestReconnectingFactory(unittest.TestCase):
    def test_reconnection(self):
        dq = defer.DeferredQueue()
        def new(obj, *args, **kwargs):
            dq.put(1)
            obj.transport.loseConnection()
        h = Hijack(AmqpProtocol, 'connectionMade', new)
        d = TimedDeferred()
        def count_reconnections(increment, count, dq):
            count += increment
            if count > 3:
                d.callback(None)
            else:
                dq.get().addCallback(count_reconnections, count, dq)
        dq.get().addCallback(count_reconnections, 0, dq)
        f = AmqpFactory()
        f.factor = 1
        f.delay = .1
        def clean_up(ignore, factory, hijacked):
            factory.stopTrying()
            hijacked.release()
            return ignore # continues chain of failure if this is an errback
        d.addBoth(clean_up, f, h)
        reactor.connectTCP('localhost', 5672, f)
        return d

Now, we can throw in our reconnecting factory and protocol:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from twisted.trial import unittest
from twisted.internet import reactor, defer, protocol
from txamqp.protocol import AMQClient
from txamqp.client import TwistedDelegate
import txamqp

class Hijack:
    def __init__(self, cls, method, new):
        self.cls = cls
        self.method = method
        self.old = old = getattr(cls, method)
        def new_method(obj, *args, **kwargs):
            old(obj, *args, **kwargs)
            new(obj, *args, **kwargs)
        setattr(cls, method, new_method)
       
    def release(self, ignore=None):
        # ignore is useful when this is used as a callback
        setattr(self.cls, self.method, self.old)
        return ignore


class TimeoutException(Exception):
    pass


class TimedDeferred(defer.Deferred):
    def __init__(self, timeout=3, msg=None):
        defer.Deferred.__init__(self)

        if not msg:
            msg = 'Deferred timed out after %s seconds' % timeout

        def onTimeout(deferred):
            deferred.errback(TimeoutException(msg))

        def callback(ignore, delayedCall):
            if delayedCall.active():
                delayedCall.cancel()

        def errback(failure, delayedCall):
            if delayedCall.active():
                delayedCall.cancel()
            return failure

        try:
            delayedCall = reactor.callLater(timeout, onTimeout, self)
            self.addCallback(callback, delayedCall)
            self.addErrback(errback, delayedCall)
        except AssertionError, e:
            self.errback(e)


class TestReconnectingFactory(unittest.TestCase):
    def test_reconnection(self):
        dq = defer.DeferredQueue()
        def new(obj, *args, **kwargs):
            dq.put(1)
            obj.transport.loseConnection()
        h = Hijack(AmqpProtocol, 'connectionMade', new)
        d = TimedDeferred()
        def count_reconnections(increment, count, dq):
            count += increment
            if count > 3:
                d.callback(None)
            else:
                dq.get().addCallback(count_reconnections, count, dq)
        dq.get().addCallback(count_reconnections, 0, dq)
        f = AmqpFactory()
        f.factor = 1
        f.delay = .1
        def clean_up(ignore, factory, hijacked):
            factory.stopTrying()
            hijacked.release()
            return ignore # continues chain of failure if this is an errback
        d.addBoth(clean_up, f, h)
        reactor.connectTCP('localhost', 5672, f)
        return d
       
       
class AmqpProtocol(AMQClient):
    def connectionMade(self):
        """Called when a connection has been made."""
        AMQClient.connectionMade(self)
        deferred = self.start({"LOGIN": self.factory.user,
                          "PASSWORD": self.factory.password})
        deferred.addCallback(self._authenticated)

    def _authenticated(self, ignore):
        """Called when the connection has been authenticated."""
        pass

               
class AmqpFactory(protocol.ReconnectingClientFactory):
    protocol = AmqpProtocol
   
    def __init__(self, spec_file=None, vhost=None, user=None, password=None):
        spec_file = spec_file or '../amqp0-8.xml'
        self.spec = txamqp.spec.load(spec_file)
        self.user = user or 'guest'
        self.password = password or 'guest'
        self.vhost = '/' or vhost
        self.delegate = TwistedDelegate()
       
    def buildProtocol(self, addr):
        """
        buildProtocol is the method that is required by protocol.ClientFactory
        which sets the protocol instance on a given factory.
        :arguements:
            addr
                *unknown*
        """

        p = self.protocol(self.delegate, self.vhost, self.spec)
        p.factory = self
        return p

To run the test we type the following:

1
    trial reconnecting.py
Switch to our mobile site