Skip to content

Commit

Permalink
Explicitly check authentication for v2 endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 22, 2021
1 parent 72133ea commit ce04a68
Show file tree
Hide file tree
Showing 16 changed files with 101 additions and 87 deletions.
34 changes: 16 additions & 18 deletions sydent/http/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def tokenFromRequest(request):
return token


def authIfV2(sydent, request, requireTermsAgreed=True):
def authV2(sydent, request, requireTermsAgreed=True):
"""For v2 APIs check that the request has a valid access token associated with it
:param sydent: The Sydent instance to use.
Expand All @@ -67,25 +67,23 @@ def authIfV2(sydent, request, requireTermsAgreed=True):
:raises MatrixRestError: If the request is v2 but could not be authed or the user has
not accepted terms.
"""
if request.path.startswith(b'/_matrix/identity/v2'):
token = tokenFromRequest(request)
token = tokenFromRequest(request)

if token is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
if token is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")

accountStore = AccountStore(sydent)
accountStore = AccountStore(sydent)

account = accountStore.getAccountByToken(token)
if account is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
account = accountStore.getAccountByToken(token)
if account is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")

if requireTermsAgreed:
terms = get_terms(sydent)
if (
terms.getMasterVersion() is not None and
account.consentVersion != terms.getMasterVersion()
):
raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")
if requireTermsAgreed:
terms = get_terms(sydent)
if (
terms.getMasterVersion() is not None and
account.consentVersion != terms.getMasterVersion()
):
raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")

return account
return None
return account
58 changes: 28 additions & 30 deletions sydent/http/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,19 @@ def __init__(self, sydent):
v2 = self.sydent.servlets.v2

validate = Resource()
validate_v2 = Resource()
email = Resource()
email_v2 = Resource()
msisdn = Resource()
emailReqCode = self.sydent.servlets.emailRequestCode
emailValCode = self.sydent.servlets.emailValidate
msisdnReqCode = self.sydent.servlets.msisdnRequestCode
msisdnValCode = self.sydent.servlets.msisdnValidate
getValidated3pid = self.sydent.servlets.getValidated3pid

lookup = self.sydent.servlets.lookup
bulk_lookup = self.sydent.servlets.bulk_lookup

hash_details = self.sydent.servlets.hash_details
lookup_v2 = self.sydent.servlets.lookup_v2
msisdn_v2 = Resource()

threepid_v1 = Resource()
threepid_v2 = Resource()
bind = self.sydent.servlets.threepidBind
unbind = self.sydent.servlets.threepidUnbind

pubkey = Resource()
ephemeralPubkey = Resource()

pk_ed25519 = self.sydent.servlets.pubkey_ed25519

root.putChild(b'_matrix', matrix)
matrix.putChild(b'identity', identity)
identity.putChild(b'api', api)
Expand All @@ -78,33 +67,42 @@ def __init__(self, sydent):
validate.putChild(b'email', email)
validate.putChild(b'msisdn', msisdn)

validate_v2.putChild(b'email', email_v2)
validate_v2.putChild(b'msisdn', msisdn_v2)

v1.putChild(b'validate', validate)

v1.putChild(b'lookup', lookup)
v1.putChild(b'bulk_lookup', bulk_lookup)
v1.putChild(b'lookup', self.sydent.servlets.lookup)
v1.putChild(b'bulk_lookup', self.sydent.servlets.bulk_lookup)

v1.putChild(b'pubkey', pubkey)
pubkey.putChild(b'isvalid', self.sydent.servlets.pubkeyIsValid)
pubkey.putChild(b'ed25519:0', pk_ed25519)
pubkey.putChild(b'ed25519:0', self.sydent.servlets.pubkey_ed25519)
pubkey.putChild(b'ephemeral', ephemeralPubkey)
ephemeralPubkey.putChild(b'isvalid', self.sydent.servlets.ephemeralPubkeyIsValid)

