refactor: use origin headers check instead of tokens for csrf protection (#356)

This commit is contained in:
Bereket Engida
2024-10-28 10:31:08 +03:00
committed by GitHub
parent e72f4597a0
commit c4aa713ccc
36 changed files with 162 additions and 298 deletions

View File

@@ -1,5 +1,14 @@
# Security Policy # Security Policy
## CSRF Protection
Better Auth protects against CSRF by enforcing strict origin checks and setting cookies with the `SameSite` attribute. As a best practice, any GET request should be designed to avoid modifying resources. If a GET request does alter data, such as in an OAuth callback, additional safeguards (e.g., state parameter verification) must be implemented. Any request containing cookies but missing an `Origin` or `Referer` header is rejected. Requests with these headers that dont match `trustedOrigins` are also discarded.
## Open Redirect Protection
Any endpoint added to a Better Auth instance, whether from a plugin or the core, should only use `callbackURL`, `currentURL`, or `redirectTo` for redirecting users post-action. These values are validated against `trustedOrigins` for security. Additionally, no endpoint handling GET requests should modify resources unless it has its own protection mechanisms in place.
## Reporting a Vulnerability ## Reporting a Vulnerability
If you discover a security vulnerability within Better Auth, please send an e-mail to security@better-auth.com. If you discover a security vulnerability within Better Auth, please send an e-mail to security@better-auth.com.

View File

@@ -3,9 +3,9 @@ title: Cookies
description: Learn how cookies are used in BetterAuth description: Learn how cookies are used in BetterAuth
--- ---
Cookies are used to store data such as session tokens, CSRF tokens, and more. All cookies are signed using the `secret` key provided in the auth options. Cookies are used to store data such as session tokens, OAuth state, and more. All cookies are signed using the `secret` key provided in the auth options.
Core Better Auth cookies like `session` and `csrf` will follow `betterauth.${cookie_name}` format. Core Better Auth cookies will follow `betterauth.${cookie_name}` format.
All cookies are `httpOnly` and `secure` if the server is running in production mode. All cookies are `httpOnly` and `secure` if the server is running in production mode.
@@ -34,20 +34,6 @@ export const auth = betterAuth({
This feature is experimental and may not work as expected in all scenarios. And this is specefically to share session cookies across subdomains. This feature is experimental and may not work as expected in all scenarios. And this is specefically to share session cookies across subdomains.
</Callout> </Callout>
### Disable CSRF Cookie (⚠︎ Not Recommended)
If you want to disable the CSRF cookie, you can set `disableCsrfCheck` to `true` in the `advanced` object in the auth options. If you disable the CSRF cookie, you should make sure that your framework handles CSRF protection itself.
```ts title="auth.ts"
import { betterAuth } from "better-auth"
export const auth = betterAuth({
advanced: {
disableCSRFCheck: true
}
})
```
### Secure Cookies ### Secure Cookies
By default, cookies are secure if the server is running in production mode. You can force cookies to be secure by setting `useSecureCookies` to `true` in the `advanced` object in the auth options. By default, cookies are secure if the server is running in production mode. You can force cookies to be secure by setting `useSecureCookies` to `true` in the `advanced` object in the auth options.
@@ -61,23 +47,3 @@ export const auth = betterAuth({
} }
}) })
``` ```
## CSRF Protection
**Cross-Site Request Forgery (CSRF) Protection in Better Auth**
Better Auth protects your app from CSRF attacks in two ways:
1. **Secure Cookies**: All cookies are marked as `HttpOnly`, `Secure`, and use the `SameSite=Lax` attribute. This ensures theyre inaccessible to client-side scripts, only sent over HTTPS, and not shared across sites.
2. **CSRF Tokens**: By default, CSRF token checks are disabled for the same origin as `baseURL`, since CSRF attacks only affect browser requests. For other origins, CSRF tokens are required for `POST` requests. It uses double submit cookies to validate the token. Each session has a unique CSRF token that is sent as a cookie and a header in every request. If the two dont match, the request is rejected.
You can adjust this behavior:
- Use `disableCSRFTokenCheck: true` on the client to skip token checks entirely.
- To allow untrusted origins, specify them in the `trustedOrigins` option on the server. These origins will be exempt from CSRF checks.
Untrusted requests without valid tokens will result in a `403` error.
<Callout type="warn">
You can also disable CSRF token check for all clients by setting `advanced.disableCSRFCheck` option on the server. You should only do this if your framework handles CSRF protection itself.
</Callout>

View File

