Revert "fix: remove current url from requests (#1187)"

This reverts commit 3aa95aa125.
This commit is contained in:
Bereket Engida
2025-01-12 01:12:55 +03:00
parent b8c4680cdc
commit fd45786c8e
15 changed files with 149 additions and 15 deletions

View File

@@ -6,7 +6,6 @@
"skipLibCheck": true, "skipLibCheck": true,
"strict": true, "strict": true,
"noEmit": true, "noEmit": true,
"declaration": true,
"esModuleInterop": true, "esModuleInterop": true,
"module": "esnext", "module": "esnext",
"moduleResolution": "bundler", "moduleResolution": "bundler",

View File

@@ -126,6 +126,28 @@ describe("Origin Check", async (it) => {
expect(res.data?.user).toBeDefined(); expect(res.data?.user).toBeDefined();
}); });
it("shouldn't allow untrusted currentURL", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
},
});
const res2 = await client.signIn.email({
email: testUser.email,
password: testUser.password,
fetchOptions: {
// @ts-expect-error - query is not defined in the type
query: {
currentURL: "http://malicious.com",
},
},
});
expect(res2.error?.status).toBe(403);
expect(res2.error?.message).toBe("Invalid currentURL");
});
it("shouldn't allow untrusted redirectTo", async (ctx) => { it("shouldn't allow untrusted redirectTo", async (ctx) => {
const client = createAuthClient({ const client = createAuthClient({
baseURL: "http://localhost:3000", baseURL: "http://localhost:3000",
@@ -141,6 +163,35 @@ describe("Origin Check", async (it) => {
expect(res.error?.message).toBe("Invalid redirectURL"); expect(res.error?.message).toBe("Invalid redirectURL");
}); });
it("should work with list of trusted origins ", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
headers: {
origin: "https://trusted.com",
},
},
});
const res = await client.forgetPassword({
email: testUser.email,
redirectTo: "http://localhost:5000/reset-password",
});
expect(res.data?.status).toBeTruthy();
const res2 = await client.signIn.email({
email: testUser.email,
password: testUser.password,
fetchOptions: {
// @ts-expect-error - query is not defined in the type
query: {
currentURL: "http://localhost:5000",
},
},
});
expect(res2.data?.user).toBeDefined();
});
it("should work with wildcard trusted origins", async (ctx) => { it("should work with wildcard trusted origins", async (ctx) => {
const client = createAuthClient({ const client = createAuthClient({
baseURL: "https://sub-domain.my-site.com", baseURL: "https://sub-domain.my-site.com",

View File

@@ -17,6 +17,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => {
ctx.headers?.get("origin") || ctx.headers?.get("referer") || ""; ctx.headers?.get("origin") || ctx.headers?.get("referer") || "";
const callbackURL = body?.callbackURL || query?.callbackURL; const callbackURL = body?.callbackURL || query?.callbackURL;
const redirectURL = body?.redirectTo; const redirectURL = body?.redirectTo;
const currentURL = query?.currentURL;
const errorCallbackURL = body?.errorCallbackURL; const errorCallbackURL = body?.errorCallbackURL;
const newUserCallbackURL = body?.newUserCallbackURL; const newUserCallbackURL = body?.newUserCallbackURL;
const trustedOrigins = context.trustedOrigins; const trustedOrigins = context.trustedOrigins;
@@ -58,6 +59,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => {
} }
callbackURL && validateURL(callbackURL, "callbackURL"); callbackURL && validateURL(callbackURL, "callbackURL");
redirectURL && validateURL(redirectURL, "redirectURL"); redirectURL && validateURL(redirectURL, "redirectURL");
currentURL && validateURL(currentURL, "currentURL");
errorCallbackURL && validateURL(errorCallbackURL, "errorCallbackURL"); errorCallbackURL && validateURL(errorCallbackURL, "errorCallbackURL");
newUserCallbackURL && validateURL(newUserCallbackURL, "newUserCallbackURL"); newUserCallbackURL && validateURL(newUserCallbackURL, "newUserCallbackURL");
}); });

View File

@@ -65,6 +65,15 @@ export const linkSocialAccount = createAuthEndpoint(
{ {
method: "POST", method: "POST",
requireHeaders: true, requireHeaders: true,
query: z
.object({
/**
* Redirect to the current URL after the
* user has signed in.
*/
currentURL: z.string().optional(),
})
.optional(),
body: z.object({ body: z.object({
/** /**
* Callback URL to redirect to after the user has signed in. * Callback URL to redirect to after the user has signed in.

View File

@@ -53,7 +53,7 @@ export async function sendVerificationEmailFn(
ctx.context.options.emailVerification?.expiresIn, ctx.context.options.emailVerification?.expiresIn,
); );
const url = `${ctx.context.baseURL}/verify-email?token=${token}&callbackURL=${ const url = `${ctx.context.baseURL}/verify-email?token=${token}&callbackURL=${
ctx.body.callbackURL || "/" ctx.body.callbackURL || ctx.query?.currentURL || "/"
}`; }`;
await ctx.context.options.emailVerification.sendVerificationEmail( await ctx.context.options.emailVerification.sendVerificationEmail(
{ {
@@ -69,6 +69,15 @@ export const sendVerificationEmail = createAuthEndpoint(
"/send-verification-email", "/send-verification-email",
{ {
method: "POST", method: "POST",
query: z
.object({
currentURL: z
.string({
description: "The URL to use for email verification callback",
})
.optional(),
})
.optional(),
body: z.object({ body: z.object({
email: z email: z
.string({ .string({

View File

@@ -195,12 +195,13 @@ export const forgetPasswordCallback = createAuthEndpoint(
export const resetPassword = createAuthEndpoint( export const resetPassword = createAuthEndpoint(
"/reset-password", "/reset-password",
{ {
method: "POST", query: z.optional(
query: z z.object({
.object({
token: z.string().optional(), token: z.string().optional(),
}) currentURL: z.string().optional(),
.optional(), }),
),
method: "POST",
body: z.object({ body: z.object({
newPassword: z.string({ newPassword: z.string({
description: "The new password to set", description: "The new password to set",
@@ -235,7 +236,12 @@ export const resetPassword = createAuthEndpoint(
}, },
}, },
async (ctx) => { async (ctx) => {
const token = ctx.body.token || ctx.query?.token; const token =
ctx.body.token ||
ctx.query?.token ||
(ctx.query?.currentURL
? new URL(ctx.query.currentURL).searchParams.get("token")
: "");
if (!token) { if (!token) {
throw new APIError("BAD_REQUEST", { throw new APIError("BAD_REQUEST", {
message: BASE_ERROR_CODES.INVALID_TOKEN, message: BASE_ERROR_CODES.INVALID_TOKEN,

View File

@@ -12,6 +12,15 @@ export const signInSocial = createAuthEndpoint(
"/sign-in/social", "/sign-in/social",
{ {
method: "POST", method: "POST",
query: z
.object({
/**
* Redirect to the current URL after the
* user has signed in.
*/
currentURL: z.string().optional(),
})
.optional(),
body: z.object({ body: z.object({
/** /**
* Callback URL to redirect to after the user * Callback URL to redirect to after the user

View File

@@ -18,6 +18,11 @@ export const signUpEmail = <O extends BetterAuthOptions>() =>
"/sign-up/email", "/sign-up/email",
{ {
method: "POST", method: "POST",
query: z
.object({
currentURL: z.string().optional(),
})
.optional(),
body: z.record(z.string(), z.any()) as unknown as ZodObject<{ body: z.record(z.string(), z.any()) as unknown as ZodObject<{
name: ZodString; name: ZodString;
email: ZodString; email: ZodString;
@@ -192,7 +197,9 @@ export const signUpEmail = <O extends BetterAuthOptions>() =>
); );
const url = `${ const url = `${
ctx.context.baseURL ctx.context.baseURL
}/verify-email?token=${token}&callbackURL=${body.callbackURL || "/"}`; }/verify-email?token=${token}&callbackURL=${
body.callbackURL || ctx.query?.currentURL || "/"
}`;
await ctx.context.options.emailVerification?.sendVerificationEmail?.( await ctx.context.options.emailVerification?.sendVerificationEmail?.(
{ {
user: createdUser, user: createdUser,

View File

@@ -517,6 +517,11 @@ export const changeEmail = createAuthEndpoint(
"/change-email", "/change-email",
{ {
method: "POST", method: "POST",
query: z
.object({
currentURL: z.string().optional(),
})
.optional(),
body: z.object({ body: z.object({
newEmail: z newEmail: z
.string({ .string({
@@ -611,7 +616,9 @@ export const changeEmail = createAuthEndpoint(
); );
const url = `${ const url = `${
ctx.context.baseURL ctx.context.baseURL
}/verify-email?token=${token}&callbackURL=${ctx.body.callbackURL || "/"}`; }/verify-email?token=${token}&callbackURL=${
ctx.body.callbackURL || ctx.query?.currentURL || "/"
}`;
await ctx.context.options.user.changeEmail.sendChangeEmailVerification( await ctx.context.options.user.changeEmail.sendChangeEmailVerification(
{ {
user: ctx.context.session.user, user: ctx.context.session.user,

View File

@@ -2,7 +2,7 @@ import { createFetch } from "@better-fetch/fetch";
import { getBaseURL } from "../utils/url"; import { getBaseURL } from "../utils/url";
import { type WritableAtom } from "nanostores"; import { type WritableAtom } from "nanostores";
import type { AtomListener, ClientOptions } from "./types"; import type { AtomListener, ClientOptions } from "./types";
import { redirectPlugin } from "./fetch-plugins"; import { addCurrentURL, redirectPlugin } from "./fetch-plugins";
import { getSessionAtom } from "./session-atom"; import { getSessionAtom } from "./session-atom";
import { parseJSON } from "./parser"; import { parseJSON } from "./parser";
@@ -35,6 +35,7 @@ export const getClientConfig = (options?: ClientOptions) => {
? [...(options?.fetchOptions?.plugins || []), ...pluginsFetchPlugins] ? [...(options?.fetchOptions?.plugins || []), ...pluginsFetchPlugins]
: [ : [
redirectPlugin, redirectPlugin,
addCurrentURL,
...(options?.fetchOptions?.plugins || []), ...(options?.fetchOptions?.plugins || []),
...pluginsFetchPlugins, ...pluginsFetchPlugins,
], ],

View File

@@ -17,3 +17,22 @@ export const redirectPlugin = {
}, },
}, },
} satisfies BetterFetchPlugin; } satisfies BetterFetchPlugin;
export const addCurrentURL = {
id: "add-current-url",
name: "Add current URL",
hooks: {
onRequest(context) {
if (typeof window !== "undefined" && window.location) {
if (window.location) {
try {
const url = new URL(context.url);
url.searchParams.set("currentURL", window.location.href);
context.url = url;
} catch {}
}
}
return context;
},
},
} satisfies BetterFetchPlugin;

View File

@@ -16,4 +16,3 @@ export * from "../../plugins/custom-session/client";
export * from "./infer-plugin"; export * from "./infer-plugin";
export * from "../../plugins/sso/client"; export * from "../../plugins/sso/client";
export * from "../../plugins/oidc-provider/client"; export * from "../../plugins/oidc-provider/client";
export type * from "@simplewebauthn/server";

View File

@@ -1,6 +1,7 @@
import { z } from "zod"; import { z } from "zod";
import type { GenericEndpointContext } from "../types"; import type { GenericEndpointContext } from "../types";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { getOrigin } from "../utils/url";
import { generateRandomString } from "../crypto"; import { generateRandomString } from "../crypto";
export async function generateState( export async function generateState(
@@ -10,7 +11,10 @@ export async function generateState(
userId: string; userId: string;
}, },
) { ) {
const callbackURL = c.body?.callbackURL || c.context.options.baseURL; const callbackURL =
c.body?.callbackURL ||
(c.query?.currentURL ? getOrigin(c.query?.currentURL) : "") ||
c.context.options.baseURL;
if (!callbackURL) { if (!callbackURL) {
throw new APIError("BAD_REQUEST", { throw new APIError("BAD_REQUEST", {
message: "callbackURL is required", message: "callbackURL is required",
@@ -21,7 +25,7 @@ export async function generateState(
const data = JSON.stringify({ const data = JSON.stringify({
callbackURL, callbackURL,
codeVerifier, codeVerifier,
errorURL: c.body?.errorCallbackURL, errorURL: c.body?.errorCallbackURL || c.query?.currentURL,
newUserURL: c.body?.newUserCallbackURL, newUserURL: c.body?.newUserCallbackURL,
link, link,
/** /**

View File

@@ -260,6 +260,19 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
"/sign-in/oauth2", "/sign-in/oauth2",
{ {
method: "POST", method: "POST",
query: z
.object({
/**
* Redirect to the current URL after the
* user has signed in.
*/
currentURL: z
.string({
description: "Redirect to the current URL after sign in",
})
.optional(),
})
.optional(),
body: z.object({ body: z.object({
providerId: z.string({ providerId: z.string({
description: "The provider ID for the OAuth provider", description: "The provider ID for the OAuth provider",

View File

@@ -736,4 +736,3 @@ export const oidcProvider = (options: OIDCOptions) => {
schema, schema,
} satisfies BetterAuthPlugin; } satisfies BetterAuthPlugin;
}; };
export type * from "./types";