chore: cleanup

This commit is contained in:
Bereket Engida
2024-10-12 02:51:39 +03:00
parent ec0fed68dd
commit 94f1fe10c0
6 changed files with 43 additions and 52 deletions

View File

@@ -15,7 +15,6 @@ import { useParams, useRouter } from "next/navigation";
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
import { client, organization } from "@/lib/auth-client"; import { client, organization } from "@/lib/auth-client";
import { InvitationError } from "./invitation-error"; import { InvitationError } from "./invitation-error";
import { Invitation } from "@/lib/auth-types";
export default function InvitationPage() { export default function InvitationPage() {
const params = useParams<{ const params = useParams<{

View File

@@ -56,16 +56,16 @@ export const callbackOAuth = createAuthEndpoint(
} }
const { const {
data: { callbackURL, currentURL, dontRememberMe, code }, data: { callbackURL, currentURL, code: stateCode },
} = parsedState; } = parsedState;
const storedCode = await c.getSignedCookie( const storedState = await c.getSignedCookie(
c.context.authCookies.state.name, c.context.authCookies.state.name,
c.context.secret, c.context.secret,
); );
if (storedCode !== code) { if (storedState !== stateCode) {
logger.error("Oauth code mismatch", storedCode, code); logger.error("OAuth state mismatch", storedState, stateCode);
throw c.redirect( throw c.redirect(
`${c.context.baseURL}/error?error=please_restart_the_process`, `${c.context.baseURL}/error?error=please_restart_the_process`,
); );
@@ -186,7 +186,6 @@ export const callbackOAuth = createAuthEndpoint(
const session = await c.context.internalAdapter.createSession( const session = await c.context.internalAdapter.createSession(
userId || id, userId || id,
c.request, c.request,
dontRememberMe,
); );
if (!session) { if (!session) {
const url = new URL(currentURL || callbackURL); const url = new URL(currentURL || callbackURL);
@@ -194,7 +193,7 @@ export const callbackOAuth = createAuthEndpoint(
throw c.redirect(url.toString()); throw c.redirect(url.toString());
} }
try { try {
await setSessionCookie(c, session.id, dontRememberMe); await setSessionCookie(c, session.id);
} catch (e) { } catch (e) {
c.context.logger.error("Unable to set session cookie", e); c.context.logger.error("Unable to set session cookie", e);
const url = new URL(currentURL || callbackURL); const url = new URL(currentURL || callbackURL);

View File

@@ -30,10 +30,6 @@ export const signInOAuth = createAuthEndpoint(
* OAuth2 provider to use` * OAuth2 provider to use`
*/ */
provider: z.enum(oAuthProviderList), provider: z.enum(oAuthProviderList),
/**
* If this is true the session will only be valid for the current browser session
*/
dontRememberMe: z.boolean().default(false).optional(),
}), }),
}, },
async (c) => { async (c) => {
@@ -62,37 +58,33 @@ export const signInOAuth = createAuthEndpoint(
callbackURL || currentURL?.origin || c.context.baseURL, callbackURL || currentURL?.origin || c.context.baseURL,
c.query?.currentURL, c.query?.currentURL,
); );
try { await c.setSignedCookie(
await c.setSignedCookie( cookie.state.name,
cookie.state.name, state.code,
state.code, c.context.secret,
c.context.secret, cookie.state.options,
cookie.state.options, );
); const codeVerifier = generateCodeVerifier();
const codeVerifier = generateCodeVerifier(); await c.setSignedCookie(
await c.setSignedCookie( cookie.pkCodeVerifier.name,
cookie.pkCodeVerifier.name, codeVerifier,
codeVerifier, c.context.secret,
c.context.secret, cookie.pkCodeVerifier.options,
cookie.pkCodeVerifier.options, );
); const url = provider.createAuthorizationURL({
const url = provider.createAuthorizationURL({ state: state.state,
state: state.state, codeVerifier,
codeVerifier, });
}); url.searchParams.set(
url.searchParams.set( "redirect_uri",
"redirect_uri", `${c.context.baseURL}/callback/${c.body.provider}`,
`${c.context.baseURL}/callback/${c.body.provider}`, );
); return {
return { url: url.toString(),
url: url.toString(), state: state.state,
state: state.state, codeVerifier,
codeVerifier, redirect: true,
redirect: true, };
};
} catch (e) {
throw new APIError("INTERNAL_SERVER_ERROR");
}
}, },
); );

View File

@@ -164,6 +164,12 @@ export function deleteSessionCookie(ctx: GenericEndpointContext) {
ctx.setCookie(ctx.context.authCookies.sessionToken.name, "", { ctx.setCookie(ctx.context.authCookies.sessionToken.name, "", {
maxAge: 0, maxAge: 0,
}); });
ctx.setCookie(ctx.context.authCookies.pkCodeVerifier.name, "", {
maxAge: 0,
});
ctx.setCookie(ctx.context.authCookies.state.name, "", {
maxAge: 0,
});
ctx.setCookie(ctx.context.authCookies.dontRememberToken.name, "", { ctx.setCookie(ctx.context.authCookies.dontRememberToken.name, "", {
maxAge: 0, maxAge: 0,
}); });

View File

@@ -65,6 +65,7 @@ describe("Social Providers", async () => {
const signInRes = await client.signIn.social( const signInRes = await client.signIn.social(
{ {
provider: "google", provider: "google",
callbackURL: "/callback",
}, },
{ {
onSuccess(context) { onSuccess(context) {
@@ -97,6 +98,9 @@ describe("Social Providers", async () => {
headers, headers,
onError(context) { onError(context) {
expect(context.response.status).toBe(302); expect(context.response.status).toBe(302);
const location = context.response.headers.get("location");
expect(location).toBeDefined();
expect(location).toContain("/callback");
const cookies = parseSetCookieHeader( const cookies = parseSetCookieHeader(
context.response.headers.get("set-cookie") || "", context.response.headers.get("set-cookie") || "",
); );

View File

@@ -1,19 +1,12 @@
import { generateState as generateStateOAuth } from "oslo/oauth2"; import { generateState as generateStateOAuth } from "oslo/oauth2";
import { z } from "zod"; import { z } from "zod";
export function generateState( export function generateState(callbackURL?: string, currentURL?: string) {
callbackURL?: string,
currentURL?: string,
dontRememberMe?: boolean,
additionalFields?: Record<string, any>,
) {
const code = generateStateOAuth(); const code = generateStateOAuth();
const state = JSON.stringify({ const state = JSON.stringify({
code, code,
callbackURL, callbackURL,
currentURL, currentURL,
dontRememberMe,
additionalFields,
}); });
return { state, code }; return { state, code };
} }
@@ -24,8 +17,6 @@ export function parseState(state: string) {
code: z.string(), code: z.string(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
currentURL: z.string().optional(), currentURL: z.string().optional(),
dontRememberMe: z.boolean().optional(),
additionalFields: z.record(z.string()).optional(),
}) })
.safeParse(JSON.parse(state)); .safeParse(JSON.parse(state));
return data; return data;