diff options
-rw-r--r-- | streaming/index.js | 97 |
1 files changed, 70 insertions, 27 deletions
diff --git a/streaming/index.js b/streaming/index.js index c79a58671..31c597cf0 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -97,6 +97,8 @@ const startWorker = (workerId) => { }; const app = express(); + app.set('trusted proxy', process.env.TRUSTED_PROXY_IP || 'loopback,uniquelocal'); + const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); const server = http.createServer(app); const redisNamespace = process.env.REDIS_NAMESPACE || null; @@ -177,6 +179,12 @@ const startWorker = (workerId) => { next(); }; + const setRemoteAddress = (req, res, next) => { + req.remoteAddress = req.connection.remoteAddress; + + next(); + }; + const accountFromToken = (token, req, next) => { pgPool.connect((err, client, done) => { if (err) { @@ -208,17 +216,22 @@ const startWorker = (workerId) => { }); }; - const accountFromRequest = (req, next) => { + const accountFromRequest = (req, next, required = true) => { const authorization = req.headers.authorization; const location = url.parse(req.url, true); const accessToken = location.query.access_token; if (!authorization && !accessToken) { - const err = new Error('Missing access token'); - err.statusCode = 401; + if (required) { + const err = new Error('Missing access token'); + err.statusCode = 401; - next(err); - return; + next(err); + return; + } else { + next(); + return; + } } const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken; @@ -226,7 +239,17 @@ const startWorker = (workerId) => { accountFromToken(token, req, next); }; + const PUBLIC_STREAMS = [ + 'public', + 'public:local', + 'hashtag', + 'hashtag:local', + ]; + const wsVerifyClient = (info, cb) => { + const location = url.parse(info.req.url, true); + const authRequired = !PUBLIC_STREAMS.some(stream => stream === location.query.stream); + accountFromRequest(info.req, err => { if (!err) { cb(true, undefined, undefined); @@ -234,16 +257,24 @@ const startWorker = (workerId) => { log.error(info.req.requestId, err.toString()); cb(false, 401, 'Unauthorized'); } - }); + }, authRequired); }; + const PUBLIC_ENDPOINTS = [ + '/api/v1/streaming/public', + '/api/v1/streaming/public/local', + '/api/v1/streaming/hashtag', + '/api/v1/streaming/hashtag/local', + ]; + const authenticationMiddleware = (req, res, next) => { if (req.method === 'OPTIONS') { next(); return; } - accountFromRequest(req, next); + const authRequired = !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path); + accountFromRequest(req, next, authRequired); }; const errorMiddleware = (err, req, res, {}) => { @@ -275,8 +306,10 @@ const startWorker = (workerId) => { }; const streamFrom = (id, req, output, attachCloseHandler, needsFiltering = false, notificationOnly = false) => { + const accountId = req.accountId || req.remoteAddress; + const streamType = notificationOnly ? ' (notification)' : ''; - log.verbose(req.requestId, `Starting stream from ${id} for ${req.accountId}${streamType}`); + log.verbose(req.requestId, `Starting stream from ${id} for ${accountId}${streamType}`); const listener = message => { const { event, payload, queued_at } = JSON.parse(message); @@ -286,7 +319,7 @@ const startWorker = (workerId) => { const delta = now - queued_at; const encodedPayload = typeof payload === 'object' ? JSON.stringify(payload) : payload; - log.silly(req.requestId, `Transmitting for ${req.accountId}: ${event} ${encodedPayload} Delay: ${delta}ms`); + log.silly(req.requestId, `Transmitting for ${accountId}: ${event} ${encodedPayload} Delay: ${delta}ms`); output(event, encodedPayload); }; @@ -313,26 +346,30 @@ const startWorker = (workerId) => { return; } - const queries = [ - client.query(`SELECT 1 FROM blocks WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) OR (account_id = $2 AND target_account_id = $1) UNION SELECT 1 FROM mutes WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`, [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)), - ]; + if (!req.accountId) { + const queries = [ + client.query(`SELECT 1 FROM blocks WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) OR (account_id = $2 AND target_account_id = $1) UNION SELECT 1 FROM mutes WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`, [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)), + ]; - if (accountDomain) { - queries.push(client.query('SELECT 1 FROM account_domain_blocks WHERE account_id = $1 AND domain = $2', [req.accountId, accountDomain])); - } + if (accountDomain) { + queries.push(client.query('SELECT 1 FROM account_domain_blocks WHERE account_id = $1 AND domain = $2', [req.accountId, accountDomain])); + } - Promise.all(queries).then(values => { - done(); + Promise.all(queries).then(values => { + done(); - if (values[0].rows.length > 0 || (values.length > 1 && values[1].rows.length > 0)) { - return; - } + if (values[0].rows.length > 0 || (values.length > 1 && values[1].rows.length > 0)) { + return; + } + transmit(); + }).catch(err => { + done(); + log.error(err); + }); + } else { transmit(); - }).catch(err => { - done(); - log.error(err); - }); + } }); } else { transmit(); @@ -345,13 +382,15 @@ const startWorker = (workerId) => { // Setup stream output to HTTP const streamToHttp = (req, res) => { + const accountId = req.accountId || req.remoteAddress; + res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Transfer-Encoding', 'chunked'); const heartbeat = setInterval(() => res.write(':thump\n'), 15000); req.on('close', () => { - log.verbose(req.requestId, `Ending stream for ${req.accountId}`); + log.verbose(req.requestId, `Ending stream for ${accountId}`); clearInterval(heartbeat); }); @@ -383,8 +422,10 @@ const startWorker = (workerId) => { // Setup stream end for WebSockets const streamWsEnd = (req, ws, closeHandler = false) => (id, listener) => { + const accountId = req.accountId || req.remoteAddress; + ws.on('close', () => { - log.verbose(req.requestId, `Ending stream for ${req.accountId}`); + log.verbose(req.requestId, `Ending stream for ${accountId}`); unsubscribe(id, listener); if (closeHandler) { closeHandler(); @@ -392,7 +433,7 @@ const startWorker = (workerId) => { }); ws.on('error', () => { - log.verbose(req.requestId, `Ending stream for ${req.accountId}`); + log.verbose(req.requestId, `Ending stream for ${accountId}`); unsubscribe(id, listener); if (closeHandler) { closeHandler(); @@ -401,6 +442,7 @@ const startWorker = (workerId) => { }; app.use(setRequestId); + app.use(setRemoteAddress); app.use(allowCrossDomain); app.use(authenticationMiddleware); app.use(errorMiddleware); @@ -451,6 +493,7 @@ const startWorker = (workerId) => { const req = ws.upgradeReq; const location = url.parse(req.url, true); req.requestId = uuid.v4(); + req.remoteAddress = ws._socket.remoteAddress; ws.isAlive = true; |