mirror of
https://github.com/LukeHagar/better-auth.git
synced 2025-12-07 20:37:44 +00:00
fix(stripe): improve subscription cancellation handling and add callback endpoint
This commit is contained in:
@@ -172,9 +172,16 @@ function Component(props: {
|
||||
variant="destructive"
|
||||
className="w-full"
|
||||
onClick={async () => {
|
||||
await client.subscription.cancel({
|
||||
returnUrl: "/dashboard",
|
||||
});
|
||||
await client.subscription.cancel(
|
||||
{
|
||||
returnUrl: "/dashboard",
|
||||
},
|
||||
{
|
||||
onError: (ctx) => {
|
||||
toast.error(ctx.error.message);
|
||||
},
|
||||
},
|
||||
);
|
||||
}}
|
||||
>
|
||||
Cancel Plan
|
||||
|
||||
@@ -26,7 +26,7 @@ export const client = createAuthClient({
|
||||
oneTapClient({
|
||||
clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!,
|
||||
promptOptions: {
|
||||
maxAttempts: 2,
|
||||
maxAttempts: 1,
|
||||
},
|
||||
}),
|
||||
oidcClient(),
|
||||
|
||||
@@ -20,7 +20,6 @@ import { nextCookies } from "better-auth/next-js";
|
||||
import { passkey } from "better-auth/plugins/passkey";
|
||||
import { stripe } from "@better-auth/stripe";
|
||||
import { Stripe } from "stripe";
|
||||
import Database from "better-sqlite3";
|
||||
|
||||
const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev";
|
||||
const to = process.env.TEST_EMAIL || "";
|
||||
@@ -51,7 +50,10 @@ const STARTER_PRICE_ID = {
|
||||
|
||||
export const auth = betterAuth({
|
||||
appName: "Better Auth Demo",
|
||||
database: new Database("./stripe.db"),
|
||||
database: {
|
||||
dialect,
|
||||
type: process.env.USE_MYSQL ? "mysql" : "sqlite",
|
||||
},
|
||||
emailVerification: {
|
||||
async sendVerificationEmail({ user, url }) {
|
||||
const res = await resend.emails.send({
|
||||
@@ -163,7 +165,6 @@ export const auth = betterAuth({
|
||||
stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!,
|
||||
subscription: {
|
||||
enabled: true,
|
||||
requireEmailVerification: true,
|
||||
plans: [
|
||||
{
|
||||
name: "Starter",
|
||||
|
||||
@@ -230,7 +230,10 @@ function NewBadge({ isSelected }: { isSelected?: boolean }) {
|
||||
return (
|
||||
<div className="flex items-center justify-end w-full">
|
||||
<Badge
|
||||
className=" pointer-events-none !no-underline border-dashed !decoration-transparent"
|
||||
className={cn(
|
||||
" pointer-events-none !no-underline border-dashed !decoration-transparent",
|
||||
isSelected && "!border-solid",
|
||||
)}
|
||||
variant={isSelected ? "default" : "outline"}
|
||||
>
|
||||
New
|
||||
|
||||
@@ -62,11 +62,9 @@ export const multiSession = (options?: MultiSessionConfig) => {
|
||||
).filter((v) => v !== null);
|
||||
|
||||
const s = await getSessionFromCtx(ctx);
|
||||
console.log({ sessionTokens, s });
|
||||
if (!sessionTokens.length) return ctx.json([]);
|
||||
const sessions =
|
||||
await ctx.context.internalAdapter.findSessions(sessionTokens);
|
||||
console.log({ sessions });
|
||||
const validSessions = sessions.filter(
|
||||
(session) => session && session.session.expiresAt > new Date(),
|
||||
);
|
||||
|
||||
@@ -574,12 +574,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
|
||||
defaultValue: "member",
|
||||
fieldName: options?.schema?.member?.fields?.role,
|
||||
},
|
||||
teamId: {
|
||||
type: "string",
|
||||
required: false,
|
||||
sortable: true,
|
||||
fieldName: options?.schema?.member?.fields?.teamId,
|
||||
},
|
||||
...(teamSupport
|
||||
? {
|
||||
teamId: {
|
||||
type: "string",
|
||||
required: false,
|
||||
sortable: true,
|
||||
fieldName: options?.schema?.member?.fields?.teamId,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
createdAt: {
|
||||
type: "date",
|
||||
required: true,
|
||||
@@ -611,12 +615,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
|
||||
sortable: true,
|
||||
fieldName: options?.schema?.invitation?.fields?.role,
|
||||
},
|
||||
teamId: {
|
||||
type: "string",
|
||||
required: false,
|
||||
sortable: true,
|
||||
fieldName: options?.schema?.invitation?.fields?.teamId,
|
||||
},
|
||||
...(teamSupport
|
||||
? {
|
||||
teamId: {
|
||||
type: "string",
|
||||
required: false,
|
||||
sortable: true,
|
||||
fieldName: options?.schema?.invitation?.fields?.teamId,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
status: {
|
||||
type: "string",
|
||||
required: true,
|
||||
|
||||
@@ -93,59 +93,61 @@ export async function onSubscriptionUpdated(
|
||||
const subscriptionUpdated = event.data.object as Stripe.Subscription;
|
||||
const priceId = subscriptionUpdated.items.data[0].price.id;
|
||||
const plan = await getPlanByPriceId(options, priceId);
|
||||
if (plan) {
|
||||
const stripeId = subscriptionUpdated.id;
|
||||
const subscription = await ctx.context.adapter.findOne<Subscription>({
|
||||
model: "subscription",
|
||||
where: [
|
||||
{
|
||||
field: "stripeSubscriptionId",
|
||||
value: stripeId,
|
||||
},
|
||||
],
|
||||
});
|
||||
if (!subscription) {
|
||||
return;
|
||||
}
|
||||
const seats = subscriptionUpdated.items.data[0].quantity;
|
||||
await ctx.context.adapter.update({
|
||||
model: "subscription",
|
||||
update: {
|
||||
plan: plan.name.toLowerCase(),
|
||||
limits: plan.limits,
|
||||
updatedAt: new Date(),
|
||||
status: subscriptionUpdated.status,
|
||||
periodStart: new Date(
|
||||
subscriptionUpdated.current_period_start * 1000,
|
||||
),
|
||||
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
|
||||
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
|
||||
seats,
|
||||
const stripeId = subscriptionUpdated.id;
|
||||
const subscription = await ctx.context.adapter.findOne<Subscription>({
|
||||
model: "subscription",
|
||||
where: [
|
||||
{
|
||||
field: "stripeSubscriptionId",
|
||||
value: stripeId,
|
||||
},
|
||||
where: [
|
||||
{
|
||||
field: "stripeSubscriptionId",
|
||||
value: subscriptionUpdated.id,
|
||||
},
|
||||
],
|
||||
});
|
||||
const subscriptionCanceled =
|
||||
subscriptionUpdated.status === "active" &&
|
||||
subscriptionUpdated.cancel_at_period_end;
|
||||
if (subscriptionCanceled) {
|
||||
await options.subscription.onSubscriptionCancel?.({
|
||||
subscription,
|
||||
cancellationDetails:
|
||||
subscriptionUpdated.cancellation_details || undefined,
|
||||
stripeSubscription: subscriptionUpdated,
|
||||
event,
|
||||
});
|
||||
}
|
||||
await options.subscription.onSubscriptionUpdate?.({
|
||||
event,
|
||||
],
|
||||
});
|
||||
if (!subscription) {
|
||||
return;
|
||||
}
|
||||
const seats = subscriptionUpdated.items.data[0].quantity;
|
||||
await ctx.context.adapter.update({
|
||||
model: "subscription",
|
||||
update: {
|
||||
...(plan
|
||||
? {
|
||||
plan: plan.name.toLowerCase(),
|
||||
limits: plan.limits,
|
||||
}
|
||||
: {}),
|
||||
updatedAt: new Date(),
|
||||
status: subscriptionUpdated.status,
|
||||
periodStart: new Date(subscriptionUpdated.current_period_start * 1000),
|
||||
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
|
||||
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
|
||||
seats,
|
||||
},
|
||||
where: [
|
||||
{
|
||||
field: "stripeSubscriptionId",
|
||||
value: subscriptionUpdated.id,
|
||||
},
|
||||
],
|
||||
});
|
||||
const subscriptionCanceled =
|
||||
subscriptionUpdated.status === "active" &&
|
||||
subscriptionUpdated.cancel_at_period_end &&
|
||||
!subscription.cancelAtPeriodEnd; //if this is true, it means the subscription was canceled before the event was triggered
|
||||
if (subscriptionCanceled) {
|
||||
await options.subscription.onSubscriptionCancel?.({
|
||||
subscription,
|
||||
cancellationDetails:
|
||||
subscriptionUpdated.cancellation_details || undefined,
|
||||
stripeSubscription: subscriptionUpdated,
|
||||
event,
|
||||
});
|
||||
|
||||
}
|
||||
await options.subscription.onSubscriptionUpdate?.({
|
||||
event,
|
||||
subscription,
|
||||
});
|
||||
if (plan) {
|
||||
if (
|
||||
subscriptionUpdated.status === "active" &&
|
||||
subscription.status === "trialing" &&
|
||||
|
||||
@@ -335,6 +335,85 @@ export const stripe = <O extends StripeOptions>(options: O) => {
|
||||
});
|
||||
},
|
||||
),
|
||||
cancelSubscriptionCallback: createAuthEndpoint(
|
||||
"/subscription/cancel/callback",
|
||||
{
|
||||
method: "GET",
|
||||
query: z.record(z.string(), z.any()).optional(),
|
||||
},
|
||||
async (ctx) => {
|
||||
if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) {
|
||||
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
|
||||
}
|
||||
const session = await getSessionFromCtx<{ stripeCustomerId: string }>(
|
||||
ctx,
|
||||
);
|
||||
if (!session) {
|
||||
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
|
||||
}
|
||||
const { user } = session;
|
||||
const { callbackURL, reference } = ctx.query;
|
||||
|
||||
if (user?.stripeCustomerId) {
|
||||
try {
|
||||
const subscription =
|
||||
await ctx.context.adapter.findOne<Subscription>({
|
||||
model: "subscription",
|
||||
where: [
|
||||
{
|
||||
field: "referenceId",
|
||||
value: reference,
|
||||
},
|
||||
],
|
||||
});
|
||||
console.log({ subscription });
|
||||
if (
|
||||
!subscription ||
|
||||
subscription.cancelAtPeriodEnd ||
|
||||
subscription.status === "canceled"
|
||||
) {
|
||||
throw ctx.redirect(getUrl(ctx, callbackURL));
|
||||
}
|
||||
|
||||
const stripeSubscription = await client.subscriptions.list({
|
||||
customer: user.stripeCustomerId,
|
||||
status: "active",
|
||||
});
|
||||
const currentSubscription = stripeSubscription.data.find(
|
||||
(sub) => sub.id === subscription.stripeSubscriptionId,
|
||||
);
|
||||
console.log({ currentSubscription });
|
||||
if (currentSubscription?.cancel_at_period_end === true) {
|
||||
await ctx.context.adapter.update({
|
||||
model: "subscription",
|
||||
update: {
|
||||
status: currentSubscription?.status,
|
||||
cancelAtPeriodEnd: true,
|
||||
},
|
||||
where: [
|
||||
{
|
||||
field: "referenceId",
|
||||
value: reference,
|
||||
},
|
||||
],
|
||||
});
|
||||
await options.subscription?.onSubscriptionCancel?.({
|
||||
subscription,
|
||||
cancellationDetails: currentSubscription.cancellation_details,
|
||||
stripeSubscription: currentSubscription,
|
||||
event: undefined,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
ctx.context.logger.error(
|
||||
"Error checking subscription status from Stripe",
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
throw ctx.redirect(getUrl(ctx, callbackURL));
|
||||
},
|
||||
),
|
||||
cancelSubscription: createAuthEndpoint(
|
||||
"/subscription/cancel",
|
||||
{
|
||||
@@ -377,16 +456,50 @@ export const stripe = <O extends StripeOptions>(options: O) => {
|
||||
message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
|
||||
});
|
||||
}
|
||||
const { url } = await client.billingPortal.sessions.create({
|
||||
customer: subscription.stripeCustomerId,
|
||||
return_url: getUrl(ctx, ctx.body?.returnUrl || "/"),
|
||||
flow_data: {
|
||||
type: "subscription_cancel",
|
||||
subscription_cancel: {
|
||||
subscription: activeSubscription.id,
|
||||
const { url } = await client.billingPortal.sessions
|
||||
.create({
|
||||
customer: subscription.stripeCustomerId,
|
||||
return_url: getUrl(
|
||||
ctx,
|
||||
`${
|
||||
ctx.context.baseURL
|
||||
}/subscription/cancel/callback?callbackURL=${encodeURIComponent(
|
||||
ctx.body?.returnUrl || "/",
|
||||
)}&reference=${encodeURIComponent(referenceId)}`,
|
||||
),
|
||||
flow_data: {
|
||||
type: "subscription_cancel",
|
||||
subscription_cancel: {
|
||||
subscription: activeSubscription.id,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
})
|
||||
.catch(async (e) => {
|
||||
if (e.message.includes("already set to be cancel")) {
|
||||
/**
|
||||
* incase we missed the event from stripe, we set it manually
|
||||
* this is a rare case and should not happen
|
||||
*/
|
||||
if (!subscription.cancelAtPeriodEnd) {
|
||||
await ctx.context.adapter.update({
|
||||
model: "subscription",
|
||||
update: {
|
||||
cancelAtPeriodEnd: true,
|
||||
},
|
||||
where: [
|
||||
{
|
||||
field: "referenceId",
|
||||
value: referenceId,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
throw ctx.error("BAD_REQUEST", {
|
||||
message: e.message,
|
||||
code: e.code,
|
||||
});
|
||||
});
|
||||
return {
|
||||
url,
|
||||
redirect: true,
|
||||
@@ -509,6 +622,9 @@ export const stripe = <O extends StripeOptions>(options: O) => {
|
||||
status: stripeSubscription.status,
|
||||
seats: stripeSubscription.items.data[0]?.quantity || 1,
|
||||
plan: plan.name.toLowerCase(),
|
||||
periodEnd: stripeSubscription.current_period_end,
|
||||
periodStart: stripeSubscription.current_period_start,
|
||||
stripeSubscriptionId: stripeSubscription.id,
|
||||
},
|
||||
where: [
|
||||
{
|
||||
|
||||
@@ -600,6 +600,42 @@ describe("stripe", async () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const userCancelEvent = {
|
||||
type: "customer.subscription.updated",
|
||||
data: {
|
||||
object: {
|
||||
id: "sub_123",
|
||||
customer: "cus_123",
|
||||
status: "active",
|
||||
cancel_at_period_end: true,
|
||||
cancellation_details: {
|
||||
reason: "cancellation_requested",
|
||||
comment: "Customer canceled subscription",
|
||||
},
|
||||
items: {
|
||||
data: [{ price: { id: process.env.STRIPE_PRICE_ID_1 } }],
|
||||
},
|
||||
current_period_start: Math.floor(Date.now() / 1000),
|
||||
current_period_end: Math.floor(Date.now() / 1000) + 30 * 24 * 60 * 60,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const userCancelRequest = new Request(
|
||||
"http://localhost:3000/api/auth/stripe/webhook",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"stripe-signature": "test_signature",
|
||||
},
|
||||
body: JSON.stringify(userCancelEvent),
|
||||
},
|
||||
);
|
||||
|
||||
mockStripeForEvents.webhooks.constructEvent.mockReturnValue(
|
||||
userCancelEvent,
|
||||
);
|
||||
await eventTestAuth.handler(userCancelRequest);
|
||||
const cancelEvent = {
|
||||
type: "customer.subscription.updated",
|
||||
data: {
|
||||
|
||||
@@ -247,10 +247,10 @@ export interface StripeOptions {
|
||||
* @returns
|
||||
*/
|
||||
onSubscriptionCancel?: (data: {
|
||||
event: Stripe.Event;
|
||||
event?: Stripe.Event;
|
||||
subscription: Subscription;
|
||||
stripeSubscription: Stripe.Subscription;
|
||||
cancellationDetails?: Stripe.Subscription.CancellationDetails;
|
||||
cancellationDetails?: Stripe.Subscription.CancellationDetails | null;
|
||||
}) => Promise<void>;
|
||||
/**
|
||||
* A function to check if the reference id is valid
|
||||
|
||||
Reference in New Issue
Block a user