fix(stripe): improve subscription cancellation handling and add callback endpoint

This commit is contained in:
Bereket Engida
2025-03-02 11:39:57 +03:00
parent 1d700f38f1
commit 1c91294e23
10 changed files with 254 additions and 83 deletions

View File

@@ -172,9 +172,16 @@ function Component(props: {
variant="destructive" variant="destructive"
className="w-full" className="w-full"
onClick={async () => { onClick={async () => {
await client.subscription.cancel({ await client.subscription.cancel(
returnUrl: "/dashboard", {
}); returnUrl: "/dashboard",
},
{
onError: (ctx) => {
toast.error(ctx.error.message);
},
},
);
}} }}
> >
Cancel Plan Cancel Plan

View File

@@ -26,7 +26,7 @@ export const client = createAuthClient({
oneTapClient({ oneTapClient({
clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!, clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!,
promptOptions: { promptOptions: {
maxAttempts: 2, maxAttempts: 1,
}, },
}), }),
oidcClient(), oidcClient(),

View File

@@ -20,7 +20,6 @@ import { nextCookies } from "better-auth/next-js";
import { passkey } from "better-auth/plugins/passkey"; import { passkey } from "better-auth/plugins/passkey";
import { stripe } from "@better-auth/stripe"; import { stripe } from "@better-auth/stripe";
import { Stripe } from "stripe"; import { Stripe } from "stripe";
import Database from "better-sqlite3";
const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev"; const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev";
const to = process.env.TEST_EMAIL || ""; const to = process.env.TEST_EMAIL || "";
@@ -51,7 +50,10 @@ const STARTER_PRICE_ID = {
export const auth = betterAuth({ export const auth = betterAuth({
appName: "Better Auth Demo", appName: "Better Auth Demo",
database: new Database("./stripe.db"), database: {
dialect,
type: process.env.USE_MYSQL ? "mysql" : "sqlite",
},
emailVerification: { emailVerification: {
async sendVerificationEmail({ user, url }) { async sendVerificationEmail({ user, url }) {
const res = await resend.emails.send({ const res = await resend.emails.send({
@@ -163,7 +165,6 @@ export const auth = betterAuth({
stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!, stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!,
subscription: { subscription: {
enabled: true, enabled: true,
requireEmailVerification: true,
plans: [ plans: [
{ {
name: "Starter", name: "Starter",

View File

@@ -230,7 +230,10 @@ function NewBadge({ isSelected }: { isSelected?: boolean }) {
return ( return (
<div className="flex items-center justify-end w-full"> <div className="flex items-center justify-end w-full">
<Badge <Badge
className=" pointer-events-none !no-underline border-dashed !decoration-transparent" className={cn(
" pointer-events-none !no-underline border-dashed !decoration-transparent",
isSelected && "!border-solid",
)}
variant={isSelected ? "default" : "outline"} variant={isSelected ? "default" : "outline"}
> >
New New

View File

@@ -62,11 +62,9 @@ export const multiSession = (options?: MultiSessionConfig) => {
).filter((v) => v !== null); ).filter((v) => v !== null);
const s = await getSessionFromCtx(ctx); const s = await getSessionFromCtx(ctx);
console.log({ sessionTokens, s });
if (!sessionTokens.length) return ctx.json([]); if (!sessionTokens.length) return ctx.json([]);
const sessions = const sessions =
await ctx.context.internalAdapter.findSessions(sessionTokens); await ctx.context.internalAdapter.findSessions(sessionTokens);
console.log({ sessions });
const validSessions = sessions.filter( const validSessions = sessions.filter(
(session) => session && session.session.expiresAt > new Date(), (session) => session && session.session.expiresAt > new Date(),
); );

View File

@@ -574,12 +574,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
defaultValue: "member", defaultValue: "member",
fieldName: options?.schema?.member?.fields?.role, fieldName: options?.schema?.member?.fields?.role,
}, },
teamId: { ...(teamSupport
type: "string", ? {
required: false, teamId: {
sortable: true, type: "string",
fieldName: options?.schema?.member?.fields?.teamId, required: false,
}, sortable: true,
fieldName: options?.schema?.member?.fields?.teamId,
},
}
: {}),
createdAt: { createdAt: {
type: "date", type: "date",
required: true, required: true,
@@ -611,12 +615,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
sortable: true, sortable: true,
fieldName: options?.schema?.invitation?.fields?.role, fieldName: options?.schema?.invitation?.fields?.role,
}, },
teamId: { ...(teamSupport
type: "string", ? {
required: false, teamId: {
sortable: true, type: "string",
fieldName: options?.schema?.invitation?.fields?.teamId, required: false,
}, sortable: true,
fieldName: options?.schema?.invitation?.fields?.teamId,
},
}
: {}),
status: { status: {
type: "string", type: "string",
required: true, required: true,

View File

@@ -93,59 +93,61 @@ export async function onSubscriptionUpdated(
const subscriptionUpdated = event.data.object as Stripe.Subscription; const subscriptionUpdated = event.data.object as Stripe.Subscription;
const priceId = subscriptionUpdated.items.data[0].price.id; const priceId = subscriptionUpdated.items.data[0].price.id;
const plan = await getPlanByPriceId(options, priceId); const plan = await getPlanByPriceId(options, priceId);
if (plan) { const stripeId = subscriptionUpdated.id;
const stripeId = subscriptionUpdated.id; const subscription = await ctx.context.adapter.findOne<Subscription>({
const subscription = await ctx.context.adapter.findOne<Subscription>({ model: "subscription",
model: "subscription", where: [
where: [ {
{ field: "stripeSubscriptionId",
field: "stripeSubscriptionId", value: stripeId,
value: stripeId,
},
],
});
if (!subscription) {
return;
}
const seats = subscriptionUpdated.items.data[0].quantity;
await ctx.context.adapter.update({
model: "subscription",
update: {
plan: plan.name.toLowerCase(),
limits: plan.limits,
updatedAt: new Date(),
status: subscriptionUpdated.status,
periodStart: new Date(
subscriptionUpdated.current_period_start * 1000,
),
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
seats,
}, },
where: [ ],
{ });
field: "stripeSubscriptionId", if (!subscription) {
value: subscriptionUpdated.id, return;
}, }
], const seats = subscriptionUpdated.items.data[0].quantity;
}); await ctx.context.adapter.update({
const subscriptionCanceled = model: "subscription",
subscriptionUpdated.status === "active" && update: {
subscriptionUpdated.cancel_at_period_end; ...(plan
if (subscriptionCanceled) { ? {
await options.subscription.onSubscriptionCancel?.({ plan: plan.name.toLowerCase(),
subscription, limits: plan.limits,
cancellationDetails: }
subscriptionUpdated.cancellation_details || undefined, : {}),
stripeSubscription: subscriptionUpdated, updatedAt: new Date(),
event, status: subscriptionUpdated.status,
}); periodStart: new Date(subscriptionUpdated.current_period_start * 1000),
} periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
await options.subscription.onSubscriptionUpdate?.({ cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
event, seats,
},
where: [
{
field: "stripeSubscriptionId",
value: subscriptionUpdated.id,
},
],
});
const subscriptionCanceled =
subscriptionUpdated.status === "active" &&
subscriptionUpdated.cancel_at_period_end &&
!subscription.cancelAtPeriodEnd; //if this is true, it means the subscription was canceled before the event was triggered
if (subscriptionCanceled) {
await options.subscription.onSubscriptionCancel?.({
subscription, subscription,
cancellationDetails:
subscriptionUpdated.cancellation_details || undefined,
stripeSubscription: subscriptionUpdated,
event,
}); });
}
await options.subscription.onSubscriptionUpdate?.({
event,
subscription,
});
if (plan) {
if ( if (
subscriptionUpdated.status === "active" && subscriptionUpdated.status === "active" &&
subscription.status === "trialing" && subscription.status === "trialing" &&

View File

@@ -335,6 +335,85 @@ export const stripe = <O extends StripeOptions>(options: O) => {
}); });
}, },
), ),
cancelSubscriptionCallback: createAuthEndpoint(
"/subscription/cancel/callback",
{
method: "GET",
query: z.record(z.string(), z.any()).optional(),
},
async (ctx) => {
if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) {
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
}
const session = await getSessionFromCtx<{ stripeCustomerId: string }>(
ctx,
);
if (!session) {
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
}
const { user } = session;
const { callbackURL, reference } = ctx.query;
if (user?.stripeCustomerId) {
try {
const subscription =
await ctx.context.adapter.findOne<Subscription>({
model: "subscription",
where: [
{
field: "referenceId",
value: reference,
},
],
});
console.log({ subscription });
if (
!subscription ||
subscription.cancelAtPeriodEnd ||
subscription.status === "canceled"
) {
throw ctx.redirect(getUrl(ctx, callbackURL));
}
const stripeSubscription = await client.subscriptions.list({
customer: user.stripeCustomerId,
status: "active",
});
const currentSubscription = stripeSubscription.data.find(
(sub) => sub.id === subscription.stripeSubscriptionId,
);
console.log({ currentSubscription });
if (currentSubscription?.cancel_at_period_end === true) {
await ctx.context.adapter.update({
model: "subscription",
update: {
status: currentSubscription?.status,
cancelAtPeriodEnd: true,
},
where: [
{
field: "referenceId",
value: reference,
},
],
});
await options.subscription?.onSubscriptionCancel?.({
subscription,
cancellationDetails: currentSubscription.cancellation_details,
stripeSubscription: currentSubscription,
event: undefined,
});
}
} catch (error) {
ctx.context.logger.error(
"Error checking subscription status from Stripe",
error,
);
}
}
throw ctx.redirect(getUrl(ctx, callbackURL));
},
),
cancelSubscription: createAuthEndpoint( cancelSubscription: createAuthEndpoint(
"/subscription/cancel", "/subscription/cancel",
{ {
@@ -377,16 +456,50 @@ export const stripe = <O extends StripeOptions>(options: O) => {
message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND, message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
}); });
} }
const { url } = await client.billingPortal.sessions.create({ const { url } = await client.billingPortal.sessions
customer: subscription.stripeCustomerId, .create({
return_url: getUrl(ctx, ctx.body?.returnUrl || "/"), customer: subscription.stripeCustomerId,
flow_data: { return_url: getUrl(
type: "subscription_cancel", ctx,
subscription_cancel: { `${
subscription: activeSubscription.id, ctx.context.baseURL
}/subscription/cancel/callback?callbackURL=${encodeURIComponent(
ctx.body?.returnUrl || "/",
)}&reference=${encodeURIComponent(referenceId)}`,
),
flow_data: {
type: "subscription_cancel",
subscription_cancel: {
subscription: activeSubscription.id,
},
}, },
}, })
}); .catch(async (e) => {
if (e.message.includes("already set to be cancel")) {
/**
* incase we missed the event from stripe, we set it manually
* this is a rare case and should not happen
*/
if (!subscription.cancelAtPeriodEnd) {
await ctx.context.adapter.update({
model: "subscription",
update: {
cancelAtPeriodEnd: true,
},
where: [
{
field: "referenceId",
value: referenceId,
},
],
});
}
}
throw ctx.error("BAD_REQUEST", {
message: e.message,
code: e.code,
});
});
return { return {
url, url,
redirect: true, redirect: true,
@@ -509,6 +622,9 @@ export const stripe = <O extends StripeOptions>(options: O) => {
status: stripeSubscription.status, status: stripeSubscription.status,
seats: stripeSubscription.items.data[0]?.quantity || 1, seats: stripeSubscription.items.data[0]?.quantity || 1,
plan: plan.name.toLowerCase(), plan: plan.name.toLowerCase(),
periodEnd: stripeSubscription.current_period_end,
periodStart: stripeSubscription.current_period_start,
stripeSubscriptionId: stripeSubscription.id,
}, },
where: [ where: [
{ {

View File

@@ -600,6 +600,42 @@ describe("stripe", async () => {
}), }),
); );
const userCancelEvent = {
type: "customer.subscription.updated",
data: {
object: {
id: "sub_123",
customer: "cus_123",
status: "active",
cancel_at_period_end: true,
cancellation_details: {
reason: "cancellation_requested",
comment: "Customer canceled subscription",
},
items: {
data: [{ price: { id: process.env.STRIPE_PRICE_ID_1 } }],
},
current_period_start: Math.floor(Date.now() / 1000),
current_period_end: Math.floor(Date.now() / 1000) + 30 * 24 * 60 * 60,
},
},
};
const userCancelRequest = new Request(
"http://localhost:3000/api/auth/stripe/webhook",
{
method: "POST",
headers: {
"stripe-signature": "test_signature",
},
body: JSON.stringify(userCancelEvent),
},
);
mockStripeForEvents.webhooks.constructEvent.mockReturnValue(
userCancelEvent,
);
await eventTestAuth.handler(userCancelRequest);
const cancelEvent = { const cancelEvent = {
type: "customer.subscription.updated", type: "customer.subscription.updated",
data: { data: {

View File

@@ -247,10 +247,10 @@ export interface StripeOptions {
* @returns * @returns
*/ */
onSubscriptionCancel?: (data: { onSubscriptionCancel?: (data: {
event: Stripe.Event; event?: Stripe.Event;
subscription: Subscription; subscription: Subscription;
stripeSubscription: Stripe.Subscription; stripeSubscription: Stripe.Subscription;
cancellationDetails?: Stripe.Subscription.CancellationDetails; cancellationDetails?: Stripe.Subscription.CancellationDetails | null;
}) => Promise<void>; }) => Promise<void>;
/** /**
* A function to check if the reference id is valid * A function to check if the reference id is valid