diff --git a/demo/nextjs/app/(auth)/sign-in/page.tsx b/demo/nextjs/app/(auth)/sign-in/page.tsx index b2779a6f..7c6551e6 100644 --- a/demo/nextjs/app/(auth)/sign-in/page.tsx +++ b/demo/nextjs/app/(auth)/sign-in/page.tsx @@ -4,11 +4,24 @@ import SignIn from "@/components/sign-in"; import { SignUp } from "@/components/sign-up"; import { Tabs } from "@/components/ui/tabs2"; import { client } from "@/lib/auth-client"; +import { useRouter } from "next/navigation"; import { useEffect } from "react"; +import { toast } from "sonner"; export default function Page() { + const router = useRouter(); useEffect(() => { - client.oneTap(); + client.oneTap({ + fetchOptions: { + onError: ({ error }) => { + toast.error(error.message || "An error occurred"); + }, + onSuccess: () => { + toast.success("Successfully signed in"); + router.push("/dashboard"); + }, + }, + }); }, []); return ( diff --git a/demo/nextjs/app/dashboard/change-plan.tsx b/demo/nextjs/app/dashboard/change-plan.tsx new file mode 100644 index 00000000..f4509660 --- /dev/null +++ b/demo/nextjs/app/dashboard/change-plan.tsx @@ -0,0 +1,190 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Label } from "@/components/ui/label"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; +import { client } from "@/lib/auth-client"; +import { cn } from "@/lib/utils"; +import { ArrowUpFromLine, CreditCard, RefreshCcw } from "lucide-react"; +import { useId, useState } from "react"; +import { toast } from "sonner"; + +function Component(props: { + currentPlan?: string; + isTrial?: boolean; +}) { + const [selectedPlan, setSelectedPlan] = useState("starter"); + const id = useId(); + return ( + + + + + +
+ + + + {!props.currentPlan ? "Upgrade" : "Change"} your plan + + + Pick one of the following plans. + + +
+ +
+ setSelectedPlan(value)} + > +
+ +
+ +

+ $50/month +

+
+
+
+ +
+ +

+ $99/month +

+
+
+
+ +
+ +

+ Contact our sales team +

+
+
+
+ +
+

+ note: all upgrades takes effect immediately and you'll be charged + the new amount on your next billing cycle. +

+
+ +
+ + {props.currentPlan && ( + + )} +
+
+
+
+ ); +} + +export { Component }; diff --git a/demo/nextjs/app/dashboard/page.tsx b/demo/nextjs/app/dashboard/page.tsx index b79c72e1..5d24aa37 100644 --- a/demo/nextjs/app/dashboard/page.tsx +++ b/demo/nextjs/app/dashboard/page.tsx @@ -6,7 +6,7 @@ import { OrganizationCard } from "./organization-card"; import AccountSwitcher from "@/components/account-switch"; export default async function DashboardPage() { - const [session, activeSessions, deviceSessions, organization] = + const [session, activeSessions, deviceSessions, organization, subscriptions] = await Promise.all([ auth.api.getSession({ headers: await headers(), @@ -20,7 +20,11 @@ export default async function DashboardPage() { auth.api.getFullOrganization({ headers: await headers(), }), + auth.api.listActiveSubscriptions({ + headers: await headers(), + }), ]).catch((e) => { + console.log(e); throw redirect("/sign-in"); }); return ( @@ -32,6 +36,9 @@ export default async function DashboardPage() { sub.status === "active" || sub.status === "trialing", + )} /> setIsHovered(true)} + onHoverEnd={() => setIsHovered(false)} + whileHover={{ scale: 1.05 }} + whileTap={{ scale: 0.95 }} + > + + + Upgrade to Pro + + + + + ); +} diff --git a/demo/nextjs/app/dashboard/user-card.tsx b/demo/nextjs/app/dashboard/user-card.tsx index caea58b1..9f402163 100644 --- a/demo/nextjs/app/dashboard/user-card.tsx +++ b/demo/nextjs/app/dashboard/user-card.tsx @@ -54,10 +54,16 @@ import { } from "@/components/ui/table"; import QRCode from "react-qr-code"; import CopyButton from "@/components/ui/copy-button"; +import { Badge } from "@/components/ui/badge"; +import { useQuery } from "@tanstack/react-query"; +import { SubscriptionTierLabel } from "@/components/tier-labels"; +import { Component } from "./change-plan"; +import { Subscription } from "@better-auth/stripe"; export default function UserCard(props: { session: Session | null; activeSessions: Session["session"][]; + subscription?: Subscription; }) { const router = useRouter(); const { data, isPending } = useSession(); @@ -70,31 +76,75 @@ export default function UserCard(props: { const [isSignOut, setIsSignOut] = useState(false); const [emailVerificationPending, setEmailVerificationPending] = useState(false); - + const { data: subscription } = useQuery({ + queryKey: ["subscriptions"], + initialData: props.subscription ? props.subscription : null, + queryFn: async () => { + const res = await client.subscription.list({ + fetchOptions: { + throw: true, + }, + }); + return res.length ? res[0] : null; + }, + }); return ( User -
-
- - - {session?.user.name.charAt(0)} - -
-

- {session?.user.name} -

-

{session?.user.email}

+
+
+
+ + + {session?.user.name.charAt(0)} + +
+
+

+ {session?.user.name} +

+ {!!subscription && ( + + + + + + )} +
+

{session?.user.email}

+
+ +
+
+
+ +
+
-
{session?.user.emailVerified ? null : ( diff --git a/demo/nextjs/app/pricing/page.tsx b/demo/nextjs/app/pricing/page.tsx new file mode 100644 index 00000000..0d06e346 --- /dev/null +++ b/demo/nextjs/app/pricing/page.tsx @@ -0,0 +1,59 @@ +import { Pricing } from "@/components/blocks/pricing"; + +const demoPlans = [ + { + name: "STARTER", + price: "50", + yearlyPrice: "40", + period: "per month", + features: [ + "Up to 10 projects", + "Basic analytics", + "48-hour support response time", + "Limited API access", + ], + description: "Perfect for individuals and small projects", + buttonText: "Start Free Trial", + href: "/sign-up", + isPopular: false, + }, + { + name: "PROFESSIONAL", + price: "99", + yearlyPrice: "79", + period: "per month", + features: [ + "Unlimited projects", + "Advanced analytics", + "24-hour support response time", + "Full API access", + "Priority support", + ], + description: "Ideal for growing teams and businesses", + buttonText: "Get Started", + href: "/sign-up", + isPopular: true, + }, + { + name: "ENTERPRISE", + price: "299", + yearlyPrice: "239", + period: "per month", + features: [ + "Everything in Professional", + "Custom solutions", + "Dedicated account manager", + "1-hour support response time", + "SSO Authentication", + "Advanced security", + ], + description: "For large organizations with specific needs", + buttonText: "Contact Sales", + href: "/contact", + isPopular: false, + }, +]; + +export default function Page() { + return ; +} diff --git a/demo/nextjs/components/blocks/pricing.tsx b/demo/nextjs/components/blocks/pricing.tsx new file mode 100644 index 00000000..2330719a --- /dev/null +++ b/demo/nextjs/components/blocks/pricing.tsx @@ -0,0 +1,235 @@ +"use client"; + +import { Button, buttonVariants } from "@/components/ui/button"; +import { Label } from "@/components/ui/label"; +import { Switch } from "@/components/ui/switch"; + +import { cn } from "@/lib/utils"; +import { motion } from "framer-motion"; +import { Star } from "lucide-react"; +import { useState, useRef, useEffect } from "react"; +import confetti from "canvas-confetti"; +import NumberFlow from "@number-flow/react"; +import { CheckIcon } from "@radix-ui/react-icons"; +import { client } from "@/lib/auth-client"; + +function useMediaQuery(query: string) { + const [matches, setMatches] = useState(false); + + useEffect(() => { + const media = window.matchMedia(query); + if (media.matches !== matches) { + setMatches(media.matches); + } + + const listener = () => setMatches(media.matches); + media.addListener(listener); + + return () => media.removeListener(listener); + }, [query]); + + return matches; +} + +interface PricingPlan { + name: string; + price: string; + yearlyPrice: string; + period: string; + features: string[]; + description: string; + buttonText: string; + href: string; + isPopular: boolean; +} + +interface PricingProps { + plans: PricingPlan[]; + title?: string; + description?: string; +} + +export function Pricing({ + plans, + title = "Simple, Transparent Pricing", + description = "Choose the plan that works for you", +}: PricingProps) { + const [isMonthly, setIsMonthly] = useState(true); + const isDesktop = useMediaQuery("(min-width: 768px)"); + const switchRef = useRef(null); + + const handleToggle = (checked: boolean) => { + setIsMonthly(!checked); + if (checked && switchRef.current) { + const rect = switchRef.current.getBoundingClientRect(); + const x = rect.left + rect.width / 2; + const y = rect.top + rect.height / 2; + + confetti({ + particleCount: 50, + spread: 60, + origin: { + x: x / window.innerWidth, + y: y / window.innerHeight, + }, + colors: [ + "hsl(var(--primary))", + "hsl(var(--accent))", + "hsl(var(--secondary))", + "hsl(var(--muted))", + ], + ticks: 200, + gravity: 1.2, + decay: 0.94, + startVelocity: 30, + shapes: ["circle"], + }); + } + }; + + return ( +
+
+

+ {title} +

+

+ {description} +

+
+ +
+ + + Annual billing (Save 20%) + +
+ +
+ {plans.map((plan, index) => ( + + {plan.isPopular && ( +
+ + + Popular + +
+ )} +
+

+ {plan.name} +

+
+ + + + {plan.period !== "Next 3 months" && ( + + / {plan.period} + + )} +
+ +

+ {isMonthly ? "billed monthly" : "billed annually"} +

+ +
    + {plan.features.map((feature, idx) => ( +
  • + + {feature} +
  • + ))} +
+ +
+ +

+ {plan.description} +

+
+
+ ))} +
+
+ ); +} diff --git a/demo/nextjs/components/tier-labels.tsx b/demo/nextjs/components/tier-labels.tsx new file mode 100644 index 00000000..7d7e41cd --- /dev/null +++ b/demo/nextjs/components/tier-labels.tsx @@ -0,0 +1,38 @@ +import type React from "react"; +import { cva, type VariantProps } from "class-variance-authority"; +import { cn } from "@/lib/utils"; + +const tierVariants = cva( + "inline-flex items-center rounded-full px-3 py-1 text-xs font-semibold ring-1 ring-inset transition-all duration-300 ease-in-out", + { + variants: { + variant: { + free: "bg-gray-500 text-white ring-gray-400 hover:bg-gray-600", + starter: "bg-lime-700/40 text-white ring-lime-200/40 hover:bg-lime-600", + professional: "bg-purple-800/80 ring-purple-400 hover:bg-purple-700", + enterprise: "bg-amber-500 text-black ring-amber-400 hover:bg-amber-600", + }, + }, + defaultVariants: { + variant: "free", + }, + }, +); + +export interface SubscriptionTierLabelProps + extends React.HTMLAttributes, + VariantProps { + tier?: "free" | "starter" | "professional" | "enterprise"; +} + +export const SubscriptionTierLabel: React.FC = ({ + tier = "free", + className, + ...props +}) => { + return ( + + {tier.charAt(0).toUpperCase() + tier.slice(1)} + + ); +}; diff --git a/demo/nextjs/lib/auth-client.ts b/demo/nextjs/lib/auth-client.ts index ee848db8..1f1a7f8b 100644 --- a/demo/nextjs/lib/auth-client.ts +++ b/demo/nextjs/lib/auth-client.ts @@ -10,6 +10,7 @@ import { genericOAuthClient, } from "better-auth/client/plugins"; import { toast } from "sonner"; +import { stripeClient } from "@better-auth/stripe/client"; export const client = createAuthClient({ plugins: [ @@ -30,6 +31,9 @@ export const client = createAuthClient({ }), oidcClient(), genericOAuthClient(), + stripeClient({ + subscription: true, + }), ], fetchOptions: { onError(e) { diff --git a/demo/nextjs/lib/auth.ts b/demo/nextjs/lib/auth.ts index c6b9eb87..bc72797c 100644 --- a/demo/nextjs/lib/auth.ts +++ b/demo/nextjs/lib/auth.ts @@ -18,6 +18,9 @@ import { MysqlDialect } from "kysely"; import { createPool } from "mysql2/promise"; import { nextCookies } from "better-auth/next-js"; import { passkey } from "better-auth/plugins/passkey"; +import { stripe } from "@better-auth/stripe"; +import { Stripe } from "stripe"; +import Database from "better-sqlite3"; const from = process.env.BETTER_AUTH_EMAIL || "delivered@resend.dev"; const to = process.env.TEST_EMAIL || ""; @@ -37,12 +40,18 @@ if (!dialect) { throw new Error("No dialect found"); } +const PROFESSION_PRICE_ID = { + default: "price_1QxWZ5LUjnrYIrml5Dnwnl0X", + annual: "price_1QxWZTLUjnrYIrmlyJYpwyhz", +}; +const STARTER_PRICE_ID = { + default: "price_1QxWWtLUjnrYIrmleljPKszG", + annual: "price_1QxWYqLUjnrYIrmlonqPThVF", +}; + export const auth = betterAuth({ appName: "Better Auth Demo", - database: { - dialect, - type: "sqlite", - }, + database: new Database("./stripe.db"), emailVerification: { async sendVerificationEmail({ user, url }) { const res = await resend.emails.send({ @@ -149,5 +158,31 @@ export const auth = betterAuth({ loginPage: "/sign-in", }), oneTap(), + stripe({ + stripeClient: new Stripe(process.env.STRIPE_KEY!), + stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!, + subscription: { + enabled: true, + requireEmailVerification: true, + plans: [ + { + name: "Starter", + priceId: STARTER_PRICE_ID.default, + annualDiscountPriceId: STARTER_PRICE_ID.annual, + freeTrial: { + days: 7, + }, + }, + { + name: "Professional", + priceId: PROFESSION_PRICE_ID.default, + annualDiscountPriceId: PROFESSION_PRICE_ID.annual, + }, + { + name: "Enterprise", + }, + ], + }, + }), ], }); diff --git a/demo/nextjs/package.json b/demo/nextjs/package.json index 3410170b..ad758f16 100644 --- a/demo/nextjs/package.json +++ b/demo/nextjs/package.json @@ -11,10 +11,12 @@ "lint": "next lint" }, "dependencies": { + "@better-auth/stripe": "workspace:*", "@better-fetch/fetch": "catalog:", "@hookform/resolvers": "^3.9.1", "@libsql/client": "^0.12.0", "@libsql/kysely-libsql": "^0.4.1", + "@number-flow/react": "^0.5.5", "@prisma/adapter-libsql": "^5.22.0", "@prisma/client": "^5.22.0", "@radix-ui/react-accordion": "^1.2.1", @@ -52,6 +54,7 @@ "better-auth": "workspace:*", "better-call": "catalog:", "better-sqlite3": "^11.6.0", + "canvas-confetti": "^1.9.3", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "1.0.0", @@ -87,6 +90,7 @@ "zod": "^3.23.8" }, "devDependencies": { + "@types/canvas-confetti": "^1.9.0", "@types/node": "^20.17.9", "@types/react": "^18.3.14", "@types/react-dom": "^18.3.2", diff --git a/docs/components/sidebar-content.tsx b/docs/components/sidebar-content.tsx index aabba24a..d9f4fa08 100644 --- a/docs/components/sidebar-content.tsx +++ b/docs/components/sidebar-content.tsx @@ -1363,11 +1363,32 @@ export const contents: Content[] = [ href: "/docs/plugins/jwt", }, { - title: "Other", + title: "3d party", group: true, href: "/docs/plugins/1st-party-plugins", icon: () => , }, + { + title: "Stripe", + href: "/docs/plugins/stripe", + icon: () => ( + + + + ), + }, { title: "Community Plugins", href: "/docs/plugins/community-plugins", diff --git a/docs/content/docs/plugins/stripe.mdx b/docs/content/docs/plugins/stripe.mdx new file mode 100644 index 00000000..1ec7cb8b --- /dev/null +++ b/docs/content/docs/plugins/stripe.mdx @@ -0,0 +1,707 @@ +--- +title: Stripe +description: Stripe plugin for Better Auth to manage subscriptions and payments. +--- + +The Stripe plugin integrates Stripe's payment and subscription functionality with Better Auth. Since payment and authentication are often tightly coupled, this plugin simplifies the integration of stripe into your application, handling customer creation, subscription management, and webhook processing. + + +This plugin is currently in beta. We're actively collecting feedback and exploring additional features. If you have feature requests or suggestions, please join our [Discord community](https://discord.com/invite/Mh3DaacaFs) to discuss them. + + +## Features + +- Create Stripe Customers automatically when users sign up +- Manage subscription plans and pricing +- Process subscription lifecycle events (creation, updates, cancellations) +- Handle Stripe webhooks securely with signature verification +- Expose subscription data to your application +- Support for trial periods and subscription upgrades +- Flexible reference system to associate subscriptions with users or organizations +- Team subscription support with seats management + +## Installation + + + + ### Install the plugin + + First, install the plugin: + + ```package-install + @better-auth/stripe + ``` + + If you're using a separate client and server setup, make sure to install the plugin in both parts of your project. + + + + ### Install the Stripe SDK + + Next, install the Stripe SDK on your server: + + ```package-install + stripe + ``` + + + ### Add the plugin to your auth config + + ```ts title="auth.ts" + import { betterAuth } from "better-auth" + import { stripe } from "@better-auth/stripe" + import Stripe from "stripe" + + const stripeClient = new Stripe(process.env.STRIPE_SECRET_KEY!) + + export const auth = betterAuth({ + // ... your existing config + plugins: [ + stripe({ + stripeClient, + stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!, + createCustomerOnSignUp: true, + }) + ] + }) + ``` + + + ### Add the client plugin + + ```ts title="auth-client.ts" + import { createAuthClient } from "better-auth/client" + import { stripeClient } from "@better-auth/stripe/client" + + export const client = createAuthClient({ + // ... your existing config + plugins: [ + stripeClient({ + subscription: true //if you want to enable subscription management + }) + ] + }) + ``` + + + ### Migrate the database + + Run the migration or generate the schema to add the necessary tables to the database. + + + + ```bash + npx @better-auth/cli migrate + ``` + + + ```bash + npx @better-auth/cli generate + ``` + + + See the [Schema](#schema) section to add the tables manually. + + + ### Set up Stripe webhooks + + Create a webhook endpoint in your Stripe dashboard pointing to: + + ``` + https://your-domain.com/api/auth/stripe/webhook + ``` + `/api/auth` is the default path for the auth server. + + Make sure to select at least these events: + - `checkout.session.completed` + - `customer.subscription.updated` + - `customer.subscription.deleted` + + Save the webhook signing secret provided by Stripe and add it to your environment variables as `STRIPE_WEBHOOK_SECRET`. + + + +## Usage + +### Customer Management + +You can use this plugin solely for customer management without enabling subscriptions. This is useful if you just want to link Stripe customers to your users. + +By default, when a user signs up, a Stripe customer is automatically created if you set `createCustomerOnSignUp: true`. This customer is linked to the user in your database. +You can customize the customer creation process: + +```ts title="auth.ts" +stripe({ + // ... other options + createCustomerOnSignUp: true, + onCustomerCreate: async ({ customer, stripeCustomer, user }, request) => { + // Do something with the newly created customer + console.log(`Customer ${customer.id} created for user ${user.id}`); + }, + getCustomerCreateParams: async ({ user, session }, request) => { + // Customize the Stripe customer creation parameters + return { + metadata: { + referralSource: user.metadata?.referralSource + } + }; + } +}) +``` + +### Subscription Management + +#### Defining Plans + +You can define your subscription plans either statically or dynamically: + +```ts title="auth.ts" +// Static plans +subscription: { + enabled: true, + plans: [ + { + name: "basic", // the name of the plan, it'll be automatically lower cased when stored in the database + priceId: "price_1234567890", // the price id from stripe + limits: { + projects: 5, + storage: 10 + } + }, + { + name: "pro", + priceId: "price_0987654321", + limits: { + projects: 20, + storage: 50 + }, + freeTrial: { + days: 14, + forNewUsersOnly: true + } + } + ] +} + +// Dynamic plans (fetched from database or API) +subscription: { + enabled: true, + plans: async () => { + const plans = await db.query("SELECT * FROM plans"); + return plans.map(plan => ({ + name: plan.name, + priceId: plan.stripe_price_id, + limits: JSON.parse(plan.limits) + })); + } +} +``` + +see [plan configuration](#plan-configuration) for more. + +#### Creating a Subscription + +To create a subscription, use the `subscription.upgrade` method: + +```ts title="client.ts" +await client.subscription.upgrade({ + plan: "pro", + successUrl: "/dashboard", + cancelUrl: "/pricing", + referenceId: "org_123" // Optional: defaults to the current logged in user id + seats: 5 // Optional: for team plans +}); +``` + +This will create a Checkout Session and redirect the user to the Stripe Checkout page. + +> **Important:** The `successUrl` parameter will be internally modified to handle race conditions between checkout completion and webhook processing. The plugin creates an intermediate redirect that ensures subscription status is properly updated before redirecting to your success page. + +```ts +const { error } = await client.subscription.upgrade({ + plan: "pro", + successUrl: "/dashboard", + cancelUrl: "/pricing", +}); +if(error) { + alert(error.message); +} +``` + + +For each reference ID (user or organization), only one active or trialing subscription is supported at a time. The plugin doesn't currently support multiple concurrent active subscriptions for the same reference ID. + + + +#### Listing Active Subscriptions + +To get the user's active subscriptions: + +```ts title="client.ts" +const { data: subscriptions } = await client.subscription.list(); + +// get the active subscription +const activeSubscription = subscriptions.find( + sub => sub.status === "active" || sub.status === "trialing" +); + +// Check subscription limits +const projectLimit = subscriptions?.limits?.projects || 0; +``` + +#### Canceling a Subscription + +To cancel a subscription: + +```ts title="client.ts" +const { data } = await client.subscription.cancel({ + returnUrl: "/account", + referenceId: "org_123" // optional defaults to userId +}); +``` + +This will redirect the user to the Stripe Billing Portal where they can cancel their subscription. + +### Reference System + +By default, subscriptions are associated with the user ID. However, you can use a custom reference ID to associate subscriptions with other entities, such as organizations: + +```ts title="client.ts" +// Create a subscription for an organization +await client.subscription.upgrade({ + plan: "pro", + referenceId: "org_123456", + successUrl: "/dashboard", + cancelUrl: "/pricing", + seats: 5 // Number of seats for team plans +}); + +// List subscriptions for an organization +const { data: subscriptions } = await client.subscription.list({ + referenceId: "org_123456" +}); +``` + +#### Team Subscriptions with Seats + +For team or organization plans, you can specify the number of seats: + +```ts +await client.subscription.upgrade({ + plan: "team", + referenceId: "org_123456", + seats: 10, // 10 team members + successUrl: "/org/billing/success", + cancelUrl: "/org/billing" +}); +``` + +The `seats` parameter is passed to Stripe as the quantity for the subscription item. You can use this value in your application logic to limit the number of members in a team or organization. + +To authorize reference IDs, implement the `authorizeReference` function: + +```ts title="auth.ts" +subscription: { + // ... other options + authorizeReference: async ({ user, session, referenceId, action }) => { + // Check if the user has permission to manage subscriptions for this reference + if (action === "upgrade-subscription" || action === "cancel-subscription") { + const org = await db.member.findFirst({ + where: { + organizationId: referenceId, + userId: user.id + } + }); + return org?.role === "owner" + } + return true; + } +} +``` + +### Webhook Handling + +The plugin automatically handles common webhook events: + +- `checkout.session.completed`: Updates subscription status after checkout +- `customer.subscription.updated`: Updates subscription details when changed +- `customer.subscription.deleted`: Marks subscription as canceled + +You can also handle custom events: + +```ts title="auth.ts" +stripe({ + // ... other options + onEvent: async (event) => { + // Handle any Stripe event + switch (event.type) { + case "invoice.paid": + // Handle paid invoice + break; + case "payment_intent.succeeded": + // Handle successful payment + break; + } + } +}) +``` + +### Subscription Lifecycle Hooks + +You can hook into various subscription lifecycle events: + +```ts title="auth.ts" +subscription: { + // ... other options + onSubscriptionComplete: async ({ event, subscription, stripeSubscription, plan }) => { + // Called when a subscription is successfully created + await sendWelcomeEmail(subscription.referenceId, plan.name); + }, + onSubscriptionUpdate: async ({ event, subscription }) => { + // Called when a subscription is updated + console.log(`Subscription ${subscription.id} updated`); + }, + onSubscriptionCancel: async ({ event, subscription, stripeSubscription, cancellationDetails }) => { + // Called when a subscription is canceled + await sendCancellationEmail(subscription.referenceId); + }, + onSubscriptionDeleted: async ({ event, subscription, stripeSubscription }) => { + // Called when a subscription is deleted + console.log(`Subscription ${subscription.id} deleted`); + } +} +``` + +### Trial Periods + +You can configure trial periods for your plans: + +```ts title="auth.ts" +{ + name: "pro", + priceId: "price_0987654321", + freeTrial: { + days: 14, + forNewUsersOnly: true, // only new users can start a trial + onTrialStart: async (subscription) => { + // Called when a trial starts + await sendTrialStartEmail(subscription.referenceId); + }, + onTrialEnd: async ({ subscription, user }, request) => { + // Called when a trial ends + await sendTrialEndEmail(user.email); + }, + onTrialExpired: async (subscription) => { + // Called when a trial expires without conversion + await sendTrialExpiredEmail(subscription.referenceId); + } + } +} +``` + +## Schema + +The Stripe plugin adds the following tables to your database: + +### Customer + +Table Name: `customer` + + + +### Subscription + +Table Name: `subscription` + + + +### Customizing the Schema + +To change the schema table names or fields, you can pass a `schema` option to the Stripe plugin: + +```ts title="auth.ts" +stripe({ + // ... other options + schema: { + customer: { + modelName: "stripeCustomers", // map the customer table to stripeCustomers + fields: { + stripeCustomerId: "externalId" // map the stripeCustomerId field to externalId + } + }, + subscription: { + modelName: "stripeSubscriptions", // map the subscription table to stripeSubscriptions + fields: { + plan: "planName" // map the plan field to planName + } + } + } +}) +``` + +## Options + +### Main Options + +**stripeClient**: `Stripe` - The Stripe client instance. Required. + +**stripeWebhookSecret**: `string` - The webhook signing secret from Stripe. Required. + +**createCustomerOnSignUp**: `boolean` - Whether to automatically create a Stripe customer when a user signs up. Default: `false`. + +**onCustomerCreate**: `(data: { customer: Customer, stripeCustomer: Stripe.Customer, user: User }, request?: Request) => Promise` - A function called after a customer is created. + +**getCustomerCreateParams**: `(data: { user: User, session: Session }, request?: Request) => Promise<{}>` - A function to customize the Stripe customer creation parameters. + +**onEvent**: `(event: Stripe.Event) => Promise` - A function called for any Stripe webhook event. + +### Subscription Options + +**enabled**: `boolean` - Whether to enable subscription functionality. Required. + +**plans**: `Plan[] | (() => Promise)` - An array of subscription plans or a function that returns plans. Required if subscriptions are enabled. + +**requireEmailVerification**: `boolean` - Whether to require email verification before allowing subscription upgrades. Default: `false`. + +**authorizeReference**: `(data: { user: User, session: Session, referenceId: string, action: "upgrade-subscription" | "list-subscription" | "cancel-subscription" }, request?: Request) => Promise` - A function to authorize reference IDs. + +### Plan Configuration + +Each plan can have the following properties: + +**name**: `string` - The name of the plan. Required. + +**priceId**: `string` - The Stripe price ID. Required unless using `lookupKey`. + +**lookupKey**: `string` - The Stripe price lookup key. Alternative to `priceId`. + +**annualDiscountPriceId**: `string` - A price ID for annual billing with a discount. + +**limits**: `Record` - Limits associated with the plan (e.g., `{ projects: 10, storage: 5 }`). + +**group**: `string` - A group name for the plan, useful for categorizing plans. + +**freeTrial**: Object containing trial configuration: + - **days**: `number` - Number of trial days. + - **forNewUsersOnly**: `boolean` - Whether the trial is only for new users. Default: `true`. + - **onTrialStart**: `(subscription: Subscription) => Promise` - Called when a trial starts. + - **onTrialEnd**: `(data: { subscription: Subscription, user: User }, request?: Request) => Promise` - Called when a trial ends. + - **onTrialExpired**: `(subscription: Subscription) => Promise` - Called when a trial expires without conversion. + +## Advanced Usage + +### Using with Organizations + +The Stripe plugin works well with the organization plugin. You can associate subscriptions with organizations instead of individual users: + +```ts title="client.ts" +// Get the active organization +const { data: activeOrg } = client.useActiveOrganization(); + +// Create a subscription for the organization +await client.subscription.upgrade({ + plan: "team", + referenceId: activeOrg.id, + seats: 10, + successUrl: "/org/billing/success", + cancelUrl: "/org/billing" +}); +``` + +Make sure to implement the `authorizeReference` function to verify that the user has permission to manage subscriptions for the organization: + +```ts title="auth.ts" +authorizeReference: async ({ user, referenceId, action }) => { + const member = await db.members.findFirst({ + where: { + userId: user.id, + organizationId: referenceId + } + }); + + return member?.role === "owner" || member?.role === "admin"; +} +``` + +### Custom Checkout Session Parameters + +You can customize the Stripe Checkout session with additional parameters: + +```ts title="auth.ts" +getCheckoutSessionParams: async ({ user, session, plan, subscription }, request) => { + return { + params: { + allow_promotion_codes: true, + tax_id_collection: { + enabled: true + }, + billing_address_collection: "required", + custom_text: { + submit: { + message: "We'll start your subscription right away" + } + }, + metadata: { + planType: "business", + referralCode: user.metadata?.referralCode + } + }, + options: { + idempotencyKey: `sub_${user.id}_${plan.name}_${Date.now()}` + } + }; +} +``` + +### Tax Collection + +To enable tax collection: + +```ts title="auth.ts" +subscription: { + // ... other options + getCheckoutSessionParams: async ({ user, session, plan, subscription }, request) => { + return { + params: { + tax_id_collection: { + enabled: true + } + } + }; + } +} +``` + +## Troubleshooting + +### Webhook Issues + +If webhooks aren't being processed correctly: + +1. Check that your webhook URL is correctly configured in the Stripe dashboard +2. Verify that the webhook signing secret is correct +3. Ensure you've selected all the necessary events in the Stripe dashboard +4. Check your server logs for any errors during webhook processing + +### Subscription Status Issues + +If subscription statuses aren't updating correctly: + +1. Make sure the webhook events are being received and processed +2. Check that the `stripeCustomerId` and `stripeSubscriptionId` fields are correctly populated +3. Verify that the reference IDs match between your application and Stripe + +### Testing Webhooks Locally + +For local development, you can use the Stripe CLI to forward webhooks to your local environment: + +```bash +stripe listen --forward-to localhost:3000/api/auth/stripe/webhook +``` + +This will provide you with a webhook signing secret that you can use in your local environment. diff --git a/packages/better-auth/src/__snapshots__/init.test.ts.snap b/packages/better-auth/src/__snapshots__/init.test.ts.snap index 095fb7d4..c00585d2 100644 --- a/packages/better-auth/src/__snapshots__/init.test.ts.snap +++ b/packages/better-auth/src/__snapshots__/init.test.ts.snap @@ -120,6 +120,7 @@ exports[`init > should match config 1`] = ` "storage": "memory", "window": 10, }, + "runMigrations": [Function], "secondaryStorage": undefined, "secret": "better-auth-secret-123456789", "session": null, diff --git a/packages/better-auth/src/api/middlewares/origin-check.ts b/packages/better-auth/src/api/middlewares/origin-check.ts index 4f4b9916..cf42b0cc 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.ts @@ -70,7 +70,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => { }); export const originCheck = ( - getValue: (ctx: GenericEndpointContext) => string, + getValue: (ctx: GenericEndpointContext) => string | string[], ) => createAuthMiddleware(async (ctx) => { if (!ctx.request) { @@ -117,5 +117,8 @@ export const originCheck = ( throw new APIError("FORBIDDEN", { message: `Invalid ${label}` }); } }; - callbackURL && validateURL(callbackURL, "callbackURL"); + const callbacks = Array.isArray(callbackURL) ? callbackURL : [callbackURL]; + for (const url of callbacks) { + validateURL(url, "callbackURL"); + } }); diff --git a/packages/better-auth/src/cookies/cookie-utils.ts b/packages/better-auth/src/cookies/cookie-utils.ts index 2aa92344..244b6fde 100644 --- a/packages/better-auth/src/cookies/cookie-utils.ts +++ b/packages/better-auth/src/cookies/cookie-utils.ts @@ -73,3 +73,37 @@ export function parseSetCookieHeader( return cookies; } + +export function setCookieToHeader(headers: Headers) { + return (context: { + response: Response; + }) => { + const setCookieHeader = context.response.headers.get("set-cookie"); + if (!setCookieHeader) { + return; + } + + const cookieMap = new Map(); + + const existingCookiesHeader = headers.get("cookie") || ""; + existingCookiesHeader.split(";").forEach((cookie) => { + const [name, ...rest] = cookie.trim().split("="); + if (name && rest.length > 0) { + cookieMap.set(name, rest.join("=")); + } + }); + + const setCookieHeaders = setCookieHeader.split(","); + setCookieHeaders.forEach((header) => { + const cookies = parseSetCookieHeader(header); + cookies.forEach((value, name) => { + cookieMap.set(name, value.value); + }); + }); + + const updatedCookies = Array.from(cookieMap.entries()) + .map(([name, value]) => `${name}=${value}`) + .join("; "); + headers.set("cookie", updatedCookies); + }; +} diff --git a/packages/better-auth/src/init.ts b/packages/better-auth/src/init.ts index d2762d2a..95e53a5c 100644 --- a/packages/better-auth/src/init.ts +++ b/packages/better-auth/src/init.ts @@ -1,6 +1,6 @@ import { defu } from "defu"; import { hashPassword, verifyPassword } from "./crypto/password"; -import { createInternalAdapter } from "./db"; +import { createInternalAdapter, getMigrations } from "./db"; import { getAuthTables } from "./db/get-tables"; import { getAdapter } from "./db/utils"; import type { @@ -26,6 +26,7 @@ import { env, isProduction } from "./utils/env"; import { checkPassword } from "./utils/password"; import { getBaseURL } from "./utils/url"; import type { LiteralUnion } from "./types/helper"; +import { BetterAuthError } from "./error"; export const init = async (options: BetterAuthOptions) => { const adapter = await getAdapter(options); @@ -135,8 +136,19 @@ export const init = async (options: BetterAuthOptions) => { generateId: generateIdFunc, }), createAuthCookie: createCookieGetter(options), + async runMigrations() { + //only run migrations if database is provided and it's not an adapter + if (!options.database || "updateMany" in options.database) { + throw new BetterAuthError( + "Database is not provided or it's an adapter. Migrations are only supported with a database instance.", + ); + } + const { runMigrations } = await getMigrations(options); + await runMigrations(); + }, }; let { context } = runPluginInit(ctx); + context; return context; }; @@ -198,6 +210,7 @@ export type AuthContext = { checkPassword: typeof checkPassword; }; tables: ReturnType; + runMigrations: () => Promise; }; function runPluginInit(ctx: AuthContext) { diff --git a/packages/better-auth/src/integrations/next-js.ts b/packages/better-auth/src/integrations/next-js.ts index af18e75a..1f23430f 100644 --- a/packages/better-auth/src/integrations/next-js.ts +++ b/packages/better-auth/src/integrations/next-js.ts @@ -1,5 +1,4 @@ import type { BetterAuthPlugin } from "../types"; -import { cookies } from "next/headers"; import { parseSetCookieHeader } from "../cookies"; import { createAuthMiddleware } from "../plugins"; @@ -37,6 +36,7 @@ export const nextCookies = () => { const setCookies = returned?.get("set-cookie"); if (!setCookies) return; const parsed = parseSetCookieHeader(setCookies); + const { cookies } = await import("next/headers"); const cookieHelper = await cookies(); parsed.forEach((value, key) => { if (!key) return; diff --git a/packages/better-auth/src/oauth2/state.ts b/packages/better-auth/src/oauth2/state.ts index 5a4fbb35..98fac77c 100644 --- a/packages/better-auth/src/oauth2/state.ts +++ b/packages/better-auth/src/oauth2/state.ts @@ -24,6 +24,7 @@ export async function generateState( errorURL: c.body?.errorCallbackURL, newUserURL: c.body?.newUserCallbackURL, link, + /** * This is the actual expiry time of the state */ diff --git a/packages/better-auth/src/plugins/oidc-provider/index.ts b/packages/better-auth/src/plugins/oidc-provider/index.ts index 56cd4f7d..f341ebd5 100644 --- a/packages/better-auth/src/plugins/oidc-provider/index.ts +++ b/packages/better-auth/src/plugins/oidc-provider/index.ts @@ -145,6 +145,9 @@ export const oidcProvider = (options: OIDCOptions) => { "/.well-known/openid-configuration", { method: "GET", + metadata: { + isAction: false, + }, }, async (ctx) => { const metadata = getMetadata(ctx, options); diff --git a/packages/better-auth/src/test-utils/test-instance.ts b/packages/better-auth/src/test-utils/test-instance.ts index 08917d50..dbd76202 100644 --- a/packages/better-auth/src/test-utils/test-instance.ts +++ b/packages/better-auth/src/test-utils/test-instance.ts @@ -5,7 +5,7 @@ import { betterAuth } from "../auth"; import { createAuthClient } from "../client/vanilla"; import type { BetterAuthOptions, ClientOptions, Session, User } from "../types"; import { getMigrations } from "../db/get-migration"; -import { parseSetCookieHeader } from "../cookies"; +import { parseSetCookieHeader, setCookieToHeader } from "../cookies"; import type { SuccessContext } from "@better-fetch/fetch"; import { getAdapter } from "../db/utils"; import Database from "better-sqlite3"; @@ -232,39 +232,7 @@ export async function getTestInstance< } }; } - function cookieSetter(headers: Headers) { - return (context: { - response: Response; - }) => { - const setCookieHeader = context.response.headers.get("set-cookie"); - if (!setCookieHeader) { - return; - } - const cookieMap = new Map(); - - const existingCookiesHeader = headers.get("cookie") || ""; - existingCookiesHeader.split(";").forEach((cookie) => { - const [name, ...rest] = cookie.trim().split("="); - if (name && rest.length > 0) { - cookieMap.set(name, rest.join("=")); - } - }); - - const setCookieHeaders = setCookieHeader.split(","); - setCookieHeaders.forEach((header) => { - const cookies = parseSetCookieHeader(header); - cookies.forEach((value, name) => { - cookieMap.set(name, value.value); - }); - }); - - const updatedCookies = Array.from(cookieMap.entries()) - .map(([name, value]) => `${name}=${value}`) - .join("; "); - headers.set("cookie", updatedCookies); - }; - } const client = createAuthClient({ ...(config?.clientOptions as C extends undefined ? {} : C), baseURL: getBaseURL( @@ -281,7 +249,7 @@ export async function getTestInstance< testUser, signInWithTestUser, signInWithUser, - cookieSetter, + cookieSetter: setCookieToHeader, customFetchImpl, sessionSetter, db: await getAdapter(auth.options), diff --git a/packages/better-auth/tsconfig.json b/packages/better-auth/tsconfig.json index 5cdb3168..0d97965b 100644 --- a/packages/better-auth/tsconfig.json +++ b/packages/better-auth/tsconfig.json @@ -22,5 +22,5 @@ }, "exclude": ["**/dist", "node_modules"], "references": [], - "include": ["src/**/*"] + "include": ["src"] } diff --git a/packages/stripe/build.config.ts b/packages/stripe/build.config.ts new file mode 100644 index 00000000..2b1a140c --- /dev/null +++ b/packages/stripe/build.config.ts @@ -0,0 +1,12 @@ +import { defineBuildConfig } from "unbuild"; + +export default defineBuildConfig({ + declaration: true, + rollup: { + emitCJS: true, + }, + outDir: "dist", + clean: false, + failOnWarn: false, + externals: ["better-auth", "better-call", "@better-fetch/fetch", "stripe"], +}); diff --git a/packages/stripe/package.json b/packages/stripe/package.json new file mode 100644 index 00000000..926bd2e7 --- /dev/null +++ b/packages/stripe/package.json @@ -0,0 +1,52 @@ +{ + "name": "@better-auth/stripe", + "author": "Bereket Engida", + "version": "1.1.15", + "main": "dist/index.cjs", + "license": "MIT", + "keywords": [ + "stripe", + "auth", + "stripe" + ], + "module": "dist/index.mjs", + "description": "Stripe plugin for Better Auth", + "scripts": { + "test": "vitest", + "build": "unbuild", + "dev": "unbuild --watch" + }, + "exports": { + ".": { + "types": "./dist/index.d.ts", + "import": "./dist/index.mjs", + "require": "./dist/index.cjs" + }, + "./client": { + "types": "./dist/client.d.ts", + "import": "./dist/client.mjs", + "require": "./dist/client.cjs" + } + }, + "typesVersions": { + "*": { + "*": [ + "./dist/index.d.ts" + ], + "client": [ + "./dist/client.d.ts" + ] + } + }, + "dependencies": { + "better-auth": "workspace:^", + "zod": "^3.24.1" + }, + "devDependencies": { + "@types/better-sqlite3": "^7.6.12", + "better-sqlite3": "^11.6.0", + "vitest": "^1.6.0", + "stripe": "^17.7.0", + "better-call": "catalog:" + } +} \ No newline at end of file diff --git a/packages/stripe/src/client.ts b/packages/stripe/src/client.ts new file mode 100644 index 00000000..2745fb40 --- /dev/null +++ b/packages/stripe/src/client.ts @@ -0,0 +1,31 @@ +import type { BetterAuthClientPlugin } from "better-auth"; +import type { stripe } from "./index"; + +export const stripeClient = < + O extends { + subscription: boolean; + }, +>( + options?: O, +) => { + return { + id: "stripe-client", + $InferServerPlugin: {} as ReturnType< + typeof stripe< + O["subscription"] extends true + ? { + stripeClient: any; + stripeWebhookSecret: ""; + subscription: { + enabled: true; + plans: []; + }; + } + : { + stripeClient: any; + stripeWebhookSecret: ""; + } + > + >, + } satisfies BetterAuthClientPlugin; +}; diff --git a/packages/stripe/src/hooks.ts b/packages/stripe/src/hooks.ts new file mode 100644 index 00000000..055c98eb --- /dev/null +++ b/packages/stripe/src/hooks.ts @@ -0,0 +1,180 @@ +import type { GenericEndpointContext } from "better-auth"; +import type Stripe from "stripe"; +import type { InputSubscription, StripeOptions, Subscription } from "./types"; +import { getPlanByPriceId } from "./utils"; + +export async function onCheckoutSessionCompleted( + ctx: GenericEndpointContext, + options: StripeOptions, + event: Stripe.Event, +) { + const client = options.stripeClient; + const checkoutSession = event.data.object as Stripe.Checkout.Session; + if (checkoutSession.mode === "setup" || !options.subscription?.enabled) { + return; + } + const subscription = await client.subscriptions.retrieve( + checkoutSession.subscription as string, + ); + const priceId = subscription.items.data[0]?.price.id; + const plan = await getPlanByPriceId(options, priceId as string); + if (plan) { + const referenceId = checkoutSession?.metadata?.referenceId; + const subscriptionId = checkoutSession?.metadata?.subscriptionId; + const seats = subscription.items.data[0].quantity; + if (referenceId && subscriptionId) { + const trial = + subscription.trial_start && subscription.trial_end + ? { + trialStart: new Date(subscription.trial_start * 1000), + trialEnd: new Date(subscription.trial_end * 1000), + } + : {}; + let dbSubscription = await ctx.context.adapter.update({ + model: "subscription", + update: { + plan: plan.name.toLowerCase(), + status: subscription.status, + updatedAt: new Date(), + periodStart: new Date(subscription.current_period_start * 1000), + periodEnd: new Date(subscription.current_period_end * 1000), + seats, + ...trial, + }, + where: [ + { + field: "id", + value: subscriptionId, + }, + ], + }); + if (!dbSubscription) { + dbSubscription = await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "id", + value: subscriptionId, + }, + ], + }); + } + await options.subscription?.onSubscriptionComplete?.({ + event, + subscription: dbSubscription as Subscription, + stripeSubscription: subscription, + plan, + }); + return; + } + } +} + +export async function onSubscriptionUpdated( + ctx: GenericEndpointContext, + options: StripeOptions, + event: Stripe.Event, +) { + if (!options.subscription?.enabled) { + return; + } + const subscriptionUpdated = event.data.object as Stripe.Subscription; + const priceId = subscriptionUpdated.items.data[0].price.id; + const plan = await getPlanByPriceId(options, priceId); + if (plan) { + const stripeId = subscriptionUpdated.customer.toString(); + const subscription = await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "stripeSubscriptionId", + value: stripeId, + }, + ], + }); + if (!subscription) { + return; + } + const seats = subscriptionUpdated.items.data[0].quantity; + await ctx.context.adapter.update({ + model: "subscription", + update: { + plan: plan.name.toLowerCase(), + limits: plan.limits, + updatedAt: new Date(), + status: subscriptionUpdated.status, + periodStart: new Date(subscriptionUpdated.current_period_start * 1000), + periodEnd: new Date(subscriptionUpdated.current_period_end * 1000), + cancelAtPeriodEnd: subscriptionUpdated.cancel_at_period_end, + seats, + }, + where: [ + { + field: "stripeSubscriptionId", + value: subscriptionUpdated.id, + }, + ], + }); + const subscriptionCanceled = + subscriptionUpdated.status === "active" && + subscriptionUpdated.cancel_at_period_end; + if (subscriptionCanceled) { + await options.subscription.onSubscriptionCancel?.({ + subscription, + cancellationDetails: + subscriptionUpdated.cancellation_details || undefined, + stripeSubscription: subscriptionUpdated, + event, + }); + } + await options.subscription.onSubscriptionUpdate?.({ + event, + subscription, + }); + } +} + +export async function onSubscriptionDeleted( + ctx: GenericEndpointContext, + options: StripeOptions, + event: Stripe.Event, +) { + if (!options.subscription?.enabled) { + return; + } + const subscriptionDeleted = event.data.object as Stripe.Subscription; + const subscriptionId = subscriptionDeleted.metadata?.subscriptionId; + const stripeSubscription = await options.stripeClient.subscriptions.retrieve( + subscriptionId as string, + ); + if (stripeSubscription.status === "canceled") { + const subscription = await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "id", + value: subscriptionId, + }, + ], + }); + if (subscription) { + await ctx.context.adapter.update({ + model: "subscription", + where: [ + { + field: "id", + value: subscription.id, + }, + ], + update: { + status: "canceled", + }, + }); + await options.subscription.onSubscriptionDeleted?.({ + event, + stripeSubscription: subscriptionDeleted, + subscription, + }); + } + } +} diff --git a/packages/stripe/src/index.ts b/packages/stripe/src/index.ts new file mode 100644 index 00000000..e0379fac --- /dev/null +++ b/packages/stripe/src/index.ts @@ -0,0 +1,642 @@ +import { + type GenericEndpointContext, + type BetterAuthPlugin, +} from "better-auth"; +import { createAuthEndpoint, createAuthMiddleware } from "better-auth/plugins"; +import Stripe from "stripe"; +import { z } from "zod"; +import { + sessionMiddleware, + APIError, + originCheck, + getSessionFromCtx, +} from "better-auth/api"; +import { generateRandomString } from "better-auth/crypto"; +import { + onCheckoutSessionCompleted, + onSubscriptionDeleted, + onSubscriptionUpdated, +} from "./hooks"; +import type { InputSubscription, StripeOptions, Subscription } from "./types"; +import { getPlanByName, getPlanByPriceId, getPlans } from "./utils"; +import { getSchema } from "./schema"; + +const STRIPE_ERROR_CODES = { + SUBSCRIPTION_NOT_FOUND: "Subscription not found", + SUBSCRIPTION_PLAN_NOT_FOUND: "Subscription plan not found", + ALREADY_SUBSCRIBED_PLAN: "You're already subscribed to this plan", + UNABLE_TO_CREATE_CUSTOMER: "Unable to create customer", + FAILED_TO_FETCH_PLANS: "Failed to fetch plans", + EMAIL_VERIFICATION_REQUIRED: + "Email verification is required before you can subscribe to a plan", +} as const; + +const getUrl = (ctx: GenericEndpointContext, url: string) => { + if (url.startsWith("http")) { + return url; + } + return `${ctx.context.options.baseURL}${ + url.startsWith("/") ? url : `/${url}` + }`; +}; + +export const stripe = (options: O) => { + const client = options.stripeClient; + + const referenceMiddleware = ( + action: + | "upgrade-subscription" + | "list-subscription" + | "cancel-subscription", + ) => + createAuthMiddleware(async (ctx) => { + const session = ctx.context.session; + if (!session) { + throw new APIError("UNAUTHORIZED"); + } + const referenceId = + ctx.body?.referenceId || ctx.query?.referenceId || session.user.id; + const isAuthorized = ctx.body?.referenceId + ? await options.subscription?.authorizeReference?.({ + user: session.user, + session: session.session, + referenceId, + action, + }) + : true; + if (!isAuthorized) { + throw new APIError("UNAUTHORIZED", { + message: "Unauthorized", + }); + } + }); + + const subscriptionEndpoints = { + upgradeSubscription: createAuthEndpoint( + "/subscription/upgrade", + { + method: "POST", + body: z.object({ + plan: z.string(), + referenceId: z.string().optional(), + metadata: z.record(z.string(), z.any()).optional(), + seats: z + .number({ + description: "Number of seats to upgrade to (if applicable)", + }) + .optional(), + uiMode: z.enum(["embedded", "hosted"]).default("hosted"), + successUrl: z + .string({ + description: + "callback url to redirect back after successful subscription", + }) + .default("/"), + cancelUrl: z + .string({ + description: + "callback url to redirect back after successful subscription", + }) + .default("/"), + returnUrl: z.string().optional(), + withoutTrial: z.boolean().optional(), + disableRedirect: z.boolean().default(false), + }), + use: [ + sessionMiddleware, + originCheck((c) => { + return [c.body.successURL as string, c.body.cancelURL as string]; + }), + referenceMiddleware("upgrade-subscription"), + ], + }, + async (ctx) => { + const { user, session } = ctx.context.session; + if ( + !user.emailVerified && + options.subscription?.requireEmailVerification + ) { + throw new APIError("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.EMAIL_VERIFICATION_REQUIRED, + }); + } + const referenceId = ctx.body.referenceId || user.id; + const plan = await getPlanByName(options, ctx.body.plan); + if (!plan) { + throw new APIError("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.SUBSCRIPTION_PLAN_NOT_FOUND, + }); + } + let customerId = user.stripeCustomerId; + if (!customerId) { + try { + const stripeCustomer = await client.customers.create( + { + email: user.email, + name: user.name, + metadata: { + ...ctx.body.metadata, + userId: user.id, + }, + }, + { + idempotencyKey: generateRandomString(32, "a-z", "0-9"), + }, + ); + await ctx.context.adapter.update({ + model: "user", + update: { + stripeCustomerId: stripeCustomer.id, + }, + where: [ + { + field: "id", + value: user.id, + }, + ], + }); + customerId = stripeCustomer.id; + } catch (e: any) { + ctx.context.logger.error(e); + throw new APIError("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.UNABLE_TO_CREATE_CUSTOMER, + }); + } + } + + const activeSubscription = customerId + ? await client.subscriptions + .list({ + customer: customerId, + status: "active", + }) + .then((res) => res.data[0]) + .catch((e) => null) + : null; + const subscriptions = await ctx.context.adapter.findMany({ + model: "subscription", + where: [ + { + field: "referenceId", + value: ctx.body.referenceId || user.id, + }, + ], + }); + const existingSubscription = subscriptions.find( + (sub) => sub.status === "active" || sub.status === "trialing", + ); + if (activeSubscription && customerId) { + const { url } = await client.billingPortal.sessions + .create({ + customer: customerId, + return_url: getUrl(ctx, ctx.body.returnUrl || "/"), + flow_data: { + type: "subscription_update_confirm", + subscription_update_confirm: { + subscription: activeSubscription.id, + items: [ + { + id: activeSubscription.items.data[0]?.id as string, + quantity: 1, + price: plan.priceId, + }, + ], + }, + }, + }) + .catch(async (e) => { + if (e.message.includes("no changes")) { + /** + * If the subscription is already active on stripe, we need to + * update the status to the new status. + */ + const plan = await getPlanByPriceId( + options, + activeSubscription.items.data[0]?.plan.id, + ); + await ctx.context.adapter.update({ + model: "subscription", + update: { + status: activeSubscription.status, + seats: activeSubscription.items.data[0]?.quantity, + plan: plan?.name.toLowerCase(), + }, + where: [ + { + field: "referenceId", + value: referenceId, + }, + ], + }); + throw new APIError("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.ALREADY_SUBSCRIBED_PLAN, + }); + } + throw ctx.error("BAD_REQUEST", { + message: e.message, + code: e.code, + }); + }); + return ctx.json({ + url, + redirect: true, + }); + } + + if ( + existingSubscription && + existingSubscription.status === "active" && + existingSubscription.plan === ctx.body.plan + ) { + throw new APIError("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.ALREADY_SUBSCRIBED_PLAN, + }); + } + let subscription = existingSubscription; + if (!subscription) { + const newSubscription = await ctx.context.adapter.create< + InputSubscription, + Subscription + >({ + model: "subscription", + data: { + plan: plan.name.toLowerCase(), + stripeCustomerId: customerId, + status: "incomplete", + referenceId, + seats: ctx.body.seats || 1, + }, + }); + subscription = newSubscription; + } + + if (!subscription) { + ctx.context.logger.error("Subscription ID not found"); + throw new APIError("INTERNAL_SERVER_ERROR"); + } + + const params = await options.subscription?.getCheckoutSessionParams?.( + { + user, + session, + plan, + subscription, + }, + ctx.request, + ); + + const checkoutSession = await client.checkout.sessions + .create({ + ...(customerId + ? { + customer: customerId, + customer_update: { + name: "auto", + address: "auto", + }, + } + : { + customer_email: session.user.email, + }), + success_url: getUrl( + ctx, + `${ + ctx.context.baseURL + }/subscription/success?callbackURL=${encodeURIComponent( + ctx.body.successUrl, + )}&reference=${encodeURIComponent(referenceId)}`, + ), + cancel_url: getUrl(ctx, ctx.body.cancelUrl), + line_items: [ + { + price: plan.priceId, + quantity: ctx.body.seats || 1, + }, + ], + mode: "subscription", + client_reference_id: referenceId, + ...params, + metadata: { + userId: user.id, + subscriptionId: subscription.id, + referenceId, + ...params?.params?.metadata, + }, + }) + .catch(async (e) => { + throw ctx.error("BAD_REQUEST", { + message: e.message, + code: e.code, + }); + }); + return ctx.json({ + ...checkoutSession, + redirect: !ctx.body.disableRedirect, + }); + }, + ), + cancelSubscription: createAuthEndpoint( + "/subscription/cancel", + { + method: "POST", + body: z.object({ + referenceId: z.string().optional(), + returnUrl: z.string(), + }), + use: [ + sessionMiddleware, + originCheck((ctx) => ctx.body.returnUrl), + referenceMiddleware("cancel-subscription"), + ], + }, + async (ctx) => { + const referenceId = + ctx.body?.referenceId || ctx.context.session.user.id; + const subscription = await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "referenceId", + value: referenceId, + }, + ], + }); + if (!subscription || !subscription.stripeCustomerId) { + throw ctx.error("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND, + }); + } + const activeSubscription = await client.subscriptions + .list({ + customer: subscription.stripeCustomerId, + status: "active", + }) + .then((res) => res.data[0]); + if (!activeSubscription) { + throw ctx.error("BAD_REQUEST", { + message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND, + }); + } + const { url } = await client.billingPortal.sessions.create({ + customer: subscription.stripeCustomerId, + return_url: getUrl(ctx, ctx.body?.returnUrl || "/"), + flow_data: { + type: "subscription_cancel", + subscription_cancel: { + subscription: activeSubscription.id, + }, + }, + }); + return { + url, + redirect: true, + }; + }, + ), + listActiveSubscriptions: createAuthEndpoint( + "/subscription/list", + { + method: "GET", + query: z.optional( + z.object({ + referenceId: z.string().optional(), + }), + ), + use: [sessionMiddleware, referenceMiddleware("list-subscription")], + }, + async (ctx) => { + const subscriptions = await ctx.context.adapter.findMany({ + model: "subscription", + where: [ + { + field: "referenceId", + value: ctx.query?.referenceId || ctx.context.session.user.id, + }, + ], + }); + if (!subscriptions.length) { + return []; + } + const plans = await getPlans(options); + if (!plans) { + return []; + } + const subs = subscriptions + .map((sub) => { + const plan = plans.find( + (p) => p.name.toLowerCase() === sub.plan.toLowerCase(), + ); + return { + ...sub, + limits: plan?.limits, + }; + }) + .filter((sub) => { + return sub.status === "active" || sub.status === "trialing"; + }); + return ctx.json(subs); + }, + ), + subscriptionSuccess: createAuthEndpoint( + "/subscription/success", + { + method: "GET", + query: z.record(z.string(), z.any()).optional(), + }, + async (ctx) => { + if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) { + throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/")); + } + const session = await getSessionFromCtx<{ stripeCustomerId: string }>( + ctx, + ); + if (!session) { + throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/")); + } + const { user } = session; + const { callbackURL, reference } = ctx.query; + + const subscriptions = await ctx.context.adapter.findMany({ + model: "subscription", + where: [ + { + field: "referenceId", + value: reference, + }, + ], + }); + + const activeSubscription = subscriptions.find( + (sub) => sub.status === "active" || sub.status === "trialing", + ); + + if (activeSubscription) { + return ctx.redirect(getUrl(ctx, callbackURL)); + } + + if (user?.stripeCustomerId) { + try { + const subscription = + await ctx.context.adapter.findOne({ + model: "subscription", + where: [ + { + field: "referenceId", + value: reference, + }, + ], + }); + if (!subscription || subscription.status === "active") { + throw ctx.redirect(getUrl(ctx, callbackURL)); + } + const stripeSubscription = await client.subscriptions + .list({ + customer: user.stripeCustomerId, + status: "active", + }) + .then((res) => res.data[0]); + + if (stripeSubscription) { + const plan = await getPlanByPriceId( + options, + stripeSubscription.items.data[0]?.plan.id, + ); + + if (plan && subscriptions.length > 0) { + await ctx.context.adapter.update({ + model: "subscription", + update: { + status: stripeSubscription.status, + seats: stripeSubscription.items.data[0]?.quantity || 1, + plan: plan.name.toLowerCase(), + }, + where: [ + { + field: "referenceId", + value: reference, + }, + ], + }); + } + } + } catch (error) { + ctx.context.logger.error( + "Error fetching subscription from Stripe", + error, + ); + } + } + throw ctx.redirect(getUrl(ctx, callbackURL)); + }, + ), + } as const; + return { + id: "stripe", + endpoints: { + stripeWebhook: createAuthEndpoint( + "/stripe/webhook", + { + method: "POST", + metadata: { + isAction: false, + }, + cloneRequest: true, + }, + async (ctx) => { + if (!ctx.request?.body) { + throw new APIError("INTERNAL_SERVER_ERROR"); + } + const buf = await ctx.request.text(); + const sig = ctx.request.headers.get("stripe-signature") as string; + const webhookSecret = options.stripeWebhookSecret; + let event: Stripe.Event; + try { + if (!sig || !webhookSecret) { + throw new APIError("BAD_REQUEST", { + message: "Stripe webhook secret not found", + }); + } + event = client.webhooks.constructEvent(buf, sig, webhookSecret); + } catch (err: any) { + ctx.context.logger.error(`${err.message}`); + throw new APIError("BAD_REQUEST", { + message: `Webhook Error: ${err.message}`, + }); + } + try { + switch (event.type) { + case "checkout.session.completed": + await onCheckoutSessionCompleted(ctx, options, event); + await options.onEvent?.(event); + break; + case "customer.subscription.updated": + await onSubscriptionUpdated(ctx, options, event); + await options.onEvent?.(event); + break; + case "customer.subscription.deleted": + await onSubscriptionDeleted(ctx, options, event); + await options.onEvent?.(event); + break; + default: + await options.onEvent?.(event); + break; + } + } catch (e: any) { + ctx.context.logger.error( + `Stripe webhook failed. Error: ${e.message}`, + ); + throw new APIError("BAD_REQUEST", { + message: "Webhook error: See server logs for more information.", + }); + } + return ctx.json({ success: true }); + }, + ), + ...((options.subscription?.enabled + ? subscriptionEndpoints + : {}) as O["subscription"] extends { + enabled: boolean; + } + ? typeof subscriptionEndpoints + : {}), + }, + init(ctx) { + return { + options: { + databaseHooks: { + user: { + create: { + async after(user, ctx) { + if (ctx && options.createCustomerOnSignUp) { + const stripeCustomer = await client.customers.create({ + email: user.email, + name: user.name, + metadata: { + userId: user.id, + }, + }); + await ctx.context.adapter.update({ + model: "user", + update: { + stripeCustomerId: stripeCustomer.id, + }, + where: [ + { + field: "id", + value: user.id, + }, + ], + }); + } + }, + }, + }, + }, + }, + }; + }, + schema: getSchema(options), + } satisfies BetterAuthPlugin; +}; + +export type { Subscription }; diff --git a/packages/stripe/src/schema.ts b/packages/stripe/src/schema.ts new file mode 100644 index 00000000..d4f8fd79 --- /dev/null +++ b/packages/stripe/src/schema.ts @@ -0,0 +1,62 @@ +import type { AuthPluginSchema } from "better-auth"; +import type { StripeOptions } from "./types"; + +export const getSchema = (options: StripeOptions) => { + const subscriptions = { + subscription: { + fields: { + plan: { + type: "string", + required: true, + }, + referenceId: { + type: "string", + required: true, + }, + stripeCustomerId: { + type: "string", + required: false, + }, + stripeSubscriptionId: { + type: "string", + required: false, + }, + status: { + type: "string", + defaultValue: "incomplete", + }, + periodStart: { + type: "date", + required: false, + }, + periodEnd: { + type: "date", + required: false, + }, + cancelAtPeriodEnd: { + type: "boolean", + required: false, + defaultValue: false, + }, + seats: { + type: "number", + required: false, + }, + }, + }, + } satisfies AuthPluginSchema; + const user = { + user: { + fields: { + stripeCustomerId: { + type: "string", + required: false, + }, + }, + }, + } satisfies AuthPluginSchema; + return { + ...(options.subscription?.enabled ? subscriptions : {}), + ...user, + } as typeof user & typeof subscriptions; +}; diff --git a/packages/stripe/src/stripe.test.ts b/packages/stripe/src/stripe.test.ts new file mode 100644 index 00000000..0d78edd9 --- /dev/null +++ b/packages/stripe/src/stripe.test.ts @@ -0,0 +1,456 @@ +import { betterAuth, type User } from "better-auth"; +import { memoryAdapter } from "better-auth/adapters/memory"; +import { createAuthClient } from "better-auth/client"; +import { setCookieToHeader } from "better-auth/cookies"; +import { bearer } from "better-auth/plugins"; +import Stripe from "stripe"; +import { vi } from "vitest"; +import { stripe } from "."; +import { stripeClient } from "./client"; +import type { StripeOptions, Subscription } from "./types"; + +describe("stripe", async () => { + const mockStripe = { + customers: { + create: vi.fn().mockResolvedValue({ id: "cus_mock123" }), + }, + checkout: { + sessions: { + create: vi.fn().mockResolvedValue({ + url: "https://checkout.stripe.com/mock", + id: "", + }), + }, + }, + billingPortal: { + sessions: { + create: vi + .fn() + .mockResolvedValue({ url: "https://billing.stripe.com/mock" }), + }, + }, + subscriptions: { + retrieve: vi.fn(), + list: vi.fn().mockResolvedValue({ data: [] }), + }, + webhooks: { + constructEvent: vi.fn(), + }, + }; + + const _stripe = mockStripe as unknown as Stripe; + const data = { + user: [], + session: [], + verification: [], + account: [], + customer: [], + subscription: [], + }; + const memory = memoryAdapter(data); + const stripeOptions = { + stripeClient: _stripe, + stripeWebhookSecret: process.env.STRIPE_WEBHOOK_SECRET!, + createCustomerOnSignUp: true, + subscription: { + enabled: true, + plans: [ + { + priceId: process.env.STRIPE_PRICE_ID_1!, + name: "starter", + }, + { + priceId: process.env.STRIPE_PRICE_ID_2!, + name: "premium", + }, + ], + }, + } satisfies StripeOptions; + const auth = betterAuth({ + database: memory, + baseURL: "http://localhost:3000", + // database: new Database(":memory:"), + emailAndPassword: { + enabled: true, + }, + plugins: [stripe(stripeOptions)], + }); + const ctx = await auth.$context; + const authClient = createAuthClient({ + baseURL: "http://localhost:3000", + plugins: [ + bearer(), + stripeClient({ + subscription: true, + }), + ], + fetchOptions: { + customFetchImpl: async (url, init) => { + return auth.handler(new Request(url, init)); + }, + }, + }); + + const testUser = { + email: "test@email.com", + password: "password", + name: "Test User", + }; + + beforeEach(() => { + data.user = []; + data.session = []; + data.verification = []; + data.account = []; + data.customer = []; + data.subscription = []; + + vi.clearAllMocks(); + }); + + async function getHeader() { + const headers = new Headers(); + const userRes = await authClient.signIn.email(testUser, { + throw: true, + onSuccess: setCookieToHeader(headers), + }); + return { + headers, + response: userRes, + }; + } + + it("should create a customer on sign up", async () => { + const userRes = await authClient.signUp.email(testUser, { + throw: true, + }); + const res = await ctx.adapter.findOne({ + model: "user", + where: [ + { + field: "id", + value: userRes.user.id, + }, + ], + }); + expect(res).toMatchObject({ + id: expect.any(String), + stripeCustomerId: expect.any(String), + }); + }); + + it("should create a subscription", async () => { + const userRes = await authClient.signUp.email(testUser, { + throw: true, + }); + + const headers = new Headers(); + await authClient.signIn.email(testUser, { + throw: true, + onSuccess: setCookieToHeader(headers), + }); + + const res = await authClient.subscription.upgrade({ + plan: "starter", + fetchOptions: { + headers, + }, + }); + expect(res.data?.url).toBeDefined(); + const subscription = await ctx.adapter.findOne({ + model: "subscription", + where: [ + { + field: "referenceId", + value: userRes.user.id, + }, + ], + }); + expect(subscription).toMatchObject({ + id: expect.any(String), + plan: "starter", + referenceId: userRes.user.id, + stripeCustomerId: expect.any(String), + status: "incomplete", + periodStart: undefined, + cancelAtPeriodEnd: undefined, + }); + }); + + it("should list active subscriptions", async () => { + const userRes = await authClient.signUp.email( + { + ...testUser, + email: "list-test@email.com", + }, + { + throw: true, + }, + ); + const userId = userRes.user.id; + + const headers = new Headers(); + await authClient.signIn.email( + { + ...testUser, + email: "list-test@email.com", + }, + { + throw: true, + onSuccess: setCookieToHeader(headers), + }, + ); + + const listRes = await authClient.subscription.list({ + fetchOptions: { + headers, + }, + }); + + expect(Array.isArray(listRes.data)).toBe(true); + + await authClient.subscription.upgrade({ + plan: "starter", + fetchOptions: { + headers, + }, + }); + const listBeforeActive = await authClient.subscription.list({ + fetchOptions: { + headers, + }, + }); + expect(listBeforeActive.data?.length).toBe(0); + // Update the subscription status to active + await ctx.adapter.update({ + model: "subscription", + update: { + status: "active", + }, + where: [ + { + field: "referenceId", + value: userId, + }, + ], + }); + const listAfterRes = await authClient.subscription.list({ + fetchOptions: { + headers, + }, + }); + expect(listAfterRes.data?.length).toBeGreaterThan(0); + }); + + it("should handle subscription webhook events", async () => { + const testSubscriptionId = "sub_123456"; + const testReferenceId = "user_123"; + await ctx.adapter.create({ + model: "user", + data: { + id: testReferenceId, + email: "test@email.com", + }, + }); + await ctx.adapter.create({ + model: "subscription", + data: { + id: testSubscriptionId, + referenceId: testReferenceId, + stripeCustomerId: "cus_mock123", + status: "active", + plan: "starter", + }, + }); + const mockCheckoutSessionEvent = { + type: "checkout.session.completed", + data: { + object: { + mode: "subscription", + subscription: testSubscriptionId, + metadata: { + referenceId: testReferenceId, + subscriptionId: testSubscriptionId, + }, + }, + }, + }; + + const mockSubscription = { + id: testSubscriptionId, + status: "active", + items: { + data: [ + { + price: { id: process.env.STRIPE_PRICE_ID_1 }, + quantity: 1, + }, + ], + }, + current_period_start: Math.floor(Date.now() / 1000), + current_period_end: Math.floor(Date.now() / 1000) + 30 * 24 * 60 * 60, + }; + + const stripeForTest = { + ...stripeOptions.stripeClient, + subscriptions: { + ...stripeOptions.stripeClient.subscriptions, + retrieve: vi.fn().mockResolvedValue(mockSubscription), + }, + webhooks: { + constructEvent: vi.fn().mockReturnValue(mockCheckoutSessionEvent), + }, + }; + + const testOptions = { + ...stripeOptions, + stripeClient: stripeForTest as unknown as Stripe, + stripeWebhookSecret: "test_secret", + }; + + const testAuth = betterAuth({ + baseURL: "http://localhost:3000", + database: memory, + emailAndPassword: { + enabled: true, + }, + plugins: [stripe(testOptions)], + }); + + const testCtx = await testAuth.$context; + + const mockRequest = new Request( + "http://localhost:3000/api/auth/stripe/webhook", + { + method: "POST", + headers: { + "stripe-signature": "test_signature", + }, + body: JSON.stringify(mockCheckoutSessionEvent), + }, + ); + const response = await testAuth.handler(mockRequest); + expect(response.status).toBe(200); + + const updatedSubscription = await testCtx.adapter.findOne({ + model: "subscription", + where: [ + { + field: "id", + value: testSubscriptionId, + }, + ], + }); + expect(updatedSubscription).toMatchObject({ + id: testSubscriptionId, + status: "active", + periodStart: expect.any(Date), + periodEnd: expect.any(Date), + plan: "starter", + }); + }); + + it("should handle subscription deletion webhook", async () => { + const userId = "test_user"; + const subId = "test_sub_delete"; + + await ctx.adapter.create({ + model: "user", + data: { + id: userId, + email: "delete-test@email.com", + }, + }); + + await ctx.adapter.create({ + model: "subscription", + data: { + id: subId, + referenceId: userId, + stripeCustomerId: "cus_delete_test", + status: "active", + plan: "starter", + }, + }); + + const subscription = await ctx.adapter.findOne({ + model: "subscription", + where: [ + { + field: "referenceId", + value: userId, + }, + ], + }); + + const mockDeleteEvent = { + type: "customer.subscription.deleted", + data: { + object: { + id: "sub_deleted", + customer: subscription?.stripeCustomerId, + status: "canceled", + metadata: { + referenceId: subscription?.referenceId, + subscriptionId: subscription?.id, + }, + }, + }, + }; + + const stripeForTest = { + ...stripeOptions.stripeClient, + webhooks: { + constructEvent: vi.fn().mockReturnValue(mockDeleteEvent), + }, + subscriptions: { + retrieve: vi.fn().mockResolvedValue({ + status: "canceled", + id: subId, + }), + }, + }; + + const testOptions = { + ...stripeOptions, + stripeClient: stripeForTest as unknown as Stripe, + stripeWebhookSecret: "test_secret", + }; + + const testAuth = betterAuth({ + baseURL: "http://localhost:3000", + emailAndPassword: { + enabled: true, + }, + database: memory, + plugins: [stripe(testOptions)], + }); + + const mockRequest = new Request( + "http://localhost:3000/api/auth/stripe/webhook", + { + method: "POST", + headers: { + "stripe-signature": "test_signature", + }, + body: JSON.stringify(mockDeleteEvent), + }, + ); + + const response = await testAuth.handler(mockRequest); + expect(response.status).toBe(200); + + if (subscription) { + const updatedSubscription = await ctx.adapter.findOne({ + model: "subscription", + where: [ + { + field: "id", + value: subscription.id, + }, + ], + }); + expect(updatedSubscription?.status).toBe("canceled"); + } + }); +}); diff --git a/packages/stripe/src/types.ts b/packages/stripe/src/types.ts new file mode 100644 index 00000000..5522776d --- /dev/null +++ b/packages/stripe/src/types.ts @@ -0,0 +1,323 @@ +import type { Session, User } from "better-auth"; +import type Stripe from "stripe"; + +export type Plan = { + /** + * Monthly price id + */ + priceId?: string; + /** + * To use lookup key instead of price id + * + * https://docs.stripe.com/products-prices/ + * manage-prices#lookup-keys + */ + lookupKey?: string; + /** + * A yearly discount price id + * + * useful when you want to offer a discount for + * yearly subscription + */ + annualDiscountPriceId?: string; + /** + * Plan name + */ + name: string; + /** + * Limits for the plan + */ + limits?: Record; + /** + * Plan group name + * + * useful when you want to group plans or + * when a user can subscribe to multiple plans. + */ + group?: string; + /** + * Free trial days + */ + freeTrial?: { + /** + * Number of days + */ + days: number; + /** + * Only available for new users or users without existing subscription + * + * @default true + */ + forNewUsersOnly?: boolean; + /** + * A function that will be called when the trial + * starts. + * + * @param subscription + * @returns + */ + onTrialStart?: (subscription: Subscription) => Promise; + /** + * A function that will be called when the trial + * ends + * + * @param subscription - Subscription + * @returns + */ + onTrialEnd?: ( + data: { + subscription: Subscription; + user: User & Record; + }, + request?: Request, + ) => Promise; + /** + * A function that will be called when the trial + * expired. + * @param subscription - Subscription + * @returns + */ + onTrialExpired?: (subscription: Subscription) => Promise; + }; +}; + +export interface Subscription { + /** + * Database identifier + */ + id: string; + /** + * The plan name + */ + plan: string; + /** + * Stripe customer id + */ + stripeCustomerId?: string; + /** + * Stripe subscription id + */ + stripeSubscriptionId?: string; + /** + * Trial start date + */ + trialStart?: Date; + /** + * Trial end date + */ + trialEnd?: Date; + /** + * Price Id for the subscription + */ + priceId?: string; + /** + * To what reference id the subscription belongs to + * @example + * - userId for a user + * - workspace id for a saas platform + * - website id for a hosting platform + * + * @default - userId + */ + referenceId: string; + /** + * Subscription status + */ + status: + | "active" + | "canceled" + | "incomplete" + | "incomplete_expired" + | "past_due" + | "paused" + | "trialing" + | "unpaid"; + /** + * The billing cycle start date + */ + periodStart?: Date; + /** + * The billing cycle end date + */ + periodEnd?: Date; + /** + * Cancel at period end + */ + cancelAtPeriodEnd?: boolean; + /** + * A field to group subscriptions so you can have multiple subscriptions + * for one reference id + */ + groupId?: string; + /** + * Number of seats for the subscription (useful for team plans) + */ + seats?: number; +} + +export interface StripeOptions { + /** + * Stripe Client + */ + stripeClient: Stripe; + /** + * Stripe Webhook Secret + * + * @description Stripe webhook secret key + */ + stripeWebhookSecret: string; + /** + * Enable customer creation when a user signs up + */ + createCustomerOnSignUp?: boolean; + /** + * A callback to run after a customer has been created + * @param customer - Customer Data + * @param stripeCustomer - Stripe Customer Data + * @returns + */ + onCustomerCreate?: ( + data: { + customer: Customer; + stripeCustomer: Stripe.Customer; + user: User; + }, + request?: Request, + ) => Promise; + /** + * A custom function to get the customer create + * params + * @param data - data containing user and session + * @returns + */ + getCustomerCreateParams?: ( + data: { + user: User; + session: Session; + }, + request?: Request, + ) => Promise<{}>; + /** + * Subscriptions + */ + subscription?: { + enabled: boolean; + /** + * Subscription Configuration + */ + /** + * List of plan + */ + plans: Plan[] | (() => Promise); + /** + * Require email verification before a user is allowed to upgrade + * their subscriptions + * + * @default false + */ + requireEmailVerification?: boolean; + /** + * A callback to run after a user has subscribed to a package + * @param event - Stripe Event + * @param subscription - Subscription Data + * @returns + */ + onSubscriptionComplete?: ( + data: { + event: Stripe.Event; + stripeSubscription: Stripe.Subscription; + subscription: Subscription; + plan: Plan; + }, + request?: Request, + ) => Promise; + /** + * A callback to run after a user is about to cancel their subscription + * @returns + */ + onSubscriptionUpdate?: (data: { + event: Stripe.Event; + subscription: Subscription; + }) => Promise; + /** + * A callback to run after a user is about to cancel their subscription + * @returns + */ + onSubscriptionCancel?: (data: { + event: Stripe.Event; + subscription: Subscription; + stripeSubscription: Stripe.Subscription; + cancellationDetails?: Stripe.Subscription.CancellationDetails; + }) => Promise; + /** + * A function to check if the reference id is valid + * and belongs to the user + * + * @param data - data containing user, session and referenceId + * @param request - Request Object + * @returns + */ + authorizeReference?: ( + data: { + user: User & Record; + session: Session & Record; + referenceId: string; + action: + | "upgrade-subscription" + | "list-subscription" + | "cancel-subscription"; + }, + request?: Request, + ) => Promise; + /** + * A callback to run after a user has deleted their subscription + * @returns + */ + onSubscriptionDeleted?: (data: { + event: Stripe.Event; + stripeSubscription: Stripe.Subscription; + subscription: Subscription; + }) => Promise; + /** + * parameters for session create params + * + * @param data - data containing user, session and plan + * @param request - Request Object + */ + getCheckoutSessionParams?: ( + data: { + user: User & Record; + session: Session & Record; + plan: Plan; + subscription: Subscription; + }, + request?: Request, + ) => + | Promise<{ + params?: Stripe.Checkout.SessionCreateParams; + options?: Stripe.RequestOptions; + }> + | { + params?: Stripe.Checkout.SessionCreateParams; + options?: Stripe.RequestOptions; + }; + /** + * Enable organization subscription + */ + organization?: { + enabled: boolean; + }; + }; + onEvent?: (event: Stripe.Event) => Promise; +} + +export interface Customer { + id: string; + stripeCustomerId?: string; + userId: string; + createdAt: Date; + updatedAt: Date; +} + +export interface InputSubscription extends Omit {} +export interface InputCustomer extends Omit {} diff --git a/packages/stripe/src/utils.ts b/packages/stripe/src/utils.ts new file mode 100644 index 00000000..5991e66d --- /dev/null +++ b/packages/stripe/src/utils.ts @@ -0,0 +1,22 @@ +import type { StripeOptions } from "./types"; + +export async function getPlans(options: StripeOptions) { + return typeof options?.subscription?.plans === "function" + ? await options.subscription?.plans() + : options.subscription?.plans; +} + +export async function getPlanByPriceId( + options: StripeOptions, + priceId: string, +) { + return await getPlans(options).then((res) => + res?.find((plan) => plan.priceId === priceId), + ); +} + +export async function getPlanByName(options: StripeOptions, name: string) { + return await getPlans(options).then((res) => + res?.find((plan) => plan.name.toLowerCase() === name.toLowerCase()), + ); +} diff --git a/packages/stripe/tsconfig.json b/packages/stripe/tsconfig.json new file mode 100644 index 00000000..5eae3aee --- /dev/null +++ b/packages/stripe/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "esModuleInterop": true, + "skipLibCheck": true, + "target": "es2022", + "allowJs": true, + "resolveJsonModule": true, + "module": "ESNext", + "noEmit": true, + "moduleResolution": "Bundler", + "moduleDetection": "force", + "isolatedModules": true, + "verbatimModuleSyntax": true, + "strict": true, + "noImplicitOverride": true, + "noFallthroughCasesInSwitch": true + }, + "exclude": ["node_modules", "dist"], + "include": ["src"] +} diff --git a/packages/stripe/vitest.config.ts b/packages/stripe/vitest.config.ts new file mode 100644 index 00000000..2fd32219 --- /dev/null +++ b/packages/stripe/vitest.config.ts @@ -0,0 +1,10 @@ +import { defineConfig } from "vitest/config"; + +export default defineConfig({ + root: ".", + test: { + clearMocks: true, + globals: true, + setupFiles: ["dotenv/config"], + }, +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 6021091d..4626159a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -63,6 +63,9 @@ importers: demo/nextjs: dependencies: + '@better-auth/stripe': + specifier: workspace:* + version: link:../../packages/stripe '@better-fetch/fetch': specifier: 'catalog:' version: 1.1.15 @@ -75,6 +78,9 @@ importers: '@libsql/kysely-libsql': specifier: ^0.4.1 version: 0.4.1(kysely@0.27.4) + '@number-flow/react': + specifier: ^0.5.5 + version: 0.5.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1) '@prisma/adapter-libsql': specifier: ^5.22.0 version: 5.22.0(@libsql/client@0.12.0) @@ -186,6 +192,9 @@ importers: better-sqlite3: specifier: ^11.6.0 version: 11.6.0 + canvas-confetti: + specifier: ^1.9.3 + version: 1.9.3 class-variance-authority: specifier: ^0.7.1 version: 0.7.1 @@ -286,6 +295,9 @@ importers: specifier: ^3.23.8 version: 3.24.1 devDependencies: + '@types/canvas-confetti': + specifier: ^1.9.0 + version: 1.9.0 '@types/node': specifier: ^20.17.9 version: 20.17.9 @@ -1719,6 +1731,31 @@ importers: specifier: ^1.6.0 version: 1.6.0(@types/node@22.10.7)(happy-dom@15.11.7)(less@4.2.1)(lightningcss@1.27.0)(sass@1.83.1)(terser@5.36.0) + packages/stripe: + dependencies: + better-auth: + specifier: workspace:^ + version: link:../better-auth + zod: + specifier: ^3.24.1 + version: 3.24.2 + devDependencies: + '@types/better-sqlite3': + specifier: ^7.6.12 + version: 7.6.12 + better-call: + specifier: 'catalog:' + version: 1.0.3 + better-sqlite3: + specifier: ^11.6.0 + version: 11.6.0 + stripe: + specifier: ^17.7.0 + version: 17.7.0 + vitest: + specifier: ^1.6.0 + version: 1.6.0(@types/node@22.10.7)(happy-dom@15.11.7)(less@4.2.1)(lightningcss@1.27.0)(sass@1.83.1)(terser@5.36.0) + packages: '@0no-co/graphql.web@1.0.13': @@ -5394,6 +5431,12 @@ packages: resolution: {integrity: sha512-gGq0NJkIGSwdbUt4yhdF8ZrmkGKVz9vAdVzpOfnom+V8PLSmSOVhZwbNvZZS1EYcJN5hzzKBxmmVVAInM6HQLg==} engines: {node: ^14.17.0 || ^16.13.0 || >=18.0.0} + '@number-flow/react@0.5.5': + resolution: {integrity: sha512-Zdju5n0osxrb+7jbcpUJ9L2VJ2+9ptwjz5+A+2wq9Q32hs3PW/noPJjHtLTrtGINM9mEw76DcDg0ac/dx6j1aA==} + peerDependencies: + react: ^18 || ^19 + react-dom: ^18 || ^19 + '@nuxt/devalue@2.0.2': resolution: {integrity: sha512-GBzP8zOc7CGWyFQS6dv1lQz8VVpz5C2yRszbXufwG/9zhStTIH50EtD87NmWbTMwXDvZLNg8GIpb1UFdH93JCA==} @@ -8964,6 +9007,9 @@ packages: '@types/bun@1.2.4': resolution: {integrity: sha512-QtuV5OMR8/rdKJs213iwXDpfVvnskPXY/S0ZiFbsTjQZycuqPbMW8Gf/XhLfwE5njW8sxI2WjISURXPlHypMFA==} + '@types/canvas-confetti@1.9.0': + resolution: {integrity: sha512-aBGj/dULrimR1XDZLtG9JwxX1b4HPRF6CX9Yfwh3NvstZEm1ZL7RBnel4keCPSqs1ANRu1u2Aoz9R+VmtjYuTg==} + '@types/chrome@0.0.258': resolution: {integrity: sha512-vicJi6cg2zaFuLmLY7laG6PHBknjKFusPYlaKQ9Zlycskofy71rStlGvW07MUuqUIVorZf8k5KH+zeTTGcH2dQ==} @@ -10736,6 +10782,9 @@ packages: caniuse-lite@1.0.30001676: resolution: {integrity: sha512-Qz6zwGCiPghQXGJvgQAem79esjitvJ+CxSbSQkW9H/UX5hg8XM88d4lp2W+MEQ81j+Hip58Il+jGVdazk1z9cw==} + canvas-confetti@1.9.3: + resolution: {integrity: sha512-rFfTURMvmVEX1gyXFgn5QMn81bYk70qa0HLzcIOSVEyl57n6o9ItHeBtUSWdvKAPY0xlvBHno4/v3QPrT83q9g==} + capture-stack-trace@1.0.2: resolution: {integrity: sha512-X/WM2UQs6VMHUtjUDnZTRI+i1crWteJySFzr9UpGoQa4WQffXVTTXuekjl7TjZRlcF2XfjgITT0HxZ9RnxeT0w==} engines: {node: '>=0.10.0'} @@ -16213,6 +16262,9 @@ packages: nullthrows@1.1.1: resolution: {integrity: sha512-2vPPEi+Z7WqML2jZYddDIfy5Dqb0r2fze2zTxNNknZaFpVHU3mFB3R+DWeJWGVx0ecvttSGlJTI+WG+8Z4cDWw==} + number-flow@0.5.3: + resolution: {integrity: sha512-iLKyssImNWQmJ41rza9K7P5lHRZTyishi/9FarWPLQHYY2Ydtl6eiXINEjZ1fa8dHeY0O7+YOD+Py3ZsJddYkg==} + nuxi@3.15.0: resolution: {integrity: sha512-ZVu45nuDrdb7nzKW2kLGY/N1vvFYLLbUVX6gUYw4BApKGGu4+GktTR5o48dGVgMYX9A8chaugl7TL9ZYmwC9Mg==} engines: {node: ^16.10.0 || >=18.0.0} @@ -18698,6 +18750,10 @@ packages: strip-literal@2.1.1: resolution: {integrity: sha512-631UJ6O00eNGfMiWG78ck80dfBab8X6IVFB51jZK5Icd7XAs60Z5y7QdSd/wGIklnWvRbUNloVzhOKKmutxQ6Q==} + stripe@17.7.0: + resolution: {integrity: sha512-aT2BU9KkizY9SATf14WhhYVv2uOapBWX0OFWF4xvcj1mPaNotlSc2CsxpS4DS46ZueSppmCF5BX1sNYBtwBvfw==} + engines: {node: '>=12.*'} + striptags@3.2.0: resolution: {integrity: sha512-g45ZOGzHDMe2bdYMdIvdAfCQkCTDMGBazSw1ypMowwGIee7ZQ5dU0rBJ8Jqgl+jAKIv4dbeE1jscZq9wid1Tkw==} @@ -24427,6 +24483,13 @@ snapshots: dependencies: which: 3.0.1 + '@number-flow/react@0.5.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + esm-env: 1.2.1 + number-flow: 0.5.3 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + '@nuxt/devalue@2.0.2': {} '@nuxt/devtools-kit@1.6.0(magicast@0.3.5)(rollup@4.31.0)(vite@5.4.14(@types/node@22.10.7)(less@4.2.1)(lightningcss@1.27.0)(sass@1.83.1)(terser@5.36.0))': @@ -24442,7 +24505,7 @@ snapshots: '@nuxt/devtools-wizard@1.6.0': dependencies: - consola: 3.2.3 + consola: 3.4.0 diff: 7.0.0 execa: 7.2.0 global-directory: 4.0.1 @@ -29287,6 +29350,8 @@ snapshots: dependencies: bun-types: 1.2.4 + '@types/canvas-confetti@1.9.0': {} + '@types/chrome@0.0.258': dependencies: '@types/filesystem': 0.0.36 @@ -32018,6 +32083,8 @@ snapshots: caniuse-lite@1.0.30001676: {} + canvas-confetti@1.9.3: {} + capture-stack-trace@1.0.2: {} ccount@2.0.1: {} @@ -39109,6 +39176,10 @@ snapshots: nullthrows@1.1.1: {} + number-flow@0.5.3: + dependencies: + esm-env: 1.2.1 + nuxi@3.15.0: {} nuxt@3.14.1592(@azure/identity@4.6.0)(@biomejs/biome@1.9.4)(@libsql/client@0.12.0)(@parcel/watcher@2.4.1)(@types/node@22.10.7)(better-sqlite3@11.6.0)(drizzle-orm@0.39.3(@cloudflare/workers-types@4.20250214.0)(@libsql/client-wasm@0.14.0)(@libsql/client@0.12.0)(@prisma/client@5.22.0(prisma@5.22.0))(@types/better-sqlite3@7.6.12)(@types/pg@8.11.10)(better-sqlite3@11.6.0)(bun-types@1.2.4)(kysely@0.27.4)(mysql2@3.11.5)(pg@8.13.1)(prisma@5.22.0))(encoding@0.1.13)(eslint@8.57.1)(ioredis@5.4.1)(less@4.2.1)(lightningcss@1.27.0)(magicast@0.3.5)(mysql2@3.11.5)(optionator@0.9.4)(rollup@4.31.0)(sass@1.83.1)(terser@5.36.0)(typescript@5.7.2)(vite@5.4.14(@types/node@22.10.7)(less@4.2.1)(lightningcss@1.27.0)(sass@1.83.1)(terser@5.36.0)): @@ -42487,6 +42558,11 @@ snapshots: dependencies: js-tokens: 9.0.1 + stripe@17.7.0: + dependencies: + '@types/node': 22.10.7 + qs: 6.13.0 + striptags@3.2.0: {} strnum@1.0.5: {} @@ -42671,8 +42747,8 @@ snapshots: superstruct: 2.0.2 valibot: 1.0.0-beta.15(typescript@5.7.2) yup: 1.4.0 - zod: 3.24.1 - zod-to-json-schema: 3.23.5(zod@3.24.1) + zod: 3.24.2 + zod-to-json-schema: 3.23.5(zod@3.24.2) transitivePeerDependencies: - '@types/json-schema' - typescript @@ -44137,7 +44213,7 @@ snapshots: dependencies: esbuild: 0.24.2 postcss: 8.4.49 - rollup: 4.31.0 + rollup: 4.34.8 optionalDependencies: '@types/node': 22.10.7 fsevents: 2.3.3 @@ -44735,6 +44811,11 @@ snapshots: dependencies: zod: 3.24.1 + zod-to-json-schema@3.23.5(zod@3.24.2): + dependencies: + zod: 3.24.2 + optional: true + zod-to-ts@1.2.0(typescript@5.7.2)(zod@3.24.1): dependencies: typescript: 5.7.2