Revert "fix: remove current url from requests (#1187)"

This reverts commit 3aa95aa125.
This commit is contained in:
Bereket Engida
2025-01-12 01:12:55 +03:00
parent b8c4680cdc
commit fd45786c8e
15 changed files with 149 additions and 15 deletions

View File

@@ -6,7 +6,6 @@
"skipLibCheck": true,
"strict": true,
"noEmit": true,
"declaration": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "bundler",

View File

@@ -126,6 +126,28 @@ describe("Origin Check", async (it) => {
expect(res.data?.user).toBeDefined();
});
it("shouldn't allow untrusted currentURL", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
},
});
const res2 = await client.signIn.email({
email: testUser.email,
password: testUser.password,
fetchOptions: {
// @ts-expect-error - query is not defined in the type
query: {
currentURL: "http://malicious.com",
},
},
});
expect(res2.error?.status).toBe(403);
expect(res2.error?.message).toBe("Invalid currentURL");
});
it("shouldn't allow untrusted redirectTo", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
@@ -141,6 +163,35 @@ describe("Origin Check", async (it) => {
expect(res.error?.message).toBe("Invalid redirectURL");
});
it("should work with list of trusted origins ", async (ctx) => {
const client = createAuthClient({
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
headers: {
origin: "https://trusted.com",
},
},
});
const res = await client.forgetPassword({
email: testUser.email,
redirectTo: "http://localhost:5000/reset-password",
});
expect(res.data?.status).toBeTruthy();
const res2 = await client.signIn.email({
email: testUser.email,
password: testUser.password,
fetchOptions: {
// @ts-expect-error - query is not defined in the type
query: {
currentURL: "http://localhost:5000",
},
},
});
expect(res2.data?.user).toBeDefined();
});
it("should work with wildcard trusted origins", async (ctx) => {
const client = createAuthClient({
baseURL: "https://sub-domain.my-site.com",

View File

@@ -17,6 +17,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => {
ctx.headers?.get("origin") || ctx.headers?.get("referer") || "";
const callbackURL = body?.callbackURL || query?.callbackURL;
const redirectURL = body?.redirectTo;
const currentURL = query?.currentURL;
const errorCallbackURL = body?.errorCallbackURL;
const newUserCallbackURL = body?.newUserCallbackURL;
const trustedOrigins = context.trustedOrigins;
@@ -58,6 +59,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => {
}
callbackURL && validateURL(callbackURL, "callbackURL");
redirectURL && validateURL(redirectURL, "redirectURL");
currentURL && validateURL(currentURL, "currentURL");
errorCallbackURL && validateURL(errorCallbackURL, "errorCallbackURL");
newUserCallbackURL && validateURL(newUserCallbackURL, "newUserCallbackURL");
});

View File

@@ -65,6 +65,15 @@ export const linkSocialAccount = createAuthEndpoint(
{
method: "POST",
requireHeaders: true,
query: z
.object({
/**
* Redirect to the current URL after the
* user has signed in.
*/
currentURL: z.string().optional(),
})
.optional(),
body: z.object({
/**
* Callback URL to redirect to after the user has signed in.

View File

@@ -53,7 +53,7 @@ export async function sendVerificationEmailFn(
ctx.context.options.emailVerification?.expiresIn,
);
const url = `${ctx.context.baseURL}/verify-email?token=${token}&callbackURL=${
ctx.body.callbackURL || "/"
ctx.body.callbackURL || ctx.query?.currentURL || "/"
}`;
await ctx.context.options.emailVerification.sendVerificationEmail(
{
@@ -69,6 +69,15 @@ export const sendVerificationEmail = createAuthEndpoint(
"/send-verification-email",
{
method: "POST",
query: z
.object({
currentURL: z
.string({
description: "The URL to use for email verification callback",
})
.optional(),
})
.optional(),
body: z.object({
email: z
.string({

View File

@@ -195,12 +195,13 @@ export const forgetPasswordCallback = createAuthEndpoint(
export const resetPassword = createAuthEndpoint(
"/reset-password",
{
method: "POST",
query: z
.object({
query: z.optional(
z.object({
token: z.string().optional(),
})
.optional(),
currentURL: z.string().optional(),
}),
),
method: "POST",
body: z.object({
newPassword: z.string({
description: "The new password to set",
@@ -235,7 +236,12 @@ export const resetPassword = createAuthEndpoint(
},
},
async (ctx) => {
const token = ctx.body.token || ctx.query?.token;
const token =
ctx.body.token ||
ctx.query?.token ||
(ctx.query?.currentURL
? new URL(ctx.query.currentURL).searchParams.get("token")
: "");
if (!token) {
throw new APIError("BAD_REQUEST", {
message: BASE_ERROR_CODES.INVALID_TOKEN,

View File

@@ -12,6 +12,15 @@ export const signInSocial = createAuthEndpoint(
"/sign-in/social",
{
method: "POST",
query: z
.object({
/**
* Redirect to the current URL after the
* user has signed in.
*/
currentURL: z.string().optional(),
})
.optional(),
body: z.object({
/**
* Callback URL to redirect to after the user

View File

@@ -18,6 +18,11 @@ export const signUpEmail = <O extends BetterAuthOptions>() =>
"/sign-up/email",
{
method: "POST",
query: z
.object({
currentURL: z.string().optional(),
})
.optional(),
body: z.record(z.string(), z.any()) as unknown as ZodObject<{
name: ZodString;
email: ZodString;
@@ -192,7 +197,9 @@ export const signUpEmail = <O extends BetterAuthOptions>() =>
);
const url = `${
ctx.context.baseURL
}/verify-email?token=${token}&callbackURL=${body.callbackURL || "/"}`;
}/verify-email?token=${token}&callbackURL=${
body.callbackURL || ctx.query?.currentURL || "/"
}`;
await ctx.context.options.emailVerification?.sendVerificationEmail?.(
{
user: createdUser,

View File

@@ -517,6 +517,11 @@ export const changeEmail = createAuthEndpoint(
"/change-email",
{
method: "POST",
query: z
.object({
currentURL: z.string().optional(),
})
.optional(),
body: z.object({
newEmail: z
.string({
@@ -611,7 +616,9 @@ export const changeEmail = createAuthEndpoint(
);
const url = `${
ctx.context.baseURL
}/verify-email?token=${token}&callbackURL=${ctx.body.callbackURL || "/"}`;
}/verify-email?token=${token}&callbackURL=${
ctx.body.callbackURL || ctx.query?.currentURL || "/"
}`;
await ctx.context.options.user.changeEmail.sendChangeEmailVerification(
{
user: ctx.context.session.user,

View File

@@ -2,7 +2,7 @@ import { createFetch } from "@better-fetch/fetch";
import { getBaseURL } from "../utils/url";
import { type WritableAtom } from "nanostores";
import type { AtomListener, ClientOptions } from "./types";
import { redirectPlugin } from "./fetch-plugins";
import { addCurrentURL, redirectPlugin } from "./fetch-plugins";
import { getSessionAtom } from "./session-atom";
import { parseJSON } from "./parser";
@@ -35,6 +35,7 @@ export const getClientConfig = (options?: ClientOptions) => {
? [...(options?.fetchOptions?.plugins || []), ...pluginsFetchPlugins]
: [
redirectPlugin,
addCurrentURL,
...(options?.fetchOptions?.plugins || []),
...pluginsFetchPlugins,
],

View File

@@ -17,3 +17,22 @@ export const redirectPlugin = {
},
},
} satisfies BetterFetchPlugin;
export const addCurrentURL = {
id: "add-current-url",
name: "Add current URL",
hooks: {
onRequest(context) {
if (typeof window !== "undefined" && window.location) {
if (window.location) {
try {
const url = new URL(context.url);
url.searchParams.set("currentURL", window.location.href);
context.url = url;
} catch {}
}
}
return context;
},
},
} satisfies BetterFetchPlugin;

View File

@@ -16,4 +16,3 @@ export * from "../../plugins/custom-session/client";
export * from "./infer-plugin";
export * from "../../plugins/sso/client";
export * from "../../plugins/oidc-provider/client";
export type * from "@simplewebauthn/server";

View File

@@ -1,6 +1,7 @@
import { z } from "zod";
import type { GenericEndpointContext } from "../types";
import { APIError } from "better-call";
import { getOrigin } from "../utils/url";
import { generateRandomString } from "../crypto";
export async function generateState(
@@ -10,7 +11,10 @@ export async function generateState(
userId: string;
},
) {
const callbackURL = c.body?.callbackURL || c.context.options.baseURL;
const callbackURL =
c.body?.callbackURL ||
(c.query?.currentURL ? getOrigin(c.query?.currentURL) : "") ||
c.context.options.baseURL;
if (!callbackURL) {
throw new APIError("BAD_REQUEST", {
message: "callbackURL is required",
@@ -21,7 +25,7 @@ export async function generateState(
const data = JSON.stringify({
callbackURL,
codeVerifier,
errorURL: c.body?.errorCallbackURL,
errorURL: c.body?.errorCallbackURL || c.query?.currentURL,
newUserURL: c.body?.newUserCallbackURL,
link,
/**

View File

@@ -260,6 +260,19 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
"/sign-in/oauth2",
{
method: "POST",
query: z
.object({
/**
* Redirect to the current URL after the
* user has signed in.
*/
currentURL: z
.string({
description: "Redirect to the current URL after sign in",
})
.optional(),
})
.optional(),
body: z.object({
providerId: z.string({
description: "The provider ID for the OAuth provider",

View File

@@ -736,4 +736,3 @@ export const oidcProvider = (options: OIDCOptions) => {
schema,
} satisfies BetterAuthPlugin;
};
export type * from "./types";