mirror of
https://github.com/LukeHagar/better-auth.git
synced 2025-12-07 20:37:44 +00:00
fix(stripe): improve subscription cancellation handling and add callback endpoint
This commit is contained in:
@@ -172,9 +172,16 @@ function Component(props: {
|
|||||||
variant="destructive"
|
variant="destructive"
|
||||||
className="w-full"
|
className="w-full"
|
||||||
onClick={async () => {
|
onClick={async () => {
|
||||||
await client.subscription.cancel({
|
await client.subscription.cancel(
|
||||||
returnUrl: "/dashboard",
|
{
|
||||||
});
|
returnUrl: "/dashboard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
onError: (ctx) => {
|
||||||
|
toast.error(ctx.error.message);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
Cancel Plan
|
Cancel Plan
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ export const client = createAuthClient({
|
|||||||
oneTapClient({
|
oneTapClient({
|
||||||
clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!,
|
clientId: process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID!,
|
||||||
promptOptions: {
|
promptOptions: {
|
||||||
maxAttempts: 2,
|
maxAttempts: 1,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
oidcClient(),
|
oidcClient(),
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import { nextCookies } from "better-auth/next-js";
|
|||||||
import { passkey } from "better-auth/plugins/passkey";
|
import { passkey } from "better-auth/plugins/passkey";
|
||||||
import { stripe } from "@better-auth/stripe";
|
import { stripe } from "@better-auth/stripe";
|
||||||
import { Stripe } from "stripe";
|
import { Stripe } from "stripe";
|
||||||
import Database from "better-sqlite3";
|
|
||||||
|
|
||||||
const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev";
|
const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev";
|
||||||
const to = process.env.TEST_EMAIL || "";
|
const to = process.env.TEST_EMAIL || "";
|
||||||
@@ -51,7 +50,10 @@ const STARTER_PRICE_ID = {
|
|||||||
|
|
||||||
export const auth = betterAuth({
|
export const auth = betterAuth({
|
||||||
appName: "Better Auth Demo",
|
appName: "Better Auth Demo",
|
||||||
database: new Database("./stripe.db"),
|
database: {
|
||||||
|
dialect,
|
||||||
|
type: process.env.USE_MYSQL ? "mysql" : "sqlite",
|
||||||
|
},
|
||||||
emailVerification: {
|
emailVerification: {
|
||||||
async sendVerificationEmail({ user, url }) {
|
async sendVerificationEmail({ user, url }) {
|
||||||
const res = await resend.emails.send({
|
const res = await resend.emails.send({
|
||||||
@@ -163,7 +165,6 @@ export const auth = betterAuth({
|
|||||||
stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!,
|
stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!,
|
||||||
subscription: {
|
subscription: {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
requireEmailVerification: true,
|
|
||||||
plans: [
|
plans: [
|
||||||
{
|
{
|
||||||
name: "Starter",
|
name: "Starter",
|
||||||
|
|||||||
@@ -230,7 +230,10 @@ function NewBadge({ isSelected }: { isSelected?: boolean }) {
|
|||||||
return (
|
return (
|
||||||
<div className="flex items-center justify-end w-full">
|
<div className="flex items-center justify-end w-full">
|
||||||
<Badge
|
<Badge
|
||||||
className=" pointer-events-none !no-underline border-dashed !decoration-transparent"
|
className={cn(
|
||||||
|
" pointer-events-none !no-underline border-dashed !decoration-transparent",
|
||||||
|
isSelected && "!border-solid",
|
||||||
|
)}
|
||||||
variant={isSelected ? "default" : "outline"}
|
variant={isSelected ? "default" : "outline"}
|
||||||
>
|
>
|
||||||
New
|
New
|
||||||
|
|||||||
@@ -62,11 +62,9 @@ export const multiSession = (options?: MultiSessionConfig) => {
|
|||||||
).filter((v) => v !== null);
|
).filter((v) => v !== null);
|
||||||
|
|
||||||
const s = await getSessionFromCtx(ctx);
|
const s = await getSessionFromCtx(ctx);
|
||||||
console.log({ sessionTokens, s });
|
|
||||||
if (!sessionTokens.length) return ctx.json([]);
|
if (!sessionTokens.length) return ctx.json([]);
|
||||||
const sessions =
|
const sessions =
|
||||||
await ctx.context.internalAdapter.findSessions(sessionTokens);
|
await ctx.context.internalAdapter.findSessions(sessionTokens);
|
||||||
console.log({ sessions });
|
|
||||||
const validSessions = sessions.filter(
|
const validSessions = sessions.filter(
|
||||||
(session) => session && session.session.expiresAt > new Date(),
|
(session) => session && session.session.expiresAt > new Date(),
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -574,12 +574,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
|
|||||||
defaultValue: "member",
|
defaultValue: "member",
|
||||||
fieldName: options?.schema?.member?.fields?.role,
|
fieldName: options?.schema?.member?.fields?.role,
|
||||||
},
|
},
|
||||||
teamId: {
|
...(teamSupport
|
||||||
type: "string",
|
? {
|
||||||
required: false,
|
teamId: {
|
||||||
sortable: true,
|
type: "string",
|
||||||
fieldName: options?.schema?.member?.fields?.teamId,
|
required: false,
|
||||||
},
|
sortable: true,
|
||||||
|
fieldName: options?.schema?.member?.fields?.teamId,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
createdAt: {
|
createdAt: {
|
||||||
type: "date",
|
type: "date",
|
||||||
required: true,
|
required: true,
|
||||||
@@ -611,12 +615,16 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
|
|||||||
sortable: true,
|
sortable: true,
|
||||||
fieldName: options?.schema?.invitation?.fields?.role,
|
fieldName: options?.schema?.invitation?.fields?.role,
|
||||||
},
|
},
|
||||||
teamId: {
|
...(teamSupport
|
||||||
type: "string",
|
? {
|
||||||
required: false,
|
teamId: {
|
||||||
sortable: true,
|
type: "string",
|
||||||
fieldName: options?.schema?.invitation?.fields?.teamId,
|
required: false,
|
||||||
},
|
sortable: true,
|
||||||
|
fieldName: options?.schema?.invitation?.fields?.teamId,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
status: {
|
status: {
|
||||||
type: "string",
|
type: "string",
|
||||||
required: true,
|
required: true,
|
||||||
|
|||||||
@@ -93,59 +93,61 @@ export async function onSubscriptionUpdated(
|
|||||||
const subscriptionUpdated = event.data.object as Stripe.Subscription;
|
const subscriptionUpdated = event.data.object as Stripe.Subscription;
|
||||||
const priceId = subscriptionUpdated.items.data[0].price.id;
|
const priceId = subscriptionUpdated.items.data[0].price.id;
|
||||||
const plan = await getPlanByPriceId(options, priceId);
|
const plan = await getPlanByPriceId(options, priceId);
|
||||||
if (plan) {
|
const stripeId = subscriptionUpdated.id;
|
||||||
const stripeId = subscriptionUpdated.id;
|
const subscription = await ctx.context.adapter.findOne<Subscription>({
|
||||||
const subscription = await ctx.context.adapter.findOne<Subscription>({
|
model: "subscription",
|
||||||
model: "subscription",
|
where: [
|
||||||
where: [
|
{
|
||||||
{
|
field: "stripeSubscriptionId",
|
||||||
field: "stripeSubscriptionId",
|
value: stripeId,
|
||||||
value: stripeId,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
if (!subscription) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const seats = subscriptionUpdated.items.data[0].quantity;
|
|
||||||
await ctx.context.adapter.update({
|
|
||||||
model: "subscription",
|
|
||||||
update: {
|
|
||||||
plan: plan.name.toLowerCase(),
|
|
||||||
limits: plan.limits,
|
|
||||||
updatedAt: new Date(),
|
|
||||||
status: subscriptionUpdated.status,
|
|
||||||
periodStart: new Date(
|
|
||||||
subscriptionUpdated.current_period_start * 1000,
|
|
||||||
),
|
|
||||||
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
|
|
||||||
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
|
|
||||||
seats,
|
|
||||||
},
|
},
|
||||||
where: [
|
],
|
||||||
{
|
});
|
||||||
field: "stripeSubscriptionId",
|
if (!subscription) {
|
||||||
value: subscriptionUpdated.id,
|
return;
|
||||||
},
|
}
|
||||||
],
|
const seats = subscriptionUpdated.items.data[0].quantity;
|
||||||
});
|
await ctx.context.adapter.update({
|
||||||
const subscriptionCanceled =
|
model: "subscription",
|
||||||
subscriptionUpdated.status === "active" &&
|
update: {
|
||||||
subscriptionUpdated.cancel_at_period_end;
|
...(plan
|
||||||
if (subscriptionCanceled) {
|
? {
|
||||||
await options.subscription.onSubscriptionCancel?.({
|
plan: plan.name.toLowerCase(),
|
||||||
subscription,
|
limits: plan.limits,
|
||||||
cancellationDetails:
|
}
|
||||||
subscriptionUpdated.cancellation_details || undefined,
|
: {}),
|
||||||
stripeSubscription: subscriptionUpdated,
|
updatedAt: new Date(),
|
||||||
event,
|
status: subscriptionUpdated.status,
|
||||||
});
|
periodStart: new Date(subscriptionUpdated.current_period_start * 1000),
|
||||||
}
|
periodEnd: new Date(subscriptionUpdated.current_period_end * 1000),
|
||||||
await options.subscription.onSubscriptionUpdate?.({
|
cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end,
|
||||||
event,
|
seats,
|
||||||
|
},
|
||||||
|
where: [
|
||||||
|
{
|
||||||
|
field: "stripeSubscriptionId",
|
||||||
|
value: subscriptionUpdated.id,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
const subscriptionCanceled =
|
||||||
|
subscriptionUpdated.status === "active" &&
|
||||||
|
subscriptionUpdated.cancel_at_period_end &&
|
||||||
|
!subscription.cancelAtPeriodEnd; //if this is true, it means the subscription was canceled before the event was triggered
|
||||||
|
if (subscriptionCanceled) {
|
||||||
|
await options.subscription.onSubscriptionCancel?.({
|
||||||
subscription,
|
subscription,
|
||||||
|
cancellationDetails:
|
||||||
|
subscriptionUpdated.cancellation_details || undefined,
|
||||||
|
stripeSubscription: subscriptionUpdated,
|
||||||
|
event,
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
await options.subscription.onSubscriptionUpdate?.({
|
||||||
|
event,
|
||||||
|
subscription,
|
||||||
|
});
|
||||||
|
if (plan) {
|
||||||
if (
|
if (
|
||||||
subscriptionUpdated.status === "active" &&
|
subscriptionUpdated.status === "active" &&
|
||||||
subscription.status === "trialing" &&
|
subscription.status === "trialing" &&
|
||||||
|
|||||||
@@ -335,6 +335,85 @@ export const stripe = <O extends StripeOptions>(options: O) => {
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
cancelSubscriptionCallback: createAuthEndpoint(
|
||||||
|
"/subscription/cancel/callback",
|
||||||
|
{
|
||||||
|
method: "GET",
|
||||||
|
query: z.record(z.string(), z.any()).optional(),
|
||||||
|
},
|
||||||
|
async (ctx) => {
|
||||||
|
if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) {
|
||||||
|
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
|
||||||
|
}
|
||||||
|
const session = await getSessionFromCtx<{ stripeCustomerId: string }>(
|
||||||
|
ctx,
|
||||||
|
);
|
||||||
|
if (!session) {
|
||||||
|
throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
|
||||||
|
}
|
||||||
|
const { user } = session;
|
||||||
|
const { callbackURL, reference } = ctx.query;
|
||||||
|
|
||||||
|
if (user?.stripeCustomerId) {
|
||||||
|
try {
|
||||||
|
const subscription =
|
||||||
|
await ctx.context.adapter.findOne<Subscription>({
|
||||||
|
model: "subscription",
|
||||||
|
where: [
|
||||||
|
{
|
||||||
|
field: "referenceId",
|
||||||
|
value: reference,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
console.log({ subscription });
|
||||||
|
if (
|
||||||
|
!subscription ||
|
||||||
|
subscription.cancelAtPeriodEnd ||
|
||||||
|
subscription.status === "canceled"
|
||||||
|
) {
|
||||||
|
throw ctx.redirect(getUrl(ctx, callbackURL));
|
||||||
|
}
|
||||||
|
|
||||||
|
const stripeSubscription = await client.subscriptions.list({
|
||||||
|
customer: user.stripeCustomerId,
|
||||||
|
status: "active",
|
||||||
|
});
|
||||||
|
const currentSubscription = stripeSubscription.data.find(
|
||||||
|
(sub) => sub.id === subscription.stripeSubscriptionId,
|
||||||
|
);
|
||||||
|
console.log({ currentSubscription });
|
||||||
|
if (currentSubscription?.cancel_at_period_end === true) {
|
||||||
|
await ctx.context.adapter.update({
|
||||||
|
model: "subscription",
|
||||||
|
update: {
|
||||||
|
status: currentSubscription?.status,
|
||||||
|
cancelAtPeriodEnd: true,
|
||||||
|
},
|
||||||
|
where: [
|
||||||
|
{
|
||||||
|
field: "referenceId",
|
||||||
|
value: reference,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
await options.subscription?.onSubscriptionCancel?.({
|
||||||
|
subscription,
|
||||||
|
cancellationDetails: currentSubscription.cancellation_details,
|
||||||
|
stripeSubscription: currentSubscription,
|
||||||
|
event: undefined,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
ctx.context.logger.error(
|
||||||
|
"Error checking subscription status from Stripe",
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw ctx.redirect(getUrl(ctx, callbackURL));
|
||||||
|
},
|
||||||
|
),
|
||||||
cancelSubscription: createAuthEndpoint(
|
cancelSubscription: createAuthEndpoint(
|
||||||
"/subscription/cancel",
|
"/subscription/cancel",
|
||||||
{
|
{
|
||||||
@@ -377,16 +456,50 @@ export const stripe = <O extends StripeOptions>(options: O) => {
|
|||||||
message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
|
message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
const { url } = await client.billingPortal.sessions.create({
|
const { url } = await client.billingPortal.sessions
|
||||||
customer: subscription.stripeCustomerId,
|
.create({
|
||||||
return_url: getUrl(ctx, ctx.body?.returnUrl || "/"),
|
customer: subscription.stripeCustomerId,
|
||||||
flow_data: {
|
return_url: getUrl(
|
||||||
type: "subscription_cancel",
|
ctx,
|
||||||
subscription_cancel: {
|
`${
|
||||||
subscription: activeSubscription.id,
|
ctx.context.baseURL
|
||||||
|
}/subscription/cancel/callback?callbackURL=${encodeURIComponent(
|
||||||
|
ctx.body?.returnUrl || "/",
|
||||||
|
)}&reference=${encodeURIComponent(referenceId)}`,
|
||||||
|
),
|
||||||
|
flow_data: {
|
||||||
|
type: "subscription_cancel",
|
||||||
|
subscription_cancel: {
|
||||||
|
subscription: activeSubscription.id,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
})
|
||||||
});
|
.catch(async (e) => {
|
||||||
|
if (e.message.includes("already set to be cancel")) {
|
||||||
|
/**
|
||||||
|
* incase we missed the event from stripe, we set it manually
|
||||||
|
* this is a rare case and should not happen
|
||||||
|
*/
|
||||||
|
if (!subscription.cancelAtPeriodEnd) {
|
||||||
|
await ctx.context.adapter.update({
|
||||||
|
model: "subscription",
|
||||||
|
update: {
|
||||||
|
cancelAtPeriodEnd: true,
|
||||||
|
},
|
||||||
|
where: [
|
||||||
|
{
|
||||||
|
field: "referenceId",
|
||||||
|
value: referenceId,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw ctx.error("BAD_REQUEST", {
|
||||||
|
message: e.message,
|
||||||
|
code: e.code,
|
||||||
|
});
|
||||||
|
});
|
||||||
return {
|
return {
|
||||||
url,
|
url,
|
||||||
redirect: true,
|
redirect: true,
|
||||||
@@ -509,6 +622,9 @@ export const stripe = <O extends StripeOptions>(options: O) => {
|
|||||||
status: stripeSubscription.status,
|
status: stripeSubscription.status,
|
||||||
seats: stripeSubscription.items.data[0]?.quantity || 1,
|
seats: stripeSubscription.items.data[0]?.quantity || 1,
|
||||||
plan: plan.name.toLowerCase(),
|
plan: plan.name.toLowerCase(),
|
||||||
|
periodEnd: stripeSubscription.current_period_end,
|
||||||
|
periodStart: stripeSubscription.current_period_start,
|
||||||
|
stripeSubscriptionId: stripeSubscription.id,
|
||||||
},
|
},
|
||||||
where: [
|
where: [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -600,6 +600,42 @@ describe("stripe", async () => {
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const userCancelEvent = {
|
||||||
|
type: "customer.subscription.updated",
|
||||||
|
data: {
|
||||||
|
object: {
|
||||||
|
id: "sub_123",
|
||||||
|
customer: "cus_123",
|
||||||
|
status: "active",
|
||||||
|
cancel_at_period_end: true,
|
||||||
|
cancellation_details: {
|
||||||
|
reason: "cancellation_requested",
|
||||||
|
comment: "Customer canceled subscription",
|
||||||
|
},
|
||||||
|
items: {
|
||||||
|
data: [{ price: { id: process.env.STRIPE_PRICE_ID_1 } }],
|
||||||
|
},
|
||||||
|
current_period_start: Math.floor(Date.now() / 1000),
|
||||||
|
current_period_end: Math.floor(Date.now() / 1000) + 30 * 24 * 60 * 60,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const userCancelRequest = new Request(
|
||||||
|
"http://localhost:3000/api/auth/stripe/webhook",
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"stripe-signature": "test_signature",
|
||||||
|
},
|
||||||
|
body: JSON.stringify(userCancelEvent),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
mockStripeForEvents.webhooks.constructEvent.mockReturnValue(
|
||||||
|
userCancelEvent,
|
||||||
|
);
|
||||||
|
await eventTestAuth.handler(userCancelRequest);
|
||||||
const cancelEvent = {
|
const cancelEvent = {
|
||||||
type: "customer.subscription.updated",
|
type: "customer.subscription.updated",
|
||||||
data: {
|
data: {
|
||||||
|
|||||||
@@ -247,10 +247,10 @@ export interface StripeOptions {
|
|||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
onSubscriptionCancel?: (data: {
|
onSubscriptionCancel?: (data: {
|
||||||
event: Stripe.Event;
|
event?: Stripe.Event;
|
||||||
subscription: Subscription;
|
subscription: Subscription;
|
||||||
stripeSubscription: Stripe.Subscription;
|
stripeSubscription: Stripe.Subscription;
|
||||||
cancellationDetails?: Stripe.Subscription.CancellationDetails;
|
cancellationDetails?: Stripe.Subscription.CancellationDetails | null;
|
||||||
}) => Promise<void>;
|
}) => Promise<void>;
|
||||||
/**
|
/**
|
||||||
* A function to check if the reference id is valid
|
* A function to check if the reference id is valid
|
||||||
|
|||||||
Reference in New Issue
Block a user