Skip to content

Commit b0c6172

Browse files
authored
Merge pull request #328 from bennr01/fix_327
treq.testing.StubTreq: fix persisting twisted.web.server.Session objects between requests
2 parents 4be2f69 + bbf412c commit b0c6172

File tree

4 files changed

+75
-6
lines changed

4 files changed

+75
-6
lines changed

changelog.d/327.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``treq.testing.StubTreq`` now persists ``twisted.web.server.Session`` instances between requests.

docs/testing.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@ This is superior to calling your resource's methods directly or passing mock obj
5757
Thus, the ``request`` object your code interacts with is a *real* :class:`twisted.web.server.Request` and behaves the same as it would in production.
5858

5959
Note that if your resource returns :data:`~twisted.web.server.NOT_DONE_YET` you must keep a reference to the :class:`~treq.testing.RequestTraversalAgent` and call its :meth:`~treq.testing.RequestTraversalAgent.flush()` method to spin the memory reactor once the server writes additional data before the client will receive it.
60+

src/treq/test/test_testing.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,34 @@ def render(self, request):
5252
return NOT_DONE_YET
5353

5454

55+
class _SessionIdTestResource(Resource):
56+
"""
57+
Resource that returns the current session ID.
58+
"""
59+
isLeaf = True
60+
61+
def __init__(self):
62+
super().__init__()
63+
# keep track of all sessions created, so we can manually expire them later
64+
self.sessions = []
65+
66+
def render(self, request):
67+
session = request.getSession()
68+
if session not in self.sessions:
69+
# new session, add to internal list
70+
self.sessions.append(session)
71+
uid = session.uid
72+
return uid
73+
74+
def expire_sessions(self):
75+
"""
76+
Manually expire all sessions created by this resource.
77+
"""
78+
for session in self.sessions:
79+
session.expire()
80+
self.sessions = []
81+
82+
5583
class StubbingTests(TestCase):
5684
"""
5785
Tests for :class:`StubTreq`.
@@ -242,6 +270,40 @@ def test_handles_successful_asynchronous_requests_with_streaming(self):
242270
stub.flush()
243271
self.successResultOf(d)
244272

273+
def test_session_persistence_between_requests(self):
274+
"""
275+
Calling request.getSession() in the wrapped resource will return
276+
a session with the same ID, until the sessions are cleaned.
277+
"""
278+
rsrc = _SessionIdTestResource()
279+
stub = StubTreq(rsrc)
280+
# request 1, getting original session ID
281+
d = stub.request("method", "http://example.com/")
282+
resp = self.successResultOf(d)
283+
cookies = resp.cookies()
284+
sid_1 = self.successResultOf(resp.content())
285+
# request 2, ensuring session ID stays the same
286+
d = stub.request("method", "http://example.com/", cookies=cookies)
287+
resp = self.successResultOf(d)
288+
sid_2 = self.successResultOf(resp.content())
289+
self.assertEqual(sid_1, sid_2)
290+
# request 3, ensuring the session IDs are different after cleaning
291+
# or expiring the sessions
292+
293+
# manually expire the sessions.
294+
rsrc.expire_sessions()
295+
296+
d = stub.request("method", "http://example.com/")
297+
resp = self.successResultOf(d)
298+
cookies = resp.cookies()
299+
sid_3 = self.successResultOf(resp.content())
300+
self.assertNotEqual(sid_1, sid_3)
301+
# request 4, ensuring that once again the session IDs are the same
302+
d = stub.request("method", "http://example.com/", cookies=cookies)
303+
resp = self.successResultOf(d)
304+
sid_4 = self.successResultOf(resp.content())
305+
self.assertEqual(sid_3, sid_4)
306+
245307

246308
class HasHeadersTests(TestCase):
247309
"""

src/treq/testing.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from twisted.web.error import SchemeNotSupported
2727
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
2828
from twisted.web.resource import Resource
29-
from twisted.web.server import Site
29+
from twisted.web.server import Session, Site
3030

3131
from zope.interface import directlyProvides, implementer
3232

@@ -88,6 +88,12 @@ def __init__(self, rootResource):
8888
reactor=self._memoryReactor,
8989
endpointFactory=_EndpointFactory(self._memoryReactor))
9090
self._rootResource = rootResource
91+
self._serverFactory = Site(self._rootResource, reactor=self._memoryReactor)
92+
self._serverFactory.sessionFactory = lambda site, uid: Session(
93+
site,
94+
uid,
95+
reactor=self._memoryReactor,
96+
)
9197
self._pumps = set()
9298

9399
def request(self, method, uri, headers=None, bodyProducer=None):
@@ -126,8 +132,7 @@ def check_already_called(r):
126132
# Create the protocol and fake transport for the client and server,
127133
# using the factory that was passed to the MemoryReactor for the
128134
# client, and a Site around our rootResource for the server.
129-
serverFactory = Site(self._rootResource, reactor=self._memoryReactor)
130-
serverProtocol = serverFactory.buildProtocol(clientAddress)
135+
serverProtocol = self._serverFactory.buildProtocol(clientAddress)
131136
serverTransport = iosim.FakeTransport(
132137
serverProtocol, isServer=True,
133138
hostAddress=serverAddress, peerAddress=clientAddress)
@@ -228,8 +233,8 @@ def __init__(self, resource):
228233
:param resource: A :obj:`Resource` object that provides the fake
229234
responses
230235
"""
231-
_agent = RequestTraversalAgent(resource)
232-
_client = HTTPClient(agent=_agent,
236+
self._agent = RequestTraversalAgent(resource)
237+
_client = HTTPClient(agent=self._agent,
233238
data_to_body_producer=_SynchronousProducer)
234239
for function_name in treq.__all__:
235240
function = getattr(_client, function_name, None)
@@ -239,7 +244,7 @@ def __init__(self, resource):
239244
function = _reject_files(function)
240245

241246
setattr(self, function_name, function)
242-
self.flush = _agent.flush
247+
self.flush = self._agent.flush
243248

244249

245250
class StringStubbingResource(Resource):

0 commit comments

Comments
 (0)