From aa10200e58ccb340b6384532ccdba6b7fbac037a Mon Sep 17 00:00:00 2001 From: Eugen Rochko Date: Thu, 12 Nov 2020 23:05:24 +0100 Subject: Fix streaming API allowing connections to persist after access token invalidation (#15111) Fix #14816 --- streaming/index.js | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) (limited to 'streaming') diff --git a/streaming/index.js b/streaming/index.js index 791a26941..877f45d19 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -294,7 +294,7 @@ const startWorker = (workerId) => { return; } - client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { + client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { done(); if (err) { @@ -310,6 +310,7 @@ const startWorker = (workerId) => { return; } + req.accessTokenId = result.rows[0].id; req.scopes = result.rows[0].scopes.split(' '); req.accountId = result.rows[0].account_id; req.chosenLanguages = result.rows[0].chosen_languages; @@ -450,6 +451,55 @@ const startWorker = (workerId) => { }); }; + /** + * @typedef SystemMessageHandlers + * @property {function(): void} onKill + */ + + /** + * @param {any} req + * @param {SystemMessageHandlers} eventHandlers + * @return {function(string): void} + */ + const createSystemMessageListener = (req, eventHandlers) => { + return message => { + const json = parseJSON(message); + + if (!json) return; + + const { event } = json; + + log.silly(req.requestId, `System message for ${req.accountId}: ${event}`); + + if (event === 'kill') { + log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`); + eventHandlers.onKill(); + } + } + }; + + /** + * @param {any} req + * @param {any} res + */ + const subscribeHttpToSystemChannel = (req, res) => { + const systemChannelId = `timeline:access_token:${req.accessTokenId}`; + + const listener = createSystemMessageListener(req, { + + onKill () { + res.end(); + }, + + }); + + res.on('close', () => { + unsubscribe(`${redisPrefix}${systemChannelId}`, listener); + }); + + subscribe(`${redisPrefix}${systemChannelId}`, listener); + }; + /** * @param {any} req * @param {any} res @@ -462,6 +512,8 @@ const startWorker = (workerId) => { } accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => { + subscribeHttpToSystemChannel(req, res); + }).then(() => { next(); }).catch(err => { next(err); @@ -536,7 +588,9 @@ const startWorker = (workerId) => { const listener = message => { const json = parseJSON(message); + if (!json) return; + const { event, payload, queued_at } = json; const transmit = () => { @@ -902,6 +956,28 @@ const startWorker = (workerId) => { socket.send(JSON.stringify({ error: err.toString() })); }); + /** + * @param {WebSocketSession} session + */ + const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => { + const systemChannelId = `timeline:access_token:${request.accessTokenId}`; + + const listener = createSystemMessageListener(request, { + + onKill () { + socket.close(); + }, + + }); + + subscribe(`${redisPrefix}${systemChannelId}`, listener); + + subscriptions[systemChannelId] = { + listener, + stopHeartbeat: () => {}, + }; + }; + /** * @param {string|string[]} arrayOrString * @return {string} @@ -948,7 +1024,9 @@ const startWorker = (workerId) => { ws.on('message', data => { const json = parseJSON(data); + if (!json) return; + const { type, stream, ...params } = json; if (type === 'subscribe') { @@ -960,6 +1038,8 @@ const startWorker = (workerId) => { } }); + subscribeWebsocketToSystemChannel(session); + if (location.query.stream) { subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); } -- cgit