From 384ef1c3449192e49c5ea1fd001587029fae6380 Mon Sep 17 00:00:00 2001 From: Bereket Engida <86073083+Bekacru@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:00:14 +0300 Subject: [PATCH] feat: hooks (#916) --- docs/components/sidebar-content.tsx | 20 +++ docs/content/docs/concepts/hooks.mdx | 161 ++++++++++++++++++++++ packages/better-auth/src/api/call.test.ts | 58 +++++++- packages/better-auth/src/api/index.ts | 105 ++++++++------ packages/better-auth/src/types/adapter.ts | 3 +- packages/better-auth/src/types/options.ts | 141 ++++++++++++++++--- packages/better-auth/src/types/plugins.ts | 52 +++---- 7 files changed, 452 insertions(+), 88 deletions(-) create mode 100644 docs/content/docs/concepts/hooks.mdx diff --git a/docs/components/sidebar-content.tsx b/docs/components/sidebar-content.tsx index 2cfa5719..6c9b6493 100644 --- a/docs/components/sidebar-content.tsx +++ b/docs/components/sidebar-content.tsx @@ -233,6 +233,26 @@ export const contents: Content[] = [ ), }, + { + href: "/docs/concepts/hooks", + title: "Hooks", + icon: (props?: SVGProps) => ( + + + + ), + }, { href: "/docs/concepts/plugins", title: "Plugins", diff --git a/docs/content/docs/concepts/hooks.mdx b/docs/content/docs/concepts/hooks.mdx new file mode 100644 index 00000000..71db2422 --- /dev/null +++ b/docs/content/docs/concepts/hooks.mdx @@ -0,0 +1,161 @@ +--- +title: Hooks +description: Better Auth Hooks let you customize BetterAuth's behavior +--- + +Hooks in Better Auth let you "hook into" the lifecycle and execute custom logic. They provide a way to customize Better Auth's behavior without writing a full plugin. + +## Before Hooks + +**Before hooks** run *before* an endpoint is executed. Use them to modify requests, pre validate data, or return early. + +### Example: Enforce Email Domain Restriction + +This hook ensures that users can only sign up if their email ends with `@example.com`: + +```ts title="auth.ts" +import { betterAuth } from "better-auth"; +import { createAuthMiddleware } from "better-auth/api"; + +export const auth = betterAuth({ + hooks: { + before: createAuthMiddleware(async (ctx) => { + if (ctx.path !== "/sign-up/email") { + return; + } + if (!ctx.body?.email.endsWith("@example.com")) { + throw new APIError("BAD_REQUEST", { + message: "Email must end with @example.com", + }); + } + }), + }, +}); +``` + +### Example: Modify Request Context + +To adjust the request context before proceeding: + +```ts +import { betterAuth } from "better-auth"; + +export const auth = betterAuth({ + hooks: { + before: createAuthMiddleware(async (ctx) => { + if (ctx.path === "/sign-up/email") { + return { + ...ctx, + body: { + ...ctx.body, + name: "John Doe", + }, + }; + } + }), + }, +}); +``` + +## After Hooks + +**After hooks** run *after* an endpoint is executed. Use them to modify responses. + +### Example: Add a Custom Header + +This hook adds a custom header to the response: + +```ts +import { betterAuth } from "better-auth"; + +export const auth = betterAuth({ + hooks: { + after: createAuthMiddleware(async (ctx) => { + ctx.response.headers.set("X-Custom-Header", "Hello World"); + return { + responseHeader: ctx.responseHeader // return the updated response headers + } + }), + }, +}); +``` + +## Context Object + +The `ctx` object provides: + +- **Path:** `ctx.path` to get the current endpoint path. +- **Body:** `ctx.body` for parsed request body (available for POST requests). +- **Headers:** `ctx.headers` to access request headers. +- **Request:** `ctx.request` to access the request object (may not exist in server-only endpoints). +- **Query Parameters:** `ctx.query` to access query parameters. + +and more. + +### JSON Responses + +Use `ctx.json` to send JSON responses: + +```ts +const hook = createAuthMiddleware(async (ctx) => { + return ctx.json({ + message: "Hello World", + }); +}); +``` + +### Redirects + +Use `ctx.redirect` to redirect users: + +```ts +const hook = createAuthMiddleware(async (ctx) => { + throw ctx.redirect("/sign-up/name"); +}); +``` + +### Cookies + +- Set cookies: `ctx.setCookies` or `ctx.setSignedCookie`. +- Get cookies: `ctx.getCookies` or `ctx.getSignedCookies`. + +Example: + +```ts +const hook = createAuthMiddleware(async (ctx) => { + ctx.setCookies("my-cookie", "value"); + await ctx.setSignedCookie("my-signed-cookie", "value", ctx.context.secret, { + maxAge: 1000, + }); + + const cookie = ctx.getCookies("my-cookie"); + const signedCookie = await ctx.getSignedCookies("my-signed-cookie"); +}); +``` + +### Predefined Auth Cookies + +Access BetterAuth’s predefined cookie properties: + +```ts +const hook = createAuthMiddleware(async (ctx) => { + const cookieName = ctx.context.authCookies.sessionToken.name; +}); +``` + +### Errors + +Throw errors with `APIError` for a specific status code and message: + +```ts +const hook = createAuthMiddleware(async (ctx) => { + throw new APIError("BAD_REQUEST", { + message: "Invalid request", + }); +}); +``` + +## Reusable Hooks + +If you need to reuse a hook across multiple endpoints, consider creating a plugin. Learn more in the [Plugins Documentation](/docs/concepts/plugins). + diff --git a/packages/better-auth/src/api/call.test.ts b/packages/better-auth/src/api/call.test.ts index 734acc7c..4ce737d8 100644 --- a/packages/better-auth/src/api/call.test.ts +++ b/packages/better-auth/src/api/call.test.ts @@ -10,13 +10,14 @@ import { init } from "../init"; import type { BetterAuthOptions, BetterAuthPlugin } from "../types"; import { z } from "zod"; import { createAuthClient } from "../client"; -import { convertSetCookieToCookie } from "../test-utils/headers"; describe("call", async () => { const q = z.optional( z.object({ testBeforeHook: z.string().optional(), + testBeforeGlobal: z.string().optional(), testAfterHook: z.string().optional(), + testAfterGlobal: z.string().optional(), testContext: z.string().optional(), message: z.string().optional(), }), @@ -180,6 +181,28 @@ describe("call", async () => { emailAndPassword: { enabled: true, }, + hooks: { + before: createAuthMiddleware(async (ctx) => { + if (ctx.path === "/sign-up/email") { + return { + context: { + body: { + ...ctx.body, + email: "changed@email.com", + }, + }, + }; + } + if (ctx.query?.testBeforeGlobal) { + return ctx.json({ before: "global" }); + } + }), + after: createAuthMiddleware(async (ctx) => { + if (ctx.query?.testAfterGlobal) { + return ctx.json({ after: "global" }); + } + }), + }, } satisfies BetterAuthOptions; const authContext = init(options); const { api } = getEndpoints(authContext, options); @@ -343,6 +366,39 @@ describe("call", async () => { }); }); + it("should intercept on global before hook", async () => { + const response = await api.test({ + query: { + testBeforeGlobal: "true", + }, + }); + expect(response).toMatchObject({ + before: "global", + }); + }); + + it("should intercept on global after hook", async () => { + const response = await api.test({ + query: { + testAfterGlobal: "true", + }, + }); + expect(response).toMatchObject({ + after: "global", + }); + }); + + it("global before hook should change the context", async (ctx) => { + const response = await api.signUpEmail({ + body: { + email: "my-email@test.com", + password: "password", + name: "test", + }, + }); + expect(response.email).toBe("changed@email.com"); + }); + it("should fetch using a client", async () => { const response = await client.$fetch("/test"); expect(response.data).toMatchObject({ diff --git a/packages/better-auth/src/api/index.ts b/packages/better-auth/src/api/index.ts index 76759a11..851684a7 100644 --- a/packages/better-auth/src/api/index.ts +++ b/packages/better-auth/src/api/index.ts @@ -195,21 +195,49 @@ export function getEndpoints< }; const plugins = options.plugins || []; - for (const plugin of plugins) { - const beforeHooks = plugin.hooks?.before ?? []; - for (const hook of beforeHooks) { - if (!hook.matcher(internalContext)) continue; - const hookRes = await hook.handler(internalContext); - if (hookRes && "context" in hookRes) { - // modify the context with the response from the hook - internalContext = defu(internalContext, hookRes.context); - continue; + const beforeHooks = plugins + .map((plugin) => { + if (plugin.hooks?.before) { + return plugin.hooks.before; } + }) + .filter((plugin) => plugin !== undefined) + .flat(); + const afterHooks = plugins + .map((plugin) => { + if (plugin.hooks?.after) { + return plugin.hooks.after; + } + }) + .filter((plugin) => plugin !== undefined) + .flat(); + if (options.hooks?.before) { + beforeHooks.push({ + matcher: () => true, + handler: options.hooks.before, + }); + } + if (options.hooks?.after) { + afterHooks.push({ + matcher: () => true, + handler: options.hooks.after, + }); + } + for (const hook of beforeHooks) { + if (!hook.matcher(internalContext)) continue; + const hookRes = await hook.handler(internalContext); + if (hookRes && "context" in hookRes) { + // modify the context with the response from the hook + internalContext = { + ...internalContext, + ...hookRes.context, + }; + continue; + } - if (hookRes) { - // return with the response from the hook - return hookRes; - } + if (hookRes) { + // return with the response from the hook + return hookRes; } } @@ -225,25 +253,16 @@ export function getEndpoints< internalContext.context.newSession = newSession; } if (e instanceof APIError) { - const afterPlugins = options.plugins - ?.map((plugin) => { - if (plugin.hooks?.after) { - return plugin.hooks.after; - } - }) - .filter((plugin) => plugin !== undefined) - .flat(); - /** * If there are no after plugins, we can directly throw the error */ - if (!afterPlugins?.length) { + if (!afterHooks?.length) { e.headers = endpoint.headers; throw e; } internalContext.context.returned = e; internalContext.context.returned.headers = endpoint.headers; - for (const hook of afterPlugins || []) { + for (const hook of afterHooks || []) { const match = hook.matcher(internalContext); if (match) { try { @@ -272,29 +291,25 @@ export function getEndpoints< } internalContext.context.returned = endpointRes; internalContext.responseHeader = endpoint.headers; - for (const plugin of options.plugins || []) { - if (plugin.hooks?.after) { - for (const hook of plugin.hooks.after) { - const match = hook.matcher(internalContext); - if (match) { - try { - const hookRes = await hook.handler(internalContext); - if (hookRes) { - if ("responseHeader" in hookRes) { - const headers = hookRes.responseHeader as Headers; - internalContext.responseHeader = headers; - } else { - internalContext.context.returned = hookRes; - } - } - } catch (e) { - if (e instanceof APIError) { - internalContext.context.returned = e; - continue; - } - throw e; + for (const hook of afterHooks) { + const match = hook.matcher(internalContext); + if (match) { + try { + const hookRes = await hook.handler(internalContext); + if (hookRes) { + if ("responseHeader" in hookRes) { + const headers = hookRes.responseHeader as Headers; + internalContext.responseHeader = headers; + } else { + internalContext.context.returned = hookRes; } } + } catch (e) { + if (e instanceof APIError) { + internalContext.context.returned = e; + continue; + } + throw e; } } } diff --git a/packages/better-auth/src/types/adapter.ts b/packages/better-auth/src/types/adapter.ts index d65edc81..33692524 100644 --- a/packages/better-auth/src/types/adapter.ts +++ b/packages/better-auth/src/types/adapter.ts @@ -1,4 +1,5 @@ -import type { BetterAuthOptions } from "."; +import type { GenericEndpointContext } from "./context"; +import type { BetterAuthOptions } from "./options"; /** * Adapter where clause diff --git a/packages/better-auth/src/types/options.ts b/packages/better-auth/src/types/options.ts index 84444c04..77d4050f 100644 --- a/packages/better-auth/src/types/options.ts +++ b/packages/better-auth/src/types/options.ts @@ -1,8 +1,12 @@ import type { Dialect, Kysely, MysqlPool, PostgresPool } from "kysely"; import type { Account, Session, User, Verification } from "../db/schema"; -import type { BetterAuthPlugin } from "."; +import type { + BetterAuthPlugin, + HookAfterHandler, + HookBeforeHandler, +} from "./plugins"; import type { SocialProviderList, SocialProviders } from "../social-providers"; -import type { AdapterInstance, SecondaryStorage } from "."; +import type { AdapterInstance, SecondaryStorage, Where } from "./adapter"; import type { KyselyDatabaseType } from "../adapters/kysely-adapter/types"; import type { FieldAttribute } from "../db"; import type { Models, RateLimit } from "./models"; @@ -586,8 +590,11 @@ export type BetterAuthOptions = { * operations. */ databaseHooks?: { + /** + * User hooks + */ user?: { - [key in "create" | "update"]?: { + create?: { /** * Hook that is called before a user is created. * if the hook returns false, the user will not be created. @@ -605,12 +612,33 @@ export type BetterAuthOptions = { */ after?: (user: User) => Promise; }; - }; - session?: { - [key in "create" | "update"]?: { + update?: { /** - * Hook that is called before a user is created. - * if the hook returns false, the user will not be created. + * Hook that is called before a user is updated. + * if the hook returns false, the user will not be updated. + * If the hook returns an object, it'll be used instead of the original data + */ + before?: (user: Partial) => Promise< + | boolean + | void + | { + data: User & Record; + } + >; + /** + * Hook that is called after a user is updated. + */ + after?: (user: User) => Promise; + }; + }; + /** + * Session Hook + */ + session?: { + create?: { + /** + * Hook that is called before a session is updated. + * if the hook returns false, the session will not be updated. * If the hook returns an object, it'll be used instead of the original data */ before?: (session: Session) => Promise< @@ -621,16 +649,40 @@ export type BetterAuthOptions = { } >; /** - * Hook that is called after a user is created. + * Hook that is called after a session is updated. + */ + after?: (session: Session) => Promise; + }; + /** + * Update hook + */ + update?: { + /** + * Hook that is called before a user is updated. + * if the hook returns false, the session will not be updated. + * If the hook returns an object, it'll be used instead of the original data + */ + before?: (session: Partial) => Promise< + | boolean + | void + | { + data: Session & Record; + } + >; + /** + * Hook that is called after a session is updated. */ after?: (session: Session) => Promise; }; }; + /** + * Account Hook + */ account?: { - [key in "create" | "update"]?: { + create?: { /** - * Hook that is called before a user is created. - * If the hook returns false, the user will not be created. + * Hook that is called before a account is created. + * If the hook returns false, the account will not be created. * If the hook returns an object, it'll be used instead of the original data */ before?: (account: Account) => Promise< @@ -641,16 +693,40 @@ export type BetterAuthOptions = { } >; /** - * Hook that is called after a user is created. + * Hook that is called after a account is created. + */ + after?: (account: Account) => Promise; + }; + /** + * Update hook + */ + update?: { + /** + * Hook that is called before a account is update. + * If the hook returns false, the user will not be updated. + * If the hook returns an object, it'll be used instead of the original data + */ + before?: (account: Partial) => Promise< + | boolean + | void + | { + data: Account & Record; + } + >; + /** + * Hook that is called after a account is updated. */ after?: (account: Account) => Promise; }; }; + /** + * Verification Hook + */ verification?: { - [key in "create" | "update"]: { + create?: { /** - * Hook that is called before a user is created. - * if the hook returns false, the user will not be created. + * Hook that is called before a verification is created. + * if the hook returns false, the verification will not be created. * If the hook returns an object, it'll be used instead of the original data */ before?: (verification: Verification) => Promise< @@ -661,7 +737,25 @@ export type BetterAuthOptions = { } >; /** - * Hook that is called after a user is created. + * Hook that is called after a verification is created. + */ + after?: (verification: Verification) => Promise; + }; + update?: { + /** + * Hook that is called before a verification is updated. + * if the hook returns false, the verification will not be updated. + * If the hook returns an object, it'll be used instead of the original data + */ + before?: (verification: Partial) => Promise< + | boolean + | void + | { + data: Verification & Record; + } + >; + /** + * Hook that is called after a verification is updated. */ after?: (verification: Verification) => Promise; }; @@ -685,4 +779,17 @@ export type BetterAuthOptions = { */ onError?: (error: unknown, ctx: AuthContext) => void | Promise; }; + /** + * Hooks + */ + hooks?: { + /** + * Before a request is processed + */ + before?: HookBeforeHandler; + /** + * After a request is processed + */ + after?: HookAfterHandler; + }; }; diff --git a/packages/better-auth/src/types/plugins.ts b/packages/better-auth/src/types/plugins.ts index dd892ffc..7a2695f5 100644 --- a/packages/better-auth/src/types/plugins.ts +++ b/packages/better-auth/src/types/plugins.ts @@ -17,6 +17,32 @@ export type PluginSchema = { }; }; +export type HookBeforeHandler = (context: HookEndpointContext) => Promise< + | void + | { + context?: Partial; + } + | Response + | { + response: Record; + body: any; + _flag: "json"; + } +>; + +export type HookAfterHandler = (context: HookEndpointContext) => Promise< + | void + | { + responseHeader?: Headers; + } + | Response + | { + response: Record; + body: any; + _flag: "json"; + } +>; + export type BetterAuthPlugin = { id: LiteralString; /** @@ -55,18 +81,7 @@ export type BetterAuthPlugin = { hooks?: { before?: { matcher: (context: HookEndpointContext) => boolean; - handler: (context: HookEndpointContext) => Promise< - | void - | { - context?: Partial; - } - | Response - | { - response: Record; - body: any; - _flag: "json"; - } - >; + handler: HookBeforeHandler; }[]; after?: { matcher: ( @@ -75,18 +90,7 @@ export type BetterAuthPlugin = { endpoint: Endpoint; }>, ) => boolean; - handler: (context: HookEndpointContext) => Promise< - | void - | { - responseHeader?: Headers; - } - | Response - | { - response: Record; - body: any; - _flag: "json"; - } - >; + handler: HookAfterHandler; }[]; }; /**