mirror of
https://github.com/LukeHagar/better-auth.git
synced 2025-12-06 20:37:44 +00:00
feat: rate limiter
This commit is contained in:
@@ -315,9 +315,6 @@ function ChangePassword() {
|
|||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="currentColor" d="M2.5 18.5v-1h19v1zm.535-5.973l-.762-.442l.965-1.693h-1.93v-.884h1.93l-.965-1.642l.762-.443L4 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L4 10.835zm8 0l-.762-.442l.966-1.693H9.308v-.884h1.93l-.965-1.642l.762-.443L12 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L12 10.835zm8 0l-.762-.442l.966-1.693h-1.931v-.884h1.93l-.965-1.642l.762-.443L20 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L20 10.835z"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="currentColor" d="M2.5 18.5v-1h19v1zm.535-5.973l-.762-.442l.965-1.693h-1.93v-.884h1.93l-.965-1.642l.762-.443L4 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L4 10.835zm8 0l-.762-.442l.966-1.693H9.308v-.884h1.93l-.965-1.642l.762-.443L12 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L12 10.835zm8 0l-.762-.442l.966-1.693h-1.931v-.884h1.93l-.965-1.642l.762-.443L20 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L20 10.835z"></path></svg>
|
||||||
<span
|
<span
|
||||||
className="text-sm text-muted-foreground"
|
className="text-sm text-muted-foreground"
|
||||||
onClick={async () => {
|
|
||||||
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
Change Password
|
Change Password
|
||||||
</span>
|
</span>
|
||||||
@@ -376,7 +373,7 @@ function ChangePassword() {
|
|||||||
})
|
})
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
if (res.error) {
|
if (res.error) {
|
||||||
toast.error(res.error.message);
|
toast.error(res.error.message || "Couldn't change your password! Make sure it's correct");
|
||||||
} else {
|
} else {
|
||||||
setOpen(false);
|
setOpen(false);
|
||||||
toast.success("Password changed successfully");
|
toast.success("Password changed successfully");
|
||||||
|
|||||||
@@ -32,10 +32,15 @@ export async function SignInButton() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function checkOptimisticSession(headers: Headers) {
|
||||||
|
const guessIsSignIn = headers.get("cookie")?.includes("better-auth.session") || headers.get("cookie")?.includes("__Secure-better-auth.session-token")
|
||||||
|
return !!guessIsSignIn;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
export function SignInFallback() {
|
export function SignInFallback() {
|
||||||
//to avoid flash of unauthenticated state
|
//to avoid flash of unauthenticated state
|
||||||
const guessIsSignIn = headers().get("cookie")?.includes("better-auth.session") || headers().get("cookie")?.includes("__Secure-better-auth.session-token")
|
const guessIsSignIn = checkOptimisticSession(headers())
|
||||||
return (
|
return (
|
||||||
<Link href={
|
<Link href={
|
||||||
guessIsSignIn ? "/dashboard" : "/sign-in"
|
guessIsSignIn ? "/dashboard" : "/sign-in"
|
||||||
|
|||||||
@@ -63,11 +63,8 @@
|
|||||||
"vue": "^3.5.0"
|
"vue": "^3.5.0"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@babel/preset-typescript": "^7.24.7",
|
|
||||||
"c12": "^1.11.2",
|
|
||||||
"chalk": "^5.3.0",
|
|
||||||
"commander": "^12.1.0",
|
|
||||||
"@babel/preset-react": "^7.24.7",
|
"@babel/preset-react": "^7.24.7",
|
||||||
|
"@babel/preset-typescript": "^7.24.7",
|
||||||
"@better-fetch/fetch": "^1.1.8",
|
"@better-fetch/fetch": "^1.1.8",
|
||||||
"@better-fetch/logger": "^1.1.3",
|
"@better-fetch/logger": "^1.1.3",
|
||||||
"@chronark/access-policies": "^0.0.2",
|
"@chronark/access-policies": "^0.0.2",
|
||||||
@@ -82,8 +79,11 @@
|
|||||||
"@simplewebauthn/server": "^10.0.1",
|
"@simplewebauthn/server": "^10.0.1",
|
||||||
"arctic": "2.0.0-next.5",
|
"arctic": "2.0.0-next.5",
|
||||||
"argon2": "^0.31.2",
|
"argon2": "^0.31.2",
|
||||||
"better-call": "0.2.3-beta.1",
|
"better-call": "0.2.3-beta.3",
|
||||||
"better-sqlite3": "^11.1.2",
|
"better-sqlite3": "^11.1.2",
|
||||||
|
"c12": "^1.11.2",
|
||||||
|
"chalk": "^5.3.0",
|
||||||
|
"commander": "^12.1.0",
|
||||||
"consola": "^3.2.3",
|
"consola": "^3.2.3",
|
||||||
"defu": "^6.1.4",
|
"defu": "^6.1.4",
|
||||||
"dotenv": "^16.4.5",
|
"dotenv": "^16.4.5",
|
||||||
|
|||||||
@@ -3,8 +3,39 @@ import { createAuthEndpoint } from "../call";
|
|||||||
import { getDate } from "../../utils/date";
|
import { getDate } from "../../utils/date";
|
||||||
import { deleteSessionCookie, setSessionCookie } from "../../utils/cookies";
|
import { deleteSessionCookie, setSessionCookie } from "../../utils/cookies";
|
||||||
import { sessionMiddleware } from "../middlewares/session";
|
import { sessionMiddleware } from "../middlewares/session";
|
||||||
import type { Session } from "../../adapters/schema";
|
import type { Session, User } from "../../adapters/schema";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
import { getIp } from "../../utils/get-request-ip";
|
||||||
|
|
||||||
|
const sessionCache = new Map<
|
||||||
|
string,
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
session: Session;
|
||||||
|
user: User;
|
||||||
|
};
|
||||||
|
expiresAt: number;
|
||||||
|
}
|
||||||
|
>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a unique key for the request to cache the
|
||||||
|
* request for 5 seconds for this specific request.
|
||||||
|
*
|
||||||
|
* This is to prevent reaching to database if getSession is
|
||||||
|
* called multiple times for the same request
|
||||||
|
*/
|
||||||
|
function getRequestUniqueKey(ctx: Context<any, any>, token: string): string {
|
||||||
|
if (!ctx.request) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
const { method, url, headers } = ctx.request;
|
||||||
|
const userAgent = ctx.request.headers.get("User-Agent") || "";
|
||||||
|
const ip = getIp(ctx.request) || "";
|
||||||
|
const headerString = JSON.stringify(headers);
|
||||||
|
const uniqueString = `${method}:${url}:${headerString}:${userAgent}:${ip}:${token}`;
|
||||||
|
return uniqueString;
|
||||||
|
}
|
||||||
|
|
||||||
export const getSession = createAuthEndpoint(
|
export const getSession = createAuthEndpoint(
|
||||||
"/session",
|
"/session",
|
||||||
@@ -24,6 +55,16 @@ export const getSession = createAuthEndpoint(
|
|||||||
status: 401,
|
status: 401,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const key = getRequestUniqueKey(ctx, sessionCookieToken);
|
||||||
|
const cachedSession = sessionCache.get(key);
|
||||||
|
if (cachedSession) {
|
||||||
|
if (cachedSession.expiresAt > Date.now()) {
|
||||||
|
return ctx.json(cachedSession.data);
|
||||||
|
}
|
||||||
|
sessionCache.delete(key);
|
||||||
|
}
|
||||||
|
|
||||||
const session =
|
const session =
|
||||||
await ctx.context.internalAdapter.findSession(sessionCookieToken);
|
await ctx.context.internalAdapter.findSession(sessionCookieToken);
|
||||||
|
|
||||||
@@ -89,6 +130,10 @@ export const getSession = createAuthEndpoint(
|
|||||||
user: session.user,
|
user: session.user,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
sessionCache.set(key, {
|
||||||
|
data: session,
|
||||||
|
expiresAt: Date.now() + 5000,
|
||||||
|
});
|
||||||
return ctx.json(session);
|
return ctx.json(session);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
ctx.context.logger.error(error);
|
ctx.context.logger.error(error);
|
||||||
|
|||||||
@@ -95,6 +95,11 @@ describe("type", () => {
|
|||||||
const client = createReactClient({
|
const client = createReactClient({
|
||||||
plugins: [testClientPlugin()],
|
plugins: [testClientPlugin()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
type ReturnedSession = ReturnType<typeof client.useSession>;
|
type ReturnedSession = ReturnType<typeof client.useSession>;
|
||||||
expectTypeOf<ReturnedSession>().toMatchTypeOf<{
|
expectTypeOf<ReturnedSession>().toMatchTypeOf<{
|
||||||
@@ -121,6 +126,11 @@ describe("type", () => {
|
|||||||
const client = createReactClient({
|
const client = createReactClient({
|
||||||
plugins: [testClientPlugin()],
|
plugins: [testClientPlugin()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
expectTypeOf(client.useComputedAtom).toEqualTypeOf<() => number>();
|
expectTypeOf(client.useComputedAtom).toEqualTypeOf<() => number>();
|
||||||
});
|
});
|
||||||
@@ -128,6 +138,11 @@ describe("type", () => {
|
|||||||
const client = createSolidClient({
|
const client = createSolidClient({
|
||||||
plugins: [testClientPlugin()],
|
plugins: [testClientPlugin()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
|
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
|
||||||
() => Accessor<number>
|
() => Accessor<number>
|
||||||
@@ -137,6 +152,11 @@ describe("type", () => {
|
|||||||
const client = createVueClient({
|
const client = createVueClient({
|
||||||
plugins: [testClientPlugin()],
|
plugins: [testClientPlugin()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
|
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
|
||||||
() => Readonly<Ref<number>>
|
() => Readonly<Ref<number>>
|
||||||
@@ -146,6 +166,11 @@ describe("type", () => {
|
|||||||
const client = createSvelteClient({
|
const client = createSvelteClient({
|
||||||
plugins: [testClientPlugin()],
|
plugins: [testClientPlugin()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
expectTypeOf(client.useComputedAtom).toEqualTypeOf<ReadableAtom<number>>();
|
expectTypeOf(client.useComputedAtom).toEqualTypeOf<ReadableAtom<number>>();
|
||||||
});
|
});
|
||||||
@@ -154,6 +179,11 @@ describe("type", () => {
|
|||||||
const client = createSolidClient({
|
const client = createSolidClient({
|
||||||
plugins: [testClientPlugin(), testClientPlugin2()],
|
plugins: [testClientPlugin(), testClientPlugin2()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
expectTypeOf(client.setTestAtom).toEqualTypeOf<(value: boolean) => void>();
|
expectTypeOf(client.setTestAtom).toEqualTypeOf<(value: boolean) => void>();
|
||||||
expectTypeOf(client.test.signOut).toEqualTypeOf<() => Promise<void>>();
|
expectTypeOf(client.test.signOut).toEqualTypeOf<() => Promise<void>>();
|
||||||
@@ -163,9 +193,14 @@ describe("type", () => {
|
|||||||
const client = createSolidClient({
|
const client = createSolidClient({
|
||||||
plugins: [testClientPlugin(), testClientPlugin2(), twoFactorClient()],
|
plugins: [testClientPlugin(), testClientPlugin2(), twoFactorClient()],
|
||||||
baseURL: "http://localhost:3000",
|
baseURL: "http://localhost:3000",
|
||||||
|
fetchOptions: {
|
||||||
|
customFetchImpl: async (url, init) => {
|
||||||
|
return new Response();
|
||||||
|
},
|
||||||
|
},
|
||||||
});
|
});
|
||||||
const $infer = client.$infer;
|
const $infer = client.$Infer;
|
||||||
expectTypeOf($infer.session).toEqualTypeOf<{
|
expectTypeOf($infer.Session).toEqualTypeOf<{
|
||||||
session: {
|
session: {
|
||||||
id: string;
|
id: string;
|
||||||
userId: string;
|
userId: string;
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ export const addCurrentURL = {
|
|||||||
},
|
},
|
||||||
} satisfies BetterFetchPlugin;
|
} satisfies BetterFetchPlugin;
|
||||||
|
|
||||||
|
const cache = new Map<string, string>();
|
||||||
export const csrfPlugin = {
|
export const csrfPlugin = {
|
||||||
id: "csrf",
|
id: "csrf",
|
||||||
name: "CSRF Check",
|
name: "CSRF Check",
|
||||||
@@ -42,27 +43,39 @@ export const csrfPlugin = {
|
|||||||
|
|
||||||
if (options?.method !== "GET") {
|
if (options?.method !== "GET") {
|
||||||
options = options || {};
|
options = options || {};
|
||||||
const { data, error } = await betterFetch<{
|
const csrfToken = cache.get("CSRF_TOKEN");
|
||||||
csrfToken: string;
|
if (!csrfToken) {
|
||||||
}>("/csrf", {
|
const { data, error } = await betterFetch<{
|
||||||
body: undefined,
|
csrfToken: string;
|
||||||
baseURL: options.baseURL,
|
}>("/csrf", {
|
||||||
plugins: [],
|
body: undefined,
|
||||||
method: "GET",
|
baseURL: options.baseURL,
|
||||||
credentials: "include",
|
plugins: [],
|
||||||
customFetchImpl: options.customFetchImpl,
|
method: "GET",
|
||||||
});
|
credentials: "include",
|
||||||
if (error?.status === 404) {
|
customFetchImpl: options.customFetchImpl,
|
||||||
throw new BetterAuthError(
|
});
|
||||||
"Route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
|
if (error) {
|
||||||
);
|
if (error.status === 404) {
|
||||||
}
|
throw new BetterAuthError(
|
||||||
if (error) {
|
"CSRF route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
|
||||||
throw new BetterAuthError(error.message || "Failed to get CSRF token.");
|
);
|
||||||
|
}
|
||||||
|
if (error.status === 429) {
|
||||||
|
return new Response(null, {
|
||||||
|
status: 429,
|
||||||
|
statusText: "Too Many Requests",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
throw new BetterAuthError(
|
||||||
|
"Failed to fetch CSRF token: " + error.message,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
cache.set("CSRF_TOKEN", data.csrfToken);
|
||||||
}
|
}
|
||||||
options.body = {
|
options.body = {
|
||||||
...options?.body,
|
...options?.body,
|
||||||
csrfToken: data.csrfToken,
|
csrfToken: csrfToken,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
options.credentials = "include";
|
options.credentials = "include";
|
||||||
|
|||||||
@@ -86,12 +86,11 @@ export const getPasskeyActions = (
|
|||||||
const verified = await $fetch<{
|
const verified = await $fetch<{
|
||||||
passkey: Passkey;
|
passkey: Passkey;
|
||||||
}>("/passkey/verify-registration", {
|
}>("/passkey/verify-registration", {
|
||||||
|
...opts?.options,
|
||||||
body: {
|
body: {
|
||||||
response: res,
|
response: res,
|
||||||
|
|
||||||
name: opts?.name,
|
name: opts?.name,
|
||||||
},
|
},
|
||||||
...opts?.options,
|
|
||||||
});
|
});
|
||||||
if (!verified.data) {
|
if (!verified.data) {
|
||||||
return verified;
|
return verified;
|
||||||
|
|||||||
@@ -74,7 +74,10 @@ export const passkey = (options?: PasskeyOptions) => {
|
|||||||
const baseURL = process.env.BETTER_AUTH_URL;
|
const baseURL = process.env.BETTER_AUTH_URL;
|
||||||
const rpID =
|
const rpID =
|
||||||
options?.rpID ||
|
options?.rpID ||
|
||||||
baseURL?.replace("http://", "").replace("https://", "") ||
|
baseURL
|
||||||
|
?.replace("http://", "")
|
||||||
|
.replace("https://", "")
|
||||||
|
.replace(":3000", "") ||
|
||||||
"localhost";
|
"localhost";
|
||||||
if (!rpID) {
|
if (!rpID) {
|
||||||
throw new BetterAuthError(
|
throw new BetterAuthError(
|
||||||
|
|||||||
19
packages/better-auth/src/plugins/rate-limiter/get-key.ts
Normal file
19
packages/better-auth/src/plugins/rate-limiter/get-key.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { getSession } from "../../api/routes";
|
||||||
|
import { BetterAuthError } from "../../error/better-auth-error";
|
||||||
|
import { getIp } from "../../utils/get-request-ip";
|
||||||
|
|
||||||
|
export async function getRateLimitKey(req: Request) {
|
||||||
|
if (req.headers.get("Authorization") || req.headers.get("cookie")) {
|
||||||
|
const session = await getSession({
|
||||||
|
headers: req.headers,
|
||||||
|
});
|
||||||
|
if (session) {
|
||||||
|
return session.user.id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const ip = getIp(req);
|
||||||
|
if (!ip) {
|
||||||
|
throw new BetterAuthError("IP not found");
|
||||||
|
}
|
||||||
|
return ip;
|
||||||
|
}
|
||||||
239
packages/better-auth/src/plugins/rate-limiter/index.ts
Normal file
239
packages/better-auth/src/plugins/rate-limiter/index.ts
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
import { APIError } from "better-call";
|
||||||
|
import { createAuthMiddleware } from "../../api/call";
|
||||||
|
import type { GenericEndpointContext } from "../../types/context";
|
||||||
|
import type { BetterAuthPlugin } from "../../types/plugins";
|
||||||
|
import { getRateLimitKey } from "./get-key";
|
||||||
|
|
||||||
|
interface RateLimit {
|
||||||
|
key: string;
|
||||||
|
count: number;
|
||||||
|
lastRequest: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RateLimitOptions {
|
||||||
|
/**
|
||||||
|
* Enable rate limiting. You can also pass a function
|
||||||
|
* to enable rate limiting for specific endpoints.
|
||||||
|
*
|
||||||
|
* @default true
|
||||||
|
*/
|
||||||
|
enabled: boolean | ((request: Request) => boolean | Promise<boolean>);
|
||||||
|
/**
|
||||||
|
* The window to use for rate limiting. The value
|
||||||
|
* should be in seconds.
|
||||||
|
* @default 15 minutes (15 * 60)
|
||||||
|
*/
|
||||||
|
window?: number;
|
||||||
|
/**
|
||||||
|
* The maximum number of requests allowed within the window.
|
||||||
|
* @default 100
|
||||||
|
*/
|
||||||
|
max?: number;
|
||||||
|
/**
|
||||||
|
* Function to get the key to use for rate limiting.
|
||||||
|
* @default "ip" or "userId" if the user is logged in.
|
||||||
|
*/
|
||||||
|
getKey?: (request: Request) => string | Promise<string>;
|
||||||
|
storage?:
|
||||||
|
| "database"
|
||||||
|
| "memory"
|
||||||
|
| {
|
||||||
|
get: (key: string) => Promise<RateLimit | undefined>;
|
||||||
|
set: (key: string, value: RateLimit) => Promise<void>;
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Custom rate limiting function.
|
||||||
|
*/
|
||||||
|
customRateLimit?: (request: Request) => Promise<boolean>;
|
||||||
|
/**
|
||||||
|
* Special rules to apply to specific paths.
|
||||||
|
*
|
||||||
|
* By default, endpoints that starts with "/sign-in" or "/sign-up" are added
|
||||||
|
* to the rate limiting mechanism with a count value of 2.
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* specialRules: [
|
||||||
|
* {
|
||||||
|
* matcher: (request) => request.url.startsWith("/sign-in"),
|
||||||
|
* // This will half the amount of requests allowed for the sign-in endpoint
|
||||||
|
* countValue: 2,
|
||||||
|
* }
|
||||||
|
* ]
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
specialRules?: {
|
||||||
|
/**
|
||||||
|
* Custom matcher to determine if the special rule should be applied.
|
||||||
|
*/
|
||||||
|
matcher: (path: string) => boolean;
|
||||||
|
/**
|
||||||
|
* The value to use for the count.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
countValue: number;
|
||||||
|
}[];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rate limiting plugin for BetterAuth. It implements a simple rate limiting
|
||||||
|
* mechanism to prevent abuse. It can be configured to use a database, memory
|
||||||
|
* storage or a custom storage. It can also be configured to use a custom rate
|
||||||
|
* limiting function.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* const plugin = rateLimiter({
|
||||||
|
* enabled: true,
|
||||||
|
* window: 60,
|
||||||
|
* max: 100,
|
||||||
|
* });
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export const rateLimiter = (options: RateLimitOptions) => {
|
||||||
|
const opts = {
|
||||||
|
storage: "database",
|
||||||
|
max: 100,
|
||||||
|
window: 15 * 60,
|
||||||
|
specialRules: [
|
||||||
|
{
|
||||||
|
matcher(path) {
|
||||||
|
return path.startsWith("/sign-in") || path.startsWith("/sign-up");
|
||||||
|
},
|
||||||
|
countValue: 2,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
...options,
|
||||||
|
} satisfies RateLimitOptions;
|
||||||
|
const schema =
|
||||||
|
opts.storage === "database"
|
||||||
|
? ({
|
||||||
|
rateLimit: {
|
||||||
|
fields: {
|
||||||
|
key: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
count: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
lastRequest: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} as const)
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
function createDBStorage(ctx: GenericEndpointContext) {
|
||||||
|
const db = ctx.context.db;
|
||||||
|
return {
|
||||||
|
get: async (key: string) => {
|
||||||
|
const result = await db
|
||||||
|
.selectFrom("rateLimit")
|
||||||
|
.where("key", "=", key)
|
||||||
|
.selectAll()
|
||||||
|
.executeTakeFirst();
|
||||||
|
return result as RateLimit | undefined;
|
||||||
|
},
|
||||||
|
set: async (key: string, value: RateLimit, isNew: boolean = true) => {
|
||||||
|
if (isNew) {
|
||||||
|
await db
|
||||||
|
.insertInto("rateLimit")
|
||||||
|
.values({
|
||||||
|
key,
|
||||||
|
count: value.count,
|
||||||
|
lastRequest: value.lastRequest,
|
||||||
|
})
|
||||||
|
.execute();
|
||||||
|
} else {
|
||||||
|
await db
|
||||||
|
.updateTable("rateLimit")
|
||||||
|
.set({
|
||||||
|
count: value.count,
|
||||||
|
lastRequest: value.lastRequest,
|
||||||
|
})
|
||||||
|
.where("key", "=", key)
|
||||||
|
.execute();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const storage = new Map<string, RateLimit>();
|
||||||
|
function createMemoryStorage() {
|
||||||
|
return {
|
||||||
|
get: async (key: string) => {
|
||||||
|
return storage.get(key);
|
||||||
|
},
|
||||||
|
set: async (key: string, value: RateLimit) => {
|
||||||
|
storage.set(key, value);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: "rate-limiter",
|
||||||
|
middlewares: [
|
||||||
|
{
|
||||||
|
path: "/**",
|
||||||
|
middleware: createAuthMiddleware(async (ctx) => {
|
||||||
|
if (!ctx.request) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (opts.customRateLimit) {
|
||||||
|
const shouldLimit = await opts.customRateLimit(ctx.request);
|
||||||
|
if (!shouldLimit) {
|
||||||
|
throw new APIError("TOO_MANY_REQUESTS", {
|
||||||
|
message: "Too many requests",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const key = await getRateLimitKey(ctx.request);
|
||||||
|
const storage =
|
||||||
|
opts.storage === "database"
|
||||||
|
? createDBStorage(ctx)
|
||||||
|
: opts.storage === "memory"
|
||||||
|
? createMemoryStorage()
|
||||||
|
: opts.storage;
|
||||||
|
const rateLimit = await storage.get(key);
|
||||||
|
if (!rateLimit) {
|
||||||
|
await storage.set(key, {
|
||||||
|
key,
|
||||||
|
count: 0,
|
||||||
|
lastRequest: new Date().getTime(),
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const now = new Date().getTime();
|
||||||
|
const windowStart = now - opts.window * 1000;
|
||||||
|
if (
|
||||||
|
rateLimit.lastRequest >= windowStart &&
|
||||||
|
rateLimit.count >= opts.max
|
||||||
|
) {
|
||||||
|
throw new APIError("TOO_MANY_REQUESTS", {
|
||||||
|
message: "Too many requests",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rateLimit.lastRequest < windowStart) {
|
||||||
|
rateLimit.count = 0;
|
||||||
|
}
|
||||||
|
const count =
|
||||||
|
opts.specialRules.find((rule) => rule.matcher(ctx.path))
|
||||||
|
?.countValue ?? 1;
|
||||||
|
|
||||||
|
await storage.set(
|
||||||
|
key,
|
||||||
|
{
|
||||||
|
key,
|
||||||
|
count: rateLimit.count + count,
|
||||||
|
lastRequest: now,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
schema,
|
||||||
|
} satisfies BetterAuthPlugin;
|
||||||
|
};
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import { describe, it, beforeAll, expect, vi } from "vitest";
|
||||||
|
import { getTestInstance } from "../../test-utils/test-instance";
|
||||||
|
import { rateLimiter } from ".";
|
||||||
|
|
||||||
|
describe("rate-limiter", async () => {
|
||||||
|
const { client, testUser } = await getTestInstance({
|
||||||
|
plugins: [
|
||||||
|
rateLimiter({
|
||||||
|
enabled: true,
|
||||||
|
storage: "memory",
|
||||||
|
max: 10,
|
||||||
|
window: 10,
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
it.only("should allow requests within the limit", async () => {
|
||||||
|
for (let i = 0; i < 10; i++) {
|
||||||
|
const response = await client.signIn.email({
|
||||||
|
email: testUser.email,
|
||||||
|
password: testUser.password,
|
||||||
|
});
|
||||||
|
if (i === 9) {
|
||||||
|
expect(response.error?.status).toBe(429);
|
||||||
|
} else {
|
||||||
|
expect(response.error).toBeNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it.only("should reset the limit after the window period", async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
|
||||||
|
// Make 10 requests to hit the limit
|
||||||
|
for (let i = 0; i < 10; i++) {
|
||||||
|
const res = await client.signIn.email({
|
||||||
|
email: testUser.email,
|
||||||
|
password: testUser.password,
|
||||||
|
});
|
||||||
|
if (res.error?.status === 429) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Advance the timer by 11 seconds (just over the 10-second window)
|
||||||
|
vi.advanceTimersByTime(11000);
|
||||||
|
const response = await client.signIn.email({
|
||||||
|
email: testUser.email,
|
||||||
|
password: testUser.password,
|
||||||
|
});
|
||||||
|
expect(response.error).toBeNull();
|
||||||
|
vi.useRealTimers();
|
||||||
|
});
|
||||||
|
});
|
||||||
26
packages/better-auth/src/utils/get-request-ip.ts
Normal file
26
packages/better-auth/src/utils/get-request-ip.ts
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
export function getIp(req: Request): string | null {
|
||||||
|
const testIP = "127.0.0.1";
|
||||||
|
if (process.env.NODE_ENV === "test") {
|
||||||
|
return testIP;
|
||||||
|
}
|
||||||
|
const headers = [
|
||||||
|
"x-client-ip",
|
||||||
|
"x-forwarded-for",
|
||||||
|
"cf-connecting-ip",
|
||||||
|
"fastly-client-ip",
|
||||||
|
"x-real-ip",
|
||||||
|
"x-cluster-client-ip",
|
||||||
|
"x-forwarded",
|
||||||
|
"forwarded-for",
|
||||||
|
"forwarded",
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const header of headers) {
|
||||||
|
const value = req.headers.get(header);
|
||||||
|
if (typeof value === "string") {
|
||||||
|
const ip = value.split(",")[0].trim();
|
||||||
|
if (ip) return ip;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
18
pnpm-lock.yaml
generated
18
pnpm-lock.yaml
generated
@@ -896,8 +896,8 @@ importers:
|
|||||||
specifier: ^0.31.2
|
specifier: ^0.31.2
|
||||||
version: 0.31.2
|
version: 0.31.2
|
||||||
better-call:
|
better-call:
|
||||||
specifier: 0.2.3-beta.1
|
specifier: 0.2.3-beta.3
|
||||||
version: 0.2.3-beta.1
|
version: 0.2.3-beta.3
|
||||||
better-sqlite3:
|
better-sqlite3:
|
||||||
specifier: ^11.1.2
|
specifier: ^11.1.2
|
||||||
version: 11.1.2
|
version: 11.1.2
|
||||||
@@ -2627,7 +2627,7 @@ packages:
|
|||||||
|
|
||||||
'@expo/bunyan@4.0.1':
|
'@expo/bunyan@4.0.1':
|
||||||
resolution: {integrity: sha512-+Lla7nYSiHZirgK+U/uYzsLv/X+HaJienbD5AKX1UQZHYfWaP+9uuQluRB4GrEVWF0GZ7vEVp/jzaOT9k/SQlg==}
|
resolution: {integrity: sha512-+Lla7nYSiHZirgK+U/uYzsLv/X+HaJienbD5AKX1UQZHYfWaP+9uuQluRB4GrEVWF0GZ7vEVp/jzaOT9k/SQlg==}
|
||||||
engines: {'0': node >=0.10.0}
|
engines: {node: '>=0.10.0'}
|
||||||
|
|
||||||
'@expo/cli@0.18.29':
|
'@expo/cli@0.18.29':
|
||||||
resolution: {integrity: sha512-X810C48Ss+67RdZU39YEO1khNYo1RmjouRV+vVe0QhMoTe8R6OA3t+XYEdwaNbJ5p/DJN7szfHfNmX2glpC7xg==}
|
resolution: {integrity: sha512-X810C48Ss+67RdZU39YEO1khNYo1RmjouRV+vVe0QhMoTe8R6OA3t+XYEdwaNbJ5p/DJN7szfHfNmX2glpC7xg==}
|
||||||
@@ -6216,12 +6216,12 @@ packages:
|
|||||||
peerDependencies:
|
peerDependencies:
|
||||||
typescript: ^5.6.0-beta
|
typescript: ^5.6.0-beta
|
||||||
|
|
||||||
better-call@0.2.3-beta.1:
|
|
||||||
resolution: {integrity: sha512-I4+sm4OIgbnUdbtI5anc5Hd9dSCRnuy3pb2WdA2guLHwzjCEHlg7bTBlP6usG50FIRBceDrxfCnFBo4I85Iazg==}
|
|
||||||
|
|
||||||
better-call@0.2.3-beta.2:
|
better-call@0.2.3-beta.2:
|
||||||
resolution: {integrity: sha512-ybOtGcR4pOsHI2XE+urR9zcmK+s0YnhJSx8KDj6ul7MUEyYOiMEnq/bylyH62/7qXuYb9q8Oqkp9NF9vWOZ4Mg==}
|
resolution: {integrity: sha512-ybOtGcR4pOsHI2XE+urR9zcmK+s0YnhJSx8KDj6ul7MUEyYOiMEnq/bylyH62/7qXuYb9q8Oqkp9NF9vWOZ4Mg==}
|
||||||
|
|
||||||
|
better-call@0.2.3-beta.3:
|
||||||
|
resolution: {integrity: sha512-1W0HB1aJ1adhbkVAi5gBQE/0uhdz+Y1Nrg3I0xZDFQx1R6vocBq7R+bVDrI5eNX+HmbzrXdY6O+Q/CfhhczJcg==}
|
||||||
|
|
||||||
better-opn@3.0.2:
|
better-opn@3.0.2:
|
||||||
resolution: {integrity: sha512-aVNobHnJqLiUelTaHat9DZ1qM2w0C0Eym4LPI/3JxOnSokGVdsl1T1kN7TFvsEAD8G47A6VKQ0TVHqbBnYMJlQ==}
|
resolution: {integrity: sha512-aVNobHnJqLiUelTaHat9DZ1qM2w0C0Eym4LPI/3JxOnSokGVdsl1T1kN7TFvsEAD8G47A6VKQ0TVHqbBnYMJlQ==}
|
||||||
engines: {node: '>=12.0.0'}
|
engines: {node: '>=12.0.0'}
|
||||||
@@ -9133,10 +9133,12 @@ packages:
|
|||||||
|
|
||||||
libsql@0.3.19:
|
libsql@0.3.19:
|
||||||
resolution: {integrity: sha512-Aj5cQ5uk/6fHdmeW0TiXK42FqUlwx7ytmMLPSaUQPin5HKKKuUPD62MAbN4OEweGBBI7q1BekoEN4gPUEL6MZA==}
|
resolution: {integrity: sha512-Aj5cQ5uk/6fHdmeW0TiXK42FqUlwx7ytmMLPSaUQPin5HKKKuUPD62MAbN4OEweGBBI7q1BekoEN4gPUEL6MZA==}
|
||||||
|
cpu: [x64, arm64, wasm32]
|
||||||
os: [darwin, linux, win32]
|
os: [darwin, linux, win32]
|
||||||
|
|
||||||
libsql@0.4.5:
|
libsql@0.4.5:
|
||||||
resolution: {integrity: sha512-sorTJV6PNt94Wap27Sai5gtVLIea4Otb2LUiAUyr3p6BPOScGMKGt5F1b5X/XgkNtcsDKeX5qfeBDj+PdShclQ==}
|
resolution: {integrity: sha512-sorTJV6PNt94Wap27Sai5gtVLIea4Otb2LUiAUyr3p6BPOScGMKGt5F1b5X/XgkNtcsDKeX5qfeBDj+PdShclQ==}
|
||||||
|
cpu: [x64, arm64, wasm32]
|
||||||
os: [darwin, linux, win32]
|
os: [darwin, linux, win32]
|
||||||
|
|
||||||
lighthouse-logger@1.4.2:
|
lighthouse-logger@1.4.2:
|
||||||
@@ -20812,7 +20814,7 @@ snapshots:
|
|||||||
set-cookie-parser: 2.7.0
|
set-cookie-parser: 2.7.0
|
||||||
typescript: 5.6.2
|
typescript: 5.6.2
|
||||||
|
|
||||||
better-call@0.2.3-beta.1:
|
better-call@0.2.3-beta.2:
|
||||||
dependencies:
|
dependencies:
|
||||||
'@better-fetch/fetch': 1.1.8
|
'@better-fetch/fetch': 1.1.8
|
||||||
'@types/set-cookie-parser': 2.4.10
|
'@types/set-cookie-parser': 2.4.10
|
||||||
@@ -20820,7 +20822,7 @@ snapshots:
|
|||||||
set-cookie-parser: 2.7.0
|
set-cookie-parser: 2.7.0
|
||||||
typescript: 5.6.2
|
typescript: 5.6.2
|
||||||
|
|
||||||
better-call@0.2.3-beta.2:
|
better-call@0.2.3-beta.3:
|
||||||
dependencies:
|
dependencies:
|
||||||
'@better-fetch/fetch': 1.1.8
|
'@better-fetch/fetch': 1.1.8
|
||||||
'@types/set-cookie-parser': 2.4.10
|
'@types/set-cookie-parser': 2.4.10
|
||||||
|
|||||||
Reference in New Issue
Block a user