diff options
Diffstat (limited to 'streaming/index.js')
-rw-r--r-- | streaming/index.js | 50 |
1 files changed, 43 insertions, 7 deletions
diff --git a/streaming/index.js b/streaming/index.js index 55ecc3ba3..10df210a3 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -195,14 +195,14 @@ const startWorker = (workerId) => { next(); }; - const accountFromToken = (token, req, next) => { + const accountFromToken = (token, allowedScopes, req, next) => { pgPool.connect((err, client, done) => { if (err) { next(err); return; } - client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { + client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { done(); if (err) { @@ -218,18 +218,29 @@ const startWorker = (workerId) => { return; } + const scopes = result.rows[0].scopes.split(' '); + + if (allowedScopes.size > 0 && !scopes.some(scope => allowedScopes.includes(scope))) { + err = new Error('Access token does not cover required scopes'); + err.statusCode = 401; + + next(err); + return; + } + req.accountId = result.rows[0].account_id; req.chosenLanguages = result.rows[0].chosen_languages; + req.allowNotifications = scopes.some(scope => ['read', 'read:notifications'].includes(scope)); next(); }); }); }; - const accountFromRequest = (req, next, required = true) => { + const accountFromRequest = (req, next, required = true, allowedScopes = ['read']) => { const authorization = req.headers.authorization; const location = url.parse(req.url, true); - const accessToken = location.query.access_token; + const accessToken = location.query.access_token || req.headers['sec-websocket-protocol']; if (!authorization && !accessToken) { if (required) { @@ -246,7 +257,7 @@ const startWorker = (workerId) => { const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken; - accountFromToken(token, req, next); + accountFromToken(token, allowedScopes, req, next); }; const PUBLIC_STREAMS = [ @@ -261,6 +272,16 @@ const startWorker = (workerId) => { const wsVerifyClient = (info, cb) => { const location = url.parse(info.req.url, true); const authRequired = !PUBLIC_STREAMS.some(stream => stream === location.query.stream); + const allowedScopes = []; + + if (authRequired) { + allowedScopes.push('read'); + if (location.query.stream === 'user:notification') { + allowedScopes.push('read:notifications'); + } else { + allowedScopes.push('read:statuses'); + } + } accountFromRequest(info.req, err => { if (!err) { @@ -269,7 +290,7 @@ const startWorker = (workerId) => { log.error(info.req.requestId, err.toString()); cb(false, 401, 'Unauthorized'); } - }, authRequired); + }, authRequired, allowedScopes); }; const PUBLIC_ENDPOINTS = [ @@ -286,7 +307,18 @@ const startWorker = (workerId) => { } const authRequired = !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path); - accountFromRequest(req, next, authRequired); + const allowedScopes = []; + + if (authRequired) { + allowedScopes.push('read'); + if (req.path === '/api/v1/streaming/user/notification') { + allowedScopes.push('read:notifications'); + } else { + allowedScopes.push('read:statuses'); + } + } + + accountFromRequest(req, next, authRequired, allowedScopes); }; const errorMiddleware = (err, req, res, {}) => { @@ -339,6 +371,10 @@ const startWorker = (workerId) => { return; } + if (event === 'notification' && !req.allowNotifications) { + return; + } + // Only send local-only statuses to logged-in users if (payload.local_only && !req.accountId) { log.silly(req.requestId, `Message ${payload.id} filtered because it was local-only`); |