feat: rate limiter

This commit is contained in:
Bereket Engida
2024-09-18 14:20:09 +03:00
parent 6c69b804d2
commit 1554e770b4
13 changed files with 478 additions and 42 deletions

View File

@@ -315,9 +315,6 @@ function ChangePassword() {
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="currentColor" d="M2.5 18.5v-1h19v1zm.535-5.973l-.762-.442l.965-1.693h-1.93v-.884h1.93l-.965-1.642l.762-.443L4 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L4 10.835zm8 0l-.762-.442l.966-1.693H9.308v-.884h1.93l-.965-1.642l.762-.443L12 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L12 10.835zm8 0l-.762-.442l.966-1.693h-1.931v-.884h1.93l-.965-1.642l.762-.443L20 9.066l.966-1.643l.761.443l-.965 1.642h1.93v.884h-1.93l.965 1.693l-.762.442L20 10.835z"></path></svg>
<span
className="text-sm text-muted-foreground"
onClick={async () => {
}}
>
Change Password
</span>
@@ -376,7 +373,7 @@ function ChangePassword() {
})
setLoading(false);
if (res.error) {
toast.error(res.error.message);
toast.error(res.error.message || "Couldn't change your password! Make sure it's correct");
} else {
setOpen(false);
toast.success("Password changed successfully");

View File

@@ -32,10 +32,15 @@ export async function SignInButton() {
}
function checkOptimisticSession(headers: Headers) {
const guessIsSignIn = headers.get("cookie")?.includes("better-auth.session") || headers.get("cookie")?.includes("__Secure-better-auth.session-token")
return !!guessIsSignIn;
}
export function SignInFallback() {
//to avoid flash of unauthenticated state
const guessIsSignIn = headers().get("cookie")?.includes("better-auth.session") || headers().get("cookie")?.includes("__Secure-better-auth.session-token")
const guessIsSignIn = checkOptimisticSession(headers())
return (
<Link href={
guessIsSignIn ? "/dashboard" : "/sign-in"

View File

@@ -63,11 +63,8 @@
"vue": "^3.5.0"
},
"dependencies": {
"@babel/preset-typescript": "^7.24.7",
"c12": "^1.11.2",
"chalk": "^5.3.0",
"commander": "^12.1.0",
"@babel/preset-react": "^7.24.7",
"@babel/preset-typescript": "^7.24.7",
"@better-fetch/fetch": "^1.1.8",
"@better-fetch/logger": "^1.1.3",
"@chronark/access-policies": "^0.0.2",
@@ -82,8 +79,11 @@
"@simplewebauthn/server": "^10.0.1",
"arctic": "2.0.0-next.5",
"argon2": "^0.31.2",
"better-call": "0.2.3-beta.1",
"better-call": "0.2.3-beta.3",
"better-sqlite3": "^11.1.2",
"c12": "^1.11.2",
"chalk": "^5.3.0",
"commander": "^12.1.0",
"consola": "^3.2.3",
"defu": "^6.1.4",
"dotenv": "^16.4.5",

View File

@@ -3,8 +3,39 @@ import { createAuthEndpoint } from "../call";
import { getDate } from "../../utils/date";
import { deleteSessionCookie, setSessionCookie } from "../../utils/cookies";
import { sessionMiddleware } from "../middlewares/session";
import type { Session } from "../../adapters/schema";
import type { Session, User } from "../../adapters/schema";
import { z } from "zod";
import { getIp } from "../../utils/get-request-ip";
const sessionCache = new Map<
string,
{
data: {
session: Session;
user: User;
};
expiresAt: number;
}
>();
/**
* Generate a unique key for the request to cache the
* request for 5 seconds for this specific request.
*
* This is to prevent reaching to database if getSession is
* called multiple times for the same request
*/
function getRequestUniqueKey(ctx: Context<any, any>, token: string): string {
if (!ctx.request) {
return "";
}
const { method, url, headers } = ctx.request;
const userAgent = ctx.request.headers.get("User-Agent") || "";
const ip = getIp(ctx.request) || "";
const headerString = JSON.stringify(headers);
const uniqueString = `${method}:${url}:${headerString}:${userAgent}:${ip}:${token}`;
return uniqueString;
}
export const getSession = createAuthEndpoint(
"/session",
@@ -24,6 +55,16 @@ export const getSession = createAuthEndpoint(
status: 401,
});
}
const key = getRequestUniqueKey(ctx, sessionCookieToken);
const cachedSession = sessionCache.get(key);
if (cachedSession) {
if (cachedSession.expiresAt > Date.now()) {
return ctx.json(cachedSession.data);
}
sessionCache.delete(key);
}
const session =
await ctx.context.internalAdapter.findSession(sessionCookieToken);
@@ -89,6 +130,10 @@ export const getSession = createAuthEndpoint(
user: session.user,
});
}
sessionCache.set(key, {
data: session,
expiresAt: Date.now() + 5000,
});
return ctx.json(session);
} catch (error) {
ctx.context.logger.error(error);

View File

@@ -95,6 +95,11 @@ describe("type", () => {
const client = createReactClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
type ReturnedSession = ReturnType<typeof client.useSession>;
expectTypeOf<ReturnedSession>().toMatchTypeOf<{
@@ -121,6 +126,11 @@ describe("type", () => {
const client = createReactClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<() => number>();
});
@@ -128,6 +138,11 @@ describe("type", () => {
const client = createSolidClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
() => Accessor<number>
@@ -137,6 +152,11 @@ describe("type", () => {
const client = createVueClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
() => Readonly<Ref<number>>
@@ -146,6 +166,11 @@ describe("type", () => {
const client = createSvelteClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<ReadableAtom<number>>();
});
@@ -154,6 +179,11 @@ describe("type", () => {
const client = createSolidClient({
plugins: [testClientPlugin(), testClientPlugin2()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
expectTypeOf(client.setTestAtom).toEqualTypeOf<(value: boolean) => void>();
expectTypeOf(client.test.signOut).toEqualTypeOf<() => Promise<void>>();
@@ -163,9 +193,14 @@ describe("type", () => {
const client = createSolidClient({
plugins: [testClientPlugin(), testClientPlugin2(), twoFactorClient()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl: async (url, init) => {
return new Response();
},
},
});
const $infer = client.$infer;
expectTypeOf($infer.session).toEqualTypeOf<{
const $infer = client.$Infer;
expectTypeOf($infer.Session).toEqualTypeOf<{
session: {
id: string;
userId: string;

View File

@@ -30,6 +30,7 @@ export const addCurrentURL = {
},
} satisfies BetterFetchPlugin;
const cache = new Map<string, string>();
export const csrfPlugin = {
id: "csrf",
name: "CSRF Check",
@@ -42,6 +43,8 @@ export const csrfPlugin = {
if (options?.method !== "GET") {
options = options || {};
const csrfToken = cache.get("CSRF_TOKEN");
if (!csrfToken) {
const { data, error } = await betterFetch<{
csrfToken: string;
}>("/csrf", {
@@ -52,17 +55,27 @@ export const csrfPlugin = {
credentials: "include",
customFetchImpl: options.customFetchImpl,
});
if (error?.status === 404) {
if (error) {
if (error.status === 404) {
throw new BetterAuthError(
"Route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
"CSRF route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
);
}
if (error) {
throw new BetterAuthError(error.message || "Failed to get CSRF token.");
if (error.status === 429) {
return new Response(null, {
status: 429,
statusText: "Too Many Requests",
});
}
throw new BetterAuthError(
"Failed to fetch CSRF token: " + error.message,
);
}
cache.set("CSRF_TOKEN", data.csrfToken);
}
options.body = {
...options?.body,
csrfToken: data.csrfToken,
csrfToken: csrfToken,
};
}
options.credentials = "include";

View File

@@ -86,12 +86,11 @@ export const getPasskeyActions = (
const verified = await $fetch<{
passkey: Passkey;
}>("/passkey/verify-registration", {
...opts?.options,
body: {
response: res,
name: opts?.name,
},
...opts?.options,
});
if (!verified.data) {
return verified;

View File

@@ -74,7 +74,10 @@ export const passkey = (options?: PasskeyOptions) => {
const baseURL = process.env.BETTER_AUTH_URL;
const rpID =
options?.rpID ||
baseURL?.replace("http://", "").replace("https://", "") ||
baseURL
?.replace("http://", "")
.replace("https://", "")
.replace(":3000", "") ||
"localhost";
if (!rpID) {
throw new BetterAuthError(

View File

@@ -0,0 +1,19 @@
import { getSession } from "../../api/routes";
import { BetterAuthError } from "../../error/better-auth-error";
import { getIp } from "../../utils/get-request-ip";
export async function getRateLimitKey(req: Request) {
if (req.headers.get("Authorization") || req.headers.get("cookie")) {
const session = await getSession({
headers: req.headers,
});
if (session) {
return session.user.id;
}
}
const ip = getIp(req);
if (!ip) {
throw new BetterAuthError("IP not found");
}
return ip;
}

View File

@@ -0,0 +1,239 @@
import { APIError } from "better-call";
import { createAuthMiddleware } from "../../api/call";
import type { GenericEndpointContext } from "../../types/context";
import type { BetterAuthPlugin } from "../../types/plugins";
import { getRateLimitKey } from "./get-key";
interface RateLimit {
key: string;
count: number;
lastRequest: number;
}
export interface RateLimitOptions {
/**
* Enable rate limiting. You can also pass a function
* to enable rate limiting for specific endpoints.
*
* @default true
*/
enabled: boolean | ((request: Request) => boolean | Promise<boolean>);
/**
* The window to use for rate limiting. The value
* should be in seconds.
* @default 15 minutes (15 * 60)
*/
window?: number;
/**
* The maximum number of requests allowed within the window.
* @default 100
*/
max?: number;
/**
* Function to get the key to use for rate limiting.
* @default "ip" or "userId" if the user is logged in.
*/
getKey?: (request: Request) => string | Promise<string>;
storage?:
| "database"
| "memory"
| {
get: (key: string) => Promise<RateLimit | undefined>;
set: (key: string, value: RateLimit) => Promise<void>;
};
/**
* Custom rate limiting function.
*/
customRateLimit?: (request: Request) => Promise<boolean>;
/**
* Special rules to apply to specific paths.
*
* By default, endpoints that starts with "/sign-in" or "/sign-up" are added
* to the rate limiting mechanism with a count value of 2.
* @example
* ```ts
* specialRules: [
* {
* matcher: (request) => request.url.startsWith("/sign-in"),
* // This will half the amount of requests allowed for the sign-in endpoint
* countValue: 2,
* }
* ]
* ```
*/
specialRules?: {
/**
* Custom matcher to determine if the special rule should be applied.
*/
matcher: (path: string) => boolean;
/**
* The value to use for the count.
*
*/
countValue: number;
}[];
}
/**
* Rate limiting plugin for BetterAuth. It implements a simple rate limiting
* mechanism to prevent abuse. It can be configured to use a database, memory
* storage or a custom storage. It can also be configured to use a custom rate
* limiting function.
*
* @example
* ```ts
* const plugin = rateLimiter({
* enabled: true,
* window: 60,
* max: 100,
* });
* ```
*/
export const rateLimiter = (options: RateLimitOptions) => {
const opts = {
storage: "database",
max: 100,
window: 15 * 60,
specialRules: [
{
matcher(path) {
return path.startsWith("/sign-in") || path.startsWith("/sign-up");
},
countValue: 2,
},
],
...options,
} satisfies RateLimitOptions;
const schema =
opts.storage === "database"
? ({
rateLimit: {
fields: {
key: {
type: "string",
},
count: {
type: "number",
},
lastRequest: {
type: "number",
},
},
},
} as const)
: undefined;
function createDBStorage(ctx: GenericEndpointContext) {
const db = ctx.context.db;
return {
get: async (key: string) => {
const result = await db
.selectFrom("rateLimit")
.where("key", "=", key)
.selectAll()
.executeTakeFirst();
return result as RateLimit | undefined;
},
set: async (key: string, value: RateLimit, isNew: boolean = true) => {
if (isNew) {
await db
.insertInto("rateLimit")
.values({
key,
count: value.count,
lastRequest: value.lastRequest,
})
.execute();
} else {
await db
.updateTable("rateLimit")
.set({
count: value.count,
lastRequest: value.lastRequest,
})
.where("key", "=", key)
.execute();
}
},
};
}
const storage = new Map<string, RateLimit>();
function createMemoryStorage() {
return {
get: async (key: string) => {
return storage.get(key);
},
set: async (key: string, value: RateLimit) => {
storage.set(key, value);
},
};
}
return {
id: "rate-limiter",
middlewares: [
{
path: "/**",
middleware: createAuthMiddleware(async (ctx) => {
if (!ctx.request) {
return;
}
if (opts.customRateLimit) {
const shouldLimit = await opts.customRateLimit(ctx.request);
if (!shouldLimit) {
throw new APIError("TOO_MANY_REQUESTS", {
message: "Too many requests",
});
}
return;
}
const key = await getRateLimitKey(ctx.request);
const storage =
opts.storage === "database"
? createDBStorage(ctx)
: opts.storage === "memory"
? createMemoryStorage()
: opts.storage;
const rateLimit = await storage.get(key);
if (!rateLimit) {
await storage.set(key, {
key,
count: 0,
lastRequest: new Date().getTime(),
});
return;
}
const now = new Date().getTime();
const windowStart = now - opts.window * 1000;
if (
rateLimit.lastRequest >= windowStart &&
rateLimit.count >= opts.max
) {
throw new APIError("TOO_MANY_REQUESTS", {
message: "Too many requests",
});
}
if (rateLimit.lastRequest < windowStart) {
rateLimit.count = 0;
}
const count =
opts.specialRules.find((rule) => rule.matcher(ctx.path))
?.countValue ?? 1;
await storage.set(
key,
{
key,
count: rateLimit.count + count,
lastRequest: now,
},
false,
);
return;
}),
},
],
schema,
} satisfies BetterAuthPlugin;
};

View File

@@ -0,0 +1,53 @@
import { describe, it, beforeAll, expect, vi } from "vitest";
import { getTestInstance } from "../../test-utils/test-instance";
import { rateLimiter } from ".";
describe("rate-limiter", async () => {
const { client, testUser } = await getTestInstance({
plugins: [
rateLimiter({
enabled: true,
storage: "memory",
max: 10,
window: 10,
}),
],
});
it.only("should allow requests within the limit", async () => {
for (let i = 0; i < 10; i++) {
const response = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
if (i === 9) {
expect(response.error?.status).toBe(429);
} else {
expect(response.error).toBeNull();
}
}
});
it.only("should reset the limit after the window period", async () => {
vi.useFakeTimers();
// Make 10 requests to hit the limit
for (let i = 0; i < 10; i++) {
const res = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
if (res.error?.status === 429) {
break;
}
}
// Advance the timer by 11 seconds (just over the 10-second window)
vi.advanceTimersByTime(11000);
const response = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
expect(response.error).toBeNull();
vi.useRealTimers();
});
});

View File

@@ -0,0 +1,26 @@
export function getIp(req: Request): string | null {
const testIP = "127.0.0.1";
if (process.env.NODE_ENV === "test") {
return testIP;
}
const headers = [
"x-client-ip",
"x-forwarded-for",
"cf-connecting-ip",
"fastly-client-ip",
"x-real-ip",
"x-cluster-client-ip",
"x-forwarded",
"forwarded-for",
"forwarded",
];
for (const header of headers) {
const value = req.headers.get(header);
if (typeof value === "string") {
const ip = value.split(",")[0].trim();
if (ip) return ip;
}
}
return null;
}

18
pnpm-lock.yaml generated
View File

@@ -896,8 +896,8 @@ importers:
specifier: ^0.31.2
version: 0.31.2
better-call:
specifier: 0.2.3-beta.1
version: 0.2.3-beta.1
specifier: 0.2.3-beta.3
version: 0.2.3-beta.3
better-sqlite3:
specifier: ^11.1.2
version: 11.1.2
@@ -2627,7 +2627,7 @@ packages:
'@expo/bunyan@4.0.1':
resolution: {integrity: sha512-+Lla7nYSiHZirgK+U/uYzsLv/X+HaJienbD5AKX1UQZHYfWaP+9uuQluRB4GrEVWF0GZ7vEVp/jzaOT9k/SQlg==}
engines: {'0': node >=0.10.0}
engines: {node: '>=0.10.0'}
'@expo/cli@0.18.29':
resolution: {integrity: sha512-X810C48Ss+67RdZU39YEO1khNYo1RmjouRV+vVe0QhMoTe8R6OA3t+XYEdwaNbJ5p/DJN7szfHfNmX2glpC7xg==}
@@ -6216,12 +6216,12 @@ packages:
peerDependencies:
typescript: ^5.6.0-beta
better-call@0.2.3-beta.1:
resolution: {integrity: sha512-I4+sm4OIgbnUdbtI5anc5Hd9dSCRnuy3pb2WdA2guLHwzjCEHlg7bTBlP6usG50FIRBceDrxfCnFBo4I85Iazg==}
better-call@0.2.3-beta.2:
resolution: {integrity: sha512-ybOtGcR4pOsHI2XE+urR9zcmK+s0YnhJSx8KDj6ul7MUEyYOiMEnq/bylyH62/7qXuYb9q8Oqkp9NF9vWOZ4Mg==}
better-call@0.2.3-beta.3:
resolution: {integrity: sha512-1W0HB1aJ1adhbkVAi5gBQE/0uhdz+Y1Nrg3I0xZDFQx1R6vocBq7R+bVDrI5eNX+HmbzrXdY6O+Q/CfhhczJcg==}
better-opn@3.0.2:
resolution: {integrity: sha512-aVNobHnJqLiUelTaHat9DZ1qM2w0C0Eym4LPI/3JxOnSokGVdsl1T1kN7TFvsEAD8G47A6VKQ0TVHqbBnYMJlQ==}
engines: {node: '>=12.0.0'}
@@ -9133,10 +9133,12 @@ packages:
libsql@0.3.19:
resolution: {integrity: sha512-Aj5cQ5uk/6fHdmeW0TiXK42FqUlwx7ytmMLPSaUQPin5HKKKuUPD62MAbN4OEweGBBI7q1BekoEN4gPUEL6MZA==}
cpu: [x64, arm64, wasm32]
os: [darwin, linux, win32]
libsql@0.4.5:
resolution: {integrity: sha512-sorTJV6PNt94Wap27Sai5gtVLIea4Otb2LUiAUyr3p6BPOScGMKGt5F1b5X/XgkNtcsDKeX5qfeBDj+PdShclQ==}
cpu: [x64, arm64, wasm32]
os: [darwin, linux, win32]
lighthouse-logger@1.4.2:
@@ -20812,7 +20814,7 @@ snapshots:
set-cookie-parser: 2.7.0
typescript: 5.6.2
better-call@0.2.3-beta.1:
better-call@0.2.3-beta.2:
dependencies:
'@better-fetch/fetch': 1.1.8
'@types/set-cookie-parser': 2.4.10
@@ -20820,7 +20822,7 @@ snapshots:
set-cookie-parser: 2.7.0
typescript: 5.6.2
better-call@0.2.3-beta.2:
better-call@0.2.3-beta.3:
dependencies:
'@better-fetch/fetch': 1.1.8
'@types/set-cookie-parser': 2.4.10