import urlparse import cgi import time import warnings from openid.message import Message, OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, \ OPENID1_NS, BARE_NS from openid import cryptutil, dh, oidutil, kvform from openid.store.nonce import mkNonce, split as splitNonce from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_2_0_TYPE, \ OPENID_1_1_TYPE from openid.consumer.consumer import \ AuthRequest, GenericConsumer, SUCCESS, FAILURE, CANCEL, SETUP_NEEDED, \ SuccessResponse, FailureResponse, SetupNeededResponse, CancelResponse, \ DiffieHellmanSHA1ConsumerSession, Consumer, PlainTextConsumerSession, \ SetupNeededError, DiffieHellmanSHA256ConsumerSession, ServerError, \ ProtocolError, _httpResponseToMessage from openid import association from openid.server.server import \ PlainTextServerSession, DiffieHellmanSHA1ServerSession from openid.yadis.manager import Discovery from openid.yadis.discover import DiscoveryFailure from openid.dh import DiffieHellman from openid.fetchers import HTTPResponse, HTTPFetchingError from openid import fetchers from openid.store import memstore from support import CatchLogs assocs = [ ('another 20-byte key.', 'Snarky'), ('\x00' * 20, 'Zeros'), ] def mkSuccess(endpoint, q): """Convenience function to create a SuccessResponse with the given arguments, all signed.""" signed_list = ['openid.' + k for k in q.keys()] return SuccessResponse(endpoint, Message.fromOpenIDArgs(q), signed_list) def parseQuery(qs): q = {} for (k, v) in cgi.parse_qsl(qs): assert not q.has_key(k) q[k] = v return q def associate(qs, assoc_secret, assoc_handle): """Do the server's half of the associate call, using the given secret and handle.""" q = parseQuery(qs) assert q['openid.mode'] == 'associate' assert q['openid.assoc_type'] == 'HMAC-SHA1' reply_dict = { 'assoc_type':'HMAC-SHA1', 'assoc_handle':assoc_handle, 'expires_in':'600', } if q.get('openid.session_type') == 'DH-SHA1': assert len(q) == 6 or len(q) == 4 message = Message.fromPostArgs(q) session = DiffieHellmanSHA1ServerSession.fromMessage(message) reply_dict['session_type'] = 'DH-SHA1' else: assert len(q) == 2 session = PlainTextServerSession.fromQuery(q) reply_dict.update(session.answer(assoc_secret)) return kvform.dictToKV(reply_dict) GOODSIG = "[A Good Signature]" class GoodAssociation: expiresIn = 3600 handle = "-blah-" def getExpiresIn(self): return self.expiresIn def checkMessageSignature(self, message): return message.getArg(OPENID_NS, 'sig') == GOODSIG class GoodAssocStore(memstore.MemoryStore): def getAssociation(self, server_url, handle=None): return GoodAssociation() class TestFetcher(object): def __init__(self, user_url, user_page, (assoc_secret, assoc_handle)): self.get_responses = {user_url:self.response(user_url, 200, user_page)} self.assoc_secret = assoc_secret self.assoc_handle = assoc_handle self.num_assocs = 0 def response(self, url, status, body): return HTTPResponse( final_url=url, status=status, headers={}, body=body) def fetch(self, url, body=None, headers=None): if body is None: if url in self.get_responses: return self.get_responses[url] else: try: body.index('openid.mode=associate') except ValueError: pass # fall through else: assert body.find('DH-SHA1') != -1 response = associate( body, self.assoc_secret, self.assoc_handle) self.num_assocs += 1 return self.response(url, 200, response) return self.response(url, 404, 'Not found') def makeFastConsumerSession(): """ Create custom DH object so tests run quickly. """ dh = DiffieHellman(100389557, 2) return DiffieHellmanSHA1ConsumerSession(dh) def setConsumerSession(con): con.session_types = {'DH-SHA1': makeFastConsumerSession} def _test_success(server_url, user_url, delegate_url, links, immediate=False): store = memstore.MemoryStore() if immediate: mode = 'checkid_immediate' else: mode = 'checkid_setup' endpoint = OpenIDServiceEndpoint() endpoint.claimed_id = user_url endpoint.server_url = server_url endpoint.local_id = delegate_url endpoint.type_uris = [OPENID_1_1_TYPE] fetcher = TestFetcher(None, None, assocs[0]) fetchers.setDefaultFetcher(fetcher, wrap_exceptions=False) def run(): trust_root = consumer_url consumer = GenericConsumer(store) setConsumerSession(consumer) request = consumer.begin(endpoint) return_to = consumer_url m = request.getMessage(trust_root, return_to, immediate) redirect_url = request.redirectURL(trust_root, return_to, immediate) parsed = urlparse.urlparse(redirect_url) qs = parsed[4] q = parseQuery(qs) new_return_to = q['openid.return_to'] del q['openid.return_to'] assert q == { 'openid.mode':mode, 'openid.identity':delegate_url, 'openid.trust_root':trust_root, 'openid.assoc_handle':fetcher.assoc_handle, }, (q, user_url, delegate_url, mode) assert new_return_to.startswith(return_to) assert redirect_url.startswith(server_url) parsed = urlparse.urlparse(new_return_to) query = parseQuery(parsed[4]) query.update({ 'openid.mode':'id_res', 'openid.return_to':new_return_to, 'openid.identity':delegate_url, 'openid.assoc_handle':fetcher.assoc_handle, }) assoc = store.getAssociation(server_url, fetcher.assoc_handle) message = Message.fromPostArgs(query) message = assoc.signMessage(message) info = consumer.complete(message, request.endpoint, new_return_to) assert info.status == SUCCESS, info.message assert info.identity_url == user_url assert fetcher.num_assocs == 0 run() assert fetcher.num_assocs == 1 # Test that doing it again uses the existing association run() assert fetcher.num_assocs == 1 # Another association is created if we remove the existing one store.removeAssociation(server_url, fetcher.assoc_handle) run() assert fetcher.num_assocs == 2 # Test that doing it again uses the existing association run() assert fetcher.num_assocs == 2 import unittest http_server_url = 'http://server.example.com/' consumer_url = 'http://consumer.example.com/' https_server_url = 'https://server.example.com/' class TestSuccess(unittest.TestCase, CatchLogs): server_url = http_server_url user_url = 'http://www.example.com/user.html' delegate_url = 'http://consumer.example.com/user' def setUp(self): CatchLogs.setUp(self) self.links = '' % ( self.server_url,) self.delegate_links = ('' '') % ( self.server_url, self.delegate_url) def tearDown(self): CatchLogs.tearDown(self) def test_nodelegate(self): _test_success(self.server_url, self.user_url, self.user_url, self.links) def test_nodelegateImmediate(self): _test_success(self.server_url, self.user_url, self.user_url, self.links, True) def test_delegate(self): _test_success(self.server_url, self.user_url, self.delegate_url, self.delegate_links) def test_delegateImmediate(self): _test_success(self.server_url, self.user_url, self.delegate_url, self.delegate_links, True) class TestSuccessHTTPS(TestSuccess): server_url = https_server_url class TestConstruct(unittest.TestCase): def setUp(self): self.store_sentinel = object() def test_construct(self): oidc = GenericConsumer(self.store_sentinel) self.failUnless(oidc.store is self.store_sentinel) def test_nostore(self): self.failUnlessRaises(TypeError, GenericConsumer) class TestIdRes(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer def setUp(self): CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.consumer = self.consumer_class(self.store) self.return_to = "nonny" self.endpoint = OpenIDServiceEndpoint() self.endpoint.claimed_id = self.consumer_id = "consu" self.endpoint.server_url = self.server_url = "serlie" self.endpoint.local_id = self.server_id = "sirod" self.endpoint.type_uris = [OPENID_1_1_TYPE] def disableDiscoveryVerification(self): """Set the discovery verification to a no-op for test cases in which we don't care.""" def dummyVerifyDiscover(_, endpoint): return endpoint self.consumer._verifyDiscoveryResults = dummyVerifyDiscover def disableReturnToChecking(self): def checkReturnTo(unused1, unused2): return True self.consumer._checkReturnTo = checkReturnTo complete = self.consumer.complete def callCompleteWithoutReturnTo(message, endpoint): return complete(message, endpoint, None) self.consumer.complete = callCompleteWithoutReturnTo class TestIdResCheckSignature(TestIdRes): def setUp(self): TestIdRes.setUp(self) self.assoc = GoodAssociation() self.assoc.handle = "{not_dumb}" self.store.storeAssociation(self.endpoint.server_url, self.assoc) self.message = Message.fromPostArgs({ 'openid.mode': 'id_res', 'openid.identity': '=example', 'openid.sig': GOODSIG, 'openid.assoc_handle': self.assoc.handle, 'openid.signed': 'mode,identity,assoc_handle,signed', 'frobboz': 'banzit', }) def test_sign(self): # assoc_handle to assoc with good sig self.consumer._idResCheckSignature(self.message, self.endpoint.server_url) def test_signFailsWithBadSig(self): self.message.setArg(OPENID_NS, 'sig', 'BAD SIGNATURE') self.failUnlessRaises( ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) def test_stateless(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._processCheckAuthResponse = ( lambda response, server_url: True) self.consumer._makeKVPost = lambda args, server_url: {} self.consumer._idResCheckSignature(self.message, self.endpoint.server_url) def test_statelessRaisesError(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._checkAuth = lambda unused1, unused2: False self.failUnlessRaises( ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) def test_stateless_noStore(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer.store = None self.consumer._processCheckAuthResponse = ( lambda response, server_url: True) self.consumer._makeKVPost = lambda args, server_url: {} self.consumer._idResCheckSignature(self.message, self.endpoint.server_url) def test_statelessRaisesError_noStore(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._checkAuth = lambda unused1, unused2: False self.consumer.store = None self.failUnlessRaises( ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) class TestQueryFormat(TestIdRes): def test_notAList(self): # XXX: should be a Message object test, not a consumer test # Value should be a single string. If it's a list, it should generate # an exception. query = {'openid.mode': ['cancel']} try: r = Message.fromPostArgs(query) except TypeError, err: self.failUnless(str(err).find('values') != -1, err) else: self.fail("expected TypeError, got this instead: %s" % (r,)) class TestComplete(TestIdRes): """Testing GenericConsumer.complete. Other TestIdRes subclasses test more specific aspects. """ def test_setupNeededIdRes(self): message = Message.fromOpenIDArgs({'mode': 'id_res'}) setup_url_sentinel = object() def raiseSetupNeeded(msg): self.failUnless(msg is message) raise SetupNeededError(setup_url_sentinel) self.consumer._checkSetupNeeded = raiseSetupNeeded response = self.consumer.complete(message, None, None) self.failUnlessEqual(SETUP_NEEDED, response.status) self.failUnless(setup_url_sentinel is response.setup_url) def test_cancel(self): message = Message.fromPostArgs({'openid.mode': 'cancel'}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, CANCEL) self.failUnless(r.identity_url == self.endpoint.claimed_id) def test_cancel_with_return_to(self): message = Message.fromPostArgs({'openid.mode': 'cancel'}) r = self.consumer.complete(message, self.endpoint, self.return_to) self.failUnlessEqual(r.status, CANCEL) self.failUnless(r.identity_url == self.endpoint.claimed_id) def test_error(self): msg = 'an error message' message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, }) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) self.failUnlessEqual(r.message, msg) def test_errorWithNoOptionalKeys(self): msg = 'an error message' contact = 'some contact info here' message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.contact': contact, }) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) self.failUnless(r.contact == contact) self.failUnless(r.reference is None) self.failUnlessEqual(r.message, msg) def test_errorWithOptionalKeys(self): msg = 'an error message' contact = 'me' reference = 'support ticket' message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.reference': reference, 'openid.contact': contact, 'openid.ns': OPENID2_NS, }) r = self.consumer.complete(message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) self.failUnless(r.contact == contact) self.failUnless(r.reference == reference) self.failUnlessEqual(r.message, msg) def test_noMode(self): message = Message.fromPostArgs({}) r = self.consumer.complete(message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) def test_idResMissingField(self): # XXX - this test is passing, but not necessarily by what it # is supposed to test for. status in FAILURE, but it's because # *check_auth* failed, not because it's missing an arg, exactly. message = Message.fromPostArgs({'openid.mode': 'id_res'}) self.failUnlessRaises(ProtocolError, self.consumer._doIdRes, message, self.endpoint, None) def test_idResURLMismatch(self): class VerifiedError(Exception): pass def discoverAndVerify(claimed_id, _to_match_endpoints): raise VerifiedError self.consumer._discoverAndVerify = discoverAndVerify self.disableReturnToChecking() message = Message.fromPostArgs( {'openid.mode': 'id_res', 'openid.return_to': 'return_to (just anything)', 'openid.identity': 'something wrong (not self.consumer_id)', 'openid.assoc_handle': 'does not matter', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', }) self.consumer.store = GoodAssocStore() self.failUnlessRaises(VerifiedError, self.consumer.complete, message, self.endpoint) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') class TestCompleteMissingSig(unittest.TestCase, CatchLogs): def setUp(self): self.store = GoodAssocStore() self.consumer = GenericConsumer(self.store) self.server_url = "http://idp.unittest/" CatchLogs.setUp(self) claimed_id = 'bogus.claimed' self.message = Message.fromOpenIDArgs( {'mode': 'id_res', 'return_to': 'return_to (just anything)', 'identity': claimed_id, 'assoc_handle': 'does not matter', 'sig': GOODSIG, 'response_nonce': mkNonce(), 'signed': 'identity,return_to,response_nonce,assoc_handle,claimed_id,op_endpoint', 'claimed_id': claimed_id, 'op_endpoint': self.server_url, 'ns':OPENID2_NS, }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url self.endpoint.claimed_id = claimed_id self.consumer._checkReturnTo = lambda unused1, unused2 : True def tearDown(self): CatchLogs.tearDown(self) def test_idResMissingNoSigs(self): def _vrfy(resp_msg, endpoint=None): return endpoint self.consumer._verifyDiscoveryResults = _vrfy r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) def test_idResNoIdentity(self): self.message.delArg(OPENID_NS, 'identity') self.message.delArg(OPENID_NS, 'claimed_id') self.endpoint.claimed_id = None self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,op_endpoint') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) def test_idResMissingIdentitySig(self): self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) def test_idResMissingReturnToSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) def test_idResMissingAssocHandleSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) def test_idResMissingClaimedIDSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,assoc_handle') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) def failUnlessSuccess(self, response): if response.status != SUCCESS: self.fail("Non-successful response: %s" % (response,)) class TestCheckAuthResponse(TestIdRes, CatchLogs): def setUp(self): CatchLogs.setUp(self) TestIdRes.setUp(self) def tearDown(self): CatchLogs.tearDown(self) def _createAssoc(self): issued = time.time() lifetime = 1000 assoc = association.Association( 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') store = self.consumer.store store.storeAssociation(self.server_url, assoc) assoc2 = store.getAssociation(self.server_url) self.failUnlessEqual(assoc, assoc2) def test_goodResponse(self): """successful response to check_authentication""" response = Message.fromOpenIDArgs({'is_valid':'true',}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) def test_missingAnswer(self): """check_authentication returns false when the server sends no answer""" response = Message.fromOpenIDArgs({}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) def test_badResponse(self): """check_authentication returns false when is_valid is false""" response = Message.fromOpenIDArgs({'is_valid':'false',}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) def test_badResponseInvalidate(self): """Make sure that the handle is invalidated when is_valid is false From "Verifying directly with the OpenID Provider":: If the OP responds with "is_valid" set to "true", and "invalidate_handle" is present, the Relying Party SHOULD NOT send further authentication requests with that handle. """ self._createAssoc() response = Message.fromOpenIDArgs({ 'is_valid':'false', 'invalidate_handle':'handle', }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) self.failUnless( self.consumer.store.getAssociation(self.server_url) is None) def test_invalidateMissing(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ 'is_valid':'true', 'invalidate_handle':'missing', }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnlessLogMatches( 'Received "invalidate_handle"' ) def test_invalidateMissing_noStore(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ 'is_valid':'true', 'invalidate_handle':'missing', }) self.consumer.store = None r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnlessLogMatches( 'Received "invalidate_handle"', 'Unexpectedly got invalidate_handle without a store') def test_invalidatePresent(self): """invalidate_handle with a handle that exists From "Verifying directly with the OpenID Provider":: If the OP responds with "is_valid" set to "true", and "invalidate_handle" is present, the Relying Party SHOULD NOT send further authentication requests with that handle. """ self._createAssoc() response = Message.fromOpenIDArgs({ 'is_valid':'true', 'invalidate_handle':'handle', }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnless( self.consumer.store.getAssociation(self.server_url) is None) class TestSetupNeeded(TestIdRes): def failUnlessSetupNeeded(self, expected_setup_url, message): try: self.consumer._checkSetupNeeded(message) except SetupNeededError, why: self.failUnlessEqual(expected_setup_url, why.user_setup_url) else: self.fail("Expected to find an immediate-mode response") def test_setupNeededOpenID1(self): """The minimum conditions necessary to trigger Setup Needed""" setup_url = 'http://unittest/setup-here' message = Message.fromPostArgs({ 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, }) self.failUnless(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) def test_setupNeededOpenID1_extra(self): """Extra stuff along with setup_url still trigger Setup Needed""" setup_url = 'http://unittest/setup-here' message = Message.fromPostArgs({ 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, 'openid.identity': 'bogus', }) self.failUnless(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) def test_noSetupNeededOpenID1(self): """When the user_setup_url is missing on an OpenID 1 message, we assume that it's not a cancel response to checkid_immediate""" message = Message.fromOpenIDArgs({'mode': 'id_res'}) self.failUnless(message.isOpenID1()) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) def test_setupNeededOpenID2(self): message = Message.fromOpenIDArgs({ 'mode':'setup_needed', 'ns':OPENID2_NS, }) self.failUnless(message.isOpenID2()) response = self.consumer.complete(message, None, None) self.failUnlessEqual('setup_needed', response.status) self.failUnlessEqual(None, response.setup_url) def test_setupNeededDoesntWorkForOpenID1(self): message = Message.fromOpenIDArgs({ 'mode':'setup_needed', }) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) response = self.consumer.complete(message, None, None) self.failUnlessEqual('failure', response.status) self.failUnless(response.message.startswith('Invalid openid.mode')) def test_noSetupNeededOpenID2(self): message = Message.fromOpenIDArgs({ 'mode':'id_res', 'game':'puerto_rico', 'ns':OPENID2_NS, }) self.failUnless(message.isOpenID2()) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) class IdResCheckForFieldsTest(TestIdRes): def setUp(self): self.consumer = GenericConsumer(None) def mkSuccessTest(openid_args, signed_list): def test(self): message = Message.fromOpenIDArgs(openid_args) message.setArg(OPENID_NS, 'signed', ','.join(signed_list)) self.consumer._idResCheckForFields(message) return test test_openid1Success = mkSuccessTest( {'return_to':'return', 'assoc_handle':'assoc handle', 'sig':'a signature', 'identity':'someone', }, ['return_to', 'identity']) test_openid2Success = mkSuccessTest( {'ns':OPENID2_NS, 'return_to':'return', 'assoc_handle':'assoc handle', 'sig':'a signature', 'op_endpoint':'my favourite server', 'response_nonce':'use only once', }, ['return_to', 'response_nonce', 'assoc_handle', 'op_endpoint']) test_openid2Success_identifiers = mkSuccessTest( {'ns':OPENID2_NS, 'return_to':'return', 'assoc_handle':'assoc handle', 'sig':'a signature', 'claimed_id':'i claim to be me', 'identity':'my server knows me as me', 'op_endpoint':'my favourite server', 'response_nonce':'use only once', }, ['return_to', 'response_nonce', 'identity', 'claimed_id', 'assoc_handle', 'op_endpoint']) def mkMissingFieldTest(openid_args): def test(self): message = Message.fromOpenIDArgs(openid_args) try: self.consumer._idResCheckForFields(message) except ProtocolError, why: self.failUnless(why[0].startswith('Missing required')) else: self.fail('Expected an error, but none occurred') return test def mkMissingSignedTest(openid_args): def test(self): message = Message.fromOpenIDArgs(openid_args) try: self.consumer._idResCheckForFields(message) except ProtocolError, why: self.failUnless(why[0].endswith('not signed')) else: self.fail('Expected an error, but none occurred') return test test_openid1Missing_returnToSig = mkMissingSignedTest( {'return_to':'return', 'assoc_handle':'assoc handle', 'sig':'a signature', 'identity':'someone', 'signed':'identity', }) test_openid1Missing_identitySig = mkMissingSignedTest( {'return_to':'return', 'assoc_handle':'assoc handle', 'sig':'a signature', 'identity':'someone', 'signed':'return_to' }) test_openid2Missing_opEndpointSig = mkMissingSignedTest( {'ns':OPENID2_NS, 'return_to':'return', 'assoc_handle':'assoc handle', 'sig':'a signature', 'identity':'someone', 'op_endpoint':'the endpoint', 'signed':'return_to,identity,assoc_handle' }) test_openid1MissingReturnTo = mkMissingFieldTest( {'assoc_handle':'assoc handle', 'sig':'a signature', 'identity':'someone', }) test_openid1MissingAssocHandle = mkMissingFieldTest( {'return_to':'return', 'sig':'a signature', 'identity':'someone', }) # XXX: I could go on... class CheckAuthHappened(Exception): pass class CheckNonceVerifyTest(TestIdRes, CatchLogs): def setUp(self): CatchLogs.setUp(self) TestIdRes.setUp(self) self.consumer.openid1_nonce_query_arg_name = 'nonce' def tearDown(self): CatchLogs.tearDown(self) def test_openid1Success(self): """use consumer-generated nonce""" nonce_value = mkNonce() self.return_to = 'http://rt.unittest/?nonce=%s' % (nonce_value,) self.response = Message.fromOpenIDArgs({'return_to': self.return_to}) self.response.setArg(BARE_NS, 'nonce', nonce_value) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_openid1Missing(self): """use consumer-generated nonce""" self.response = Message.fromOpenIDArgs({}) n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint) self.failUnless(n is None, n) self.failUnlessLogEmpty() def test_consumerNonceOpenID2(self): """OpenID 2 does not use consumer-generated nonce""" self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) self.response = Message.fromOpenIDArgs( {'return_to': self.return_to, 'ns':OPENID2_NS}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonce(self): """use server-generated nonce""" self.response = Message.fromOpenIDArgs( {'ns':OPENID2_NS, 'response_nonce': mkNonce(),}) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" self.response = Message.fromOpenIDArgs( {'ns':OPENID1_NS, 'return_to': 'http://return.to/', 'response_nonce': mkNonce(),}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() def test_badNonce(self): """remove the nonce from the store From "Checking the Nonce":: When the Relying Party checks the signature on an assertion, the Relying Party SHOULD ensure that an assertion has not yet been accepted with the same value for "openid.response_nonce" from the same OP Endpoint URL. """ nonce = mkNonce() stamp, salt = splitNonce(nonce) self.store.useNonce(self.server_url, stamp, salt) self.response = Message.fromOpenIDArgs( {'response_nonce': nonce, 'ns':OPENID2_NS, }) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_successWithNoStore(self): """When there is no store, checking the nonce succeeds""" self.consumer.store = None self.response = Message.fromOpenIDArgs( {'response_nonce': mkNonce(), 'ns':OPENID2_NS, }) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_tamperedNonce(self): """Malformed nonce""" self.response = Message.fromOpenIDArgs( {'ns':OPENID2_NS, 'response_nonce':'malformed'}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_missingNonce(self): """no nonce parameter on the return_to""" self.response = Message.fromOpenIDArgs( {'return_to': self.return_to}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) class CheckAuthDetectingConsumer(GenericConsumer): def _checkAuth(self, *args): raise CheckAuthHappened(args) def _idResCheckNonce(self, *args): """We're not testing nonce-checking, so just return success when it asks.""" return True class TestCheckAuthTriggered(TestIdRes, CatchLogs): consumer_class = CheckAuthDetectingConsumer def setUp(self): TestIdRes.setUp(self) CatchLogs.setUp(self) self.disableDiscoveryVerification() def test_checkAuthTriggered(self): message = Message.fromPostArgs({ 'openid.return_to':self.return_to, 'openid.identity':self.server_id, 'openid.assoc_handle':'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', }) self.disableReturnToChecking() try: result = self.consumer._doIdRes(message, self.endpoint, None) except CheckAuthHappened: pass else: self.fail('_checkAuth did not happen. Result was: %r %s' % (result, self.messages)) def test_checkAuthTriggeredWithAssoc(self): # Store an association for this server that does not match the # handle that is in the message issued = time.time() lifetime = 1000 assoc = association.Association( 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, assoc) self.disableReturnToChecking() message = Message.fromPostArgs({ 'openid.return_to':self.return_to, 'openid.identity':self.server_id, 'openid.assoc_handle':'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', }) try: result = self.consumer._doIdRes(message, self.endpoint, None) except CheckAuthHappened: pass else: self.fail('_checkAuth did not happen. Result was: %r' % (result,)) def test_expiredAssoc(self): # Store an expired association for the server with the handle # that is in the message issued = time.time() - 10 lifetime = 0 handle = 'handle' assoc = association.Association( handle, 'secret', issued, lifetime, 'HMAC-SHA1') self.failUnless(assoc.expiresIn <= 0) self.store.storeAssociation(self.server_url, assoc) message = Message.fromPostArgs({ 'openid.return_to':self.return_to, 'openid.identity':self.server_id, 'openid.assoc_handle':handle, 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', }) self.disableReturnToChecking() self.failUnlessRaises(ProtocolError, self.consumer._doIdRes, message, self.endpoint, None) def test_newerAssoc(self): lifetime = 1000 good_issued = time.time() - 10 good_handle = 'handle' good_assoc = association.Association( good_handle, 'secret', good_issued, lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, good_assoc) bad_issued = time.time() - 5 bad_handle = 'handle2' bad_assoc = association.Association( bad_handle, 'secret', bad_issued, lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, bad_assoc) query = { 'return_to':self.return_to, 'identity':self.server_id, 'assoc_handle':good_handle, } message = Message.fromOpenIDArgs(query) message = good_assoc.signMessage(message) self.disableReturnToChecking() info = self.consumer._doIdRes(message, self.endpoint, None) self.failUnlessEqual(info.status, SUCCESS, info.message) self.failUnlessEqual(self.consumer_id, info.identity_url) class TestReturnToArgs(unittest.TestCase): """Verifying the Return URL paramaters. From the specification "Verifying the Return URL":: To verify that the "openid.return_to" URL matches the URL that is processing this assertion: - The URL scheme, authority, and path MUST be the same between the two URLs. - Any query parameters that are present in the "openid.return_to" URL MUST also be present with the same values in the accepting URL. XXX: So far we have only tested the second item on the list above. XXX: _verifyReturnToArgs is not invoked anywhere. """ def setUp(self): store = object() self.consumer = GenericConsumer(store) def test_returnToArgsOkay(self): query = { 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', 'foo': 'bar', } # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) def test_returnToArgsUnexpectedArg(self): query = { 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/', 'foo': 'bar', } # no return value, success is assumed if there are no exceptions. self.failUnlessRaises(ProtocolError, self.consumer._verifyReturnToArgs, query) def test_returnToMismatch(self): query = { 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', } # fail, query has no key 'foo'. self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) query['foo'] = 'baz' # fail, values for 'foo' do not match. self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) def test_noReturnTo(self): query = {'openid.mode': 'id_res'} self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) def test_completeBadReturnTo(self): """Test GenericConsumer.complete()'s handling of bad return_to values. """ return_to = "http://some.url/path?foo=bar" # Scheme, authority, and path differences are checked by # GenericConsumer._checkReturnTo. Query args checked by # GenericConsumer._verifyReturnToArgs. bad_return_tos = [ # Scheme only "https://some.url/path?foo=bar", # Authority only "http://some.url.invalid/path?foo=bar", # Path only "http://some.url/path_extra?foo=bar", # Query args differ "http://some.url/path?foo=bar2", "http://some.url/path?foo2=bar", ] m = Message(OPENID1_NS) m.setArg(OPENID_NS, 'mode', 'cancel') m.setArg(BARE_NS, 'foo', 'bar') endpoint = None for bad in bad_return_tos: m.setArg(OPENID_NS, 'return_to', bad) self.failIf(self.consumer._checkReturnTo(m, return_to)) def test_completeGoodReturnTo(self): """Test GenericConsumer.complete()'s handling of good return_to values. """ return_to = "http://some.url/path" good_return_tos = [ (return_to, {}), (return_to + "?another=arg", {(BARE_NS, 'another'): 'arg'}), (return_to + "?another=arg#fragment", {(BARE_NS, 'another'): 'arg'}), ("HTTP"+return_to[4:], {}), (return_to.replace('url','URL'), {}), ("http://some.url:80/path", {}), ("http://some.url/p%61th", {}), ("http://some.url/./path", {}), ] endpoint = None for good, extra in good_return_tos: m = Message(OPENID1_NS) m.setArg(OPENID_NS, 'mode', 'cancel') for ns, key in extra: m.setArg(ns, key, extra[(ns, key)]) m.setArg(OPENID_NS, 'return_to', good) result = self.consumer.complete(m, endpoint, return_to) self.failUnless(isinstance(result, CancelResponse), \ "Expected CancelResponse, got %r for %s" % (result, good,)) class MockFetcher(object): def __init__(self, response=None): self.response = response or HTTPResponse() self.fetches = [] def fetch(self, url, body=None, headers=None): self.fetches.append((url, body, headers)) return self.response class ExceptionRaisingMockFetcher(object): class MyException(Exception): pass def fetch(self, url, body=None, headers=None): raise self.MyException('mock fetcher exception') class BadArgCheckingConsumer(GenericConsumer): def _makeKVPost(self, args, _): assert args == { 'openid.mode':'check_authentication', 'openid.signed':'foo', 'openid.ns':OPENID1_NS }, args return None class TestCheckAuth(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer def setUp(self): CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.consumer = self.consumer_class(self.store) self._orig_fetcher = fetchers.getDefaultFetcher() self.fetcher = MockFetcher() fetchers.setDefaultFetcher(self.fetcher) def tearDown(self): CatchLogs.tearDown(self) fetchers.setDefaultFetcher(self._orig_fetcher, wrap_exceptions=False) def test_error(self): self.fetcher.response = HTTPResponse( "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') query = {'openid.signed': 'stuff', 'openid.stuff':'a value'} r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) self.failIf(r) self.failUnless(self.messages) def test_bad_args(self): query = { 'openid.signed':'foo', 'closid.foo':'something', } consumer = BadArgCheckingConsumer(self.store) consumer._checkAuth(Message.fromPostArgs(query), 'does://not.matter') def test_signedList(self): query = Message.fromOpenIDArgs({ 'mode': 'id_res', 'sig': 'rabbits', 'identity': '=example', 'assoc_handle': 'munchkins', 'ns.sreg': 'urn:sreg', 'sreg.email': 'bogus@example.com', 'signed': 'identity,mode,ns.sreg,sreg.email', 'foo': 'bar', }) args = self.consumer._createCheckAuthRequest(query) self.failUnless(args.isOpenID1()) for signed_arg in query.getArg(OPENID_NS, 'signed').split(','): self.failUnless(args.getAliasedArg(signed_arg), signed_arg) def test_112(self): args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', 'openid.claimed_id': 'http://binkley.lan/user/test01', 'openid.identity': 'http://test01.binkley.lan/', 'openid.mode': 'id_res', 'openid.ns': 'http://specs.openid.net/auth/2.0', 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', 'openid.op_endpoint': 'http://binkley.lan/server', 'openid.pape.auth_policies': 'none', 'openid.pape.auth_time': '2008-01-28T20:42:36Z', 'openid.pape.nist_auth_level': '0', 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' } self.failUnlessEqual(OPENID2_NS, args['openid.ns']) incoming = Message.fromPostArgs(args) self.failUnless(incoming.isOpenID2()) car = self.consumer._createCheckAuthRequest(incoming) expected_args = args.copy() expected_args['openid.mode'] = 'check_authentication' expected =Message.fromPostArgs(expected_args) self.failUnless(expected.isOpenID2()) self.failUnlessEqual(expected, car) self.failUnlessEqual(expected_args, car.toPostArgs()) class TestFetchAssoc(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer def setUp(self): CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.fetcher = MockFetcher() fetchers.setDefaultFetcher(self.fetcher) self.consumer = self.consumer_class(self.store) def test_error_404(self): """404 from a kv post raises HTTPFetchingError""" self.fetcher.response = HTTPResponse( "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') self.failUnlessRaises( fetchers.HTTPFetchingError, self.consumer._makeKVPost, Message.fromPostArgs({'mode':'associate'}), "http://server_url") def test_error_exception_unwrapped(self): """Ensure that exceptions are bubbled through from fetchers when making associations """ self.fetcher = ExceptionRaisingMockFetcher() fetchers.setDefaultFetcher(self.fetcher, wrap_exceptions=False) self.failUnlessRaises(self.fetcher.MyException, self.consumer._makeKVPost, Message.fromPostArgs({'mode':'associate'}), "http://server_url") # exception fetching returns no association e = OpenIDServiceEndpoint() e.server_url = 'some://url' self.failUnlessRaises(self.fetcher.MyException, self.consumer._getAssociation, e) self.failUnlessRaises(self.fetcher.MyException, self.consumer._checkAuth, Message.fromPostArgs({'openid.signed':''}), 'some://url') def test_error_exception_wrapped(self): """Ensure that openid.fetchers.HTTPFetchingError is caught by the association creation stuff. """ self.fetcher = ExceptionRaisingMockFetcher() # This will wrap exceptions! fetchers.setDefaultFetcher(self.fetcher) self.failUnlessRaises(fetchers.HTTPFetchingError, self.consumer._makeKVPost, Message.fromOpenIDArgs({'mode':'associate'}), "http://server_url") # exception fetching returns no association e = OpenIDServiceEndpoint() e.server_url = 'some://url' self.failUnless(self.consumer._getAssociation(e) is None) msg = Message.fromPostArgs({'openid.signed':''}) self.failIf(self.consumer._checkAuth(msg, 'some://url')) class TestSuccessResponse(unittest.TestCase): def setUp(self): self.endpoint = OpenIDServiceEndpoint() self.endpoint.claimed_id = 'identity_url' def test_extensionResponse(self): resp = mkSuccess(self.endpoint, { 'ns.sreg':'urn:sreg', 'ns.unittest':'urn:unittest', 'unittest.one':'1', 'unittest.two':'2', 'sreg.nickname':'j3h', 'return_to':'return_to', }) utargs = resp.extensionResponse('urn:unittest', False) self.failUnlessEqual(utargs, {'one':'1', 'two':'2'}) sregargs = resp.extensionResponse('urn:sreg', False) self.failUnlessEqual(sregargs, {'nickname':'j3h'}) def test_extensionResponseSigned(self): args = { 'ns.sreg':'urn:sreg', 'ns.unittest':'urn:unittest', 'unittest.one':'1', 'unittest.two':'2', 'sreg.nickname':'j3h', 'sreg.dob':'yesterday', 'return_to':'return_to', 'signed': 'sreg.nickname,unittest.one,sreg.dob', } signed_list = ['openid.sreg.nickname', 'openid.unittest.one', 'openid.sreg.dob',] # Don't use mkSuccess because it creates an all-inclusive # signed list. msg = Message.fromOpenIDArgs(args) resp = SuccessResponse(self.endpoint, msg, signed_list) # All args in this NS are signed, so expect all. sregargs = resp.extensionResponse('urn:sreg', True) self.failUnlessEqual(sregargs, {'nickname':'j3h', 'dob': 'yesterday'}) # Not all args in this NS are signed, so expect None when # asking for them. utargs = resp.extensionResponse('urn:unittest', True) self.failUnlessEqual(utargs, None) def test_noReturnTo(self): resp = mkSuccess(self.endpoint, {}) self.failUnless(resp.getReturnTo() is None) def test_returnTo(self): resp = mkSuccess(self.endpoint, {'return_to':'return_to'}) self.failUnlessEqual(resp.getReturnTo(), 'return_to') def test_displayIdentifierClaimedId(self): resp = mkSuccess(self.endpoint, {}) self.failUnlessEqual(resp.getDisplayIdentifier(), resp.endpoint.claimed_id) def test_displayIdentifierOverride(self): self.endpoint.display_identifier = "http://input.url/" resp = mkSuccess(self.endpoint, {}) self.failUnlessEqual(resp.getDisplayIdentifier(), "http://input.url/") class StubConsumer(object): def __init__(self): self.assoc = object() self.response = None self.endpoint = None def begin(self, service): auth_req = AuthRequest(service, self.assoc) self.endpoint = service return auth_req def complete(self, message, endpoint, return_to): assert endpoint is self.endpoint return self.response class ConsumerTest(unittest.TestCase): """Tests for high-level consumer.Consumer functions. Its GenericConsumer component is stubbed out with StubConsumer. """ def setUp(self): self.endpoint = OpenIDServiceEndpoint() self.endpoint.claimed_id = self.identity_url = 'http://identity.url/' self.store = None self.session = {} self.consumer = Consumer(self.session, self.store) self.consumer.consumer = StubConsumer() self.discovery = Discovery(self.session, self.identity_url, self.consumer.session_key_prefix) def test_setAssociationPreference(self): self.consumer.setAssociationPreference([]) self.failUnless(isinstance(self.consumer.consumer.negotiator, association.SessionNegotiator)) self.failUnlessEqual([], self.consumer.consumer.negotiator.allowed_types) self.consumer.setAssociationPreference([('HMAC-SHA1', 'DH-SHA1')]) self.failUnlessEqual([('HMAC-SHA1', 'DH-SHA1')], self.consumer.consumer.negotiator.allowed_types) def withDummyDiscovery(self, callable, dummy_getNextService): class DummyDisco(object): def __init__(self, *ignored): pass getNextService = dummy_getNextService import openid.consumer.consumer old_discovery = openid.consumer.consumer.Discovery try: openid.consumer.consumer.Discovery = DummyDisco callable() finally: openid.consumer.consumer.Discovery = old_discovery def test_beginHTTPError(self): """Make sure that the discovery HTTP failure case behaves properly """ def getNextService(self, ignored): raise HTTPFetchingError("Unit test") def test(): try: self.consumer.begin('unused in this test') except DiscoveryFailure, why: self.failUnless(why[0].startswith('Error fetching')) self.failIf(why[0].find('Unit test') == -1) else: self.fail('Expected DiscoveryFailure') self.withDummyDiscovery(test, getNextService) def test_beginNoServices(self): def getNextService(self, ignored): return None url = 'http://a.user.url/' def test(): try: self.consumer.begin(url) except DiscoveryFailure, why: self.failUnless(why[0].startswith('No usable OpenID')) self.failIf(why[0].find(url) == -1) else: self.fail('Expected DiscoveryFailure') self.withDummyDiscovery(test, getNextService) def test_beginWithoutDiscovery(self): # Does this really test anything non-trivial? result = self.consumer.beginWithoutDiscovery(self.endpoint) # The result is an auth request self.failUnless(isinstance(result, AuthRequest)) # Side-effect of calling beginWithoutDiscovery is setting the # session value to the endpoint attribute of the result self.failUnless(self.session[self.consumer._token_key] is result.endpoint) # The endpoint that we passed in is the endpoint on the auth_request self.failUnless(result.endpoint is self.endpoint) def test_completeEmptySession(self): text = "failed complete" def checkEndpoint(message, endpoint, return_to): self.failUnless(endpoint is None) return FailureResponse(endpoint, text) self.consumer.consumer.complete = checkEndpoint response = self.consumer.complete({}, None) self.failUnlessEqual(response.status, FAILURE) self.failUnlessEqual(response.message, text) self.failUnless(response.identity_url is None) def _doResp(self, auth_req, exp_resp): """complete a transaction, using the expected response from the generic consumer.""" # response is an attribute of StubConsumer, returned by # StubConsumer.complete. self.consumer.consumer.response = exp_resp # endpoint is stored in the session self.failUnless(self.session) resp = self.consumer.complete({}, None) # All responses should have the same identity URL, and the # session should be cleaned out if self.endpoint.claimed_id != IDENTIFIER_SELECT: self.failUnless(resp.identity_url is self.identity_url) self.failIf(self.consumer._token_key in self.session) # Expected status response self.failUnlessEqual(resp.status, exp_resp.status) return resp def _doRespNoDisco(self, exp_resp): """Set up a transaction without discovery""" auth_req = self.consumer.beginWithoutDiscovery(self.endpoint) resp = self._doResp(auth_req, exp_resp) # There should be nothing left in the session once we have completed. self.failIf(self.session) return resp def test_noDiscoCompleteSuccessWithToken(self): self._doRespNoDisco(mkSuccess(self.endpoint, {})) def test_noDiscoCompleteCancelWithToken(self): self._doRespNoDisco(CancelResponse(self.endpoint)) def test_noDiscoCompleteFailure(self): msg = 'failed!' resp = self._doRespNoDisco(FailureResponse(self.endpoint, msg)) self.failUnless(resp.message is msg) def test_noDiscoCompleteSetupNeeded(self): setup_url = 'http://setup.url/' resp = self._doRespNoDisco( SetupNeededResponse(self.endpoint, setup_url)) self.failUnless(resp.setup_url is setup_url) # To test that discovery is cleaned up, we need to initialize a # Yadis manager, and have it put its values in the session. def _doRespDisco(self, is_clean, exp_resp): """Set up and execute a transaction, with discovery""" self.discovery.createManager([self.endpoint], self.identity_url) auth_req = self.consumer.begin(self.identity_url) resp = self._doResp(auth_req, exp_resp) manager = self.discovery.getManager() if is_clean: self.failUnless(self.discovery.getManager() is None, manager) else: self.failIf(self.discovery.getManager() is None, manager) return resp # Cancel and success DO clean up the discovery process def test_completeSuccess(self): self._doRespDisco(True, mkSuccess(self.endpoint, {})) def test_completeCancel(self): self._doRespDisco(True, CancelResponse(self.endpoint)) # Failure and setup_needed don't clean up the discovery process def test_completeFailure(self): msg = 'failed!' resp = self._doRespDisco(False, FailureResponse(self.endpoint, msg)) self.failUnless(resp.message is msg) def test_completeSetupNeeded(self): setup_url = 'http://setup.url/' resp = self._doRespDisco( False, SetupNeededResponse(self.endpoint, setup_url)) self.failUnless(resp.setup_url is setup_url) def test_successDifferentURL(self): """ Be sure that the session gets cleaned up when the response is successful and has a different URL than the one in the request. """ # Set up a request endpoint describing an IDP URL self.identity_url = 'http://idp.url/' self.endpoint.claimed_id = self.endpoint.local_id = IDENTIFIER_SELECT # Use a response endpoint with a different URL (asserted by # the IDP) resp_endpoint = OpenIDServiceEndpoint() resp_endpoint.claimed_id = "http://user.url/" resp = self._doRespDisco( True, mkSuccess(resp_endpoint, {})) self.failUnless(self.discovery.getManager(force=True) is None) def test_begin(self): self.discovery.createManager([self.endpoint], self.identity_url) # Should not raise an exception auth_req = self.consumer.begin(self.identity_url) self.failUnless(isinstance(auth_req, AuthRequest)) self.failUnless(auth_req.endpoint is self.endpoint) self.failUnless(auth_req.endpoint is self.consumer.consumer.endpoint) self.failUnless(auth_req.assoc is self.consumer.consumer.assoc) class IDPDrivenTest(unittest.TestCase): def setUp(self): self.store = GoodAssocStore() self.consumer = GenericConsumer(self.store) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = "http://idp.unittest/" def test_idpDrivenBegin(self): # Testing here that the token-handling doesn't explode... self.consumer.begin(self.endpoint) def test_idpDrivenComplete(self): identifier = '=directed_identifier' message = Message.fromPostArgs({ 'openid.identity': '=directed_identifier', 'openid.return_to': 'x', 'openid.assoc_handle': 'z', 'openid.signed': 'identity,return_to', 'openid.sig': GOODSIG, }) discovered_endpoint = OpenIDServiceEndpoint() discovered_endpoint.claimed_id = identifier discovered_endpoint.server_url = self.endpoint.server_url discovered_endpoint.local_id = identifier iverified = [] def verifyDiscoveryResults(identifier, endpoint): self.failUnless(endpoint is self.endpoint) iverified.append(discovered_endpoint) return discovered_endpoint self.consumer._verifyDiscoveryResults = verifyDiscoveryResults self.consumer._idResCheckNonce = lambda *args: True self.consumer._checkReturnTo = lambda unused1, unused2 : True response = self.consumer._doIdRes(message, self.endpoint, None) self.failUnlessSuccess(response) self.failUnlessEqual(response.identity_url, "=directed_identifier") # assert that discovery attempt happens and returns good self.failUnlessEqual(iverified, [discovered_endpoint]) def test_idpDrivenCompleteFraud(self): # crap with an identifier that doesn't match discovery info message = Message.fromPostArgs({ 'openid.identity': '=directed_identifier', 'openid.return_to': 'x', 'openid.assoc_handle': 'z', 'openid.signed': 'identity,return_to', 'openid.sig': GOODSIG, }) def verifyDiscoveryResults(identifier, endpoint): raise DiscoveryFailure("PHREAK!", None) self.consumer._verifyDiscoveryResults = verifyDiscoveryResults self.consumer._checkReturnTo = lambda unused1, unused2 : True self.failUnlessRaises(DiscoveryFailure, self.consumer._doIdRes, message, self.endpoint, None) def failUnlessSuccess(self, response): if response.status != SUCCESS: self.fail("Non-successful response: %s" % (response,)) class TestDiscoveryVerification(unittest.TestCase): services = [] def setUp(self): self.store = GoodAssocStore() self.consumer = GenericConsumer(self.store) self.consumer._discover = self.discoveryFunc self.identifier = "http://idp.unittest/1337" self.server_url = "http://endpoint.unittest/" self.message = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.identity': self.identifier, 'openid.claimed_id': self.identifier, 'openid.op_endpoint': self.server_url, }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url def test_theGoodStuff(self): endpoint = OpenIDServiceEndpoint() endpoint.type_uris = [OPENID_2_0_TYPE] endpoint.claimed_id = self.identifier endpoint.server_url = self.server_url endpoint.local_id = self.identifier self.services = [endpoint] r = self.consumer._verifyDiscoveryResults(self.message, endpoint) self.failUnlessEqual(r, endpoint) def test_otherServer(self): text = "verify failed" def discoverAndVerify(claimed_id, to_match_endpoints): self.failUnlessEqual(claimed_id, self.identifier) for to_match in to_match_endpoints: self.failUnlessEqual(claimed_id, to_match.claimed_id) raise ProtocolError(text) self.consumer._discoverAndVerify = discoverAndVerify # a set of things without the stuff endpoint = OpenIDServiceEndpoint() endpoint.type_uris = [OPENID_2_0_TYPE] endpoint.claimed_id = self.identifier endpoint.server_url = "http://the-MOON.unittest/" endpoint.local_id = self.identifier self.services = [endpoint] try: r = self.consumer._verifyDiscoveryResults(self.message, endpoint) except ProtocolError, e: # Should we make more ProtocolError subclasses? self.failUnless(str(e), text) else: self.fail("expected ProtocolError, %r returned." % (r,)) def test_foreignDelegate(self): text = "verify failed" def discoverAndVerify(claimed_id, to_match_endpoints): self.failUnlessEqual(claimed_id, self.identifier) for to_match in to_match_endpoints: self.failUnlessEqual(claimed_id, to_match.claimed_id) raise ProtocolError(text) self.consumer._discoverAndVerify = discoverAndVerify # a set of things with the server stuff but other delegate endpoint = OpenIDServiceEndpoint() endpoint.type_uris = [OPENID_2_0_TYPE] endpoint.claimed_id = self.identifier endpoint.server_url = self.server_url endpoint.local_id = "http://unittest/juan-carlos" try: r = self.consumer._verifyDiscoveryResults(self.message, endpoint) except ProtocolError, e: self.failUnlessEqual(str(e), text) else: self.fail("Exepected ProtocolError, %r returned" % (r,)) def test_nothingDiscovered(self): # a set of no things. self.services = [] self.failUnlessRaises(DiscoveryFailure, self.consumer._verifyDiscoveryResults, self.message, self.endpoint) def discoveryFunc(self, identifier): return identifier, self.services class TestCreateAssociationRequest(unittest.TestCase): def setUp(self): class DummyEndpoint(object): use_compatibility = False def compatibilityMode(self): return self.use_compatibility self.endpoint = DummyEndpoint() self.consumer = GenericConsumer(store=None) self.assoc_type = 'HMAC-SHA1' def test_noEncryptionSendsType(self): session_type = 'no-encryption' session, args = self.consumer._createAssociateRequest( self.endpoint, self.assoc_type, session_type) self.failUnless(isinstance(session, PlainTextConsumerSession)) expected = Message.fromOpenIDArgs( {'ns':OPENID2_NS, 'session_type':session_type, 'mode':'associate', 'assoc_type':self.assoc_type, }) self.failUnlessEqual(expected, args) def test_noEncryptionCompatibility(self): self.endpoint.use_compatibility = True session_type = 'no-encryption' session, args = self.consumer._createAssociateRequest( self.endpoint, self.assoc_type, session_type) self.failUnless(isinstance(session, PlainTextConsumerSession)) self.failUnlessEqual(Message.fromOpenIDArgs({'mode':'associate', 'assoc_type':self.assoc_type, }), args) def test_dhSHA1Compatibility(self): # Set the consumer's session type to a fast session since we # need it here. setConsumerSession(self.consumer) self.endpoint.use_compatibility = True session_type = 'DH-SHA1' session, args = self.consumer._createAssociateRequest( self.endpoint, self.assoc_type, session_type) self.failUnless(isinstance(session, DiffieHellmanSHA1ConsumerSession)) # This is a random base-64 value, so just check that it's # present. self.failUnless(args.getArg(OPENID1_NS, 'dh_consumer_public')) args.delArg(OPENID1_NS, 'dh_consumer_public') # OK, session_type is set here and not for no-encryption # compatibility expected = Message.fromOpenIDArgs({'mode':'associate', 'session_type':'DH-SHA1', 'assoc_type':self.assoc_type, 'dh_modulus': 'BfvStQ==', 'dh_gen': 'Ag==', }) self.failUnlessEqual(expected, args) # XXX: test the other types class TestDiffieHellmanResponseParameters(object): session_cls = None message_namespace = None def setUp(self): # Pre-compute DH with small prime so tests run quickly. self.server_dh = DiffieHellman(100389557, 2) self.consumer_dh = DiffieHellman(100389557, 2) # base64(btwoc(g ^ xb mod p)) self.dh_server_public = cryptutil.longToBase64(self.server_dh.public) self.secret = cryptutil.randomString(self.session_cls.secret_size) self.enc_mac_key = oidutil.toBase64( self.server_dh.xorSecret(self.consumer_dh.public, self.secret, self.session_cls.hash_func)) self.consumer_session = self.session_cls(self.consumer_dh) self.msg = Message(self.message_namespace) def testExtractSecret(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) extracted = self.consumer_session.extractSecret(self.msg) self.failUnlessEqual(extracted, self.secret) def testAbsentServerPublic(self): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) self.failUnlessRaises(KeyError, self.consumer_session.extractSecret, self.msg) def testAbsentMacKey(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) self.failUnlessRaises(KeyError, self.consumer_session.extractSecret, self.msg) def testInvalidBase64Public(self): self.msg.setArg(OPENID_NS, 'dh_server_public', 'n o t b a s e 6 4.') self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg) def testInvalidBase64MacKey(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) self.msg.setArg(OPENID_NS, 'enc_mac_key', 'n o t base 64') self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg) class TestOpenID1SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID1_NS class TestOpenID2SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID2_NS if cryptutil.SHA256_AVAILABLE: class TestOpenID2SHA256(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA256ConsumerSession message_namespace = OPENID2_NS else: warnings.warn("Not running SHA256 association session tests.") class TestNoStore(unittest.TestCase): def setUp(self): self.consumer = GenericConsumer(None) def test_completeNoGetAssoc(self): """_getAssociation is never called when the store is None""" def notCalled(unused): self.fail('This method was unexpectedly called') endpoint = OpenIDServiceEndpoint() endpoint.claimed_id = 'identity_url' self.consumer._getAssociation = notCalled auth_request = self.consumer.begin(endpoint) # _getAssociation was not called class NonAnonymousAuthRequest(object): endpoint = 'unused' def setAnonymous(self, unused): raise ValueError('Should trigger ProtocolError') class TestConsumerAnonymous(unittest.TestCase): def test_beginWithoutDiscoveryAnonymousFail(self): """Make sure that ValueError for setting an auth request anonymous gets converted to a ProtocolError """ sess = {} consumer = Consumer(sess, None) def bogusBegin(unused): return NonAnonymousAuthRequest() consumer.consumer.begin = bogusBegin self.failUnlessRaises( ProtocolError, consumer.beginWithoutDiscovery, None) class TestDiscoverAndVerify(unittest.TestCase): def setUp(self): self.consumer = GenericConsumer(None) self.discovery_result = None def dummyDiscover(unused_identifier): return self.discovery_result self.consumer._discover = dummyDiscover self.to_match = OpenIDServiceEndpoint() def failUnlessDiscoveryFailure(self): self.failUnlessRaises( DiscoveryFailure, self.consumer._discoverAndVerify, 'http://claimed-id.com/', [self.to_match]) def test_noServices(self): """Discovery returning no results results in a DiscoveryFailure exception""" self.discovery_result = (None, []) self.failUnlessDiscoveryFailure() def test_noMatches(self): """If no discovered endpoint matches the values from the assertion, then we end up raising a ProtocolError """ self.discovery_result = (None, ['unused']) def raiseProtocolError(unused1, unused2): raise ProtocolError('unit test') self.consumer._verifyDiscoverySingle = raiseProtocolError self.failUnlessDiscoveryFailure() def test_matches(self): """If an endpoint matches, we return it """ # Discovery returns a single "endpoint" object matching_endpoint = 'matching endpoint' self.discovery_result = (None, [matching_endpoint]) # Make verifying discovery return True for this endpoint def returnTrue(unused1, unused2): return True self.consumer._verifyDiscoverySingle = returnTrue # Since _verifyDiscoverySingle returns True, we should get the # first endpoint that we passed in as a result. result = self.consumer._discoverAndVerify( 'http://claimed.id/', [self.to_match]) self.failUnlessEqual(matching_endpoint, result) from openid.extension import Extension class SillyExtension(Extension): ns_uri = 'http://silly.example.com/' ns_alias = 'silly' def getExtensionArgs(self): return {'i_am':'silly'} class TestAddExtension(unittest.TestCase): def test_SillyExtension(self): ext = SillyExtension() ar = AuthRequest(OpenIDServiceEndpoint(), None) ar.addExtension(ext) ext_args = ar.message.getArgs(ext.ns_uri) self.failUnlessEqual(ext.getExtensionArgs(), ext_args) class TestKVPost(unittest.TestCase): def setUp(self): self.server_url = 'http://unittest/%s' % (self.id(),) def test_200(self): from openid.fetchers import HTTPResponse response = HTTPResponse() response.status = 200 response.body = "foo:bar\nbaz:quux\n" r = _httpResponseToMessage(response, self.server_url) expected_msg = Message.fromOpenIDArgs({'foo':'bar','baz':'quux'}) self.failUnlessEqual(expected_msg, r) def test_400(self): response = HTTPResponse() response.status = 400 response.body = "error:bonk\nerror_code:7\n" try: r = _httpResponseToMessage(response, self.server_url) except ServerError, e: self.failUnlessEqual(e.error_text, 'bonk') self.failUnlessEqual(e.error_code, '7') else: self.fail("Expected ServerError, got return %r" % (r,)) def test_500(self): # 500 as an example of any non-200, non-400 code. response = HTTPResponse() response.status = 500 response.body = "foo:bar\nbaz:quux\n" self.failUnlessRaises(fetchers.HTTPFetchingError, _httpResponseToMessage, response, self.server_url) if __name__ == '__main__': unittest.main()