diff --git a/packages/server/app.ts b/packages/server/app.ts index 56f2812be..19e6e21be 100644 --- a/packages/server/app.ts +++ b/packages/server/app.ts @@ -54,7 +54,6 @@ import { } from '@/modules/shared/helpers/envHelper' import * as ModulesSetup from '@/modules' import { GraphQLContext, Optional } from '@/modules/shared/helpers/typeHelper' -import { createRateLimiterMiddleware } from '@/modules/core/services/ratelimiter' import { get, has, isString } from 'lodash' import { corsMiddlewareFactory } from '@/modules/core/configs/cors' @@ -87,6 +86,7 @@ import { initiateRequestContextMiddleware } from '@/logging/requestContext' import { randomUUID } from 'crypto' +import { createRateLimiterMiddleware } from '@/modules/core/rest/ratelimiter' const GRAPHQL_PATH = '/graphql' diff --git a/packages/server/modules/auth/strategies/local.ts b/packages/server/modules/auth/strategies/local.ts index 1bc363ea7..8ab18a339 100644 --- a/packages/server/modules/auth/strategies/local.ts +++ b/packages/server/modules/auth/strategies/local.ts @@ -1,5 +1,4 @@ import { - sendRateLimitResponse, getRateLimitResult, isRateLimitBreached } from '@/modules/core/services/ratelimiter' @@ -24,6 +23,9 @@ import { } from '@/modules/core/domain/users/operations' import { GetServerInfo } from '@/modules/core/domain/server/operations' import { UserValidationError } from '@/modules/core/errors/user' +import { RateLimitError } from '@/modules/core/errors/ratelimit' +import { isRateLimiterEnabled } from '@/modules/shared/helpers/envHelper' +import { addRateLimitHeadersToResponse } from '@/modules/core/rest/ratelimiter' const localStrategyBuilderFactory = (deps: { @@ -97,12 +99,16 @@ const localStrategyBuilderFactory = if (!req.body.password) throw new UserInputError('Password missing') const user = req.body - const ip = getIpFromRequest(req) - if (ip) user.ip = ip - const source = ip ? ip : 'unknown' - const rateLimitResult = await deps.getRateLimitResult('USER_CREATE', source) - if (isRateLimitBreached(rateLimitResult)) { - return sendRateLimitResponse(res, rateLimitResult) + + if (isRateLimiterEnabled()) { + const ip = getIpFromRequest(req) + if (ip) user.ip = ip + const source = ip ? ip : 'unknown' + const rateLimitResult = await deps.getRateLimitResult('USER_CREATE', source) + if (isRateLimitBreached(rateLimitResult)) { + addRateLimitHeadersToResponse(res, rateLimitResult) + return next(new RateLimitError(rateLimitResult)) + } } // 1. if the server is invite only you must have an invite diff --git a/packages/server/modules/core/graph/resolvers/commits.ts b/packages/server/modules/core/graph/resolvers/commits.ts index 1a8711637..795b40c28 100644 --- a/packages/server/modules/core/graph/resolvers/commits.ts +++ b/packages/server/modules/core/graph/resolvers/commits.ts @@ -75,6 +75,7 @@ import { import { LegacyUserCommit } from '@/modules/core/domain/commits/types' import coreModule from '@/modules/core' import { getEventBus } from '@/modules/shared/services/eventBus' +import { isRateLimiterEnabled } from '@/modules/shared/helpers/envHelper' const getStreams = getStreamsFactory({ db }) @@ -340,9 +341,14 @@ export = { context.resourceAccessRules ) - const rateLimitResult = await getRateLimitResult('COMMIT_CREATE', context.userId!) - if (isRateLimitBreached(rateLimitResult)) { - throw new RateLimitError(rateLimitResult) + if (isRateLimiterEnabled()) { + const rateLimitResult = await getRateLimitResult( + 'COMMIT_CREATE', + context.userId! + ) + if (isRateLimitBreached(rateLimitResult)) { + throw new RateLimitError(rateLimitResult) + } } const createCommitByBranchId = createCommitByBranchIdFactory({ diff --git a/packages/server/modules/core/graph/resolvers/projects.ts b/packages/server/modules/core/graph/resolvers/projects.ts index 34df2b857..b82fd62a5 100644 --- a/packages/server/modules/core/graph/resolvers/projects.ts +++ b/packages/server/modules/core/graph/resolvers/projects.ts @@ -91,6 +91,7 @@ import { createAndSendInviteFactory } from '@/modules/serverinvites/services/cre import { inviteUsersToProjectFactory } from '@/modules/serverinvites/services/projectInviteManagement' import { authorizeResolver, validateScopes } from '@/modules/shared' import { throwForNotHavingServerRole } from '@/modules/shared/authz' +import { isRateLimiterEnabled } from '@/modules/shared/helpers/envHelper' import { getEventBus } from '@/modules/shared/services/eventBus' import { filteredSubscribe, @@ -289,9 +290,14 @@ export = { }, // This one is only used outside of a workspace, so the project is always created in the main db async create(_parent, args, context) { - const rateLimitResult = await getRateLimitResult('STREAM_CREATE', context.userId!) - if (isRateLimitBreached(rateLimitResult)) { - throw new RateLimitError(rateLimitResult) + if (isRateLimiterEnabled()) { + const rateLimitResult = await getRateLimitResult( + 'STREAM_CREATE', + context.userId! + ) + if (isRateLimitBreached(rateLimitResult)) { + throw new RateLimitError(rateLimitResult) + } } const regionKey = await getValidDefaultProjectRegionKey() diff --git a/packages/server/modules/core/graph/resolvers/streams.ts b/packages/server/modules/core/graph/resolvers/streams.ts index c244be87b..b22983ee0 100644 --- a/packages/server/modules/core/graph/resolvers/streams.ts +++ b/packages/server/modules/core/graph/resolvers/streams.ts @@ -82,7 +82,10 @@ import { } from '@/modules/core/services/streams/favorite' import { getUserFactory, getUsersFactory } from '@/modules/core/repositories/users' import { getServerInfoFactory } from '@/modules/core/repositories/server' -import { adminOverrideEnabled } from '@/modules/shared/helpers/envHelper' +import { + adminOverrideEnabled, + isRateLimiterEnabled +} from '@/modules/shared/helpers/envHelper' const getServerInfo = getServerInfoFactory({ db }) const getUsers = getUsersFactory({ db }) @@ -435,9 +438,14 @@ export = { }, Mutation: { async streamCreate(_, args, context) { - const rateLimitResult = await getRateLimitResult('STREAM_CREATE', context.userId!) - if (isRateLimitBreached(rateLimitResult)) { - throw new RateLimitError(rateLimitResult) + if (isRateLimiterEnabled()) { + const rateLimitResult = await getRateLimitResult( + 'STREAM_CREATE', + context.userId! + ) + if (isRateLimitBreached(rateLimitResult)) { + throw new RateLimitError(rateLimitResult) + } } const { id } = await createStreamReturnRecord({ diff --git a/packages/server/modules/core/graph/resolvers/versions.ts b/packages/server/modules/core/graph/resolvers/versions.ts index 475a039fc..572df3316 100644 --- a/packages/server/modules/core/graph/resolvers/versions.ts +++ b/packages/server/modules/core/graph/resolvers/versions.ts @@ -5,7 +5,10 @@ import { filteredSubscribe, ProjectSubscriptions } from '@/modules/shared/utils/subscriptions' -import { getServerOrigin } from '@/modules/shared/helpers/envHelper' +import { + getServerOrigin, + isRateLimiterEnabled +} from '@/modules/shared/helpers/envHelper' import { batchDeleteCommitsFactory, batchMoveCommitsFactory @@ -169,9 +172,11 @@ export = { projectId: args.input.projectId }) - const rateLimitResult = await getRateLimitResult('COMMIT_CREATE', ctx.userId!) - if (isRateLimitBreached(rateLimitResult)) { - throw new RateLimitError(rateLimitResult) + if (isRateLimiterEnabled()) { + const rateLimitResult = await getRateLimitResult('COMMIT_CREATE', ctx.userId!) + if (isRateLimitBreached(rateLimitResult)) { + throw new RateLimitError(rateLimitResult) + } } const projectDb = await getProjectDbClient({ projectId: args.input.projectId }) diff --git a/packages/server/modules/core/rest/ratelimiter.ts b/packages/server/modules/core/rest/ratelimiter.ts new file mode 100644 index 000000000..fcb7848c3 --- /dev/null +++ b/packages/server/modules/core/rest/ratelimiter.ts @@ -0,0 +1,69 @@ +import type { RequestHandler, Response } from 'express' +import { + getActionForPath, + getRateLimitResult, + getSourceFromRequest, + isRateLimitBreached, + RATE_LIMITERS, + type RateLimitBreached, + type RateLimiterMapping +} from '@/modules/core/services/ratelimiter' +import { isRateLimiterEnabled } from '@/modules/shared/helpers/envHelper' +import { getRequestPath } from '@/modules/core/helpers/server' +import { RateLimitError } from '@/modules/core/errors/ratelimit' +import { ensureError } from '@speckle/shared' + +export const createRateLimiterMiddleware = ( + rateLimiterMapping: RateLimiterMapping = RATE_LIMITERS +): RequestHandler => { + return async (req, res, next) => { + if (!isRateLimiterEnabled()) return next() + const path = getRequestPath(req) || '' + const action = getActionForPath(path, req.method) + const source = getSourceFromRequest(req) + try { + const rateLimitResult = await getRateLimitResult( + action, + source, + rateLimiterMapping + ) + if (isRateLimitBreached(rateLimitResult)) { + addRateLimitHeadersToResponse(res, rateLimitResult) + return next(new RateLimitError(rateLimitResult)) + } else { + if (res.headersSent) return res + res.setHeader('X-RateLimit-Remaining', rateLimitResult.remainingPoints) + return next() + } + } catch (err) { + const e = !(err instanceof RateLimitError) + ? new RateLimitError( + { + isWithinLimits: false, + msBeforeNext: 0, + action + }, + 'Unknown rate limit error', + { cause: ensureError(err) } + ) + : err + + addRateLimitHeadersToResponse(res, e.rateLimitBreached) + return next(e) + } + } +} + +export const addRateLimitHeadersToResponse = ( + res: Response, + rateLimitBreached: RateLimitBreached +) => { + if (res.headersSent) return res + res.setHeader('Retry-After', rateLimitBreached.msBeforeNext / 1000) + res.removeHeader('X-RateLimit-Remaining') + res.setHeader( + 'X-RateLimit-Reset', + new Date(Date.now() + rateLimitBreached.msBeforeNext).toISOString() + ) + res.setHeader('X-Speckle-Meditation', 'https://http.cat/429') +} diff --git a/packages/server/modules/core/services/ratelimiter.ts b/packages/server/modules/core/services/ratelimiter.ts index fd490118f..792ca436f 100644 --- a/packages/server/modules/core/services/ratelimiter.ts +++ b/packages/server/modules/core/services/ratelimiter.ts @@ -1,9 +1,5 @@ import express from 'express' -import { - getRedisUrl, - getIntFromEnv, - getBooleanFromEnv -} from '@/modules/shared/helpers/envHelper' +import { getRedisUrl, getIntFromEnv } from '@/modules/shared/helpers/envHelper' import { BurstyRateLimiter, RateLimiterAbstract, @@ -13,10 +9,8 @@ import { } from 'rate-limiter-flexible' import { TIME } from '@speckle/shared' import { getIpFromRequest } from '@/modules/shared/utils/ip' -import { RateLimitError } from '@/modules/core/errors/ratelimit' import { rateLimiterLogger } from '@/logging/logging' import { createRedisClient } from '@/modules/shared/redis/redis' -import { getRequestPath } from '@/modules/core/helpers/server' import { getTokenFromRequest } from '@/modules/shared/middleware' export interface RateLimitResult { @@ -56,10 +50,6 @@ export type RateLimiterMapping = { export type RateLimitAction = keyof typeof LIMITS -export const isRateLimiterEnabled = (): boolean => { - return getBooleanFromEnv('RATELIMITER_ENABLED', true) -} - export const LIMITS = { ALL_REQUESTS: { regularOptions: { @@ -272,23 +262,6 @@ export const LIMITS = { export const allActions = Object.keys(LIMITS) as RateLimitAction[] -export const sendRateLimitResponse = ( - res: express.Response, - rateLimitBreached: RateLimitBreached -): express.Response => { - if (res.headersSent) return res - res.setHeader('Retry-After', rateLimitBreached.msBeforeNext / 1000) - res.removeHeader('X-RateLimit-Remaining') - res.setHeader( - 'X-RateLimit-Reset', - new Date(Date.now() + rateLimitBreached.msBeforeNext).toISOString() - ) - res.setHeader('X-Speckle-Meditation', 'https://http.cat/429') - return res.status(429).send({ - err: 'You are sending too many requests. You have been rate limited. Please try again later.' - }) -} - export const getActionForPath = (path: string, verb: string): RateLimitAction => { const maybeAction = `${verb} ${path}` as RateLimitAction const maybeActionNoVerb = path as RateLimitAction @@ -308,35 +281,6 @@ export const getSourceFromRequest = (req: express.Request): string => { return source } -export const createRateLimiterMiddleware = ( - rateLimiterMapping: RateLimiterMapping = RATE_LIMITERS -) => { - return async ( - req: express.Request, - res: express.Response, - next: express.NextFunction - ) => { - if (!isRateLimiterEnabled()) return next() - const path = getRequestPath(req) || '' - const action = getActionForPath(path, req.method) - const source = getSourceFromRequest(req) - - const rateLimitResult = await getRateLimitResult(action, source, rateLimiterMapping) - if (isRateLimitBreached(rateLimitResult)) { - return sendRateLimitResponse(res, rateLimitResult) - } else { - try { - if (res.headersSent) return res - res.setHeader('X-RateLimit-Remaining', rateLimitResult.remainingPoints) - return next() - } catch (err) { - if (!(err instanceof RateLimitError)) throw err - return sendRateLimitResponse(res, err.rateLimitBreached) - } - } - } -} - // we need to take the `BurstyRateLimiter` specific type because // its not considered as an RateLimiterAbstract in the rate-limiter-flexible package // This is just a rant comment, but why define the Abstract then if not diff --git a/packages/server/modules/core/tests/ratelimiter.spec.ts b/packages/server/modules/core/tests/ratelimiter.spec.ts index 557b7241f..0318075bb 100644 --- a/packages/server/modules/core/tests/ratelimiter.spec.ts +++ b/packages/server/modules/core/tests/ratelimiter.spec.ts @@ -1,11 +1,9 @@ /* istanbul ignore file */ import { TIME } from '@speckle/shared' import { - createRateLimiterMiddleware, getRateLimitResult, isRateLimitBreached, getActionForPath, - sendRateLimitResponse, RateLimitBreached, RateLimits, createConsumer, @@ -16,6 +14,10 @@ import { import { expect } from 'chai' import httpMocks from 'node-mocks-http' import { RateLimiterMemory } from 'rate-limiter-flexible' +import { + addRateLimitHeadersToResponse, + createRateLimiterMiddleware +} from '@/modules/core/rest/ratelimiter' type RateLimiterOptions = { [key in RateLimitAction]: RateLimits @@ -90,7 +92,7 @@ describe('Rate Limiting', () => { msBeforeNext: 4900 } const response = httpMocks.createResponse() - await sendRateLimitResponse(response, breached) + await addRateLimitHeadersToResponse(response, breached) assert429response(response) }) }) @@ -139,9 +141,16 @@ describe('Rate Limiting', () => { }) let response = httpMocks.createResponse() - let nextCalled = 0 - const next = () => { - nextCalled++ + let nextCalledWithErr = 0 + let nextCalledWithoutErr = 0 + const next = (err: unknown) => { + if (err) { + nextCalledWithErr++ + } else { + nextCalledWithoutErr++ + } + expect(err).to.not.be.undefined + expect(err).to.have.property('rateLimitBreached') } const SUT = createRateLimiterMiddleware(createTestRateLimiterMappings()) @@ -151,7 +160,8 @@ describe('Rate Limiting', () => { await SUT(request, response, next) }) - expect(nextCalled).to.equal(0) + expect(nextCalledWithErr).to.equal(1) + expect(nextCalledWithoutErr).to.equal(0) assert429response(response) }) }) @@ -170,5 +180,5 @@ const assert429response = (response: any) => { expect(response.getHeader('X-RateLimit-Remaining')).to.be.undefined expect(response.getHeader('Retry-After')).to.be.greaterThanOrEqual(4) expect(response.getHeader('X-RateLimit-Reset')).to.not.be.undefined - expect(response.statusCode).to.equal(429) + // expect(response.statusCode).to.equal(429) // response status code is added by the error handler, which is not part of this integration test } diff --git a/packages/server/modules/gendo/graph/resolvers/index.ts b/packages/server/modules/gendo/graph/resolvers/index.ts index 86bbc51e5..873420f09 100644 --- a/packages/server/modules/gendo/graph/resolvers/index.ts +++ b/packages/server/modules/gendo/graph/resolvers/index.ts @@ -36,7 +36,8 @@ import { getGendoAIKey, getGendoAICreditLimit, getServerOrigin, - getFeatureFlags + getFeatureFlags, + isRateLimiterEnabled } from '@/modules/shared/helpers/envHelper' import { getProjectObjectStorage } from '@/modules/multiregion/utils/blobStorageSelector' import { storeFileStreamFactory } from '@/modules/blobstorage/repositories/blobs' @@ -86,12 +87,14 @@ export = FF_GENDOAI_MODULE_ENABLED ctx.resourceAccessRules ) - const rateLimitResult = await getRateLimitResult( - 'GENDO_AI_RENDER_REQUEST', - ctx.userId as string - ) - if (isRateLimitBreached(rateLimitResult)) { - throw new RateLimitError(rateLimitResult) + if (isRateLimiterEnabled()) { + const rateLimitResult = await getRateLimitResult( + 'GENDO_AI_RENDER_REQUEST', + ctx.userId as string + ) + if (isRateLimitBreached(rateLimitResult)) { + throw new RateLimitError(rateLimitResult) + } } const userId = ctx.userId! diff --git a/packages/server/modules/shared/helpers/envHelper.ts b/packages/server/modules/shared/helpers/envHelper.ts index 72602e43b..d145b535c 100644 --- a/packages/server/modules/shared/helpers/envHelper.ts +++ b/packages/server/modules/shared/helpers/envHelper.ts @@ -451,3 +451,7 @@ export function enableImprovedKnexTelemetryStackTraces() { export function disablePreviews() { return getBooleanFromEnv('DISABLE_PREVIEWS') } + +export const isRateLimiterEnabled = (): boolean => { + return getBooleanFromEnv('RATELIMITER_ENABLED', true) +} diff --git a/packages/server/modules/workspaces/graph/resolvers/workspaces.ts b/packages/server/modules/workspaces/graph/resolvers/workspaces.ts index 83b37eef7..2c65ffa5f 100644 --- a/packages/server/modules/workspaces/graph/resolvers/workspaces.ts +++ b/packages/server/modules/workspaces/graph/resolvers/workspaces.ts @@ -41,7 +41,11 @@ import { import { createProjectInviteFactory } from '@/modules/serverinvites/services/projectInviteManagement' import { getInvitationTargetUsersFactory } from '@/modules/serverinvites/services/retrieval' import { authorizeResolver } from '@/modules/shared' -import { getFeatureFlags, getServerOrigin } from '@/modules/shared/helpers/envHelper' +import { + getFeatureFlags, + getServerOrigin, + isRateLimiterEnabled +} from '@/modules/shared/helpers/envHelper' import { getEventBus } from '@/modules/shared/services/eventBus' import { WorkspaceInviteResourceType } from '@/modules/workspacesCore/domain/constants' import { @@ -933,12 +937,14 @@ export = FF_WORKSPACES_MODULE_ENABLED }, WorkspaceProjectMutations: { create: async (_parent, args, context) => { - const rateLimitResult = await getRateLimitResult( - 'STREAM_CREATE', - context.userId! - ) - if (isRateLimitBreached(rateLimitResult)) { - throw new RateLimitError(rateLimitResult) + if (isRateLimiterEnabled()) { + const rateLimitResult = await getRateLimitResult( + 'STREAM_CREATE', + context.userId! + ) + if (isRateLimitBreached(rateLimitResult)) { + throw new RateLimitError(rateLimitResult) + } } await authorizeResolver(