From 5804ee4b5068b8084ced284c80665538636fddca Mon Sep 17 00:00:00 2001 From: Iain Sproat <68657+iainsproat@users.noreply.github.com> Date: Wed, 12 Mar 2025 09:28:38 +0000 Subject: [PATCH 1/9] fix(preview service): handle shutdown via terminus --- packages/preview-service/package.json | 1 + packages/preview-service/src/main.ts | 34 ++++++++++++------- .../server/modules/previews/resultListener.ts | 2 +- yarn.lock | 10 ++++++ 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/packages/preview-service/package.json b/packages/preview-service/package.json index 23a39884b..3ee4919a4 100644 --- a/packages/preview-service/package.json +++ b/packages/preview-service/package.json @@ -27,6 +27,7 @@ "build": "tsc -p ./tsconfig.build.json" }, "dependencies": { + "@godaddy/terminus": "^4.12.1", "@speckle/shared": "workspace:^", "bull": "^4.16.4", "dotenv": "^16.4.7", diff --git a/packages/preview-service/src/main.ts b/packages/preview-service/src/main.ts index 641898fc3..7bfce860e 100644 --- a/packages/preview-service/src/main.ts +++ b/packages/preview-service/src/main.ts @@ -15,6 +15,7 @@ import { jobProcessor } from '@/jobProcessor.js' import { Redis, RedisOptions } from 'ioredis' import { jobPayload } from '@speckle/shared/dist/esm/previews/job.js' import { initMetrics, initPrometheusRegistry } from '@/metrics.js' +import { createTerminus } from '@godaddy/terminus' const app = express() const host = HOST @@ -122,25 +123,32 @@ const server = app.listen(port, host, async () => { }) }) -const shutdown = async () => { - // stop accepting new jobs +const beforeShutdown = async () => { + logger.info('🛑 Beginning shut down, pausing all jobs') + // stop accepting new jobs and kill any running jobs await jobQueue.pause( true, // just pausing this local worker of the queue true // do not wait for active jobs to finish ) - // if there is a job currently running, cancell it with an error if (jobDoneCallback) { - jobDoneCallback(new Error('Job cancelled due to perview-service shutdown')) + logger.warn('Cancelling job due to preview-service shutdown') + jobDoneCallback(new Error('Job cancelled due to preview-service shutdown')) } - - logger.info('Received signal to shut down') - server.close(() => { - logger.debug('Exiting the express server') - process.exit() - }) } -process.on('SIGINT', async () => await shutdown()) -process.on('SIGQUIT', async () => await shutdown()) -process.on('SIGABRT', async () => await shutdown()) +const onShutdown = async () => { + logger.info('👋 Completed shut down, now exiting') +} + +createTerminus(server, { + beforeShutdown, + onShutdown, + logger: (msg, err) => { + if (err) { + logger.error({ err }, msg) + return + } + logger.info(msg) + } +}) diff --git a/packages/server/modules/previews/resultListener.ts b/packages/server/modules/previews/resultListener.ts index 1d3ca199c..23b2f08da 100644 --- a/packages/server/modules/previews/resultListener.ts +++ b/packages/server/modules/previews/resultListener.ts @@ -84,7 +84,7 @@ export const consumePreviewResultFactory = switch (previewResult.status) { case 'error': - log.error(previewMessage) + log.error({ reason: previewResult.reason }, previewMessage) await upsertObjectPreview({ objectPreview: { objectId, diff --git a/yarn.lock b/yarn.lock index 0f1d5306a..77cc393fc 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10526,6 +10526,15 @@ __metadata: languageName: node linkType: hard +"@godaddy/terminus@npm:^4.12.1": + version: 4.12.1 + resolution: "@godaddy/terminus@npm:4.12.1" + dependencies: + stoppable: "npm:^1.1.0" + checksum: 10/e1b6e0a079db5748c71211e22d2fc1ab76ee20763bb552e9e2b6003330c0b19fee10b20b1dd961218a402c6aadd3a8ecf7f03ecdfaa2c65ed072b2479eb9b517 + languageName: node + linkType: hard + "@godaddy/terminus@npm:^4.9.0": version: 4.10.2 resolution: "@godaddy/terminus@npm:4.10.2" @@ -16752,6 +16761,7 @@ __metadata: version: 0.0.0-use.local resolution: "@speckle/preview-service@workspace:packages/preview-service" dependencies: + "@godaddy/terminus": "npm:^4.12.1" "@speckle/shared": "workspace:^" "@swc/cli": "npm:^0.5.1" "@swc/core": "npm:^1.9.3" From 96bfedefe8cbea10342e95cef2892568f544dcf2 Mon Sep 17 00:00:00 2001 From: Alessandro Magionami Date: Tue, 18 Mar 2025 13:05:19 +0100 Subject: [PATCH 2/9] chore(gatekeeper): take end billing cycle date from stripe --- .../server/modules/gatekeeper/clients/stripe.ts | 3 ++- .../server/modules/gatekeeper/domain/billing.ts | 10 +++++----- .../server/modules/gatekeeper/services/checkout.ts | 13 +------------ packages/server/modules/gatekeeper/tests/helpers.ts | 7 +++++-- .../modules/workspaces/tests/helpers/creation.ts | 5 ++++- 5 files changed, 17 insertions(+), 21 deletions(-) diff --git a/packages/server/modules/gatekeeper/clients/stripe.ts b/packages/server/modules/gatekeeper/clients/stripe.ts index 6f8b26d28..c73127c89 100644 --- a/packages/server/modules/gatekeeper/clients/stripe.ts +++ b/packages/server/modules/gatekeeper/clients/stripe.ts @@ -66,6 +66,7 @@ export const parseSubscriptionData = ( cancelAt: stripeSubscription.cancel_at ? new Date(stripeSubscription.cancel_at * 1000) : null, + currentPeriodEnd: stripeSubscription.current_period_end, products: stripeSubscription.items.data.map((subscriptionItem) => { const productId = typeof subscriptionItem.price.product === 'string' @@ -84,7 +85,7 @@ export const parseSubscriptionData = ( } }) } - return subscriptionData + return SubscriptionData.parse(subscriptionData) } // this should be a reconcile subscriptions, we keep an accurate state in the DB diff --git a/packages/server/modules/gatekeeper/domain/billing.ts b/packages/server/modules/gatekeeper/domain/billing.ts index 9795b5822..7e3a2dd94 100644 --- a/packages/server/modules/gatekeeper/domain/billing.ts +++ b/packages/server/modules/gatekeeper/domain/billing.ts @@ -113,7 +113,7 @@ const subscriptionProduct = z.object({ export type SubscriptionProduct = z.infer -export const subscriptionData = z.object({ +export const SubscriptionData = z.object({ subscriptionId: z.string().min(1), customerId: z.string().min(1), cancelAt: z.date().nullable(), @@ -127,8 +127,11 @@ export const subscriptionData = z.object({ z.literal('unpaid'), z.literal('paused') ]), - products: subscriptionProduct.array() + products: subscriptionProduct.array(), + currentPeriodEnd: z.coerce.date() }) +// this abstracts the stripe sub data +export type SubscriptionData = z.infer export const calculateSubscriptionSeats = ({ subscriptionData, @@ -147,9 +150,6 @@ export const calculateSubscriptionSeats = ({ return { guest: guestProduct?.quantity || 0, plan: planProduct?.quantity || 0 } } -// this abstracts the stripe sub data -export type SubscriptionData = z.infer - export type UpsertWorkspaceSubscription = (args: { workspaceSubscription: WorkspaceSubscription }) => Promise diff --git a/packages/server/modules/gatekeeper/services/checkout.ts b/packages/server/modules/gatekeeper/services/checkout.ts index 1d6ab0f5b..f28030db1 100644 --- a/packages/server/modules/gatekeeper/services/checkout.ts +++ b/packages/server/modules/gatekeeper/services/checkout.ts @@ -66,18 +66,7 @@ export const completeCheckoutSessionFactory = const subscriptionData = await getSubscriptionData({ subscriptionId }) - const currentBillingCycleEnd = new Date() - switch (checkoutSession.billingInterval) { - case 'monthly': - currentBillingCycleEnd.setMonth(currentBillingCycleEnd.getMonth() + 1) - break - case 'yearly': - currentBillingCycleEnd.setMonth(currentBillingCycleEnd.getMonth() + 12) - break - - default: - throwUncoveredError(checkoutSession.billingInterval) - } + const currentBillingCycleEnd = subscriptionData.currentPeriodEnd const workspaceSubscription = { createdAt: new Date(), diff --git a/packages/server/modules/gatekeeper/tests/helpers.ts b/packages/server/modules/gatekeeper/tests/helpers.ts index b55bb5205..6fec9ef02 100644 --- a/packages/server/modules/gatekeeper/tests/helpers.ts +++ b/packages/server/modules/gatekeeper/tests/helpers.ts @@ -8,7 +8,9 @@ import { assign } from 'lodash' export const createTestSubscriptionData = ( overrides: Partial = {} ): SubscriptionData => { - const defaultValues: SubscriptionData = { + const aMonthFromNow = new Date() + aMonthFromNow.setMonth(new Date().getMonth() + 1) + const defaultValues = { cancelAt: null, customerId: cryptoRandomString({ length: 10 }), products: [ @@ -20,7 +22,8 @@ export const createTestSubscriptionData = ( } ], status: 'active', - subscriptionId: cryptoRandomString({ length: 10 }) + subscriptionId: cryptoRandomString({ length: 10 }), + currentPeriodEnd: aMonthFromNow.toISOString() } return assign(defaultValues, overrides) } diff --git a/packages/server/modules/workspaces/tests/helpers/creation.ts b/packages/server/modules/workspaces/tests/helpers/creation.ts index c71de513b..0fd4b7171 100644 --- a/packages/server/modules/workspaces/tests/helpers/creation.ts +++ b/packages/server/modules/workspaces/tests/helpers/creation.ts @@ -201,6 +201,8 @@ export const createTestWorkspace = async ( } if (addSubscription) { + const aMonthFromNow = new Date() + aMonthFromNow.setMonth(new Date().getMonth() + 1) await upsertSubscription({ workspaceSubscription: { workspaceId: newWorkspace.id, @@ -213,7 +215,8 @@ export const createTestWorkspace = async ( customerId: cryptoRandomString({ length: 10 }), cancelAt: null, status: 'active', - products: [] + products: [], + currentPeriodEnd: aMonthFromNow } } }) From 194a1fe6074d97b2e35d58b8fbcab8d71079a067 Mon Sep 17 00:00:00 2001 From: Alessandro Magionami Date: Thu, 20 Mar 2025 16:29:32 +0100 Subject: [PATCH 3/9] feat(gatekeeper): downscale new plans --- packages/server/modules/gatekeeper/index.ts | 47 +++- .../gatekeeper/repositories/billing.ts | 45 +++- .../gatekeeper/services/subscriptions.ts | 123 +-------- .../manageSubscriptionDownscale.ts | 240 ++++++++++++++++++ .../intergration/billingRepositories.spec.ts | 19 +- .../tests/unit/subscriptions.spec.ts | 204 +++++++++++++-- 6 files changed, 523 insertions(+), 155 deletions(-) create mode 100644 packages/server/modules/gatekeeper/services/subscriptions/manageSubscriptionDownscale.ts diff --git a/packages/server/modules/gatekeeper/index.ts b/packages/server/modules/gatekeeper/index.ts index 88b1d1fd6..6ee4a6a90 100644 --- a/packages/server/modules/gatekeeper/index.ts +++ b/packages/server/modules/gatekeeper/index.ts @@ -14,23 +14,23 @@ import { acquireTaskLockFactory, releaseTaskLockFactory } from '@/modules/core/repositories/scheduledTasks' -import { - downscaleWorkspaceSubscriptionFactory, - manageSubscriptionDownscaleFactory -} from '@/modules/gatekeeper/services/subscriptions' import { changeExpiredTrialWorkspacePlanStatusesFactory, getWorkspacePlanByProjectIdFactory, getWorkspacePlanFactory, getWorkspacesByPlanAgeFactory, - getWorkspaceSubscriptionsPastBillingCycleEndFactory, + getWorkspaceSubscriptionsPastBillingCycleEndFactoryNewPlans, + getWorkspaceSubscriptionsPastBillingCycleEndFactoryOldPlans, upsertWorkspaceSubscriptionFactory } from '@/modules/gatekeeper/repositories/billing' import { countWorkspaceRoleWithOptionalProjectRoleFactory, getWorkspaceCollaboratorsFactory } from '@/modules/workspaces/repositories/workspaces' -import { reconcileWorkspaceSubscriptionFactory } from '@/modules/gatekeeper/clients/stripe' +import { + getSubscriptionDataFactory, + reconcileWorkspaceSubscriptionFactory +} from '@/modules/gatekeeper/clients/stripe' import { ScheduleExecution } from '@/modules/core/domain/scheduledTasks/operations' import { EventBusEmit, getEventBus } from '@/modules/shared/services/eventBus' import { sendWorkspaceTrialExpiresEmailFactory } from '@/modules/gatekeeper/services/trialEmails' @@ -42,6 +42,13 @@ import coreModule from '@/modules/core/index' import { isProjectReadOnlyFactory } from '@/modules/gatekeeper/services/readOnly' import { WorkspaceReadOnlyError } from '@/modules/gatekeeper/errors/billing' import { InvalidLicenseError } from '@/modules/gatekeeper/errors/license' +import { + downscaleWorkspaceSubscriptionFactoryNew, + downscaleWorkspaceSubscriptionFactoryOld, + manageSubscriptionDownscaleFactoryNew, + manageSubscriptionDownscaleFactoryOld +} from '@/modules/gatekeeper/services/subscriptions/manageSubscriptionDownscale' +import { countSeatsByTypeInWorkspaceFactory } from '@/modules/gatekeeper/repositories/workspaceSeat' const { FF_GATEKEEPER_MODULE_ENABLED, FF_BILLING_INTEGRATION_ENABLED } = getFeatureFlags() @@ -58,16 +65,31 @@ const scheduleWorkspaceSubscriptionDownscale = ({ }) => { const stripe = getStripeClient() - const manageSubscriptionDownscale = manageSubscriptionDownscaleFactory({ - downscaleWorkspaceSubscription: downscaleWorkspaceSubscriptionFactory({ + const manageSubscriptionDownscaleOld = manageSubscriptionDownscaleFactoryOld({ + downscaleWorkspaceSubscription: downscaleWorkspaceSubscriptionFactoryOld({ countWorkspaceRole: countWorkspaceRoleWithOptionalProjectRoleFactory({ db }), getWorkspacePlan: getWorkspacePlanFactory({ db }), reconcileSubscriptionData: reconcileWorkspaceSubscriptionFactory({ stripe }), getWorkspacePlanProductId }), - getWorkspaceSubscriptions: getWorkspaceSubscriptionsPastBillingCycleEndFactory({ - db + getWorkspaceSubscriptions: + getWorkspaceSubscriptionsPastBillingCycleEndFactoryOldPlans({ + db + }), + updateWorkspaceSubscription: upsertWorkspaceSubscriptionFactory({ db }) + }) + const manageSubscriptionDownscaleNew = manageSubscriptionDownscaleFactoryNew({ + downscaleWorkspaceSubscription: downscaleWorkspaceSubscriptionFactoryNew({ + countSeatsByTypeInWorkspace: countSeatsByTypeInWorkspaceFactory({ db }), + getWorkspacePlan: getWorkspacePlanFactory({ db }), + reconcileSubscriptionData: reconcileWorkspaceSubscriptionFactory({ stripe }), + getWorkspacePlanProductId }), + getWorkspaceSubscriptions: + getWorkspaceSubscriptionsPastBillingCycleEndFactoryNewPlans({ + db + }), + getSubscriptionData: getSubscriptionDataFactory({ stripe }), updateWorkspaceSubscription: upsertWorkspaceSubscriptionFactory({ db }) }) @@ -76,7 +98,10 @@ const scheduleWorkspaceSubscriptionDownscale = ({ cronExpression, 'WorkspaceSubscriptionDownscale', async (_scheduledTime, { logger }) => { - await manageSubscriptionDownscale({ logger }) + await Promise.all([ + manageSubscriptionDownscaleOld({ logger }), // Only takes old plans subscriptions + manageSubscriptionDownscaleNew({ logger }) // Only takes new plans subscriptions + ]) } ) } diff --git a/packages/server/modules/gatekeeper/repositories/billing.ts b/packages/server/modules/gatekeeper/repositories/billing.ts index 23651121b..c7663eee1 100644 --- a/packages/server/modules/gatekeeper/repositories/billing.ts +++ b/packages/server/modules/gatekeeper/repositories/billing.ts @@ -27,6 +27,7 @@ import { WorkspacePlan } from '@/modules/gatekeeperCore/domain/billing' import { formatJsonArrayRecords } from '@/modules/shared/helpers/dbHelper' import { Workspace } from '@/modules/workspacesCore/domain/types' import { Workspaces } from '@/modules/workspacesCore/helpers/db' +import { PaidWorkspacePlansNew, PaidWorkspacePlansOld } from '@speckle/shared' import { Knex } from 'knex' import { omit } from 'lodash' @@ -212,15 +213,55 @@ export const getWorkspaceSubscriptionBySubscriptionIdFactory = return subscription ?? null } -export const getWorkspaceSubscriptionsPastBillingCycleEndFactory = +const newPlans = Object.values(PaidWorkspacePlansNew) +const oldPlans = Object.values(PaidWorkspacePlansOld) + +export const getWorkspaceSubscriptionsPastBillingCycleEndFactoryOldPlans = ({ db }: { db: Knex }): GetWorkspaceSubscriptions => async () => { const cycleEnd = new Date() cycleEnd.setMinutes(cycleEnd.getMinutes() + 5) return await tables .workspaceSubscriptions(db) - .select() + .join( + WorkspacePlans.name, + WorkspacePlans.col.workspaceId, + 'workspace_subscriptions.workspaceId' + ) + .whereIn(WorkspacePlans.col.name, oldPlans) .where('currentBillingCycleEnd', '<', cycleEnd) + .select([ + 'workspace_subscriptions.workspaceId', + 'workspace_subscriptions.createdAt', + 'workspace_subscriptions.updatedAt', + 'workspace_subscriptions.currentBillingCycleEnd', + 'workspace_subscriptions.billingInterval', + 'workspace_subscriptions.subscriptionData' + ]) + } + +export const getWorkspaceSubscriptionsPastBillingCycleEndFactoryNewPlans = + ({ db }: { db: Knex }): GetWorkspaceSubscriptions => + async () => { + const cycleEnd = new Date() + cycleEnd.setMinutes(cycleEnd.getMinutes() + 5) + return await tables + .workspaceSubscriptions(db) + .join( + WorkspacePlans.name, + WorkspacePlans.col.workspaceId, + 'workspace_subscriptions.workspaceId' + ) + .whereIn(WorkspacePlans.col.name, newPlans) + .where('currentBillingCycleEnd', '<', cycleEnd) + .select([ + 'workspace_subscriptions.workspaceId', + 'workspace_subscriptions.createdAt', + 'workspace_subscriptions.updatedAt', + 'workspace_subscriptions.currentBillingCycleEnd', + 'workspace_subscriptions.billingInterval', + 'workspace_subscriptions.subscriptionData' + ]) } export const getWorkspacePlanByProjectIdFactory = diff --git a/packages/server/modules/gatekeeper/services/subscriptions.ts b/packages/server/modules/gatekeeper/services/subscriptions.ts index a069545fd..6440c5e7d 100644 --- a/packages/server/modules/gatekeeper/services/subscriptions.ts +++ b/packages/server/modules/gatekeeper/services/subscriptions.ts @@ -1,18 +1,15 @@ -import type { Logger } from '@/observability/logging' import { GetWorkspacePlan, GetWorkspacePlanPriceId, GetWorkspacePlanProductId, GetWorkspaceSubscription, GetWorkspaceSubscriptionBySubscriptionId, - GetWorkspaceSubscriptions, ReconcileSubscriptionData, SubscriptionData, SubscriptionDataInput, UpsertPaidWorkspacePlan, UpsertWorkspaceSubscription, - WorkspaceSeatType, - WorkspaceSubscription + WorkspaceSeatType } from '@/modules/gatekeeper/domain/billing' import { WorkspacePlanMismatchError, @@ -27,9 +24,7 @@ import { throwUncoveredError, WorkspaceRoles } from '@speckle/shared' -import { cloneDeep, isEqual, sum } from 'lodash' -import { mutateSubscriptionDataWithNewValidSeatNumbers } from '@/modules/gatekeeper/services/subscriptions/mutateSubscriptionDataWithNewValidSeatNumbers' -import { calculateNewBillingCycleEnd } from '@/modules/gatekeeper/services/subscriptions/calculateNewBillingCycleEnd' +import { cloneDeep, sum } from 'lodash' import { CountSeatsByTypeInWorkspace } from '@/modules/gatekeeper/domain/operations' export const handleSubscriptionUpdateFactory = @@ -297,117 +292,3 @@ export const addWorkspaceSubscriptionSeatIfNeededFactoryOld = prorationBehavior: 'create_prorations' }) } - -type DownscaleWorkspaceSubscription = (args: { - workspaceSubscription: WorkspaceSubscription -}) => Promise - -export const downscaleWorkspaceSubscriptionFactory = - ({ - getWorkspacePlan, - countWorkspaceRole, - getWorkspacePlanProductId, - reconcileSubscriptionData - }: { - getWorkspacePlan: GetWorkspacePlan - countWorkspaceRole: CountWorkspaceRoleWithOptionalProjectRole - getWorkspacePlanProductId: GetWorkspacePlanProductId - reconcileSubscriptionData: ReconcileSubscriptionData - }): DownscaleWorkspaceSubscription => - async ({ workspaceSubscription }) => { - const workspaceId = workspaceSubscription.workspaceId - - const workspacePlan = await getWorkspacePlan({ workspaceId }) - if (!workspacePlan) throw new WorkspacePlanNotFoundError() - - switch (workspacePlan.name) { - case 'team': - case 'pro': - // Cause seat types matter, a future issue - throw new NotImplementedError() - case 'starter': - case 'plus': - case 'business': - break - case 'unlimited': - case 'academia': - case 'starterInvoiced': - case 'plusInvoiced': - case 'businessInvoiced': - case 'free': - throw new WorkspacePlanMismatchError() - default: - throwUncoveredError(workspacePlan) - } - - if (workspacePlan.status === 'canceled') return false - - // TODO: Guests will be able to have a paid seat - const [guestCount, memberCount, adminCount] = await Promise.all([ - countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:guest' }), - countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:member' }), - countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:admin' }) - ]) - - const subscriptionData = cloneDeep(workspaceSubscription.subscriptionData) - - mutateSubscriptionDataWithNewValidSeatNumbers({ - seatCount: guestCount, - workspacePlan: 'guest', - getWorkspacePlanProductId, - subscriptionData - }) - mutateSubscriptionDataWithNewValidSeatNumbers({ - seatCount: memberCount + adminCount, - workspacePlan: workspacePlan.name, - getWorkspacePlanProductId, - subscriptionData - }) - - if (!isEqual(subscriptionData, workspaceSubscription.subscriptionData)) { - await reconcileSubscriptionData({ subscriptionData, prorationBehavior: 'none' }) - return true - } - return false - } - -export const manageSubscriptionDownscaleFactory = - ({ - getWorkspaceSubscriptions, - downscaleWorkspaceSubscription, - updateWorkspaceSubscription - }: { - getWorkspaceSubscriptions: GetWorkspaceSubscriptions - downscaleWorkspaceSubscription: DownscaleWorkspaceSubscription - updateWorkspaceSubscription: UpsertWorkspaceSubscription - }) => - async (context: { logger: Logger }) => { - const { logger } = context - const subscriptions = await getWorkspaceSubscriptions() - for (const workspaceSubscription of subscriptions) { - const log = logger.child({ workspaceId: workspaceSubscription.workspaceId }) - try { - const subDownscaled = await downscaleWorkspaceSubscription({ - workspaceSubscription - }) - if (subDownscaled) { - log.info( - 'Downscaled workspace subscription to match the current workspace team' - ) - } else { - log.info('Did not need to downscale the workspace subscription') - } - } catch (err) { - log.error({ err }, 'Failed to downscale workspace subscription') - } - const newBillingCycleEnd = calculateNewBillingCycleEnd({ workspaceSubscription }) - const updatedWorkspaceSubscription = { - ...workspaceSubscription, - currentBillingCycleEnd: newBillingCycleEnd - } - await updateWorkspaceSubscription({ - workspaceSubscription: updatedWorkspaceSubscription - }) - log.info({ updatedWorkspaceSubscription }, 'Updated workspace billing cycle end') - } - } diff --git a/packages/server/modules/gatekeeper/services/subscriptions/manageSubscriptionDownscale.ts b/packages/server/modules/gatekeeper/services/subscriptions/manageSubscriptionDownscale.ts new file mode 100644 index 000000000..6ccb12618 --- /dev/null +++ b/packages/server/modules/gatekeeper/services/subscriptions/manageSubscriptionDownscale.ts @@ -0,0 +1,240 @@ +import { + GetSubscriptionData, + GetWorkspacePlan, + GetWorkspacePlanProductId, + GetWorkspaceSubscriptions, + ReconcileSubscriptionData, + UpsertWorkspaceSubscription, + WorkspaceSubscription +} from '@/modules/gatekeeper/domain/billing' +import { CountSeatsByTypeInWorkspace } from '@/modules/gatekeeper/domain/operations' +import { + WorkspacePlanMismatchError, + WorkspacePlanNotFoundError +} from '@/modules/gatekeeper/errors/billing' +import { calculateNewBillingCycleEnd } from '@/modules/gatekeeper/services/subscriptions/calculateNewBillingCycleEnd' +import { mutateSubscriptionDataWithNewValidSeatNumbers } from '@/modules/gatekeeper/services/subscriptions/mutateSubscriptionDataWithNewValidSeatNumbers' +import { NotImplementedError } from '@/modules/shared/errors' +import { CountWorkspaceRoleWithOptionalProjectRole } from '@/modules/workspaces/domain/operations' +import { Logger } from '@/observability/logging' +import { throwUncoveredError } from '@speckle/shared' +import { cloneDeep, isEqual } from 'lodash' + +type DownscaleWorkspaceSubscription = (args: { + workspaceSubscription: WorkspaceSubscription +}) => Promise + +export const downscaleWorkspaceSubscriptionFactoryOld = + ({ + getWorkspacePlan, + countWorkspaceRole, + getWorkspacePlanProductId, + reconcileSubscriptionData + }: { + getWorkspacePlan: GetWorkspacePlan + countWorkspaceRole: CountWorkspaceRoleWithOptionalProjectRole + getWorkspacePlanProductId: GetWorkspacePlanProductId + reconcileSubscriptionData: ReconcileSubscriptionData + }): DownscaleWorkspaceSubscription => + async ({ workspaceSubscription }) => { + const workspaceId = workspaceSubscription.workspaceId + + const workspacePlan = await getWorkspacePlan({ workspaceId }) + if (!workspacePlan) throw new WorkspacePlanNotFoundError() + + switch (workspacePlan.name) { + case 'team': + case 'pro': + // Cause seat types matter, a future issue + throw new NotImplementedError() + case 'starter': + case 'plus': + case 'business': + break + case 'unlimited': + case 'academia': + case 'starterInvoiced': + case 'plusInvoiced': + case 'businessInvoiced': + case 'free': + throw new WorkspacePlanMismatchError() + default: + throwUncoveredError(workspacePlan) + } + + if (workspacePlan.status === 'canceled') return false + + // TODO: Guests will be able to have a paid seat + const [guestCount, memberCount, adminCount] = await Promise.all([ + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:guest' }), + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:member' }), + countWorkspaceRole({ workspaceId, workspaceRole: 'workspace:admin' }) + ]) + + const subscriptionData = cloneDeep(workspaceSubscription.subscriptionData) + + mutateSubscriptionDataWithNewValidSeatNumbers({ + seatCount: guestCount, + workspacePlan: 'guest', + getWorkspacePlanProductId, + subscriptionData + }) + mutateSubscriptionDataWithNewValidSeatNumbers({ + seatCount: memberCount + adminCount, + workspacePlan: workspacePlan.name, + getWorkspacePlanProductId, + subscriptionData + }) + + if (!isEqual(subscriptionData, workspaceSubscription.subscriptionData)) { + await reconcileSubscriptionData({ subscriptionData, prorationBehavior: 'none' }) + return true + } + return false + } + +export const downscaleWorkspaceSubscriptionFactoryNew = + ({ + getWorkspacePlan, + countSeatsByTypeInWorkspace, + getWorkspacePlanProductId, + reconcileSubscriptionData + }: { + getWorkspacePlan: GetWorkspacePlan + countSeatsByTypeInWorkspace: CountSeatsByTypeInWorkspace + getWorkspacePlanProductId: GetWorkspacePlanProductId + reconcileSubscriptionData: ReconcileSubscriptionData + }): DownscaleWorkspaceSubscription => + async ({ workspaceSubscription }) => { + const workspaceId = workspaceSubscription.workspaceId + + const workspacePlan = await getWorkspacePlan({ workspaceId }) + if (!workspacePlan) throw new WorkspacePlanNotFoundError() + + switch (workspacePlan.name) { + case 'team': + case 'pro': + break + case 'starter': + case 'plus': + case 'business': + case 'unlimited': + case 'academia': + case 'starterInvoiced': + case 'plusInvoiced': + case 'businessInvoiced': + case 'free': + throw new WorkspacePlanMismatchError() + default: + throwUncoveredError(workspacePlan) + } + + if (workspacePlan.status === 'canceled') return false + + const editorsCount = await countSeatsByTypeInWorkspace({ + workspaceId, + type: 'editor' + }) + + const subscriptionData = cloneDeep(workspaceSubscription.subscriptionData) + + mutateSubscriptionDataWithNewValidSeatNumbers({ + seatCount: editorsCount, + workspacePlan: workspacePlan.name, + getWorkspacePlanProductId, + subscriptionData + }) + + if (!isEqual(subscriptionData, workspaceSubscription.subscriptionData)) { + await reconcileSubscriptionData({ subscriptionData, prorationBehavior: 'none' }) + return true + } + return false + } + +export const manageSubscriptionDownscaleFactoryOld = + ({ + getWorkspaceSubscriptions, + downscaleWorkspaceSubscription, + updateWorkspaceSubscription + }: { + getWorkspaceSubscriptions: GetWorkspaceSubscriptions + downscaleWorkspaceSubscription: DownscaleWorkspaceSubscription + updateWorkspaceSubscription: UpsertWorkspaceSubscription + }) => + async (context: { logger: Logger }) => { + const { logger } = context + const subscriptions = await getWorkspaceSubscriptions() + for (const workspaceSubscription of subscriptions) { + const log = logger.child({ workspaceId: workspaceSubscription.workspaceId }) + try { + const subDownscaled = await downscaleWorkspaceSubscription({ + workspaceSubscription + }) + if (subDownscaled) { + log.info( + 'Downscaled workspace subscription to match the current workspace team' + ) + } else { + log.info('Did not need to downscale the workspace subscription') + } + } catch (err) { + log.error({ err }, 'Failed to downscale workspace subscription') + } + const newBillingCycleEnd = calculateNewBillingCycleEnd({ workspaceSubscription }) + const updatedWorkspaceSubscription = { + ...workspaceSubscription, + currentBillingCycleEnd: newBillingCycleEnd + } + await updateWorkspaceSubscription({ + workspaceSubscription: updatedWorkspaceSubscription + }) + log.info({ updatedWorkspaceSubscription }, 'Updated workspace billing cycle end') + } + } + +export const manageSubscriptionDownscaleFactoryNew = + ({ + getWorkspaceSubscriptions, + downscaleWorkspaceSubscription, + updateWorkspaceSubscription, + getSubscriptionData + }: { + getWorkspaceSubscriptions: GetWorkspaceSubscriptions + downscaleWorkspaceSubscription: DownscaleWorkspaceSubscription + updateWorkspaceSubscription: UpsertWorkspaceSubscription + getSubscriptionData: GetSubscriptionData + }) => + async (context: { logger: Logger }) => { + const { logger } = context + const subscriptions = await getWorkspaceSubscriptions() + for (const workspaceSubscription of subscriptions) { + const log = logger.child({ workspaceId: workspaceSubscription.workspaceId }) + try { + //TODO: + const subDownscaled = await downscaleWorkspaceSubscription({ + workspaceSubscription + }) + if (subDownscaled) { + log.info( + 'Downscaled workspace subscription to match the current workspace team' + ) + } else { + log.info('Did not need to downscale the workspace subscription') + } + } catch (err) { + log.error({ err }, 'Failed to downscale workspace subscription') + } + const subscriptionData = await getSubscriptionData( + workspaceSubscription.subscriptionData + ) + const updatedWorkspaceSubscription = { + ...workspaceSubscription, + currentBillingCycleEnd: subscriptionData.currentPeriodEnd + } + await updateWorkspaceSubscription({ + workspaceSubscription: updatedWorkspaceSubscription + }) + log.info({ updatedWorkspaceSubscription }, 'Updated workspace billing cycle end') + } + } diff --git a/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts b/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts index 10ed8c67e..e4793515f 100644 --- a/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts +++ b/packages/server/modules/gatekeeper/tests/intergration/billingRepositories.spec.ts @@ -10,10 +10,11 @@ import { upsertPaidWorkspacePlanFactory, getWorkspaceSubscriptionFactory, getWorkspaceSubscriptionBySubscriptionIdFactory, - getWorkspaceSubscriptionsPastBillingCycleEndFactory, changeExpiredTrialWorkspacePlanStatusesFactory, upsertTrialWorkspacePlanFactory, - getWorkspacesByPlanAgeFactory + getWorkspacesByPlanAgeFactory, + getWorkspaceSubscriptionsPastBillingCycleEndFactoryOldPlans, + upsertWorkspacePlanFactory } from '@/modules/gatekeeper/repositories/billing' import { createTestSubscriptionData, @@ -43,8 +44,8 @@ const getWorkspaceSubscription = getWorkspaceSubscriptionFactory({ db }) const getWorkspaceSubscriptionBySubscriptionId = getWorkspaceSubscriptionBySubscriptionIdFactory({ db }) -const getSubscriptionsAboutToEndBillingCycle = - getWorkspaceSubscriptionsPastBillingCycleEndFactory({ db }) +const getSubscriptionsAboutToEndBillingCycleOld = + getWorkspaceSubscriptionsPastBillingCycleEndFactoryOldPlans({ db }) const changeExpiredTrialWorkspacePlanStatuses = changeExpiredTrialWorkspacePlanStatusesFactory({ db }) @@ -526,10 +527,18 @@ describe('billing repositories @gatekeeper', () => { const workspace2Subscription = createTestWorkspaceSubscription({ workspaceId: workspace2Id }) + await upsertWorkspacePlanFactory({ db })({ + workspacePlan: { + workspaceId: workspace2Subscription.workspaceId, + name: 'plus', + status: 'valid', + createdAt: new Date() + } + }) await upsertWorkspaceSubscription({ workspaceSubscription: workspace2Subscription }) - const subscriptions = await getSubscriptionsAboutToEndBillingCycle() + const subscriptions = await getSubscriptionsAboutToEndBillingCycleOld() expect(subscriptions).deep.equalInAnyOrder([workspace2Subscription]) }) }) diff --git a/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts b/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts index e3f08d820..a630ea3f4 100644 --- a/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts +++ b/packages/server/modules/gatekeeper/tests/unit/subscriptions.spec.ts @@ -15,10 +15,14 @@ import { import { addWorkspaceSubscriptionSeatIfNeededFactoryNew, addWorkspaceSubscriptionSeatIfNeededFactoryOld, - downscaleWorkspaceSubscriptionFactory, - handleSubscriptionUpdateFactory, - manageSubscriptionDownscaleFactory + handleSubscriptionUpdateFactory } from '@/modules/gatekeeper/services/subscriptions' +import { + downscaleWorkspaceSubscriptionFactoryNew, + downscaleWorkspaceSubscriptionFactoryOld, + manageSubscriptionDownscaleFactoryOld +} from '@/modules/gatekeeper/services/subscriptions/manageSubscriptionDownscale' + import { createTestSubscriptionData, createTestWorkspaceSubscription @@ -1014,13 +1018,13 @@ describe('subscriptions @gatekeeper', () => { }) }) - describe('downscaleWorkspaceSubscriptionFactory creates a function, that', () => { + describe('downscaleWorkspaceSubscriptionFactoryOld creates a function, that', () => { it('throws an error if the workspace has no plan attached to it', async () => { const subscriptionData = createTestSubscriptionData() const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData }) - const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryOld({ getWorkspacePlan: async () => null, countWorkspaceRole: async () => { expect.fail() @@ -1044,7 +1048,7 @@ describe('subscriptions @gatekeeper', () => { subscriptionData, workspaceId }) - const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryOld({ getWorkspacePlan: async () => ({ name: 'unlimited', workspaceId, @@ -1073,7 +1077,7 @@ describe('subscriptions @gatekeeper', () => { subscriptionData, workspaceId }) - const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryOld({ getWorkspacePlan: async () => ({ name: 'plus', workspaceId, @@ -1109,7 +1113,7 @@ describe('subscriptions @gatekeeper', () => { workspaceId }) const workspacePlanName = 'plus' - const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryOld({ getWorkspacePlan: async () => ({ name: workspacePlanName, workspaceId, @@ -1164,7 +1168,7 @@ describe('subscriptions @gatekeeper', () => { const workspacePlanName = 'plus' let reconciledSub: SubscriptionDataInput | undefined = undefined - const downscaleSubscription = downscaleWorkspaceSubscriptionFactory({ + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryOld({ getWorkspacePlan: async () => ({ name: workspacePlanName, workspaceId, @@ -1193,14 +1197,178 @@ describe('subscriptions @gatekeeper', () => { ).to.be.equal(guestQuantity / 2) }) }) - describe('manageSubscriptionDownscaleFactory, creates a function, that', () => { + describe('downscaleWorkspaceSubscriptionFactoryNew creates a function, that', () => { + it('throws an error if the workspace has no plan attached to it', async () => { + const subscriptionData = createTestSubscriptionData() + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData + }) + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryNew({ + getWorkspacePlan: async () => null, + countSeatsByTypeInWorkspace: async () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const err = await expectToThrow(async () => { + await downscaleSubscription({ workspaceSubscription }) + }) + expect(err.message).to.equal(new WorkspacePlanNotFoundError().message) + }) + it('throws an error if workspacePlan is not a paid plan', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData() + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }) + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryNew({ + getWorkspacePlan: async () => ({ + name: 'unlimited', + workspaceId, + createdAt: new Date(), + status: 'valid' + }), + countSeatsByTypeInWorkspace: async () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const err = await expectToThrow(async () => { + await downscaleSubscription({ workspaceSubscription }) + }) + expect(err.message).to.equal(new WorkspacePlanMismatchError().message) + }) + it('returns if the subscription is canceled', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData() + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }) + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryNew({ + getWorkspacePlan: async () => ({ + name: 'pro', + workspaceId, + createdAt: new Date(), + status: 'canceled' + }), + countSeatsByTypeInWorkspace: async () => { + expect.fail() + }, + getWorkspacePlanProductId: () => { + expect.fail() + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + const hasDownscaled = await downscaleSubscription({ workspaceSubscription }) + expect(hasDownscaled).to.be.false + }) + it('does not reconcile the subscription seats did not change', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const priceId = cryptoRandomString({ length: 10 }) + const productId = cryptoRandomString({ length: 10 }) + const quantity = 10 + const subscriptionItemId = cryptoRandomString({ length: 10 }) + const subscriptionData = createTestSubscriptionData({ + products: [{ priceId, productId, quantity, subscriptionItemId }] + }) + const workspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + billingInterval: 'monthly', + currentBillingCycleEnd: new Date(2034, 11, 5), + workspaceId + }) + const workspacePlanName = 'pro' + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryNew({ + getWorkspacePlan: async () => ({ + name: workspacePlanName, + workspaceId, + createdAt: new Date(), + status: 'valid' + }), + countSeatsByTypeInWorkspace: async () => { + return 10 + }, + getWorkspacePlanProductId: ({ workspacePlan }) => { + return workspacePlan === workspacePlanName + ? productId + : cryptoRandomString({ length: 10 }) + }, + reconcileSubscriptionData: async () => { + expect.fail() + } + }) + await downscaleSubscription({ workspaceSubscription }) + }) + it('reconciles the subscription to the new seat values', async () => { + const workspaceId = cryptoRandomString({ length: 10 }) + const proPriceId = cryptoRandomString({ length: 10 }) + const proProductId = cryptoRandomString({ length: 10 }) + const proQuantity = 10 + const proSubscriptionItemId = cryptoRandomString({ length: 10 }) + + const subscriptionData = createTestSubscriptionData({ + products: [ + { + priceId: proPriceId, + productId: proProductId, + quantity: proQuantity, + subscriptionItemId: proSubscriptionItemId + } + ] + }) + const testWorkspaceSubscription = createTestWorkspaceSubscription({ + subscriptionData, + workspaceId + }) + const workspacePlanName = 'pro' + + let reconciledSub: SubscriptionDataInput | undefined = undefined + const downscaleSubscription = downscaleWorkspaceSubscriptionFactoryNew({ + getWorkspacePlan: async () => ({ + name: workspacePlanName, + workspaceId, + createdAt: new Date(), + status: 'valid' + }), + countSeatsByTypeInWorkspace: async () => { + return 5 + }, + getWorkspacePlanProductId: () => { + return proProductId + }, + reconcileSubscriptionData: async ({ subscriptionData }) => { + reconciledSub = subscriptionData + } + }) + await downscaleSubscription({ workspaceSubscription: testWorkspaceSubscription }) + + expect( + reconciledSub!.products.find((p) => p.productId === proProductId)?.quantity + ).to.be.equal(5) + }) + }) + describe('manageSubscriptionDownscaleFactoryOld, creates a function, that', () => { it('still updates the monthly billing cycle end, even if subscription reconciliation fails', async () => { const testWorkspaceSubscription = createTestWorkspaceSubscription({ billingInterval: 'monthly', currentBillingCycleEnd: new Date(2034, 11, 5) }) let updatedWorkspaceSubscription: WorkspaceSubscription | undefined = undefined - await manageSubscriptionDownscaleFactory({ + await manageSubscriptionDownscaleFactoryOld({ getWorkspaceSubscriptions: async () => [testWorkspaceSubscription], downscaleWorkspaceSubscription: async () => { throw new Error('kabumm') @@ -1222,7 +1390,7 @@ describe('subscriptions @gatekeeper', () => { currentBillingCycleEnd: new Date(2034, 11, 5) }) let updatedWorkspaceSubscription: WorkspaceSubscription | undefined = undefined - await manageSubscriptionDownscaleFactory({ + await manageSubscriptionDownscaleFactoryOld({ getWorkspaceSubscriptions: async () => [testWorkspaceSubscription], downscaleWorkspaceSubscription: async () => { throw new Error('kabumm') @@ -1593,7 +1761,8 @@ describe('subscriptions @gatekeeper', () => { customerId: cryptoRandomString({ length: 10 }), subscriptionId: cryptoRandomString({ length: 10 }), status: 'active', - products: [] + products: [], + currentPeriodEnd: new Date() } const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData @@ -1657,7 +1826,8 @@ describe('subscriptions @gatekeeper', () => { quantity: 20, subscriptionItemId: cryptoRandomString({ length: 10 }) } - ] + ], + currentPeriodEnd: new Date() } const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData, @@ -2044,7 +2214,8 @@ describe('subscriptions @gatekeeper', () => { customerId: cryptoRandomString({ length: 10 }), subscriptionId: cryptoRandomString({ length: 10 }), status: 'active', - products: [] + products: [], + currentPeriodEnd: new Date() } const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData @@ -2102,7 +2273,8 @@ describe('subscriptions @gatekeeper', () => { quantity: 10, subscriptionItemId: cryptoRandomString({ length: 10 }) } - ] + ], + currentPeriodEnd: new Date() } const workspaceSubscription = createTestWorkspaceSubscription({ subscriptionData, From 38fd761fe32ad458ab54fef56195b34fbdf1f385 Mon Sep 17 00:00:00 2001 From: Alessandro Magionami Date: Thu, 20 Mar 2025 18:57:52 +0100 Subject: [PATCH 4/9] fix(gatekeeper): fix date format in subscription parse --- packages/server/modules/gatekeeper/clients/stripe.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/server/modules/gatekeeper/clients/stripe.ts b/packages/server/modules/gatekeeper/clients/stripe.ts index c73127c89..5b6a256a2 100644 --- a/packages/server/modules/gatekeeper/clients/stripe.ts +++ b/packages/server/modules/gatekeeper/clients/stripe.ts @@ -66,7 +66,7 @@ export const parseSubscriptionData = ( cancelAt: stripeSubscription.cancel_at ? new Date(stripeSubscription.cancel_at * 1000) : null, - currentPeriodEnd: stripeSubscription.current_period_end, + currentPeriodEnd: stripeSubscription.current_period_end * 1000, // this value arrives as a UNIX timestamp products: stripeSubscription.items.data.map((subscriptionItem) => { const productId = typeof subscriptionItem.price.product === 'string' From b1c9d8b2d451feda42749b0bf53ec8eb8bdfa7f2 Mon Sep 17 00:00:00 2001 From: Alessandro Magionami Date: Fri, 21 Mar 2025 11:14:34 +0100 Subject: [PATCH 5/9] feat(gatekeeper): on invoice created trigger downscale --- .../server/modules/gatekeeper/rest/billing.ts | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/packages/server/modules/gatekeeper/rest/billing.ts b/packages/server/modules/gatekeeper/rest/billing.ts index 5aefe102e..1b681cd53 100644 --- a/packages/server/modules/gatekeeper/rest/billing.ts +++ b/packages/server/modules/gatekeeper/rest/billing.ts @@ -22,6 +22,7 @@ import { withTransaction } from '@/modules/shared/helpers/dbHelper' import { getStripeClient } from '@/modules/gatekeeper/stripe' import { handleSubscriptionUpdateFactory } from '@/modules/gatekeeper/services/subscriptions' import { getEventBus } from '@/modules/shared/services/eventBus' +import { SubscriptionData } from '@/modules/gatekeeper/domain/billing' export const getBillingRouter = (): Router => { const router = Router() @@ -144,6 +145,19 @@ export const getBillingRouter = (): Router => { })({ subscriptionData: parseSubscriptionData(event.data.object) }) break + case 'invoice.created': + const subscriptionData = await getSubscriptionFromEventFactory({ stripe })( + event + ) + if (!subscriptionData) break + await handleSubscriptionUpdateFactory({ + getWorkspacePlan: getWorkspacePlanFactory({ db }), + upsertPaidWorkspacePlan: upsertPaidWorkspacePlanFactory({ db }), + getWorkspaceSubscriptionBySubscriptionId: + getWorkspaceSubscriptionBySubscriptionIdFactory({ db }), + upsertWorkspaceSubscription: upsertWorkspaceSubscriptionFactory({ db }) + })({ subscriptionData }) + break default: break @@ -154,3 +168,18 @@ export const getBillingRouter = (): Router => { return router } + +const getSubscriptionFromEventFactory = + ({ stripe }: { stripe: Stripe }) => + async (event: Stripe.InvoiceCreatedEvent): Promise => { + const subscription = event.data.object.subscription + if (!subscription) { + return null + } + if (typeof subscription === 'string') { + return await getSubscriptionDataFactory({ stripe })({ + subscriptionId: subscription + }) + } + return parseSubscriptionData(subscription) + } From cd39e18d9baeffab35c9c29fd600e61524daeb9a Mon Sep 17 00:00:00 2001 From: Alessandro Magionami Date: Mon, 24 Mar 2025 15:31:02 +0100 Subject: [PATCH 6/9] chore(workspaces): fix linter --- packages/server/modules/gatekeeper/tests/unit/checkout.spec.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/server/modules/gatekeeper/tests/unit/checkout.spec.ts b/packages/server/modules/gatekeeper/tests/unit/checkout.spec.ts index 2efe6312a..d08988ea2 100644 --- a/packages/server/modules/gatekeeper/tests/unit/checkout.spec.ts +++ b/packages/server/modules/gatekeeper/tests/unit/checkout.spec.ts @@ -586,7 +586,8 @@ describe('checkout @gatekeeper', () => { } ], status: 'active', - cancelAt: null + cancelAt: null, + currentPeriodEnd: new Date() } let storedWorkspaceSubscriptionData: WorkspaceSubscription | undefined = From 800547309a12fd10e77eb08d523456317603c49b Mon Sep 17 00:00:00 2001 From: Alessandro Magionami Date: Mon, 24 Mar 2025 15:42:54 +0100 Subject: [PATCH 7/9] chore(workspaces): create table helper for subscriptions table --- .../gatekeeper/repositories/billing.ts | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/packages/server/modules/gatekeeper/repositories/billing.ts b/packages/server/modules/gatekeeper/repositories/billing.ts index c7663eee1..4e5e749b4 100644 --- a/packages/server/modules/gatekeeper/repositories/billing.ts +++ b/packages/server/modules/gatekeeper/repositories/billing.ts @@ -38,6 +38,14 @@ const WorkspacePlans = buildTableHelper('workspace_plans', [ 'createdAt', 'updatedAt' ]) +const WorkspaceSubscriptions = buildTableHelper('workspace_subscriptions', [ + 'workspaceId', + 'createdAt', + 'updatedAt', + 'currentBillingCycleEnd', + 'billingInterval', + 'subscriptionData' +]) const tables = { workspaces: (db: Knex) => db('workspaces'), @@ -230,14 +238,7 @@ export const getWorkspaceSubscriptionsPastBillingCycleEndFactoryOldPlans = ) .whereIn(WorkspacePlans.col.name, oldPlans) .where('currentBillingCycleEnd', '<', cycleEnd) - .select([ - 'workspace_subscriptions.workspaceId', - 'workspace_subscriptions.createdAt', - 'workspace_subscriptions.updatedAt', - 'workspace_subscriptions.currentBillingCycleEnd', - 'workspace_subscriptions.billingInterval', - 'workspace_subscriptions.subscriptionData' - ]) + .select(WorkspaceSubscriptions.cols) } export const getWorkspaceSubscriptionsPastBillingCycleEndFactoryNewPlans = @@ -254,14 +255,7 @@ export const getWorkspaceSubscriptionsPastBillingCycleEndFactoryNewPlans = ) .whereIn(WorkspacePlans.col.name, newPlans) .where('currentBillingCycleEnd', '<', cycleEnd) - .select([ - 'workspace_subscriptions.workspaceId', - 'workspace_subscriptions.createdAt', - 'workspace_subscriptions.updatedAt', - 'workspace_subscriptions.currentBillingCycleEnd', - 'workspace_subscriptions.billingInterval', - 'workspace_subscriptions.subscriptionData' - ]) + .select(WorkspaceSubscriptions.cols) } export const getWorkspacePlanByProjectIdFactory = From a6a4ceee86b81257910353baeca61427cc3e2566 Mon Sep 17 00:00:00 2001 From: Kristaps Fabians Geikins Date: Tue, 25 Mar 2025 18:49:02 +0200 Subject: [PATCH 8/9] feat: true-myth result structures & other auth policy improvements (#4262) * fixing up typing * better dynamic loader mechanism * buildReqLoaders cleanup * added caching to loaders * ensuring all loaders are async * fe2 plugins error handling fix * feat(shared): true-myth result structures & other auth policy improvements * moving workspaceCore loaders to correct place --- ...{accessibility.ts => 001-accessibility.ts} | 0 .../plugins/{dayjs.ts => 001-dayjs.ts} | 0 .../plugins/{portal.ts => 001-portal.ts} | 0 .../plugins/{tippy.ts => 001-tippy.ts} | 0 .../plugins/{001-logger.ts => 010-logger.ts} | 0 .../plugins/{002-rum.ts => 020-rum.ts} | 0 ....server.ts => 030-healthMetrics.server.ts} | 0 ...04-redis.server.ts => 040-redis.server.ts} | 0 .../plugins/{005-cache.ts => 050-cache.ts} | 0 ...{006-dataPreload.ts => 060-dataPreload.ts} | 0 ...lient.client.ts => 070-mpClient.client.ts} | 0 ...erver.server.ts => 070-mpServer.server.ts} | 0 .../{008-mp.client.ts => 080-mp.client.ts} | 0 packages/server/app.ts | 6 +- .../core/{authz.ts => authz/loaders/index.ts} | 24 +- .../modules/core/graph/resolvers/projects.ts | 3 +- packages/server/modules/core/index.ts | 3 - packages/server/modules/index.ts | 96 +++++++- packages/server/modules/loaders.ts | 55 +---- .../modules/shared/helpers/typeHelper.ts | 2 +- .../server/modules/shared/middleware/index.ts | 54 ++--- .../{authz.ts => authz/loaders/index.ts} | 28 ++- packages/server/modules/workspaces/index.ts | 3 - .../server/modules/workspacesCore/authz.ts | 19 -- .../workspacesCore/authz/loaders/index.ts | 17 ++ .../server/modules/workspacesCore/index.ts | 2 - packages/server/package.json | 1 + packages/server/test/graphqlHelper.ts | 42 ++-- packages/server/tsconfig.json | 2 +- .../server/type-augmentations/true-myth.d.ts | 10 + packages/shared/eslint.config.mjs | 6 + packages/shared/package.json | 1 + .../shared/src/authz/checks/projects.spec.ts | 46 ++-- packages/shared/src/authz/checks/projects.ts | 18 +- .../src/authz/checks/serverRole.spec.ts | 8 +- .../shared/src/authz/checks/serverRole.ts | 7 +- .../src/authz/checks/workspaceRole.spec.ts | 14 +- .../shared/src/authz/checks/workspaceRole.ts | 13 +- .../src/authz/checks/workspaceSso.spec.ts | 29 ++- .../shared/src/authz/checks/workspaceSso.ts | 11 +- .../shared/src/authz/domain/authErrors.ts | 32 ++- .../shared/src/authz/domain/authResult.ts | 19 -- .../src/authz/domain/core/operations.ts | 6 +- packages/shared/src/authz/domain/errors.ts | 6 + packages/shared/src/authz/domain/loaders.ts | 52 ++++- packages/shared/src/authz/domain/policies.ts | 12 + .../src/authz/domain/projects/operations.ts | 15 +- .../src/authz/domain/workspaces/operations.ts | 19 +- packages/shared/src/authz/index.ts | 8 +- .../authz/policies/canQueryProject.spec.ts | 221 ++++++++++-------- .../src/authz/policies/canQueryProject.ts | 81 ++++--- packages/shared/src/authz/policies/index.ts | 6 +- packages/shared/src/tests/fakes.ts | 7 +- yarn.lock | 9 + 54 files changed, 615 insertions(+), 398 deletions(-) rename packages/frontend-2/plugins/{accessibility.ts => 001-accessibility.ts} (100%) rename packages/frontend-2/plugins/{dayjs.ts => 001-dayjs.ts} (100%) rename packages/frontend-2/plugins/{portal.ts => 001-portal.ts} (100%) rename packages/frontend-2/plugins/{tippy.ts => 001-tippy.ts} (100%) rename packages/frontend-2/plugins/{001-logger.ts => 010-logger.ts} (100%) rename packages/frontend-2/plugins/{002-rum.ts => 020-rum.ts} (100%) rename packages/frontend-2/plugins/{003-healthMetrics.server.ts => 030-healthMetrics.server.ts} (100%) rename packages/frontend-2/plugins/{004-redis.server.ts => 040-redis.server.ts} (100%) rename packages/frontend-2/plugins/{005-cache.ts => 050-cache.ts} (100%) rename packages/frontend-2/plugins/{006-dataPreload.ts => 060-dataPreload.ts} (100%) rename packages/frontend-2/plugins/{007-mpClient.client.ts => 070-mpClient.client.ts} (100%) rename packages/frontend-2/plugins/{007-mpServer.server.ts => 070-mpServer.server.ts} (100%) rename packages/frontend-2/plugins/{008-mp.client.ts => 080-mp.client.ts} (100%) rename packages/server/modules/core/{authz.ts => authz/loaders/index.ts} (56%) rename packages/server/modules/workspaces/{authz.ts => authz/loaders/index.ts} (51%) delete mode 100644 packages/server/modules/workspacesCore/authz.ts create mode 100644 packages/server/modules/workspacesCore/authz/loaders/index.ts create mode 100644 packages/server/type-augmentations/true-myth.d.ts delete mode 100644 packages/shared/src/authz/domain/authResult.ts diff --git a/packages/frontend-2/plugins/accessibility.ts b/packages/frontend-2/plugins/001-accessibility.ts similarity index 100% rename from packages/frontend-2/plugins/accessibility.ts rename to packages/frontend-2/plugins/001-accessibility.ts diff --git a/packages/frontend-2/plugins/dayjs.ts b/packages/frontend-2/plugins/001-dayjs.ts similarity index 100% rename from packages/frontend-2/plugins/dayjs.ts rename to packages/frontend-2/plugins/001-dayjs.ts diff --git a/packages/frontend-2/plugins/portal.ts b/packages/frontend-2/plugins/001-portal.ts similarity index 100% rename from packages/frontend-2/plugins/portal.ts rename to packages/frontend-2/plugins/001-portal.ts diff --git a/packages/frontend-2/plugins/tippy.ts b/packages/frontend-2/plugins/001-tippy.ts similarity index 100% rename from packages/frontend-2/plugins/tippy.ts rename to packages/frontend-2/plugins/001-tippy.ts diff --git a/packages/frontend-2/plugins/001-logger.ts b/packages/frontend-2/plugins/010-logger.ts similarity index 100% rename from packages/frontend-2/plugins/001-logger.ts rename to packages/frontend-2/plugins/010-logger.ts diff --git a/packages/frontend-2/plugins/002-rum.ts b/packages/frontend-2/plugins/020-rum.ts similarity index 100% rename from packages/frontend-2/plugins/002-rum.ts rename to packages/frontend-2/plugins/020-rum.ts diff --git a/packages/frontend-2/plugins/003-healthMetrics.server.ts b/packages/frontend-2/plugins/030-healthMetrics.server.ts similarity index 100% rename from packages/frontend-2/plugins/003-healthMetrics.server.ts rename to packages/frontend-2/plugins/030-healthMetrics.server.ts diff --git a/packages/frontend-2/plugins/004-redis.server.ts b/packages/frontend-2/plugins/040-redis.server.ts similarity index 100% rename from packages/frontend-2/plugins/004-redis.server.ts rename to packages/frontend-2/plugins/040-redis.server.ts diff --git a/packages/frontend-2/plugins/005-cache.ts b/packages/frontend-2/plugins/050-cache.ts similarity index 100% rename from packages/frontend-2/plugins/005-cache.ts rename to packages/frontend-2/plugins/050-cache.ts diff --git a/packages/frontend-2/plugins/006-dataPreload.ts b/packages/frontend-2/plugins/060-dataPreload.ts similarity index 100% rename from packages/frontend-2/plugins/006-dataPreload.ts rename to packages/frontend-2/plugins/060-dataPreload.ts diff --git a/packages/frontend-2/plugins/007-mpClient.client.ts b/packages/frontend-2/plugins/070-mpClient.client.ts similarity index 100% rename from packages/frontend-2/plugins/007-mpClient.client.ts rename to packages/frontend-2/plugins/070-mpClient.client.ts diff --git a/packages/frontend-2/plugins/007-mpServer.server.ts b/packages/frontend-2/plugins/070-mpServer.server.ts similarity index 100% rename from packages/frontend-2/plugins/007-mpServer.server.ts rename to packages/frontend-2/plugins/070-mpServer.server.ts diff --git a/packages/frontend-2/plugins/008-mp.client.ts b/packages/frontend-2/plugins/080-mp.client.ts similarity index 100% rename from packages/frontend-2/plugins/008-mp.client.ts rename to packages/frontend-2/plugins/080-mp.client.ts diff --git a/packages/server/app.ts b/packages/server/app.ts index 665abebe3..9d50a75e1 100644 --- a/packages/server/app.ts +++ b/packages/server/app.ts @@ -196,11 +196,7 @@ export function buildApolloSubscriptionServer(params: { // for subscriptions) try { const headers = getHeaders({ connContext, connectionParams }) - const buildCtx = await buildContext({ - req: null, - token, - cleanLoadersEarly: false - }) + const buildCtx = await buildContext({ token }) buildCtx.log.info( { userId: buildCtx.userId, diff --git a/packages/server/modules/core/authz.ts b/packages/server/modules/core/authz/loaders/index.ts similarity index 56% rename from packages/server/modules/core/authz.ts rename to packages/server/modules/core/authz/loaders/index.ts index f417a345c..47a4118e0 100644 --- a/packages/server/modules/core/authz.ts +++ b/packages/server/modules/core/authz/loaders/index.ts @@ -1,27 +1,31 @@ +import { defineModuleLoaders } from '@/modules/loaders' import { getStreamFactory } from '@/modules/core/repositories/streams' -import { defineLoaders } from '@/modules/loaders' import { getFeatureFlags } from '@/modules/shared/helpers/envHelper' import { db } from '@/db/knex' import { getUserServerRoleFactory } from '@/modules/shared/repositories/acl' +import { err, ok } from 'true-myth/result' +import { Authz } from '@speckle/shared' -export const defineModuleLoaders = () => { +export default defineModuleLoaders(async () => { const getStream = getStreamFactory({ db }) const getUserServerRole = getUserServerRoleFactory({ db }) - defineLoaders({ - getEnv: getFeatureFlags, + return { + getEnv: async () => ok(getFeatureFlags()), getProject: async ({ projectId }) => { const project = await getStream({ streamId: projectId }) - if (!project) return null - return { ...project, projectId: project.id } + if (!project) return err(Authz.ProjectNotFoundError) + return ok({ ...project, projectId: project.id }) }, getProjectRole: async ({ userId, projectId }) => { const project = await getStream({ streamId: projectId, userId }) - return project?.role ?? null + if (!project?.role) return err(Authz.ProjectRoleNotFoundError) + return ok(project.role) }, getServerRole: async ({ userId }) => { const role = await getUserServerRole({ userId }) - return role ?? null + if (!role) return err(Authz.ServerRoleNotFoundError) + return ok(role) } - }) -} + } +}) diff --git a/packages/server/modules/core/graph/resolvers/projects.ts b/packages/server/modules/core/graph/resolvers/projects.ts index fbc0e231a..affc1c941 100644 --- a/packages/server/modules/core/graph/resolvers/projects.ts +++ b/packages/server/modules/core/graph/resolvers/projects.ts @@ -184,7 +184,7 @@ export = { userId: context.userId }) - if (!canQuery.authorized) { + if (!canQuery.isOk) { switch (canQuery.error.code) { case Authz.ProjectNotFoundError.code: throw new StreamNotFoundError() @@ -199,6 +199,7 @@ export = { const project = await getStream({ streamId: args.id }) + // TODO: Should scopes & token resource access rules be checked in authz policy? if (!project?.isPublic && !project?.isDiscoverable) { await validateScopes(context.scopes, Scopes.Streams.Read) } diff --git a/packages/server/modules/core/index.ts b/packages/server/modules/core/index.ts index 0eace522d..8f68f4620 100644 --- a/packages/server/modules/core/index.ts +++ b/packages/server/modules/core/index.ts @@ -23,7 +23,6 @@ import { reportSubscriptionEventsFactory } from '@/modules/core/events/subscript import { getEventBus } from '@/modules/shared/services/eventBus' import { publish } from '@/modules/shared/utils/subscriptions' import { getStreamCollaboratorsFactory } from '@/modules/core/repositories/streams' -import { defineModuleLoaders } from '@/modules/core/authz' let stopTestSubs: (() => void) | undefined = undefined @@ -88,8 +87,6 @@ const coreModule: SpeckleModule<{ getStreamCollaborators: getStreamCollaboratorsFactory({ db }) })() } - - defineModuleLoaders() }, async shutdown() { await shutdownResultListener() diff --git a/packages/server/modules/index.ts b/packages/server/modules/index.ts index 37d28b4d6..432b1270e 100644 --- a/packages/server/modules/index.ts +++ b/packages/server/modules/index.ts @@ -4,14 +4,14 @@ import fs from 'fs' import path from 'path' import { appRoot, packageRoot } from '@/bootstrap' -import { values, merge, camelCase, reduce, intersection } from 'lodash' +import { values, merge, camelCase, reduce, intersection, difference } from 'lodash' import baseTypeDefs from '@/modules/core/graph/schema/baseTypeDefs' import { scalarResolvers } from '@/modules/core/graph/scalars' import { makeExecutableSchema } from '@graphql-tools/schema' import { moduleLogger } from '@/observability/logging' import { addMocksToSchema } from '@graphql-tools/mock' import { getFeatureFlags } from '@/modules/shared/helpers/envHelper' -import { isNonNullable } from '@speckle/shared' +import { isNonNullable, Optional, Authz } from '@speckle/shared' import { SpeckleModule } from '@/modules/shared/helpers/typeHelper' import type { Express } from 'express' import { RequestDataLoadersBuilder } from '@/modules/shared/helpers/graphqlHelper' @@ -22,9 +22,14 @@ import { } from '@/modules/core/graph/helpers/directiveHelper' import { AppMocksConfig } from '@/modules/mocks' import { SpeckleModuleMocksConfig } from '@/modules/shared/helpers/mocks' -import { LogicError } from '@/modules/shared/errors' +import { LoaderConfigurationError, LogicError } from '@/modules/shared/errors' import type { Registry } from 'prom-client' -import { validateLoaders } from '@/modules/loaders' +import type { defineModuleLoaders } from '@/modules/loaders' +import { + inMemoryCacheProviderFactory, + wrapWithCache +} from '@/modules/shared/utils/caching' +import TTLCache from '@isaacs/ttlcache' /** * Cached speckle module requires @@ -128,7 +133,8 @@ export const init = async (params: { app: Express; metricsRegister: Registry }) await module.finalize?.({ app, isInitial, metricsRegister }) } - validateLoaders() + // Validate & cache authz loaders + await moduleAuthLoaders() hasInitializationOccurred = true } @@ -148,10 +154,12 @@ export const shutdown = async () => { */ export const graphDataloadersBuilders = (): RequestDataLoadersBuilder[] => { let dataLoaders: RequestDataLoadersBuilder[] = [] + const enabledModuleNames = getEnabledModuleNames() // load code modules from /modules const codeModuleDirs = fs.readdirSync(`${appRoot}/modules`) codeModuleDirs.forEach((file) => { + if (!enabledModuleNames.includes(file)) return const fullPath = path.join(`${appRoot}/modules`, file) // load dataloaders @@ -169,13 +177,15 @@ export const graphDataloadersBuilders = (): RequestDataLoadersBuilder[] => } /** - * GQL components will be loaded even from disabled modules to avoid schema complexity, so ensure - * that resolvers return valid values even if the module is disabled + * GQL components - typedefs, resolvers, directives + * (assets & directives will be loaded from even disabled components cause the schema must be static) */ const graphComponents = (): Pick, 'resolvers'> & { directiveBuilders: Record typeDefs: string[] } => { + const enabledModuleNames = getEnabledModuleNames() + // Base query and mutation to allow for type extension by modules. const typeDefs = [baseTypeDefs] @@ -197,11 +207,12 @@ const graphComponents = (): Pick, 'resolvers'> & { // load code modules from /modules const codeModuleDirs = fs.readdirSync(`${appRoot}/modules`) codeModuleDirs.forEach((file) => { + const isEnabledModule = enabledModuleNames.includes(file) const fullPath = path.join(`${appRoot}/modules`, file) // first pass load of resolvers const resolversPath = path.join(fullPath, 'graph', 'resolvers') - if (fs.existsSync(resolversPath)) { + if (isEnabledModule && fs.existsSync(resolversPath)) { const newResolverObjs = values(autoloadFromDirectory(resolversPath)).map((o) => 'default' in o ? o.default : o ) @@ -304,3 +315,72 @@ export const moduleMockConfigs = ( return mockConfigs } + +export const moduleAuthLoaders = async () => { + const enabledModuleNames = getEnabledModuleNames() + + let loaders: Partial = {} + + // load auth loaders from /modules and in same order as the whitelist + const codeModuleDirs = fs.readdirSync(`${appRoot}/modules`) + const coreModuleDirsOrdered = intersection(enabledModuleNames, codeModuleDirs) + for (const moduleName of coreModuleDirsOrdered) { + const fullModulePath = path.join(`${appRoot}/modules`, moduleName) + const loadersFolderPath = path.join(fullModulePath, 'authz', 'loaders') + if (!fs.existsSync(loadersFolderPath)) continue + + // We only take the first loaders.ts file we find (for now) + const moduleLoadersBuilderFn = values(autoloadFromDirectory(loadersFolderPath)) + .map((l) => l.default) + .filter(isNonNullable)[0] as Optional> + + loaders = { + ...loaders, + ...(await moduleLoadersBuilderFn?.()) + } + } + + // validate that all were loaded + const notFoundKeys = difference( + Object.values(Authz.AuthCheckContextLoaderKeys), + Object.keys(loaders) + ) + if (notFoundKeys.length) { + throw new LoaderConfigurationError( + `Missing authz loaders found: ${notFoundKeys.join(', ')}` + ) + } + + const allLoaders = loaders as Authz.AuthCheckContextLoaders + + /** + * Add inmemory caching to all loaders. Since the loaders & their caches are scoped to each request and these checks + * occur before any mutations, we can safely cache them in memory with a long ttl. + * + * In edge cases - the caches can be cleared + */ + const cache = new TTLCache() + const loadersWithCache: Authz.AuthCheckContextLoaders = Object.entries( + allLoaders + ).reduce((acc, entry) => { + const key = entry[0] as Authz.AuthCheckContextLoaderKeys + const loader = entry[1] as Authz.AllAuthCheckContextLoaders[typeof key] + + const newLoader = wrapWithCache({ + resolver: loader, + name: `authzLoader:${key}`, + // since its the inmemory cache, we dont have to worry about true-myth results being + // serialized and deserialized as they would be with redis + cacheProvider: inMemoryCacheProviderFactory({ cache }), + ttlMs: 1000 * 60 * 60 // 1 hour (longer than any req will be) + }) + acc[key] = newLoader + + return acc + }, {} as Authz.AuthCheckContextLoaders) + + return { + loaders: loadersWithCache, + clearCache: () => cache.clear() + } +} diff --git a/packages/server/modules/loaders.ts b/packages/server/modules/loaders.ts index 4390f416a..a4705670d 100644 --- a/packages/server/modules/loaders.ts +++ b/packages/server/modules/loaders.ts @@ -1,51 +1,8 @@ -import { LoaderConfigurationError } from '@/modules/shared/errors' -import { Authz } from '@speckle/shared' +import { Authz, MaybeAsync } from '@speckle/shared' -let cachedLoaders: Partial = {} - -const loaderKeys: (keyof Authz.AuthCheckContextLoaders)[] = [ - 'getEnv', - 'getProject', - 'getProjectRole', - 'getServerRole', - 'getWorkspace', - 'getWorkspaceRole', - 'getWorkspaceSsoProvider', - 'getWorkspaceSsoSession' -] - -export const defineLoaders = ( - loaders: Partial -): void => { - for (const key of Object.keys(loaders)) { - if (!loaderKeys.includes(key as keyof Authz.AuthCheckContextLoaders)) { - throw new LoaderConfigurationError( - `Attempted to define loader with unknown key: ${key}` - ) - } - } - - cachedLoaders = { - ...cachedLoaders, - ...loaders - } -} - -const isValidLoaders = ( - loaders: Partial -): loaders is Authz.AuthCheckContextLoaders => { - return loaderKeys.every((key) => !!loaders[key]) -} - -export const validateLoaders = () => { - if (!isValidLoaders(cachedLoaders)) { - throw new LoaderConfigurationError() - } -} - -export const getLoaders = (): Authz.AuthCheckContextLoaders => { - if (!isValidLoaders(cachedLoaders)) { - throw new LoaderConfigurationError('Attempted to reference invalid loaders.') - } - return cachedLoaders +// define being an arg simplifes usage in export default calls +export const defineModuleLoaders = ( + define: () => MaybeAsync> +) => { + return async () => await define() } diff --git a/packages/server/modules/shared/helpers/typeHelper.ts b/packages/server/modules/shared/helpers/typeHelper.ts index 27b0a180e..1e88a41f8 100644 --- a/packages/server/modules/shared/helpers/typeHelper.ts +++ b/packages/server/modules/shared/helpers/typeHelper.ts @@ -53,7 +53,7 @@ export type SpeckleModule = Record void } /** * Request-scoped GraphQL dataloaders * @see https://github.com/graphql/dataloader diff --git a/packages/server/modules/shared/middleware/index.ts b/packages/server/modules/shared/middleware/index.ts index 85a3e0a12..1bd635649 100644 --- a/packages/server/modules/shared/middleware/index.ts +++ b/packages/server/modules/shared/middleware/index.ts @@ -24,13 +24,11 @@ import { MaybeNullOrUndefined, Nullable } from '@/modules/shared/helpers/typeHelper' -import { Authz, Optional, wait } from '@speckle/shared' +import { Authz, wait } from '@speckle/shared' import { mixpanel } from '@/modules/shared/utils/mixpanel' import * as Observability from '@speckle/shared/dist/commonjs/observability/index.js' -import { pino } from 'pino' import { getIpFromRequest } from '@/modules/shared/utils/ip' import { Netmask } from 'netmask' -import { Merge } from 'type-fest' import { resourceAccessRuleToIdentifier } from '@/modules/core/helpers/token' import { delayGraphqlResponsesBy } from '@/modules/shared/helpers/envHelper' import { subscriptionLogger } from '@/observability/logging' @@ -48,7 +46,7 @@ import { getTokenAppInfoFactory } from '@/modules/auth/repositories/apps' import { getUserRoleFactory } from '@/modules/core/repositories/users' import { UserInputError } from '@/modules/core/errors/userinput' import compression from 'compression' -import { getLoaders } from '@/modules/loaders' +import { moduleAuthLoaders } from '@/modules' export const authMiddlewareCreator = ( steps: AuthPipelineFunction[] @@ -170,28 +168,17 @@ export const authContextMiddleware: RequestHandler = async (req, res, next) => { next() } -export async function addLoadersToCtx( - ctx: Merge, { log?: Optional }>, - options?: Partial<{ cleanLoadersEarly: boolean }> -): Promise { - const log = - ctx.log || Observability.extendLoggerComponent(Observability.getLogger(), 'graphql') - const loaders = await buildRequestLoaders(ctx, options) - return { ...ctx, loaders, log } -} - /** * Build context for GQL operations */ -export async function buildContext({ - req, - token, - cleanLoadersEarly -}: { - req: MaybeNullOrUndefined +export async function buildContext(params?: { + req?: MaybeNullOrUndefined token?: Nullable + authContext?: AuthContext cleanLoadersEarly?: boolean }): Promise { + const { req, token, authContext, cleanLoadersEarly } = params || {} + const validateToken = validateTokenFactory({ revokeUserTokenById: revokeUserTokenByIdFactory({ db }), getApiTokenById: getApiTokenByIdFactory({ db }), @@ -207,6 +194,7 @@ export async function buildContext({ }) const ctx = + authContext || req?.context || (await createAuthContextFromToken(token ?? getTokenFromRequest(req), validateToken)) @@ -221,17 +209,23 @@ export async function buildContext({ await wait(delay) } - const authPolicies = Authz.authPoliciesFactory(getLoaders()) + const [authLoaders, dataLoaders] = await Promise.all([ + moduleAuthLoaders(), + buildRequestLoaders(ctx, { cleanLoadersEarly }) + ]) + const authPolicies = Authz.authPoliciesFactory(authLoaders.loaders) - // Adding request data loaders - return await addLoadersToCtx( - { - ...ctx, - log, - authPolicies - }, - { cleanLoadersEarly } - ) + return { + ...ctx, + loaders: dataLoaders, + log, + authPolicies: { + ...authPolicies, + clearCache: () => { + authLoaders.clearCache() + } + } + } } /** diff --git a/packages/server/modules/workspaces/authz.ts b/packages/server/modules/workspaces/authz/loaders/index.ts similarity index 51% rename from packages/server/modules/workspaces/authz.ts rename to packages/server/modules/workspaces/authz/loaders/index.ts index e5640f0e7..d2916962f 100644 --- a/packages/server/modules/workspaces/authz.ts +++ b/packages/server/modules/workspaces/authz/loaders/index.ts @@ -1,5 +1,5 @@ import { db } from '@/db/knex' -import { defineLoaders } from '@/modules/loaders' +import { defineModuleLoaders } from '@/modules/loaders' import { getUserSsoSessionFactory, getWorkspaceSsoProviderRecordFactory @@ -8,29 +8,39 @@ import { getWorkspaceFactory, getWorkspaceRoleForUserFactory } from '@/modules/workspaces/repositories/workspaces' +import { Authz } from '@speckle/shared' +import { err, ok } from 'true-myth/result' -export const defineModuleLoaders = () => { - defineLoaders({ - getWorkspace: getWorkspaceFactory({ db }), +export default defineModuleLoaders(async () => { + const getWorkspace = getWorkspaceFactory({ db }) + return { + getWorkspace: async ({ workspaceId }) => { + const workspace = await getWorkspace({ workspaceId }) + if (!workspace) return err(Authz.WorkspaceNotFoundError) + return ok(workspace) + }, getWorkspaceRole: async ({ userId, workspaceId }) => { const role = await getWorkspaceRoleForUserFactory({ db })({ userId, workspaceId }) - return role?.role ?? null + if (!role) return err(Authz.WorkspaceRoleNotFoundError) + return ok(role.role) }, getWorkspaceSsoSession: async ({ userId, workspaceId }) => { const ssoSession = await getUserSsoSessionFactory({ db })({ userId, workspaceId }) - return ssoSession ?? null + if (!ssoSession) return err(Authz.WorkspaceSsoSessionNotFoundError) + return ok(ssoSession) }, getWorkspaceSsoProvider: async ({ workspaceId }) => { const ssoProvider = await getWorkspaceSsoProviderRecordFactory({ db })({ workspaceId }) - return ssoProvider ?? null + if (!ssoProvider) return err(Authz.WorkspaceSsoProviderNotFoundError) + return ok(ssoProvider) } - }) -} + } +}) diff --git a/packages/server/modules/workspaces/index.ts b/packages/server/modules/workspaces/index.ts index ea8b69fcb..05cc72fb2 100644 --- a/packages/server/modules/workspaces/index.ts +++ b/packages/server/modules/workspaces/index.ts @@ -10,7 +10,6 @@ import { initializeEventListenersFactory } from '@/modules/workspaces/events/eve import { validateModuleLicense } from '@/modules/gatekeeper/services/validateLicense' import { getSsoRouter } from '@/modules/workspaces/rest/sso' import { InvalidLicenseError } from '@/modules/gatekeeper/errors/license' -import { defineModuleLoaders } from '@/modules/workspaces/authz' const { FF_WORKSPACES_MODULE_ENABLED, FF_WORKSPACES_SSO_ENABLED } = getFeatureFlags() @@ -45,8 +44,6 @@ const workspacesModule: SpeckleModule = { quitListeners = initializeEventListenersFactory({ db })() } await Promise.all([initScopes(), initRoles()]) - - defineModuleLoaders() }, shutdown() { if (!FF_WORKSPACES_MODULE_ENABLED) return diff --git a/packages/server/modules/workspacesCore/authz.ts b/packages/server/modules/workspacesCore/authz.ts deleted file mode 100644 index de435a3d0..000000000 --- a/packages/server/modules/workspacesCore/authz.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { defineLoaders } from '@/modules/loaders' -import { LoaderUnsupportedError } from '@/modules/shared/errors' - -export const defineModuleLoaders = () => { - defineLoaders({ - getWorkspace: async () => { - throw new LoaderUnsupportedError() - }, - getWorkspaceRole: async () => { - throw new LoaderUnsupportedError() - }, - getWorkspaceSsoSession: async () => { - throw new LoaderUnsupportedError() - }, - getWorkspaceSsoProvider: async () => { - throw new LoaderUnsupportedError() - } - }) -} diff --git a/packages/server/modules/workspacesCore/authz/loaders/index.ts b/packages/server/modules/workspacesCore/authz/loaders/index.ts new file mode 100644 index 000000000..81af81928 --- /dev/null +++ b/packages/server/modules/workspacesCore/authz/loaders/index.ts @@ -0,0 +1,17 @@ +import { defineModuleLoaders } from '@/modules/loaders' +import { LoaderUnsupportedError } from '@/modules/shared/errors' + +export default defineModuleLoaders(() => ({ + getWorkspace: async () => { + throw new LoaderUnsupportedError() + }, + getWorkspaceRole: async () => { + throw new LoaderUnsupportedError() + }, + getWorkspaceSsoSession: async () => { + throw new LoaderUnsupportedError() + }, + getWorkspaceSsoProvider: async () => { + throw new LoaderUnsupportedError() + } +})) diff --git a/packages/server/modules/workspacesCore/index.ts b/packages/server/modules/workspacesCore/index.ts index 7f55f6749..e1ddc7ded 100644 --- a/packages/server/modules/workspacesCore/index.ts +++ b/packages/server/modules/workspacesCore/index.ts @@ -1,8 +1,6 @@ import { SpeckleModule } from '@/modules/shared/helpers/typeHelper' -import { defineModuleLoaders } from '@/modules/workspacesCore/authz' import { moduleLogger } from '@/observability/logging' export const init: SpeckleModule['init'] = () => { moduleLogger.info('⚒️ Init workspaces core module') - defineModuleLoaders() } diff --git a/packages/server/package.json b/packages/server/package.json index d4d2f9e2c..56a879fa0 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -124,6 +124,7 @@ "string-pixel-width": "^1.10.0", "stripe": "^17.1.0", "subscriptions-transport-ws": "^0.11.0", + "true-myth": "^8.5.0", "ua-parser-js": "^1.0.38", "undici": "^5.28.4", "verror": "^1.10.1", diff --git a/packages/server/test/graphqlHelper.ts b/packages/server/test/graphqlHelper.ts index 4183a4e97..49e449df7 100644 --- a/packages/server/test/graphqlHelper.ts +++ b/packages/server/test/graphqlHelper.ts @@ -3,11 +3,10 @@ import { DocumentNode, FormattedExecutionResult } from 'graphql' import { GraphQLContext } from '@/modules/shared/helpers/typeHelper' import { TypedDocumentNode } from '@graphql-typed-document-node/core' import { buildApolloServer, buildApolloSubscriptionServer } from '@/app' -import { addLoadersToCtx } from '@/modules/shared/middleware' +import { buildContext } from '@/modules/shared/middleware' import { Roles } from '@/modules/core/helpers/mainConstants' import { AllScopes, - Authz, buildManualPromise, ensureError, MaybeAsync, @@ -34,7 +33,6 @@ import { PingPongDocument } from '@/test/graphql/generated/graphql' import { BaseError } from '@/modules/shared/errors' import EventEmitter from 'eventemitter2' import { expectToThrow } from '@/test/assertionHelper' -import { getLoaders } from '@/modules/loaders' type TypedGraphqlResponse> = GraphQLResponse @@ -110,30 +108,32 @@ export async function executeOperation< export const createTestContext = async ( ctx?: Partial ): Promise => - addLoadersToCtx({ - auth: false, - userId: undefined, - role: undefined, - token: undefined, - scopes: [], - stream: undefined, - err: undefined, - authPolicies: Authz.authPoliciesFactory(getLoaders()), - ...(ctx || {}) + await buildContext({ + authContext: { + auth: false, + userId: undefined, + role: undefined, + token: undefined, + scopes: [], + stream: undefined, + err: undefined, + ...(ctx || {}) + } }) export const createAuthedTestContext = async ( userId: string, ctxOverrides?: Partial ): Promise => - addLoadersToCtx({ - auth: true, - userId, - role: Roles.Server.User, - token: 'asd', - scopes: AllScopes, - authPolicies: Authz.authPoliciesFactory(getLoaders()), - ...(ctxOverrides || {}) + await buildContext({ + authContext: { + auth: true, + userId, + role: Roles.Server.User, + token: 'asd', + scopes: AllScopes, + ...(ctxOverrides || {}) + } }) const buildMergedContext = async (params: { diff --git a/packages/server/tsconfig.json b/packages/server/tsconfig.json index 71b6b0644..67bba2e20 100644 --- a/packages/server/tsconfig.json +++ b/packages/server/tsconfig.json @@ -26,7 +26,7 @@ /* Modules */ "module": "commonjs" /* Specify what module code is generated. */, // "rootDir": "./", /* Specify the root folder within your source files. */ - // "moduleResolution": "node", /* Specify how TypeScript looks up a file from a given module specifier. */ + "moduleResolution": "node" /* Specify how TypeScript looks up a file from a given module specifier. */, "baseUrl": "./" /* Specify the base directory to resolve non-relative module names. */, "paths": { "@/*": ["./*"], diff --git a/packages/server/type-augmentations/true-myth.d.ts b/packages/server/type-augmentations/true-myth.d.ts new file mode 100644 index 000000000..688250b38 --- /dev/null +++ b/packages/server/type-augmentations/true-myth.d.ts @@ -0,0 +1,10 @@ +// Only need to do this because our CommonJS app does not support true-myth's export maps +declare module 'true-myth/result' { + import { Result } from 'true-myth/dist/cjs/result.cjs' + export * from 'true-myth/dist/cjs/result.cjs' + export default Result +} + +declare module 'true-myth' { + export * from 'true-myth/dist/cjs/index.cjs' +} diff --git a/packages/shared/eslint.config.mjs b/packages/shared/eslint.config.mjs index 77a8f94b7..813c816f0 100644 --- a/packages/shared/eslint.config.mjs +++ b/packages/shared/eslint.config.mjs @@ -36,6 +36,12 @@ const configs = [ rules: { '@typescript-eslint/no-explicit-any': 'off' } + }, + { + files: ['**/*.spec.ts'], + rules: { + '@typescript-eslint/require-await': 'off' // so we can easily make sync mocked loaders -> async + } } ] diff --git a/packages/shared/package.json b/packages/shared/package.json index 2cbb571df..aeb01b3aa 100644 --- a/packages/shared/package.json +++ b/packages/shared/package.json @@ -38,6 +38,7 @@ "dependencies": { "lodash": "^4.17.21", "lodash-es": "^4.17.21", + "true-myth": "^8.5.0", "type-fest": "^3.11.1" }, "peerDependencies": { diff --git a/packages/shared/src/authz/checks/projects.spec.ts b/packages/shared/src/authz/checks/projects.spec.ts index efe25a678..a4718b5ac 100644 --- a/packages/shared/src/authz/checks/projects.spec.ts +++ b/packages/shared/src/authz/checks/projects.spec.ts @@ -6,13 +6,15 @@ import { import cryptoRandomString from 'crypto-random-string' import { Project } from '../domain/projects/types.js' import { Roles, UncoveredError } from '../../core/index.js' +import { err, ok } from 'true-myth/result' +import { ProjectNotFoundError, ProjectRoleNotFoundError } from '../domain/authErrors.js' describe('project checks', () => { describe('requireExactProjectVisibilityFactory returns a function, that', () => { it('throws if project does not exist', async () => { const requireExactProjectVisibility = requireExactProjectVisibilityFactory({ loaders: { - getProject: () => Promise.resolve(null) + getProject: () => Promise.resolve(err(ProjectNotFoundError)) } }) await expect( @@ -26,9 +28,11 @@ describe('project checks', () => { const result = await requireExactProjectVisibilityFactory({ loaders: { getProject: () => - Promise.resolve({ - isDiscoverable: true - } as Project) + Promise.resolve( + ok({ + isDiscoverable: true + } as Project) + ) } })({ projectVisibility: 'linkShareable', @@ -40,9 +44,11 @@ describe('project checks', () => { const result = await requireExactProjectVisibilityFactory({ loaders: { getProject: () => - Promise.resolve({ - isPublic: true - } as Project) + Promise.resolve( + ok({ + isPublic: true + } as Project) + ) } })({ projectVisibility: 'public', @@ -54,10 +60,12 @@ describe('project checks', () => { const result = await requireExactProjectVisibilityFactory({ loaders: { getProject: () => - Promise.resolve({ - isDiscoverable: false, - isPublic: false - } as Project) + Promise.resolve( + ok({ + isDiscoverable: false, + isPublic: false + } as Project) + ) } })({ projectVisibility: 'private', @@ -70,10 +78,12 @@ describe('project checks', () => { requireExactProjectVisibilityFactory({ loaders: { getProject: () => - Promise.resolve({ - isDiscoverable: false, - isPublic: false - } as Project) + Promise.resolve( + ok({ + isDiscoverable: false, + isPublic: false + } as Project) + ) } })({ // @ts-expect-error this is what im testing here @@ -87,7 +97,7 @@ describe('project checks', () => { it('returns false, if there is no role for the user', async () => { const result = await requireMinimumProjectRoleFactory({ loaders: { - getProjectRole: () => Promise.resolve(null) + getProjectRole: () => Promise.resolve(err(ProjectRoleNotFoundError)) } })({ projectId: cryptoRandomString({ length: 10 }), @@ -99,7 +109,7 @@ describe('project checks', () => { it('returns false, if the role is not sufficient', async () => { const result = await requireMinimumProjectRoleFactory({ loaders: { - getProjectRole: () => Promise.resolve(Roles.Stream.Reviewer) + getProjectRole: () => Promise.resolve(ok(Roles.Stream.Reviewer)) } })({ projectId: cryptoRandomString({ length: 10 }), @@ -111,7 +121,7 @@ describe('project checks', () => { it('returns true, if the role is sufficient', async () => { const result = await requireMinimumProjectRoleFactory({ loaders: { - getProjectRole: () => Promise.resolve(Roles.Stream.Contributor) + getProjectRole: () => Promise.resolve(ok(Roles.Stream.Contributor)) } })({ projectId: cryptoRandomString({ length: 10 }), diff --git a/packages/shared/src/authz/checks/projects.ts b/packages/shared/src/authz/checks/projects.ts index 879e95a2d..79bd65d1e 100644 --- a/packages/shared/src/authz/checks/projects.ts +++ b/packages/shared/src/authz/checks/projects.ts @@ -1,11 +1,11 @@ import { StreamRoles, throwUncoveredError } from '../../core/index.js' import { ProjectNotFoundError } from '../domain/errors.js' -import { AuthCheckContext } from '../domain/loaders.js' +import { AuthCheckContext, AuthCheckContextLoaderKeys } from '../domain/loaders.js' import { isMinimumProjectRole } from '../domain/projects/logic.js' import { ProjectVisibility } from '../domain/projects/types.js' export const requireExactProjectVisibilityFactory = - ({ loaders }: AuthCheckContext<'getProject'>) => + ({ loaders }: AuthCheckContext) => async (args: { projectVisibility: ProjectVisibility projectId: string @@ -13,22 +13,22 @@ export const requireExactProjectVisibilityFactory = const { projectId, projectVisibility } = args const project = await loaders.getProject({ projectId }) - if (!project) throw new ProjectNotFoundError({ projectId }) + if (!project.isOk) throw new ProjectNotFoundError({ projectId }) switch (projectVisibility) { case 'linkShareable': - return project.isDiscoverable === true + return project.value.isDiscoverable === true case 'public': - return project.isPublic === true + return project.value.isPublic === true case 'private': - return project.isPublic !== true && project.isDiscoverable !== true + return project.value.isPublic !== true && project.value.isDiscoverable !== true default: throwUncoveredError(projectVisibility) } } export const requireMinimumProjectRoleFactory = - ({ loaders }: AuthCheckContext<'getProjectRole'>) => + ({ loaders }: AuthCheckContext) => async (args: { userId: string projectId: string @@ -37,7 +37,7 @@ export const requireMinimumProjectRoleFactory = const { userId, projectId, role: requiredProjectRole } = args const userProjectRole = await loaders.getProjectRole({ userId, projectId }) - return userProjectRole - ? isMinimumProjectRole(userProjectRole, requiredProjectRole) + return userProjectRole.isOk + ? isMinimumProjectRole(userProjectRole.value, requiredProjectRole) : false } diff --git a/packages/shared/src/authz/checks/serverRole.spec.ts b/packages/shared/src/authz/checks/serverRole.spec.ts index d4fed5699..db13883f3 100644 --- a/packages/shared/src/authz/checks/serverRole.spec.ts +++ b/packages/shared/src/authz/checks/serverRole.spec.ts @@ -1,12 +1,14 @@ import { describe, expect, it } from 'vitest' import { requireExactServerRole } from './serverRole.js' import cryptoRandomString from 'crypto-random-string' +import { err, ok } from 'true-myth/result' +import { ServerRoleNotFoundError } from '../domain/authErrors.js' describe('requireExactServerRole returns a function, that', () => { it('returns false for mismatch roles', async () => { const result = await requireExactServerRole({ loaders: { - getServerRole: () => Promise.resolve('server:user') + getServerRole: () => Promise.resolve(ok('server:user')) } })({ userId: cryptoRandomString({ length: 9 }), @@ -17,7 +19,7 @@ describe('requireExactServerRole returns a function, that', () => { it('returns false for users without roles', async () => { const result = await requireExactServerRole({ loaders: { - getServerRole: () => Promise.resolve(null) + getServerRole: () => Promise.resolve(err(ServerRoleNotFoundError)) } })({ userId: cryptoRandomString({ length: 9 }), @@ -28,7 +30,7 @@ describe('requireExactServerRole returns a function, that', () => { it('returns true for matching roles', async () => { const result = await requireExactServerRole({ loaders: { - getServerRole: () => Promise.resolve('server:admin') + getServerRole: () => Promise.resolve(ok('server:admin')) } })({ userId: cryptoRandomString({ length: 9 }), diff --git a/packages/shared/src/authz/checks/serverRole.ts b/packages/shared/src/authz/checks/serverRole.ts index 06236917d..11d8035a8 100644 --- a/packages/shared/src/authz/checks/serverRole.ts +++ b/packages/shared/src/authz/checks/serverRole.ts @@ -1,12 +1,13 @@ import { ServerRoles } from '../../core/constants.js' -import { AuthCheckContext } from '../domain/loaders.js' +import { AuthCheckContext, AuthCheckContextLoaderKeys } from '../domain/loaders.js' export const requireExactServerRole = - ({ loaders }: AuthCheckContext<'getServerRole'>) => + ({ loaders }: AuthCheckContext) => async (args: { userId: string; role: ServerRoles }): Promise => { const { userId, role: requiredServerRole } = args const userServerRole = await loaders.getServerRole({ userId }) + if (!userServerRole.isOk) return false - return userServerRole === requiredServerRole + return userServerRole.value === requiredServerRole } diff --git a/packages/shared/src/authz/checks/workspaceRole.spec.ts b/packages/shared/src/authz/checks/workspaceRole.spec.ts index 9d1a0e1a5..57028b7ce 100644 --- a/packages/shared/src/authz/checks/workspaceRole.spec.ts +++ b/packages/shared/src/authz/checks/workspaceRole.spec.ts @@ -4,12 +4,14 @@ import { requireMinimumWorkspaceRole } from './workspaceRole.js' import cryptoRandomString from 'crypto-random-string' +import { err, ok } from 'true-myth/result' +import { WorkspaceRoleNotFoundError } from '../domain/authErrors.js' describe('requireAnyWorkspaceRole returns a function, that', () => { it('returns false if the user has no role', async () => { const result = await requireAnyWorkspaceRole({ loaders: { - getWorkspaceRole: () => Promise.resolve(null) + getWorkspaceRole: () => Promise.resolve(err(WorkspaceRoleNotFoundError)) } })({ userId: cryptoRandomString({ length: 9 }), @@ -20,7 +22,7 @@ describe('requireAnyWorkspaceRole returns a function, that', () => { it('returns true if the user has a role', async () => { const result = await requireAnyWorkspaceRole({ loaders: { - getWorkspaceRole: () => Promise.resolve('workspace:member') + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')) } })({ userId: cryptoRandomString({ length: 9 }), @@ -34,7 +36,7 @@ describe('requireMinimumWorkspaceRole returns a function, that', () => { it('returns false if user does not have a role', async () => { const result = await requireMinimumWorkspaceRole({ loaders: { - getWorkspaceRole: () => Promise.resolve(null) + getWorkspaceRole: () => Promise.resolve(err(WorkspaceRoleNotFoundError)) } })({ userId: cryptoRandomString({ length: 9 }), @@ -46,7 +48,7 @@ describe('requireMinimumWorkspaceRole returns a function, that', () => { it('returns false if user is below target role', async () => { const result = await requireMinimumWorkspaceRole({ loaders: { - getWorkspaceRole: () => Promise.resolve('workspace:member') + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')) } })({ userId: cryptoRandomString({ length: 9 }), @@ -58,7 +60,7 @@ describe('requireMinimumWorkspaceRole returns a function, that', () => { it('returns true if user matches target role', async () => { const result = await requireMinimumWorkspaceRole({ loaders: { - getWorkspaceRole: () => Promise.resolve('workspace:member') + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')) } })({ userId: cryptoRandomString({ length: 9 }), @@ -70,7 +72,7 @@ describe('requireMinimumWorkspaceRole returns a function, that', () => { it('returns true if user exceeds target role', async () => { const result = await requireMinimumWorkspaceRole({ loaders: { - getWorkspaceRole: () => Promise.resolve('workspace:admin') + getWorkspaceRole: () => Promise.resolve(ok('workspace:admin')) } })({ userId: cryptoRandomString({ length: 9 }), diff --git a/packages/shared/src/authz/checks/workspaceRole.ts b/packages/shared/src/authz/checks/workspaceRole.ts index 770311853..8014f3fa3 100644 --- a/packages/shared/src/authz/checks/workspaceRole.ts +++ b/packages/shared/src/authz/checks/workspaceRole.ts @@ -1,19 +1,18 @@ import { WorkspaceRoles } from '../../core/constants.js' -import { AuthCheckContext } from '../domain/loaders.js' +import { AuthCheckContext, AuthCheckContextLoaderKeys } from '../domain/loaders.js' import { isMinimumWorkspaceRole } from '../domain/workspaces/logic.js' export const requireAnyWorkspaceRole = - ({ loaders }: AuthCheckContext<'getWorkspaceRole'>) => + ({ loaders }: AuthCheckContext) => async (args: { userId: string; workspaceId: string }): Promise => { const { userId, workspaceId } = args const userWorkspaceRole = await loaders.getWorkspaceRole({ userId, workspaceId }) - - return userWorkspaceRole !== null + return userWorkspaceRole.isOk } export const requireMinimumWorkspaceRole = - ({ loaders }: AuthCheckContext<'getWorkspaceRole'>) => + ({ loaders }: AuthCheckContext) => async (args: { userId: string workspaceId: string @@ -23,7 +22,7 @@ export const requireMinimumWorkspaceRole = const userWorkspaceRole = await loaders.getWorkspaceRole({ userId, workspaceId }) - return userWorkspaceRole - ? isMinimumWorkspaceRole(userWorkspaceRole, requiredWorkspaceRole) + return userWorkspaceRole.isOk + ? isMinimumWorkspaceRole(userWorkspaceRole.value, requiredWorkspaceRole) : false } diff --git a/packages/shared/src/authz/checks/workspaceSso.spec.ts b/packages/shared/src/authz/checks/workspaceSso.spec.ts index e07d9d075..e65230d2e 100644 --- a/packages/shared/src/authz/checks/workspaceSso.spec.ts +++ b/packages/shared/src/authz/checks/workspaceSso.spec.ts @@ -1,12 +1,15 @@ import { describe, expect, it } from 'vitest' import { requireValidWorkspaceSsoSession } from './workspaceSso.js' import cryptoRandomString from 'crypto-random-string' +import { err, ok } from 'true-myth/result' +import { WorkspaceSsoSessionNotFoundError } from '../domain/authErrors.js' describe('requireValidWorkspaceSsoSession returns a function, that', () => { it('returns false if user does not have an SSO session', async () => { const result = await requireValidWorkspaceSsoSession({ loaders: { - getWorkspaceSsoSession: () => Promise.resolve(null) + getWorkspaceSsoSession: () => + Promise.resolve(err(WorkspaceSsoSessionNotFoundError)) } })({ userId: cryptoRandomString({ length: 9 }), @@ -25,11 +28,13 @@ describe('requireValidWorkspaceSsoSession returns a function, that', () => { const result = await requireValidWorkspaceSsoSession({ loaders: { getWorkspaceSsoSession: () => - Promise.resolve({ - userId, - providerId, - validUntil - }) + Promise.resolve( + ok({ + userId, + providerId, + validUntil + }) + ) } })({ userId, @@ -48,11 +53,13 @@ describe('requireValidWorkspaceSsoSession returns a function, that', () => { const result = await requireValidWorkspaceSsoSession({ loaders: { getWorkspaceSsoSession: () => - Promise.resolve({ - userId, - providerId, - validUntil - }) + Promise.resolve( + ok({ + userId, + providerId, + validUntil + }) + ) } })({ userId, diff --git a/packages/shared/src/authz/checks/workspaceSso.ts b/packages/shared/src/authz/checks/workspaceSso.ts index 4ed74274c..d3488dbfe 100644 --- a/packages/shared/src/authz/checks/workspaceSso.ts +++ b/packages/shared/src/authz/checks/workspaceSso.ts @@ -1,7 +1,9 @@ -import { AuthCheckContext } from '../domain/loaders.js' +import { AuthCheckContext, AuthCheckContextLoaderKeys } from '../domain/loaders.js' export const requireValidWorkspaceSsoSession = - ({ loaders }: AuthCheckContext<'getWorkspaceSsoSession'>) => + ({ + loaders + }: AuthCheckContext) => async (args: { userId: string; workspaceId: string }): Promise => { const { userId, workspaceId } = args @@ -9,9 +11,10 @@ export const requireValidWorkspaceSsoSession = userId, workspaceId }) + if (!workspaceSsoSession.isOk) return false const isExpiredSession = - new Date().getTime() > (workspaceSsoSession?.validUntil?.getTime() ?? 0) + new Date().getTime() > workspaceSsoSession.value.validUntil.getTime() - return !!workspaceSsoSession && !isExpiredSession + return !isExpiredSession } diff --git a/packages/shared/src/authz/domain/authErrors.ts b/packages/shared/src/authz/domain/authErrors.ts index 155ffb547..24fa03341 100644 --- a/packages/shared/src/authz/domain/authErrors.ts +++ b/packages/shared/src/authz/domain/authErrors.ts @@ -1,4 +1,4 @@ -type AuthError = { +export type AuthError = { code: ErrorCode message: string } @@ -25,12 +25,42 @@ export const ProjectNoAccessError = defineAuthError({ message: 'You do not have access to the project' }) +export const ProjectRoleNotFoundError = defineAuthError({ + code: 'ProjectRoleNotFound', + message: 'Could not resolve your project role' +}) + +export const WorkspaceNotFoundError = defineAuthError({ + code: 'WorkspaceNotFound', + message: 'Workspace not found' +}) + export const WorkspaceNoAccessError = defineAuthError({ code: 'WorkspaceNoAccess', message: 'You do not have access to the workspace' }) +export const WorkspaceSsoProviderNotFoundError = defineAuthError({ + code: 'WorkspaceSsoProviderNotFound', + message: 'The workspace SSO provider was not found' +}) + export const WorkspaceSsoSessionInvalidError = defineAuthError({ code: 'WorkspaceSsoSessionInvalid', message: 'Your workspace SSO session is invalid' }) + +export const WorkspaceSsoSessionNotFoundError = defineAuthError({ + code: 'WorkspaceSsoSessionNotFound', + message: 'Your workspace SSO session was not found' +}) + +export const WorkspaceRoleNotFoundError = defineAuthError({ + code: 'WorkspaceRoleNotFound', + message: 'The user does not have a role in the workspace' +}) + +export const ServerRoleNotFoundError = defineAuthError({ + code: 'ServerRoleNotFound', + message: 'Could not resolve your server role' +}) diff --git a/packages/shared/src/authz/domain/authResult.ts b/packages/shared/src/authz/domain/authResult.ts deleted file mode 100644 index 4fd888b08..000000000 --- a/packages/shared/src/authz/domain/authResult.ts +++ /dev/null @@ -1,19 +0,0 @@ -type AuthSuccess = { - authorized: true -} - -export type AuthFailure = { - authorized: false - error: T -} - -export type AuthResult = AuthSuccess | AuthFailure - -export const authorized = (): AuthSuccess => ({ - authorized: true -}) - -export const unauthorized = (error: T): AuthFailure => ({ - authorized: false, - error -}) diff --git a/packages/shared/src/authz/domain/core/operations.ts b/packages/shared/src/authz/domain/core/operations.ts index a44bb0b85..c9b657853 100644 --- a/packages/shared/src/authz/domain/core/operations.ts +++ b/packages/shared/src/authz/domain/core/operations.ts @@ -1,3 +1,7 @@ +import Result from 'true-myth/result' import { ServerRoles } from '../../../core/constants.js' +import { ServerRoleNotFoundError } from '../authErrors.js' -export type GetServerRole = (args: { userId: string }) => Promise +export type GetServerRole = (args: { + userId: string +}) => Promise> diff --git a/packages/shared/src/authz/domain/errors.ts b/packages/shared/src/authz/domain/errors.ts index c7a03ddd8..4c0501b8e 100644 --- a/packages/shared/src/authz/domain/errors.ts +++ b/packages/shared/src/authz/domain/errors.ts @@ -1,3 +1,9 @@ +export class LogicError extends Error { + constructor(message: string) { + super(message) + } +} + export class ProjectNotFoundError extends Error { constructor({ projectId }: { projectId: string }) { super(`Project with id ${projectId} not found`) diff --git a/packages/shared/src/authz/domain/loaders.ts b/packages/shared/src/authz/domain/loaders.ts index 485dfab25..d689ec5ee 100644 --- a/packages/shared/src/authz/domain/loaders.ts +++ b/packages/shared/src/authz/domain/loaders.ts @@ -1,3 +1,5 @@ +import { OverrideProperties } from 'type-fest' +import { MaybeAsync } from '../../core/index.js' import type { GetServerRole } from './core/operations.js' import type { GetProject, GetProjectRole } from './projects/operations.js' import type { @@ -8,11 +10,47 @@ import type { GetWorkspaceSsoSession } from './workspaces/operations.js' -export type AuthCheckContext = { - loaders: Pick +// utility type that ensures all properties functions that return promises +type PromiseAll = { + [K in keyof T]: T[K] extends (...args: infer Args) => MaybeAsync + ? (...args: Args) => Promise + : never } -export type AuthCheckContextLoaders = { +// wrapper type for AllAuthCheckContextLoaders that ensures loaders follow the expected schema +type AuthContextLoaderMappingDefinition< + Mapping extends { + [Key in keyof Mapping]: Key extends AuthCheckContextLoaderKeys + ? Mapping[Key] + : never + } +> = PromiseAll< + OverrideProperties< + { + [key in AuthCheckContextLoaderKeys]: unknown + }, + Mapping + > +> + +/** + * All loaders must be listed here for app startup validation to work properly + */ +export const AuthCheckContextLoaderKeys = { + getEnv: 'getEnv', + getProject: 'getProject', + getProjectRole: 'getProjectRole', + getServerRole: 'getServerRole', + getWorkspace: 'getWorkspace', + getWorkspaceRole: 'getWorkspaceRole', + getWorkspaceSsoProvider: 'getWorkspaceSsoProvider', + getWorkspaceSsoSession: 'getWorkspaceSsoSession' +} + +export type AuthCheckContextLoaderKeys = + (typeof AuthCheckContextLoaderKeys)[keyof typeof AuthCheckContextLoaderKeys] + +export type AllAuthCheckContextLoaders = AuthContextLoaderMappingDefinition<{ getEnv: GetEnv getProject: GetProject getProjectRole: GetProjectRole @@ -21,4 +59,12 @@ export type AuthCheckContextLoaders = { getWorkspaceRole: GetWorkspaceRole getWorkspaceSsoProvider: GetWorkspaceSsoProvider getWorkspaceSsoSession: GetWorkspaceSsoSession +}> + +export type AuthCheckContextLoaders< + LoaderKeys extends AuthCheckContextLoaderKeys = AuthCheckContextLoaderKeys +> = Pick + +export type AuthCheckContext = { + loaders: AuthCheckContextLoaders } diff --git a/packages/shared/src/authz/domain/policies.ts b/packages/shared/src/authz/domain/policies.ts index e666f1580..9633f6d9e 100644 --- a/packages/shared/src/authz/domain/policies.ts +++ b/packages/shared/src/authz/domain/policies.ts @@ -1,3 +1,15 @@ +import Result from 'true-myth/result' +import { AuthError } from './authErrors.js' +import { AuthCheckContextLoaderKeys, AuthCheckContextLoaders } from './loaders.js' + export type ProjectContext = { projectId: string } export type UserContext = { userId?: string } + +export type AuthPolicyFactory< + LoaderKeys extends AuthCheckContextLoaderKeys, + Args extends object, + ExpectedAuthErrors extends AuthError +> = ( + loaders: AuthCheckContextLoaders +) => (args: Args) => Promise> diff --git a/packages/shared/src/authz/domain/projects/operations.ts b/packages/shared/src/authz/domain/projects/operations.ts index d7b65a5e7..f488c8a89 100644 --- a/packages/shared/src/authz/domain/projects/operations.ts +++ b/packages/shared/src/authz/domain/projects/operations.ts @@ -1,10 +1,19 @@ +import { Result } from 'true-myth/result' import { StreamRoles } from '../../../core/constants.js' import { Project } from './types.js' +import { + ProjectNoAccessError, + ProjectNotFoundError, + ProjectRoleNotFoundError +} from '../authErrors.js' -// TODO: this should probably just throw an error if the project doesn't exist -export type GetProject = (args: { projectId: string }) => Promise +export type GetProject = (args: { + projectId: string +}) => Promise< + Result +> export type GetProjectRole = (args: { userId: string projectId: string -}) => Promise +}) => Promise> diff --git a/packages/shared/src/authz/domain/workspaces/operations.ts b/packages/shared/src/authz/domain/workspaces/operations.ts index 60db00294..355747fd4 100644 --- a/packages/shared/src/authz/domain/workspaces/operations.ts +++ b/packages/shared/src/authz/domain/workspaces/operations.ts @@ -1,21 +1,30 @@ +import Result from 'true-myth/result' import { WorkspaceRoles } from '../../../core/constants.js' import { FeatureFlags } from '../../../environment/index.js' import { Workspace, WorkspaceSsoProvider, WorkspaceSsoSession } from './types.js' +import { + WorkspaceNotFoundError, + WorkspaceRoleNotFoundError, + WorkspaceSsoProviderNotFoundError, + WorkspaceSsoSessionNotFoundError +} from '../authErrors.js' -export type GetWorkspace = (args: { workspaceId: string }) => Promise +export type GetWorkspace = (args: { + workspaceId: string +}) => Promise> export type GetWorkspaceRole = (args: { userId: string workspaceId: string -}) => Promise +}) => Promise> export type GetWorkspaceSsoProvider = (args: { workspaceId: string -}) => Promise +}) => Promise> export type GetWorkspaceSsoSession = (args: { userId: string workspaceId: string -}) => Promise +}) => Promise> -export type GetEnv = () => FeatureFlags +export type GetEnv = () => Result diff --git a/packages/shared/src/authz/index.ts b/packages/shared/src/authz/index.ts index 21db6af0b..307d759ad 100644 --- a/packages/shared/src/authz/index.ts +++ b/packages/shared/src/authz/index.ts @@ -1,4 +1,8 @@ -export { authPoliciesFactory, AuthPolices } from './policies/index.js' -export { AuthCheckContextLoaders } from './domain/loaders.js' +export { authPoliciesFactory, AuthPolicies } from './policies/index.js' +export { + AllAuthCheckContextLoaders, + AuthCheckContextLoaders, + AuthCheckContextLoaderKeys +} from './domain/loaders.js' export * from './domain/authErrors.js' diff --git a/packages/shared/src/authz/policies/canQueryProject.spec.ts b/packages/shared/src/authz/policies/canQueryProject.spec.ts index 15682461a..f81a7146e 100644 --- a/packages/shared/src/authz/policies/canQueryProject.spec.ts +++ b/packages/shared/src/authz/policies/canQueryProject.spec.ts @@ -3,8 +3,16 @@ import { canQueryProjectPolicyFactory } from './canQueryProject.js' import { parseFeatureFlags } from '../../environment/index.js' import crs from 'crypto-random-string' import { Roles } from '../../core/constants.js' -import { ProjectNoAccessError, ProjectNotFoundError } from '../domain/authErrors.js' +import { + ProjectNoAccessError, + ProjectNotFoundError, + ProjectRoleNotFoundError, + WorkspaceRoleNotFoundError, + WorkspaceSsoProviderNotFoundError, + WorkspaceSsoSessionNotFoundError +} from '../domain/authErrors.js' import { getProjectFake } from '../../tests/fakes.js' +import { err, ok } from 'true-myth/result' const canQueryProjectArgs = () => { const projectId = crs({ length: 10 }) @@ -16,8 +24,8 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => describe('project not found', () => { it('by returning project no access', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => parseFeatureFlags({}), - getProject: () => Promise.resolve(null), + getEnv: async () => ok(parseFeatureFlags({})), + getProject: () => Promise.resolve(err(ProjectNotFoundError)), getProjectRole: () => { assert.fail() }, @@ -36,8 +44,8 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(false) - if (!canQuery.authorized) { + expect(canQuery.isOk).toBe(false) + if (!canQuery.isOk) { expect(canQuery.error.code).toBe(ProjectNotFoundError.code) } }) @@ -45,7 +53,7 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => describe('project visibility', () => { it('allows anyone on a public project', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => parseFeatureFlags({}), + getEnv: async () => ok(parseFeatureFlags({})), getProject: getProjectFake({ isPublic: true }), getProjectRole: () => { assert.fail() @@ -64,11 +72,11 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) it('allows anyone on a linkShareable project', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => parseFeatureFlags({}), + getEnv: async () => ok(parseFeatureFlags({})), getProject: getProjectFake({ isDiscoverable: true }), getProjectRole: () => { assert.fail() @@ -87,7 +95,7 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) }) @@ -96,12 +104,14 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => 'allows access to private projects with role %', async (role) => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'false' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'false' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false }), - getProjectRole: () => Promise.resolve(role), + getProjectRole: () => Promise.resolve(ok(role)), getServerRole: () => { assert.fail() }, @@ -116,17 +126,19 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) } ) it('does not allow access to private projects without a project role', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'false' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'false' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false }), - getProjectRole: () => Promise.resolve(null), + getProjectRole: () => Promise.resolve(err(ProjectRoleNotFoundError)), getServerRole: () => { assert.fail() }, @@ -141,8 +153,8 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(false) - if (!canQuery.authorized) { + expect(canQuery.isOk).toBe(false) + if (!canQuery.isOk) { expect(canQuery.error.code).toBe(ProjectNoAccessError.code) } }) @@ -150,9 +162,10 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => describe('admin override', () => { it('allows server admins without project roles on private projects if admin override is enabled', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => parseFeatureFlags({ FF_ADMIN_OVERRIDE_ENABLED: 'true' }), + getEnv: async () => + ok(parseFeatureFlags({ FF_ADMIN_OVERRIDE_ENABLED: 'true' })), getProject: getProjectFake({ isDiscoverable: false, isPublic: false }), - getServerRole: () => Promise.resolve(Roles.Server.Admin), + getServerRole: () => Promise.resolve(ok(Roles.Server.Admin)), getProjectRole: () => { assert.fail() }, @@ -167,20 +180,22 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) it('does not allow server admins without project roles on private projects if admin override is disabled', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_ADMIN_OVERRIDE_ENABLED: 'false', - FF_WORKSPACES_MODULE_ENABLED: 'false' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_ADMIN_OVERRIDE_ENABLED: 'false', + FF_WORKSPACES_MODULE_ENABLED: 'false' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false }), - getServerRole: () => Promise.resolve(Roles.Server.Admin), + getServerRole: () => Promise.resolve(ok(Roles.Server.Admin)), getProjectRole: () => { - return Promise.resolve(null) + return Promise.resolve(err(ProjectRoleNotFoundError)) }, getWorkspaceRole: () => { assert.fail() @@ -193,8 +208,8 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(false) - if (!canQuery.authorized) { + expect(canQuery.isOk).toBe(false) + if (!canQuery.isOk) { expect(canQuery.error.code).toBe(ProjectNoAccessError.code) } }) @@ -202,13 +217,14 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => describe('the workspace world', () => { it('does not check workspace rules if the workspaces module is not enabled', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => parseFeatureFlags({ FF_WORKSPACES_MODULE_ENABLED: 'false' }), + getEnv: async () => + ok(parseFeatureFlags({ FF_WORKSPACES_MODULE_ENABLED: 'false' })), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, @@ -223,24 +239,26 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) it('does not allow project access without a workspace role', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve(null), + getWorkspaceRole: () => Promise.resolve(err(WorkspaceRoleNotFoundError)), getWorkspaceSsoSession: () => { assert.fail() }, @@ -249,48 +267,53 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(false) + expect(canQuery.isOk).toBe(false) }) it('allows project access via workspace role if user does not have project role', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve(null), + getProjectRole: () => Promise.resolve(err(ProjectRoleNotFoundError)), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve('workspace:admin'), + getWorkspaceRole: () => Promise.resolve(ok('workspace:admin')), getWorkspaceSsoSession: () => { assert.fail() }, - getWorkspaceSsoProvider: () => Promise.resolve(null) + getWorkspaceSsoProvider: () => + Promise.resolve(err(WorkspaceSsoProviderNotFoundError)) }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) it('does not check SSO sessions if user is workspace guest', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve('workspace:guest'), + getWorkspaceRole: () => Promise.resolve(ok('workspace:guest')), getWorkspaceSsoSession: () => { assert.fail() }, @@ -299,105 +322,115 @@ describe('canQueryProjectPolicyFactory creates a function, that handles ', () => } }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) it('does not check SSO sessions if workspace does not have it enabled', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve('workspace:member'), + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')), getWorkspaceSsoSession: () => { assert.fail() }, - getWorkspaceSsoProvider: () => Promise.resolve(null) + getWorkspaceSsoProvider: () => + Promise.resolve(err(WorkspaceSsoProviderNotFoundError)) }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) it('does not allow project access if SSO session is missing', async () => { const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve('workspace:member'), - getWorkspaceSsoSession: () => Promise.resolve(null), - getWorkspaceSsoProvider: () => Promise.resolve({ providerId: 'foo' }) + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')), + getWorkspaceSsoSession: () => + Promise.resolve(err(WorkspaceSsoSessionNotFoundError)), + getWorkspaceSsoProvider: () => Promise.resolve(ok({ providerId: 'foo' })) }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(false) + expect(canQuery.isOk).toBe(false) }) it('does not allow project access if SSO session is expired or invalid', async () => { const date = new Date() date.setDate(date.getDate() - 1) const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve('workspace:member'), + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')), getWorkspaceSsoSession: () => - Promise.resolve({ validUntil: date, userId: 'foo', providerId: 'foo' }), - getWorkspaceSsoProvider: () => Promise.resolve({ providerId: 'foo' }) + Promise.resolve(ok({ validUntil: date, userId: 'foo', providerId: 'foo' })), + getWorkspaceSsoProvider: () => Promise.resolve(ok({ providerId: 'foo' })) }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(false) + expect(canQuery.isOk).toBe(false) }) it('allows project access if SSO session is valid', async () => { const date = new Date() date.setDate(date.getDate() + 1) const canQueryProject = canQueryProjectPolicyFactory({ - getEnv: () => - parseFeatureFlags({ - FF_WORKSPACES_MODULE_ENABLED: 'true' - }), + getEnv: async () => + ok( + parseFeatureFlags({ + FF_WORKSPACES_MODULE_ENABLED: 'true' + }) + ), getProject: getProjectFake({ isDiscoverable: false, isPublic: false, workspaceId: crs({ length: 10 }) }), - getProjectRole: () => Promise.resolve('stream:contributor'), + getProjectRole: () => Promise.resolve(ok('stream:contributor')), getServerRole: () => { assert.fail() }, - getWorkspaceRole: () => Promise.resolve('workspace:member'), + getWorkspaceRole: () => Promise.resolve(ok('workspace:member')), getWorkspaceSsoSession: () => - Promise.resolve({ validUntil: date, userId: 'foo', providerId: 'foo' }), - getWorkspaceSsoProvider: () => Promise.resolve({ providerId: 'foo' }) + Promise.resolve(ok({ validUntil: date, userId: 'foo', providerId: 'foo' })), + getWorkspaceSsoProvider: () => Promise.resolve(ok({ providerId: 'foo' })) }) const canQuery = await canQueryProject(canQueryProjectArgs()) - expect(canQuery.authorized).toBe(true) + expect(canQuery.isOk).toBe(true) }) }) }) diff --git a/packages/shared/src/authz/policies/canQueryProject.ts b/packages/shared/src/authz/policies/canQueryProject.ts index 2d6896c51..cb4ee24ab 100644 --- a/packages/shared/src/authz/policies/canQueryProject.ts +++ b/packages/shared/src/authz/policies/canQueryProject.ts @@ -2,13 +2,11 @@ import { requireAnyWorkspaceRole, requireMinimumWorkspaceRole } from '../checks/workspaceRole.js' -import { AuthResult, authorized, unauthorized } from '../domain/authResult.js' import { requireExactProjectVisibilityFactory, requireMinimumProjectRoleFactory } from '../checks/projects.js' -import { AuthCheckContextLoaders } from '../domain/loaders.js' -import { ProjectContext, UserContext } from '../domain/policies.js' +import { AuthPolicyFactory, ProjectContext, UserContext } from '../domain/policies.js' import { requireExactServerRole } from '../checks/serverRole.js' import { requireValidWorkspaceSsoSession } from '../checks/workspaceSso.js' import { Roles } from '../../core/constants.js' @@ -18,36 +16,37 @@ import { WorkspaceNoAccessError, WorkspaceSsoSessionInvalidError } from '../domain/authErrors.js' +import { err, isOk, ok } from 'true-myth/result' +import { AuthCheckContextLoaderKeys } from '../domain/loaders.js' +import { LogicError } from '../domain/errors.js' -export const canQueryProjectPolicyFactory = - ( - loaders: Pick< - AuthCheckContextLoaders, - | 'getEnv' - | 'getProject' - | 'getProjectRole' - | 'getServerRole' - | 'getWorkspaceRole' - | 'getWorkspaceSsoProvider' - | 'getWorkspaceSsoSession' - > - ) => - async ({ - userId, - projectId - }: UserContext & ProjectContext): Promise< - AuthResult< - | typeof ProjectNotFoundError - | typeof ProjectNoAccessError - | typeof WorkspaceNoAccessError - | typeof WorkspaceSsoSessionInvalidError - > - > => { - const { FF_ADMIN_OVERRIDE_ENABLED, FF_WORKSPACES_MODULE_ENABLED } = loaders.getEnv() +export const canQueryProjectPolicyFactory: AuthPolicyFactory< + | typeof AuthCheckContextLoaderKeys.getEnv + | typeof AuthCheckContextLoaderKeys.getProject + | typeof AuthCheckContextLoaderKeys.getProjectRole + | typeof AuthCheckContextLoaderKeys.getServerRole + | typeof AuthCheckContextLoaderKeys.getWorkspaceRole + | typeof AuthCheckContextLoaderKeys.getWorkspaceSsoProvider + | typeof AuthCheckContextLoaderKeys.getWorkspaceSsoSession, + UserContext & ProjectContext, + | typeof ProjectNotFoundError + | typeof ProjectNoAccessError + | typeof WorkspaceNoAccessError + | typeof WorkspaceSsoSessionInvalidError +> = + (loaders) => + async ({ userId, projectId }) => { + const env = await loaders.getEnv() + if (!isOk(env)) { + throw new LogicError('Failed to load environment variables') + } + + const { FF_ADMIN_OVERRIDE_ENABLED, FF_WORKSPACES_MODULE_ENABLED } = env.value const project = await loaders.getProject({ projectId }) - // hiding the project not found, to stop id brute force lookups - if (!project) return unauthorized(ProjectNotFoundError) + if (!isOk(project)) { + return err(project.error) + } // All users may read public projects const isPublicResult = await requireExactProjectVisibilityFactory({ loaders })({ @@ -55,7 +54,7 @@ export const canQueryProjectPolicyFactory = projectVisibility: 'public' }) if (isPublicResult) { - return authorized() + return ok(true) } // All users may read link-shareable projects @@ -66,11 +65,11 @@ export const canQueryProjectPolicyFactory = projectVisibility: 'linkShareable' }) if (isLinkShareableResult) { - return authorized() + return ok(true) } // From this point on, you cannot pass as an unknown user if (!userId) { - return unauthorized(ProjectNoAccessError) + return err(ProjectNoAccessError) } // When G O D M O D E is enabled @@ -81,11 +80,11 @@ export const canQueryProjectPolicyFactory = role: Roles.Server.Admin }) if (isServerAdminResult) { - return authorized() + return ok(true) } } - const { workspaceId } = project + const { workspaceId } = project.value // When a project belongs to a workspace if (FF_WORKSPACES_MODULE_ENABLED && !!workspaceId) { @@ -96,7 +95,7 @@ export const canQueryProjectPolicyFactory = }) if (!hasWorkspaceRoleResult) { // Should we hide the fact, the project is in a workspace? - return unauthorized(WorkspaceNoAccessError) + return err(WorkspaceNoAccessError) } const hasMinimumMemberRole = await requireMinimumWorkspaceRole({ @@ -111,7 +110,7 @@ export const canQueryProjectPolicyFactory = const workspaceSsoProvider = await loaders.getWorkspaceSsoProvider({ workspaceId }) - if (!!workspaceSsoProvider) { + if (workspaceSsoProvider.isOk) { // Member and admin user must have a valid SSO session to read project data const hasValidSsoSessionResult = await requireValidWorkspaceSsoSession({ loaders @@ -120,12 +119,12 @@ export const canQueryProjectPolicyFactory = workspaceId }) if (!hasValidSsoSessionResult) { - return unauthorized(WorkspaceSsoSessionInvalidError) + return err(WorkspaceSsoSessionInvalidError) } } // Workspace members get to go through without an explicit project role - return authorized() + return ok(true) } else { // just fall through to the generic project role check for workspace:guest-s } @@ -140,7 +139,7 @@ export const canQueryProjectPolicyFactory = role: 'stream:reviewer' }) if (hasMinimumProjectRoleResult) { - return authorized() + return ok(true) } - return unauthorized(ProjectNoAccessError) + return err(ProjectNoAccessError) } diff --git a/packages/shared/src/authz/policies/index.ts b/packages/shared/src/authz/policies/index.ts index d8b7fdec5..f7538c332 100644 --- a/packages/shared/src/authz/policies/index.ts +++ b/packages/shared/src/authz/policies/index.ts @@ -1,10 +1,10 @@ -import { AuthCheckContextLoaders } from '../domain/loaders.js' +import { AllAuthCheckContextLoaders } from '../domain/loaders.js' import { canQueryProjectPolicyFactory } from './canQueryProject.js' -export const authPoliciesFactory = (loaders: AuthCheckContextLoaders) => ({ +export const authPoliciesFactory = (loaders: AllAuthCheckContextLoaders) => ({ project: { canQuery: canQueryProjectPolicyFactory(loaders) } }) -export type AuthPolices = ReturnType +export type AuthPolicies = ReturnType diff --git a/packages/shared/src/tests/fakes.ts b/packages/shared/src/tests/fakes.ts index 10458a995..c44ec6e07 100644 --- a/packages/shared/src/tests/fakes.ts +++ b/packages/shared/src/tests/fakes.ts @@ -1,14 +1,15 @@ import { merge } from 'lodash' import { Project } from '../authz/domain/projects/types.js' +import { ok, Result } from 'true-myth/result' export const fakeGetFactory = >(defaults: T) => (overrides?: Partial) => - (): Promise => { + (): Promise> => { if (overrides) { - return Promise.resolve(merge(defaults, overrides)) + return Promise.resolve(ok(merge(defaults, overrides))) } - return Promise.resolve(defaults) + return Promise.resolve(ok(defaults)) } export const getProjectFake = fakeGetFactory({ diff --git a/yarn.lock b/yarn.lock index 2fae86476..e7637d2cb 100644 --- a/yarn.lock +++ b/yarn.lock @@ -16953,6 +16953,7 @@ __metadata: stripe: "npm:^17.1.0" subscriptions-transport-ws: "npm:^0.11.0" supertest: "npm:^4.0.2" + true-myth: "npm:^8.5.0" ts-node: "npm:^10.9.2" tsconfig-paths: "npm:^4.0.0" type-fest: "npm:^4.26.1" @@ -16993,6 +16994,7 @@ __metadata: mixpanel: "npm:^0.17.0" pino: "npm:^8.7.0" pino-http: "npm:^8.0.0" + true-myth: "npm:^8.5.0" tshy: "npm:^1.14.0" type-fest: "npm:^3.11.1" typescript: "npm:^4.5.4" @@ -50223,6 +50225,13 @@ __metadata: languageName: node linkType: hard +"true-myth@npm:^8.5.0": + version: 8.5.0 + resolution: "true-myth@npm:8.5.0" + checksum: 10/f3f96d96df8f0bfb5d26c379bbe029b947349b9dabdd00afa939ff75e979864c951c647ba35b21a13d676a9b34d721bffd2a1e9f0750d5d2a588988f16de5bd0 + languageName: node + linkType: hard + "ts-api-utils@npm:^1.3.0": version: 1.3.0 resolution: "ts-api-utils@npm:1.3.0" From 01cbd46939ff7b45101053445cec80b3dab39983 Mon Sep 17 00:00:00 2001 From: Alexandru Popovici Date: Wed, 26 Mar 2025 10:27:03 +0200 Subject: [PATCH 9/9] chore(viewer-lib): Updates to the export list (#4258) - Basit pass and pipeline have been renamed to Shaded - Viewer now exports all implemented passes - Viewer now exports the pass option types as well as defaults --- packages/viewer/src/index.ts | 41 +++++++++++++++---- .../src/modules/extensions/ViewModes.ts | 4 +- .../src/modules/pipeline/Passes/EdgesPass.ts | 13 +++--- .../Passes/{BasitPass.ts => ShadedPass.ts} | 2 +- .../pipeline/Pipelines/BasitViewPipeline.ts | 4 +- .../pipeline/Pipelines/EdgesPipeline.ts | 6 +-- .../Pipelines/MRT/MRTEdgesPipeline.ts | 6 +-- .../Pipelines/MRT/MRTPenViewPipeline.ts | 6 +-- .../Pipelines/MRT/MRTShadedViewPipeline.ts | 6 +-- .../pipeline/Pipelines/PenViewPipeline.ts | 6 +-- .../pipeline/Pipelines/ShadedViewPipeline.ts | 6 +-- .../Pipelines/TechnicalViewPipeline.ts | 10 ++--- 12 files changed, 69 insertions(+), 41 deletions(-) rename packages/viewer/src/modules/pipeline/Passes/{BasitPass.ts => ShadedPass.ts} (99%) diff --git a/packages/viewer/src/index.ts b/packages/viewer/src/index.ts index 6d9a8bd5c..693485d73 100644 --- a/packages/viewer/src/index.ts +++ b/packages/viewer/src/index.ts @@ -114,15 +114,26 @@ import { } from './modules/pipeline/Passes/GPass.js' import { Pipeline } from './modules/pipeline/Pipelines/Pipeline.js' import { ProgressivePipeline } from './modules/pipeline/Pipelines/ProgressivePipeline.js' -import { DepthPass } from './modules/pipeline/Passes/DepthPass.js' +import { DepthPass, DepthPassOptions } from './modules/pipeline/Passes/DepthPass.js' import { GeometryPass } from './modules/pipeline/Passes/GeometryPass.js' import { NormalsPass } from './modules/pipeline/Passes/NormalsPass.js' -import { InputType, OutputPass } from './modules/pipeline/Passes/OutputPass.js' -import { ViewportPass } from './modules/pipeline/Passes/ViewportPass.js' -import { BlendPass } from './modules/pipeline/Passes/BlendPass.js' +import { + InputType, + OutputPass, + OutputPassOptions +} from './modules/pipeline/Passes/OutputPass.js' +import { + ViewportPass, + ViewportPassOptions +} from './modules/pipeline/Passes/ViewportPass.js' +import { BlendPass, BlendPassOptions } from './modules/pipeline/Passes/BlendPass.js' import { DepthNormalPass } from './modules/pipeline/Passes/DepthNormalPass.js' -import { BasitPass } from './modules/pipeline/Passes/BasitPass.js' -import { ProgressiveAOPass } from './modules/pipeline/Passes/ProgressiveAOPass.js' +import { ShadedPass } from './modules/pipeline/Passes/ShadedPass.js' +import { + DefaultProgressiveAOPassOptions, + ProgressiveAOPass, + ProgressiveAOPassOptions +} from './modules/pipeline/Passes/ProgressiveAOPass.js' import { TAAPass } from './modules/pipeline/Passes/TAAPass.js' import { FilterMaterial, @@ -133,13 +144,18 @@ import { SpeckleOfflineLoader } from './modules/loaders/Speckle/SpeckleOfflineLo import { AccelerationStructure } from './modules/objects/AccelerationStructure.js' import { TopLevelAccelerationStructure } from './modules/objects/TopLevelAccelerationStructure.js' import { StencilPass } from './modules/pipeline/Passes/StencilPass.js' -import { StencilMaskPass } from './modules/pipeline/Passes/StencilMaskPass.js' import { SpeckleWebGLRenderer } from './modules/objects/SpeckleWebGLRenderer.js' import { InstancedMeshBatch } from './modules/batching/InstancedMeshBatch.js' import { ViewModeEvent, ViewModeEventPayload } from './modules/extensions/ViewModes.js' import { BasitPipeline } from './modules/pipeline/Pipelines/BasitViewPipeline.js' import SpeckleMesh from './modules/objects/SpeckleMesh.js' import SpeckleInstancedMesh from './modules/objects/SpeckleInstancedMesh.js' +import { StencilMaskPass } from './modules/pipeline/Passes/StencilMaskPass.js' +import { + DefaultEdgesPassOptions, + EdgesPass, + EdgesPassOptions +} from './modules/pipeline/Passes/EdgesPass.js' export { Viewer, @@ -207,12 +223,21 @@ export { ViewportPass, BlendPass, DepthNormalPass, - BasitPass, + ShadedPass as BasitPass, ProgressiveAOPass, TAAPass, StencilPass, StencilMaskPass, + EdgesPass, PassOptions, + EdgesPassOptions as EdgePassOptions, + BlendPassOptions, + DepthPassOptions, + OutputPassOptions, + ProgressiveAOPassOptions, + ViewportPassOptions, + DefaultEdgesPassOptions, + DefaultProgressiveAOPassOptions, ClearFlags, ObjectVisibility, InputType, diff --git a/packages/viewer/src/modules/extensions/ViewModes.ts b/packages/viewer/src/modules/extensions/ViewModes.ts index 70eaffd6a..6abfbb088 100644 --- a/packages/viewer/src/modules/extensions/ViewModes.ts +++ b/packages/viewer/src/modules/extensions/ViewModes.ts @@ -1,5 +1,5 @@ import { IViewer, UpdateFlags, ViewerEvent } from '../../IViewer.js' -import { BasitPass } from '../pipeline/Passes/BasitPass.js' +import { ShadedPass } from '../pipeline/Passes/ShadedPass.js' import { GPass } from '../pipeline/Passes/GPass.js' import { ArcticViewPipeline } from '../pipeline/Pipelines/ArcticViewPipeline.js' import { BasitPipeline } from '../pipeline/Pipelines/BasitViewPipeline.js' @@ -53,7 +53,7 @@ export class ViewModes extends Extension { .getRenderer() .pipeline.getPass('BASIT') .forEach((pass: GPass) => { - ;(pass as BasitPass).applyColorIndices() + ;(pass as ShadedPass).applyColorIndices() }) } }) diff --git a/packages/viewer/src/modules/pipeline/Passes/EdgesPass.ts b/packages/viewer/src/modules/pipeline/Passes/EdgesPass.ts index 0527aed56..f98c14186 100644 --- a/packages/viewer/src/modules/pipeline/Passes/EdgesPass.ts +++ b/packages/viewer/src/modules/pipeline/Passes/EdgesPass.ts @@ -15,7 +15,7 @@ import { speckleEdgesGeneratorFrag } from '../../materials/shaders/speckle-edges import { speckleEdgesGeneratorVert } from '../../materials/shaders/speckle-edges-generator-vert.js' import { Pipeline } from '../Pipelines/Pipeline.js' -export interface EdgePassOptions extends PassOptions { +export interface EdgesPassOptions extends PassOptions { depthMultiplier?: number depthBias?: number normalMultiplier?: number @@ -26,7 +26,7 @@ export interface EdgePassOptions extends PassOptions { backgroundTextureIntensity: number } -export const DefaultEdgePassOptions: Required = { +export const DefaultEdgesPassOptions: Required = { depthMultiplier: 1, depthBias: 0.001, normalMultiplier: 1, @@ -37,13 +37,16 @@ export const DefaultEdgePassOptions: Required = { backgroundTextureIntensity: 0 } -export class EdgePass extends BaseGPass { +export class EdgesPass extends BaseGPass { public edgesMaterial: ShaderMaterial private fsQuad: FullScreenQuad - public _options: Required = Object.assign({}, DefaultEdgePassOptions) + public _options: Required = Object.assign( + {}, + DefaultEdgesPassOptions + ) - public set options(value: EdgePassOptions) { + public set options(value: EdgesPassOptions) { super.options = value this.setBackground( this._options.backgroundTexture, diff --git a/packages/viewer/src/modules/pipeline/Passes/BasitPass.ts b/packages/viewer/src/modules/pipeline/Passes/ShadedPass.ts similarity index 99% rename from packages/viewer/src/modules/pipeline/Passes/BasitPass.ts rename to packages/viewer/src/modules/pipeline/Passes/ShadedPass.ts index 6925f2066..4db7b5a2c 100644 --- a/packages/viewer/src/modules/pipeline/Passes/BasitPass.ts +++ b/packages/viewer/src/modules/pipeline/Passes/ShadedPass.ts @@ -18,7 +18,7 @@ import SpeckleStandardColoredMaterial from '../../materials/SpeckleStandardColor import { Assets } from '../../../index.js' import SpeckleMesh from '../../objects/SpeckleMesh.js' -export class BasitPass extends BaseGPass { +export class ShadedPass extends BaseGPass { protected tree: WorldTree protected speckleRenderer: SpeckleRenderer protected materialMap: { diff --git a/packages/viewer/src/modules/pipeline/Pipelines/BasitViewPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/BasitViewPipeline.ts index fbc25d2b0..d8dfb3cd8 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/BasitViewPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/BasitViewPipeline.ts @@ -2,7 +2,7 @@ import { ObjectLayers, WorldTree } from '../../../index.js' import SpeckleRenderer from '../../SpeckleRenderer.js' import { GeometryPass } from '../Passes/GeometryPass.js' import { Pipeline } from './Pipeline.js' -import { BasitPass } from '../Passes/BasitPass.js' +import { ShadedPass } from '../Passes/ShadedPass.js' import { ClearFlags, ObjectVisibility } from '../Passes/GPass.js' import { StencilPass } from '../Passes/StencilPass.js' import { StencilMaskPass } from '../Passes/StencilMaskPass.js' @@ -11,7 +11,7 @@ export class BasitPipeline extends Pipeline { constructor(speckleRenderer: SpeckleRenderer, tree: WorldTree) { super(speckleRenderer) - const basitPass = new BasitPass(tree, speckleRenderer) + const basitPass = new ShadedPass(tree, speckleRenderer) basitPass.setLayers([ObjectLayers.STREAM_CONTENT_MESH, ObjectLayers.PROPS]) basitPass.setClearColor(0x000000, 0) basitPass.setClearFlags(ClearFlags.COLOR) diff --git a/packages/viewer/src/modules/pipeline/Pipelines/EdgesPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/EdgesPipeline.ts index a001c30ad..904070989 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/EdgesPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/EdgesPipeline.ts @@ -2,7 +2,7 @@ import SpeckleRenderer from '../../SpeckleRenderer.js' import { BlendPass } from '../Passes/BlendPass.js' import { GeometryPass } from '../Passes/GeometryPass.js' import { DepthPass } from '../Passes/DepthPass.js' -import { EdgePass } from '../Passes/EdgesPass.js' +import { EdgesPass } from '../Passes/EdgesPass.js' import { NormalsPass } from '../Passes/NormalsPass.js' import { ClearFlags, ObjectVisibility } from '../Passes/GPass.js' import { ProgressiveAOPass } from '../Passes/ProgressiveAOPass.js' @@ -71,11 +71,11 @@ export class EdgesPipeline extends ProgressivePipeline { progressiveAOPass.setClearColor(0xffffff, 1) progressiveAOPass.accumulationFrames = this.accumulationFrameCount - const edgesPass = new EdgePass() + const edgesPass = new EdgesPass() edgesPass.setTexture('tDepth', depthPass.outputTarget?.texture) edgesPass.setTexture('tNormal', normalPass.outputTarget?.texture) - const edgesPassDynamic = new EdgePass() + const edgesPassDynamic = new EdgesPass() edgesPassDynamic.setTexture('tDepth', depthPassDynamic.outputTarget?.texture) edgesPassDynamic.setTexture('tNormal', normalPassDynamic.outputTarget?.texture) diff --git a/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTEdgesPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTEdgesPipeline.ts index c80e0f5d9..6eecada97 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTEdgesPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTEdgesPipeline.ts @@ -1,7 +1,7 @@ import SpeckleRenderer from '../../../SpeckleRenderer.js' import { BlendPass } from '../../Passes/BlendPass.js' import { GeometryPass } from '../../Passes/GeometryPass.js' -import { EdgePass } from '../../Passes/EdgesPass.js' +import { EdgesPass } from '../../Passes/EdgesPass.js' import { ClearFlags, ObjectVisibility } from '../../Passes/GPass.js' import { ProgressiveAOPass } from '../../Passes/ProgressiveAOPass.js' import { TAAPass } from '../../Passes/TAAPass.js' @@ -56,12 +56,12 @@ export class MRTEdgesPipeline extends ProgressivePipeline { progressiveAOPass.setClearColor(0xffffff, 1) progressiveAOPass.accumulationFrames = this.accumulationFrameCount - const edgesPass = new EdgePass() + const edgesPass = new EdgesPass() edgesPass.setTexture('tDepth', depthNormalIdPass.depthTexture) edgesPass.setTexture('tNormal', depthNormalIdPass.normalTexture) edgesPass.setTexture('tId', depthNormalIdPass.idTexture) - const edgesPassDynamic = new EdgePass() + const edgesPassDynamic = new EdgesPass() edgesPassDynamic.setTexture('tDepth', depthPassNormalIdDynamic.depthTexture) edgesPassDynamic.setTexture('tNormal', depthPassNormalIdDynamic.normalTexture) edgesPassDynamic.setTexture('tId', depthPassNormalIdDynamic.idTexture) diff --git a/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTPenViewPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTPenViewPipeline.ts index ed9d57482..b889d9dfc 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTPenViewPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTPenViewPipeline.ts @@ -1,7 +1,7 @@ import { ObjectLayers } from '../../../../index.js' import SpeckleRenderer from '../../../SpeckleRenderer.js' import { GeometryPass } from '../../Passes/GeometryPass.js' -import { EdgePass } from '../../Passes/EdgesPass.js' +import { EdgesPass } from '../../Passes/EdgesPass.js' import { OutputPass } from '../../Passes/OutputPass.js' import { ObjectVisibility, ClearFlags } from '../../Passes/GPass.js' import { StencilMaskPass } from '../../Passes/StencilMaskPass.js' @@ -36,12 +36,12 @@ export class MRTPenViewPipeline extends ProgressivePipeline { depthNormalIdPassDynamic.setClearColor(0x000000, 1) depthNormalIdPassDynamic.setClearFlags(ClearFlags.COLOR | ClearFlags.DEPTH) - const edgesPass = new EdgePass() + const edgesPass = new EdgesPass() edgesPass.setTexture('tDepth', depthNormalIdPass.depthTexture) edgesPass.setTexture('tNormal', depthNormalIdPass.normalTexture) edgesPass.setTexture('tId', depthNormalIdPass.idTexture) - const edgesPassDynamic = new EdgePass() + const edgesPassDynamic = new EdgesPass() edgesPassDynamic.setTexture('tDepth', depthNormalIdPassDynamic.depthTexture) edgesPassDynamic.setTexture('tNormal', depthNormalIdPassDynamic.normalTexture) edgesPassDynamic.setTexture('tId', depthNormalIdPassDynamic.idTexture) diff --git a/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTShadedViewPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTShadedViewPipeline.ts index b4129f70c..2ed67db80 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTShadedViewPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/MRT/MRTShadedViewPipeline.ts @@ -2,7 +2,7 @@ import { ObjectLayers, AssetType } from '../../../../index.js' import SpeckleRenderer from '../../../SpeckleRenderer.js' import { BlendPass } from '../../Passes/BlendPass.js' import { GeometryPass } from '../../Passes/GeometryPass.js' -import { EdgePass } from '../../Passes/EdgesPass.js' +import { EdgesPass } from '../../Passes/EdgesPass.js' import { ClearFlags, ObjectVisibility } from '../../Passes/GPass.js' import { StencilMaskPass } from '../../Passes/StencilMaskPass.js' import { StencilPass } from '../../Passes/StencilPass.js' @@ -50,12 +50,12 @@ export class MRTShadedViewPipeline extends ProgressivePipeline { const shadowcatcherPass = new GeometryPass() shadowcatcherPass.setLayers([ObjectLayers.SHADOWCATCHER]) - const edgesPass = new EdgePass() + const edgesPass = new EdgesPass() edgesPass.setTexture('tDepth', depthNormalIdPass.depthTexture) edgesPass.setTexture('tNormal', depthNormalIdPass.normalTexture) edgesPass.setTexture('tId', depthNormalIdPass.idTexture) - const edgesPassDynamic = new EdgePass() + const edgesPassDynamic = new EdgesPass() edgesPassDynamic.setTexture('tDepth', depthPassNormalIdDynamic.depthTexture) edgesPassDynamic.setTexture('tNormal', depthPassNormalIdDynamic.normalTexture) edgesPassDynamic.setTexture('tId', depthPassNormalIdDynamic.idTexture) diff --git a/packages/viewer/src/modules/pipeline/Pipelines/PenViewPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/PenViewPipeline.ts index ef231872e..54614bdf3 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/PenViewPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/PenViewPipeline.ts @@ -1,6 +1,6 @@ import SpeckleRenderer from '../../SpeckleRenderer.js' import { DepthPass } from '../Passes/DepthPass.js' -import { EdgePass } from '../Passes/EdgesPass.js' +import { EdgesPass } from '../Passes/EdgesPass.js' import { NormalsPass } from '../Passes/NormalsPass.js' import { ClearFlags, ObjectVisibility } from '../Passes/GPass.js' import { TAAPass } from '../Passes/TAAPass.js' @@ -50,11 +50,11 @@ export class PenViewPipeline extends ProgressivePipeline { normalPassDynamic.setClearColor(0x000000, 1) normalPassDynamic.setClearFlags(ClearFlags.COLOR | ClearFlags.DEPTH) - const edgesPass = new EdgePass() + const edgesPass = new EdgesPass() edgesPass.setTexture('tDepth', depthPass.outputTarget?.texture) edgesPass.setTexture('tNormal', normalPass.outputTarget?.texture) - const edgesPassDynamic = new EdgePass() + const edgesPassDynamic = new EdgesPass() edgesPassDynamic.setTexture('tDepth', depthPassDynamic.outputTarget?.texture) edgesPassDynamic.setTexture('tNormal', normalPassDynamic.outputTarget?.texture) edgesPassDynamic.outputTarget = null diff --git a/packages/viewer/src/modules/pipeline/Pipelines/ShadedViewPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/ShadedViewPipeline.ts index acb96d900..84a6a0c6d 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/ShadedViewPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/ShadedViewPipeline.ts @@ -1,7 +1,7 @@ import SpeckleRenderer from '../../SpeckleRenderer.js' import { BlendPass } from '../Passes/BlendPass.js' import { DepthPass } from '../Passes/DepthPass.js' -import { EdgePass } from '../Passes/EdgesPass.js' +import { EdgesPass } from '../Passes/EdgesPass.js' import { NormalsPass } from '../Passes/NormalsPass.js' import { TAAPass } from '../Passes/TAAPass.js' import { AssetType, ObjectLayers } from '../../../IViewer.js' @@ -60,11 +60,11 @@ export class ShadedViewPipeline extends ProgressivePipeline { const shadowcatcherPass = new GeometryPass() shadowcatcherPass.setLayers([ObjectLayers.SHADOWCATCHER]) - const edgesPass = new EdgePass() + const edgesPass = new EdgesPass() edgesPass.setTexture('tDepth', depthPass.outputTarget?.texture) edgesPass.setTexture('tNormal', normalPass.outputTarget?.texture) - const edgesPassDynamic = new EdgePass() + const edgesPassDynamic = new EdgesPass() edgesPassDynamic.setTexture('tDepth', depthPassDynamic.outputTarget?.texture) edgesPassDynamic.setTexture('tNormal', normalPassDynamic.outputTarget?.texture) diff --git a/packages/viewer/src/modules/pipeline/Pipelines/TechnicalViewPipeline.ts b/packages/viewer/src/modules/pipeline/Pipelines/TechnicalViewPipeline.ts index da66df40e..b6fcc58cb 100644 --- a/packages/viewer/src/modules/pipeline/Pipelines/TechnicalViewPipeline.ts +++ b/packages/viewer/src/modules/pipeline/Pipelines/TechnicalViewPipeline.ts @@ -1,7 +1,7 @@ import { BackSide, NoBlending, WebGLRenderTarget } from 'three' import SpeckleRenderer from '../../SpeckleRenderer.js' import { DepthPass } from '../Passes/DepthPass.js' -import { EdgePass } from '../Passes/EdgesPass.js' +import { EdgesPass } from '../Passes/EdgesPass.js' import { NormalsPass } from '../Passes/NormalsPass.js' import { ObjectVisibility } from '../Passes/GPass.js' import { TAAPass } from '../Passes/TAAPass.js' @@ -56,17 +56,17 @@ export class TechnicalViewPipeline extends ProgressivePipeline { normalPassBackDynamic.overrideMaterial.side = BackSide // normalPassBackDynamic.overrideMaterial.depthTest = false - const edgesPassFront = new EdgePass() + const edgesPassFront = new EdgesPass() edgesPassFront.setTexture('tDepth', depthPassFront.outputTarget?.texture) edgesPassFront.setTexture('tNormal', normalPassFront.outputTarget?.texture) - const edgesPassBack = new EdgePass() + const edgesPassBack = new EdgesPass() edgesPassBack.setTexture('tDepth', depthPassBack.outputTarget?.texture) edgesPassBack.setTexture('tNormal', normalPassBack.outputTarget?.texture) edgesPassBack.edgesMaterial.uniforms.uOutlineDensity.value = 0.25 edgesPassBack.edgesMaterial.needsUpdate = true - const edgesPassFrontDynamic = new EdgePass() + const edgesPassFrontDynamic = new EdgesPass() edgesPassFrontDynamic.setTexture( 'tDepth', depthPassFrontDynamic.outputTarget?.texture @@ -76,7 +76,7 @@ export class TechnicalViewPipeline extends ProgressivePipeline { normalPassFrontDynamic.outputTarget?.texture ) - const edgesPassBackDynamic = new EdgePass() + const edgesPassBackDynamic = new EdgesPass() edgesPassBackDynamic.setTexture( 'tDepth', depthPassBackDynamic.outputTarget?.texture