diff --git a/demo/nextjs/app/dashboard/change-plan.tsx b/demo/nextjs/app/dashboard/change-plan.tsx index f4509660..376e955f 100644 --- a/demo/nextjs/app/dashboard/change-plan.tsx +++ b/demo/nextjs/app/dashboard/change-plan.tsx @@ -172,9 +172,16 @@ function Component(props: { variant="destructive" className="w-full" onClick={async () => { - await client.subscription.cancel({ - returnUrl: "/dashboard", - }); + await client.subscription.cancel( + { + returnUrl: "/dashboard", + }, + { + onError: (ctx) => { + toast.error(ctx.error.message); + }, + }, + ); }} > Cancel Plan diff --git a/demo/nextjs/lib/auth-client.ts b/demo/nextjs/lib/auth-client.ts index 1f1a7f8b..81365f20 100644 --- a/demo/nextjs/lib/auth-client.ts +++ b/demo/nextjs/lib/auth-client.ts @@ -26,7 +26,7 @@ export const client = createAuthClient({ oneTapClient({ clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!, promptOptions: { - maxAttempts: 2, + maxAttempts: 1, }, }), oidcClient(), diff --git a/demo/nextjs/lib/auth.ts b/demo/nextjs/lib/auth.ts index bc72797c..0864f951 100644 --- a/demo/nextjs/lib/auth.ts +++ b/demo/nextjs/lib/auth.ts @@ -20,7 +20,6 @@ import { nextCookies } from "better-auth/next-js"; import { passkey } from "better-auth/plugins/passkey"; import { stripe } from "@better-auth/stripe"; import { Stripe } from "stripe"; -import Database from "better-sqlite3"; const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev"; const to = process.env.TEST_EMAIL || ""; @@ -51,7 +50,10 @@ const STARTER_PRICE_ID = { export const auth = betterAuth({ appName: "Better Auth Demo", - database: new Database("./stripe.db"), + database: { + dialect, + type: process.env.USE_MYSQL ? "mysql" : "sqlite", + }, emailVerification: { async sendVerificationEmail({ user, url }) { const res = await resend.emails.send({ @@ -163,7 +165,6 @@ export const auth = betterAuth({ stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!, subscription: { enabled: true, - requireEmailVerification: true, plans: [ { name: "Starter", diff --git a/docs/components/side-bar.tsx b/docs/components/side-bar.tsx index 28fda945..31ae4b71 100644 --- a/docs/components/side-bar.tsx +++ b/docs/components/side-bar.tsx @@ -230,7 +230,10 @@ function NewBadge({ isSelected }: { isSelected?: boolean }) { return (
New diff --git a/packages/better-auth/src/plugins/multi-session/index.ts b/packages/better-auth/src/plugins/multi-session/index.ts index 40818e49..e3ebc194 100644 --- a/packages/better-auth/src/plugins/multi-session/index.ts +++ b/packages/better-auth/src/plugins/multi-session/index.ts @@ -62,11 +62,9 @@ export const multiSession = (options?: MultiSessionConfig) => { ).filter((v) => v !== null); const s = await getSessionFromCtx(ctx); - console.log({ sessionTokens, s }); if (!sessionTokens.length) return ctx.json([]); const sessions = await ctx.context.internalAdapter.findSessions(sessionTokens); - console.log({ sessions }); const validSessions = sessions.filter( (session) => session && session.session.expiresAt > new Date(), ); diff --git a/packages/better-auth/src/plugins/organization/organization.ts b/packages/better-auth/src/plugins/organization/organization.ts index cae60d6b..32eccb7b 100644 --- a/packages/better-auth/src/plugins/organization/organization.ts +++ b/packages/better-auth/src/plugins/organization/organization.ts @@ -574,12 +574,16 @@ export const organization = (options?: O) => { defaultValue: "member", fieldName: options?.schema?.member?.fields?.role, }, - teamId: { - type: "string", - required: false, - sortable: true, - fieldName: options?.schema?.member?.fields?.teamId, - }, + ...(teamSupport + ? { + teamId: { + type: "string", + required: false, + sortable: true, + fieldName: options?.schema?.member?.fields?.teamId, + }, + } + : {}), createdAt: { type: "date", required: true, @@ -611,12 +615,16 @@ export const organization = (options?: O) => { sortable: true, fieldName: options?.schema?.invitation?.fields?.role, }, - teamId: { - type: "string", - required: false, - sortable: true, - fieldName: options?.schema?.invitation?.fields?.teamId, - }, + ...(teamSupport + ? { + teamId: { + type: "string", + required: false, + sortable: true, + fieldName: options?.schema?.invitation?.fields?.teamId, + }, + } + : {}), status: { type: "string", required: true, diff --git a/packages/stripe/src/hooks.ts b/packages/stripe/src/hooks.ts index 7975f2fb..cb7ae1dc 100644 --- a/packages/stripe/src/hooks.ts +++ b/packages/stripe/src/hooks.ts @@ -93,59 +93,61 @@ export async function onSubscriptionUpdated( const subscriptionUpdated = event.data.object as Stripe.Subscription; const priceId = subscriptionUpdated.items.data[0].price.id; const plan = await getPlanByPriceId(options, priceId); - if (plan) { - const stripeId = subscriptionUpdated.id; - const subscription = await ctx.context.adapter.findOne({ - model: "subscription", - where: [ - { - field: "stripeSubscriptionId", - value: stripeId, - }, - ], - }); - if (!subscription) { - return; - } - const seats = subscriptionUpdated.items.data[0].quantity; - await ctx.context.adapter.update({ - model: "subscription", - update: { - plan: plan.name.toLowerCase(), - limits: plan.limits, - updatedAt: new Date(), - status: subscriptionUpdated.status, - periodStart: new Date( - subscriptionUpdated.current_period_start * 1000, - ), - periodEnd: new Date(subscriptionUpdated.current_period_end * 1000), - cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end, - seats, + const stripeId = subscriptionUpdated.id; + const subscription = await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "stripeSubscriptionId", + value: stripeId, }, - where: [ - { - field: "stripeSubscriptionId", - value: subscriptionUpdated.id, - }, - ], - }); - const subscriptionCanceled = - subscriptionUpdated.status === "active" && - subscriptionUpdated.cancel_at_period_end; - if (subscriptionCanceled) { - await options.subscription.onSubscriptionCancel?.({ - subscription, - cancellationDetails: - subscriptionUpdated.cancellation_details || undefined, - stripeSubscription: subscriptionUpdated, - event, - }); - } - await options.subscription.onSubscriptionUpdate?.({ - event, + ], + }); + if (!subscription) { + return; + } + const seats = subscriptionUpdated.items.data[0].quantity; + await ctx.context.adapter.update({ + model: "subscription", + update: { + ...(plan + ? { + plan: plan.name.toLowerCase(), + limits: plan.limits, + } + : {}), + updatedAt: new Date(), + status: subscriptionUpdated.status, + periodStart: new Date(subscriptionUpdated.current_period_start * 1000), + periodEnd: new Date(subscriptionUpdated.current_period_end * 1000), + cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end, + seats, + }, + where: [ + { + field: "stripeSubscriptionId", + value: subscriptionUpdated.id, + }, + ], + }); + const subscriptionCanceled = + subscriptionUpdated.status === "active" && + subscriptionUpdated.cancel_at_period_end && + !subscription.cancelAtPeriodEnd; //if this is true, it means the subscription was canceled before the event was triggered + if (subscriptionCanceled) { + await options.subscription.onSubscriptionCancel?.({ subscription, + cancellationDetails: + subscriptionUpdated.cancellation_details || undefined, + stripeSubscription: subscriptionUpdated, + event, }); - + } + await options.subscription.onSubscriptionUpdate?.({ + event, + subscription, + }); + if (plan) { if ( subscriptionUpdated.status === "active" && subscription.status === "trialing" && diff --git a/packages/stripe/src/index.ts b/packages/stripe/src/index.ts index e0379fac..19e7dc42 100644 --- a/packages/stripe/src/index.ts +++ b/packages/stripe/src/index.ts @@ -335,6 +335,85 @@ export const stripe = (options: O) => { }); }, ), + cancelSubscriptionCallback: createAuthEndpoint( + "/subscription/cancel/callback", + { + method: "GET", + query: z.record(z.string(), z.any()).optional(), + }, + async (ctx) => { + if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) { + throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/")); + } + const session = await getSessionFromCtx<{ stripeCustomerId: string }>( + ctx, + ); + if (!session) { + throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/")); + } + const { user } = session; + const { callbackURL, reference } = ctx.query; + + if (user?.stripeCustomerId) { + try { + const subscription = + await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "referenceId", + value: reference, + }, + ], + }); + console.log({ subscription }); + if ( + !subscription || + subscription.cancelAtPeriodEnd || + subscription.status === "canceled" + ) { + throw ctx.redirect(getUrl(ctx, callbackURL)); + } + + const stripeSubscription = await client.subscriptions.list({ + customer: user.stripeCustomerId, + status: "active", + }); + const currentSubscription = stripeSubscription.data.find( + (sub) => sub.id === subscription.stripeSubscriptionId, + ); + console.log({ currentSubscription }); + if (currentSubscription?.cancel_at_period_end === true) { + await ctx.context.adapter.update({ + model: "subscription", + update: { + status: currentSubscription?.status, + cancelAtPeriodEnd: true, + }, + where: [ + { + field: "referenceId", + value: reference, + }, + ], + }); + await options.subscription?.onSubscriptionCancel?.({ + subscription, + cancellationDetails: currentSubscription.cancellation_details, + stripeSubscription: currentSubscription, + event: undefined, + }); + } + } catch (error) { + ctx.context.logger.error( + "Error checking subscription status from Stripe", + error, + ); + } + } + throw ctx.redirect(getUrl(ctx, callbackURL)); + }, + ), cancelSubscription: createAuthEndpoint( "/subscription/cancel", { @@ -377,16 +456,50 @@ export const stripe = (options: O) => { message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND, }); } - const { url } = await client.billingPortal.sessions.create({ - customer: subscription.stripeCustomerId, - return_url: getUrl(ctx, ctx.body?.returnUrl || "/"), - flow_data: { - type: "subscription_cancel", - subscription_cancel: { - subscription: activeSubscription.id, + const { url } = await client.billingPortal.sessions + .create({ + customer: subscription.stripeCustomerId, + return_url: getUrl( + ctx, + `${ + ctx.context.baseURL + }/subscription/cancel/callback?callbackURL=${encodeURIComponent( + ctx.body?.returnUrl || "/", + )}&reference=${encodeURIComponent(referenceId)}`, + ), + flow_data: { + type: "subscription_cancel", + subscription_cancel: { + subscription: activeSubscription.id, + }, }, - }, - }); + }) + .catch(async (e) => { + if (e.message.includes("already set to be cancel")) { + /** + * incase we missed the event from stripe, we set it manually + * this is a rare case and should not happen + */ + if (!subscription.cancelAtPeriodEnd) { + await ctx.context.adapter.update({ + model: "subscription", + update: { + cancelAtPeriodEnd: true, + }, + where: [ + { + field: "referenceId", + value: referenceId, + }, + ], + }); + } + } + throw ctx.error("BAD_REQUEST", { + message: e.message, + code: e.code, + }); + }); return { url, redirect: true, @@ -509,6 +622,9 @@ export const stripe = (options: O) => { status: stripeSubscription.status, seats: stripeSubscription.items.data[0]?.quantity || 1, plan: plan.name.toLowerCase(), + periodEnd: stripeSubscription.current_period_end, + periodStart: stripeSubscription.current_period_start, + stripeSubscriptionId: stripeSubscription.id, }, where: [ { diff --git a/packages/stripe/src/stripe.test.ts b/packages/stripe/src/stripe.test.ts index de4dbdb0..db5b4098 100644 --- a/packages/stripe/src/stripe.test.ts +++ b/packages/stripe/src/stripe.test.ts @@ -600,6 +600,42 @@ describe("stripe", async () => { }), ); + const userCancelEvent = { + type: "customer.subscription.updated", + data: { + object: { + id: "sub_123", + customer: "cus_123", + status: "active", + cancel_at_period_end: true, + cancellation_details: { + reason: "cancellation_requested", + comment: "Customer canceled subscription", + }, + items: { + data: [{ price: { id: process.env.STRIPE_PRICE_ID_1 } }], + }, + current_period_start: Math.floor(Date.now() / 1000), + current_period_end: Math.floor(Date.now() / 1000) + 30 * 24 * 60 * 60, + }, + }, + }; + + const userCancelRequest = new Request( + "http://localhost:3000/api/auth/stripe/webhook", + { + method: "POST", + headers: { + "stripe-signature": "test_signature", + }, + body: JSON.stringify(userCancelEvent), + }, + ); + + mockStripeForEvents.webhooks.constructEvent.mockReturnValue( + userCancelEvent, + ); + await eventTestAuth.handler(userCancelRequest); const cancelEvent = { type: "customer.subscription.updated", data: { diff --git a/packages/stripe/src/types.ts b/packages/stripe/src/types.ts index b9784b58..2bab978b 100644 --- a/packages/stripe/src/types.ts +++ b/packages/stripe/src/types.ts @@ -247,10 +247,10 @@ export interface StripeOptions { * @returns */ onSubscriptionCancel?: (data: { - event: Stripe.Event; + event?: Stripe.Event; subscription: Subscription; stripeSubscription: Stripe.Subscription; - cancellationDetails?: Stripe.Subscription.CancellationDetails; + cancellationDetails?: Stripe.Subscription.CancellationDetails | null; }) => Promise; /** * A function to check if the reference id is valid