threepid_v2.putChild(b'getValidated3pid', getValidated3pid)
threepid_v2.putChild(b'bind', bind)
threepid_v2.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pidV2)
threepid_v2.putChild(b'bind', self.sydent.servlets.threepidBindV2)
threepid_v2.putChild(b'unbind', unbind)

threepid_v1.putChild(b'getValidated3pid', getValidated3pid)
threepid_v1.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pid)
threepid_v1.putChild(b'unbind', unbind)
if self.sydent.enable_v1_associations:
threepid_v1.putChild(b'bind', bind)
threepid_v1.putChild(b'bind', self.sydent.servlets.threepidBind)

v1.putChild(b'3pid', threepid_v1)

email.putChild(b'requestToken', emailReqCode)
email.putChild(b'submitToken', emailValCode)
email.putChild(b'requestToken', self.sydent.servlets.emailRequestCode)
email.putChild(b'submitToken', self.sydent.servlets.emailValidate)

email_v2.putChild(b'requestToken', self.sydent.servlets.emailRequestCodeV2)
email_v2.putChild(b'submitToken', self.sydent.servlets.emailValidateV2)

msisdn.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCode)
msisdn.putChild(b'submitToken', self.sydent.servlets.msisdnValidate)

msisdn.putChild(b'requestToken', msisdnReqCode)
msisdn.putChild(b'submitToken', msisdnValCode)
msisdn_v2.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCodeV2)
msisdn_v2.putChild(b'submitToken', self.sydent.servlets.msisdnValidateV2)

v1.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)

Expand All @@ -122,13 +120,13 @@ def __init__(self, sydent):
account.putChild(b'logout', self.sydent.servlets.logoutServlet)

# v2 versions of existing APIs
v2.putChild(b'validate', validate)
v2.putChild(b'validate', validate_v2)
v2.putChild(b'pubkey', pubkey)
v2.putChild(b'3pid', threepid_v2)
v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServlet)
v2.putChild(b'lookup', lookup_v2)
v2.putChild(b'hash_details', hash_details)
v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServletV2)
v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServletV2)
v2.putChild(b'lookup', self.sydent.servlets.lookup_v2)
v2.putChild(b'hash_details', self.sydent.servlets.hash_details)

self.factory = Site(root)
self.factory.displayTracebacks = False
Expand Down
5 changes: 2 additions & 3 deletions sydent/http/servlets/accountservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from twisted.web.resource import Resource

from sydent.http.servlets import jsonwrap, send_cors
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2


class AccountServlet(Resource):
Expand All @@ -36,7 +36,7 @@ def render_GET(self, request):
"""
send_cors(request)

account = authIfV2(self.sydent, request)
account = authV2(self.sydent, request)

return {
"user_id": account.userId,
Expand All @@ -45,4 +45,3 @@ def render_GET(self, request):
def render_OPTIONS(self, request):
send_cors(request)
return b''

8 changes: 5 additions & 3 deletions sydent/http/servlets/blindlysignstuffservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,26 @@
import signedjson.sign
from sydent.db.invite_tokens import JoinTokenStore
from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2

logger = logging.getLogger(__name__)


class BlindlySignStuffServlet(Resource):
isLeaf = True

def __init__(self, syd):
def __init__(self, syd, require_auth=False):
self.sydent = syd
self.server_name = syd.server_name
self.tokenStore = JoinTokenStore(syd)
self.require_auth = require_auth

@jsonwrap
def render_POST(self, request):
send_cors(request)

authIfV2(self.sydent, request)
if self.require_auth:
authV2(self.sydent, request)

args = get_args(request, ("private_key", "token", "mxid"))

Expand Down
3 changes: 0 additions & 3 deletions sydent/http/servlets/bulklookupservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import logging

from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
from sydent.http.auth import authIfV2


logger = logging.getLogger(__name__)
Expand All @@ -45,8 +44,6 @@ def render_POST(self, request):
"""
send_cors(request)

authIfV2(self.sydent, request)

args = get_args(request, ('threepids',))

threepids = args['threepids']
Expand Down
14 changes: 9 additions & 5 deletions sydent/http/servlets/emailservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@


