diff --git a/demo/nextjs/lib/auth-client.ts b/demo/nextjs/lib/auth-client.ts index 81365f20..913b8658 100644 --- a/demo/nextjs/lib/auth-client.ts +++ b/demo/nextjs/lib/auth-client.ts @@ -53,5 +53,3 @@ export const { useListOrganizations, useActiveOrganization, } = client; - -client.$store.listen("$sessionSignal", async () => {}); diff --git a/docs/app/blog/[[...slug]]/page.tsx b/docs/app/blog/[[...slug]]/page.tsx index 7cf64d11..37ad51df 100644 --- a/docs/app/blog/[[...slug]]/page.tsx +++ b/docs/app/blog/[[...slug]]/page.tsx @@ -22,6 +22,7 @@ import { BookIcon, GitHubIcon, XIcon } from "../_components/icons"; import { DiscordLogoIcon } from "@radix-ui/react-icons"; import { StarField } from "../_components/stat-field"; import Image from "next/image"; +import { BlogPage } from "../_components/blog-list"; const metaTitle = "Blogs"; const metaDescription = "Latest changes , fixes and updates."; @@ -33,6 +34,9 @@ export default async function Page({ params: Promise<{ slug?: string[] }>; }) { const { slug } = await params; + if (!slug) { + return ; + } const page = blogs.getPage(slug); if (!page) { notFound(); @@ -41,31 +45,33 @@ export default async function Page({ const toc = page.data?.toc; const { title, description, date } = page.data; return ( -
-
+
+
-
-
- - - + +
+
+ + + +
+

+ {title}{" "} +

-

- {title}{" "} -

-
+

{description}

@@ -111,7 +117,7 @@ export default async function Page({

-
+
diff --git a/docs/app/page.tsx b/docs/app/page.tsx index 059b28f2..0b73b761 100644 --- a/docs/app/page.tsx +++ b/docs/app/page.tsx @@ -34,24 +34,22 @@ export default async function HomePage() {
- Introducing{" "} - - Better Auth Infrastructure - + Announcing Our{" "} + $5M seed round | - Join the waitlist → + Read more → - Join the waitlist → + Read more →
diff --git a/docs/components/nav-bar.tsx b/docs/components/nav-bar.tsx index 4bdef964..6d5ac7e5 100644 --- a/docs/components/nav-bar.tsx +++ b/docs/components/nav-bar.tsx @@ -120,6 +120,10 @@ export const navMenu = [ name: "changelogs", path: "/changelogs", }, + { + name: "blogs", + path: "/blog", + }, { name: "community", path: "/community", diff --git a/docs/components/nav-mobile.tsx b/docs/components/nav-mobile.tsx index b660c893..5df43f39 100644 --- a/docs/components/nav-mobile.tsx +++ b/docs/components/nav-mobile.tsx @@ -217,6 +217,10 @@ export const navMenu: { name: "changelogs", path: "/changelogs", }, + { + name: "blogs", + path: "/blog", + }, { name: "community", path: "/community", diff --git a/docs/components/sidebar-content.tsx b/docs/components/sidebar-content.tsx index 647a2817..4a50301b 100644 --- a/docs/components/sidebar-content.tsx +++ b/docs/components/sidebar-content.tsx @@ -516,6 +516,23 @@ export const contents: Content[] = [ ), }, + { + title: "Hugging Face", + href: "/docs/authentication/huggingface", + icon: () => ( + + + + ), + }, { title: "Kick", href: "/docs/authentication/kick", @@ -1283,20 +1300,20 @@ C0.7,239.6,62.1,0.5,62.2,0.4c0,0,54,13.8,119.9,30.8S302.1,62,302.2,62c0.2,0,0.2, ), @@ -1571,8 +1588,8 @@ C0.7,239.6,62.1,0.5,62.2,0.4c0,0,54,13.8,119.9,30.8S302.1,62,302.2,62c0.2,0,0.2, xmlns="http://www.w3.org/2000/svg" > diff --git a/docs/content/blogs/seed-round.mdx b/docs/content/blogs/seed-round.mdx new file mode 100644 index 00000000..6c2c53c9 --- /dev/null +++ b/docs/content/blogs/seed-round.mdx @@ -0,0 +1,35 @@ +--- +title: "Announcing our $5M seed round" +description: "We raised $5M seed led by Peak XV Partners" +date: 2025-06-24 +author: + name: "Bereket Engida" + avatar: "/blogs/bereket.png" + twitter: "iambereket" +image: "/blogs/seed-round.png" +tags: ["seed round", "authentication", "funding"] +--- + +## Announcing our $5M seed round + +We’re excited to share that Better Auth has raised a $5 million seed round led by Peak XV Partners (formerly Sequoia Capital India & SEA), with participation from Y Combinator, Chapter One, P1 Ventures, and a group of incredible investors and angels. + +This funding fuels the next phase of **Better Auth**. + +From the start we are obsessed with making it possible for developers to **own their auth**. To **democratize high quality authentication** and make rolling your own auth not just doable, but the obvious choice. + +It started with building the framework. Since then, we’ve seen incredible growth and support from the community. Thank you everyone for being part of this journey. It’s still early days, and there’s so much more to build. This funding will allow us to have more people invloved and to push the boundaries of what's possible. + +On top of the framework, we’re also building the infrastructure to cover the gaps we couldn't cover in the framework: + +* A unified dashboard to manage users and user analytics +* Enterprise-grade security: bot, abuse, and fraud protection +* Authentication Email and SMS service +* Fast, globally distributed session storage +* and more. + +[Join the waitlist](https://better-auth.build) to get early access to the infrastructure. + +And if you're excited about making auth accessible - we're hiring! + +Reach out to [bereket@better-auth.com](mailto:bereket@better-auth.com). \ No newline at end of file diff --git a/docs/content/docs/authentication/github.mdx b/docs/content/docs/authentication/github.mdx index 3891425f..bcf7e064 100644 --- a/docs/content/docs/authentication/github.mdx +++ b/docs/content/docs/authentication/github.mdx @@ -10,7 +10,7 @@ description: GitHub provider setup and usage. Make sure to set the redirect URL to `http://localhost:3000/api/auth/callback/github` for local development. For production, you should set it to the URL of your application. If you change the base path of the auth routes, you should update the redirect URL accordingly. - Important: You MUST include the user.email scope in your Github app. See details below. + Important: You MUST include the user:email scope in your GitHub app. See details below. diff --git a/docs/content/docs/authentication/huggingface.mdx b/docs/content/docs/authentication/huggingface.mdx new file mode 100644 index 00000000..6e53be2b --- /dev/null +++ b/docs/content/docs/authentication/huggingface.mdx @@ -0,0 +1,47 @@ +--- +title: Hugging Face +description: Hugging Face provider setup and usage. +--- + + + + ### Get your Hugging Face credentials + To use Hugging Face sign in, you need a client ID and client secret. [Hugging Face OAuth documentation](https://huggingface.co/docs/hub/oauth). Make sure the created oauth app on Hugging Face has the "email" scope. + + Make sure to set the redirect URL to `http://localhost:3000/api/auth/callback/huggingface` for local development. For production, you should set it to the URL of your application. If you change the base path of the auth routes, you should update the redirect URL accordingly. + + + + ### Configure the provider + To configure the provider, you need to import the provider and pass it to the `socialProviders` option of the auth instance. + + ```ts title="auth.ts" + import { betterAuth } from "better-auth" + + export const auth = betterAuth({ + socialProviders: { + huggingface: { // [!code highlight] + clientId: process.env.HUGGINGFACE_CLIENT_ID as string, // [!code highlight] + clientSecret: process.env.HUGGINGFACE_CLIENT_SECRET as string, // [!code highlight] + }, // [!code highlight] + }, + }) + ``` + + + ### Sign In with Hugging Face + To sign in with Hugging Face, you can use the `signIn.social` function provided by the client. The `signIn` function takes an object with the following properties: + - `provider`: The provider to use. It should be set to `huggingface`. + + ```ts title="auth-client.ts" + import { createAuthClient } from "better-auth/client" + const authClient = createAuthClient() + + const signIn = async () => { + const data = await authClient.signIn.social({ + provider: "huggingface" + }) + } + ``` + + diff --git a/docs/content/docs/authentication/microsoft.mdx b/docs/content/docs/authentication/microsoft.mdx index 8eec0a09..853a7469 100644 --- a/docs/content/docs/authentication/microsoft.mdx +++ b/docs/content/docs/authentication/microsoft.mdx @@ -44,7 +44,7 @@ To sign in with Microsoft, you can use the `signIn.social` function provided by - `provider`: The provider to use. It should be set to `microsoft`. -```ts title="auth-client.ts" / +```ts title="auth-client.ts" import { createAuthClient } from "better-auth/client"; const authClient = createAuthClient(); diff --git a/docs/content/docs/authentication/spotify.mdx b/docs/content/docs/authentication/spotify.mdx index a30bf8ac..76800f09 100644 --- a/docs/content/docs/authentication/spotify.mdx +++ b/docs/content/docs/authentication/spotify.mdx @@ -34,7 +34,7 @@ description: Spotify provider setup and usage. To sign in with Spotify, you can use the `signIn.social` function provided by the client. The `signIn` function takes an object with the following properties: - `provider`: The provider to use. It should be set to `spotify`. - ```ts title="auth-client.ts" / + ```ts title="auth-client.ts" import { createAuthClient } from "better-auth/client" const authClient = createAuthClient() diff --git a/docs/content/docs/concepts/plugins.mdx b/docs/content/docs/concepts/plugins.mdx index 9bf34edd..7ca4e03a 100644 --- a/docs/content/docs/concepts/plugins.mdx +++ b/docs/content/docs/concepts/plugins.mdx @@ -5,7 +5,7 @@ description: Learn how to use plugins with Better Auth. Plugins are a key part of Better Auth, they let you extend the base functionalities. You can use them to add new authentication methods, features, or customize behaviors. -Better Auth offers comes with many built-in plugins ready to use. Check the plugins section for details. You can also create your own plugins. +Better Auth comes with many built-in plugins ready to use. Check the plugins section for details. You can also create your own plugins. ## Using a Plugin @@ -510,7 +510,7 @@ See built-in plugins for examples of how to use atoms properly. ### Path methods -by default, inferred paths use `GET` method if they don't require a body and `POST` if they do. You can override this by passing a `pathMethods` object. The key should be the path and the value should be the method ("POST" | "GET"). +By default, inferred paths use `GET` method if they don't require a body and `POST` if they do. You can override this by passing a `pathMethods` object. The key should be the path and the value should be the method ("POST" | "GET"). ```ts title="client-plugin.ts" import type { BetterAuthClientPlugin } from "better-auth/client"; diff --git a/docs/content/docs/concepts/typescript.mdx b/docs/content/docs/concepts/typescript.mdx index 2474b31c..de2837b2 100644 --- a/docs/content/docs/concepts/typescript.mdx +++ b/docs/content/docs/concepts/typescript.mdx @@ -71,7 +71,8 @@ export const auth = betterAuth({ user: { additionalFields: { role: { - type: "string" + type: "string", + input: false } } } @@ -83,6 +84,26 @@ type Session = typeof auth.$Infer.Session In the example above, we added a `role` field to the user object. This field is now available on the `Session` type. + +### The `input` property + +The `input` property in an additional field configuration determines whether the field should be included in the user input. This property defaults to `true`, meaning the field will be part of the user input during operations like registration. + +To prevent a field from being part of the user input, you must explicitly set `input: false`: + +```ts +additionalFields: { + role: { + type: "string", + input: false + } +} +``` + +When `input` is set to `false`, the field will be excluded from user input, preventing users from passing a value for it. + +By default, additional fields are included in the user input, which can lead to security vulnerabilities if not handled carefully. For fields that should not be set by the user, like a `role`, it is crucial to set `input: false` in the configuration. + ### Inferring Additional Fields on Client To make sure proper type inference for additional fields on the client side, you need to inform the client about these fields. There are two approaches to achieve this, depending on your project structure: diff --git a/docs/content/docs/concepts/users-accounts.mdx b/docs/content/docs/concepts/users-accounts.mdx index d85a23c2..851f4645 100644 --- a/docs/content/docs/concepts/users-accounts.mdx +++ b/docs/content/docs/concepts/users-accounts.mdx @@ -357,6 +357,27 @@ Users already signed in can manually link their account to additional social pro }); ``` + You can also link accounts using ID tokens directly, without redirecting to the provider's OAuth flow: + + ```ts + await authClient.linkSocial({ + provider: "google", + idToken: { + token: "id_token_from_provider", + nonce: "nonce_used_for_token", // Optional + accessToken: "access_token", // Optional, may be required by some providers + refreshToken: "refresh_token" // Optional + } + }); + ``` + + This is useful when you already have valid tokens from the provider, for example: + - After signing in with a native SDK + - When using a mobile app that handles authentication + - When implementing custom OAuth flows + + The ID token must be valid and the provider must support ID token verification. + If you want your users to be able to link a social account with a different email address than the user, or if you want to use a provider that does not return email addresses, you will need to enable this in the account linking settings. ```ts title="auth.ts" export const auth = betterAuth({ @@ -368,6 +389,18 @@ Users already signed in can manually link their account to additional social pro }); ``` + If you want the newly linked accounts to update the user information, you need to enable this in the account linking settings. + + ```ts title="auth.ts" + export const auth = betterAuth({ + account: { + accountLinking: { + updateUserInfoOnLink: true + } + }, + }); + ``` + - **Linking Credential-Based Accounts:** To link a credential-based account (e.g., email and password), users can initiate a "forgot password" flow, or you can call the `setPassword` method on the server. ```ts diff --git a/docs/content/docs/integrations/next.mdx b/docs/content/docs/integrations/next.mdx index 4549f2a1..ca8fad29 100644 --- a/docs/content/docs/integrations/next.mdx +++ b/docs/content/docs/integrations/next.mdx @@ -142,6 +142,9 @@ import { getSessionCookie } from "better-auth/cookies"; export async function middleware(request: NextRequest) { const sessionCookie = getSessionCookie(request); + // THIS IS NOT SECURE! + // This is the recommended approach to optimistically redirect users + // We recommend handling auth checks in each page/route if (!sessionCookie) { return NextResponse.redirect(new URL("/", request.url)); } @@ -178,6 +181,33 @@ export async function middleware(request: NextRequest) { } ``` +### How to handle auth checks in each page/route + +In this example, we are using the `auth.api.getSession` function within a server component to get the session object, +then we are checking if the session is valid. If it's not, we are redirecting the user to the sign-in page. + +```tsx title="app/dashboard/page.tsx" +import { auth } from "@/lib/auth"; +import { headers } from "next/headers"; +import { redirect } from "next/navigation"; + +export default async function DashboardPage() { + const session = await auth.api.getSession({ + headers: await headers() + }) + + if(!session) { + redirect("/sign-in") + } + + return ( +
+

Welcome {session.user.name}

+
+ ) +} +``` + ### For Next.js release `15.1.7` and below If you need the full session object, you'll have to fetch it from the `/get-session` API route. Since Next.js middleware doesn't support running Node.js APIs directly, you must make an HTTP request. diff --git a/docs/content/docs/integrations/tanstack.mdx b/docs/content/docs/integrations/tanstack.mdx index 40a34d22..45364dc9 100644 --- a/docs/content/docs/integrations/tanstack.mdx +++ b/docs/content/docs/integrations/tanstack.mdx @@ -9,14 +9,14 @@ Before you start, make sure you have a Better Auth instance configured. If you h ### Mount the handler -We need to mount the handler to a TanStack API endpoint. -Create a new file: `/app/routes/api/auth/$.ts` +We need to mount the handler to a TanStack API endpoint/Server Route. +Create a new file: `/src/routes/api/auth/$.ts` -```ts title="routes/api/auth/$.ts" +```ts title="src/routes/api/auth/$.ts" import { auth } from '@/lib/auth' // import your auth instance -import { createAPIFileRoute } from '@tanstack/react-start/api' +import { createServerFileRoute } from '@tanstack/react-start/server' -export const APIRoute = createAPIFileRoute('/api/auth/$')({ +export const ServerRoute = createServerFileRoute('/api/auth/$').methods({ GET: ({ request }) => { return auth.handler(request) }, @@ -26,15 +26,18 @@ export const APIRoute = createAPIFileRoute('/api/auth/$')({ }) ``` -If you haven't defined an API Route yet, you can do so by creating a file: `/app/api.ts` +If you haven't created your server route handler yet, you can do so by creating a file: `/src/server.ts` -```ts title="app/api.ts" +```ts title="src/server.ts" import { - createStartAPIHandler, - defaultAPIFileRouteHandler, -} from '@tanstack/react-start/api' + createStartHandler, + defaultStreamHandler, +} from '@tanstack/react-start/server' +import { createRouter } from './router' -export default createStartAPIHandler(defaultAPIFileRouteHandler) +export default createStartHandler({ + createRouter, +})(defaultStreamHandler) ``` ### Usage tips @@ -42,7 +45,7 @@ export default createStartAPIHandler(defaultAPIFileRouteHandler) - We recommend using the client SDK or `authClient` to handle authentication, rather than server actions with `auth.api`. - When you call functions that need to set cookies (like `signInEmail` or `signUpEmail`), you'll need to handle cookie setting for TanStack Start. Better Auth provides a `reactStartCookies` plugin to automatically handle this for you. -```ts title="auth.ts" +```ts title="src/lib/auth.ts" import { betterAuth } from "better-auth"; import { reactStartCookies } from "better-auth/react-start"; diff --git a/docs/content/docs/plugins/oauth-proxy.mdx b/docs/content/docs/plugins/oauth-proxy.mdx index 27de652d..a11d7564 100644 --- a/docs/content/docs/plugins/oauth-proxy.mdx +++ b/docs/content/docs/plugins/oauth-proxy.mdx @@ -66,6 +66,6 @@ To share cookies between the proxy server and your main server it uses URL query ## Options -**currentURL**: The application's current URL is automatically determined by the plugin. It first it check for the request URL if invoked by a client, then it checks the base URL from popular hosting providers, and finally falls back to the `baseURL` in your auth config. If the URL isn’t inferred correctly, you can specify it manually here. +**currentURL**: The application's current URL is automatically determined by the plugin. It first checks for the request URL if invoked by a client, then it checks the base URL from popular hosting providers, and finally falls back to the `baseURL` in your auth config. If the URL isn’t inferred correctly, you can specify it manually here. **productionURL**: If this value matches the `baseURL` in your auth config, requests will not be proxied. Defaults to the `BETTER_AUTH_URL` environment variable. diff --git a/docs/content/docs/plugins/sso.mdx b/docs/content/docs/plugins/sso.mdx index f5407f99..3345c19c 100644 --- a/docs/content/docs/plugins/sso.mdx +++ b/docs/content/docs/plugins/sso.mdx @@ -3,13 +3,9 @@ title: Single Sign-On (SSO) description: Integrate Single Sign-On (SSO) with your application. --- -`OIDC` `OAuth2` `SSO` +`OIDC` `OAuth2` `SSO` `SAML` -Single Sign-On (SSO) allows users to authenticate with multiple applications using a single set of credentials. This plugin supports OpenID Connect (OIDC) and OAuth2 providers. - - - SAML support is coming soon. Upvote the feature request on our [GitHub](https://github.com/better-auth/better-auth/issues/96) - +Single Sign-On (SSO) allows users to authenticate with multiple applications using a single set of credentials. This plugin supports OpenID Connect (OIDC), OAuth2 providers, and SAML 2.0. ## Installation @@ -67,30 +63,30 @@ Single Sign-On (SSO) allows users to authenticate with multiple applications usi ### Register an OIDC Provider -To register an OIDC provider, use the `createOIDCProvider` endpoint and provide the necessary configuration details for the provider. +To register an OIDC provider, use the `registerSSOProvider` endpoint and provide the necessary configuration details for the provider. A redirect URL will be automatically generated using the provider ID. For instance, if the provider ID is `hydra`, the redirect URL would be `{baseURL}/api/auth/sso/callback/hydra`. Note that `/api/auth` may vary depending on your base path configuration. -```ts title="register-provider.ts" +```ts title="register-oidc-provider.ts" import { authClient } from "@/lib/auth-client"; -// only with issuer if the provider supports discovery +// Register with OIDC configuration await authClient.sso.register({ - issuer: "https://idp.example.com", providerId: "example-provider", -}); - -// with all fields -await authClient.sso.register({ issuer: "https://idp.example.com", domain: "example.com", - clientId: "client-id", - clientSecret: "client-secret", - authorizationEndpoint: "https://idp.example.com/authorize", - tokenEndpoint: "https://idp.example.com/token", - jwksEndpoint: "https://idp.example.com/jwks", + oidcConfig: { + clientId: "client-id", + clientSecret: "client-secret", + authorizationEndpoint: "https://idp.example.com/authorize", + tokenEndpoint: "https://idp.example.com/token", + jwksEndpoint: "https://idp.example.com/jwks", + discoveryEndpoint: "https://idp.example.com/.well-known/openid-configuration", + scopes: ["openid", "email", "profile"], + pkce: true, + }, mapping: { id: "sub", email: "email", @@ -98,23 +94,28 @@ await authClient.sso.register({ name: "name", image: "picture", }, - providerId: "example-provider", }); ``` -```ts title="register-provider.ts" +```ts title="register-oidc-provider.ts" const { headers } = await signInWithTestUser(); -await auth.api.createOIDCProvider({ +await auth.api.registerSSOProvider({ body: { + providerId: "example-provider", issuer: "https://idp.example.com", domain: "example.com", - clientId: "your-client-id", - clientSecret: "your-client-secret", - authorizationEndpoint: "https://idp.example.com/authorize", - tokenEndpoint: "https://idp.example.com/token", - jwksEndpoint: "https://idp.example.com/jwks", + oidcConfig: { + clientId: "your-client-id", + clientSecret: "your-client-secret", + authorizationEndpoint: "https://idp.example.com/authorize", + tokenEndpoint: "https://idp.example.com/token", + jwksEndpoint: "https://idp.example.com/jwks", + discoveryEndpoint: "https://idp.example.com/.well-known/openid-configuration", + scopes: ["openid", "email", "profile"], + pkce: true, + }, mapping: { id: "sub", email: "email", @@ -122,7 +123,6 @@ await auth.api.createOIDCProvider({ name: "name", image: "picture", }, - providerId: "example-provider", }, headers, }); @@ -130,6 +130,130 @@ await auth.api.createOIDCProvider({ +### Register a SAML Provider + +To register a SAML provider, use the `registerSSOProvider` endpoint with SAML configuration details. The provider will act as a Service Provider (SP) and integrate with your Identity Provider (IdP). + + + +```ts title="register-saml-provider.ts" +import { authClient } from "@/lib/auth-client"; + +await authClient.sso.register({ + providerId: "saml-provider", + issuer: "https://idp.example.com", + domain: "example.com", + samlConfig: { + entryPoint: "https://idp.example.com/sso", + cert: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----", + callbackUrl: "https://yourapp.com/api/auth/sso/saml2/callback/saml-provider", + audience: "https://yourapp.com", + wantAssertionsSigned: true, + signatureAlgorithm: "sha256", + digestAlgorithm: "sha256", + identifierFormat: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", + idpMetadata: { + metadata: "", + privateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + privateKeyPass: "your-private-key-password", + isAssertionEncrypted: true, + encPrivateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + encPrivateKeyPass: "your-encryption-key-password" + }, + spMetadata: { + metadata: "", + binding: "post", + privateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + privateKeyPass: "your-sp-private-key-password", + isAssertionEncrypted: true, + encPrivateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + encPrivateKeyPass: "your-sp-encryption-key-password" + } + }, + mapping: { + id: "nameID", + email: "email", + name: "displayName", + firstName: "givenName", + lastName: "surname", + extraFields: { + department: "department", + role: "role" + } + }, +}); +``` + + + +```ts title="register-saml-provider.ts" +const { headers } = await signInWithTestUser(); +await auth.api.registerSSOProvider({ + body: { + providerId: "saml-provider", + issuer: "https://idp.example.com", + domain: "example.com", + samlConfig: { + entryPoint: "https://idp.example.com/sso", + cert: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----", + callbackUrl: "https://yourapp.com/api/auth/sso/saml2/callback/saml-provider", + audience: "https://yourapp.com", + wantAssertionsSigned: true, + signatureAlgorithm: "sha256", + digestAlgorithm: "sha256", + identifierFormat: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", + idpMetadata: { + metadata: "", + privateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + privateKeyPass: "your-private-key-password", + isAssertionEncrypted: true, + encPrivateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + encPrivateKeyPass: "your-encryption-key-password" + }, + spMetadata: { + metadata: "", + binding: "post", + privateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + privateKeyPass: "your-sp-private-key-password", + isAssertionEncrypted: true, + encPrivateKey: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----", + encPrivateKeyPass: "your-sp-encryption-key-password" + } + }, + mapping: { + id: "nameID", + email: "email", + name: "displayName", + firstName: "givenName", + lastName: "surname", + extraFields: { + department: "department", + role: "role" + } + }, + }, + headers, +}); +``` + + + +### Get Service Provider Metadata + +For SAML providers, you can retrieve the Service Provider metadata XML that needs to be configured in your Identity Provider: + +```ts title="get-sp-metadata.ts" +const response = await auth.api.spMetadata({ + query: { + providerId: "saml-provider", + format: "xml" // or "json" + } +}); + +const metadataXML = await response.text(); +console.log(metadataXML); +``` + ### Sign In with SSO To sign in with an SSO provider, you can call `signIn.sso` @@ -183,7 +307,6 @@ const res = await auth.api.signInSSO({ When a user is authenticated, if the user does not exist, the user will be provisioned using the `provisionUser` function. If the organization provisioning is enabled and a provider is associated with an organization, the user will be added to the organization. - ```ts title="auth.ts" const auth = betterAuth({ plugins: [ @@ -203,6 +326,280 @@ const auth = betterAuth({ }); ``` +## Provisioning + +The SSO plugin provides powerful provisioning capabilities to automatically set up users and manage their organization memberships when they sign in through SSO providers. + +### User Provisioning + +User provisioning allows you to run custom logic whenever a user signs in through an SSO provider. This is useful for: + +- Setting up user profiles with additional data from the SSO provider +- Synchronizing user attributes with external systems +- Creating user-specific resources +- Logging SSO sign-ins +- Updating user information from the SSO provider + +```ts title="auth.ts" +const auth = betterAuth({ + plugins: [ + sso({ + provisionUser: async ({ user, userInfo, token, provider }) => { + // Update user profile with SSO data + await updateUserProfile(user.id, { + department: userInfo.attributes?.department, + jobTitle: userInfo.attributes?.jobTitle, + manager: userInfo.attributes?.manager, + lastSSOLogin: new Date(), + }); + + // Create user-specific resources + await createUserWorkspace(user.id); + + // Sync with external systems + await syncUserWithCRM(user.id, userInfo); + + // Log the SSO sign-in + await auditLog.create({ + userId: user.id, + action: 'sso_signin', + provider: provider.providerId, + metadata: { + email: userInfo.email, + ssoProvider: provider.issuer, + }, + }); + }, + }), + ], +}); +``` + +The `provisionUser` function receives: +- **user**: The user object from the database +- **userInfo**: User information from the SSO provider (includes attributes, email, name, etc.) +- **token**: OAuth2 tokens (for OIDC providers) - may be undefined for SAML +- **provider**: The SSO provider configuration + +### Organization Provisioning + +Organization provisioning automatically manages user memberships in organizations when SSO providers are linked to specific organizations. This is particularly useful for: + +- Enterprise SSO where each company/domain maps to an organization +- Automatic role assignment based on SSO attributes +- Managing team memberships through SSO + +#### Basic Organization Provisioning + +```ts title="auth.ts" +const auth = betterAuth({ + plugins: [ + sso({ + organizationProvisioning: { + disabled: false, // Enable org provisioning + defaultRole: "member", // Default role for new members + }, + }), + ], +}); +``` + +#### Advanced Organization Provisioning with Custom Roles + +```ts title="auth.ts" +const auth = betterAuth({ + plugins: [ + sso({ + organizationProvisioning: { + disabled: false, + defaultRole: "member", + getRole: async ({ user, userInfo, provider }) => { + // Assign roles based on SSO attributes + const department = userInfo.attributes?.department; + const jobTitle = userInfo.attributes?.jobTitle; + + // Admins based on job title + if (jobTitle?.toLowerCase().includes('manager') || + jobTitle?.toLowerCase().includes('director') || + jobTitle?.toLowerCase().includes('vp')) { + return "admin"; + } + + // Special roles for IT department + if (department?.toLowerCase() === 'it') { + return "admin"; + } + + // Default to member for everyone else + return "member"; + }, + }, + }), + ], +}); +``` + +#### Linking SSO Providers to Organizations + +When registering an SSO provider, you can link it to a specific organization: + +```ts title="register-org-provider.ts" +await auth.api.registerSSOProvider({ + body: { + providerId: "acme-corp-saml", + issuer: "https://acme-corp.okta.com", + domain: "acmecorp.com", + organizationId: "org_acme_corp_id", // Link to organization + samlConfig: { + // SAML configuration... + }, + }, + headers, +}); +``` + +Now when users from `acmecorp.com` sign in through this provider, they'll automatically be added to the "Acme Corp" organization with the appropriate role. + +#### Multiple Organizations Example + +You can set up multiple SSO providers for different organizations: + +```ts title="multi-org-setup.ts" +// Acme Corp SAML provider +await auth.api.registerSSOProvider({ + body: { + providerId: "acme-corp", + issuer: "https://acme.okta.com", + domain: "acmecorp.com", + organizationId: "org_acme_id", + samlConfig: { /* ... */ }, + }, + headers, +}); + +// TechStart OIDC provider +await auth.api.registerSSOProvider({ + body: { + providerId: "techstart-google", + issuer: "https://accounts.google.com", + domain: "techstart.io", + organizationId: "org_techstart_id", + oidcConfig: { /* ... */ }, + }, + headers, +}); +``` + +#### Organization Provisioning Flow + +1. **User signs in** through an SSO provider linked to an organization +2. **User is authenticated** and either found or created in the database +3. **Organization membership is checked** - if the user isn't already a member of the linked organization +4. **Role is determined** using either the `defaultRole` or `getRole` function +5. **User is added** to the organization with the determined role +6. **User provisioning runs** (if configured) for additional setup + +### Provisioning Best Practices + +#### 1. Idempotent Operations +Make sure your provisioning functions can be safely run multiple times: + +```ts +provisionUser: async ({ user, userInfo }) => { + // Check if already provisioned + const existingProfile = await getUserProfile(user.id); + if (!existingProfile.ssoProvisioned) { + await createUserResources(user.id); + await markAsProvisioned(user.id); + } + + // Always update attributes (they might change) + await updateUserAttributes(user.id, userInfo.attributes); +}, +``` + +#### 2. Error Handling +Handle errors gracefully to avoid blocking user sign-in: + +```ts +provisionUser: async ({ user, userInfo }) => { + try { + await syncWithExternalSystem(user, userInfo); + } catch (error) { + // Log error but don't throw - user can still sign in + console.error('Failed to sync user with external system:', error); + await logProvisioningError(user.id, error); + } +}, +``` + +#### 3. Conditional Provisioning +Only run certain provisioning steps when needed: + +```ts +organizationProvisioning: { + disabled: false, + getRole: async ({ user, userInfo, provider }) => { + // Only process role assignment for certain providers + if (provider.providerId.includes('enterprise')) { + return determineEnterpriseRole(userInfo); + } + return "member"; + }, +}, +``` + +## SAML Configuration + +### Service Provider Configuration + +When registering a SAML provider, you need to provide Service Provider (SP) metadata configuration: + +- **metadata**: XML metadata for the Service Provider +- **binding**: The binding method, typically "post" or "redirect" +- **privateKey**: Private key for signing (optional) +- **privateKeyPass**: Password for the private key (if encrypted) +- **isAssertionEncrypted**: Whether assertions should be encrypted +- **encPrivateKey**: Private key for decryption (if encryption is enabled) +- **encPrivateKeyPass**: Password for the encryption private key + +### Identity Provider Configuration + +You also need to provide Identity Provider (IdP) configuration: + +- **metadata**: XML metadata from your Identity Provider +- **privateKey**: Private key for the IdP communication (optional) +- **privateKeyPass**: Password for the IdP private key (if encrypted) +- **isAssertionEncrypted**: Whether assertions from IdP are encrypted +- **encPrivateKey**: Private key for IdP assertion decryption +- **encPrivateKeyPass**: Password for the IdP decryption key + +### SAML Attribute Mapping + +Configure how SAML attributes map to user fields: + +```ts +mapping: { + id: "nameID", // Default: "nameID" + email: "email", // Default: "email" or "nameID" + name: "displayName", // Default: "displayName" + firstName: "givenName", // Default: "givenName" + lastName: "surname", // Default: "surname" + extraFields: { + department: "department", + role: "jobTitle", + phone: "telephoneNumber" + } +} +``` + +### SAML Endpoints + +The plugin automatically creates the following SAML endpoints: + +- **SP Metadata**: `/api/auth/sso/saml2/sp/metadata?providerId={providerId}` +- **SAML Callback**: `/api/auth/sso/saml2/callback/{providerId}` + ## Schema The plugin requires additional fields in the `ssoProvider` table to store the provider's configuration. @@ -214,7 +611,8 @@ The plugin requires additional fields in the `ssoProvider` table to store the pr }, { name: "issuer", type: "string", description: "The issuer identifier", isRequired: true }, { name: "domain", type: "string", description: "The domain of the provider", isRequired: true }, - { name: "oidcConfig", type: "string", description: "The OIDC configuration", isRequired: false }, + { name: "oidcConfig", type: "string", description: "The OIDC configuration (JSON string)", isRequired: false }, + { name: "samlConfig", type: "string", description: "The SAML configuration (JSON string)", isRequired: false }, { name: "userId", type: "string", description: "The user ID", isRequired: true, references: { model: "user", field: "id" } }, { name: "providerId", type: "string", description: "The provider ID. Used to identify a provider and to generate a redirect URL.", isRequired: true, isUnique: true }, { name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false }, @@ -229,6 +627,10 @@ The plugin requires additional fields in the `ssoProvider` table to store the pr **organizationProvisioning**: Options for provisioning users to an organization. +**defaultOverrideUserInfo**: Override user info with the provider info by default. + +**disableImplicitSignUp**: Disable implicit sign up for new users. + diff --git a/docs/public/blogs/seed-round.png b/docs/public/blogs/seed-round.png new file mode 100644 index 00000000..022b758c Binary files /dev/null and b/docs/public/blogs/seed-round.png differ diff --git a/examples/nextjs-mcp/README.md b/examples/nextjs-mcp/README.md index 0b1307ca..b1e7d56e 100644 --- a/examples/nextjs-mcp/README.md +++ b/examples/nextjs-mcp/README.md @@ -1,6 +1,6 @@ # Better Auth - MCP Demo -This is example repo on how to setup Better Auth for MCP Auth using Nextjs and Vercel MCP adapter. +This is an example repo on how to setup Better Auth for MCP Auth using Nextjs and Vercel MCP adapter. ## Usage @@ -12,7 +12,7 @@ First, add the plugin to your auth instance import { betterAuth } from "better-auth"; import { mcp } from "better-auth/plugins"; -export cosnt auth = betterAuth({ +export const auth = betterAuth({ plugins: [ mcp({ loginPage: "/sign-in" // path to a page where users login @@ -46,7 +46,7 @@ import { toNextJsHandler } from "better-auth/next-js"; export const { GET, POST } = toNextJsHandler(auth); ``` -Use `auth.api.getMcpSession` to get the session using the access token sent from the MCP client +You can use the helper function `withMcpAuth` to get the session and handle unauthenticated calls automatically. ```ts import { auth } from "@/lib/auth"; @@ -54,7 +54,7 @@ import { createMcpHandler } from "@vercel/mcp-adapter"; import { withMcpAuth } from "better-auth/plugins"; import { z } from "zod"; -const handler = withMcpAuth(auth, (req, sesssion) => { +const handler = withMcpAuth(auth, (req, session) => { //session => This isn’t a typical Better Auth session - instead, it returns the access token record along with the scopes and user ID. return createMcpHandler( (server) => { diff --git a/packages/better-auth/build.config.ts b/packages/better-auth/build.config.ts index 9c808149..56e813e4 100644 --- a/packages/better-auth/build.config.ts +++ b/packages/better-auth/build.config.ts @@ -110,5 +110,6 @@ export default defineBuildConfig({ "./src/plugins/username/index.ts", "./src/plugins/haveibeenpwned/index.ts", "./src/plugins/one-time-token/index.ts", + "./src/test-utils/index.ts", ], }); diff --git a/packages/better-auth/package.json b/packages/better-auth/package.json index 695202d7..dd9e7073 100644 --- a/packages/better-auth/package.json +++ b/packages/better-auth/package.json @@ -1,6 +1,6 @@ { "name": "better-auth", - "version": "1.2.10-beta.1", + "version": "1.2.10", "description": "The most comprehensive authentication library for TypeScript.", "type": "module", "license": "MIT", @@ -135,6 +135,16 @@ "default": "./dist/client/solid/index.cjs" } }, + "./test": { + "import": { + "types": "./dist/test-utils/index.d.ts", + "default": "./dist/test-utils/index.mjs" + }, + "require": { + "types": "./dist/test-utils/index.d.cts", + "default": "./dist/test-utils/index.cjs" + } + }, "./api": { "import": { "types": "./dist/api/index.d.ts", diff --git a/packages/better-auth/src/adapters/create-adapter/index.ts b/packages/better-auth/src/adapters/create-adapter/index.ts index 68f3e069..9e727c92 100644 --- a/packages/better-auth/src/adapters/create-adapter/index.ts +++ b/packages/better-auth/src/adapters/create-adapter/index.ts @@ -318,7 +318,10 @@ export const createAdapter = !config.disableIdGeneration && !options.advanced?.database?.useNumberId ) { - fields.id = idField({ customModelName: unsafe_model, forceAllowId }); + fields.id = idField({ + customModelName: unsafe_model, + forceAllowId: forceAllowId && "id" in data, + }); } for (const field in fields) { const value = data[field]; diff --git a/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts b/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts index e206af99..b679c71b 100644 --- a/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts +++ b/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts @@ -1,326 +1,274 @@ import { ObjectId, type Db } from "mongodb"; -import { getAuthTables } from "../../db"; -import type { Adapter, BetterAuthOptions, Where } from "../../types"; -import { withApplyDefault } from "../utils"; +import type { Where } from "../../types"; +import { createAdapter, type AdapterDebugLogs } from "../create-adapter"; -const createTransform = (options: BetterAuthOptions) => { - const schema = getAuthTables(options); +export interface MongoDBAdapterConfig { /** - * if custom id gen is provided we don't want to override with object id + * Enable debug logs for the adapter + * + * @default false */ - const customIdGen = - options.advanced?.database?.generateId || options.advanced?.generateId; + debugLogs?: AdapterDebugLogs; + /** + * Use plural table names + * + * @default false + */ + usePlural?: boolean; +} - function serializeID(field: string, value: any, model: string) { - if (customIdGen) { - return value; - } - if ( - field === "id" || - field === "_id" || - schema[model].fields[field].references?.field === "id" - ) { - if (typeof value !== "string") { - if (value instanceof ObjectId) { +export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => + createAdapter({ + config: { + adapterId: "mongodb-adapter", + adapterName: "MongoDB Adapter", + usePlural: config?.usePlural ?? false, + debugLogs: config?.debugLogs ?? false, + mapKeysTransformInput: { + id: "_id", + }, + mapKeysTransformOutput: { + _id: "id", + }, + supportsNumericIds: false, + customTransformInput({ + action, + data, + field, + fieldAttributes, + schema, + model, + }) { + // Given the key transformation, we know that `id` is already mapped to `_id` + if (field === "_id" || fieldAttributes.references?.field === "id") { + if (action === "update") { + return data; + } + if (Array.isArray(data)) { + return data.map((v) => new ObjectId()); + } + if (typeof data === "string") { + try { + return new ObjectId(data); + } catch (error) { + return new ObjectId(); + } + } + return new ObjectId(); + } + return data; + }, + customTransformOutput({ data, field, fieldAttributes }) { + if (field === "id" || fieldAttributes.references?.field === "id") { + if (data instanceof ObjectId) { + return data.toHexString(); + } + if (Array.isArray(data)) { + return data.map((v) => { + if (v instanceof ObjectId) { + return v.toHexString(); + } + return v; + }); + } + return data; + } + return data; + }, + }, + adapter: ({ options, getFieldName, schema, getDefaultModelName }) => { + /** + * if custom id gen is provided we don't want to override with object id + */ + const customIdGen = options.advanced?.database?.generateId; + + function serializeID({ + field, + value, + model, + }: { field: string; value: any; model: string }) { + if (customIdGen) { return value; } - if (Array.isArray(value)) { - return value.map((v) => { - if (typeof v === "string") { - try { - return new ObjectId(v); - } catch (e) { - return v; - } + model = getDefaultModelName(model); + if ( + field === "id" || + field === "_id" || + schema[model].fields[field]?.references?.field === "id" + ) { + if (typeof value !== "string") { + if (value instanceof ObjectId) { + return value; } - if (v instanceof ObjectId) { - return v; + if (Array.isArray(value)) { + return value.map((v) => { + if (typeof v === "string") { + try { + return new ObjectId(v); + } catch (e) { + return v; + } + } + if (v instanceof ObjectId) { + return v; + } + throw new Error("Invalid id value"); + }); } throw new Error("Invalid id value"); - }); + } + try { + return new ObjectId(value); + } catch (e) { + return value; + } } - throw new Error("Invalid id value"); - } - try { - return new ObjectId(value); - } catch (e) { return value; } - } - return value; - } - function deserializeID(field: string, value: any, model: string) { - if (customIdGen) { - return value; - } - if ( - field === "id" || - schema[model].fields[field].references?.field === "id" - ) { - if (value instanceof ObjectId) { - return value.toHexString(); - } - if (Array.isArray(value)) { - return value.map((v) => { - if (v instanceof ObjectId) { - return v.toHexString(); - } - return v; - }); - } - return value; - } - return value; - } - - function getField(field: string, model: string) { - if (field === "id") { - if (customIdGen) { - return "id"; - } - return "_id"; - } - const f = schema[model].fields[field]; - return f.fieldName || field; - } - - return { - transformInput( - data: Record, - model: string, - action: "create" | "update", - ) { - const transformedData: Record = - action === "update" - ? {} - : customIdGen - ? { - id: customIdGen({ model }), - } - : { - _id: new ObjectId(), + function convertWhereClause({ + where, + model, + }: { where: Where[]; model: string }) { + if (!where.length) return {}; + const conditions = where.map((w) => { + const { + field: field_, + value, + operator = "eq", + connector = "AND", + } = w; + let condition: any; + let field = getFieldName({ model, field: field_ }); + if (field === "id") field = "_id"; + switch (operator.toLowerCase()) { + case "eq": + condition = { + [field]: serializeID({ + field, + value, + model, + }), }; - const fields = schema[model].fields; - for (const field in fields) { - const value = data[field]; - if ( - value === undefined && - (!fields[field].defaultValue || action === "update") - ) { - continue; - } - transformedData[fields[field].fieldName || field] = withApplyDefault( - serializeID(field, value, model), - fields[field], - action, - ); - } - return transformedData; - }, - transformOutput( - data: Record, - model: string, - select: string[] = [], - ) { - const transformedData: Record = - data.id || data._id - ? select.length === 0 || select.includes("id") - ? { - id: data.id ? data.id.toString() : data._id.toString(), - } - : {} - : {}; + break; + case "in": + condition = { + [field]: { + $in: Array.isArray(value) + ? value.map((v) => serializeID({ field, value: v, model })) + : [serializeID({ field, value, model })], + }, + }; + break; + case "gt": + condition = { [field]: { $gt: value } }; + break; + case "gte": + condition = { [field]: { $gte: value } }; + break; + case "lt": + condition = { [field]: { $lt: value } }; + break; + case "lte": + condition = { [field]: { $lte: value } }; + break; + case "ne": + condition = { [field]: { $ne: value } }; + break; - const tableSchema = schema[model].fields; - for (const key in tableSchema) { - if (select.length && !select.includes(key)) { - continue; + case "contains": + condition = { [field]: { $regex: `.*${value}.*` } }; + break; + case "starts_with": + condition = { [field]: { $regex: `${value}.*` } }; + break; + case "ends_with": + condition = { [field]: { $regex: `.*${value}` } }; + break; + default: + throw new Error(`Unsupported operator: ${operator}`); + } + return { condition, connector }; + }); + if (conditions.length === 1) { + return conditions[0].condition; } - const field = tableSchema[key]; - if (field) { - transformedData[key] = deserializeID( - key, - data[field.fieldName || key], - model, + const andConditions = conditions + .filter((c) => c.connector === "AND") + .map((c) => c.condition); + const orConditions = conditions + .filter((c) => c.connector === "OR") + .map((c) => c.condition); + + let clause = {}; + if (andConditions.length) { + clause = { ...clause, $and: andConditions }; + } + if (orConditions.length) { + clause = { ...clause, $or: orConditions }; + } + return clause; + } + + return { + async create({ model, data: values }) { + const res = await db.collection(model).insertOne(values); + const insertedData = { _id: res.insertedId.toString(), ...values }; + return insertedData as any; + }, + async findOne({ model, where, select }) { + const clause = convertWhereClause({ where, model }); + const res = await db.collection(model).findOne(clause); + if (!res) return null; + return res as any; + }, + async findMany({ model, where, limit, offset, sortBy }) { + const clause = where ? convertWhereClause({ where, model }) : {}; + const cursor = db.collection(model).find(clause); + if (limit) cursor.limit(limit); + if (offset) cursor.skip(offset); + if (sortBy) + cursor.sort( + getFieldName({ field: sortBy.field, model }), + sortBy.direction === "desc" ? -1 : 1, + ); + const res = await cursor.toArray(); + return res as any; + }, + async count({ model }) { + const res = await db.collection(model).countDocuments(); + return res; + }, + async update({ model, where, update: values }) { + const clause = convertWhereClause({ where, model }); + + const res = await db.collection(model).findOneAndUpdate( + clause, + { $set: values as any }, + { + returnDocument: "after", + }, ); - } - } - return transformedData as any; - }, - convertWhereClause(where: Where[], model: string) { - if (!where.length) return {}; - const conditions = where.map((w) => { - const { field: _field, value, operator = "eq", connector = "AND" } = w; - let condition: any; - const field = getField(_field, model); - switch (operator.toLowerCase()) { - case "eq": - condition = { - [field]: serializeID(_field, value, model), - }; - break; - case "in": - condition = { - [field]: { - $in: Array.isArray(value) - ? serializeID(_field, value, model) - : [serializeID(_field, value, model)], - }, - }; - break; - case "gt": - condition = { [field]: { $gt: value } }; - break; - case "gte": - condition = { [field]: { $gte: value } }; - break; - case "lt": - condition = { [field]: { $lt: value } }; - break; - case "lte": - condition = { [field]: { $lte: value } }; - break; - case "ne": - condition = { [field]: { $ne: value } }; - break; + if (!res) return null; + return res as any; + }, + async updateMany({ model, where, update: values }) { + const clause = convertWhereClause({ where, model }); - case "contains": - condition = { [field]: { $regex: `.*${value}.*` } }; - break; - case "starts_with": - condition = { [field]: { $regex: `${value}.*` } }; - break; - case "ends_with": - condition = { [field]: { $regex: `.*${value}` } }; - break; - default: - throw new Error(`Unsupported operator: ${operator}`); - } - return { condition, connector }; - }); - if (conditions.length === 1) { - return conditions[0].condition; - } - const andConditions = conditions - .filter((c) => c.connector === "AND") - .map((c) => c.condition); - const orConditions = conditions - .filter((c) => c.connector === "OR") - .map((c) => c.condition); - - let clause = {}; - if (andConditions.length) { - clause = { ...clause, $and: andConditions }; - } - if (orConditions.length) { - clause = { ...clause, $or: orConditions }; - } - return clause; + const res = await db.collection(model).updateMany(clause, { + $set: values as any, + }); + return res.modifiedCount; + }, + async delete({ model, where }) { + const clause = convertWhereClause({ where, model }); + await db.collection(model).deleteOne(clause); + }, + async deleteMany({ model, where }) { + const clause = convertWhereClause({ where, model }); + const res = await db.collection(model).deleteMany(clause); + return res.deletedCount; + }, + }; }, - getModelName: (model: string) => { - return schema[model].modelName; - }, - getField, - }; -}; - -export const mongodbAdapter = (db: Db) => (options: BetterAuthOptions) => { - const transform = createTransform(options); - const hasCustomId = options.advanced?.generateId; - return { - id: "mongodb-adapter", - async create(data) { - const { model, data: values, select } = data; - const transformedData = transform.transformInput(values, model, "create"); - if (transformedData.id && !hasCustomId) { - // biome-ignore lint/performance/noDelete: setting id to undefined will cause the id to be null in the database which is not what we want - delete transformedData.id; - } - const res = await db - .collection(transform.getModelName(model)) - .insertOne(transformedData); - const id = res.insertedId; - const insertedData = { id: id.toString(), ...transformedData }; - const t = transform.transformOutput(insertedData, model, select); - return t; - }, - async findOne(data) { - const { model, where, select } = data; - const clause = transform.convertWhereClause(where, model); - const res = await db - .collection(transform.getModelName(model)) - .findOne(clause); - if (!res) return null; - const transformedData = transform.transformOutput(res, model, select); - return transformedData; - }, - async findMany(data) { - const { model, where, limit, offset, sortBy } = data; - const clause = where ? transform.convertWhereClause(where, model) : {}; - const cursor = db.collection(transform.getModelName(model)).find(clause); - if (limit) cursor.limit(limit); - if (offset) cursor.skip(offset); - if (sortBy) - cursor.sort( - transform.getField(sortBy.field, model), - sortBy.direction === "desc" ? -1 : 1, - ); - const res = await cursor.toArray(); - return res.map((r) => transform.transformOutput(r, model)); - }, - async count(data) { - const { model } = data; - const res = await db - .collection(transform.getModelName(model)) - .countDocuments(); - return res; - }, - async update(data) { - const { model, where, update: values } = data; - const clause = transform.convertWhereClause(where, model); - - const transformedData = transform.transformInput(values, model, "update"); - - const res = await db - .collection(transform.getModelName(model)) - .findOneAndUpdate( - clause, - { $set: transformedData }, - { - returnDocument: "after", - }, - ); - const output = res?.value ?? res; - if (!output) return null; - return transform.transformOutput(output, model); - }, - async updateMany(data) { - const { model, where, update: values } = data; - const clause = transform.convertWhereClause(where, model); - const transformedData = transform.transformInput(values, model, "update"); - const res = await db - .collection(transform.getModelName(model)) - .updateMany(clause, { $set: transformedData }); - return res.modifiedCount; - }, - async delete(data) { - const { model, where } = data; - const clause = transform.convertWhereClause(where, model); - const res = await db - .collection(transform.getModelName(model)) - .findOneAndDelete(clause); - const output = res?.value ?? res; - if (!output) return null; - return transform.transformOutput(output, model); - }, - async deleteMany(data) { - const { model, where } = data; - const clause = transform.convertWhereClause(where, model); - const res = await db - .collection(transform.getModelName(model)) - .deleteMany(clause); - return res.deletedCount; - }, - } satisfies Adapter; -}; + }); diff --git a/packages/better-auth/src/api/routes/account.test.ts b/packages/better-auth/src/api/routes/account.test.ts index f8ee62c9..be46c8c0 100644 --- a/packages/better-auth/src/api/routes/account.test.ts +++ b/packages/better-auth/src/api/routes/account.test.ts @@ -1,4 +1,12 @@ -import { describe, expect, it, vi } from "vitest"; +import { + afterEach, + beforeAll, + describe, + expect, + it, + vi, + type MockInstance, +} from "vitest"; import { getTestInstance } from "../../test-utils/test-instance"; import { parseSetCookieHeader } from "../../cookies"; import type { GoogleProfile } from "../../social-providers"; @@ -61,6 +69,20 @@ describe("account", async () => { const ctx = await auth.$context; + let googleVerifyIdTokenMock: MockInstance; + let googleGetUserInfoMock: MockInstance; + beforeAll(() => { + const googleProvider = ctx.socialProviders.find((v) => v.id === "google")!; + expect(googleProvider).toBeTruthy(); + + googleVerifyIdTokenMock = vi.spyOn(googleProvider, "verifyIdToken"); + googleGetUserInfoMock = vi.spyOn(googleProvider, "getUserInfo"); + }); + afterEach(() => { + googleVerifyIdTokenMock.mockClear(); + googleGetUserInfoMock.mockClear(); + }); + const { headers } = await signInWithTestUser(); it("should list all accounts", async () => { @@ -96,7 +118,9 @@ describe("account", async () => { redirect: true, }); const state = - new URL(linkAccountRes.data!.url).searchParams.get("state") || ""; + linkAccountRes.data && "url" in linkAccountRes.data + ? new URL(linkAccountRes.data.url).searchParams.get("state") || "" + : ""; email = "test@test.com"; await client.$fetch("/callback/google", { query: { @@ -139,7 +163,10 @@ describe("account", async () => { redirect: true, }); - const url = new URL(linkAccountRes.data!.url); + const url = + linkAccountRes.data && "url" in linkAccountRes.data + ? new URL(linkAccountRes.data.url) + : new URL(""); const scopesParam = url.searchParams.get("scope"); expect(scopesParam).toContain(customScope); }); @@ -169,7 +196,9 @@ describe("account", async () => { redirect: true, }); const state = - new URL(linkAccountRes.data!.url).searchParams.get("state") || ""; + linkAccountRes.data && "url" in linkAccountRes.data + ? new URL(linkAccountRes.data.url).searchParams.get("state") || "" + : ""; email = "test2@test.com"; await client.$fetch("/callback/google", { query: { @@ -192,6 +221,53 @@ describe("account", async () => { }); expect(accounts.data?.length).toBe(2); }); + + it("should link third account with idToken", async () => { + googleVerifyIdTokenMock.mockResolvedValueOnce(true); + const user = { + id: "0987654321", + name: "test2", + email: "test2@gmail.com", + sub: "test2", + emailVerified: true, + }; + const userInfo = { + user, + data: user, + }; + googleGetUserInfoMock.mockResolvedValueOnce(userInfo); + + const { headers: headers2 } = await signInWithTestUser(); + const linkAccountRes = await client.linkSocial( + { + provider: "google", + callbackURL: "/callback", + idToken: { token: "test" }, + }, + { + headers: headers2, + onSuccess(context) { + const cookies = parseSetCookieHeader( + context.response.headers.get("set-cookie") || "", + ); + headers.set( + "cookie", + `better-auth.state=${cookies.get("better-auth.state")?.value}`, + ); + }, + }, + ); + + expect(googleVerifyIdTokenMock).toHaveBeenCalledOnce(); + expect(googleGetUserInfoMock).toHaveBeenCalledOnce(); + + const { headers: headers3 } = await signInWithTestUser(); + const accounts = await client.listAccounts({ + fetchOptions: { headers: headers3 }, + }); + expect(accounts.data?.length).toBe(3); + }); + it("should unlink account", async () => { const { headers } = await signInWithTestUser(); const previousAccounts = await client.listAccounts({ @@ -199,7 +275,7 @@ describe("account", async () => { headers, }, }); - expect(previousAccounts.data?.length).toBe(2); + expect(previousAccounts.data?.length).toBe(3); const unlinkAccountId = previousAccounts.data![1].accountId; const unlinkRes = await client.unlinkAccount({ providerId: "google", @@ -214,7 +290,7 @@ describe("account", async () => { headers, }, }); - expect(accounts.data?.length).toBe(1); + expect(accounts.data?.length).toBe(2); }); it("should fail to unlink the last account of a provider", async () => { diff --git a/packages/better-auth/src/api/routes/account.ts b/packages/better-auth/src/api/routes/account.ts index 68d72ad3..5694a197 100644 --- a/packages/better-auth/src/api/routes/account.ts +++ b/packages/better-auth/src/api/routes/account.ts @@ -104,6 +104,22 @@ export const linkSocialAccount = createAuthEndpoint( * OAuth2 provider to use */ provider: SocialProviderListEnum, + /** + * ID Token for direct authentication without redirect + */ + idToken: z + .object({ + token: z.string(), + nonce: z.string().optional(), + accessToken: z.string().optional(), + refreshToken: z.string().optional(), + scopes: z.array(z.string()).optional(), + }) + .optional(), + /** + * Whether to allow sign up for new users + */ + requestSignUp: z.boolean().optional(), /** * Additional scopes to request when linking the account. * This is useful for requesting additional permissions when @@ -146,8 +162,11 @@ export const linkSocialAccount = createAuthEndpoint( description: "Indicates if the user should be redirected to the authorization URL", }, + status: { + type: "boolean", + }, }, - required: ["url", "redirect"], + required: ["redirect"], }, }, }, @@ -175,6 +194,133 @@ export const linkSocialAccount = createAuthEndpoint( }); } + // Handle ID Token flow if provided + if (c.body.idToken) { + if (!provider.verifyIdToken) { + c.context.logger.error( + "Provider does not support id token verification", + { + provider: c.body.provider, + }, + ); + throw new APIError("NOT_FOUND", { + message: BASE_ERROR_CODES.ID_TOKEN_NOT_SUPPORTED, + }); + } + + const { token, nonce } = c.body.idToken; + const valid = await provider.verifyIdToken(token, nonce); + if (!valid) { + c.context.logger.error("Invalid id token", { + provider: c.body.provider, + }); + throw new APIError("UNAUTHORIZED", { + message: BASE_ERROR_CODES.INVALID_TOKEN, + }); + } + + const linkingUserInfo = await provider.getUserInfo({ + idToken: token, + accessToken: c.body.idToken.accessToken, + refreshToken: c.body.idToken.refreshToken, + }); + + if (!linkingUserInfo || !linkingUserInfo?.user) { + c.context.logger.error("Failed to get user info", { + provider: c.body.provider, + }); + throw new APIError("UNAUTHORIZED", { + message: BASE_ERROR_CODES.FAILED_TO_GET_USER_INFO, + }); + } + + if (!linkingUserInfo.user.email) { + c.context.logger.error("User email not found", { + provider: c.body.provider, + }); + throw new APIError("UNAUTHORIZED", { + message: BASE_ERROR_CODES.USER_EMAIL_NOT_FOUND, + }); + } + + const existingAccounts = await c.context.internalAdapter.findAccounts( + session.user.id, + ); + + const hasBeenLinked = existingAccounts.find( + (a) => + a.providerId === provider.id && + a.accountId === linkingUserInfo.user.id, + ); + + if (hasBeenLinked) { + return c.json({ + redirect: false, + status: true, + }); + } + + const trustedProviders = + c.context.options.account?.accountLinking?.trustedProviders; + + const isTrustedProvider = trustedProviders?.includes(provider.id); + if ( + (!isTrustedProvider && !linkingUserInfo.user.emailVerified) || + c.context.options.account?.accountLinking?.enabled === false + ) { + throw new APIError("UNAUTHORIZED", { + message: "Account not linked - linking not allowed", + }); + } + + if ( + linkingUserInfo.user.email !== session.user.email && + c.context.options.account?.accountLinking?.allowDifferentEmails !== true + ) { + throw new APIError("UNAUTHORIZED", { + message: "Account not linked - different emails not allowed", + }); + } + + try { + await c.context.internalAdapter.createAccount( + { + userId: session.user.id, + providerId: provider.id, + accountId: linkingUserInfo.user.id.toString(), + accessToken: c.body.idToken.accessToken, + idToken: token, + refreshToken: c.body.idToken.refreshToken, + scope: c.body.idToken.scopes?.join(","), + }, + c, + ); + } catch (e: any) { + throw new APIError("EXPECTATION_FAILED", { + message: "Account not linked - unable to create account", + }); + } + + if ( + c.context.options.account?.accountLinking?.updateUserInfoOnLink === true + ) { + try { + await c.context.internalAdapter.updateUser(session.user.id, { + name: linkingUserInfo.user?.name, + image: linkingUserInfo.user?.image, + }); + } catch (e: any) { + console.warn("Could not update user - " + e.toString()); + } + } + + return c.json({ + redirect: false, + status: true, + }); + } + + // Handle OAuth flow const state = await generateState(c, { userId: session.user.id, email: session.user.email, diff --git a/packages/better-auth/src/api/routes/callback.ts b/packages/better-auth/src/api/routes/callback.ts index 040e2867..d306cd42 100644 --- a/packages/better-auth/src/api/routes/callback.ts +++ b/packages/better-auth/src/api/routes/callback.ts @@ -126,6 +126,19 @@ export const callbackOAuth = createAuthEndpoint( } if (link) { + const trustedProviders = + c.context.options.account?.accountLinking?.trustedProviders; + const isTrustedProvider = trustedProviders?.includes( + provider.id as "apple", + ); + if ( + (!isTrustedProvider && !userInfo.emailVerified) || + c.context.options.account?.accountLinking?.enabled === false + ) { + c.context.logger.error("Unable to link account - untrusted provider"); + return redirectOnError("unable_to_link_account"); + } + const existingAccount = await c.context.internalAdapter.findAccount( userInfo.id, ); diff --git a/packages/better-auth/src/api/routes/update-user.ts b/packages/better-auth/src/api/routes/update-user.ts index 9ba062fd..86caaafa 100644 --- a/packages/better-auth/src/api/routes/update-user.ts +++ b/packages/better-auth/src/api/routes/update-user.ts @@ -421,17 +421,14 @@ export const deleteUser = createAuthEndpoint( throw new APIError("NOT_FOUND"); } const session = ctx.context.session; - let canDelete = false; - const accounts = await ctx.context.internalAdapter.findAccounts( - session.user.id, - ); - const account = accounts.find( - (account) => account.providerId === "credential" && account.password, - ); - - // If the user has a password, we can try to delete the account if (ctx.body.password) { + const accounts = await ctx.context.internalAdapter.findAccounts( + session.user.id, + ); + const account = accounts.find( + (account) => account.providerId === "credential" && account.password, + ); if (!account || !account.password) { throw new APIError("BAD_REQUEST", { message: BASE_ERROR_CODES.CREDENTIAL_ACCOUNT_NOT_FOUND, @@ -446,10 +443,8 @@ export const deleteUser = createAuthEndpoint( message: BASE_ERROR_CODES.INVALID_PASSWORD, }); } - canDelete = true; } - // If the user has a token, we can try to delete the account if (ctx.body.token) { //@ts-expect-error await deleteUserCallback({ @@ -464,15 +459,7 @@ export const deleteUser = createAuthEndpoint( }); } - // if user didn't provide a password or token, try sending email verification if (ctx.context.options.user.deleteUser?.sendDeleteAccountVerification) { - // if the user has a password but it was not provided, we can't delete the account - if (account && account.password && !canDelete) { - throw new APIError("BAD_REQUEST", { - message: BASE_ERROR_CODES.USER_ALREADY_HAS_PASSWORD, - }); - } - const token = generateRandomString(32, "0-9", "a-z"); await ctx.context.internalAdapter.createVerificationValue( { @@ -506,25 +493,15 @@ export const deleteUser = createAuthEndpoint( }); } - // if the user didn't provide a password or token, or email verification is not enabled - // we can check if the session is fresh and delete based on that - if (ctx.context.options.session?.freshAge) { + if (!ctx.body.password && ctx.context.sessionConfig.freshAge !== 0) { const currentAge = session.session.createdAt.getTime(); - const freshAge = ctx.context.options.session.freshAge; + const freshAge = ctx.context.sessionConfig.freshAge * 1000; const now = Date.now(); if (now - currentAge > freshAge * 1000) { throw new APIError("BAD_REQUEST", { message: BASE_ERROR_CODES.SESSION_EXPIRED, }); } - canDelete = true; - } - - // if password/fresh session didn't work, we can't delete the account - if (!canDelete) { - throw new APIError("BAD_REQUEST", { - message: "User cannot be deleted. please provide a password or token", - }); } const beforeDelete = ctx.context.options.user.deleteUser?.beforeDelete; diff --git a/packages/better-auth/src/client/index.ts b/packages/better-auth/src/client/index.ts index cc45a1dd..052fe755 100644 --- a/packages/better-auth/src/client/index.ts +++ b/packages/better-auth/src/client/index.ts @@ -1,4 +1,4 @@ -import type { BetterAuthPlugin } from "../types"; +import type { BetterAuthOptions, BetterAuthPlugin } from "../types"; import type { BetterAuthClientPlugin } from "./types"; export * from "./vanilla"; export * from "./query"; @@ -11,6 +11,10 @@ export const InferPlugin = () => { } satisfies BetterAuthClientPlugin; }; +export function InferAuth() { + return {} as O["options"]; +} + //@ts-expect-error export type * from "nanostores"; export type * from "@better-fetch/fetch"; diff --git a/packages/better-auth/src/client/types.ts b/packages/better-auth/src/client/types.ts index bd77027d..ff1449ba 100644 --- a/packages/better-auth/src/client/types.ts +++ b/packages/better-auth/src/client/types.ts @@ -12,7 +12,7 @@ import type { } from "../types/helper"; import type { Auth } from "../auth"; import type { InferRoutes } from "./path-to-object"; -import type { Session, User } from "../types"; +import type { BetterAuthOptions, Session, User } from "../types"; import type { InferFieldsInputClient, InferFieldsOutput } from "../db"; export type AtomListener = { @@ -72,6 +72,7 @@ export interface ClientOptions { baseURL?: string; basePath?: string; disableDefaultFetchPlugins?: boolean; + $InferAuth?: BetterAuthOptions; } export type InferClientAPI = InferRoutes< diff --git a/packages/better-auth/src/client/vanilla.ts b/packages/better-auth/src/client/vanilla.ts index de490893..59283160 100644 --- a/packages/better-auth/src/client/vanilla.ts +++ b/packages/better-auth/src/client/vanilla.ts @@ -16,6 +16,7 @@ import type { BetterFetchResponse, } from "@better-fetch/fetch"; import type { BASE_ERROR_CODES } from "../error/codes"; +import type { InferRoutes } from "./path-to-object"; type InferResolvedHooks = O["plugins"] extends Array< infer Plugin @@ -89,5 +90,18 @@ export function createAuthClient