diff --git a/demo/nextjs/tsconfig.json b/demo/nextjs/tsconfig.json index 0747c25d..96f8e1b6 100644 --- a/demo/nextjs/tsconfig.json +++ b/demo/nextjs/tsconfig.json @@ -6,7 +6,6 @@ "skipLibCheck": true, "strict": true, "noEmit": true, - "declaration": true, "esModuleInterop": true, "module": "esnext", "moduleResolution": "bundler", diff --git a/packages/better-auth/src/api/middlewares/origin-check.test.ts b/packages/better-auth/src/api/middlewares/origin-check.test.ts index d7a43890..2c77fd27 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.test.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.test.ts @@ -126,6 +126,28 @@ describe("Origin Check", async (it) => { expect(res.data?.user).toBeDefined(); }); + it("shouldn't allow untrusted currentURL", async (ctx) => { + const client = createAuthClient({ + baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl, + }, + }); + + const res2 = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + fetchOptions: { + // @ts-expect-error - query is not defined in the type + query: { + currentURL: "http://malicious.com", + }, + }, + }); + expect(res2.error?.status).toBe(403); + expect(res2.error?.message).toBe("Invalid currentURL"); + }); + it("shouldn't allow untrusted redirectTo", async (ctx) => { const client = createAuthClient({ baseURL: "http://localhost:3000", @@ -141,6 +163,35 @@ describe("Origin Check", async (it) => { expect(res.error?.message).toBe("Invalid redirectURL"); }); + it("should work with list of trusted origins ", async (ctx) => { + const client = createAuthClient({ + baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl, + headers: { + origin: "https://trusted.com", + }, + }, + }); + const res = await client.forgetPassword({ + email: testUser.email, + redirectTo: "http://localhost:5000/reset-password", + }); + expect(res.data?.status).toBeTruthy(); + + const res2 = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + fetchOptions: { + // @ts-expect-error - query is not defined in the type + query: { + currentURL: "http://localhost:5000", + }, + }, + }); + expect(res2.data?.user).toBeDefined(); + }); + it("should work with wildcard trusted origins", async (ctx) => { const client = createAuthClient({ baseURL: "https://sub-domain.my-site.com", diff --git a/packages/better-auth/src/api/middlewares/origin-check.ts b/packages/better-auth/src/api/middlewares/origin-check.ts index c3e33f9b..b2e87cb9 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.ts @@ -17,6 +17,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => { ctx.headers?.get("origin") || ctx.headers?.get("referer") || ""; const callbackURL = body?.callbackURL || query?.callbackURL; const redirectURL = body?.redirectTo; + const currentURL = query?.currentURL; const errorCallbackURL = body?.errorCallbackURL; const newUserCallbackURL = body?.newUserCallbackURL; const trustedOrigins = context.trustedOrigins; @@ -58,6 +59,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => { } callbackURL && validateURL(callbackURL, "callbackURL"); redirectURL && validateURL(redirectURL, "redirectURL"); + currentURL && validateURL(currentURL, "currentURL"); errorCallbackURL && validateURL(errorCallbackURL, "errorCallbackURL"); newUserCallbackURL && validateURL(newUserCallbackURL, "newUserCallbackURL"); }); diff --git a/packages/better-auth/src/api/routes/account.ts b/packages/better-auth/src/api/routes/account.ts index 6c2840f7..a7d7db2a 100644 --- a/packages/better-auth/src/api/routes/account.ts +++ b/packages/better-auth/src/api/routes/account.ts @@ -65,6 +65,15 @@ export const linkSocialAccount = createAuthEndpoint( { method: "POST", requireHeaders: true, + query: z + .object({ + /** + * Redirect to the current URL after the + * user has signed in. + */ + currentURL: z.string().optional(), + }) + .optional(), body: z.object({ /** * Callback URL to redirect to after the user has signed in. diff --git a/packages/better-auth/src/api/routes/email-verification.ts b/packages/better-auth/src/api/routes/email-verification.ts index d1a2c15b..5f2ce63f 100644 --- a/packages/better-auth/src/api/routes/email-verification.ts +++ b/packages/better-auth/src/api/routes/email-verification.ts @@ -53,7 +53,7 @@ export async function sendVerificationEmailFn( ctx.context.options.emailVerification?.expiresIn, ); const url = `${ctx.context.baseURL}/verify-email?token=${token}&callbackURL=${ - ctx.body.callbackURL || "/" + ctx.body.callbackURL || ctx.query?.currentURL || "/" }`; await ctx.context.options.emailVerification.sendVerificationEmail( { @@ -69,6 +69,15 @@ export const sendVerificationEmail = createAuthEndpoint( "/send-verification-email", { method: "POST", + query: z + .object({ + currentURL: z + .string({ + description: "The URL to use for email verification callback", + }) + .optional(), + }) + .optional(), body: z.object({ email: z .string({ diff --git a/packages/better-auth/src/api/routes/forget-password.ts b/packages/better-auth/src/api/routes/forget-password.ts index 9ebe0fa1..7fe122ad 100644 --- a/packages/better-auth/src/api/routes/forget-password.ts +++ b/packages/better-auth/src/api/routes/forget-password.ts @@ -195,12 +195,13 @@ export const forgetPasswordCallback = createAuthEndpoint( export const resetPassword = createAuthEndpoint( "/reset-password", { - method: "POST", - query: z - .object({ + query: z.optional( + z.object({ token: z.string().optional(), - }) - .optional(), + currentURL: z.string().optional(), + }), + ), + method: "POST", body: z.object({ newPassword: z.string({ description: "The new password to set", @@ -235,7 +236,12 @@ export const resetPassword = createAuthEndpoint( }, }, async (ctx) => { - const token = ctx.body.token || ctx.query?.token; + const token = + ctx.body.token || + ctx.query?.token || + (ctx.query?.currentURL + ? new URL(ctx.query.currentURL).searchParams.get("token") + : ""); if (!token) { throw new APIError("BAD_REQUEST", { message: BASE_ERROR_CODES.INVALID_TOKEN, diff --git a/packages/better-auth/src/api/routes/sign-in.ts b/packages/better-auth/src/api/routes/sign-in.ts index 2a536dfa..9b0bde0c 100644 --- a/packages/better-auth/src/api/routes/sign-in.ts +++ b/packages/better-auth/src/api/routes/sign-in.ts @@ -12,6 +12,15 @@ export const signInSocial = createAuthEndpoint( "/sign-in/social", { method: "POST", + query: z + .object({ + /** + * Redirect to the current URL after the + * user has signed in. + */ + currentURL: z.string().optional(), + }) + .optional(), body: z.object({ /** * Callback URL to redirect to after the user diff --git a/packages/better-auth/src/api/routes/sign-up.ts b/packages/better-auth/src/api/routes/sign-up.ts index eb5aecb7..d7743b10 100644 --- a/packages/better-auth/src/api/routes/sign-up.ts +++ b/packages/better-auth/src/api/routes/sign-up.ts @@ -18,6 +18,11 @@ export const signUpEmail = () => "/sign-up/email", { method: "POST", + query: z + .object({ + currentURL: z.string().optional(), + }) + .optional(), body: z.record(z.string(), z.any()) as unknown as ZodObject<{ name: ZodString; email: ZodString; @@ -192,7 +197,9 @@ export const signUpEmail = () => ); const url = `${ ctx.context.baseURL - }/verify-email?token=${token}&callbackURL=${body.callbackURL || "/"}`; + }/verify-email?token=${token}&callbackURL=${ + body.callbackURL || ctx.query?.currentURL || "/" + }`; await ctx.context.options.emailVerification?.sendVerificationEmail?.( { user: createdUser, diff --git a/packages/better-auth/src/api/routes/update-user.ts b/packages/better-auth/src/api/routes/update-user.ts index b2f614f2..fc340fda 100644 --- a/packages/better-auth/src/api/routes/update-user.ts +++ b/packages/better-auth/src/api/routes/update-user.ts @@ -517,6 +517,11 @@ export const changeEmail = createAuthEndpoint( "/change-email", { method: "POST", + query: z + .object({ + currentURL: z.string().optional(), + }) + .optional(), body: z.object({ newEmail: z .string({ @@ -611,7 +616,9 @@ export const changeEmail = createAuthEndpoint( ); const url = `${ ctx.context.baseURL - }/verify-email?token=${token}&callbackURL=${ctx.body.callbackURL || "/"}`; + }/verify-email?token=${token}&callbackURL=${ + ctx.body.callbackURL || ctx.query?.currentURL || "/" + }`; await ctx.context.options.user.changeEmail.sendChangeEmailVerification( { user: ctx.context.session.user, diff --git a/packages/better-auth/src/client/config.ts b/packages/better-auth/src/client/config.ts index 591a29ae..dce07b21 100644 --- a/packages/better-auth/src/client/config.ts +++ b/packages/better-auth/src/client/config.ts @@ -2,7 +2,7 @@ import { createFetch } from "@better-fetch/fetch"; import { getBaseURL } from "../utils/url"; import { type WritableAtom } from "nanostores"; import type { AtomListener, ClientOptions } from "./types"; -import { redirectPlugin } from "./fetch-plugins"; +import { addCurrentURL, redirectPlugin } from "./fetch-plugins"; import { getSessionAtom } from "./session-atom"; import { parseJSON } from "./parser"; @@ -35,6 +35,7 @@ export const getClientConfig = (options?: ClientOptions) => { ? [...(options?.fetchOptions?.plugins || []), ...pluginsFetchPlugins] : [ redirectPlugin, + addCurrentURL, ...(options?.fetchOptions?.plugins || []), ...pluginsFetchPlugins, ], diff --git a/packages/better-auth/src/client/fetch-plugins.ts b/packages/better-auth/src/client/fetch-plugins.ts index 439afe22..d990ba0f 100644 --- a/packages/better-auth/src/client/fetch-plugins.ts +++ b/packages/better-auth/src/client/fetch-plugins.ts @@ -17,3 +17,22 @@ export const redirectPlugin = { }, }, } satisfies BetterFetchPlugin; + +export const addCurrentURL = { + id: "add-current-url", + name: "Add current URL", + hooks: { + onRequest(context) { + if (typeof window !== "undefined" && window.location) { + if (window.location) { + try { + const url = new URL(context.url); + url.searchParams.set("currentURL", window.location.href); + context.url = url; + } catch {} + } + } + return context; + }, + }, +} satisfies BetterFetchPlugin; diff --git a/packages/better-auth/src/client/plugins/index.ts b/packages/better-auth/src/client/plugins/index.ts index 26ed3083..8d70f321 100644 --- a/packages/better-auth/src/client/plugins/index.ts +++ b/packages/better-auth/src/client/plugins/index.ts @@ -16,4 +16,3 @@ export * from "../../plugins/custom-session/client"; export * from "./infer-plugin"; export * from "../../plugins/sso/client"; export * from "../../plugins/oidc-provider/client"; -export type * from "@simplewebauthn/server"; diff --git a/packages/better-auth/src/oauth2/state.ts b/packages/better-auth/src/oauth2/state.ts index 51fcf896..e681f0dd 100644 --- a/packages/better-auth/src/oauth2/state.ts +++ b/packages/better-auth/src/oauth2/state.ts @@ -1,6 +1,7 @@ import { z } from "zod"; import type { GenericEndpointContext } from "../types"; import { APIError } from "better-call"; +import { getOrigin } from "../utils/url"; import { generateRandomString } from "../crypto"; export async function generateState( @@ -10,7 +11,10 @@ export async function generateState( userId: string; }, ) { - const callbackURL = c.body?.callbackURL || c.context.options.baseURL; + const callbackURL = + c.body?.callbackURL || + (c.query?.currentURL ? getOrigin(c.query?.currentURL) : "") || + c.context.options.baseURL; if (!callbackURL) { throw new APIError("BAD_REQUEST", { message: "callbackURL is required", @@ -21,7 +25,7 @@ export async function generateState( const data = JSON.stringify({ callbackURL, codeVerifier, - errorURL: c.body?.errorCallbackURL, + errorURL: c.body?.errorCallbackURL || c.query?.currentURL, newUserURL: c.body?.newUserCallbackURL, link, /** diff --git a/packages/better-auth/src/plugins/generic-oauth/index.ts b/packages/better-auth/src/plugins/generic-oauth/index.ts index 0323fcee..a5183f35 100644 --- a/packages/better-auth/src/plugins/generic-oauth/index.ts +++ b/packages/better-auth/src/plugins/generic-oauth/index.ts @@ -260,6 +260,19 @@ export const genericOAuth = (options: GenericOAuthOptions) => { "/sign-in/oauth2", { method: "POST", + query: z + .object({ + /** + * Redirect to the current URL after the + * user has signed in. + */ + currentURL: z + .string({ + description: "Redirect to the current URL after sign in", + }) + .optional(), + }) + .optional(), body: z.object({ providerId: z.string({ description: "The provider ID for the OAuth provider", diff --git a/packages/better-auth/src/plugins/oidc-provider/index.ts b/packages/better-auth/src/plugins/oidc-provider/index.ts index 82fa5280..eaf9083a 100644 --- a/packages/better-auth/src/plugins/oidc-provider/index.ts +++ b/packages/better-auth/src/plugins/oidc-provider/index.ts @@ -736,4 +736,3 @@ export const oidcProvider = (options: OIDCOptions) => { schema, } satisfies BetterAuthPlugin; }; -export type * from "./types";