diff --git a/demo/nextjs/app/dashboard/user-card.tsx b/demo/nextjs/app/dashboard/user-card.tsx
index 40346ffc..4e576887 100644
--- a/demo/nextjs/app/dashboard/user-card.tsx
+++ b/demo/nextjs/app/dashboard/user-card.tsx
@@ -315,9 +315,6 @@ function ChangePassword() {
{
-
- }}
>
Change Password
@@ -376,7 +373,7 @@ function ChangePassword() {
})
setLoading(false);
if (res.error) {
- toast.error(res.error.message);
+ toast.error(res.error.message || "Couldn't change your password! Make sure it's correct");
} else {
setOpen(false);
toast.success("Password changed successfully");
diff --git a/demo/nextjs/components/sign-in-btn.tsx b/demo/nextjs/components/sign-in-btn.tsx
index b7ca8898..192b9c1a 100644
--- a/demo/nextjs/components/sign-in-btn.tsx
+++ b/demo/nextjs/components/sign-in-btn.tsx
@@ -32,10 +32,15 @@ export async function SignInButton() {
}
+function checkOptimisticSession(headers: Headers) {
+ const guessIsSignIn = headers.get("cookie")?.includes("better-auth.session") || headers.get("cookie")?.includes("__Secure-better-auth.session-token")
+ return !!guessIsSignIn;
+}
+
export function SignInFallback() {
//to avoid flash of unauthenticated state
- const guessIsSignIn = headers().get("cookie")?.includes("better-auth.session") || headers().get("cookie")?.includes("__Secure-better-auth.session-token")
+ const guessIsSignIn = checkOptimisticSession(headers())
return (
();
+
+/**
+ * Generate a unique key for the request to cache the
+ * request for 5 seconds for this specific request.
+ *
+ * This is to prevent reaching to database if getSession is
+ * called multiple times for the same request
+ */
+function getRequestUniqueKey(ctx: Context, token: string): string {
+ if (!ctx.request) {
+ return "";
+ }
+ const { method, url, headers } = ctx.request;
+ const userAgent = ctx.request.headers.get("User-Agent") || "";
+ const ip = getIp(ctx.request) || "";
+ const headerString = JSON.stringify(headers);
+ const uniqueString = `${method}:${url}:${headerString}:${userAgent}:${ip}:${token}`;
+ return uniqueString;
+}
export const getSession = createAuthEndpoint(
"/session",
@@ -24,6 +55,16 @@ export const getSession = createAuthEndpoint(
status: 401,
});
}
+
+ const key = getRequestUniqueKey(ctx, sessionCookieToken);
+ const cachedSession = sessionCache.get(key);
+ if (cachedSession) {
+ if (cachedSession.expiresAt > Date.now()) {
+ return ctx.json(cachedSession.data);
+ }
+ sessionCache.delete(key);
+ }
+
const session =
await ctx.context.internalAdapter.findSession(sessionCookieToken);
@@ -89,6 +130,10 @@ export const getSession = createAuthEndpoint(
user: session.user,
});
}
+ sessionCache.set(key, {
+ data: session,
+ expiresAt: Date.now() + 5000,
+ });
return ctx.json(session);
} catch (error) {
ctx.context.logger.error(error);
diff --git a/packages/better-auth/src/client/client.test.ts b/packages/better-auth/src/client/client.test.ts
index 699a4202..d3cf8796 100644
--- a/packages/better-auth/src/client/client.test.ts
+++ b/packages/better-auth/src/client/client.test.ts
@@ -95,6 +95,11 @@ describe("type", () => {
const client = createReactClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
type ReturnedSession = ReturnType;
expectTypeOf().toMatchTypeOf<{
@@ -121,6 +126,11 @@ describe("type", () => {
const client = createReactClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<() => number>();
});
@@ -128,6 +138,11 @@ describe("type", () => {
const client = createSolidClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
() => Accessor
@@ -137,6 +152,11 @@ describe("type", () => {
const client = createVueClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf<
() => Readonly[>
@@ -146,6 +166,11 @@ describe("type", () => {
const client = createSvelteClient({
plugins: [testClientPlugin()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
expectTypeOf(client.useComputedAtom).toEqualTypeOf>();
});
@@ -154,6 +179,11 @@ describe("type", () => {
const client = createSolidClient({
plugins: [testClientPlugin(), testClientPlugin2()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
expectTypeOf(client.setTestAtom).toEqualTypeOf<(value: boolean) => void>();
expectTypeOf(client.test.signOut).toEqualTypeOf<() => Promise>();
@@ -163,9 +193,14 @@ describe("type", () => {
const client = createSolidClient({
plugins: [testClientPlugin(), testClientPlugin2(), twoFactorClient()],
baseURL: "http://localhost:3000",
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return new Response();
+ },
+ },
});
- const $infer = client.$infer;
- expectTypeOf($infer.session).toEqualTypeOf<{
+ const $infer = client.$Infer;
+ expectTypeOf($infer.Session).toEqualTypeOf<{
session: {
id: string;
userId: string;
diff --git a/packages/better-auth/src/client/fetch-plugins.ts b/packages/better-auth/src/client/fetch-plugins.ts
index d1dcf5e7..ecde1910 100644
--- a/packages/better-auth/src/client/fetch-plugins.ts
+++ b/packages/better-auth/src/client/fetch-plugins.ts
@@ -30,6 +30,7 @@ export const addCurrentURL = {
},
} satisfies BetterFetchPlugin;
+const cache = new Map();
export const csrfPlugin = {
id: "csrf",
name: "CSRF Check",
@@ -42,27 +43,39 @@ export const csrfPlugin = {
if (options?.method !== "GET") {
options = options || {};
- const { data, error } = await betterFetch<{
- csrfToken: string;
- }>("/csrf", {
- body: undefined,
- baseURL: options.baseURL,
- plugins: [],
- method: "GET",
- credentials: "include",
- customFetchImpl: options.customFetchImpl,
- });
- if (error?.status === 404) {
- throw new BetterAuthError(
- "Route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
- );
- }
- if (error) {
- throw new BetterAuthError(error.message || "Failed to get CSRF token.");
+ const csrfToken = cache.get("CSRF_TOKEN");
+ if (!csrfToken) {
+ const { data, error } = await betterFetch<{
+ csrfToken: string;
+ }>("/csrf", {
+ body: undefined,
+ baseURL: options.baseURL,
+ plugins: [],
+ method: "GET",
+ credentials: "include",
+ customFetchImpl: options.customFetchImpl,
+ });
+ if (error) {
+ if (error.status === 404) {
+ throw new BetterAuthError(
+ "CSRF route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
+ );
+ }
+ if (error.status === 429) {
+ return new Response(null, {
+ status: 429,
+ statusText: "Too Many Requests",
+ });
+ }
+ throw new BetterAuthError(
+ "Failed to fetch CSRF token: " + error.message,
+ );
+ }
+ cache.set("CSRF_TOKEN", data.csrfToken);
}
options.body = {
...options?.body,
- csrfToken: data.csrfToken,
+ csrfToken: csrfToken,
};
}
options.credentials = "include";
diff --git a/packages/better-auth/src/plugins/passkey/client.ts b/packages/better-auth/src/plugins/passkey/client.ts
index e0313055..9b7a31c3 100644
--- a/packages/better-auth/src/plugins/passkey/client.ts
+++ b/packages/better-auth/src/plugins/passkey/client.ts
@@ -86,12 +86,11 @@ export const getPasskeyActions = (
const verified = await $fetch<{
passkey: Passkey;
}>("/passkey/verify-registration", {
+ ...opts?.options,
body: {
response: res,
-
name: opts?.name,
},
- ...opts?.options,
});
if (!verified.data) {
return verified;
diff --git a/packages/better-auth/src/plugins/passkey/index.ts b/packages/better-auth/src/plugins/passkey/index.ts
index 439a640d..02d3b772 100644
--- a/packages/better-auth/src/plugins/passkey/index.ts
+++ b/packages/better-auth/src/plugins/passkey/index.ts
@@ -74,7 +74,10 @@ export const passkey = (options?: PasskeyOptions) => {
const baseURL = process.env.BETTER_AUTH_URL;
const rpID =
options?.rpID ||
- baseURL?.replace("http://", "").replace("https://", "") ||
+ baseURL
+ ?.replace("http://", "")
+ .replace("https://", "")
+ .replace(":3000", "") ||
"localhost";
if (!rpID) {
throw new BetterAuthError(
diff --git a/packages/better-auth/src/plugins/rate-limiter/get-key.ts b/packages/better-auth/src/plugins/rate-limiter/get-key.ts
new file mode 100644
index 00000000..7e718e93
--- /dev/null
+++ b/packages/better-auth/src/plugins/rate-limiter/get-key.ts
@@ -0,0 +1,19 @@
+import { getSession } from "../../api/routes";
+import { BetterAuthError } from "../../error/better-auth-error";
+import { getIp } from "../../utils/get-request-ip";
+
+export async function getRateLimitKey(req: Request) {
+ if (req.headers.get("Authorization") || req.headers.get("cookie")) {
+ const session = await getSession({
+ headers: req.headers,
+ });
+ if (session) {
+ return session.user.id;
+ }
+ }
+ const ip = getIp(req);
+ if (!ip) {
+ throw new BetterAuthError("IP not found");
+ }
+ return ip;
+}
diff --git a/packages/better-auth/src/plugins/rate-limiter/index.ts b/packages/better-auth/src/plugins/rate-limiter/index.ts
new file mode 100644
index 00000000..802f4fc7
--- /dev/null
+++ b/packages/better-auth/src/plugins/rate-limiter/index.ts
@@ -0,0 +1,239 @@
+import { APIError } from "better-call";
+import { createAuthMiddleware } from "../../api/call";
+import type { GenericEndpointContext } from "../../types/context";
+import type { BetterAuthPlugin } from "../../types/plugins";
+import { getRateLimitKey } from "./get-key";
+
+interface RateLimit {
+ key: string;
+ count: number;
+ lastRequest: number;
+}
+
+export interface RateLimitOptions {
+ /**
+ * Enable rate limiting. You can also pass a function
+ * to enable rate limiting for specific endpoints.
+ *
+ * @default true
+ */
+ enabled: boolean | ((request: Request) => boolean | Promise);
+ /**
+ * The window to use for rate limiting. The value
+ * should be in seconds.
+ * @default 15 minutes (15 * 60)
+ */
+ window?: number;
+ /**
+ * The maximum number of requests allowed within the window.
+ * @default 100
+ */
+ max?: number;
+ /**
+ * Function to get the key to use for rate limiting.
+ * @default "ip" or "userId" if the user is logged in.
+ */
+ getKey?: (request: Request) => string | Promise;
+ storage?:
+ | "database"
+ | "memory"
+ | {
+ get: (key: string) => Promise;
+ set: (key: string, value: RateLimit) => Promise;
+ };
+ /**
+ * Custom rate limiting function.
+ */
+ customRateLimit?: (request: Request) => Promise;
+ /**
+ * Special rules to apply to specific paths.
+ *
+ * By default, endpoints that starts with "/sign-in" or "/sign-up" are added
+ * to the rate limiting mechanism with a count value of 2.
+ * @example
+ * ```ts
+ * specialRules: [
+ * {
+ * matcher: (request) => request.url.startsWith("/sign-in"),
+ * // This will half the amount of requests allowed for the sign-in endpoint
+ * countValue: 2,
+ * }
+ * ]
+ * ```
+ */
+ specialRules?: {
+ /**
+ * Custom matcher to determine if the special rule should be applied.
+ */
+ matcher: (path: string) => boolean;
+ /**
+ * The value to use for the count.
+ *
+ */
+ countValue: number;
+ }[];
+}
+
+/**
+ * Rate limiting plugin for BetterAuth. It implements a simple rate limiting
+ * mechanism to prevent abuse. It can be configured to use a database, memory
+ * storage or a custom storage. It can also be configured to use a custom rate
+ * limiting function.
+ *
+ * @example
+ * ```ts
+ * const plugin = rateLimiter({
+ * enabled: true,
+ * window: 60,
+ * max: 100,
+ * });
+ * ```
+ */
+export const rateLimiter = (options: RateLimitOptions) => {
+ const opts = {
+ storage: "database",
+ max: 100,
+ window: 15 * 60,
+ specialRules: [
+ {
+ matcher(path) {
+ return path.startsWith("/sign-in") || path.startsWith("/sign-up");
+ },
+ countValue: 2,
+ },
+ ],
+ ...options,
+ } satisfies RateLimitOptions;
+ const schema =
+ opts.storage === "database"
+ ? ({
+ rateLimit: {
+ fields: {
+ key: {
+ type: "string",
+ },
+ count: {
+ type: "number",
+ },
+ lastRequest: {
+ type: "number",
+ },
+ },
+ },
+ } as const)
+ : undefined;
+
+ function createDBStorage(ctx: GenericEndpointContext) {
+ const db = ctx.context.db;
+ return {
+ get: async (key: string) => {
+ const result = await db
+ .selectFrom("rateLimit")
+ .where("key", "=", key)
+ .selectAll()
+ .executeTakeFirst();
+ return result as RateLimit | undefined;
+ },
+ set: async (key: string, value: RateLimit, isNew: boolean = true) => {
+ if (isNew) {
+ await db
+ .insertInto("rateLimit")
+ .values({
+ key,
+ count: value.count,
+ lastRequest: value.lastRequest,
+ })
+ .execute();
+ } else {
+ await db
+ .updateTable("rateLimit")
+ .set({
+ count: value.count,
+ lastRequest: value.lastRequest,
+ })
+ .where("key", "=", key)
+ .execute();
+ }
+ },
+ };
+ }
+ const storage = new Map();
+ function createMemoryStorage() {
+ return {
+ get: async (key: string) => {
+ return storage.get(key);
+ },
+ set: async (key: string, value: RateLimit) => {
+ storage.set(key, value);
+ },
+ };
+ }
+
+ return {
+ id: "rate-limiter",
+ middlewares: [
+ {
+ path: "/**",
+ middleware: createAuthMiddleware(async (ctx) => {
+ if (!ctx.request) {
+ return;
+ }
+ if (opts.customRateLimit) {
+ const shouldLimit = await opts.customRateLimit(ctx.request);
+ if (!shouldLimit) {
+ throw new APIError("TOO_MANY_REQUESTS", {
+ message: "Too many requests",
+ });
+ }
+ return;
+ }
+ const key = await getRateLimitKey(ctx.request);
+ const storage =
+ opts.storage === "database"
+ ? createDBStorage(ctx)
+ : opts.storage === "memory"
+ ? createMemoryStorage()
+ : opts.storage;
+ const rateLimit = await storage.get(key);
+ if (!rateLimit) {
+ await storage.set(key, {
+ key,
+ count: 0,
+ lastRequest: new Date().getTime(),
+ });
+ return;
+ }
+ const now = new Date().getTime();
+ const windowStart = now - opts.window * 1000;
+ if (
+ rateLimit.lastRequest >= windowStart &&
+ rateLimit.count >= opts.max
+ ) {
+ throw new APIError("TOO_MANY_REQUESTS", {
+ message: "Too many requests",
+ });
+ }
+
+ if (rateLimit.lastRequest < windowStart) {
+ rateLimit.count = 0;
+ }
+ const count =
+ opts.specialRules.find((rule) => rule.matcher(ctx.path))
+ ?.countValue ?? 1;
+
+ await storage.set(
+ key,
+ {
+ key,
+ count: rateLimit.count + count,
+ lastRequest: now,
+ },
+ false,
+ );
+ return;
+ }),
+ },
+ ],
+ schema,
+ } satisfies BetterAuthPlugin;
+};
diff --git a/packages/better-auth/src/plugins/rate-limiter/rate-limiter.test.ts b/packages/better-auth/src/plugins/rate-limiter/rate-limiter.test.ts
new file mode 100644
index 00000000..25e77d4b
--- /dev/null
+++ b/packages/better-auth/src/plugins/rate-limiter/rate-limiter.test.ts
@@ -0,0 +1,53 @@
+import { describe, it, beforeAll, expect, vi } from "vitest";
+import { getTestInstance } from "../../test-utils/test-instance";
+import { rateLimiter } from ".";
+
+describe("rate-limiter", async () => {
+ const { client, testUser } = await getTestInstance({
+ plugins: [
+ rateLimiter({
+ enabled: true,
+ storage: "memory",
+ max: 10,
+ window: 10,
+ }),
+ ],
+ });
+
+ it.only("should allow requests within the limit", async () => {
+ for (let i = 0; i < 10; i++) {
+ const response = await client.signIn.email({
+ email: testUser.email,
+ password: testUser.password,
+ });
+ if (i === 9) {
+ expect(response.error?.status).toBe(429);
+ } else {
+ expect(response.error).toBeNull();
+ }
+ }
+ });
+
+ it.only("should reset the limit after the window period", async () => {
+ vi.useFakeTimers();
+
+ // Make 10 requests to hit the limit
+ for (let i = 0; i < 10; i++) {
+ const res = await client.signIn.email({
+ email: testUser.email,
+ password: testUser.password,
+ });
+ if (res.error?.status === 429) {
+ break;
+ }
+ }
+ // Advance the timer by 11 seconds (just over the 10-second window)
+ vi.advanceTimersByTime(11000);
+ const response = await client.signIn.email({
+ email: testUser.email,
+ password: testUser.password,
+ });
+ expect(response.error).toBeNull();
+ vi.useRealTimers();
+ });
+});
diff --git a/packages/better-auth/src/utils/get-request-ip.ts b/packages/better-auth/src/utils/get-request-ip.ts
new file mode 100644
index 00000000..9ef54b90
--- /dev/null
+++ b/packages/better-auth/src/utils/get-request-ip.ts
@@ -0,0 +1,26 @@
+export function getIp(req: Request): string | null {
+ const testIP = "127.0.0.1";
+ if (process.env.NODE_ENV === "test") {
+ return testIP;
+ }
+ const headers = [
+ "x-client-ip",
+ "x-forwarded-for",
+ "cf-connecting-ip",
+ "fastly-client-ip",
+ "x-real-ip",
+ "x-cluster-client-ip",
+ "x-forwarded",
+ "forwarded-for",
+ "forwarded",
+ ];
+
+ for (const header of headers) {
+ const value = req.headers.get(header);
+ if (typeof value === "string") {
+ const ip = value.split(",")[0].trim();
+ if (ip) return ip;
+ }
+ }
+ return null;
+}
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 3da083dc..5b0c50df 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -896,8 +896,8 @@ importers:
specifier: ^0.31.2
version: 0.31.2
better-call:
- specifier: 0.2.3-beta.1
- version: 0.2.3-beta.1
+ specifier: 0.2.3-beta.3
+ version: 0.2.3-beta.3
better-sqlite3:
specifier: ^11.1.2
version: 11.1.2
@@ -2627,7 +2627,7 @@ packages:
'@expo/bunyan@4.0.1':
resolution: {integrity: sha512-+Lla7nYSiHZirgK+U/uYzsLv/X+HaJienbD5AKX1UQZHYfWaP+9uuQluRB4GrEVWF0GZ7vEVp/jzaOT9k/SQlg==}
- engines: {'0': node >=0.10.0}
+ engines: {node: '>=0.10.0'}
'@expo/cli@0.18.29':
resolution: {integrity: sha512-X810C48Ss+67RdZU39YEO1khNYo1RmjouRV+vVe0QhMoTe8R6OA3t+XYEdwaNbJ5p/DJN7szfHfNmX2glpC7xg==}
@@ -6216,12 +6216,12 @@ packages:
peerDependencies:
typescript: ^5.6.0-beta
- better-call@0.2.3-beta.1:
- resolution: {integrity: sha512-I4+sm4OIgbnUdbtI5anc5Hd9dSCRnuy3pb2WdA2guLHwzjCEHlg7bTBlP6usG50FIRBceDrxfCnFBo4I85Iazg==}
-
better-call@0.2.3-beta.2:
resolution: {integrity: sha512-ybOtGcR4pOsHI2XE+urR9zcmK+s0YnhJSx8KDj6ul7MUEyYOiMEnq/bylyH62/7qXuYb9q8Oqkp9NF9vWOZ4Mg==}
+ better-call@0.2.3-beta.3:
+ resolution: {integrity: sha512-1W0HB1aJ1adhbkVAi5gBQE/0uhdz+Y1Nrg3I0xZDFQx1R6vocBq7R+bVDrI5eNX+HmbzrXdY6O+Q/CfhhczJcg==}
+
better-opn@3.0.2:
resolution: {integrity: sha512-aVNobHnJqLiUelTaHat9DZ1qM2w0C0Eym4LPI/3JxOnSokGVdsl1T1kN7TFvsEAD8G47A6VKQ0TVHqbBnYMJlQ==}
engines: {node: '>=12.0.0'}
@@ -9133,10 +9133,12 @@ packages:
libsql@0.3.19:
resolution: {integrity: sha512-Aj5cQ5uk/6fHdmeW0TiXK42FqUlwx7ytmMLPSaUQPin5HKKKuUPD62MAbN4OEweGBBI7q1BekoEN4gPUEL6MZA==}
+ cpu: [x64, arm64, wasm32]
os: [darwin, linux, win32]
libsql@0.4.5:
resolution: {integrity: sha512-sorTJV6PNt94Wap27Sai5gtVLIea4Otb2LUiAUyr3p6BPOScGMKGt5F1b5X/XgkNtcsDKeX5qfeBDj+PdShclQ==}
+ cpu: [x64, arm64, wasm32]
os: [darwin, linux, win32]
lighthouse-logger@1.4.2:
@@ -20812,7 +20814,7 @@ snapshots:
set-cookie-parser: 2.7.0
typescript: 5.6.2
- better-call@0.2.3-beta.1:
+ better-call@0.2.3-beta.2:
dependencies:
'@better-fetch/fetch': 1.1.8
'@types/set-cookie-parser': 2.4.10
@@ -20820,7 +20822,7 @@ snapshots:
set-cookie-parser: 2.7.0
typescript: 5.6.2
- better-call@0.2.3-beta.2:
+ better-call@0.2.3-beta.3:
dependencies:
'@better-fetch/fetch': 1.1.8
'@types/set-cookie-parser': 2.4.10
]