From 227350ff99a082fb0a8dba99b4d43e5a3163f1fc Mon Sep 17 00:00:00 2001 From: Bereket Engida Date: Wed, 6 Nov 2024 20:49:49 +0300 Subject: [PATCH] feat: wildcard matching for trusted origins --- .../src/api/middlewares/origin-check.test.ts | 23 ++++++++++++++++++- .../src/api/middlewares/origin-check.ts | 17 ++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/packages/better-auth/src/api/middlewares/origin-check.test.ts b/packages/better-auth/src/api/middlewares/origin-check.test.ts index 76497776..b8d6f58b 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.test.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.test.ts @@ -4,7 +4,11 @@ import { createAuthClient } from "../../client"; describe("Origin Check", async (it) => { const { customFetchImpl, testUser } = await getTestInstance({ - trustedOrigins: ["http://localhost:5000", "https://trusted.com"], + trustedOrigins: [ + "http://localhost:5000", + "https://trusted.com", + "*.my-site.com", + ], emailAndPassword: { enabled: true, async sendResetPassword(url, user) {}, @@ -166,4 +170,21 @@ describe("Origin Check", async (it) => { }); expect(res2.data?.session).toBeDefined(); }); + + it("should work with wildcard trusted origins", async (ctx) => { + const client = createAuthClient({ + baseURL: "https://sub-domain.my-site.com", + fetchOptions: { + customFetchImpl, + headers: { + origin: "https://sub-domain.my-site.com", + }, + }, + }); + const res = await client.forgetPassword({ + email: testUser.email, + redirectTo: "https://sub-domain.my-site.com/reset-password", + }); + expect(res.data?.status).toBeTruthy(); + }); }); diff --git a/packages/better-auth/src/api/middlewares/origin-check.ts b/packages/better-auth/src/api/middlewares/origin-check.ts index 0ac4bd72..12f47806 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.ts @@ -12,16 +12,29 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => { const { body, query, context } = ctx; const originHeader = ctx.headers?.get("origin") || ctx.headers?.get("referer") || ""; - const callbackURL = body?.callbackURL; + const callbackURL = body?.callbackURL || query?.callbackURL; const redirectURL = body?.redirectTo; const currentURL = query?.currentURL; const trustedOrigins = context.trustedOrigins; const usesCookies = ctx.headers?.has("cookie"); + const matchesPattern = (url: string, pattern: string): boolean => { + if (pattern.includes("*")) { + const regex = new RegExp( + "^" + pattern.replace(/\*/g, "[^/]+").replace(/\./g, "\\.") + "$", + ); + return regex.test(url); + } + return url.startsWith(pattern); + }; const validateURL = (url: string | undefined, label: string) => { + if (!url) { + return; + } const isTrustedOrigin = trustedOrigins.some( (origin) => - url?.startsWith(origin) || (url?.startsWith("/") && label !== "origin"), + matchesPattern(url, origin) || + (url?.startsWith("/") && label !== "origin" && !url.includes(":")), ); if (!isTrustedOrigin) { logger.error(`Invalid ${label}: ${url}`);