@@ -237,7 +237,16 @@ const auth = betterAuth({
### `trustedOrignins` ### `trustedOrignins`
list of trusted origins. This will disable CSRF token check for the provided origins. By default, CSRF token check is disabled for origins that are same as `baseURL`. list of trusted origins. This is very important to prevent CSRF attacks and open redirects.
```ts title="auth.ts"
const auth = betterAuth({
trustedOrigins: [
'https://example.com',
'https://app.example.com'
]
})
```
### `advanced` ### `advanced`
@@ -252,7 +261,7 @@ list of trusted origins. This will disable CSRF token check for the provided ori
default: `true (if base URL is 'https')` default: `true (if base URL is 'https')`
}, },
disableCSRFCheck: { disableCSRFCheck: {
description: "Disable CSRF check.", description: "Disable csrf protection checks. ⚠︎ only use this if you know what you are doing.",
type: 'boolean', type: 'boolean',
default: false default: false
} }

View File

@@ -13,7 +13,7 @@
"lint": "biome check .", "lint": "biome check .",
"lint:fix": "biome check . --apply", "lint:fix": "biome check . --apply",
"release": "turbo --filter \"./packages/*\" build && bumpp && pnpm -r publish --access public --no-git-checks", "release": "turbo --filter \"./packages/*\" build && bumpp && pnpm -r publish --access public --no-git-checks",
"release:no-build": "bumpp && pnpm -r publish --access public --no-git-checks", "release:no-build": "bumpp && pnpm -r publish --access public --no-git-checks --tag next",
"release:beta": "turbo --filter \"./packages/*\" build && bumpp && pnpm -r publish --access public --tag next --no-git-checks", "release:beta": "turbo --filter \"./packages/*\" build && bumpp && pnpm -r publish --access public --tag next --no-git-checks",
"test": "turbo --filter \"./packages/*\" test", "test": "turbo --filter \"./packages/*\" test",
"typecheck": "turbo --filter \"./packages/*\" typecheck" "typecheck": "turbo --filter \"./packages/*\" typecheck"

View File

