feat: wildcard matching for trusted origins

This commit is contained in:
Bereket Engida
2024-11-06 20:49:49 +03:00
parent 867d5dba4b
commit 227350ff99
2 changed files with 37 additions and 3 deletions

View File

@@ -4,7 +4,11 @@ import { createAuthClient } from "../../client";
describe("Origin Check", async (it) => {
const { customFetchImpl, testUser } = await getTestInstance({
trustedOrigins: ["http://localhost:5000", "https://trusted.com"],
trustedOrigins: [
"http://localhost:5000",
"https://trusted.com",
"*.my-site.com",
],
emailAndPassword: {
enabled: true,
async sendResetPassword(url, user) {},
@@ -166,4 +170,21 @@ describe("Origin Check", async (it) => {
});
expect(res2.data?.session).toBeDefined();
});
it("should work with wildcard trusted origins", async (ctx) => {
const client = createAuthClient({
baseURL: "https://sub-domain.my-site.com",
fetchOptions: {
customFetchImpl,
headers: {
origin: "https://sub-domain.my-site.com",
},
},
});
const res = await client.forgetPassword({
email: testUser.email,
redirectTo: "https://sub-domain.my-site.com/reset-password",
});
expect(res.data?.status).toBeTruthy();
});
});

View File

@@ -12,16 +12,29 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => {
const { body, query, context } = ctx;
const originHeader =
ctx.headers?.get("origin") || ctx.headers?.get("referer") || "";
const callbackURL = body?.callbackURL;
const callbackURL = body?.callbackURL || query?.callbackURL;
const redirectURL = body?.redirectTo;
const currentURL = query?.currentURL;
const trustedOrigins = context.trustedOrigins;
const usesCookies = ctx.headers?.has("cookie");
const matchesPattern = (url: string, pattern: string): boolean => {
if (pattern.includes("*")) {
const regex = new RegExp(
"^" + pattern.replace(/\*/g, "[^/]+").replace(/\./g, "\\.") + "$",
);
return regex.test(url);
}
return url.startsWith(pattern);
};
const validateURL = (url: string | undefined, label: string) => {
if (!url) {
return;
}
const isTrustedOrigin = trustedOrigins.some(
(origin) =>
url?.startsWith(origin) || (url?.startsWith("/") && label !== "origin"),
matchesPattern(url, origin) ||
(url?.startsWith("/") && label !== "origin" && !url.includes(":")),
);
if (!isTrustedOrigin) {
logger.error(`Invalid ${label}: ${url}`);