feat: hooks (#916)

This commit is contained in:
Bereket Engida
2024-12-20 21:00:14 +03:00
committed by GitHub
parent 8dffaa2cd6
commit 384ef1c344
7 changed files with 452 additions and 88 deletions

View File

@@ -233,6 +233,26 @@ export const contents: Content[] = [
</svg>
),
},
{
href: "/docs/concepts/hooks",
title: "Hooks",
icon: (props?: SVGProps<any>) => (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1.2em"
height="1.2em"
viewBox="0 0 65 64"
fill="currentColor"
>
<path
fill="currentColor"
fillRule="evenodd"
d="M30.719 27.2c-6.18-3.098-8.44-7.209-6.946-12.241c1.316-4.435 5.752-7.337 10.27-6.655c2.253.34 4.235 1.263 5.78 3.023c2.337 2.667 2.865 5.78 2.151 9.246l2.135.579l3.201.868c2.273-6.234-.393-13.518-6.214-17.258c-6.061-3.893-13.914-3.071-19.062 2c-2.687 2.649-4.158 5.88-4.5 9.62c-.483 5.29 1.703 9.558 5.375 13.21l-5.84 9.793q-.283.02-.502.033c-.269.016-.48.029-.688.058c-3.839.544-6.54 3.958-5.895 7.44c.73 3.933 4.309 6.348 7.983 5.385c3.896-1.02 5.97-4.78 4.5-8.644c-.532-1.398-.203-2.294.463-3.394c1.877-3.101 3.727-6.219 5.61-9.394zm13.222 4.686l-5.647-9.96q.14-.349.272-.665v-.001c.185-.448.354-.858.495-1.277c.747-2.21.296-4.228-1.122-6.02c-1.736-2.194-4.764-2.991-7.345-2.004c-2.605.997-4.272 3.554-4.158 6.383c.115 2.86 2.034 5.414 5.008 5.929c1.78.308 2.652 1.154 3.442 2.61c1.68 3.1 3.42 6.165 5.162 9.233v.001q1.033 1.817 2.061 3.64c5.832-3.888 10.657-3.764 14.26.285c3.12 3.51 3.186 8.854.153 12.438c-3.557 4.201-8.348 4.368-13.826.82l-4.352 3.642c5.546 5.536 13.463 6.272 19.723 1.963c6.099-4.199 8.222-12.258 5.116-19.063c-2.57-5.633-9.737-10.895-19.242-7.954m-12.623 16.99H42.76q.238.321.455.63c.303.428.592.834.928 1.195c2.424 2.592 6.516 2.72 9.106.315c2.685-2.492 2.807-6.68.27-9.281c-2.483-2.547-6.725-2.79-9.03-.094c-1.4 1.639-2.835 1.831-4.694 1.802c-3.397-.052-6.795-.042-10.193-.032q-2.045.007-4.088.008c.309 6.695-2.222 10.867-7.242 11.858c-4.916.97-9.443-1.538-11.037-6.114c-1.81-5.2.428-9.359 6.898-12.66c-.487-1.763-.98-3.548-1.466-5.315C5.617 32.724.327 39.565.872 47.26c.483 6.793 5.963 12.827 12.665 13.907c3.64.588 7.06-.022 10.233-1.822c4.082-2.316 6.451-5.958 7.548-10.47"
clipRule="evenodd"
></path>
</svg>
),
},
{
href: "/docs/concepts/plugins",
title: "Plugins",

View File

@@ -0,0 +1,161 @@
---
title: Hooks
description: Better Auth Hooks let you customize BetterAuth's behavior
---
Hooks in Better Auth let you "hook into" the lifecycle and execute custom logic. They provide a way to customize Better Auth's behavior without writing a full plugin.
## Before Hooks
**Before hooks** run *before* an endpoint is executed. Use them to modify requests, pre validate data, or return early.
### Example: Enforce Email Domain Restriction
This hook ensures that users can only sign up if their email ends with `@example.com`:
```ts title="auth.ts"
import { betterAuth } from "better-auth";
import { createAuthMiddleware } from "better-auth/api";
export const auth = betterAuth({
hooks: {
before: createAuthMiddleware(async (ctx) => {
if (ctx.path !== "/sign-up/email") {
return;
}
if (!ctx.body?.email.endsWith("@example.com")) {
throw new APIError("BAD_REQUEST", {
message: "Email must end with @example.com",
});
}
}),
},
});
```
### Example: Modify Request Context
To adjust the request context before proceeding:
```ts
import { betterAuth } from "better-auth";
export const auth = betterAuth({
hooks: {
before: createAuthMiddleware(async (ctx) => {
if (ctx.path === "/sign-up/email") {
return {
...ctx,
body: {
...ctx.body,
name: "John Doe",
},
};
}
}),
},
});
```
## After Hooks
**After hooks** run *after* an endpoint is executed. Use them to modify responses.
### Example: Add a Custom Header
This hook adds a custom header to the response:
```ts
import { betterAuth } from "better-auth";
export const auth = betterAuth({
hooks: {
after: createAuthMiddleware(async (ctx) => {
ctx.response.headers.set("X-Custom-Header", "Hello World");
return {
responseHeader: ctx.responseHeader // return the updated response headers
}
}),
},
});
```
## Context Object
The `ctx` object provides:
- **Path:** `ctx.path` to get the current endpoint path.
- **Body:** `ctx.body` for parsed request body (available for POST requests).
- **Headers:** `ctx.headers` to access request headers.
- **Request:** `ctx.request` to access the request object (may not exist in server-only endpoints).
- **Query Parameters:** `ctx.query` to access query parameters.
and more.
### JSON Responses
Use `ctx.json` to send JSON responses:
```ts
const hook = createAuthMiddleware(async (ctx) => {
return ctx.json({
message: "Hello World",
});
});
```
### Redirects
Use `ctx.redirect` to redirect users:
```ts
const hook = createAuthMiddleware(async (ctx) => {
throw ctx.redirect("/sign-up/name");
});
```
### Cookies
- Set cookies: `ctx.setCookies` or `ctx.setSignedCookie`.
- Get cookies: `ctx.getCookies` or `ctx.getSignedCookies`.
Example:
```ts
const hook = createAuthMiddleware(async (ctx) => {
ctx.setCookies("my-cookie", "value");
await ctx.setSignedCookie("my-signed-cookie", "value", ctx.context.secret, {
maxAge: 1000,
});
const cookie = ctx.getCookies("my-cookie");
const signedCookie = await ctx.getSignedCookies("my-signed-cookie");
});
```
### Predefined Auth Cookies
Access BetterAuths predefined cookie properties:
```ts
const hook = createAuthMiddleware(async (ctx) => {
const cookieName = ctx.context.authCookies.sessionToken.name;
});
```
### Errors
Throw errors with `APIError` for a specific status code and message:
```ts
const hook = createAuthMiddleware(async (ctx) => {
throw new APIError("BAD_REQUEST", {
message: "Invalid request",
});
});
```
## Reusable Hooks
If you need to reuse a hook across multiple endpoints, consider creating a plugin. Learn more in the [Plugins Documentation](/docs/concepts/plugins).

View File

@@ -10,13 +10,14 @@ import { init } from "../init";
import type { BetterAuthOptions, BetterAuthPlugin } from "../types";
import { z } from "zod";
import { createAuthClient } from "../client";
import { convertSetCookieToCookie } from "../test-utils/headers";
describe("call", async () => {
const q = z.optional(
z.object({
testBeforeHook: z.string().optional(),
testBeforeGlobal: z.string().optional(),
testAfterHook: z.string().optional(),
testAfterGlobal: z.string().optional(),
testContext: z.string().optional(),
message: z.string().optional(),
}),
@@ -180,6 +181,28 @@ describe("call", async () => {
emailAndPassword: {
enabled: true,
},
hooks: {
before: createAuthMiddleware(async (ctx) => {
if (ctx.path === "/sign-up/email") {
return {
context: {
body: {
...ctx.body,
email: "changed@email.com",
},
},
};
}
if (ctx.query?.testBeforeGlobal) {
return ctx.json({ before: "global" });
}
}),
after: createAuthMiddleware(async (ctx) => {
if (ctx.query?.testAfterGlobal) {
return ctx.json({ after: "global" });
}
}),
},
} satisfies BetterAuthOptions;
const authContext = init(options);
const { api } = getEndpoints(authContext, options);
@@ -343,6 +366,39 @@ describe("call", async () => {
});
});
it("should intercept on global before hook", async () => {
const response = await api.test({
query: {
testBeforeGlobal: "true",
},
});
expect(response).toMatchObject({
before: "global",
});
});
it("should intercept on global after hook", async () => {
const response = await api.test({
query: {
testAfterGlobal: "true",
},
});
expect(response).toMatchObject({
after: "global",
});
});
it("global before hook should change the context", async (ctx) => {
const response = await api.signUpEmail({
body: {
email: "my-email@test.com",
password: "password",
name: "test",
},
});
expect(response.email).toBe("changed@email.com");
});
it("should fetch using a client", async () => {
const response = await client.$fetch("/test");
expect(response.data).toMatchObject({

View File

@@ -195,14 +195,43 @@ export function getEndpoints<
};
const plugins = options.plugins || [];
for (const plugin of plugins) {
const beforeHooks = plugin.hooks?.before ?? [];
const beforeHooks = plugins
.map((plugin) => {
if (plugin.hooks?.before) {
return plugin.hooks.before;
}
})
.filter((plugin) => plugin !== undefined)
.flat();
const afterHooks = plugins
.map((plugin) => {
if (plugin.hooks?.after) {
return plugin.hooks.after;
}
})
.filter((plugin) => plugin !== undefined)
.flat();
if (options.hooks?.before) {
beforeHooks.push({
matcher: () => true,
handler: options.hooks.before,
});
}
if (options.hooks?.after) {
afterHooks.push({
matcher: () => true,
handler: options.hooks.after,
});
}
for (const hook of beforeHooks) {
if (!hook.matcher(internalContext)) continue;
const hookRes = await hook.handler(internalContext);
if (hookRes && "context" in hookRes) {
// modify the context with the response from the hook
internalContext = defu(internalContext, hookRes.context);
internalContext = {
...internalContext,
...hookRes.context,
};
continue;
}
@@ -211,7 +240,6 @@ export function getEndpoints<
return hookRes;
}
}
}
let endpointRes: any;
try {
@@ -225,25 +253,16 @@ export function getEndpoints<
internalContext.context.newSession = newSession;
}
if (e instanceof APIError) {
const afterPlugins = options.plugins
?.map((plugin) => {
if (plugin.hooks?.after) {
return plugin.hooks.after;
}
})
.filter((plugin) => plugin !== undefined)
.flat();
/**
* If there are no after plugins, we can directly throw the error
*/
if (!afterPlugins?.length) {
if (!afterHooks?.length) {
e.headers = endpoint.headers;
throw e;
}
internalContext.context.returned = e;
internalContext.context.returned.headers = endpoint.headers;
for (const hook of afterPlugins || []) {
for (const hook of afterHooks || []) {
const match = hook.matcher(internalContext);
if (match) {
try {
@@ -272,9 +291,7 @@ export function getEndpoints<
}
internalContext.context.returned = endpointRes;
internalContext.responseHeader = endpoint.headers;
for (const plugin of options.plugins || []) {
if (plugin.hooks?.after) {
for (const hook of plugin.hooks.after) {
for (const hook of afterHooks) {
const match = hook.matcher(internalContext);
if (match) {
try {
@@ -296,8 +313,6 @@ export function getEndpoints<
}
}
}
}
}
const response = internalContext.context.returned;
if (response instanceof Response) {
endpoint.headers.forEach((value, key) => {

View File

@@ -1,4 +1,5 @@
import type { BetterAuthOptions } from ".";
import type { GenericEndpointContext } from "./context";
import type { BetterAuthOptions } from "./options";
/**
* Adapter where clause

View File

@@ -1,8 +1,12 @@
import type { Dialect, Kysely, MysqlPool, PostgresPool } from "kysely";
import type { Account, Session, User, Verification } from "../db/schema";
import type { BetterAuthPlugin } from ".";
import type {
BetterAuthPlugin,
HookAfterHandler,
HookBeforeHandler,
} from "./plugins";
import type { SocialProviderList, SocialProviders } from "../social-providers";
import type { AdapterInstance, SecondaryStorage } from ".";
import type { AdapterInstance, SecondaryStorage, Where } from "./adapter";
import type { KyselyDatabaseType } from "../adapters/kysely-adapter/types";
import type { FieldAttribute } from "../db";
import type { Models, RateLimit } from "./models";
@@ -586,8 +590,11 @@ export type BetterAuthOptions = {
* operations.
*/
databaseHooks?: {
/**
* User hooks
*/
user?: {
[key in "create" | "update"]?: {
create?: {
/**
* Hook that is called before a user is created.
* if the hook returns false, the user will not be created.
@@ -605,12 +612,33 @@ export type BetterAuthOptions = {
*/
after?: (user: User) => Promise<void>;
};
};
session?: {
[key in "create" | "update"]?: {
update?: {
/**
* Hook that is called before a user is created.
* if the hook returns false, the user will not be created.
* Hook that is called before a user is updated.
* if the hook returns false, the user will not be updated.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (user: Partial<User>) => Promise<
| boolean
| void
| {
data: User & Record<string, any>;
}
>;
/**
* Hook that is called after a user is updated.
*/
after?: (user: User) => Promise<void>;
};
};
/**
* Session Hook
*/
session?: {
create?: {
/**
* Hook that is called before a session is updated.
* if the hook returns false, the session will not be updated.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (session: Session) => Promise<
@@ -621,16 +649,40 @@ export type BetterAuthOptions = {
}
>;
/**
* Hook that is called after a user is created.
* Hook that is called after a session is updated.
*/
after?: (session: Session) => Promise<void>;
};
/**
* Update hook
*/
update?: {
/**
* Hook that is called before a user is updated.
* if the hook returns false, the session will not be updated.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (session: Partial<Session>) => Promise<
| boolean
| void
| {
data: Session & Record<string, any>;
}
>;
/**
* Hook that is called after a session is updated.
*/
after?: (session: Session) => Promise<void>;
};
};
account?: {
[key in "create" | "update"]?: {
/**
* Hook that is called before a user is created.
* If the hook returns false, the user will not be created.
* Account Hook
*/
account?: {
create?: {
/**
* Hook that is called before a account is created.
* If the hook returns false, the account will not be created.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (account: Account) => Promise<
@@ -641,16 +693,40 @@ export type BetterAuthOptions = {
}
>;
/**
* Hook that is called after a user is created.
* Hook that is called after a account is created.
*/
after?: (account: Account) => Promise<void>;
};
/**
* Update hook
*/
update?: {
/**
* Hook that is called before a account is update.
* If the hook returns false, the user will not be updated.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (account: Partial<Account>) => Promise<
| boolean
| void
| {
data: Account & Record<string, any>;
}
>;
/**
* Hook that is called after a account is updated.
*/
after?: (account: Account) => Promise<void>;
};
};
verification?: {
[key in "create" | "update"]: {
/**
* Hook that is called before a user is created.
* if the hook returns false, the user will not be created.
* Verification Hook
*/
verification?: {
create?: {
/**
* Hook that is called before a verification is created.
* if the hook returns false, the verification will not be created.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (verification: Verification) => Promise<
@@ -661,7 +737,25 @@ export type BetterAuthOptions = {
}
>;
/**
* Hook that is called after a user is created.
* Hook that is called after a verification is created.
*/
after?: (verification: Verification) => Promise<void>;
};
update?: {
/**
* Hook that is called before a verification is updated.
* if the hook returns false, the verification will not be updated.
* If the hook returns an object, it'll be used instead of the original data
*/
before?: (verification: Partial<Verification>) => Promise<
| boolean
| void
| {
data: Verification & Record<string, any>;
}
>;
/**
* Hook that is called after a verification is updated.
*/
after?: (verification: Verification) => Promise<void>;
};
@@ -685,4 +779,17 @@ export type BetterAuthOptions = {
*/
onError?: (error: unknown, ctx: AuthContext) => void | Promise<void>;
};
/**
* Hooks
*/
hooks?: {
/**
* Before a request is processed
*/
before?: HookBeforeHandler;
/**
* After a request is processed
*/
after?: HookAfterHandler;
};
};

View File

@@ -17,6 +17,32 @@ export type PluginSchema = {
};
};
export type HookBeforeHandler = (context: HookEndpointContext) => Promise<
| void
| {
context?: Partial<HookEndpointContext>;
}
| Response
| {
response: Record<string, any>;
body: any;
_flag: "json";
}
>;
export type HookAfterHandler = (context: HookEndpointContext) => Promise<
| void
| {
responseHeader?: Headers;
}
| Response
| {
response: Record<string, any>;
body: any;
_flag: "json";
}
>;
export type BetterAuthPlugin = {
id: LiteralString;
/**
@@ -55,18 +81,7 @@ export type BetterAuthPlugin = {
hooks?: {
before?: {
matcher: (context: HookEndpointContext) => boolean;
handler: (context: HookEndpointContext) => Promise<
| void
| {
context?: Partial<HookEndpointContext>;
}
| Response
| {
response: Record<string, any>;
body: any;
_flag: "json";
}
>;
handler: HookBeforeHandler;
}[];
after?: {
matcher: (
@@ -75,18 +90,7 @@ export type BetterAuthPlugin = {
endpoint: Endpoint;
}>,
) => boolean;
handler: (context: HookEndpointContext) => Promise<
| void
| {
responseHeader?: Headers;
}
| Response
| {
response: Record<string, any>;
body: any;
_flag: "json";
}
>;
handler: HookAfterHandler;
}[];
};
/**