fix(stripe): improve subscription cancellation handling and add callback endpoint

This commit is contained in:
Bereket Engida
2025-03-02 11:39:57 +03:00
parent 1d700f38f1
commit 1c91294e23
10 changed files with 254 additions and 83 deletions

View File

@@ -172,9 +172,16 @@ function Component(props: {
variant="destructive"
className="w-full"
onClick={async () => {
await client.subscription.cancel({
returnUrl: "/dashboard",
});
await client.subscription.cancel(
{
returnUrl: "/dashboard",
},
{
onError: (ctx) => {
toast.error(ctx.error.message);
},
},
);
}}
>
Cancel Plan

View File

@@ -26,7 +26,7 @@ export const client = createAuthClient({
oneTapClient({
clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!,
promptOptions: {
maxAttempts: 2,
maxAttempts: 1,
},
}),
oidcClient(),

View File

@@ -20,7 +20,6 @@ import { nextCookies } from "better-auth/next-js";
import { passkey } from "better-auth/plugins/passkey";
import { stripe } from "@better-auth/stripe";
import { Stripe } from "stripe";
import Database from "better-sqlite3";
const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev";
const to = process.env.TEST_EMAIL || "";
@@ -51,7 +50,10 @@ const STARTER_PRICE_ID = {
export const auth = betterAuth({
appName: "Better Auth Demo",
database: new Database("./stripe.db"),
database: {
dialect,
type: process.env.USE_MYSQL ? "mysql" : "sqlite",
},
emailVerification: {
async sendVerificationEmail({ user, url }) {
const res = await resend.emails.send({
@@ -163,7 +165,6 @@ export const auth = betterAuth({
stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!,
subscription: {
enabled: true,
requireEmailVerification: true,
plans: [
{
name: "Starter",

View File

@@ -230,7 +230,10 @@ function NewBadge({ isSelected }: { isSelected?: boolean }) {
return (
<div className="flex items-center justify-end w-full">
<Badge
className=" pointer-events-none !no-underline border-dashed !decoration-transparent"
className={cn(
" pointer-events-none !no-underline border-dashed !decoration-transparent",
isSelected && "!border-solid",
)}
variant={isSelected ? "default" : "outline"}
>
New

View File

@@ -62,11 +62,9 @@ export const multiSession = (options?: MultiSessionConfig) => {
).filter((v) => v !== null);
const s = await getSessionFromCtx(ctx);
console.log({ sessionTokens, s });
if (!sessionTokens.length) return ctx.json([]);
const sessions =
await ctx.context.internalAdapter.findSessions(sessionTokens);
console.log({ sessions });
const validSessions = sessions.filter(
(session) => session && session.session.expiresAt > new Date(),
);

View File

@@ -574,12 +574,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
defaultValue: "member",
fieldName: options?.schema?.member?.fields?.role,
},
teamId: {
type: "string",
required: false,
sortable: true,
fieldName: options?.schema?.member?.fields?.teamId,
},
...(teamSupport
? {
teamId: {
type: "string",
required: false,
sortable: true,
fieldName: options?.schema?.member?.fields?.teamId,
},
}
: {}),
createdAt: {
type: "date",
required: true,
@@ -611,12 +615,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
sortable: true,
fieldName: options?.schema?.invitation?.fields?.role,
},
teamId: {
type: "string",
required: false,
sortable: true,
fieldName: options?.schema?.invitation?.fields?.teamId,
},
...(teamSupport
? {
teamId: {
type: "string",
required: false,
sortable: true,
fieldName: options?.schema?.invitation?.fields?.teamId,
},
}
: {}),
status: {
type: "string",
required: true,

View File

@@ -93,59 +93,61 @@ export async function onSubscriptionUpdated(
const subscriptionUpdated = event.data.object as Stripe.Subscription;
const priceId = subscriptionUpdated.items.data[0].price.id;
const plan = await getPlanByPriceId(options, priceId);
if (plan) {
const stripeId = subscriptionUpdated.id;
const subscription = await ctx.context.adapter.findOne<Subscription>({
model: "subscription",
where: [
{
field: "stripeSubscriptionId",
value: stripeId,
},
],
});
if (!subscription) {
return;
}
const seats = subscriptionUpdated.items.data[0].quantity;
await ctx.context.adapter.update({
model: "subscription",
update: {
plan: plan.name.toLowerCase(),
limits: plan.limits,
updatedAt: new Date(),
status: subscriptionUpdated.status,
periodStart: new Date(
subscriptionUpdated.current_period_start * 1000,
),
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
seats,
const stripeId = subscriptionUpdated.id;
const subscription = await ctx.context.adapter.findOne<Subscription>({
model: "subscription",
where: [
{
field: "stripeSubscriptionId",
value: stripeId,
},
where: [
{
field: "stripeSubscriptionId",
value: subscriptionUpdated.id,
},
],
});
const subscriptionCanceled =
subscriptionUpdated.status === "active" &&
subscriptionUpdated.cancel_at_period_end;
if (subscriptionCanceled) {
await options.subscription.onSubscriptionCancel?.({
subscription,
cancellationDetails:
subscriptionUpdated.cancellation_details || undefined,
stripeSubscription: subscriptionUpdated,
event,
});
}
await options.subscription.onSubscriptionUpdate?.({
event,
],
});
if (!subscription) {
return;
}
const seats = subscriptionUpdated.items.data[0].quantity;
await ctx.context.adapter.update({
model: "subscription",
update: {
...(plan
? {
plan: plan.name.toLowerCase(),
limits: plan.limits,
}
: {}),
updatedAt: new Date(),
status: subscriptionUpdated.status,
periodStart: new Date(subscriptionUpdated.current_period_start * 1000),
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
seats,
},
where: [
{
field: "stripeSubscriptionId",
value: subscriptionUpdated.id,
},
],
});
const subscriptionCanceled =
subscriptionUpdated.status === "active" &&
subscriptionUpdated.cancel_at_period_end &&
!subscription.cancelAtPeriodEnd; //if this is true, it means the subscription was canceled before the event was triggered
if (subscriptionCanceled) {
await options.subscription.onSubscriptionCancel?.({
subscription,
cancellationDetails:
subscriptionUpdated.cancellation_details || undefined,
stripeSubscription: subscriptionUpdated,
event,
});
}
await options.subscription.onSubscriptionUpdate?.({
event,
subscription,
});
if (plan) {
if (
subscriptionUpdated.status === "active" &&
subscription.status === "trialing" &&

View File

@@ -335,6 +335,85 @@ export const stripe = <O extends StripeOptions>(options: O) => {
});
},
),
cancelSubscriptionCallback: createAuthEndpoint(
"/subscription/cancel/callback",
{
method: "GET",
query: z.record(z.string(), z.any()).optional(),
},
async (ctx) => {
if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) {
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
}
const session = await getSessionFromCtx<{ stripeCustomerId: string }>(
ctx,
);
if (!session) {
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
}
const { user } = session;
const { callbackURL, reference } = ctx.query;
if (user?.stripeCustomerId) {
try {
const subscription =
await ctx.context.adapter.findOne<Subscription>({
model: "subscription",
where: [
{
field: "referenceId",
value: reference,
},
],
});
console.log({ subscription });
if (
!subscription ||
subscription.cancelAtPeriodEnd ||
subscription.status === "canceled"
) {
throw ctx.redirect(getUrl(ctx, callbackURL));
}
const stripeSubscription = await client.subscriptions.list({
customer: user.stripeCustomerId,
status: "active",
});
const currentSubscription = stripeSubscription.data.find(
(sub) => sub.id === subscription.stripeSubscriptionId,
);
console.log({ currentSubscription });
if (currentSubscription?.cancel_at_period_end === true) {
await ctx.context.adapter.update({
model: "subscription",
update: {
status: currentSubscription?.status,
cancelAtPeriodEnd: true,
},
where: [
{
field: "referenceId",
value: reference,
},
],
});
await options.subscription?.onSubscriptionCancel?.({
subscription,
cancellationDetails: currentSubscription.cancellation_details,
stripeSubscription: currentSubscription,
event: undefined,
});
}
} catch (error) {
ctx.context.logger.error(
"Error checking subscription status from Stripe",
error,
);
}
}
throw ctx.redirect(getUrl(ctx, callbackURL));
},
),
cancelSubscription: createAuthEndpoint(
"/subscription/cancel",
{
@@ -377,16 +456,50 @@ export const stripe = <O extends StripeOptions>(options: O) => {
message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
});
}
const { url } = await client.billingPortal.sessions.create({
customer: subscription.stripeCustomerId,
return_url: getUrl(ctx, ctx.body?.returnUrl || "/"),
flow_data: {
type: "subscription_cancel",
subscription_cancel: {
subscription: activeSubscription.id,
const { url } = await client.billingPortal.sessions
.create({
customer: subscription.stripeCustomerId,
return_url: getUrl(
ctx,
`${
ctx.context.baseURL
}/subscription/cancel/callback?callbackURL=${encodeURIComponent(
ctx.body?.returnUrl || "/",
)}&reference=${encodeURIComponent(referenceId)}`,
),
flow_data: {
type: "subscription_cancel",
subscription_cancel: {
subscription: activeSubscription.id,
},
},
},
});
})
.catch(async (e) => {
if (e.message.includes("already set to be cancel")) {
/**
* incase we missed the event from stripe, we set it manually
* this is a rare case and should not happen
*/
if (!subscription.cancelAtPeriodEnd) {
await ctx.context.adapter.update({
model: "subscription",
update: {
cancelAtPeriodEnd: true,
},
where: [
{
field: "referenceId",
value: referenceId,
},
],
});
}
}
throw ctx.error("BAD_REQUEST", {
message: e.message,
code: e.code,
});
});
return {
url,
redirect: true,
@@ -509,6 +622,9 @@ export const stripe = <O extends StripeOptions>(options: O) => {
status: stripeSubscription.status,
seats: stripeSubscription.items.data[0]?.quantity || 1,
plan: plan.name.toLowerCase(),
periodEnd: stripeSubscription.current_period_end,
periodStart: stripeSubscription.current_period_start,
stripeSubscriptionId: stripeSubscription.id,
},
where: [
{

View File

@@ -600,6 +600,42 @@ describe("stripe", async () => {
}),
);
const userCancelEvent = {
type: "customer.subscription.updated",
data: {
object: {
id: "sub_123",
customer: "cus_123",
status: "active",
cancel_at_period_end: true,
cancellation_details: {
reason: "cancellation_requested",
comment: "Customer canceled subscription",
},
items: {
data: [{ price: { id: process.env.STRIPE_PRICE_ID_1 } }],
},
current_period_start: Math.floor(Date.now() / 1000),
current_period_end: Math.floor(Date.now() / 1000) + 30 * 24 * 60 * 60,
},
},
};
const userCancelRequest = new Request(
"http://localhost:3000/api/auth/stripe/webhook",
{
method: "POST",
headers: {
"stripe-signature": "test_signature",
},
body: JSON.stringify(userCancelEvent),
},
);
mockStripeForEvents.webhooks.constructEvent.mockReturnValue(
userCancelEvent,
);
await eventTestAuth.handler(userCancelRequest);
const cancelEvent = {
type: "customer.subscription.updated",
data: {

View File

@@ -247,10 +247,10 @@ export interface StripeOptions {
* @returns
*/
onSubscriptionCancel?: (data: {
event: Stripe.Event;
event?: Stripe.Event;
subscription: Subscription;
stripeSubscription: Stripe.Subscription;
cancellationDetails?: Stripe.Subscription.CancellationDetails;
cancellationDetails?: Stripe.Subscription.CancellationDetails | null;
}) => Promise<void>;
/**
* A function to check if the reference id is valid