@@ -1,6 +1,6 @@
{ {
"name": "better-auth", "name": "better-auth",
"version": "0.6.1", "version": "0.6.2-beta.7",
"description": "The most comprehensive authentication library for TypeScript.", "description": "The most comprehensive authentication library for TypeScript.",
"type": "module", "type": "module",
"repository": { "repository": {

View File

@@ -13,16 +13,6 @@ exports[`init > should match config 1`] = `
}, },
"appName": "Better Auth", "appName": "Better Auth",
"authCookies": { "authCookies": {
"csrfToken": {
"name": "better-auth.csrf_token",
"options": {
"httpOnly": true,
"maxAge": 604800,
"path": "/",
"sameSite": "lax",
"secure": false,
},
},
"dontRememberToken": { "dontRememberToken": {
"name": "better-auth.dont_remember", "name": "better-auth.dont_remember",
"options": { "options": {

View File

@@ -2,7 +2,7 @@ import { APIError, type Endpoint, createRouter, statusCode } from "better-call";
import type { AuthContext } from "../init"; import type { AuthContext } from "../init";
import type { BetterAuthOptions } from "../types"; import type { BetterAuthOptions } from "../types";
import type { UnionToIntersection } from "../types/helper"; import type { UnionToIntersection } from "../types/helper";
import { csrfMiddleware } from "./middlewares/csrf"; import { originCheckMiddleware } from "./middlewares/origin-check";
import { import {
callbackOAuth, callbackOAuth,
forgetPassword, forgetPassword,
@@ -25,7 +25,6 @@ import {
setPassword, setPassword,
updateUser, updateUser,
} from "./routes"; } from "./routes";
import { getCSRFToken } from "./routes/csrf";
import { ok } from "./routes/ok"; import { ok } from "./routes/ok";
import { signUpEmail } from "./routes/sign-up"; import { signUpEmail } from "./routes/sign-up";
import { error } from "./routes/error"; import { error } from "./routes/error";
@@ -87,7 +86,6 @@ export function getEndpoints<
const baseEndpoints = { const baseEndpoints = {
signInOAuth, signInOAuth,
callbackOAuth, callbackOAuth,
getCSRFToken,
getSession: getSession<Option>(), getSession: getSession<Option>(),
signOut, signOut,
signUpEmail: signUpEmail<Option>(), signUpEmail: signUpEmail<Option>(),
@@ -239,7 +237,7 @@ export const router = <C extends AuthContext, Option extends BetterAuthOptions>(
routerMiddleware: [ routerMiddleware: [
{ {
path: "/**", path: "/**",
middleware: csrfMiddleware, middleware: originCheckMiddleware,
}, },
...middlewares, ...middlewares,
], ],

View File

@@ -1,61 +0,0 @@
import { APIError } from "better-call";
import { z } from "zod";
import { hs256 } from "../../crypto";
import { createAuthMiddleware } from "../call";
import { deleteSessionCookie } from "../../cookies";
export const csrfMiddleware = createAuthMiddleware(
{
body: z
.object({
csrfToken: z.string().optional(),
})
.optional(),
},
async (ctx) => {
if (
ctx.request?.method !== "POST" ||
ctx.context.options.advanced?.disableCSRFCheck
) {
return;
}
const originHeader = ctx.headers?.get("origin") || "";
/**
* If origin is the same as baseURL or if the
* origin is in the trustedOrigins then we
* don't need to check the CSRF token.
*/
if (originHeader) {
const origin = new URL(originHeader).origin;
if (ctx.context.trustedOrigins.includes(origin)) {
return;
}
}
const csrfToken = ctx.body?.csrfToken;
if (!csrfToken) {
throw new APIError("UNAUTHORIZED", {
message: "CSRF Token is required",
});
}
const csrfCookie = await ctx.getSignedCookie(
ctx.context.authCookies.csrfToken.name,
ctx.context.secret,
);
const [token, hash] = csrfCookie?.split("!") || [null, null];
if (!csrfToken || !token || !hash || token !== csrfToken) {
throw new APIError("UNAUTHORIZED", {
message: "Invalid CSRF Token",
});
}
const expectedHash = await hs256(ctx.context.secret, token);
if (hash !== expectedHash) {
ctx.setCookie(ctx.context.authCookies.csrfToken.name, "", {
maxAge: 0,
});
throw new APIError("UNAUTHORIZED", {
message: "Invalid CSRF Token",
});
}
},
);

View File

@@ -1 +1 @@
export * from "./csrf"; export * from "./origin-check";

View File

@@ -2,13 +2,16 @@ import { describe, expect } from "vitest";
import { getTestInstance } from "../../test-utils/test-instance"; import { getTestInstance } from "../../test-utils/test-instance";
import { createAuthClient } from "../../client"; import { createAuthClient } from "../../client";
describe("redirectURLMiddleware", async (it) => { describe("Origin Check", async (it) => {
const { customFetchImpl, testUser } = await getTestInstance({ const { customFetchImpl, testUser } = await getTestInstance({
trustedOrigins: ["http://localhost:5000", "https://trusted.com"], trustedOrigins: ["http://localhost:5000", "https://trusted.com"],
emailAndPassword: { emailAndPassword: {
enabled: true, enabled: true,
async sendResetPassword(url, user) {}, async sendResetPassword(url, user) {},
}, },
advanced: {
disableCSRFCheck: false,
},
}); });
it("should not allow untrusted origins", async (ctx) => { it("should not allow untrusted origins", async (ctx) => {
@@ -32,6 +35,9 @@ describe("redirectURLMiddleware", async (it) => {
baseURL: "http://localhost:3000", baseURL: "http://localhost:3000",
fetchOptions: { fetchOptions: {
customFetchImpl, customFetchImpl,
headers: {
origin: "http://localhost:3000",
},
}, },
}); });
const res = await client.signIn.email({ const res = await client.signIn.email({
@@ -42,6 +48,59 @@ describe("redirectURLMiddleware", async (it) => {
expect(res.data?.session).toBeDefined(); expect(res.data?.session).toBeDefined();
}); });
it("shouldn't allow untrusted origin headers", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
headers: {
origin: "malicious.com",
cookie: "session=123",
},
},
});
const res = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
expect(res.error?.status).toBe(403);
});
it("shouldn't allow untrusted origin subdomains", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
headers: {
origin: "http://sub-domain.trusted.com",
cookie: "session=123",
},
},
});
const res = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
expect(res.error?.status).toBe(403);
});
it("should allow untrusted origin if they don't contain cookies", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
headers: {
origin: "http://sub-domain.trusted.com",
},
},
});
const res = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
expect(res.data?.session).toBeDefined();
});
it("shouldn't allow untrusted currentURL", async (ctx) => { it("shouldn't allow untrusted currentURL", async (ctx) => {
const client = createAuthClient({ const client = createAuthClient({
baseURL: "http://localhost:3000", baseURL: "http://localhost:3000",
@@ -76,7 +135,7 @@ describe("redirectURLMiddleware", async (it) => {
redirectTo: "http://malicious.com", redirectTo: "http://malicious.com",
}); });
expect(res.error?.status).toBe(403); expect(res.error?.status).toBe(403);
expect(res.error?.message).toBe("Invalid callbackURL"); expect(res.error?.message).toBe("Invalid redirectURL");
}); });
it("should work with list of trusted origins ", async (ctx) => { it("should work with list of trusted origins ", async (ctx) => {
@@ -84,6 +143,9 @@ describe("redirectURLMiddleware", async (it) => {
baseURL: "http://localhost:3000", baseURL: "http://localhost:3000",
fetchOptions: { fetchOptions: {
customFetchImpl, customFetchImpl,
headers: {
origin: "https://trusted.com",
},
}, },
}); });
const res = await client.forgetPassword({ const res = await client.forgetPassword({

View File

@@ -0,0 +1,41 @@
import { APIError } from "better-call";
import { createAuthMiddleware } from "../call";
import { logger } from "../../utils";
/**
* A middleware to validate callbackURL, redirectURL, currentURL and origin against trustedOrigins.
*/
export const originCheckMiddleware = createAuthMiddleware(async (ctx) => {
if (ctx.request?.method !== "POST") {
return;
}
const { body, query, context } = ctx;
const originHeader =
ctx.headers?.get("origin") || ctx.headers?.get("referer") || "";
const callbackURL = body?.callbackURL;
const redirectURL = body?.redirectTo;
const currentURL = query?.currentURL;
const trustedOrigins = context.trustedOrigins;
const usesCookies = ctx.headers?.has("cookie");
const validateURL = (url: string | undefined, label: string) => {
const isTrustedOrigin = trustedOrigins.some(
(origin) =>
url?.startsWith(origin) || (url?.startsWith("/") && label !== "origin"),
);
if (!isTrustedOrigin) {
logger.error(`Invalid ${label}: ${url}`);
logger.info(
`If it's a valid URL, please add ${url} to trustedOrigins in your auth config\n`,
`Current list of trustedOrigins: ${trustedOrigins}`,
);
throw new APIError("FORBIDDEN", { message: `Invalid ${label}` });
}
};
if (usesCookies && !ctx.context.options.advanced?.disableCSRFCheck) {
validateURL(originHeader, "origin");
}
callbackURL && validateURL(callbackURL, "callbackURL");
redirectURL && validateURL(redirectURL, "redirectURL");
currentURL && validateURL(currentURL, "currentURL");
});

