diff --git a/.circleci/config.yml b/.circleci/config.yml index 7157612f7..a01a41fb1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -439,11 +439,12 @@ jobs: docker: - image: cimg/node:18.19.0 - image: cimg/redis:7.2.4 - - image: cimg/postgres:14.11 + - image: cimg/postgres:17.0 environment: POSTGRES_DB: speckle2_test POSTGRES_PASSWORD: speckle POSTGRES_USER: speckle + command: -c 'max_connections=1000' - image: 'minio/minio' command: server /data --console-address ":9001" # environment: @@ -453,6 +454,7 @@ jobs: NODE_ENV: test DATABASE_URL: 'postgres://speckle:speckle@127.0.0.1:5432/speckle2_test' PGDATABASE: speckle2_test + POSTGRES_MAX_CONNECTIONS_SERVER: 20 PGUSER: speckle SESSION_SECRET: 'keyboard cat' STRATEGY_LOCAL: 'true' diff --git a/packages/server/.vscode/launch.json b/packages/server/.vscode/launch.json index 19791fef4..e2c68a644 100644 --- a/packages/server/.vscode/launch.json +++ b/packages/server/.vscode/launch.json @@ -47,11 +47,11 @@ "console": "integratedTerminal" }, { - "args": ["-g='@ChunkInsertionObject'", "--timeout=10000", "--exit"], + "args": ["-f='should throw and preserve all roles'", "--timeout=0", "--exit"], // "envFile": "${workspaceFolder}/.env", "env": { - "PORT": "0", - "POSTGRES_URL": "postgresql://127.0.0.1:5432/speckle2_test" + "PORT": "0" + // "POSTGRES_URL": "postgresql://127.0.0.1:5432/speckle2_test" // "POSTGRES_USER": "speckle", // "POSTGRES_PASSWORD": "speckle", // "POSTGRES_DB": "speckle2_test", @@ -59,7 +59,7 @@ }, "internalConsoleOptions": "openOnSessionStart", "name": "Mocha Tests", - "program": "${workspaceFolder}/node_modules/mocha/bin/_mocha", + "program": "${workspaceFolder}../../../node_modules/mocha/bin/_mocha", "request": "launch", "skipFiles": ["/**"], "type": "node" diff --git a/packages/server/modules/activitystream/index.ts b/packages/server/modules/activitystream/index.ts index cb8a38d19..f893c6b2b 100644 --- a/packages/server/modules/activitystream/index.ts +++ b/packages/server/modules/activitystream/index.ts @@ -20,7 +20,10 @@ import { } from '@/modules/activitystream/services/accessRequestActivity' import { ScheduleExecution } from '@/modules/core/domain/scheduledTasks/operations' import { scheduleExecutionFactory } from '@/modules/core/services/taskScheduler' -import { acquireTaskLockFactory } from '@/modules/core/repositories/scheduledTasks' +import { + acquireTaskLockFactory, + releaseTaskLockFactory +} from '@/modules/core/repositories/scheduledTasks' let scheduledTask: ReturnType | null = null let quitEventListeners: Optional> = @@ -44,7 +47,8 @@ const initializeEventListeners = () => { const scheduleWeeklyActivityNotifications = () => { const scheduleExecution = scheduleExecutionFactory({ - acquireTaskLock: acquireTaskLockFactory({ db }) + acquireTaskLock: acquireTaskLockFactory({ db }), + releaseTaskLock: releaseTaskLockFactory({ db }) }) // just to test stuff diff --git a/packages/server/modules/core/domain/scheduledTasks/operations.ts b/packages/server/modules/core/domain/scheduledTasks/operations.ts index e1f51a858..906f51357 100644 --- a/packages/server/modules/core/domain/scheduledTasks/operations.ts +++ b/packages/server/modules/core/domain/scheduledTasks/operations.ts @@ -5,6 +5,8 @@ export type AcquireTaskLock = ( scheduledTask: ScheduledTask ) => Promise +export type ReleaseTaskLock = (args: { taskName: string }) => Promise + export type ScheduleExecution = ( cronExpression: string, taskName: string, diff --git a/packages/server/modules/core/domain/scheduledTasks/types.ts b/packages/server/modules/core/domain/scheduledTasks/types.ts index e2b6d11ad..4e5f62379 100644 --- a/packages/server/modules/core/domain/scheduledTasks/types.ts +++ b/packages/server/modules/core/domain/scheduledTasks/types.ts @@ -1,3 +1,4 @@ -import { ScheduledTaskRecord } from '@/modules/core/helpers/types' - -export type ScheduledTask = ScheduledTaskRecord +export type ScheduledTask = { + taskName: string + lockExpiresAt: Date +} diff --git a/packages/server/modules/core/helpers/types.ts b/packages/server/modules/core/helpers/types.ts index b5eac097f..0178da52b 100644 --- a/packages/server/modules/core/helpers/types.ts +++ b/packages/server/modules/core/helpers/types.ts @@ -125,11 +125,6 @@ export type BranchRecord = { updatedAt: Date } -export type ScheduledTaskRecord = { - taskName: string - lockExpiresAt: Date -} - export type ObjectRecord = { id: string speckleType: string diff --git a/packages/server/modules/core/repositories/scheduledTasks.ts b/packages/server/modules/core/repositories/scheduledTasks.ts index eb7f2ee96..56fe64cdf 100644 --- a/packages/server/modules/core/repositories/scheduledTasks.ts +++ b/packages/server/modules/core/repositories/scheduledTasks.ts @@ -1,22 +1,31 @@ import { ScheduledTasks } from '@/modules/core/dbSchema' -import { AcquireTaskLock } from '@/modules/core/domain/scheduledTasks/operations' -import { ScheduledTaskRecord } from '@/modules/core/helpers/types' +import { + AcquireTaskLock, + ReleaseTaskLock +} from '@/modules/core/domain/scheduledTasks/operations' +import { ScheduledTask } from '@/modules/core/domain/scheduledTasks/types' import { Knex } from 'knex' const tables = { - scheduledTasks: (db: Knex) => db(ScheduledTasks.name) + scheduledTasks: (db: Knex) => db(ScheduledTasks.name) } export const acquireTaskLockFactory = - (deps: { db: Knex }): AcquireTaskLock => - async (scheduledTask: ScheduledTaskRecord): Promise => { + ({ db }: { db: Knex }): AcquireTaskLock => + async (scheduledTask) => { const now = new Date() const [lock] = await tables - .scheduledTasks(deps.db) + .scheduledTasks(db) .insert(scheduledTask) .onConflict(ScheduledTasks.withoutTablePrefix.col.taskName) .merge() .where(ScheduledTasks.col.lockExpiresAt, '<', now) .returning('*') - return (lock as ScheduledTaskRecord) ?? null + return lock ?? null + } + +export const releaseTaskLockFactory = + ({ db }: { db: Knex }): ReleaseTaskLock => + async ({ taskName }) => { + await tables.scheduledTasks(db).where({ taskName }).delete() } diff --git a/packages/server/modules/core/services/taskScheduler.ts b/packages/server/modules/core/services/taskScheduler.ts index 3b0a7163b..7554c000f 100644 --- a/packages/server/modules/core/services/taskScheduler.ts +++ b/packages/server/modules/core/services/taskScheduler.ts @@ -1,9 +1,10 @@ import cron from 'node-cron' import { InvalidArgumentError } from '@/modules/shared/errors' import { ensureError } from '@/modules/shared/helpers/errorHelper' -import { activitiesLogger } from '@/logging/logging' +import { logger } from '@/logging/logging' import { AcquireTaskLock, + ReleaseTaskLock, ScheduleExecution } from '@/modules/core/domain/scheduledTasks/operations' @@ -12,22 +13,20 @@ export const scheduledCallbackWrapper = async ( taskName: string, lockTimeout: number, callback: (scheduledTime: Date) => Promise, - acquireLock: AcquireTaskLock + acquireLock: AcquireTaskLock, + releaseTaskLock: ReleaseTaskLock ) => { - const boundLogger = activitiesLogger.child({ taskName }) + const boundLogger = logger.child({ taskName }) // try to acquire the task lock with the function name and a new expiration date const lockExpiresAt = new Date(scheduledTime.getTime() + lockTimeout) + const lock = await acquireLock({ taskName, lockExpiresAt }) + + // if couldn't acquire it, stop execution + if (!lock) { + boundLogger.warn(`Could not acquire task lock for ${taskName}, stopping execution.`) + return + } try { - const lock = await acquireLock({ taskName, lockExpiresAt }) - - // if couldn't acquire it, stop execution - if (!lock) { - boundLogger.warn( - `Could not acquire task lock for ${taskName}, stopping execution.` - ) - return null - } - // else continue executing the callback... boundLogger.info(`Executing scheduled function ${taskName} at ${scheduledTime}`) await callback(scheduledTime) @@ -45,11 +44,19 @@ export const scheduledCallbackWrapper = async ( ensureError(error, 'unknown reason').message }` ) + } finally { + releaseTaskLock(lock) } } export const scheduleExecutionFactory = - (deps: { acquireTaskLock: AcquireTaskLock }): ScheduleExecution => + ({ + acquireTaskLock, + releaseTaskLock + }: { + acquireTaskLock: AcquireTaskLock + releaseTaskLock: ReleaseTaskLock + }): ScheduleExecution => ( cronExpression: string, taskName: string, @@ -67,7 +74,8 @@ export const scheduleExecutionFactory = taskName, lockTimeout, callback, - deps.acquireTaskLock + acquireTaskLock, + releaseTaskLock ) }) } diff --git a/packages/server/modules/core/tests/integration/scheduledTasks.spec.ts b/packages/server/modules/core/tests/integration/scheduledTasks.spec.ts new file mode 100644 index 000000000..2a54cbdd0 --- /dev/null +++ b/packages/server/modules/core/tests/integration/scheduledTasks.spec.ts @@ -0,0 +1,62 @@ +import { db } from '@/db/knex' +import { + acquireTaskLockFactory, + releaseTaskLockFactory +} from '@/modules/core/repositories/scheduledTasks' +import { expect } from 'chai' +import cryptoRandomString from 'crypto-random-string' + +describe('scheduledTasks repositories @core', () => { + describe('acquireTaskLockFactory creates a function, that', () => { + it('returns the inserted task lock', async () => { + const taskLock = { + taskName: cryptoRandomString({ length: 10 }), + lockExpiresAt: new Date() + } + const storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).deep.equal(taskLock) + }) + it('acquires lock if the previous lock for the taskName has expired', async () => { + const taskLock = { + taskName: cryptoRandomString({ length: 10 }), + lockExpiresAt: new Date() + } + let storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).deep.equal(taskLock) + taskLock.lockExpiresAt = new Date(2099, 12, 31) + + storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).deep.equal(taskLock) + }) + it('returns null if the previous lock for the task name has not expired', async () => { + const taskLock = { + taskName: cryptoRandomString({ length: 10 }), + lockExpiresAt: new Date(2099, 12, 31) + } + let storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).deep.equal(taskLock) + taskLock.lockExpiresAt = new Date(2199, 12, 31) + + storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).to.be.null + }) + }) + describe('releaseTaskLockFactory creates a function, that', () => { + it('releases a lock by name', async () => { + const taskLock = { + taskName: cryptoRandomString({ length: 10 }), + lockExpiresAt: new Date(2099, 12, 31) + } + let storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).deep.equal(taskLock) + taskLock.lockExpiresAt = new Date(2199, 12, 31) + + storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).to.be.null + await releaseTaskLockFactory({ db })(taskLock) + + storedTaskLock = await acquireTaskLockFactory({ db })(taskLock) + expect(storedTaskLock).deep.equal(taskLock) + }) + }) +}) diff --git a/packages/server/modules/core/tests/scheduledTasks.spec.ts b/packages/server/modules/core/tests/unit/scheduledTasks.spec.ts similarity index 59% rename from packages/server/modules/core/tests/scheduledTasks.spec.ts rename to packages/server/modules/core/tests/unit/scheduledTasks.spec.ts index 3aa95fc28..7cc038f51 100644 --- a/packages/server/modules/core/tests/scheduledTasks.spec.ts +++ b/packages/server/modules/core/tests/unit/scheduledTasks.spec.ts @@ -1,61 +1,26 @@ import { describe } from 'mocha' -import { ScheduledTasks } from '@/modules/core/dbSchema' -import { truncateTables } from '@/test/hooks' import { ensureError } from '@/modules/shared/helpers/errorHelper' import { scheduledCallbackWrapper, scheduleExecutionFactory } from '@/modules/core/services/taskScheduler' import { expect } from 'chai' -import { sleep } from '@/test/helpers' import cryptoRandomString from 'crypto-random-string' -import { acquireTaskLockFactory } from '@/modules/core/repositories/scheduledTasks' -import { db } from '@/db/knex' - -const acquireTaskLock = acquireTaskLockFactory({ db }) -const scheduleExecution = scheduleExecutionFactory({ acquireTaskLock }) describe('Scheduled tasks @core', () => { - describe('Task lock repository', () => { - before(async () => { - await truncateTables([ScheduledTasks.name]) - }) - it('can acquire task lock for a new function name', async () => { - const taskName = cryptoRandomString({ length: 10 }) - const scheduledTask = { taskName, lockExpiresAt: new Date() } - const lock = await acquireTaskLock(scheduledTask) - expect(lock).to.be.deep.equal(scheduledTask) - }) - it('can acquire task lock if previous lock has expired', async () => { - const taskName = cryptoRandomString({ length: 10 }) - const oldTask = { taskName, lockExpiresAt: new Date() } - await acquireTaskLock(oldTask) - - await sleep(100) - const newTask = { taskName, lockExpiresAt: new Date() } - const lock = await acquireTaskLock(newTask) - expect(lock).to.be.deep.equal(newTask) - }) - it('returns an invalid lock (null), if there is another lock in place', async () => { - const taskName = cryptoRandomString({ length: 10 }) - const oldTask = { - taskName, - lockExpiresAt: new Date('2366-12-28 00:30:57.000+00') - } - await acquireTaskLock(oldTask) - const newTask = { taskName, lockExpiresAt: new Date() } - const lock = await acquireTaskLock(newTask) - expect(lock).to.be.null - }) - }) describe('Task scheduler', () => { describe('scheduled callback wrapper function', () => { let callbackExecuted = false + let lockReleased = false async function fakeCallback() { callbackExecuted = true } + async function releaseTaskLock() { + lockReleased = true + } beforeEach(() => { callbackExecuted = false + lockReleased = false }) it("doesn't invoke the callback if it aquires an invalid lock", async () => { expect(callbackExecuted).to.be.false @@ -66,9 +31,11 @@ describe('Scheduled tasks @core', () => { 100, fakeCallback, // fake lock aquire, always returning an invalid lock - async () => null + async () => null, + releaseTaskLock ) expect(callbackExecuted).to.be.false + expect(lockReleased).to.be.false }) it('invokes the callback if a task lock is acquired', async () => { expect(callbackExecuted).to.be.false @@ -79,9 +46,11 @@ describe('Scheduled tasks @core', () => { 100, fakeCallback, // fake lock aquire, always returning an invalid lock - async () => ({ taskName, lockExpiresAt: new Date() }) + async () => ({ taskName, lockExpiresAt: new Date() }), + releaseTaskLock ) expect(callbackExecuted).to.be.true + expect(lockReleased).to.be.true }) it('handles all callback errors gracefully', async () => { expect(callbackExecuted).to.be.false @@ -95,13 +64,19 @@ describe('Scheduled tasks @core', () => { throw 'catch this' }, // fake lock aquire, always returning an invalid lock - async () => ({ taskName, lockExpiresAt: new Date() }) + async () => ({ taskName, lockExpiresAt: new Date() }), + releaseTaskLock ) expect(callbackExecuted).to.be.true + expect(lockReleased).to.be.true }) }) describe('schedule execution', () => { - it('throws an InvalidArgimentError if the cron expression is not valid', async () => { + const scheduleExecution = scheduleExecutionFactory({ + acquireTaskLock: async () => null, + releaseTaskLock: async () => {} + }) + it('throws an InvalidArgumentError if the cron expression is not valid', async () => { const cronExpression = 'this is a borked cron expression' try { scheduleExecution(cronExpression, 'tick tick boom', async () => { diff --git a/packages/server/modules/gatekeeper/clients/stripe.ts b/packages/server/modules/gatekeeper/clients/stripe.ts index cb61b692f..371985760 100644 --- a/packages/server/modules/gatekeeper/clients/stripe.ts +++ b/packages/server/modules/gatekeeper/clients/stripe.ts @@ -2,8 +2,8 @@ import { CreateCheckoutSession, GetSubscriptionData, - SubscriptionData, - WorkspaceSubscription + ReconcileSubscriptionData, + SubscriptionData } from '@/modules/gatekeeper/domain/billing' import { WorkspacePlanBillingIntervals, @@ -163,19 +163,13 @@ export const parseSubscriptionData = ( // this should be a reconcile subscriptions, we keep an accurate state in the DB // on each change, we're reconciling that state to stripe export const reconcileWorkspaceSubscriptionFactory = - ({ stripe }: { stripe: Stripe }) => - async ({ - workspaceSubscription, - applyProrotation - }: { - workspaceSubscription: WorkspaceSubscription - applyProrotation: boolean - }) => { + ({ stripe }: { stripe: Stripe }): ReconcileSubscriptionData => + async ({ subscriptionData, applyProrotation }) => { const existingSubscriptionState = await getSubscriptionDataFactory({ stripe })({ - subscriptionId: workspaceSubscription.subscriptionData.subscriptionId + subscriptionId: subscriptionData.subscriptionId }) const items: Stripe.SubscriptionUpdateParams.Item[] = [] - for (const product of workspaceSubscription.subscriptionData.products) { + for (const product of subscriptionData.products) { const existingProduct = existingSubscriptionState.products.find( (p) => p.productId === product.productId ) @@ -187,13 +181,24 @@ export const reconcileWorkspaceSubscriptionFactory = items.push({ quantity: product.quantity, price: product.priceId }) items.push({ id: product.subscriptionItemId, deleted: true }) } else { - items.push({ quantity: product.quantity, id: product.subscriptionItemId }) + items.push({ + quantity: product.quantity, + id: existingProduct.subscriptionItemId + }) } } + // remove products from the sub + const productIds = subscriptionData.products.map((p) => p.productId) + const removedProducts = existingSubscriptionState.products.filter( + (p) => !productIds.includes(p.productId) + ) + for (const removedProduct of removedProducts) { + items.push({ id: removedProduct.subscriptionItemId, deleted: true }) + } // workspaceSubscription.subscriptionData.products. // const item = workspaceSubscription.subscriptionData.products.find(p => p.) - await stripe.subscriptions.update( - workspaceSubscription.subscriptionData.subscriptionId, - { items, proration_behavior: applyProrotation ? 'create_prorations' : 'none' } - ) + await stripe.subscriptions.update(subscriptionData.subscriptionId, { + items, + proration_behavior: applyProrotation ? 'create_prorations' : 'none' + }) } diff --git a/packages/server/modules/gatekeeper/domain/billing.ts b/packages/server/modules/gatekeeper/domain/billing.ts index f26c2c39d..3414c9f6d 100644 --- a/packages/server/modules/gatekeeper/domain/billing.ts +++ b/packages/server/modules/gatekeeper/domain/billing.ts @@ -5,6 +5,7 @@ import { WorkspacePlanBillingIntervals, WorkspacePricingPlans } from '@/modules/gatekeeper/domain/workspacePricing' +import { OverrideProperties } from 'type-fest' import { z } from 'zod' export type UnpaidWorkspacePlanStatuses = 'valid' @@ -109,6 +110,15 @@ export type WorkspaceSubscription = { billingInterval: WorkspacePlanBillingIntervals subscriptionData: SubscriptionData } +const subscriptionProduct = z.object({ + productId: z.string(), + subscriptionItemId: z.string(), + priceId: z.string(), + quantity: z.number() +}) + +export type SubscriptionProduct = z.infer + export const subscriptionData = z.object({ subscriptionId: z.string().min(1), customerId: z.string().min(1), @@ -123,15 +133,7 @@ export const subscriptionData = z.object({ z.literal('unpaid'), z.literal('paused') ]), - products: z - .object({ - // we're going to use the productId to match with our - productId: z.string(), - subscriptionItemId: z.string(), - priceId: z.string(), - quantity: z.number() - }) - .array() + products: subscriptionProduct.array() }) // this abstracts the stripe sub data @@ -145,6 +147,8 @@ export type GetWorkspaceSubscription = (args: { workspaceId: string }) => Promise +export type GetWorkspaceSubscriptions = () => Promise + export type GetWorkspaceSubscriptionBySubscriptionId = (args: { subscriptionId: string }) => Promise @@ -158,7 +162,18 @@ export type GetWorkspacePlanPrice = (args: { billingInterval: WorkspacePlanBillingIntervals }) => string -export type ReconcileWorkspaceSubscription = (args: { - workspaceSubscription: WorkspaceSubscription +export type GetWorkspacePlanProductId = (args: { + workspacePlan: WorkspacePricingPlans +}) => string + +export type SubscriptionDataInput = OverrideProperties< + SubscriptionData, + { + products: OverrideProperties[] + } +> + +export type ReconcileSubscriptionData = (args: { + subscriptionData: SubscriptionDataInput applyProrotation: boolean }) => Promise diff --git a/packages/server/modules/gatekeeper/events/eventListener.ts b/packages/server/modules/gatekeeper/events/eventListener.ts new file mode 100644 index 000000000..e2ae7fc97 --- /dev/null +++ b/packages/server/modules/gatekeeper/events/eventListener.ts @@ -0,0 +1,40 @@ +import { reconcileWorkspaceSubscriptionFactory } from '@/modules/gatekeeper/clients/stripe' +import { + getWorkspacePlanFactory, + getWorkspaceSubscriptionFactory +} from '@/modules/gatekeeper/repositories/billing' +import { addWorkspaceSubscriptionSeatIfNeededFactory } from '@/modules/gatekeeper/services/subscriptions' +import { + getWorkspacePlanPrice, + getWorkspacePlanProductId +} from '@/modules/gatekeeper/stripe' +import { getEventBus } from '@/modules/shared/services/eventBus' +import { countWorkspaceRoleWithOptionalProjectRoleFactory } from '@/modules/workspaces/repositories/workspaces' +import { WorkspaceEvents } from '@/modules/workspacesCore/domain/events' +import { Knex } from 'knex' +import Stripe from 'stripe' + +export const initializeEventListenersFactory = + ({ db, stripe }: { db: Knex; stripe: Stripe }) => + () => { + const eventBus = getEventBus() + const quitCbs = [ + eventBus.listen(WorkspaceEvents.RoleUpdated, async ({ payload }) => { + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: getWorkspacePlanFactory({ db }), + getWorkspaceSubscription: getWorkspaceSubscriptionFactory({ db }), + countWorkspaceRole: countWorkspaceRoleWithOptionalProjectRoleFactory({ + db + }), + getWorkspacePlanPrice, + getWorkspacePlanProductId, + reconcileSubscriptionData: reconcileWorkspaceSubscriptionFactory({ stripe }) + }) + + await addWorkspaceSubscriptionSeatIfNeeded(payload) + }) + ] + + return () => quitCbs.forEach((quit) => quit()) + } diff --git a/packages/server/modules/gatekeeper/index.ts b/packages/server/modules/gatekeeper/index.ts index ba6cf40c9..7b1de0c35 100644 --- a/packages/server/modules/gatekeeper/index.ts +++ b/packages/server/modules/gatekeeper/index.ts @@ -1,4 +1,5 @@ -import { moduleLogger } from '@/logging/logging' +import cron from 'node-cron' +import { logger, moduleLogger } from '@/logging/logging' import { SpeckleModule } from '@/modules/shared/helpers/typeHelper' import { getFeatureFlags } from '@/modules/shared/helpers/envHelper' import { validateModuleLicense } from '@/modules/gatekeeper/services/validateLicense' @@ -6,6 +7,24 @@ import { getBillingRouter } from '@/modules/gatekeeper/rest/billing' import { registerOrUpdateScopeFactory } from '@/modules/shared/repositories/scopes' import { db } from '@/db/knex' import { gatekeeperScopes } from '@/modules/gatekeeper/scopes' +import { initializeEventListenersFactory } from '@/modules/gatekeeper/events/eventListener' +import { getStripeClient, getWorkspacePlanProductId } from '@/modules/gatekeeper/stripe' +import { scheduleExecutionFactory } from '@/modules/core/services/taskScheduler' +import { + acquireTaskLockFactory, + releaseTaskLockFactory +} from '@/modules/core/repositories/scheduledTasks' +import { + downscaleWorkspaceSubscriptionFactory, + manageSubscriptionDownscaleFactory +} from '@/modules/gatekeeper/services/subscriptions' +import { + getWorkspacePlanFactory, + getWorkspaceSubscriptionsPastBillingCycleEndFactory, + upsertWorkspaceSubscriptionFactory +} from '@/modules/gatekeeper/repositories/billing' +import { countWorkspaceRoleWithOptionalProjectRoleFactory } from '@/modules/workspaces/repositories/workspaces' +import { reconcileWorkspaceSubscriptionFactory } from '@/modules/gatekeeper/clients/stripe' const { FF_GATEKEEPER_MODULE_ENABLED, FF_BILLING_INTEGRATION_ENABLED } = getFeatureFlags() @@ -15,6 +34,42 @@ const initScopes = async () => { await Promise.all(gatekeeperScopes.map((scope) => registerFunc({ scope }))) } +const scheduleWorkspaceSubscriptionDownscale = () => { + const scheduleExecution = scheduleExecutionFactory({ + acquireTaskLock: acquireTaskLockFactory({ db }), + releaseTaskLock: releaseTaskLockFactory({ db }) + }) + + const stripe = getStripeClient() + + const manageSubscriptionDownscale = manageSubscriptionDownscaleFactory({ + logger, + downscaleWorkspaceSubscription: downscaleWorkspaceSubscriptionFactory({ + countWorkspaceRole: countWorkspaceRoleWithOptionalProjectRoleFactory({ db }), + getWorkspacePlan: getWorkspacePlanFactory({ db }), + reconcileSubscriptionData: reconcileWorkspaceSubscriptionFactory({ stripe }), + getWorkspacePlanProductId + }), + getWorkspaceSubscriptions: getWorkspaceSubscriptionsPastBillingCycleEndFactory({ + db + }), + updateWorkspaceSubscription: upsertWorkspaceSubscriptionFactory({ db }) + }) + + const cronExpression = '*/10 * * * * *' + return scheduleExecution( + cronExpression, + 'WorkspaceSubscriptionDownscale', + async () => { + await manageSubscriptionDownscale() + // await cleanOrphanedWebhookConfigs() + } + ) +} + +let scheduledTask: cron.ScheduledTask | undefined = undefined +let quitListeners: (() => void) | undefined = undefined + const gatekeeperModule: SpeckleModule = { async init(app, isInitial) { await initScopes() @@ -35,6 +90,13 @@ const gatekeeperModule: SpeckleModule = { if (FF_BILLING_INTEGRATION_ENABLED) { app.use(getBillingRouter()) + scheduledTask = scheduleWorkspaceSubscriptionDownscale() + + quitListeners = initializeEventListenersFactory({ + db, + stripe: getStripeClient() + })() + const isLicenseValid = await validateModuleLicense({ requiredModules: ['billing'] }) @@ -45,6 +107,10 @@ const gatekeeperModule: SpeckleModule = { // TODO: create a cron job, that removes unused seats from the subscription at the beginning of each workspace plan's billing cycle } } + }, + async shutdown() { + if (quitListeners) quitListeners() + if (scheduledTask) scheduledTask.stop() } } export = gatekeeperModule diff --git a/packages/server/modules/gatekeeper/repositories/billing.ts b/packages/server/modules/gatekeeper/repositories/billing.ts index be31356fb..986805d30 100644 --- a/packages/server/modules/gatekeeper/repositories/billing.ts +++ b/packages/server/modules/gatekeeper/repositories/billing.ts @@ -12,7 +12,8 @@ import { DeleteCheckoutSession, GetWorkspaceCheckoutSession, GetWorkspaceSubscription, - GetWorkspaceSubscriptionBySubscriptionId + GetWorkspaceSubscriptionBySubscriptionId, + GetWorkspaceSubscriptions } from '@/modules/gatekeeper/domain/billing' import { Knex } from 'knex' @@ -127,3 +128,14 @@ export const getWorkspaceSubscriptionBySubscriptionIdFactory = .first() return subscription ?? null } + +export const getWorkspaceSubscriptionsPastBillingCycleEndFactory = + ({ db }: { db: Knex }): GetWorkspaceSubscriptions => + async () => { + const cycleEnd = new Date() + cycleEnd.setMinutes(cycleEnd.getMinutes() + 5) + return await tables + .workspaceSubscriptions(db) + .select() + .where('currentBillingCycleEnd', '<', cycleEnd) + } diff --git a/packages/server/modules/gatekeeper/services/subscriptions.ts b/packages/server/modules/gatekeeper/services/subscriptions.ts index 1d2553632..0bff65f34 100644 --- a/packages/server/modules/gatekeeper/services/subscriptions.ts +++ b/packages/server/modules/gatekeeper/services/subscriptions.ts @@ -1,17 +1,28 @@ +import { Logger } from '@/logging/logging' import { GetWorkspacePlan, + GetWorkspacePlanPrice, + GetWorkspacePlanProductId, + GetWorkspaceSubscription, GetWorkspaceSubscriptionBySubscriptionId, + GetWorkspaceSubscriptions, PaidWorkspacePlanStatuses, + ReconcileSubscriptionData, SubscriptionData, + SubscriptionDataInput, UpsertPaidWorkspacePlan, - UpsertWorkspaceSubscription + UpsertWorkspaceSubscription, + WorkspaceSubscription } from '@/modules/gatekeeper/domain/billing' +import { WorkspacePricingPlans } from '@/modules/gatekeeper/domain/workspacePricing' import { WorkspacePlanMismatchError, WorkspacePlanNotFoundError, WorkspaceSubscriptionNotFoundError } from '@/modules/gatekeeper/errors/billing' -import { throwUncoveredError } from '@speckle/shared' +import { CountWorkspaceRoleWithOptionalProjectRole } from '@/modules/workspaces/domain/operations' +import { throwUncoveredError, WorkspaceRoles } from '@speckle/shared' +import { cloneDeep, isEqual, sum } from 'lodash' export const handleSubscriptionUpdateFactory = ({ @@ -74,7 +85,248 @@ export const handleSubscriptionUpdateFactory = }) // if there is a status in the sub, we recognize, we need to update our state await upsertWorkspaceSubscription({ - workspaceSubscription: { ...subscription, subscriptionData } + workspaceSubscription: { + ...subscription, + updatedAt: new Date(), + subscriptionData + } }) } } + +export const addWorkspaceSubscriptionSeatIfNeededFactory = + ({ + getWorkspacePlan, + getWorkspaceSubscription, + countWorkspaceRole, + getWorkspacePlanProductId, + getWorkspacePlanPrice, + reconcileSubscriptionData + }: { + getWorkspacePlan: GetWorkspacePlan + getWorkspaceSubscription: GetWorkspaceSubscription + countWorkspaceRole: CountWorkspaceRoleWithOptionalProjectRole + getWorkspacePlanProductId: GetWorkspacePlanProductId + getWorkspacePlanPrice: GetWorkspacePlanPrice + reconcileSubscriptionData: ReconcileSubscriptionData + }) => + async ({ workspaceId, role }: { workspaceId: string; role: WorkspaceRoles }) => { + const workspacePlan = await getWorkspacePlan({ workspaceId }) + // if (!workspacePlan) throw new WorkspacePlanNotFoundError() + if (!workspacePlan) return + const workspaceSubscription = await getWorkspaceSubscription({ workspaceId }) + if (!workspaceSubscription) throw new WorkspaceSubscriptionNotFoundError() + + switch (workspacePlan.name) { + case 'team': + case 'pro': + case 'business': + break + case 'unlimited': + case 'academia': + throw new WorkspacePlanMismatchError() + default: + throwUncoveredError(workspacePlan) + } + + if (workspacePlan.status === 'canceled') return + + let productId: string + let priceId: string + let roleCount: number + switch (role) { + case 'workspace:guest': + roleCount = await countWorkspaceRole({ workspaceId, workspaceRole: role }) + productId = getWorkspacePlanProductId({ workspacePlan: 'guest' }) + priceId = getWorkspacePlanPrice({ + workspacePlan: 'guest', + billingInterval: workspaceSubscription.billingInterval + }) + break + case 'workspace:admin': + case 'workspace:member': + roleCount = sum( + await Promise.all([ + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:admin' }), + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:member' }) + ]) + ) + productId = getWorkspacePlanProductId({ workspacePlan: workspacePlan.name }) + priceId = getWorkspacePlanPrice({ + workspacePlan: workspacePlan.name, + billingInterval: workspaceSubscription.billingInterval + }) + break + default: + throwUncoveredError(role) + } + + const subscriptionData: SubscriptionDataInput = cloneDeep( + workspaceSubscription.subscriptionData + ) + + const currentPlanProduct = subscriptionData.products.find( + (product) => product.productId === productId + ) + if (!currentPlanProduct) { + subscriptionData.products.push({ productId, priceId, quantity: roleCount }) + } else { + // if there is enough seats, we do not have to do anything + if (currentPlanProduct.quantity >= roleCount) return + currentPlanProduct.quantity = roleCount + } + await reconcileSubscriptionData({ subscriptionData, applyProrotation: true }) + } + +const mutateSubscriptionDataWithNewValidSeatNumbers = ({ + seatCount, + workspacePlan, + getWorkspacePlanProductId, + subscriptionData +}: { + seatCount: number + workspacePlan: WorkspacePricingPlans + getWorkspacePlanProductId: GetWorkspacePlanProductId + subscriptionData: SubscriptionData +}): void => { + const productId = getWorkspacePlanProductId({ workspacePlan }) + const product = subscriptionData.products.find( + (product) => product.productId === productId + ) + if (seatCount < 0) throw new Error('Invalid seat count, cannot be negative') + + if (seatCount === 0 && product === undefined) return + if (seatCount === 0 && product !== undefined) { + const prodIndex = subscriptionData.products.indexOf(product) + subscriptionData.products.splice(prodIndex, 1) + } else if (product !== undefined && product.quantity >= seatCount) { + product.quantity = seatCount + } else { + throw new Error('Invalid subscription state') + } +} + +const calculateNewBillingCycleEnd = ({ + workspaceSubscription +}: { + workspaceSubscription: WorkspaceSubscription +}): Date => { + const newBillingCycleEnd = new Date(workspaceSubscription.currentBillingCycleEnd) + switch (workspaceSubscription.billingInterval) { + case 'monthly': + newBillingCycleEnd.setMonth(newBillingCycleEnd.getMonth() + 1) + break + case 'yearly': + newBillingCycleEnd.setFullYear(newBillingCycleEnd.getFullYear() + 1) + break + default: + throwUncoveredError(workspaceSubscription.billingInterval) + } + return newBillingCycleEnd +} + +type DownscaleWorkspaceSubscription = (args: { + workspaceSubscription: WorkspaceSubscription +}) => Promise + +export const downscaleWorkspaceSubscriptionFactory = + ({ + getWorkspacePlan, + countWorkspaceRole, + getWorkspacePlanProductId, + reconcileSubscriptionData + }: { + getWorkspacePlan: GetWorkspacePlan + countWorkspaceRole: CountWorkspaceRoleWithOptionalProjectRole + getWorkspacePlanProductId: GetWorkspacePlanProductId + reconcileSubscriptionData: ReconcileSubscriptionData + }): DownscaleWorkspaceSubscription => + async ({ workspaceSubscription }) => { + const workspaceId = workspaceSubscription.workspaceId + + const workspacePlan = await getWorkspacePlan({ workspaceId }) + if (!workspacePlan) throw new WorkspacePlanNotFoundError() + + switch (workspacePlan.name) { + case 'team': + case 'pro': + case 'business': + break + case 'unlimited': + case 'academia': + throw new WorkspacePlanMismatchError() + default: + throwUncoveredError(workspacePlan) + } + + if (workspacePlan.status === 'canceled') return false + + const [guestCount, memberCount, adminCount] = await Promise.all([ + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:guest' }), + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:member' }), + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:admin' }) + ]) + + const subscriptionData = cloneDeep(workspaceSubscription.subscriptionData) + + mutateSubscriptionDataWithNewValidSeatNumbers({ + seatCount: guestCount, + workspacePlan: 'guest', + getWorkspacePlanProductId, + subscriptionData + }) + mutateSubscriptionDataWithNewValidSeatNumbers({ + seatCount: memberCount + adminCount, + workspacePlan: workspacePlan.name, + getWorkspacePlanProductId, + subscriptionData + }) + + if (!isEqual(subscriptionData, workspaceSubscription.subscriptionData)) { + await reconcileSubscriptionData({ subscriptionData, applyProrotation: false }) + return true + } + return false + } + +export const manageSubscriptionDownscaleFactory = + ({ + logger, + getWorkspaceSubscriptions, + downscaleWorkspaceSubscription, + updateWorkspaceSubscription + }: { + getWorkspaceSubscriptions: GetWorkspaceSubscriptions + downscaleWorkspaceSubscription: DownscaleWorkspaceSubscription + updateWorkspaceSubscription: UpsertWorkspaceSubscription + logger: Logger + }) => + async () => { + const subscriptions = await getWorkspaceSubscriptions() + for (const workspaceSubscription of subscriptions) { + const log = logger.child({ workspaceId: workspaceSubscription.workspaceId }) + try { + const subDownscaled = await downscaleWorkspaceSubscription({ + workspaceSubscription + }) + if (subDownscaled) { + log.info( + 'Downscaled workspace subscription to match the current workspace team' + ) + } else { + log.info('Did not need to downscale the workspace subscription') + } + } catch (err) { + log.error({ err }, 'Failed to downscale workspace subscription') + } + const newBillingCycleEnd = calculateNewBillingCycleEnd({ workspaceSubscription }) + const updatedWorkspaceSubscription = { + ...workspaceSubscription, + currentBillingCycleEnd: newBillingCycleEnd + } + await updateWorkspaceSubscription({ + workspaceSubscription: updatedWorkspaceSubscription + }) + log.info({ updatedWorkspaceSubscription }, 'Updated workspace billing cycle end') + } + } diff --git a/packages/server/modules/gatekeeper/stripe.ts b/packages/server/modules/gatekeeper/stripe.ts index 4fa6cd597..d1ae3bb17 100644 --- a/packages/server/modules/gatekeeper/stripe.ts +++ b/packages/server/modules/gatekeeper/stripe.ts @@ -1,4 +1,7 @@ -import { GetWorkspacePlanPrice } from '@/modules/gatekeeper/domain/billing' +import { + GetWorkspacePlanPrice, + GetWorkspacePlanProductId +} from '@/modules/gatekeeper/domain/billing' import { WorkspacePlanBillingIntervals, WorkspacePricingPlans @@ -43,3 +46,7 @@ export const getWorkspacePlanPrice: GetWorkspacePlanPrice = ({ workspacePlan, billingInterval }) => workspacePlanPrices()[workspacePlan][billingInterval] + +export const getWorkspacePlanProductId: GetWorkspacePlanProductId = ({ + workspacePlan +}) => workspacePlanPrices()[workspacePlan].productId diff --git a/packages/server/modules/gatekeeper/tests/helpers.ts b/packages/server/modules/gatekeeper/tests/helpers.ts new file mode 100644 index 000000000..b55bb5205 --- /dev/null +++ b/packages/server/modules/gatekeeper/tests/helpers.ts @@ -0,0 +1,40 @@ +import { + SubscriptionData, + WorkspaceSubscription +} from '@/modules/gatekeeper/domain/billing' +import cryptoRandomString from 'crypto-random-string' +import { assign } from 'lodash' + +export const createTestSubscriptionData = ( + overrides: Partial = {} +): SubscriptionData => { + const defaultValues: SubscriptionData = { + cancelAt: null, + customerId: cryptoRandomString({ length: 10 }), + products: [ + { + priceId: cryptoRandomString({ length: 10 }), + productId: cryptoRandomString({ length: 10 }), + quantity: 3, + subscriptionItemId: cryptoRandomString({ length: 10 }) + } + ], + status: 'active', + subscriptionId: cryptoRandomString({ length: 10 }) + } + return assign(defaultValues, overrides) +} + +export const createTestWorkspaceSubscription = ( + overrides: Partial = {} +): WorkspaceSubscription => { + const defaultValues: WorkspaceSubscription = { + billingInterval: 'monthly', + createdAt: new Date(), + updatedAt: new Date(), + currentBillingCycleEnd: new Date(), + subscriptionData: createTestSubscriptionData(), + workspaceId: cryptoRandomString({ length: 10 }) + } + return assign(defaultValues, overrides) +} diff --git a/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts b/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts index 013cb2db9..4bdb178c4 100644 --- a/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts +++ b/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts @@ -1,5 +1,4 @@ import db from '@/db/knex' -import { WorkspaceSubscription } from '@/modules/gatekeeper/domain/billing' import { deleteCheckoutSessionFactory, getCheckoutSessionFactory, @@ -10,9 +9,15 @@ import { updateCheckoutSessionStatusFactory, upsertPaidWorkspacePlanFactory, getWorkspaceSubscriptionFactory, - getWorkspaceSubscriptionBySubscriptionIdFactory + getWorkspaceSubscriptionBySubscriptionIdFactory, + getWorkspaceSubscriptionsPastBillingCycleEndFactory } from '@/modules/gatekeeper/repositories/billing' +import { + createTestSubscriptionData, + createTestWorkspaceSubscription +} from '@/modules/gatekeeper/tests/helpers' import { upsertWorkspaceFactory } from '@/modules/workspaces/repositories/workspaces' +import { truncateTables } from '@/test/hooks' import { createAndStoreTestWorkspaceFactory } from '@/test/speckle-helpers/workspaces' import { expect } from 'chai' import cryptoRandomString from 'crypto-random-string' @@ -33,6 +38,9 @@ const getWorkspaceSubscription = getWorkspaceSubscriptionFactory({ db }) const getWorkspaceSubscriptionBySubscriptionId = getWorkspaceSubscriptionBySubscriptionIdFactory({ db }) +const getSubscriptionsAboutToEndBillingCycle = + getWorkspaceSubscriptionsPastBillingCycleEndFactory({ db }) + describe('billing repositories @gatekeeper', () => { describe('workspacePlans', () => { describe('upsertPaidWorkspacePlanFactory creates a function, that', () => { @@ -204,27 +212,21 @@ describe('billing repositories @gatekeeper', () => { it('saves and updates the subscription', async () => { const workspace = await createAndStoreTestWorkspace() const workspaceId = workspace.id - const workspaceSubscription: WorkspaceSubscription = { - billingInterval: 'monthly' as const, - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), - subscriptionData: { - customerId: cryptoRandomString({ length: 10 }), - status: 'active' as const, - cancelAt: null, - products: [ - { - priceId: cryptoRandomString({ length: 10 }), - quantity: 10, - productId: cryptoRandomString({ length: 10 }), - subscriptionItemId: cryptoRandomString({ length: 10 }) - } - ], - subscriptionId: cryptoRandomString({ length: 10 }) - }, - workspaceId - } + const subscriptionData = createTestSubscriptionData({ + products: [ + { + priceId: cryptoRandomString({ length: 10 }), + quantity: 10, + productId: cryptoRandomString({ length: 10 }), + subscriptionItemId: cryptoRandomString({ length: 10 }) + } + ] + }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + billingInterval: 'monthly', + subscriptionData + }) await upsertWorkspaceSubscription({ workspaceSubscription }) let storedSubscription = await getWorkspaceSubscription({ workspaceId }) expect(storedSubscription).deep.equal(workspaceSubscription) @@ -255,27 +257,7 @@ describe('billing repositories @gatekeeper', () => { it('returns the sub', async () => { const workspace = await createAndStoreTestWorkspace() const workspaceId = workspace.id - const workspaceSubscription: WorkspaceSubscription = { - billingInterval: 'monthly' as const, - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), - subscriptionData: { - customerId: cryptoRandomString({ length: 10 }), - status: 'active' as const, - cancelAt: null, - products: [ - { - priceId: cryptoRandomString({ length: 10 }), - quantity: 10, - productId: cryptoRandomString({ length: 10 }), - subscriptionItemId: cryptoRandomString({ length: 10 }) - } - ], - subscriptionId: cryptoRandomString({ length: 10 }) - }, - workspaceId - } + const workspaceSubscription = createTestWorkspaceSubscription({ workspaceId }) await upsertWorkspaceSubscription({ workspaceSubscription }) const storedSubscription = await getWorkspaceSubscriptionBySubscriptionId({ subscriptionId: workspaceSubscription.subscriptionData.subscriptionId @@ -283,5 +265,34 @@ describe('billing repositories @gatekeeper', () => { expect(storedSubscription).deep.equal(workspaceSubscription) }) }) + describe('getWorkspaceSubscriptionsPastBillingCycleEndFactory', () => { + before(async () => { + await truncateTables(['workspace_subscriptions']) + }) + it('returns subs, that are about to end their billing cycle', async () => { + const workspace1 = await createAndStoreTestWorkspace() + const workspace1Id = workspace1.id + const workspace1Subscription = createTestWorkspaceSubscription({ + workspaceId: workspace1Id, + currentBillingCycleEnd: new Date(2099, 0, 1) + }) + await upsertWorkspaceSubscription({ + workspaceSubscription: workspace1Subscription + }) + + const workspace2 = await createAndStoreTestWorkspace() + const workspace2Id = workspace2.id + const currentBillingCycleEnd = new Date() + currentBillingCycleEnd.setMinutes(currentBillingCycleEnd.getMinutes() + 4) + const workspace2Subscription = createTestWorkspaceSubscription({ + workspaceId: workspace2Id + }) + await upsertWorkspaceSubscription({ + workspaceSubscription: workspace2Subscription + }) + const subscriptions = await getSubscriptionsAboutToEndBillingCycle() + expect(subscriptions).deep.equalInAnyOrder([workspace2Subscription]) + }) + }) }) }) diff --git a/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts b/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts index 1066bfadb..94bb7f3e5 100644 --- a/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts +++ b/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts @@ -1,5 +1,6 @@ +import { logger } from '@/logging/logging' import { - SubscriptionData, + SubscriptionDataInput, WorkspacePlan, WorkspaceSubscription } from '@/modules/gatekeeper/domain/billing' @@ -8,31 +9,21 @@ import { WorkspacePlanNotFoundError, WorkspaceSubscriptionNotFoundError } from '@/modules/gatekeeper/errors/billing' -import { handleSubscriptionUpdateFactory } from '@/modules/gatekeeper/services/subscriptions' +import { + addWorkspaceSubscriptionSeatIfNeededFactory, + downscaleWorkspaceSubscriptionFactory, + handleSubscriptionUpdateFactory, + manageSubscriptionDownscaleFactory +} from '@/modules/gatekeeper/services/subscriptions' +import { + createTestSubscriptionData, + createTestWorkspaceSubscription +} from '@/modules/gatekeeper/tests/helpers' import { expectToThrow } from '@/test/assertionHelper' +import { throwUncoveredError } from '@speckle/shared' import { expect } from 'chai' import cryptoRandomString from 'crypto-random-string' -import { merge } from 'lodash' - -const createTestSubscriptionData = ( - overrides: Partial = {} -): SubscriptionData => { - const defaultValues: SubscriptionData = { - cancelAt: null, - customerId: cryptoRandomString({ length: 10 }), - products: [ - { - priceId: cryptoRandomString({ length: 10 }), - productId: cryptoRandomString({ length: 10 }), - quantity: 3, - subscriptionItemId: cryptoRandomString({ length: 10 }) - } - ], - status: 'active', - subscriptionId: cryptoRandomString({ length: 10 }) - } - return merge(defaultValues, overrides) -} +import { omit } from 'lodash' describe('subscriptions @gatekeeper', () => { describe('handleSubscriptionUpdateFactory creates a function, that', () => { @@ -58,14 +49,8 @@ describe('subscriptions @gatekeeper', () => { const subscriptionData = createTestSubscriptionData() const err = await expectToThrow(async () => { await handleSubscriptionUpdateFactory({ - getWorkspaceSubscriptionBySubscriptionId: async () => ({ - subscriptionData, - billingInterval: 'monthly', - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), - workspaceId: cryptoRandomString({ length: 10 }) - }), + getWorkspaceSubscriptionBySubscriptionId: async () => + createTestWorkspaceSubscription({ subscriptionData }), getWorkspacePlan: async () => null, upsertWorkspaceSubscription: async () => { expect.fail() @@ -83,14 +68,11 @@ describe('subscriptions @gatekeeper', () => { const workspaceId = cryptoRandomString({ length: 10 }) const err = await expectToThrow(async () => { await handleSubscriptionUpdateFactory({ - getWorkspaceSubscriptionBySubscriptionId: async () => ({ - subscriptionData, - billingInterval: 'monthly', - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), - workspaceId - }), + getWorkspaceSubscriptionBySubscriptionId: async () => + createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }), getWorkspacePlan: async () => ({ name, workspaceId, status: 'valid' }), upsertWorkspaceSubscription: async () => { expect.fail() @@ -109,14 +91,10 @@ describe('subscriptions @gatekeeper', () => { cancelAt: new Date(2099, 12, 31) }) const workspaceId = cryptoRandomString({ length: 10 }) - const workspaceSubscription = { + const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData, - billingInterval: 'monthly' as const, - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), workspaceId - } + }) let updatedSubscription: WorkspaceSubscription | undefined = undefined let updatedPlan: WorkspacePlan | undefined = undefined @@ -132,7 +110,12 @@ describe('subscriptions @gatekeeper', () => { } })({ subscriptionData }) expect(updatedPlan!.status).to.be.equal('cancelationScheduled') - expect(updatedSubscription).deep.equal(workspaceSubscription) + expect(updatedSubscription!.updatedAt).to.be.greaterThanOrEqual( + workspaceSubscription.updatedAt + ) + expect(omit(updatedSubscription!, 'updatedAt')).deep.equal( + omit(workspaceSubscription, 'updatedAt') + ) }) it('sets the state to valid', async () => { const subscriptionData = createTestSubscriptionData({ @@ -163,21 +146,23 @@ describe('subscriptions @gatekeeper', () => { } })({ subscriptionData }) expect(updatedPlan!.status).to.be.equal('valid') - expect(updatedSubscription).deep.equal(workspaceSubscription) + expect(updatedSubscription!.updatedAt).to.be.greaterThanOrEqual( + workspaceSubscription.updatedAt + ) + expect(omit(updatedSubscription!, 'updatedAt')).deep.equal( + omit(workspaceSubscription, 'updatedAt') + ) }) it('sets the state to paymentFailed', async () => { const subscriptionData = createTestSubscriptionData({ status: 'past_due' }) const workspaceId = cryptoRandomString({ length: 10 }) - const workspaceSubscription = { + + const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData, - billingInterval: 'monthly' as const, - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), workspaceId - } + }) let updatedSubscription: WorkspaceSubscription | undefined = undefined let updatedPlan: WorkspacePlan | undefined = undefined @@ -193,7 +178,12 @@ describe('subscriptions @gatekeeper', () => { } })({ subscriptionData }) expect(updatedPlan!.status).to.be.equal('paymentFailed') - expect(updatedSubscription).deep.equal(workspaceSubscription) + expect(updatedSubscription!.updatedAt).to.be.greaterThanOrEqual( + workspaceSubscription.updatedAt + ) + expect(omit(updatedSubscription!, 'updatedAt')).deep.equal( + omit(workspaceSubscription, 'updatedAt') + ) }) it('sets the state to canceled', async () => { const subscriptionData = createTestSubscriptionData({ @@ -223,7 +213,12 @@ describe('subscriptions @gatekeeper', () => { } })({ subscriptionData }) expect(updatedPlan!.status).to.be.equal('canceled') - expect(updatedSubscription).deep.equal(workspaceSubscription) + expect(updatedSubscription!.updatedAt).to.be.greaterThanOrEqual( + workspaceSubscription.updatedAt + ) + expect(omit(updatedSubscription!, 'updatedAt')).deep.equal( + omit(workspaceSubscription, 'updatedAt') + ) }) ;( ['incomplete', 'incomplete_expired', 'trialing', 'unpaid', 'paused'] as const @@ -233,14 +228,11 @@ describe('subscriptions @gatekeeper', () => { status }) const workspaceId = cryptoRandomString({ length: 10 }) - const workspaceSubscription = { + + const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData, - billingInterval: 'monthly' as const, - createdAt: new Date(), - updatedAt: new Date(), - currentBillingCycleEnd: new Date(), workspaceId - } + }) await handleSubscriptionUpdateFactory({ getWorkspaceSubscriptionBySubscriptionId: async () => workspaceSubscription, @@ -259,4 +251,619 @@ describe('subscriptions @gatekeeper', () => { }) }) }) + describe('addWorkspaceSubscriptionSeatIfNeededFactory returns a function, that', () => { + it('just returns if the workspacePlan is not found', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => null, + getWorkspaceSubscription: async () => { + expect.fail() + }, + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanPrice: () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:admin' + }) + expect(true).to.be.true + }) + it('throws if the workspaceSubscription is not found', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => ({ + name: 'unlimited', + workspaceId, + status: 'valid' + }), + getWorkspaceSubscription: async () => null, + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanPrice: () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const err = await expectToThrow(async () => { + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:admin' + }) + }) + expect(err.message).to.equal(new WorkspaceSubscriptionNotFoundError().message) + }) + it('throws if a non paid plan, has a subscription', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ products: [] }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + subscriptionData + }) + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => ({ + name: 'unlimited', + workspaceId, + status: 'valid' + }), + getWorkspaceSubscription: async () => workspaceSubscription, + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanPrice: () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const err = await expectToThrow(async () => { + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:admin' + }) + }) + expect(err.message).to.equal(new WorkspacePlanMismatchError().message) + }) + it('returns without reconciliation if the subscription is canceled', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ products: [] }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + subscriptionData + }) + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => ({ + name: 'pro', + workspaceId, + status: 'canceled' + }), + getWorkspaceSubscription: async () => workspaceSubscription, + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanPrice: () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:admin' + }) + }) + it('uses the guest count, guest product and price id if the new role is workspace:guest', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ products: [] }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + subscriptionData + }) + const workspacePlan: WorkspacePlan = { + name: 'team', + workspaceId, + status: 'valid' + } + const priceId = cryptoRandomString({ length: 10 }) + const productId = cryptoRandomString({ length: 10 }) + const roleCount = 10 + + let reconciledSubscriptionData: SubscriptionDataInput | undefined = undefined + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => workspacePlan, + getWorkspaceSubscription: async () => workspaceSubscription, + countWorkspaceRole: async ({ workspaceRole }) => { + switch (workspaceRole) { + case 'workspace:admin': + case 'workspace:member': + expect.fail() + case 'workspace:guest': + return roleCount + } + }, + getWorkspacePlanPrice: ({ workspacePlan, billingInterval }) => { + if (billingInterval !== workspaceSubscription.billingInterval) expect.fail() + switch (workspacePlan) { + case 'business': + case 'team': + case 'pro': + expect.fail() + case 'guest': + return priceId + default: + throwUncoveredError(workspacePlan) + } + }, + getWorkspacePlanProductId: (args) => { + if (args.workspacePlan !== 'guest') expect.fail() + return productId + }, + reconcileSubscriptionData: async ({ applyProrotation, subscriptionData }) => { + if (!applyProrotation) expect.fail() + reconciledSubscriptionData = subscriptionData + } + }) + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:guest' + }) + expect(reconciledSubscriptionData!.products).deep.equalInAnyOrder([ + { productId, priceId, quantity: roleCount } + ]) + }) + ;(['workspace:member', 'workspace:admin'] as const).forEach((role) => + it(`uses the admin + member count, workspacePlan product and price id if the new role is ${role}`, async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ products: [] }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + subscriptionData + }) + const workspacePlan: WorkspacePlan = { + name: 'team', + workspaceId, + status: 'valid' + } + const priceId = cryptoRandomString({ length: 10 }) + const productId = cryptoRandomString({ length: 10 }) + const roleCount = 10 + + let reconciledSubscriptionData: SubscriptionDataInput | undefined = undefined + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => workspacePlan, + getWorkspaceSubscription: async () => workspaceSubscription, + countWorkspaceRole: async ({ workspaceRole }) => { + switch (workspaceRole) { + case 'workspace:admin': + case 'workspace:member': + return roleCount + case 'workspace:guest': + expect.fail() + } + }, + getWorkspacePlanPrice: ({ workspacePlan, billingInterval }) => { + if (billingInterval !== workspaceSubscription.billingInterval) + expect.fail() + switch (workspacePlan) { + case 'business': + case 'pro': + case 'guest': + expect.fail() + case 'team': + return priceId + default: + throwUncoveredError(workspacePlan) + } + }, + getWorkspacePlanProductId: (args) => { + if (args.workspacePlan !== workspacePlan.name) expect.fail() + return productId + }, + reconcileSubscriptionData: async ({ + applyProrotation, + subscriptionData + }) => { + if (!applyProrotation) expect.fail() + reconciledSubscriptionData = subscriptionData + } + }) + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role + }) + expect(reconciledSubscriptionData!.products).deep.equalInAnyOrder([ + { productId, priceId, quantity: 2 * roleCount } + ]) + }) + ) + it('updates the sub existing product quantity if the one matching the new role, does not have enough quantities', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + + const priceId = cryptoRandomString({ length: 10 }) + const productId = cryptoRandomString({ length: 10 }) + const subscriptionItemId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ + products: [ + { + priceId, + productId, + quantity: 4, + subscriptionItemId + } + ] + }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + subscriptionData + }) + const workspacePlan: WorkspacePlan = { + name: 'team', + workspaceId, + status: 'valid' + } + const roleCount = 10 + + let reconciledSubscriptionData: SubscriptionDataInput | undefined = undefined + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => workspacePlan, + getWorkspaceSubscription: async () => workspaceSubscription, + countWorkspaceRole: async ({ workspaceRole }) => { + switch (workspaceRole) { + case 'workspace:admin': + case 'workspace:member': + return roleCount + case 'workspace:guest': + expect.fail() + } + }, + getWorkspacePlanPrice: ({ workspacePlan, billingInterval }) => { + if (billingInterval !== workspaceSubscription.billingInterval) expect.fail() + switch (workspacePlan) { + case 'business': + case 'pro': + case 'guest': + expect.fail() + case 'team': + return priceId + default: + throwUncoveredError(workspacePlan) + } + }, + getWorkspacePlanProductId: (args) => { + if (args.workspacePlan !== workspacePlan.name) expect.fail() + return productId + }, + reconcileSubscriptionData: async ({ applyProrotation, subscriptionData }) => { + if (!applyProrotation) expect.fail() + reconciledSubscriptionData = subscriptionData + } + }) + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:member' + }) + expect(reconciledSubscriptionData!.products).deep.equalInAnyOrder([ + { productId, priceId, quantity: 2 * roleCount, subscriptionItemId } + ]) + }) + it('does not update the subscription if the product matching the new role, has enough quantities', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + + const priceId = cryptoRandomString({ length: 10 }) + const productId = cryptoRandomString({ length: 10 }) + const subscriptionItemId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ + products: [ + { + priceId, + productId, + quantity: 2, + subscriptionItemId + } + ] + }) + const workspaceSubscription = createTestWorkspaceSubscription({ + workspaceId, + subscriptionData + }) + const workspacePlan: WorkspacePlan = { + name: 'team', + workspaceId, + status: 'valid' + } + const roleCount = 1 + + const addWorkspaceSubscriptionSeatIfNeeded = + addWorkspaceSubscriptionSeatIfNeededFactory({ + getWorkspacePlan: async () => workspacePlan, + getWorkspaceSubscription: async () => workspaceSubscription, + countWorkspaceRole: async ({ workspaceRole }) => { + switch (workspaceRole) { + case 'workspace:admin': + case 'workspace:member': + return roleCount + case 'workspace:guest': + expect.fail() + } + }, + getWorkspacePlanPrice: ({ workspacePlan, billingInterval }) => { + if (billingInterval !== workspaceSubscription.billingInterval) expect.fail() + switch (workspacePlan) { + case 'business': + case 'pro': + case 'guest': + expect.fail() + case 'team': + return priceId + default: + throwUncoveredError(workspacePlan) + } + }, + getWorkspacePlanProductId: (args) => { + if (args.workspacePlan !== workspacePlan.name) expect.fail() + return productId + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + await addWorkspaceSubscriptionSeatIfNeeded({ + workspaceId, + role: 'workspace:member' + }) + }) + }) + describe('downscaleWorkspaceSubscriptionFactory', () => { + it('throws an error if the workspace has no plan attached to it', async () => { + const subscriptionData = createTestSubscriptionData() + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData + }) + const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + getWorkspacePlan: async () => null, + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const err = await expectToThrow(async () => { + await downscaleSubscription({ workspaceSubscription }) + }) + expect(err.message).to.equal(new WorkspacePlanNotFoundError().message) + }) + it('throws an error if workspacePlan is not a paid plan', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData() + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }) + const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + getWorkspacePlan: async () => ({ + name: 'unlimited', + workspaceId, + status: 'valid' + }), + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const err = await expectToThrow(async () => { + await downscaleSubscription({ workspaceSubscription }) + }) + expect(err.message).to.equal(new WorkspacePlanMismatchError().message) + }) + it('returns if the subscription is canceled', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData() + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }) + const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + getWorkspacePlan: async () => ({ + name: 'pro', + workspaceId, + status: 'canceled' + }), + countWorkspaceRole: async () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const hasDownscaled = await downscaleSubscription({ workspaceSubscription }) + expect(hasDownscaled).to.be.false + }) + it('does not reconcile the subscription seats did not change', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const priceId = cryptoRandomString({ length: 10 }) + const productId = cryptoRandomString({ length: 10 }) + const quantity = 10 + const subscriptionItemId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ + products: [{ priceId, productId, quantity, subscriptionItemId }] + }) + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + billingInterval: 'monthly', + currentBillingCycleEnd: new Date(2034, 11, 5), + workspaceId + }) + const workspacePlanName = 'pro' + const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + getWorkspacePlan: async () => ({ + name: workspacePlanName, + workspaceId, + status: 'valid' + }), + countWorkspaceRole: async ({ workspaceRole }) => { + return workspaceRole === 'workspace:guest' ? 0 : 5 // 5+5 will be 10 as quantity + }, + getWorkspacePlanProductId: ({ workspacePlan }) => { + return workspacePlan === workspacePlanName + ? productId + : cryptoRandomString({ length: 10 }) + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + await downscaleSubscription({ workspaceSubscription }) + }) + it('reconciles the subscription to the new seat values', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const proPriceId = cryptoRandomString({ length: 10 }) + const proProductId = cryptoRandomString({ length: 10 }) + const proQuantity = 10 + const proSubscriptionItemId = cryptoRandomString({ length: 10 }) + + const guestPriceId = cryptoRandomString({ length: 10 }) + const guestProductId = cryptoRandomString({ length: 10 }) + const guestQuantity = 10 + const guestSubscriptionItemId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ + products: [ + { + priceId: proPriceId, + productId: proProductId, + quantity: proQuantity, + subscriptionItemId: proSubscriptionItemId + }, + { + priceId: guestPriceId, + productId: guestProductId, + quantity: guestQuantity, + subscriptionItemId: guestSubscriptionItemId + } + ] + }) + const testWorkspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }) + const workspacePlanName = 'pro' + + let reconciledSub: SubscriptionDataInput | undefined = undefined + const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + getWorkspacePlan: async () => ({ + name: workspacePlanName, + workspaceId, + status: 'valid' + }), + countWorkspaceRole: async ({ workspaceRole }) => { + return workspaceRole === 'workspace:guest' + ? guestQuantity / 2 + : proQuantity / 2 //we're halving the guest seats, regulars stay the same + }, + getWorkspacePlanProductId: ({ workspacePlan }) => { + return workspacePlan === workspacePlanName ? proProductId : guestProductId + }, + reconcileSubscriptionData: async ({ subscriptionData }) => { + reconciledSub = subscriptionData + } + }) + await downscaleSubscription({ workspaceSubscription: testWorkspaceSubscription }) + + expect( + reconciledSub!.products.find((p) => p.productId === proProductId)?.quantity + ).to.be.equal(proQuantity) + expect( + reconciledSub!.products.find((p) => p.productId === guestProductId)?.quantity + ).to.be.equal(guestQuantity / 2) + }) + }) + describe('manageSubscriptionDownscaleFactory', () => { + it('still updates the monthly billing cycle end, even if subscription reconciliation fails', async () => { + const testWorkspaceSubscription = createTestWorkspaceSubscription({ + billingInterval: 'monthly', + currentBillingCycleEnd: new Date(2034, 11, 5) + }) + let updatedWorkspaceSubscription: WorkspaceSubscription | undefined = undefined + await manageSubscriptionDownscaleFactory({ + logger, + getWorkspaceSubscriptions: async () => [testWorkspaceSubscription], + downscaleWorkspaceSubscription: async () => { + throw 'kabumm' + }, + updateWorkspaceSubscription: async ({ workspaceSubscription }) => { + updatedWorkspaceSubscription = workspaceSubscription + } + })() + + const updatedBillingCycleEnd = new Date(2035, 0, 5) + expect(updatedWorkspaceSubscription).deep.equal({ + ...testWorkspaceSubscription, + currentBillingCycleEnd: updatedBillingCycleEnd + }) + }) + it('still updates the yearly billing cycle end, even if subscription reconciliation fails', async () => { + const testWorkspaceSubscription = createTestWorkspaceSubscription({ + billingInterval: 'yearly', + currentBillingCycleEnd: new Date(2034, 11, 5) + }) + let updatedWorkspaceSubscription: WorkspaceSubscription | undefined = undefined + await manageSubscriptionDownscaleFactory({ + logger, + getWorkspaceSubscriptions: async () => [testWorkspaceSubscription], + downscaleWorkspaceSubscription: async () => { + throw 'kabumm' + }, + updateWorkspaceSubscription: async ({ workspaceSubscription }) => { + updatedWorkspaceSubscription = workspaceSubscription + } + })() + + const updatedBillingCycleEnd = new Date(2035, 11, 5) + expect(updatedWorkspaceSubscription).deep.equal({ + ...testWorkspaceSubscription, + currentBillingCycleEnd: updatedBillingCycleEnd + }) + }) + }) }) diff --git a/packages/server/modules/shared/command.ts b/packages/server/modules/shared/command.ts new file mode 100644 index 000000000..1066094fb --- /dev/null +++ b/packages/server/modules/shared/command.ts @@ -0,0 +1,33 @@ +import { EmitArg, EventBus, EventBusEmit } from '@/modules/shared/services/eventBus' +import { Knex } from 'knex' + +export const commandFactory = + ) => ReturnType>({ + db, + eventBus, + operationFactory + }: { + db: Knex + eventBus: EventBus + operationFactory: (arg: { db: Knex; emit: EventBusEmit }) => TOperation + }) => + async (...args: Parameters): Promise>> => { + const events: EmitArg[] = [] + const emit: EventBusEmit = async ({ eventName, payload }) => { + events.push({ eventName, payload }) + } + + const trx = await db.transaction() + try { + const result = await operationFactory({ db, emit })(...args) + + await trx.commit() + for (const event of events) { + await eventBus.emit(event) + } + return result as Awaited> + } catch (err) { + trx.rollback() + throw err + } + } diff --git a/packages/server/modules/shared/services/eventBus.ts b/packages/server/modules/shared/services/eventBus.ts index 05e2c8b70..36f839373 100644 --- a/packages/server/modules/shared/services/eventBus.ts +++ b/packages/server/modules/shared/services/eventBus.ts @@ -86,9 +86,9 @@ export function initializeEventBus() { emit: async (args: { eventName: EventName payload: EventTypes[EventName] - }): Promise => { + }): Promise => { // curate the proper payload here and eventName object here, before emitting - return emitter.emitAsync(args.eventName, args) + await emitter.emitAsync(args.eventName, args) }, /** @@ -124,6 +124,7 @@ export function initializeEventBus() { export type EventBus = ReturnType export type EventBusPayloads = EventTypes export type EventBusEmit = EventBus['emit'] +export type EmitArg = Parameters[0] let eventBus: EventBus diff --git a/packages/server/modules/shared/test/unit/eventBus.spec.ts b/packages/server/modules/shared/test/unit/eventBus.spec.ts index 20987245f..6fb246498 100644 --- a/packages/server/modules/shared/test/unit/eventBus.spec.ts +++ b/packages/server/modules/shared/test/unit/eventBus.spec.ts @@ -62,22 +62,6 @@ describe('Event Bus', () => { await testEventBus.emit({ eventName: 'test.string', payload: 'fake event' }) expect(eventNumbers.sort((a, b) => a - b)).to.deep.equal([1, 1, 2]) }) - it('returns results from listeners to the emitter', async () => { - const testEventBus = initializeEventBus() - - testEventBus.listen('test.string', ({ payload }) => ({ - outcome: payload - })) - - const lookWhatHappened = 'echo this back to me' - const results = await testEventBus.emit({ - eventName: 'test.string', - payload: lookWhatHappened - }) - - expect(results.length).to.equal(1) - expect(results[0]).to.deep.equal({ outcome: lookWhatHappened }) - }) it('bubbles up listener exceptions to emitter', async () => { const testEventBus = initializeEventBus() diff --git a/packages/server/modules/webhooks/index.ts b/packages/server/modules/webhooks/index.ts index b5af0bc43..f1036dbf5 100644 --- a/packages/server/modules/webhooks/index.ts +++ b/packages/server/modules/webhooks/index.ts @@ -2,14 +2,18 @@ import cron from 'node-cron' import { SpeckleModule } from '@/modules/shared/helpers/typeHelper' import { activitiesLogger, moduleLogger } from '@/logging/logging' import { scheduleExecutionFactory } from '@/modules/core/services/taskScheduler' -import { acquireTaskLockFactory } from '@/modules/core/repositories/scheduledTasks' +import { + acquireTaskLockFactory, + releaseTaskLockFactory +} from '@/modules/core/repositories/scheduledTasks' import { cleanOrphanedWebhookConfigsFactory } from '@/modules/webhooks/repositories/cleanup' import { Knex } from 'knex' import { db } from '@/db/knex' const scheduleWebhookCleanupFactory = ({ db }: { db: Knex }) => { const scheduleExecution = scheduleExecutionFactory({ - acquireTaskLock: acquireTaskLockFactory({ db }) + acquireTaskLock: acquireTaskLockFactory({ db }), + releaseTaskLock: releaseTaskLockFactory({ db }) }) const cronExpression = '0 4 * * 1' diff --git a/packages/server/modules/workspaces/domain/operations.ts b/packages/server/modules/workspaces/domain/operations.ts index 56afe1175..010bd5880 100644 --- a/packages/server/modules/workspaces/domain/operations.ts +++ b/packages/server/modules/workspaces/domain/operations.ts @@ -217,7 +217,7 @@ export type UpdateWorkspaceProjectRole = ( export type EmitWorkspaceEvent = (args: { eventName: TEvent payload: EventBusPayloads[TEvent] -}) => Promise +}) => Promise export type CountProjectsVersionsByWorkspaceId = (args: { workspaceId: string diff --git a/packages/server/modules/workspaces/graph/resolvers/workspaces.ts b/packages/server/modules/workspaces/graph/resolvers/workspaces.ts index ad271e594..2dc5eff7d 100644 --- a/packages/server/modules/workspaces/graph/resolvers/workspaces.ts +++ b/packages/server/modules/workspaces/graph/resolvers/workspaces.ts @@ -42,7 +42,6 @@ import { import { createProjectInviteFactory } from '@/modules/serverinvites/services/projectInviteManagement' import { getInvitationTargetUsersFactory } from '@/modules/serverinvites/services/retrieval' import { authorizeResolver } from '@/modules/shared' -import { withTransaction } from '@/modules/shared/helpers/dbHelper' import { getFeatureFlags } from '@/modules/shared/helpers/envHelper' import { getEventBus } from '@/modules/shared/services/eventBus' import { WorkspaceInviteResourceType } from '@/modules/workspaces/domain/constants' @@ -149,7 +148,10 @@ import { publish } from '@/modules/shared/utils/subscriptions' import { updateStreamRoleAndNotifyFactory } from '@/modules/core/services/streams/management' import { getUserFactory, getUsersFactory } from '@/modules/core/repositories/users' import { getServerInfoFactory } from '@/modules/core/repositories/server' +import { commandFactory } from '@/modules/shared/command' +import { withTransaction } from '@/modules/shared/helpers/dbHelper' +const eventBus = getEventBus() const getServerInfo = getServerInfoFactory({ db }) const getUser = getUserFactory({ db }) const getUsers = getUsersFactory({ db }) @@ -456,36 +458,34 @@ export = FF_WORKSPACES_MODULE_ENABLED ) if (!role) { + // this is currently not working with the command factory + // TODO: include the onWorkspaceRoleDeletedFactory listener service const trx = await db.transaction() - const deleteWorkspaceRole = deleteWorkspaceRoleFactory({ deleteWorkspaceRole: repoDeleteWorkspaceRoleFactory({ db: trx }), getWorkspaceRoles: getWorkspaceRolesFactory({ db: trx }), emitWorkspaceEvent: getEventBus().emit }) - - await withTransaction(deleteWorkspaceRole(args.input), trx) + await withTransaction(deleteWorkspaceRole({ workspaceId, userId }), trx) } else { if (!isWorkspaceRole(role)) { throw new WorkspaceInvalidRoleError() } - - const trx = await db.transaction() - - const updateWorkspaceRole = updateWorkspaceRoleFactory({ - upsertWorkspaceRole: upsertWorkspaceRoleFactory({ db: trx }), - getWorkspaceWithDomains: getWorkspaceWithDomainsFactory({ db: trx }), - findVerifiedEmailsByUserId: findVerifiedEmailsByUserIdFactory({ - db: trx - }), - getWorkspaceRoles: getWorkspaceRolesFactory({ db: trx }), - emitWorkspaceEvent: getEventBus().emit + const updateWorkspaceRole = commandFactory({ + db, + eventBus, + operationFactory: ({ db, emit }) => + updateWorkspaceRoleFactory({ + upsertWorkspaceRole: upsertWorkspaceRoleFactory({ db }), + getWorkspaceWithDomains: getWorkspaceWithDomainsFactory({ db }), + findVerifiedEmailsByUserId: findVerifiedEmailsByUserIdFactory({ + db + }), + getWorkspaceRoles: getWorkspaceRolesFactory({ db }), + emitWorkspaceEvent: emit + }) }) - - await withTransaction( - updateWorkspaceRole({ userId, workspaceId, role }), - trx - ) + await updateWorkspaceRole({ userId, workspaceId, role }) } return await getWorkspaceFactory({ db })({ workspaceId }) @@ -559,19 +559,18 @@ export = FF_WORKSPACES_MODULE_ENABLED }) }, leave: async (_parent, args, context) => { + // this is currently not working with the command factory + // TODO: include the onWorkspaceRoleDeletedFactory listener service const trx = await db.transaction() - const deleteWorkspaceRole = deleteWorkspaceRoleFactory({ deleteWorkspaceRole: repoDeleteWorkspaceRoleFactory({ db: trx }), getWorkspaceRoles: getWorkspaceRolesFactory({ db: trx }), emitWorkspaceEvent: getEventBus().emit }) - await withTransaction( deleteWorkspaceRole({ workspaceId: args.id, userId: context.userId! }), trx ) - return true }, invites: () => ({}), @@ -770,33 +769,33 @@ export = FF_WORKSPACES_MODULE_ENABLED context.resourceAccessRules ) - const trx = await db.transaction() - - const moveProjectToWorkspace = moveProjectToWorkspaceFactory({ - getProject: getProjectFactory({ db }), - updateProject: updateProjectFactory({ db: trx }), - upsertProjectRole: upsertProjectRoleFactory({ db: trx }), - getProjectCollaborators: getProjectCollaboratorsFactory({ db }), - getWorkspaceRoles: getWorkspaceRolesFactory({ db: trx }), - getWorkspaceRoleToDefaultProjectRoleMapping: - getWorkspaceRoleToDefaultProjectRoleMappingFactory({ - getWorkspace: getWorkspaceFactory({ db }) - }), - updateWorkspaceRole: updateWorkspaceRoleFactory({ - getWorkspaceRoles: getWorkspaceRolesFactory({ db: trx }), - getWorkspaceWithDomains: getWorkspaceWithDomainsFactory({ db: trx }), - findVerifiedEmailsByUserId: findVerifiedEmailsByUserIdFactory({ - db: trx - }), - upsertWorkspaceRole: upsertWorkspaceRoleFactory({ db: trx }), - emitWorkspaceEvent: getEventBus().emit - }) + const moveProjectToWorkspace = commandFactory({ + db, + eventBus, + operationFactory: ({ db, emit }) => + moveProjectToWorkspaceFactory({ + getProject: getProjectFactory({ db }), + updateProject: updateProjectFactory({ db }), + upsertProjectRole: upsertProjectRoleFactory({ db }), + getProjectCollaborators: getProjectCollaboratorsFactory({ db }), + getWorkspaceRoles: getWorkspaceRolesFactory({ db }), + getWorkspaceRoleToDefaultProjectRoleMapping: + getWorkspaceRoleToDefaultProjectRoleMappingFactory({ + getWorkspace: getWorkspaceFactory({ db }) + }), + updateWorkspaceRole: updateWorkspaceRoleFactory({ + getWorkspaceRoles: getWorkspaceRolesFactory({ db }), + getWorkspaceWithDomains: getWorkspaceWithDomainsFactory({ db }), + findVerifiedEmailsByUserId: findVerifiedEmailsByUserIdFactory({ + db + }), + upsertWorkspaceRole: upsertWorkspaceRoleFactory({ db }), + emitWorkspaceEvent: emit + }) + }) }) - return await withTransaction( - moveProjectToWorkspace({ projectId, workspaceId }), - trx - ) + return await moveProjectToWorkspace({ projectId, workspaceId }) } }, Workspace: { diff --git a/packages/server/modules/workspaces/tests/unit/services/join.spec.ts b/packages/server/modules/workspaces/tests/unit/services/join.spec.ts index 2b3fdfdef..70143fbe1 100644 --- a/packages/server/modules/workspaces/tests/unit/services/join.spec.ts +++ b/packages/server/modules/workspaces/tests/unit/services/join.spec.ts @@ -123,7 +123,6 @@ describe('Workspace join services', () => { }, emitWorkspaceEvent: async ({ eventName }) => { firedEvents.push(eventName) - return [] } })({ userId, workspaceId }) diff --git a/packages/server/modules/workspaces/tests/unit/services/management.spec.ts b/packages/server/modules/workspaces/tests/unit/services/management.spec.ts index 74cff4e41..7a4803ed0 100644 --- a/packages/server/modules/workspaces/tests/unit/services/management.spec.ts +++ b/packages/server/modules/workspaces/tests/unit/services/management.spec.ts @@ -79,7 +79,6 @@ const buildCreateWorkspaceWithTestContext = ( context.eventData.isCalled = true context.eventData.eventName = eventName context.eventData.payload = payload - return [] }, ...dependencyOverrides } @@ -408,9 +407,7 @@ describe('Workspace services', () => { let newWorkspaceName await updateWorkspaceFactory({ getWorkspace: async () => workspace, - emitWorkspaceEvent: async () => { - return [] - }, + emitWorkspaceEvent: async () => {}, validateSlug: async () => {}, upsertWorkspace: async ({ workspace }) => { @@ -448,9 +445,7 @@ describe('Workspace services', () => { await updateWorkspaceFactory({ getWorkspace: async () => workspace, - emitWorkspaceEvent: async () => { - return [] - }, + emitWorkspaceEvent: async () => {}, validateSlug: async () => {}, upsertWorkspace: async ({ workspace }) => { updatedWorkspace = workspace @@ -544,8 +539,6 @@ const buildDeleteWorkspaceRoleAndTestContext = ( break } } - - return [] }, ...dependencyOverrides } @@ -622,8 +615,6 @@ const buildUpdateWorkspaceRoleAndTestContext = ( break } } - - return [] }, ...dependencyOverrides } @@ -1205,7 +1196,6 @@ describe('Workspace role services', () => { }, emitWorkspaceEvent: async ({ eventName }) => { omittedEventName = eventName - return [] }, storeWorkspaceDomain: async ({ workspaceDomain }) => { storedDomains = workspaceDomain @@ -1272,9 +1262,7 @@ describe('Workspace role services', () => { upsertWorkspace: async ({ workspace }) => { workspaceData = { ...workspaceData, ...workspace } }, - emitWorkspaceEvent: async () => { - return [] - }, + emitWorkspaceEvent: async () => {}, storeWorkspaceDomain: async ({ workspaceDomain }) => { insertedDomains.push(workspaceDomain) } diff --git a/workspace.code-workspace b/workspace.code-workspace index 75d24d382..48902a487 100644 --- a/workspace.code-workspace +++ b/workspace.code-workspace @@ -97,7 +97,8 @@ "Encryptor", "Insertable", "mjml", - "OIDC" + "OIDC", + "Prorotation" ], "tailwindCSS.experimental.configFile": { "packages/frontend-2/tailwind.config.mjs": "packages/frontend-2/**"