from sydent.http.servlets import get_args, jsonwrap, send_cors
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2


class EmailRequestCodeServlet(Resource):
isLeaf = True

def __init__(self, syd):
def __init__(self, syd, require_auth=False):
self.sydent = syd
self.require_auth = require_auth

@jsonwrap
def render_POST(self, request):
send_cors(request)

authIfV2(self.sydent, request)
if self.require_auth:
authV2(self.sydent, request)

args = get_args(request, ('email', 'client_secret', 'send_attempt'))

Expand Down Expand Up @@ -85,8 +87,9 @@ def render_OPTIONS(self, request):
class EmailValidateCodeServlet(Resource):
isLeaf = True

def __init__(self, syd):
def __init__(self, syd, require_auth=False):
self.sydent = syd
self.require_auth = require_auth

def render_GET(self, request):
args = get_args(request, ('nextLink',), required=False)
Expand Down Expand Up @@ -121,7 +124,8 @@ def render_GET(self, request):
def render_POST(self, request):
send_cors(request)

authIfV2(self.sydent, request)
if self.require_auth:
authV2(self.sydent, request)

return self.do_validate_request(request)

Expand Down
8 changes: 5 additions & 3 deletions sydent/http/servlets/getvalidated3pidservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from twisted.web.resource import Resource

from sydent.http.servlets import jsonwrap, get_args
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2
from sydent.db.valsession import ThreePidValSessionStore
from sydent.util.stringutils import is_valid_client_secret
from sydent.validators import (
Expand All @@ -32,12 +32,14 @@
class GetValidated3pidServlet(Resource):
isLeaf = True

def __init__(self, syd):
def __init__(self, syd, require_auth=False):
self.sydent = syd
self.require_auth = require_auth

@jsonwrap
def render_GET(self, request):
authIfV2(self.sydent, request)
if self.require_auth:
authV2(self.sydent, request)

args = get_args(request, ('sid', 'client_secret'))

Expand Down
4 changes: 2 additions & 2 deletions sydent/http/servlets/hashdetailsservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import absolute_import

from twisted.web.resource import Resource
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2

import logging

Expand Down Expand Up @@ -48,7 +48,7 @@ def render_GET(self, request):
"""
send_cors(request)

authIfV2(self.sydent, request)
authV2(self.sydent, request)

return {
"algorithms": self.known_algorithms,
Expand Down
4 changes: 2 additions & 2 deletions sydent/http/servlets/logoutservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from sydent.http.servlets import jsonwrap, send_cors
from sydent.db.accounts import AccountStore
from sydent.http.auth import authIfV2, tokenFromRequest
from sydent.http.auth import authV2, tokenFromRequest


logger = logging.getLogger(__name__)
Expand All @@ -40,7 +40,7 @@ def render_POST(self, request):
"""
send_cors(request)

authIfV2(self.sydent, request, False)
authV2(self.sydent, request, False)

token = tokenFromRequest(request)

Expand Down
3 changes: 0 additions & 3 deletions sydent/http/servlets/lookupservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import signedjson.sign

from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
from sydent.http.auth import authIfV2


logger = logging.getLogger(__name__)
Expand All @@ -49,8 +48,6 @@ def render_GET(self, request):
"""
send_cors(request)

authIfV2(self.sydent, request)

args = get_args(request, ('medium', 'address'))

medium = args['medium']
Expand Down
4 changes: 2 additions & 2 deletions sydent/http/servlets/lookupv2servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from sydent.http.servlets import get_args, jsonwrap, send_cors
from sydent.db.threepid_associations import GlobalAssociationStore
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2
from sydent.http.servlets.hashdetailsservlet import HashDetailsServlet

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,7 +62,7 @@ def render_POST(self, request):
"""
send_cors(request)

authIfV2(self.sydent, request)
authV2(self.sydent, request)

args = get_args(request, ('addresses', 'algorithm', 'pepper'))

Expand Down
Loading

0 comments on commit ce04a68

Please sign in to comment.