diff options
Diffstat (limited to 'streaming/index.js')
-rw-r--r-- | streaming/index.js | 737 |
1 files changed, 502 insertions, 235 deletions
diff --git a/streaming/index.js b/streaming/index.js index f69064e33..f5c9b4224 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -1,3 +1,5 @@ +// @ts-check + const os = require('os'); const throng = require('throng'); const dotenv = require('dotenv'); @@ -12,7 +14,7 @@ const uuid = require('uuid'); const fs = require('fs'); const env = process.env.NODE_ENV || 'development'; -const alwaysRequireAuth = process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true'; +const alwaysRequireAuth = process.env.LIMITED_FEDERATION_MODE === 'true' || process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true'; dotenv.config({ path: env === 'production' ? '.env.production' : '.env', @@ -20,6 +22,10 @@ dotenv.config({ log.level = process.env.LOG_LEVEL || 'verbose'; +/** + * @param {string} dbUrl + * @return {Object.<string, any>} + */ const dbUrlToConfig = (dbUrl) => { if (!dbUrl) { return {}; @@ -53,6 +59,10 @@ const dbUrlToConfig = (dbUrl) => { return config; }; +/** + * @param {Object.<string, any>} defaultConfig + * @param {string} redisUrl + */ const redisUrlToClient = (defaultConfig, redisUrl) => { const config = defaultConfig; @@ -108,6 +118,7 @@ const startWorker = (workerId) => { } const app = express(); + app.set('trusted proxy', process.env.TRUSTED_PROXY_IP || 'loopback,uniquelocal'); const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); @@ -130,6 +141,9 @@ const startWorker = (workerId) => { const redisSubscribeClient = redisUrlToClient(redisParams, process.env.REDIS_URL); const redisClient = redisUrlToClient(redisParams, process.env.REDIS_URL); + /** + * @type {Object.<string, Array.<function(string): void>>} + */ const subs = {}; redisSubscribeClient.on('message', (channel, message) => { @@ -144,11 +158,11 @@ const startWorker = (workerId) => { callbacks.forEach(callback => callback(message)); }); + /** + * @param {string[]} channels + * @return {function(): void} + */ const subscriptionHeartbeat = channels => { - if (!Array.isArray(channels)) { - channels = [channels]; - } - const interval = 6 * 60; const tellSubscribed = () => { @@ -164,25 +178,65 @@ const startWorker = (workerId) => { }; }; + /** + * @param {string} channel + * @param {function(string): void} callback + */ const subscribe = (channel, callback) => { log.silly(`Adding listener for ${channel}`); subs[channel] = subs[channel] || []; + if (subs[channel].length === 0) { log.verbose(`Subscribe ${channel}`); redisSubscribeClient.subscribe(channel); } + subs[channel].push(callback); }; + /** + * @param {string} channel + * @param {function(string): void} callback + */ const unsubscribe = (channel, callback) => { log.silly(`Removing listener for ${channel}`); + + if (!subs[channel]) { + return; + } + subs[channel] = subs[channel].filter(item => item !== callback); + if (subs[channel].length === 0) { log.verbose(`Unsubscribe ${channel}`); redisSubscribeClient.unsubscribe(channel); } }; + const FALSE_VALUES = [ + false, + 0, + "0", + "f", + "F", + "false", + "FALSE", + "off", + "OFF" + ]; + + /** + * @param {any} value + * @return {boolean} + */ + const isTruthy = value => + value && !FALSE_VALUES.includes(value); + + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} + */ const allowCrossDomain = (req, res, next) => { res.header('Access-Control-Allow-Origin', '*'); res.header('Access-Control-Allow-Headers', 'Authorization, Accept, Cache-Control'); @@ -191,6 +245,11 @@ const startWorker = (workerId) => { next(); }; + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} + */ const setRequestId = (req, res, next) => { req.requestId = uuid.v4(); res.header('X-Request-Id', req.requestId); @@ -198,16 +257,26 @@ const startWorker = (workerId) => { next(); }; + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} + */ const setRemoteAddress = (req, res, next) => { req.remoteAddress = req.connection.remoteAddress; next(); }; - const accountFromToken = (token, allowedScopes, req, next) => { + /** + * @param {string} token + * @param {any} req + * @return {Promise.<void>} + */ + const accountFromToken = (token, req) => new Promise((resolve, reject) => { pgPool.connect((err, client, done) => { if (err) { - next(err); + reject(err); return; } @@ -215,62 +284,89 @@ const startWorker = (workerId) => { done(); if (err) { - next(err); + reject(err); return; } if (result.rows.length === 0) { err = new Error('Invalid access token'); - err.statusCode = 401; + err.status = 401; - next(err); - return; - } - - const scopes = result.rows[0].scopes.split(' '); - - if (allowedScopes.size > 0 && !scopes.some(scope => allowedScopes.includes(scope))) { - err = new Error('Access token does not cover required scopes'); - err.statusCode = 401; - - next(err); + reject(err); return; } + req.scopes = result.rows[0].scopes.split(' '); req.accountId = result.rows[0].account_id; req.chosenLanguages = result.rows[0].chosen_languages; - req.allowNotifications = scopes.some(scope => ['read', 'read:notifications'].includes(scope)); + req.allowNotifications = req.scopes.some(scope => ['read', 'read:notifications'].includes(scope)); req.deviceId = result.rows[0].device_id; - next(); + resolve(); }); }); - }; + }); - const accountFromRequest = (req, next, required = true, allowedScopes = ['read']) => { + /** + * @param {any} req + * @param {boolean=} required + * @return {Promise.<void>} + */ + const accountFromRequest = (req, required = true) => new Promise((resolve, reject) => { const authorization = req.headers.authorization; - const location = url.parse(req.url, true); - const accessToken = location.query.access_token || req.headers['sec-websocket-protocol']; + const location = url.parse(req.url, true); + const accessToken = location.query.access_token || req.headers['sec-websocket-protocol']; if (!authorization && !accessToken) { if (required) { const err = new Error('Missing access token'); - err.statusCode = 401; + err.status = 401; - next(err); + reject(err); return; } else { - next(); + resolve(); return; } } const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken; - accountFromToken(token, allowedScopes, req, next); + resolve(accountFromToken(token, req)); + }); + + /** + * @param {any} req + * @return {string} + */ + const channelNameFromPath = req => { + const { path, query } = req; + const onlyMedia = isTruthy(query.only_media); + const allowLocalOnly = isTruthy(query.allow_local_only); + + switch(path) { + case '/api/v1/streaming/user': + return 'user'; + case '/api/v1/streaming/user/notification': + return 'user:notification'; + case '/api/v1/streaming/public': + return onlyMedia ? 'public:media' : 'public'; + case '/api/v1/streaming/public/local': + return onlyMedia ? 'public:local:media' : 'public:local'; + case '/api/v1/streaming/public/remote': + return onlyMedia ? 'public:remote:media' : 'public:remote'; + case '/api/v1/streaming/hashtag': + return 'hashtag'; + case '/api/v1/streaming/hashtag/local': + return 'hashtag:local'; + case '/api/v1/streaming/direct': + return 'direct'; + case '/api/v1/streaming/list': + return 'list'; + } }; - const PUBLIC_STREAMS = [ + const PUBLIC_CHANNELS = [ 'public', 'public:media', 'public:local', @@ -281,96 +377,149 @@ const startWorker = (workerId) => { 'hashtag:local', ]; - const wsVerifyClient = (info, cb) => { - const location = url.parse(info.req.url, true); - const authRequired = alwaysRequireAuth || !PUBLIC_STREAMS.some(stream => stream === location.query.stream); - const allowedScopes = []; + /** + * @param {any} req + * @param {string} channelName + * @return {Promise.<void>} + */ + const checkScopes = (req, channelName) => new Promise((resolve, reject) => { + log.silly(req.requestId, `Checking OAuth scopes for ${channelName}`); + + // When accessing public channels, no scopes are needed + if (PUBLIC_CHANNELS.includes(channelName)) { + resolve(); + return; + } - if (authRequired) { - allowedScopes.push('read'); - if (location.query.stream === 'user:notification') { - allowedScopes.push('read:notifications'); - } else { - allowedScopes.push('read:statuses'); - } + // The `read` scope has the highest priority, if the token has it + // then it can access all streams + const requiredScopes = ['read']; + + // When accessing specifically the notifications stream, + // we need a read:notifications, while in all other cases, + // we can allow access with read:statuses. Mind that the + // user stream will not contain notifications unless + // the token has either read or read:notifications scope + // as well, this is handled separately. + if (channelName === 'user:notification') { + requiredScopes.push('read:notifications'); + } else { + requiredScopes.push('read:statuses'); } - accountFromRequest(info.req, err => { - if (!err) { - cb(true, undefined, undefined); - } else { - log.error(info.req.requestId, err.toString()); - cb(false, 401, 'Unauthorized'); - } - }, authRequired, allowedScopes); - }; + if (requiredScopes.some(requiredScope => req.scopes.includes(requiredScope))) { + resolve(); + return; + } - const PUBLIC_ENDPOINTS = [ - '/api/v1/streaming/public', - '/api/v1/streaming/public/allow_local_only', - '/api/v1/streaming/public/local', - '/api/v1/streaming/public/remote', - '/api/v1/streaming/hashtag', - '/api/v1/streaming/hashtag/local', - ]; + const err = new Error('Access token does not cover required scopes'); + err.status = 401; + + reject(err); + }); + + /** + * @param {any} info + * @param {function(boolean, number, string): void} callback + */ + const wsVerifyClient = (info, callback) => { + // When verifying the websockets connection, we no longer pre-emptively + // check OAuth scopes and drop the connection if they're missing. We only + // drop the connection if access without token is not allowed by environment + // variables. OAuth scope checks are moved to the point of subscription + // to a specific stream. + + accountFromRequest(info.req, alwaysRequireAuth).then(() => { + callback(true, undefined, undefined); + }).catch(err => { + log.error(info.req.requestId, err.toString()); + callback(false, 401, 'Unauthorized'); + }); + }; + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} next + */ const authenticationMiddleware = (req, res, next) => { if (req.method === 'OPTIONS') { next(); return; } - const authRequired = alwaysRequireAuth || !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path); - const allowedScopes = []; - - if (authRequired) { - allowedScopes.push('read'); - if (req.path === '/api/v1/streaming/user/notification') { - allowedScopes.push('read:notifications'); - } else { - allowedScopes.push('read:statuses'); - } - } - - accountFromRequest(req, next, authRequired, allowedScopes); + accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => { + next(); + }).catch(err => { + next(err); + }); }; - const errorMiddleware = (err, req, res, {}) => { + /** + * @param {Error} err + * @param {any} req + * @param {any} res + * @param {function(Error=): void} next + */ + const errorMiddleware = (err, req, res, next) => { log.error(req.requestId, err.toString()); - res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ error: err.statusCode ? err.toString() : 'An unexpected error occurred' })); + + if (res.headersSent) { + return next(err); + } + + res.writeHead(err.status || 500, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: err.status ? err.toString() : 'An unexpected error occurred' })); }; + /** + * @param {array} + * @param {number=} shift + * @return {string} + */ const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', '); - const authorizeListAccess = (id, req, next) => { + /** + * @param {string} listId + * @param {any} req + * @return {Promise.<void>} + */ + const authorizeListAccess = (listId, req) => new Promise((resolve, reject) => { + const { accountId } = req; + pgPool.connect((err, client, done) => { if (err) { - next(false); + reject(); return; } - client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [id], (err, result) => { + client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [listId], (err, result) => { done(); - if (err || result.rows.length === 0 || result.rows[0].account_id !== req.accountId) { - next(false); + if (err || result.rows.length === 0 || result.rows[0].account_id !== accountId) { + reject(); return; } - next(true); + resolve(); }); }); - }; + }); + /** + * @param {string[]} ids + * @param {any} req + * @param {function(string, string): void} output + * @param {function(string[], function(string): void): void} attachCloseHandler + * @param {boolean=} needsFiltering + * @param {boolean=} notificationOnly + * @param {boolean=} allowLocalOnly + * @return {function(string): void} + */ const streamFrom = (ids, req, output, attachCloseHandler, needsFiltering = false, notificationOnly = false, allowLocalOnly = false) => { const accountId = req.accountId || req.remoteAddress; const streamType = notificationOnly ? ' (notification)' : ''; - if (!Array.isArray(ids)) { - ids = [ids]; - } - log.verbose(req.requestId, `Starting stream from ${ids.join(', ')} for ${accountId}${streamType}`); const listener = message => { @@ -454,10 +603,18 @@ const startWorker = (workerId) => { subscribe(`${redisPrefix}${id}`, listener); }); - attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener); + if (attachCloseHandler) { + attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener); + } + + return listener; }; - // Setup stream output to HTTP + /** + * @param {any} req + * @param {any} res + * @return {function(string, string): void} + */ const streamToHttp = (req, res) => { const accountId = req.accountId || req.remoteAddress; @@ -480,12 +637,12 @@ const startWorker = (workerId) => { }; }; - // Setup stream end for HTTP - const streamHttpEnd = (req, closeHandler = false) => (ids, listener) => { - if (!Array.isArray(ids)) { - ids = [ids]; - } - + /** + * @param {any} req + * @param {function(): void} [closeHandler] + * @return {function(string[], function(string): void)} + */ + const streamHttpEnd = (req, closeHandler = undefined) => (ids, listener) => { req.on('close', () => { ids.forEach(id => { unsubscribe(id, listener); @@ -497,37 +654,24 @@ const startWorker = (workerId) => { }); }; - // Setup stream output to WebSockets - const streamToWs = (req, ws) => (event, payload) => { + /** + * @param {any} req + * @param {any} ws + * @param {string[]} streamName + * @return {function(string, string): void} + */ + const streamToWs = (req, ws, streamName) => (event, payload) => { if (ws.readyState !== ws.OPEN) { log.error(req.requestId, 'Tried writing to closed socket'); return; } - ws.send(JSON.stringify({ event, payload })); - }; - - // Setup stream end for WebSockets - const streamWsEnd = (req, ws, closeHandler = false) => (id, listener) => { - const accountId = req.accountId || req.remoteAddress; - - ws.on('close', () => { - log.verbose(req.requestId, `Ending stream for ${accountId}`); - unsubscribe(id, listener); - if (closeHandler) { - closeHandler(); - } - }); - - ws.on('error', () => { - log.verbose(req.requestId, `Ending stream for ${accountId}`); - unsubscribe(id, listener); - if (closeHandler) { - closeHandler(); - } - }); + ws.send(JSON.stringify({ stream: streamName, event, payload })); }; + /** + * @param {any} res + */ const httpNotFound = res => { res.writeHead(404, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Not found' })); @@ -545,165 +689,281 @@ const startWorker = (workerId) => { app.use(authenticationMiddleware); app.use(errorMiddleware); - app.get('/api/v1/streaming/user', (req, res) => { - const channels = [`timeline:${req.accountId}`]; - - if (req.deviceId) { - channels.push(`timeline:${req.accountId}:${req.deviceId}`); - } - - streamFrom(channels, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channels))); - }); - - app.get('/api/v1/streaming/user/notification', (req, res) => { - streamFrom(`timeline:${req.accountId}`, req, streamToHttp(req, res), streamHttpEnd(req), false, true); - }); - - app.get('/api/v1/streaming/public', (req, res) => { - const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true'; - const channel = onlyMedia ? 'timeline:public:media' : 'timeline:public'; - - const allowLocalOnly = req.query.allow_local_only === '1' || req.query.allow_local_only === 'true'; - - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true, false, allowLocalOnly); - }); - - app.get('/api/v1/streaming/public/local', (req, res) => { - const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true'; - const channel = onlyMedia ? 'timeline:public:local:media' : 'timeline:public:local'; - - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true, false, true); - }); - - app.get('/api/v1/streaming/public/remote', (req, res) => { - const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true'; - const channel = onlyMedia ? 'timeline:public:remote:media' : 'timeline:public:remote'; - - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/direct', (req, res) => { - const channel = `timeline:direct:${req.accountId}`; - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channel)), true); - }); - - app.get('/api/v1/streaming/hashtag', (req, res) => { - const { tag } = req.query; - - if (!tag || tag.length === 0) { - httpNotFound(res); - return; - } - - streamFrom(`timeline:hashtag:${tag.toLowerCase()}`, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/hashtag/local', (req, res) => { - const { tag } = req.query; + app.get('/api/v1/streaming/*', (req, res) => { + channelNameToIds(req, channelNameFromPath(req), req.query).then(({ channelIds, options }) => { + const onSend = streamToHttp(req, res); + const onEnd = streamHttpEnd(req, subscriptionHeartbeat(channelIds)); - if (!tag || tag.length === 0) { + streamFrom(channelIds, req, onSend, onEnd, options.needsFiltering, options.notificationOnly, options.allowLocalOnly); + }).catch(err => { + log.verbose(req.requestId, 'Subscription error:', err.toString()); httpNotFound(res); - return; - } - - streamFrom(`timeline:hashtag:${tag.toLowerCase()}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/list', (req, res) => { - const listId = req.query.list; - - authorizeListAccess(listId, req, authorized => { - if (!authorized) { - httpNotFound(res); - return; - } - - const channel = `timeline:list:${listId}`; - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channel))); }); }); const wss = new WebSocketServer({ server, verifyClient: wsVerifyClient }); - wss.on('connection', (ws, req) => { - const location = url.parse(req.url, true); - req.requestId = uuid.v4(); - req.remoteAddress = ws._socket.remoteAddress; - - let channel; - - switch(location.query.stream) { + /** + * @typedef StreamParams + * @property {string} [tag] + * @property {string} [list] + * @property {string} [only_media] + */ + + /** + * @param {any} req + * @param {string} name + * @param {StreamParams} params + * @return {Promise.<{ channelIds: string[], options: { needsFiltering: boolean, notificationOnly: boolean } }>} + */ + const channelNameToIds = (req, name, params) => new Promise((resolve, reject) => { + switch(name) { case 'user': - channel = [`timeline:${req.accountId}`]; - - if (req.deviceId) { - channel.push(`timeline:${req.accountId}:${req.deviceId}`); - } + resolve({ + channelIds: req.deviceId ? [`timeline:${req.accountId}`, `timeline:${req.accountId}:${req.deviceId}`] : [`timeline:${req.accountId}`], + options: { needsFiltering: false, notificationOnly: false, allowLocalOnly: true }, + }); - streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel))); break; case 'user:notification': - streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws), false, true); + resolve({ + channelIds: [`timeline:${req.accountId}`], + options: { needsFiltering: false, notificationOnly: true, allowLocalOnly: true }, + }); + break; case 'public': - streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: isTruthy(params.allow_local_only) }, + }); + break; case 'public:allow_local_only': - streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true); + resolve({ + channelIds: ['timeline:public'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true }, + }); + break; case 'public:local': - streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true); + resolve({ + channelIds: ['timeline:public:local'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true }, + }); + break; case 'public:remote': - streamFrom('timeline:public:remote', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:remote'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: false }, + }); + break; case 'public:media': - streamFrom('timeline:public:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:media'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: isTruthy(query.allow_local_only) }, + }); + break; case 'public:allow_local_only:media': - streamFrom('timeline:public:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true); + resolve({ + channelIds: ['timeline:public:media'], + options: { needsFiltering: true, notificationsOnly: false, allowLocalOnly: true }, + }); + break; case 'public:local:media': - streamFrom('timeline:public:local:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true); + resolve({ + channelIds: ['timeline:public:local:media'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true }, + }); + break; case 'public:remote:media': - streamFrom('timeline:public:remote:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:remote:media'], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: false }, + }); + break; case 'direct': - channel = `timeline:direct:${req.accountId}`; - streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel)), true); + resolve({ + channelIds: [`timeline:direct:${req.accountId}`], + options: { needsFiltering: false, notificationOnly: false, allowLocalOnly: true }, + }); + break; case 'hashtag': - if (!location.query.tag || location.query.tag.length === 0) { - ws.close(); - return; + if (!params.tag || params.tag.length === 0) { + reject('No tag for stream provided'); + } else { + resolve({ + channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}`], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true }, + }); } - streamFrom(`timeline:hashtag:${location.query.tag.toLowerCase()}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); break; case 'hashtag:local': - if (!location.query.tag || location.query.tag.length === 0) { - ws.close(); - return; + if (!params.tag || params.tag.length === 0) { + reject('No tag for stream provided'); + } else { + resolve({ + channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}:local`], + options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true }, + }); } - streamFrom(`timeline:hashtag:${location.query.tag.toLowerCase()}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); break; case 'list': - const listId = location.query.list; - - authorizeListAccess(listId, req, authorized => { - if (!authorized) { - ws.close(); - return; - } - - channel = `timeline:list:${listId}`; - streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel))); + authorizeListAccess(params.list, req).then(() => { + resolve({ + channelIds: [`timeline:list:${params.list}`], + options: { needsFiltering: false, notificationOnly: false, allowLocalOnly: true }, + }); + }).catch(() => { + reject('Not authorized to stream this list'); }); + break; default: - ws.close(); + reject('Unknown stream type'); + } + }); + + /** + * @param {string} channelName + * @param {StreamParams} params + * @return {string[]} + */ + const streamNameFromChannelName = (channelName, params) => { + if (channelName === 'list') { + return [channelName, params.list]; + } else if (['hashtag', 'hashtag:local'].includes(channelName)) { + return [channelName, params.tag]; + } else { + return [channelName]; + } + }; + + /** + * @typedef WebSocketSession + * @property {any} socket + * @property {any} request + * @property {Object.<string, { listener: function(string): void, stopHeartbeat: function(): void }>} subscriptions + */ + + /** + * @param {WebSocketSession} session + * @param {string} channelName + * @param {StreamParams} params + */ + const subscribeWebsocketToChannel = ({ socket, request, subscriptions }, channelName, params) => + checkScopes(request, channelName).then(() => channelNameToIds(request, channelName, params)).then(({ channelIds, options }) => { + if (subscriptions[channelIds.join(';')]) { + return; + } + + const onSend = streamToWs(request, socket, streamNameFromChannelName(channelName, params)); + const stopHeartbeat = subscriptionHeartbeat(channelIds); + const listener = streamFrom(channelIds, request, onSend, undefined, options.needsFiltering, options.notificationOnly, options.allowLocalOnly); + + subscriptions[channelIds.join(';')] = { + listener, + stopHeartbeat, + }; + }).catch(err => { + log.verbose(request.requestId, 'Subscription error:', err.toString()); + socket.send(JSON.stringify({ error: err.toString() })); + }); + + /** + * @param {WebSocketSession} session + * @param {string} channelName + * @param {StreamParams} params + */ + const unsubscribeWebsocketFromChannel = ({ socket, request, subscriptions }, channelName, params) => + channelNameToIds(request, channelName, params).then(({ channelIds }) => { + log.verbose(request.requestId, `Ending stream from ${channelIds.join(', ')} for ${request.accountId}`); + + const { listener, stopHeartbeat } = subscriptions[channelIds.join(';')]; + + if (!listener) { + return; + } + + channelIds.forEach(channelId => { + unsubscribe(`${redisPrefix}${channelId}`, listener); + }); + + stopHeartbeat(); + + subscriptions[channelIds.join(';')] = undefined; + }).catch(err => { + log.verbose(request.requestId, 'Unsubscription error:', err); + socket.send(JSON.stringify({ error: err.toString() })); + }); + + /** + * @param {string|string[]} arrayOrString + * @return {string} + */ + const firstParam = arrayOrString => { + if (Array.isArray(arrayOrString)) { + return arrayOrString[0]; + } else { + return arrayOrString; + } + }; + + wss.on('connection', (ws, req) => { + const location = url.parse(req.url, true); + + req.requestId = uuid.v4(); + req.remoteAddress = ws._socket.remoteAddress; + + /** + * @type {WebSocketSession} + */ + const session = { + socket: ws, + request: req, + subscriptions: {}, + }; + + const onEnd = () => { + const keys = Object.keys(session.subscriptions); + + keys.forEach(channelIds => { + const { listener, stopHeartbeat } = session.subscriptions[channelIds]; + + channelIds.split(';').forEach(channelId => { + unsubscribe(`${redisPrefix}${channelId}`, listener); + }); + + stopHeartbeat(); + }); + }; + + ws.on('close', onEnd); + ws.on('error', onEnd); + + ws.on('message', data => { + const { type, stream, ...params } = JSON.parse(data); + + if (type === 'subscribe') { + subscribeWebsocketToChannel(session, firstParam(stream), params); + } else if (type === 'unsubscribe') { + unsubscribeWebsocketFromChannel(session, firstParam(stream), params) + } else { + // Unknown action type + } + }); + + if (location.query.stream) { + subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); } }); @@ -731,6 +991,10 @@ const startWorker = (workerId) => { process.on('uncaughtException', onError); }; +/** + * @param {any} server + * @param {function(string): void} [onSuccess] + */ const attachServerWithConfig = (server, onSuccess) => { if (process.env.SOCKET || process.env.PORT && isNaN(+process.env.PORT)) { server.listen(process.env.SOCKET || process.env.PORT, () => { @@ -748,6 +1012,9 @@ const attachServerWithConfig = (server, onSuccess) => { } }; +/** + * @param {function(Error=): void} onSuccess + */ const onPortAvailable = onSuccess => { const testServer = http.createServer(); |