View File

@@ -1,38 +0,0 @@
import { APIError } from "better-call";
import { createAuthMiddleware } from "../call";
import { logger } from "../../utils/logger";
/**
* Middleware to validate callbackURL and currentURL against trustedOrigins,
* preventing open redirect attacks.
*/
export const redirectURLMiddleware = createAuthMiddleware(async (ctx) => {
const { body, query, context } = ctx;
const callbackURL =
body?.callbackURL ||
query?.callbackURL ||
query?.redirectTo ||
body?.redirectTo;
const currentURL = query?.currentURL;
const trustedOrigins = context.trustedOrigins;
const validateURL = (url: string | undefined, label: string) => {
if (url?.startsWith("http")) {
const isTrustedOrigin = trustedOrigins.some((origin) =>
url.startsWith(origin),
);
if (!isTrustedOrigin) {
logger.error(`Invalid ${label}: ${url}`);
logger.info(
`If it's a valid URL, please add ${url} to trustedOrigins in your auth config\n`,
`Current list of trustedOrigins: ${trustedOrigins}`,
);
throw new APIError("FORBIDDEN", { message: `Invalid ${label}` });
}
}
};
validateURL(callbackURL, "callbackURL");
validateURL(currentURL, "currentURL");
});

View File

@@ -1,7 +1,6 @@
import { z } from "zod"; import { z } from "zod";
import { createAuthEndpoint } from "../call"; import { createAuthEndpoint } from "../call";
import { socialProviderList } from "../../social-providers"; import { socialProviderList } from "../../social-providers";
import { redirectURLMiddleware } from "../middlewares/redirect";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { generateState, parseState, type OAuth2Tokens } from "../../oauth2"; import { generateState, parseState, type OAuth2Tokens } from "../../oauth2";
import { generateCodeVerifier } from "oslo/oauth2"; import { generateCodeVerifier } from "oslo/oauth2";
@@ -49,7 +48,7 @@ export const linkSocialAccount = createAuthEndpoint(
*/ */
provider: z.enum(socialProviderList), provider: z.enum(socialProviderList),
}), }),
use: [redirectURLMiddleware, sessionMiddleware], use: [sessionMiddleware],
}, },
async (c) => { async (c) => {
const session = c.context.session; const session = c.context.session;

View File

@@ -1,38 +0,0 @@
import { alphabet, generateRandomString } from "../../crypto/random";
import { hs256 } from "../../crypto";
import { createAuthEndpoint } from "../call";
import { HIDE_METADATA } from "../../utils/hide-metadata";
export const getCSRFToken = createAuthEndpoint(
"/csrf",
{
method: "GET",
metadata: HIDE_METADATA,
},
async (ctx) => {
const csrfCookie = await ctx.getSignedCookie(
ctx.context.authCookies.csrfToken.name,
ctx.context.secret,
);
if (csrfCookie) {
const [token, _] = csrfCookie.split("!") || [null, null];
return ctx.json({
csrfToken: token,
});
}
const token = generateRandomString(32, alphabet("a-z", "0-9", "A-Z"));
const hash = await hs256(ctx.context.secret, token);
const cookie = `${token}!${hash}`;
await ctx.setSignedCookie(
ctx.context.authCookies.csrfToken.name,
cookie,
ctx.context.secret,
ctx.context.authCookies.csrfToken.options,
);
return ctx.json({
csrfToken: token,
});
},
);

View File

@@ -3,7 +3,6 @@ import { createJWT, validateJWT, type JWT } from "oslo/jwt";
import { z } from "zod"; import { z } from "zod";
import { createAuthEndpoint } from "../call"; import { createAuthEndpoint } from "../call";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { redirectURLMiddleware } from "../middlewares/redirect";
import { getSessionFromCtx } from "./session"; import { getSessionFromCtx } from "./session";
export async function createEmailVerificationToken( export async function createEmailVerificationToken(
@@ -45,7 +44,6 @@ export const sendVerificationEmail = createAuthEndpoint(
email: z.string().email(), email: z.string().email(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
if (!ctx.context.options.emailVerification?.sendVerificationEmail) { if (!ctx.context.options.emailVerification?.sendVerificationEmail) {
@@ -86,7 +84,6 @@ export const verifyEmail = createAuthEndpoint(
token: z.string(), token: z.string(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
const { token } = ctx.query; const { token } = ctx.query;

View File

@@ -1,7 +1,6 @@
import { z } from "zod"; import { z } from "zod";
import { createAuthEndpoint } from "../call"; import { createAuthEndpoint } from "../call";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { redirectURLMiddleware } from "../middlewares/redirect";
export const forgetPassword = createAuthEndpoint( export const forgetPassword = createAuthEndpoint(
"/forget-password", "/forget-password",
@@ -20,7 +19,6 @@ export const forgetPassword = createAuthEndpoint(
*/ */
redirectTo: z.string(), redirectTo: z.string(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
if (!ctx.context.options.emailAndPassword?.sendResetPassword) { if (!ctx.context.options.emailAndPassword?.sendResetPassword) {
@@ -82,7 +80,6 @@ export const forgetPasswordCallback = createAuthEndpoint(
query: z.object({ query: z.object({
callbackURL: z.string(), callbackURL: z.string(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
const { token } = ctx.params; const { token } = ctx.params;

View File

@@ -5,7 +5,6 @@ export * from "./sign-out";
export * from "./forget-password"; export * from "./forget-password";
export * from "./email-verification"; export * from "./email-verification";
export * from "./update-user"; export * from "./update-user";
export * from "./csrf";
export * from "./error"; export * from "./error";
export * from "./ok"; export * from "./ok";
export * from "./sign-up"; export * from "./sign-up";

View File

@@ -324,12 +324,13 @@ describe("session storage", async () => {
}, },
}); });
expect(session.data).not.toBeNull(); expect(session.data).not.toBeNull();
await client.user.revokeSession({ const res = await client.user.revokeSession({
fetchOptions: { fetchOptions: {
headers, headers,
}, },
id: session.data?.session?.id || "", id: session.data?.session?.id || "",
}); });
console.log(res);
const revokedSession = await client.getSession({ const revokedSession = await client.getSession({
fetchOptions: { fetchOptions: {
headers, headers,

View File

@@ -4,7 +4,6 @@ import { z } from "zod";
import { generateState } from "../../oauth2/state"; import { generateState } from "../../oauth2/state";
import { createAuthEndpoint } from "../call"; import { createAuthEndpoint } from "../call";
import { setSessionCookie } from "../../cookies"; import { setSessionCookie } from "../../cookies";
import { redirectURLMiddleware } from "../middlewares/redirect";
import { socialProviderList } from "../../social-providers"; import { socialProviderList } from "../../social-providers";
import { createEmailVerificationToken } from "./email-verification"; import { createEmailVerificationToken } from "./email-verification";
import { logger } from "../../utils"; import { logger } from "../../utils";
@@ -33,7 +32,6 @@ export const signInOAuth = createAuthEndpoint(
*/ */
provider: z.enum(socialProviderList), provider: z.enum(socialProviderList),
}), }),
use: [redirectURLMiddleware],
}, },
async (c) => { async (c) => {
const provider = c.context.socialProviders.find( const provider = c.context.socialProviders.find(
@@ -103,7 +101,6 @@ export const signInEmail = createAuthEndpoint(
*/ */
dontRememberMe: z.boolean().default(false).optional(), dontRememberMe: z.boolean().default(false).optional(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
if (!ctx.context.options?.emailAndPassword?.enabled) { if (!ctx.context.options?.emailAndPassword?.enabled) {

View File

@@ -2,7 +2,6 @@ import { z } from "zod";
import { createAuthEndpoint } from "../call"; import { createAuthEndpoint } from "../call";
import { deleteSessionCookie } from "../../cookies"; import { deleteSessionCookie } from "../../cookies";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { redirectURLMiddleware } from "../middlewares/redirect";
export const signOut = createAuthEndpoint( export const signOut = createAuthEndpoint(
"/sign-out", "/sign-out",

View File

@@ -13,7 +13,6 @@ import type {
import type { toZod } from "../../types/to-zod"; import type { toZod } from "../../types/to-zod";
import { parseUserInput } from "../../db/schema"; import { parseUserInput } from "../../db/schema";
import { getDate } from "../../utils/date"; import { getDate } from "../../utils/date";
import { redirectURLMiddleware } from "../middlewares/redirect";
import { logger } from "../../utils"; import { logger } from "../../utils";
export const signUpEmail = <O extends BetterAuthOptions>() => export const signUpEmail = <O extends BetterAuthOptions>() =>
@@ -33,7 +32,6 @@ export const signUpEmail = <O extends BetterAuthOptions>() =>
callbackURL: ZodOptional<ZodString>; callbackURL: ZodOptional<ZodString>;
}> & }> &
toZod<AdditionalUserFieldsInput<O>>, toZod<AdditionalUserFieldsInput<O>>,
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
if (!ctx.context.options.emailAndPassword?.enabled) { if (!ctx.context.options.emailAndPassword?.enabled) {

View File

@@ -4,7 +4,6 @@ import { alphabet, generateRandomString } from "../../crypto/random";
import { deleteSessionCookie, setSessionCookie } from "../../cookies"; import { deleteSessionCookie, setSessionCookie } from "../../cookies";
import { sessionMiddleware } from "./session"; import { sessionMiddleware } from "./session";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { redirectURLMiddleware } from "../middlewares/redirect";
import { createEmailVerificationToken } from "./email-verification"; import { createEmailVerificationToken } from "./email-verification";
import type { toZod } from "../../types/to-zod"; import type { toZod } from "../../types/to-zod";
import type { AdditionalUserFieldsInput, BetterAuthOptions } from "../../types"; import type { AdditionalUserFieldsInput, BetterAuthOptions } from "../../types";
@@ -20,7 +19,7 @@ export const updateUser = <O extends BetterAuthOptions>() =>
image: ZodOptional<ZodString>; image: ZodOptional<ZodString>;
}> & }> &
toZod<AdditionalUserFieldsInput<O>>, toZod<AdditionalUserFieldsInput<O>>,
use: [sessionMiddleware, redirectURLMiddleware], use: [sessionMiddleware],
}, },
async (ctx) => { async (ctx) => {
const body = ctx.body as { const body = ctx.body as {
@@ -263,7 +262,7 @@ export const changeEmail = createAuthEndpoint(
newEmail: z.string().email(), newEmail: z.string().email(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
}), }),
use: [sessionMiddleware, redirectURLMiddleware], use: [sessionMiddleware],
}, },
async (ctx) => { async (ctx) => {
if (!ctx.context.options.user?.changeEmail?.enabled) { if (!ctx.context.options.user?.changeEmail?.enabled) {

View File

@@ -2,18 +2,23 @@ import { createFetch } from "@better-fetch/fetch";
import { getBaseURL } from "../utils/url"; import { getBaseURL } from "../utils/url";
import { type Atom } from "nanostores"; import { type Atom } from "nanostores";
import type { AtomListener, ClientOptions } from "./types"; import type { AtomListener, ClientOptions } from "./types";
import { addCurrentURL, csrfPlugin, redirectPlugin } from "./fetch-plugins"; import { addCurrentURL, redirectPlugin } from "./fetch-plugins";
export const getClientConfig = <O extends ClientOptions>(options?: O) => { export const getClientConfig = <O extends ClientOptions>(options?: O) => {
/* check if the credentials property is supported. Useful for cf workers */
const isCredentialsSupported = "credentials" in Request.prototype;
const baseURL = getBaseURL(
options?.fetchOptions?.baseURL || options?.baseURL,
);
const $fetch = createFetch({ const $fetch = createFetch({
baseURL: getBaseURL(options?.fetchOptions?.baseURL || options?.baseURL), baseURL,
credentials: "include", ...(isCredentialsSupported ? { credentials: "include" } : {}),
method: "GET", method: "GET",
...options?.fetchOptions, ...options?.fetchOptions,
plugins: options?.disableDefaultFetchPlugins plugins: options?.disableDefaultFetchPlugins
? options.fetchOptions?.plugins ? options.fetchOptions?.plugins
: [ : [
...(!options?.disableCSRFTokenCheck ? [csrfPlugin] : []),
redirectPlugin, redirectPlugin,
addCurrentURL, addCurrentURL,
...(options?.fetchOptions?.plugins?.filter( ...(options?.fetchOptions?.plugins?.filter(

View File

@@ -1,5 +1,4 @@
import { type BetterFetchPlugin, betterFetch } from "@better-fetch/fetch"; import { type BetterFetchPlugin } from "@better-fetch/fetch";
import { BetterAuthError } from "../error";
export const redirectPlugin = { export const redirectPlugin = {
id: "redirect", id: "redirect",
@@ -29,52 +28,3 @@ export const addCurrentURL = {
}, },
}, },
} satisfies BetterFetchPlugin; } satisfies BetterFetchPlugin;
export const csrfPlugin = {
id: "csrf",
name: "CSRF Check",
async init(url, options) {
if (options?.method !== "GET") {
options = options || {};
const { data, error } = await betterFetch<{
csrfToken: string;
}>("/csrf", {
body: undefined,
baseURL: options.baseURL,
plugins: [],
method: "GET",
credentials: "include",
customFetchImpl: options.customFetchImpl,
});
if (error) {
if (error.status === 404) {
throw new BetterAuthError(
"CSRF route not found. Make sure the server is running and the base URL is correct and includes the path (e.g. http://localhost:3000/api/auth).",
);
}
if (error.status === 429) {
return new Response(
JSON.stringify({
message: "Too many requests. Please try again later.",
}),
{
status: 429,
statusText: "Too Many Requests",
},
);
}
throw new BetterAuthError(
"Failed to fetch CSRF token: " + error.message,
);
}
const csrfToken = data?.csrfToken;
options.body = {
...options?.body,
csrfToken: csrfToken,
};
}
options.credentials = "include";
return { url, options };
},
} satisfies BetterFetchPlugin;

View File

@@ -68,11 +68,12 @@ export function createDynamicPathProxy<T extends Record<string, any>>(
const options = { const options = {
...fetchOptions, ...fetchOptions,
...argFetchOptions, ...argFetchOptions,
}; } as BetterFetchOption;
const method = getMethod(routePath, knownPathMethods, arg); const method = getMethod(routePath, knownPathMethods, arg);
return await client(routePath, { return await client(routePath, {
...options, ...options,
body: body:
method === "GET" method === "GET"
? undefined ? undefined

View File

@@ -57,7 +57,6 @@ export interface ClientOptions {
fetchOptions?: BetterFetchOption; fetchOptions?: BetterFetchOption;
plugins?: BetterAuthClientPlugin[]; plugins?: BetterAuthClientPlugin[];
baseURL?: string; baseURL?: string;
disableCSRFTokenCheck?: boolean;
disableDefaultFetchPlugins?: boolean; disableDefaultFetchPlugins?: boolean;
} }

View File

@@ -41,7 +41,7 @@ describe("cookies", async () => {
const { client, testUser } = await getTestInstance({ const { client, testUser } = await getTestInstance({
advanced: { useSecureCookies: true }, advanced: { useSecureCookies: true },
}); });
await client.signIn.email( const res = await client.signIn.email(
{ {
email: testUser.email, email: testUser.email,
password: testUser.password, password: testUser.password,
@@ -49,10 +49,12 @@ describe("cookies", async () => {
{ {
onResponse(context) { onResponse(context) {
const setCookie = context.response.headers.get("set-cookie"); const setCookie = context.response.headers.get("set-cookie");
console.log(setCookie, context);
expect(setCookie).toContain("Secure"); expect(setCookie).toContain("Secure");
}, },
}, },
); );
console.log(res);
}); });
it("should use secure cookies when the base url is https", async () => { it("should use secure cookies when the base url is https", async () => {
@@ -108,7 +110,6 @@ describe("crossSubdomainCookies", () => {
crossSubDomainCookies: { crossSubDomainCookies: {
enabled: true, enabled: true,
}, },
disableCSRFCheck: true,
}, },
}); });

View File

@@ -62,17 +62,6 @@ export function getCookies(options: BetterAuthOptions) {
...(crossSubdomainEnabled ? { domain } : {}), ...(crossSubdomainEnabled ? { domain } : {}),
} satisfies CookieOptions, } satisfies CookieOptions,
}, },
csrfToken: {
name: `${secureCookiePrefix}${cookiePrefix}.csrf_token`,
options: {
httpOnly: true,
sameSite,
path: "/",
secure: !!secureCookiePrefix,
maxAge: 60 * 60 * 24 * 7,
...(crossSubdomainEnabled ? { domain } : {}),
} satisfies CookieOptions,
},
state: { state: {
name: `${secureCookiePrefix}${cookiePrefix}.state`, name: `${secureCookiePrefix}${cookiePrefix}.state`,
options: { options: {

View File

@@ -10,7 +10,6 @@ import { parseJWT } from "oslo/jwt";
import { userSchema } from "../../db/schema"; import { userSchema } from "../../db/schema";
import { generateId } from "../../utils/id"; import { generateId } from "../../utils/id";
import { setSessionCookie } from "../../cookies"; import { setSessionCookie } from "../../cookies";
import { redirectURLMiddleware } from "../../api/middlewares/redirect";
import { import {
createAuthorizationURL, createAuthorizationURL,
validateAuthorizationCode, validateAuthorizationCode,
@@ -148,7 +147,6 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
providerId: z.string(), providerId: z.string(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
const { providerId } = ctx.body; const { providerId } = ctx.body;

View File

@@ -3,7 +3,6 @@ import { createAuthEndpoint } from "../../api/call";
import type { BetterAuthPlugin } from "../../types/plugins"; import type { BetterAuthPlugin } from "../../types/plugins";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { setSessionCookie } from "../../cookies"; import { setSessionCookie } from "../../cookies";
import { redirectURLMiddleware } from "../../api/middlewares/redirect";
import { alphabet, generateRandomString } from "../../crypto"; import { alphabet, generateRandomString } from "../../crypto";
interface MagicLinkOptions { interface MagicLinkOptions {
@@ -53,7 +52,6 @@ export const magicLink = (options: MagicLinkOptions) => {
email: z.string().email(), email: z.string().email(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
}), }),
use: [redirectURLMiddleware],
}, },
async (ctx) => { async (ctx) => {
const { email } = ctx.body; const { email } = ctx.body;

View File

@@ -107,7 +107,6 @@ export const organizationClient = <O extends OrganizationClientOptions>(
$fetch, $fetch,
() => ({ () => ({
method: "POST", method: "POST",
credentials: "include",
body: { body: {
orgId: activeOrgId.get(), orgId: activeOrgId.get(),
}, },

View File

@@ -184,7 +184,6 @@ export const passkeyClient = () => {
$fetch, $fetch,
{ {
method: "GET", method: "GET",
credentials: "include",
}, },
); );
return { return {

View File

@@ -132,6 +132,7 @@ describe("Social Providers", async () => {
}, },
}, },
); );
expect(signInRes.error?.status).toBe(403); expect(signInRes.error?.status).toBe(403);
expect(signInRes.error?.message).toBe("Invalid callbackURL"); expect(signInRes.error?.message).toBe("Invalid callbackURL");
}); });

View File

@@ -60,9 +60,6 @@ export async function getTestInstance<
emailAndPassword: { emailAndPassword: {
enabled: true, enabled: true,
}, },
advanced: {
disableCSRFCheck: true,
},
rateLimit: { rateLimit: {
enabled: false, enabled: false,
}, },
@@ -72,6 +69,10 @@ export async function getTestInstance<
baseURL: "http://localhost:" + (config?.port || 3000), baseURL: "http://localhost:" + (config?.port || 3000),
...opts, ...opts,
...options, ...options,
advanced: {
disableCSRFCheck: true,
...options?.advanced,
},
} as O extends undefined ? typeof opts : O & typeof opts); } as O extends undefined ? typeof opts : O & typeof opts);
const testUser = { const testUser = {
@@ -187,9 +188,6 @@ export async function getTestInstance<
), ),
fetchOptions: { fetchOptions: {
customFetchImpl, customFetchImpl,
headers: {
origin: "http://localhost:" + (config?.port || 3000),
},
}, },
}); });
return { return {

View File

@@ -401,7 +401,9 @@ export interface BetterAuthOptions {
*/ */
useSecureCookies?: boolean; useSecureCookies?: boolean;
/** /**
* Disable CSRF check * Disable trusted origins check
*
* ⚠︎ This is a security risk and it may expose your application to CSRF attacks
*/ */
disableCSRFCheck?: boolean; disableCSRFCheck?: boolean;
/** /**

View File

@@ -1,6 +1,6 @@
{ {
"name": "@better-auth/cli", "name": "@better-auth/cli",
"version": "0.6.1", "version": "0.6.2-beta.7",
"description": "The CLI for Better Auth", "description": "The CLI for Better Auth",
"module": "dist/index.mjs", "module": "dist/index.mjs",
"repository": { "repository": {