diff --git a/demo/nextjs/app/dashboard/user-card.tsx b/demo/nextjs/app/dashboard/user-card.tsx index 40346ffc..4e576887 100644 --- a/demo/nextjs/app/dashboard/user-card.tsx +++ b/demo/nextjs/app/dashboard/user-card.tsx @@ -315,9 +315,6 @@ function ChangePassword() { { - - }} > Change Password @@ -376,7 +373,7 @@ function ChangePassword() { }) setLoading(false); if (res.error) { - toast.error(res.error.message); + toast.error(res.error.message || "Couldn't change your password! Make sure it's correct"); } else { setOpen(false); toast.success("Password changed successfully"); diff --git a/demo/nextjs/components/sign-in-btn.tsx b/demo/nextjs/components/sign-in-btn.tsx index b7ca8898..192b9c1a 100644 --- a/demo/nextjs/components/sign-in-btn.tsx +++ b/demo/nextjs/components/sign-in-btn.tsx @@ -32,10 +32,15 @@ export async function SignInButton() { } +function checkOptimisticSession(headers: Headers) { + const guessIsSignIn = headers.get("cookie")?.includes("better-auth.session") || headers.get("cookie")?.includes("__Secure-better-auth.session-token") + return !!guessIsSignIn; +} + export function SignInFallback() { //to avoid flash of unauthenticated state - const guessIsSignIn = headers().get("cookie")?.includes("better-auth.session") || headers().get("cookie")?.includes("__Secure-better-auth.session-token") + const guessIsSignIn = checkOptimisticSession(headers()) return ( (); + +/** + * Generate a unique key for the request to cache the + * request for 5 seconds for this specific request. + * + * This is to prevent reaching to database if getSession is + * called multiple times for the same request + */ +function getRequestUniqueKey(ctx: Context, token: string): string { + if (!ctx.request) { + return ""; + } + const { method, url, headers } = ctx.request; + const userAgent = ctx.request.headers.get("User-Agent") || ""; + const ip = getIp(ctx.request) || ""; + const headerString = JSON.stringify(headers); + const uniqueString = `${method}:${url}:${headerString}:${userAgent}:${ip}:${token}`; + return uniqueString; +} export const getSession = createAuthEndpoint( "/session", @@ -24,6 +55,16 @@ export const getSession = createAuthEndpoint( status: 401, }); } + + const key = getRequestUniqueKey(ctx, sessionCookieToken); + const cachedSession = sessionCache.get(key); + if (cachedSession) { + if (cachedSession.expiresAt > Date.now()) { + return ctx.json(cachedSession.data); + } + sessionCache.delete(key); + } + const session = await ctx.context.internalAdapter.findSession(sessionCookieToken); @@ -89,6 +130,10 @@ export const getSession = createAuthEndpoint( user: session.user, }); } + sessionCache.set(key, { + data: session, + expiresAt: Date.now() + 5000, + }); return ctx.json(session); } catch (error) { ctx.context.logger.error(error); diff --git a/packages/better-auth/src/client/client.test.ts b/packages/better-auth/src/client/client.test.ts index 699a4202..d3cf8796 100644 --- a/packages/better-auth/src/client/client.test.ts +++ b/packages/better-auth/src/client/client.test.ts @@ -95,6 +95,11 @@ describe("type", () => { const client = createReactClient({ plugins: [testClientPlugin()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); type ReturnedSession = ReturnType; expectTypeOf().toMatchTypeOf<{ @@ -121,6 +126,11 @@ describe("type", () => { const client = createReactClient({ plugins: [testClientPlugin()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); expectTypeOf(client.useComputedAtom).toEqualTypeOf<() => number>(); }); @@ -128,6 +138,11 @@ describe("type", () => { const client = createSolidClient({ plugins: [testClientPlugin()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); expectTypeOf(client.useComputedAtom).toEqualTypeOf< () => Accessor @@ -137,6 +152,11 @@ describe("type", () => { const client = createVueClient({ plugins: [testClientPlugin()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); expectTypeOf(client.useComputedAtom).toEqualTypeOf< () => Readonly> @@ -146,6 +166,11 @@ describe("type", () => { const client = createSvelteClient({ plugins: [testClientPlugin()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); expectTypeOf(client.useComputedAtom).toEqualTypeOf>(); }); @@ -154,6 +179,11 @@ describe("type", () => { const client = createSolidClient({ plugins: [testClientPlugin(), testClientPlugin2()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); expectTypeOf(client.setTestAtom).toEqualTypeOf<(value: boolean) => void>(); expectTypeOf(client.test.signOut).toEqualTypeOf<() => Promise>(); @@ -163,9 +193,14 @@ describe("type", () => { const client = createSolidClient({ plugins: [testClientPlugin(), testClientPlugin2(), twoFactorClient()], baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl: async (url, init) => { + return new Response(); + }, + }, }); - const $infer = client.$infer; - expectTypeOf($infer.session).toEqualTypeOf<{ + const $infer = client.$Infer; + expectTypeOf($infer.Session).toEqualTypeOf<{ session: { id: string; userId: string; diff --git a/packages/better-auth/src/client/fetch-plugins.ts b/packages/better-auth/src/client/fetch-plugins.ts index d1dcf5e7..ecde1910 100644 --- a/packages/better-auth/src/client/fetch-plugins.ts +++ b/packages/better-auth/src/client/fetch-plugins.ts @@ -30,6 +30,7 @@ export const addCurrentURL = { }, } satisfies BetterFetchPlugin; +const cache = new Map(); export const csrfPlugin = { id: "csrf", name: "CSRF Check", @@ -42,27 +43,39 @@ export const csrfPlugin = { if (options?.method !== "GET") { options = options || {}; - const { data, error } = await betterFetch<{ - csrfToken: string; - }>("/csrf", { - body: undefined, - baseURL: options.baseURL, - plugins: [], - method: "GET", - credentials: "include", - customFetchImpl: options.customFetchImpl, - }); - if (error?.status === 404) { - throw new BetterAuthError( - "Route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).", - ); - } - if (error) { - throw new BetterAuthError(error.message || "Failed to get CSRF token."); + const csrfToken = cache.get("CSRF_TOKEN"); + if (!csrfToken) { + const { data, error } = await betterFetch<{ + csrfToken: string; + }>("/csrf", { + body: undefined, + baseURL: options.baseURL, + plugins: [], + method: "GET", + credentials: "include", + customFetchImpl: options.customFetchImpl, + }); + if (error) { + if (error.status === 404) { + throw new BetterAuthError( + "CSRF route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).", + ); + } + if (error.status === 429) { + return new Response(null, { + status: 429, + statusText: "Too Many Requests", + }); + } + throw new BetterAuthError( + "Failed to fetch CSRF token: " + error.message, + ); + } + cache.set("CSRF_TOKEN", data.csrfToken); } options.body = { ...options?.body, - csrfToken: data.csrfToken, + csrfToken: csrfToken, }; } options.credentials = "include"; diff --git a/packages/better-auth/src/plugins/passkey/client.ts b/packages/better-auth/src/plugins/passkey/client.ts index e0313055..9b7a31c3 100644 --- a/packages/better-auth/src/plugins/passkey/client.ts +++ b/packages/better-auth/src/plugins/passkey/client.ts @@ -86,12 +86,11 @@ export const getPasskeyActions = ( const verified = await $fetch<{ passkey: Passkey; }>("/passkey/verify-registration", { + ...opts?.options, body: { response: res, - name: opts?.name, }, - ...opts?.options, }); if (!verified.data) { return verified; diff --git a/packages/better-auth/src/plugins/passkey/index.ts b/packages/better-auth/src/plugins/passkey/index.ts index 439a640d..02d3b772 100644 --- a/packages/better-auth/src/plugins/passkey/index.ts +++ b/packages/better-auth/src/plugins/passkey/index.ts @@ -74,7 +74,10 @@ export const passkey = (options?: PasskeyOptions) => { const baseURL = process.env.BETTER_AUTH_URL; const rpID = options?.rpID || - baseURL?.replace("http://", "").replace("https://", "") || + baseURL + ?.replace("http://", "") + .replace("https://", "") + .replace(":3000", "") || "localhost"; if (!rpID) { throw new BetterAuthError( diff --git a/packages/better-auth/src/plugins/rate-limiter/get-key.ts b/packages/better-auth/src/plugins/rate-limiter/get-key.ts new file mode 100644 index 00000000..7e718e93 --- /dev/null +++ b/packages/better-auth/src/plugins/rate-limiter/get-key.ts @@ -0,0 +1,19 @@ +import { getSession } from "../../api/routes"; +import { BetterAuthError } from "../../error/better-auth-error"; +import { getIp } from "../../utils/get-request-ip"; + +export async function getRateLimitKey(req: Request) { + if (req.headers.get("Authorization") || req.headers.get("cookie")) { + const session = await getSession({ + headers: req.headers, + }); + if (session) { + return session.user.id; + } + } + const ip = getIp(req); + if (!ip) { + throw new BetterAuthError("IP not found"); + } + return ip; +} diff --git a/packages/better-auth/src/plugins/rate-limiter/index.ts b/packages/better-auth/src/plugins/rate-limiter/index.ts new file mode 100644 index 00000000..802f4fc7 --- /dev/null +++ b/packages/better-auth/src/plugins/rate-limiter/index.ts @@ -0,0 +1,239 @@ +import { APIError } from "better-call"; +import { createAuthMiddleware } from "../../api/call"; +import type { GenericEndpointContext } from "../../types/context"; +import type { BetterAuthPlugin } from "../../types/plugins"; +import { getRateLimitKey } from "./get-key"; + +interface RateLimit { + key: string; + count: number; + lastRequest: number; +} + +export interface RateLimitOptions { + /** + * Enable rate limiting. You can also pass a function + * to enable rate limiting for specific endpoints. + * + * @default true + */ + enabled: boolean | ((request: Request) => boolean | Promise); + /** + * The window to use for rate limiting. The value + * should be in seconds. + * @default 15 minutes (15 * 60) + */ + window?: number; + /** + * The maximum number of requests allowed within the window. + * @default 100 + */ + max?: number; + /** + * Function to get the key to use for rate limiting. + * @default "ip" or "userId" if the user is logged in. + */ + getKey?: (request: Request) => string | Promise; + storage?: + | "database" + | "memory" + | { + get: (key: string) => Promise; + set: (key: string, value: RateLimit) => Promise; + }; + /** + * Custom rate limiting function. + */ + customRateLimit?: (request: Request) => Promise; + /** + * Special rules to apply to specific paths. + * + * By default, endpoints that starts with "/sign-in" or "/sign-up" are added + * to the rate limiting mechanism with a count value of 2. + * @example + * ```ts + * specialRules: [ + * { + * matcher: (request) => request.url.startsWith("/sign-in"), + * // This will half the amount of requests allowed for the sign-in endpoint + * countValue: 2, + * } + * ] + * ``` + */ + specialRules?: { + /** + * Custom matcher to determine if the special rule should be applied. + */ + matcher: (path: string) => boolean; + /** + * The value to use for the count. + * + */ + countValue: number; + }[]; +} + +/** + * Rate limiting plugin for BetterAuth. It implements a simple rate limiting + * mechanism to prevent abuse. It can be configured to use a database, memory + * storage or a custom storage. It can also be configured to use a custom rate + * limiting function. + * + * @example + * ```ts + * const plugin = rateLimiter({ + * enabled: true, + * window: 60, + * max: 100, + * }); + * ``` + */ +export const rateLimiter = (options: RateLimitOptions) => { + const opts = { + storage: "database", + max: 100, + window: 15 * 60, + specialRules: [ + { + matcher(path) { + return path.startsWith("/sign-in") || path.startsWith("/sign-up"); + }, + countValue: 2, + }, + ], + ...options, + } satisfies RateLimitOptions; + const schema = + opts.storage === "database" + ? ({ + rateLimit: { + fields: { + key: { + type: "string", + }, + count: { + type: "number", + }, + lastRequest: { + type: "number", + }, + }, + }, + } as const) + : undefined; + + function createDBStorage(ctx: GenericEndpointContext) { + const db = ctx.context.db; + return { + get: async (key: string) => { + const result = await db + .selectFrom("rateLimit") + .where("key", "=", key) + .selectAll() + .executeTakeFirst(); + return result as RateLimit | undefined; + }, + set: async (key: string, value: RateLimit, isNew: boolean = true) => { + if (isNew) { + await db + .insertInto("rateLimit") + .values({ + key, + count: value.count, + lastRequest: value.lastRequest, + }) + .execute(); + } else { + await db + .updateTable("rateLimit") + .set({ + count: value.count, + lastRequest: value.lastRequest, + }) + .where("key", "=", key) + .execute(); + } + }, + }; + } + const storage = new Map(); + function createMemoryStorage() { + return { + get: async (key: string) => { + return storage.get(key); + }, + set: async (key: string, value: RateLimit) => { + storage.set(key, value); + }, + }; + } + + return { + id: "rate-limiter", + middlewares: [ + { + path: "/**", + middleware: createAuthMiddleware(async (ctx) => { + if (!ctx.request) { + return; + } + if (opts.customRateLimit) { + const shouldLimit = await opts.customRateLimit(ctx.request); + if (!shouldLimit) { + throw new APIError("TOO_MANY_REQUESTS", { + message: "Too many requests", + }); + } + return; + } + const key = await getRateLimitKey(ctx.request); + const storage = + opts.storage === "database" + ? createDBStorage(ctx) + : opts.storage === "memory" + ? createMemoryStorage() + : opts.storage; + const rateLimit = await storage.get(key); + if (!rateLimit) { + await storage.set(key, { + key, + count: 0, + lastRequest: new Date().getTime(), + }); + return; + } + const now = new Date().getTime(); + const windowStart = now - opts.window * 1000; + if ( + rateLimit.lastRequest >= windowStart && + rateLimit.count >= opts.max + ) { + throw new APIError("TOO_MANY_REQUESTS", { + message: "Too many requests", + }); + } + + if (rateLimit.lastRequest < windowStart) { + rateLimit.count = 0; + } + const count = + opts.specialRules.find((rule) => rule.matcher(ctx.path)) + ?.countValue ?? 1; + + await storage.set( + key, + { + key, + count: rateLimit.count + count, + lastRequest: now, + }, + false, + ); + return; + }), + }, + ], + schema, + } satisfies BetterAuthPlugin; +}; diff --git a/packages/better-auth/src/plugins/rate-limiter/rate-limiter.test.ts b/packages/better-auth/src/plugins/rate-limiter/rate-limiter.test.ts new file mode 100644 index 00000000..25e77d4b --- /dev/null +++ b/packages/better-auth/src/plugins/rate-limiter/rate-limiter.test.ts @@ -0,0 +1,53 @@ +import { describe, it, beforeAll, expect, vi } from "vitest"; +import { getTestInstance } from "../../test-utils/test-instance"; +import { rateLimiter } from "."; + +describe("rate-limiter", async () => { + const { client, testUser } = await getTestInstance({ + plugins: [ + rateLimiter({ + enabled: true, + storage: "memory", + max: 10, + window: 10, + }), + ], + }); + + it.only("should allow requests within the limit", async () => { + for (let i = 0; i < 10; i++) { + const response = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + }); + if (i === 9) { + expect(response.error?.status).toBe(429); + } else { + expect(response.error).toBeNull(); + } + } + }); + + it.only("should reset the limit after the window period", async () => { + vi.useFakeTimers(); + + // Make 10 requests to hit the limit + for (let i = 0; i < 10; i++) { + const res = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + }); + if (res.error?.status === 429) { + break; + } + } + // Advance the timer by 11 seconds (just over the 10-second window) + vi.advanceTimersByTime(11000); + const response = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + }); + expect(response.error).toBeNull(); + vi.useRealTimers(); + }); +}); diff --git a/packages/better-auth/src/utils/get-request-ip.ts b/packages/better-auth/src/utils/get-request-ip.ts new file mode 100644 index 00000000..9ef54b90 --- /dev/null +++ b/packages/better-auth/src/utils/get-request-ip.ts @@ -0,0 +1,26 @@ +export function getIp(req: Request): string | null { + const testIP = "127.0.0.1"; + if (process.env.NODE_ENV === "test") { + return testIP; + } + const headers = [ + "x-client-ip", + "x-forwarded-for", + "cf-connecting-ip", + "fastly-client-ip", + "x-real-ip", + "x-cluster-client-ip", + "x-forwarded", + "forwarded-for", + "forwarded", + ]; + + for (const header of headers) { + const value = req.headers.get(header); + if (typeof value === "string") { + const ip = value.split(",")[0].trim(); + if (ip) return ip; + } + } + return null; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3da083dc..5b0c50df 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -896,8 +896,8 @@ importers: specifier: ^0.31.2 version: 0.31.2 better-call: - specifier: 0.2.3-beta.1 - version: 0.2.3-beta.1 + specifier: 0.2.3-beta.3 + version: 0.2.3-beta.3 better-sqlite3: specifier: ^11.1.2 version: 11.1.2 @@ -2627,7 +2627,7 @@ packages: '@expo/bunyan@4.0.1': resolution: {integrity: sha512-+Lla7nYSiHZirgK+U/uYzsLv/X+HaJienbD5AKX1UQZHYfWaP+9uuQluRB4GrEVWF0GZ7vEVp/jzaOT9k/SQlg==} - engines: {'0': node >=0.10.0} + engines: {node: '>=0.10.0'} '@expo/cli@0.18.29': resolution: {integrity: sha512-X810C48Ss+67RdZU39YEO1khNYo1RmjouRV+vVe0QhMoTe8R6OA3t+XYEdwaNbJ5p/DJN7szfHfNmX2glpC7xg==} @@ -6216,12 +6216,12 @@ packages: peerDependencies: typescript: ^5.6.0-beta - better-call@0.2.3-beta.1: - resolution: {integrity: sha512-I4+sm4OIgbnUdbtI5anc5Hd9dSCRnuy3pb2WdA2guLHwzjCEHlg7bTBlP6usG50FIRBceDrxfCnFBo4I85Iazg==} - better-call@0.2.3-beta.2: resolution: {integrity: sha512-ybOtGcR4pOsHI2XE+urR9zcmK+s0YnhJSx8KDj6ul7MUEyYOiMEnq/bylyH62/7qXuYb9q8Oqkp9NF9vWOZ4Mg==} + better-call@0.2.3-beta.3: + resolution: {integrity: sha512-1W0HB1aJ1adhbkVAi5gBQE/0uhdz+Y1Nrg3I0xZDFQx1R6vocBq7R+bVDrI5eNX+HmbzrXdY6O+Q/CfhhczJcg==} + better-opn@3.0.2: resolution: {integrity: sha512-aVNobHnJqLiUelTaHat9DZ1qM2w0C0Eym4LPI/3JxOnSokGVdsl1T1kN7TFvsEAD8G47A6VKQ0TVHqbBnYMJlQ==} engines: {node: '>=12.0.0'} @@ -9133,10 +9133,12 @@ packages: libsql@0.3.19: resolution: {integrity: sha512-Aj5cQ5uk/6fHdmeW0TiXK42FqUlwx7ytmMLPSaUQPin5HKKKuUPD62MAbN4OEweGBBI7q1BekoEN4gPUEL6MZA==} + cpu: [x64, arm64, wasm32] os: [darwin, linux, win32] libsql@0.4.5: resolution: {integrity: sha512-sorTJV6PNt94Wap27Sai5gtVLIea4Otb2LUiAUyr3p6BPOScGMKGt5F1b5X/XgkNtcsDKeX5qfeBDj+PdShclQ==} + cpu: [x64, arm64, wasm32] os: [darwin, linux, win32] lighthouse-logger@1.4.2: @@ -20812,7 +20814,7 @@ snapshots: set-cookie-parser: 2.7.0 typescript: 5.6.2 - better-call@0.2.3-beta.1: + better-call@0.2.3-beta.2: dependencies: '@better-fetch/fetch': 1.1.8 '@types/set-cookie-parser': 2.4.10 @@ -20820,7 +20822,7 @@ snapshots: set-cookie-parser: 2.7.0 typescript: 5.6.2 - better-call@0.2.3-beta.2: + better-call@0.2.3-beta.3: dependencies: '@better-fetch/fetch': 1.1.8 '@types/set-cookie-parser': 2.4.10