feat: remember me and many more imporves

This commit is contained in:
Bereket Engida
2024-08-29 00:10:35 +03:00
parent ed3579d19e
commit f1e363bbff
36 changed files with 401 additions and 178 deletions

Binary file not shown.

View File

@@ -15,15 +15,20 @@ import { authClient } from "@/lib/auth-client";
import { useState } from "react"; import { useState } from "react";
import { Key } from "lucide-react"; import { Key } from "lucide-react";
import { PasswordInput } from "@/components/ui/password-input"; import { PasswordInput } from "@/components/ui/password-input";
import { Checkbox } from "@/components/ui/checkbox";
import { toast } from "sonner";
import { useRouter } from "next/navigation";
export default function Page() { export default function Page() {
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
const [rememberMe, setRememberMe] = useState(false);
const router = useRouter()
return ( return (
<div className="h-[50rem] w-full dark:bg-black bg-white dark:bg-grid-white/[0.2] bg-grid-black/[0.2] relative flex items-center justify-center"> <div className="h-[50rem] w-full dark:bg-black bg-white dark:bg-grid-white/[0.2] bg-grid-black/[0.2] relative flex items-center justify-center">
{/* Radial gradient for the container to give a faded look */} {/* Radial gradient for the container to give a faded look */}
<div className="absolute pointer-events-none inset-0 flex items-center justify-center dark:bg-black bg-white [mask-image:radial-gradient(ellipse_at_center,transparent_20%,black)]"></div> <div className="absolute pointer-events-none inset-0 flex items-center justify-center dark:bg-black bg-white [mask-image:radial-gradient(ellipse_at_center,transparent_20%,black)]"></div>
<Card className="mx-auto max-w-sm"> <Card className="mx-auto max-w-sm z-50">
<CardHeader> <CardHeader>
<CardTitle className="text-2xl">Login</CardTitle> <CardTitle className="text-2xl">Login</CardTitle>
<CardDescription> <CardDescription>
@@ -63,12 +68,22 @@ export default function Page() {
placeholder="Password" placeholder="Password"
/> />
</div> </div>
<div className="flex items-center gap-2">
<Checkbox onClick={() => {
setRememberMe(!rememberMe)
}} />
<Label>Remember me</Label>
</div>
<Button type="submit" className="w-full" onClick={async () => { <Button type="submit" className="w-full" onClick={async () => {
await authClient.signIn.credential({ const res = await authClient.signIn.credential({
email, email,
password, password,
callbackURL: "/" callbackURL: "/",
dontRememberMe: !rememberMe
}) })
if (res.error) {
toast.error(res.error.message)
}
}}> }}>
Login Login
</Button> </Button>
@@ -85,9 +100,14 @@ export default function Page() {
Login with Github Login with Github
</Button> </Button>
<Button variant="secondary" className="gap-2" onClick={async () => { <Button variant="secondary" className="gap-2" onClick={async () => {
await authClient.passkey.signIn({ const res = await authClient.passkey.signIn({
callbackURL: "/" callbackURL: "/"
}) })
if (res?.error) {
toast.error(res.error.message)
} else {
router.push("/")
}
}}> }}>
<Key size={16} /> <Key size={16} />
Login with Passkey Login with Passkey

View File

@@ -2,6 +2,7 @@ import type { Metadata } from "next";
import { Inter } from "next/font/google"; import { Inter } from "next/font/google";
import "./globals.css"; import "./globals.css";
import { ThemeWrapper } from "@/components/theme-provider"; import { ThemeWrapper } from "@/components/theme-provider";
import { Toaster } from "@/components/ui/sonner";
const inter = Inter({ subsets: ["latin"] }); const inter = Inter({ subsets: ["latin"] });
@@ -16,10 +17,13 @@ export default function RootLayout({
children: React.ReactNode; children: React.ReactNode;
}>) { }>) {
return ( return (
<html lang="en"> <html lang="en" suppressHydrationWarning>
<ThemeWrapper forcedTheme="dark" attribute="class"> <body className={inter.className}>
<body className={inter.className}>{children}</body> <ThemeWrapper forcedTheme="dark" attribute="class">
</ThemeWrapper> {children}
</ThemeWrapper>
<Toaster />
</body>
</html> </html>
); );
} }

View File

@@ -11,7 +11,7 @@ export default async function TypewriterEffectSmoothDemo() {
{/* Radial gradient for the container to give a faded look */} {/* Radial gradient for the container to give a faded look */}
<div className="absolute pointer-events-none inset-0 flex items-center justify-center dark:bg-black bg-white [mask-image:radial-gradient(ellipse_at_center,transparent_20%,black)]"></div> <div className="absolute pointer-events-none inset-0 flex items-center justify-center dark:bg-black bg-white [mask-image:radial-gradient(ellipse_at_center,transparent_20%,black)]"></div>
{ {
session ? <UserCard user={session.user} /> : null session ? <UserCard session={session} /> : null
} }
</div> </div>

View File

@@ -8,9 +8,7 @@ export const SignOut = () => {
<Button <Button
onClick={async () => { onClick={async () => {
await authClient.signOut({ await authClient.signOut({
body: {
callbackURL: "/"
}
}) })
}} }}
> >

View File

@@ -12,6 +12,8 @@ const Toaster = ({ ...props }: ToasterProps) => {
<Sonner <Sonner
theme={theme as ToasterProps["theme"]} theme={theme as ToasterProps["theme"]}
className="toaster group" className="toaster group"
richColors
closeButton
toastOptions={{ toastOptions={{
classNames: { classNames: {
toast: toast:

View File

@@ -3,21 +3,21 @@
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
import { Card, CardContent, CardFooter, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardContent, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
import { Button } from "./ui/button"; import { Button } from "./ui/button";
import { LogOut } from "lucide-react"; import { Check, LogOut } from "lucide-react";
import { authClient } from "@/lib/auth-client"; import { authClient } from "@/lib/auth-client";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import AddPasskey from "./add-passkey"; import AddPasskey from "./add-passkey";
import { Session, User } from "@/lib/types";
import { toast } from "sonner";
export default function UserCard({ export default function UserCard(props: {
user, session: {
}: { user: User;
user: { session: Session
name: string; } | null
email: string;
image?: string;
};
}) { }) {
const router = useRouter(); const router = useRouter();
const session = authClient.useSession(props.session)
return ( return (
<Card> <Card>
<CardHeader> <CardHeader>
@@ -26,23 +26,42 @@ export default function UserCard({
<CardContent className="grid gap-8"> <CardContent className="grid gap-8">
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
<Avatar className="hidden h-9 w-9 sm:flex"> <Avatar className="hidden h-9 w-9 sm:flex">
<AvatarImage src={user.image || "#"} alt="Avatar" /> <AvatarImage src={session?.user.image || "#"} alt="Avatar" />
<AvatarFallback>{user.name.charAt(0)}</AvatarFallback> <AvatarFallback>{session?.user.name.charAt(0)}</AvatarFallback>
</Avatar> </Avatar>
<div className="grid gap-1"> <div className="grid gap-1">
<p className="text-sm font-medium leading-none">{user.name}</p> <p className="text-sm font-medium leading-none">{session?.user.name}</p>
<p className="text-sm text-muted-foreground">{user.email}</p> <p className="text-sm text-muted-foreground">{session?.user.email}</p>
</div> </div>
</div> </div>
<div className="border-y py-4 flex items-center justify-between"> <div className="border-y py-4 flex items-center justify-between gap-2">
<AddPasskey /> <AddPasskey />
{
session?.user.twoFactorEnabled ? <Button variant="secondary" className="gap-2" onClick={async () => {
const res = await authClient.twoFactor.disable()
if (res.error) {
toast.error(res.error.message)
}
}}>
Disable 2FA
</Button> : <Button variant="outline" className="gap-2" onClick={async () => {
const res = await authClient.twoFactor.enable()
if (res.error) {
toast.error(res.error.message)
}
}}>
<p>
Enable 2FA
</p>
</Button>
}
</div> </div>
</CardContent> </CardContent>
<CardFooter> <CardFooter>
<Button className="gap-2 z-10" variant="secondary"> <Button className="gap-2 z-10" variant="secondary">
<LogOut size={16} /> <LogOut size={16} />
<span className="text-sm" onClick={async () => { <span className="text-sm" onClick={async () => {
const res = await authClient.signOut() await authClient.signOut()
router.refresh() router.refresh()
}}> }}>
Sign Out Sign Out

View File

@@ -0,0 +1,5 @@
import { InferSession, InferUser } from "better-auth/types";
import type { auth } from "./auth";
export type User = InferUser<typeof auth>;
export type Session = InferSession<typeof auth>;

View File

@@ -15,7 +15,11 @@ export async function middleware(request: NextRequest) {
permission: { permission: {
invitation: ["create"], invitation: ["create"],
}, },
options: {
headers: request.headers,
},
}); });
console.log({ canInvite });
return NextResponse.next(); return NextResponse.next();
} }

View File

@@ -15,6 +15,7 @@
".": "./dist/index.js", ".": "./dist/index.js",
"./provider": "./dist/provider.js", "./provider": "./dist/provider.js",
"./client": "./dist/client.js", "./client": "./dist/client.js",
"./types": "./dist/types.js",
"./cli": "./dist/cli.js", "./cli": "./dist/cli.js",
"./react": "./dist/react.js", "./react": "./dist/react.js",
"./preact": "./dist/preact.js", "./preact": "./dist/preact.js",
@@ -52,7 +53,7 @@
"@simplewebauthn/browser": "^10.0.0", "@simplewebauthn/browser": "^10.0.0",
"@simplewebauthn/server": "^10.0.1", "@simplewebauthn/server": "^10.0.1",
"arctic": "^1.9.2", "arctic": "^1.9.2",
"better-call": "^0.1.33", "better-call": "^0.1.36",
"chalk": "^5.3.0", "chalk": "^5.3.0",
"commander": "^12.1.0", "commander": "^12.1.0",
"consola": "^3.2.3", "consola": "^3.2.3",

View File

@@ -93,16 +93,11 @@ export const createInternalAdapter = (
session.expiresAt.valueOf() - maxAge.valueOf() + updateDate <= session.expiresAt.valueOf() - maxAge.valueOf() + updateDate <=
Date.now(); Date.now();
if (shouldBeUpdated) { if (shouldBeUpdated) {
const updatedSession = await adapter.update<Session>({ const updatedSession = await adapter.create<Session>({
model: tables.session.tableName, model: tables.session.tableName,
where: [ data: {
{
field: "id",
value: session.id,
},
],
update: {
...session, ...session,
id: generateRandomString(32, alphabet("a-z", "0-9", "A-Z")),
expiresAt: new Date(Date.now() + sessionExpiration), expiresAt: new Date(Date.now() + sessionExpiration),
}, },
}); });

View File

@@ -1,22 +1,27 @@
import { createRouter, Endpoint } from "better-call"; import { Context, createRouter, Endpoint } from "better-call";
import { import {
signInOAuth, signInOAuth,
callbackOAuth, callbackOAuth,
getSession,
signOut, signOut,
signInCredential, signInCredential,
forgetPassword, forgetPassword,
resetPassword, resetPassword,
verifyEmail, verifyEmail,
sendVerificationEmail, sendVerificationEmail,
getSession,
} from "./routes"; } from "./routes";
import { AuthContext } from "../init"; import { AuthContext } from "../init";
import { csrfMiddleware } from "./middlewares/csrf"; import { csrfMiddleware } from "./middlewares/csrf";
import { getCSRFToken } from "./routes/csrf"; import { getCSRFToken } from "./routes/csrf";
import { signUpCredential } from "./routes/sign-up"; import { signUpCredential } from "./routes/sign-up";
import { parseAccount, parseSession, parseUser } from "../adapters/schema"; import { parseAccount, parseSession, parseUser } from "../adapters/schema";
import { BetterAuthOptions, InferSession, InferUser } from "../types";
import { Prettify } from "../types/helper";
export const router = <C extends AuthContext>(ctx: C) => { export const router = <C extends AuthContext, Option extends BetterAuthOptions>(
ctx: C,
option: Option,
) => {
const pluginEndpoints = ctx.options.plugins?.reduce( const pluginEndpoints = ctx.options.plugins?.reduce(
(acc, plugin) => { (acc, plugin) => {
return { return {
@@ -65,11 +70,30 @@ export const router = <C extends AuthContext>(ctx: C) => {
.filter((plugin) => plugin !== undefined) .filter((plugin) => plugin !== undefined)
.flat() || []; .flat() || [];
async function typedSession(
ctx: Context<
"/session",
{
method: "GET";
requireHeaders: true;
}
>,
) {
const handler = await getSession(ctx);
return handler as {
session: Prettify<InferSession<Option>>;
user: Prettify<InferUser<Option>>;
} | null;
}
typedSession.path = getSession.path;
typedSession.method = getSession.method;
typedSession.options = getSession.options;
typedSession.headers = getSession.headers;
const baseEndpoints = { const baseEndpoints = {
signInOAuth, signInOAuth,
callbackOAuth, callbackOAuth,
getCSRFToken, getCSRFToken,
getSession, getSession: typedSession,
signOut, signOut,
signUpCredential, signUpCredential,
signInCredential, signInCredential,
@@ -85,15 +109,49 @@ export const router = <C extends AuthContext>(ctx: C) => {
}; };
let api: Record<string, any> = {}; let api: Record<string, any> = {};
for (const [key, value] of Object.entries(endpoints)) { for (const [key, value] of Object.entries(endpoints)) {
api[key] = (context: any) => { api[key] = async (context: any) => {
for (const plugin of ctx.options.plugins || []) {
if (plugin.hooks?.before) {
for (const hook of plugin.hooks.before) {
const match = hook.matcher(context);
if (match) {
const hookRes = await hook.handler(context);
if (hookRes && "context" in hookRes) {
context = {
...context,
...hookRes.context,
};
}
}
}
}
}
//@ts-ignore //@ts-ignore
return value({ const endpointRes = value({
...context, ...context,
context: { context: {
...ctx, ...ctx,
...context.context, ...context.context,
}, },
}); });
let response = endpointRes;
for (const plugin of ctx.options.plugins || []) {
if (plugin.hooks?.after) {
for (const hook of plugin.hooks.after) {
const match = hook.matcher(context);
if (match) {
const hookRes = await hook.handler({
...context,
returned: endpointRes,
});
if (hookRes && "response" in hookRes) {
response = hookRes.response as any;
}
}
}
}
}
return response;
}; };
api[key].path = value.path; api[key].path = value.path;
api[key].method = value.method; api[key].method = value.method;
@@ -134,7 +192,7 @@ export const router = <C extends AuthContext>(ctx: C) => {
}); });
}, },
onError(e) { onError(e) {
console.log(e); // console.log(e);
}, },
}); });
}; };

View File

@@ -19,16 +19,14 @@ export const csrfMiddleware = createAuthMiddleware(
return; return;
} }
const url = new URL(ctx.request.url); const url = new URL(ctx.request.url);
console.log({ console.log(url.origin, ctx.context.options.baseURL);
url: ctx.request.url,
});
/** /**
* If origin is the same as baseURL or if the * If origin is the same as baseURL or if the
* origin is in the trustedOrigins then we * origin is in the trustedOrigins then we
* don't need to check the CSRF token. * don't need to check the CSRF token.
*/ */
if ( if (
url.origin === ctx.context.baseURL || url.origin === ctx.context.options.baseURL ||
ctx.context.options.trustedOrigins?.includes(url.origin) ctx.context.options.trustedOrigins?.includes(url.origin)
) { ) {
return; return;

View File

@@ -4,6 +4,7 @@ import { APIError } from "better-call";
import { parseState } from "../../utils/state"; import { parseState } from "../../utils/state";
import { userSchema } from "../../adapters/schema"; import { userSchema } from "../../adapters/schema";
import { HIDE_ON_CLIENT_METADATA } from "../../client/client-utils"; import { HIDE_ON_CLIENT_METADATA } from "../../client/client-utils";
import { generateId } from "../../utils/id";
export const callbackOAuth = createAuthEndpoint( export const callbackOAuth = createAuthEndpoint(
"/callback/:id", "/callback/:id",
@@ -36,11 +37,11 @@ export const callbackOAuth = createAuthEndpoint(
c.context.logger.error("Code verification failed"); c.context.logger.error("Code verification failed");
throw new APIError("UNAUTHORIZED"); throw new APIError("UNAUTHORIZED");
} }
const user = await provider.userInfo.getUserInfo(tokens); const user = await provider.userInfo.getUserInfo(tokens);
const id = generateId();
const data = userSchema.safeParse({ const data = userSchema.safeParse({
...user, ...user,
id: user?.id.toString(), id,
}); });
if (!user || data.success === false) { if (!user || data.success === false) {
throw new APIError("BAD_REQUEST"); throw new APIError("BAD_REQUEST");
@@ -52,7 +53,7 @@ export const callbackOAuth = createAuthEndpoint(
} }
//find user in db //find user in db
const dbUser = await c.context.internalAdapter.findUserByEmail(user.email); const dbUser = await c.context.internalAdapter.findUserByEmail(user.email);
let userId = dbUser?.user.id; const userId = dbUser?.user.id;
if (dbUser) { if (dbUser) {
//check if user has already linked this provider //check if user has already linked this provider
const hasBeenLinked = dbUser.accounts.find( const hasBeenLinked = dbUser.accounts.find(
@@ -76,14 +77,13 @@ export const callbackOAuth = createAuthEndpoint(
} }
} else { } else {
try { try {
await c.context.internalAdapter.createOAuthUser(user, { await c.context.internalAdapter.createOAuthUser(data.data, {
...tokens, ...tokens,
id: `${provider.id}:${user.id}`, id: `${provider.id}:${user.id}`,
providerId: provider.id, providerId: provider.id,
accountId: user.id, accountId: user.id,
userId: user.id, userId: id,
}); });
userId = user.id;
} catch (e) { } catch (e) {
const url = new URL(currentURL || callbackURL); const url = new URL(currentURL || callbackURL);
url.searchParams.set("error", "unable_to_create_user"); url.searchParams.set("error", "unable_to_create_user");
@@ -93,10 +93,9 @@ export const callbackOAuth = createAuthEndpoint(
} }
//this should never happen //this should never happen
if (!userId) throw new APIError("INTERNAL_SERVER_ERROR"); if (!userId) throw new APIError("INTERNAL_SERVER_ERROR");
//create session //create session
const session = await c.context.internalAdapter.createSession( const session = await c.context.internalAdapter.createSession(
userId, userId || id,
c.request, c.request,
); );
try { try {

View File

@@ -1,6 +1,5 @@
import { Context } from "better-call"; import { Context } from "better-call";
import { createAuthEndpoint } from "../call"; import { createAuthEndpoint } from "../call";
import { HIDE_ON_CLIENT_METADATA } from "../../client/client-utils";
export const getSession = createAuthEndpoint( export const getSession = createAuthEndpoint(
"/session", "/session",
@@ -31,6 +30,15 @@ export const getSession = createAuthEndpoint(
const updatedSession = await ctx.context.internalAdapter.updateSession( const updatedSession = await ctx.context.internalAdapter.updateSession(
session.session, session.session,
); );
await ctx.setSignedCookie(
ctx.context.authCookies.sessionToken.name,
updatedSession.id,
ctx.context.secret,
{
...ctx.context.authCookies.sessionToken.options,
maxAge: updatedSession.expiresAt.valueOf() - Date.now(),
},
);
return ctx.json({ return ctx.json({
session: updatedSession, session: updatedSession,
user: session.user, user: session.user,

View File

@@ -87,16 +87,18 @@ export const signInCredential = createAuthEndpoint(
email: z.string().email(), email: z.string().email(),
password: z.string(), password: z.string(),
callbackURL: z.string().optional(), callbackURL: z.string().optional(),
/**
* If this is true the session will only be valid for the current browser session
* @default false
*/
dontRememberMe: z.boolean().default(false).optional(),
}), }),
}, },
async (ctx) => { async (ctx) => {
if (!ctx.context.options?.emailAndPassword?.enabled) { if (!ctx.context.options?.emailAndPassword?.enabled) {
ctx.context.logger.error("Email and password is not enabled"); ctx.context.logger.error("Email and password is not enabled");
return ctx.json(null, { throw new APIError("BAD_REQUEST", {
body: { message: "Email and password is not enabled",
message: "Email and password is not enabled",
},
status: 400,
}); });
} }
const currentSession = await getSessionFromCtx(ctx); const currentSession = await getSessionFromCtx(ctx);
@@ -114,8 +116,8 @@ export const signInCredential = createAuthEndpoint(
const user = await ctx.context.internalAdapter.findUserByEmail(email); const user = await ctx.context.internalAdapter.findUserByEmail(email);
if (!user) { if (!user) {
ctx.context.logger.error("User not found", { email }); ctx.context.logger.error("User not found", { email });
return ctx.json(null, { throw new APIError("UNAUTHORIZED", {
status: 401, message: "Invalid email or password",
}); });
} }
const credentialAccount = user.accounts.find( const credentialAccount = user.accounts.find(
@@ -124,20 +126,17 @@ export const signInCredential = createAuthEndpoint(
const currentPassword = credentialAccount?.password; const currentPassword = credentialAccount?.password;
if (!currentPassword) { if (!currentPassword) {
ctx.context.logger.error("Password not found", { email }); ctx.context.logger.error("Password not found", { email });
return ctx.json(null, { throw new APIError("UNAUTHORIZED", {
status: 401, message: "Unexpected error",
body: { message: "Unexpected error" },
}); });
} }
const validPassword = await argon2id.verify(currentPassword, password); const validPassword = await argon2id.verify(currentPassword, password);
if (!validPassword) { if (!validPassword) {
ctx.context.logger.error("Invalid password"); ctx.context.logger.error("Invalid password");
return ctx.json(null, { throw new APIError("UNAUTHORIZED", {
status: 401, message: "Invalid email or password",
body: { message: "Invalid email or password" },
}); });
} }
const session = await ctx.context.internalAdapter.createSession( const session = await ctx.context.internalAdapter.createSession(
user.user.id, user.user.id,
ctx.request, ctx.request,
@@ -146,7 +145,12 @@ export const signInCredential = createAuthEndpoint(
ctx.context.authCookies.sessionToken.name, ctx.context.authCookies.sessionToken.name,
session.id, session.id,
ctx.context.secret, ctx.context.secret,
ctx.context.authCookies.sessionToken.options, ctx.body.dontRememberMe
? {
...ctx.context.authCookies.sessionToken.options,
maxAge: undefined,
}
: ctx.context.authCookies.sessionToken.options,
); );
return ctx.json({ return ctx.json({
user: user.user, user: user.user,

View File

@@ -1,7 +1,7 @@
import { router } from "./api"; import { router } from "./api";
import type { BetterAuthOptions } from "./types/options"; import type { BetterAuthOptions } from "./types/options";
import type { UnionToIntersection } from "type-fest"; import type { UnionToIntersection } from "type-fest";
import type { Plugin } from "./types/plugins"; import type { BetterAuthPlugin } from "./types/plugins";
import { init } from "./init"; import { init } from "./init";
import type { CustomProvider } from "./providers"; import type { CustomProvider } from "./providers";
@@ -9,20 +9,20 @@ export const betterAuth = <O extends BetterAuthOptions>(options: O) => {
const authContext = init(options); const authContext = init(options);
type PluginEndpoint = UnionToIntersection< type PluginEndpoint = UnionToIntersection<
O["plugins"] extends Array<infer T> O["plugins"] extends Array<infer T>
? T extends Plugin ? T extends BetterAuthPlugin
? T["endpoints"] ? T["endpoints"]
: Record<string, never> : {}
: Record<string, never> : {}
>; >;
type ProviderEndpoint = UnionToIntersection< type ProviderEndpoint = UnionToIntersection<
O["providers"] extends Array<infer T> O["providers"] extends Array<infer T>
? T extends CustomProvider ? T extends CustomProvider
? T["endpoints"] ? T["endpoints"]
: Record<string, never> : {}
: Record<string, never> : {}
>; >;
const { handler, endpoints } = router(authContext); const { handler, endpoints } = router(authContext, options);
type Endpoint = typeof endpoints; type Endpoint = typeof endpoints;
return { return {
handler, handler,
@@ -31,11 +31,7 @@ export const betterAuth = <O extends BetterAuthOptions>(options: O) => {
}; };
}; };
export type BetterAuth< export type BetterAuth<Endpoints extends Record<string, any> = {}> = {
Endpoints extends Record<string, any> = ReturnType<
typeof router
>["endpoints"],
> = {
handler: (request: Request) => Promise<Response>; handler: (request: Request) => Promise<Response>;
api: Endpoints; api: Endpoints;
options: BetterAuthOptions; options: BetterAuthOptions;

View File

@@ -19,7 +19,7 @@ export const createVanillaClient = <Auth extends BetterAuth = never>(
: BAuth["api"]; : BAuth["api"];
const $fetch = createFetch({ const $fetch = createFetch({
...options, ...options,
baseURL: getBaseURL(options?.baseURL), baseURL: getBaseURL(options?.baseURL).withPath,
plugins: [redirectPlugin, addCurrentURL, csrfPlugin], plugins: [redirectPlugin, addCurrentURL, csrfPlugin],
}); });
const { $session, $sessionSignal } = getSessionAtom<Auth>($fetch); const { $session, $sessionSignal } = getSessionAtom<Auth>($fetch);

View File

@@ -64,6 +64,7 @@ export function createDynamicPathProxy<T extends Record<string, any>>(
method, method,
onSuccess() { onSuccess() {
const signal = $signal?.[routePath as string]; const signal = $signal?.[routePath as string];
console.log({ routePath, signal });
if (signal) { if (signal) {
signal.set(!signal.get()); signal.set(!signal.get());
} }

View File

@@ -2,13 +2,23 @@ import { useStore } from "@nanostores/react";
import { createVanillaClient } from "./base"; import { createVanillaClient } from "./base";
import { BetterFetchOption } from "@better-fetch/fetch"; import { BetterFetchOption } from "@better-fetch/fetch";
import { BetterAuth } from "../auth"; import { BetterAuth } from "../auth";
import { InferSession, InferUser } from "../types";
export const createAuthClient = <Auth extends BetterAuth>( export const createAuthClient = <Auth extends BetterAuth>(
options?: BetterFetchOption, options?: BetterFetchOption,
) => { ) => {
const client = createVanillaClient<Auth>(options); const client = createVanillaClient<Auth>(options);
function useSession() { function useSession(
return useStore(client.$atoms.$session); initialValue: {
user: InferUser<Auth>;
session: InferSession<Auth>;
} | null = null,
) {
const session = useStore(client.$atoms.$session);
if (session) {
return session;
}
return initialValue;
} }
function useActiveOrganization() { function useActiveOrganization() {
return useStore(client.$atoms.$activeOrganization); return useStore(client.$atoms.$activeOrganization);

View File

@@ -1,64 +1,12 @@
import { atom, computed, task } from "nanostores"; import { atom, computed, task } from "nanostores";
import { Session, User } from "../adapters/schema"; import { Prettify } from "../types/helper";
import { Prettify, UnionToIntersection } from "../types/helper";
import { BetterAuth } from "../auth"; import { BetterAuth } from "../auth";
import { FieldAttribute, InferFieldOutput } from "../db";
import { BetterFetch } from "@better-fetch/fetch"; import { BetterFetch } from "@better-fetch/fetch";
import { InferSession, InferUser } from "../types/models";
export function getSessionAtom<Auth extends BetterAuth>(client: BetterFetch) { export function getSessionAtom<Auth extends BetterAuth>(client: BetterFetch) {
type AdditionalSessionFields = Auth["options"]["plugins"] extends Array< type UserWithAdditionalFields = InferUser<Auth["options"]>;
infer T type SessionWithAdditionalFields = InferSession<Auth["options"]>;
>
? T extends {
schema: {
session: {
fields: infer Field;
};
};
}
? Field extends Record<string, FieldAttribute>
? {
[key in keyof Field]: InferFieldOutput<Field[key]>;
}
: {}
: {}
: {};
type AdditionalUserFields = Auth["options"]["plugins"] extends Array<infer T>
? T extends {
schema: {
user: {
fields: infer Field;
};
};
}
? Field extends Record<infer Key, FieldAttribute>
? Prettify<
{
[key in Key as Field[key]["required"] extends false
? never
: Field[key]["defaultValue"] extends
| boolean
| string
| number
| Date
| Function
? key
: never]: InferFieldOutput<Field[key]>;
} & {
[key in Key as Field[key]["returned"] extends false
? never
: key]?: InferFieldOutput<Field[key]>;
}
>
: {}
: {}
: {};
type UserWithAdditionalFields = User &
UnionToIntersection<AdditionalUserFields>;
type SessionWithAdditionalFields = Session &
UnionToIntersection<AdditionalSessionFields>;
const $signal = atom<boolean>(false); const $signal = atom<boolean>(false);
const $session = computed($signal, () => const $session = computed($signal, () =>
task(async () => { task(async () => {

View File

@@ -1,3 +1,4 @@
import { Context, ContextTools } from "better-call";
import { createKyselyAdapter } from "./adapters/kysely"; import { createKyselyAdapter } from "./adapters/kysely";
import { getAdapter } from "./adapters/utils"; import { getAdapter } from "./adapters/utils";
import { createInternalAdapter } from "./db"; import { createInternalAdapter } from "./db";
@@ -13,13 +14,13 @@ import { createLogger } from "./utils/logger";
export const init = (options: BetterAuthOptions) => { export const init = (options: BetterAuthOptions) => {
const adapter = getAdapter(options); const adapter = getAdapter(options);
const db = createKyselyAdapter(options); const db = createKyselyAdapter(options);
const baseURL = getBaseURL(options.baseURL, options.basePath); const { baseURL, withPath } = getBaseURL(options.baseURL, options.basePath);
return { return {
options: { options: {
...options, ...options,
baseURL, baseURL: baseURL,
}, },
baseURL, baseURL: withPath,
secret: secret:
options.secret || options.secret ||
process.env.BETTER_AUTH_SECRET || process.env.BETTER_AUTH_SECRET ||

View File

@@ -3,7 +3,6 @@ import { createAuthMiddleware, optionsMiddleware } from "../../api/call";
import { OrganizationOptions } from "./organization"; import { OrganizationOptions } from "./organization";
import { defaultRoles, Role } from "./access"; import { defaultRoles, Role } from "./access";
import { Session, User } from "../../adapters/schema"; import { Session, User } from "../../adapters/schema";
import { getSession } from "../../api/routes";
import { sessionMiddleware } from "../../api/middlewares/session"; import { sessionMiddleware } from "../../api/middlewares/session";
export const orgMiddleware = createAuthMiddleware(async (ctx) => { export const orgMiddleware = createAuthMiddleware(async (ctx) => {

View File

@@ -9,7 +9,7 @@ import {
updateOrganization, updateOrganization,
} from "./routes/crud-org"; } from "./routes/crud-org";
import { AccessControl, defaultRoles, defaultStatements, Role } from "./access"; import { AccessControl, defaultRoles, defaultStatements, Role } from "./access";
import { getSession } from "../../api/routes"; import { getSessionFromCtx } from "../../api/routes";
import { AuthContext } from "../../init"; import { AuthContext } from "../../init";
import { import {
acceptInvitation, acceptInvitation,
@@ -18,6 +18,10 @@ import {
rejectInvitation, rejectInvitation,
} from "./routes/crud-invites"; } from "./routes/crud-invites";
import { deleteMember, updateMember } from "./routes/crud-members"; import { deleteMember, updateMember } from "./routes/crud-members";
import { sessionMiddleware } from "../../api/middlewares/session";
import { orgMiddleware, orgSessionMiddleware } from "./call";
import { getOrgAdapter } from "./adapter";
import { APIError } from "better-call";
export interface OrganizationOptions { export interface OrganizationOptions {
/** /**
@@ -94,7 +98,7 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
roles, roles,
getSession: async (context: AuthContext) => { getSession: async (context: AuthContext) => {
//@ts-expect-error //@ts-expect-error
return await getSession(context); return await getSessionFromCtx(context);
}, },
}); });
@@ -112,6 +116,7 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
"/org/has-permission", "/org/has-permission",
{ {
method: "POST", method: "POST",
requireHeaders: true,
body: z.object({ body: z.object({
permission: z.record(z.string(), z.array(z.string())), permission: z.record(z.string(), z.array(z.string())),
}) as unknown as ZodObject<{ }) as unknown as ZodObject<{
@@ -122,10 +127,42 @@ export const organization = <O extends OrganizationOptions>(options?: O) => {
>; >;
}>; }>;
}>, }>,
use: [orgSessionMiddleware],
}, },
async () => { async (ctx) => {
const hasPerm = true; if (!ctx.context.session.session.activeOrganizationId) {
return hasPerm; throw new APIError("BAD_REQUEST", {
message: "No active organization",
});
}
const adapter = getOrgAdapter(ctx.context.adapter);
const member = await adapter.findMemberByOrgId({
userId: ctx.context.session.user.id,
organizationId:
ctx.context.session.session.activeOrganizationId || "",
});
if (!member) {
throw new APIError("UNAUTHORIZED", {
message: "You are not a member of this organization",
});
}
const role = roles[member.role];
const result = role.authorize(ctx.body.permission as any);
if (result.error) {
return ctx.json(
{
error: result.error,
success: false,
},
{
status: 403,
},
);
}
return ctx.json({
error: null,
success: true,
});
}, },
), ),
}, },

View File

@@ -1,5 +0,0 @@
import { BetterAuthPlugin } from "../../types/plugins";
export const rememberMePlugin = async () => {
return {} satisfies BetterAuthPlugin;
};

View File

@@ -0,0 +1,13 @@
import { ContextTools } from "better-call";
import { AuthContext } from "../init";
export type GenericEndpointContext = ContextTools & {
context: AuthContext;
} & {
body: any;
request: Request;
headers: Headers;
params?: Record<string, string> | undefined;
query: any;
method: "*";
};

View File

@@ -1 +1,2 @@
export * from "./options"; export * from "./options";
export * from "./models";

View File

@@ -0,0 +1,73 @@
import { BetterAuthOptions } from ".";
import { Session, User } from "../adapters/schema";
import { BetterAuth } from "../auth";
import { FieldAttribute, InferFieldOutput } from "../db";
import { Prettify, UnionToIntersection } from "./helper";
type AdditionalSessionFields<Options extends BetterAuthOptions> =
Options["plugins"] extends Array<infer T>
? T extends {
schema: {
session: {
fields: infer Field;
};
};
}
? Field extends Record<string, FieldAttribute>
? {
[key in keyof Field]: InferFieldOutput<Field[key]>;
}
: {}
: {}
: {};
type AdditionalUserFields<Options extends BetterAuthOptions> =
Options["plugins"] extends Array<infer T>
? T extends {
schema: {
user: {
fields: infer Field;
};
};
}
? Field extends Record<infer Key, FieldAttribute>
? Prettify<
{
[key in Key as Field[key]["required"] extends false
? never
: Field[key]["defaultValue"] extends
| boolean
| string
| number
| Date
| Function
? key
: never]: InferFieldOutput<Field[key]>;
} & {
[key in Key as Field[key]["returned"] extends false
? never
: key]?: InferFieldOutput<Field[key]>;
}
>
: {}
: {}
: {};
export type InferUser<O extends BetterAuthOptions | BetterAuth> =
UnionToIntersection<
User &
(O extends BetterAuthOptions
? AdditionalUserFields<O>
: O extends BetterAuth
? AdditionalUserFields<O["options"]>
: {})
>;
export type InferSession<O extends BetterAuthOptions | BetterAuth> =
UnionToIntersection<
Session &
(O extends BetterAuthOptions
? AdditionalSessionFields<O>
: O extends BetterAuth
? AdditionalSessionFields<O["options"]>
: {})
>;

View File

@@ -1,7 +1,7 @@
import { Dialect } from "kysely"; import { Dialect } from "kysely";
import type { FieldAttribute } from "../db/field"; import type { FieldAttribute } from "../db/field";
import type { Provider } from "./provider"; import type { Provider } from "./provider";
import type { Plugin } from "./plugins"; import type { BetterAuthPlugin } from "./plugins";
import type { Adapter } from "./adapter"; import type { Adapter } from "./adapter";
import { User } from "../adapters/schema"; import { User } from "../adapters/schema";
@@ -55,7 +55,7 @@ export interface BetterAuthOptions {
/** /**
* Plugins * Plugins
*/ */
plugins?: Plugin[]; plugins?: BetterAuthPlugin[];
/** /**
* Advanced options * Advanced options
*/ */

View File

@@ -1,8 +1,9 @@
import { Migration } from "kysely"; import { Migration } from "kysely";
import { AuthEndpoint, AuthMiddleware } from "../api/call"; import { AuthEndpoint } from "../api/call";
import { FieldAttribute } from "../db/field"; import { FieldAttribute } from "../db/field";
import { LiteralString } from "./helper"; import { LiteralString } from "./helper";
import { Endpoint } from "better-call"; import { Endpoint, EndpointResponse } from "better-call";
import { GenericEndpointContext } from "./context";
export type PluginSchema = { export type PluginSchema = {
[table: string]: { [table: string]: {
@@ -22,6 +23,28 @@ export type BetterAuthPlugin = {
path: string; path: string;
middleware: Endpoint; middleware: Endpoint;
}[]; }[];
hooks?: {
before?: {
matcher: (context: GenericEndpointContext) => boolean;
handler: Endpoint<
(context: GenericEndpointContext) => Promise<void | {
context: Partial<GenericEndpointContext>;
}>
>;
}[];
after?: {
matcher: (context: GenericEndpointContext) => boolean;
handler: Endpoint<
(
context: GenericEndpointContext & {
returned: EndpointResponse;
},
) => Promise<void | {
response: EndpointResponse;
}>
>;
}[];
};
/** /**
* Schema the plugin needs * Schema the plugin needs
* *

View File

@@ -11,10 +11,16 @@ function checkHasPath(url: string): boolean {
function withPath(url: string, path = "/api/auth") { function withPath(url: string, path = "/api/auth") {
const hasPath = checkHasPath(url); const hasPath = checkHasPath(url);
if (hasPath) { if (hasPath) {
return url; return {
baseURL: new URL(url).origin,
withPath: url,
};
} }
path = path.startsWith("/") ? path : `/${path}`; path = path.startsWith("/") ? path : `/${path}`;
return `${url}${path}`; return {
baseURL: url,
withPath: `${url}${path}`,
};
} }
export function getBaseURL(url?: string, path?: string) { export function getBaseURL(url?: string, path?: string) {
@@ -33,7 +39,10 @@ export function getBaseURL(url?: string, path?: string) {
!fromEnv && !fromEnv &&
(process.env.NODE_ENV === "development" || process.env.NODE_ENV === "test") (process.env.NODE_ENV === "development" || process.env.NODE_ENV === "test")
) { ) {
return "http://localhost:3000/api/auth"; return {
baseURL: "http://localhost:3000",
withPath: "http://localhost:3000/api/auth",
};
} }
throw new Error( throw new Error(
"Could not infer baseURL from environment variables. Please pass it as an option to the createClient function.", "Could not infer baseURL from environment variables. Please pass it as an option to the createClient function.",

View File

@@ -4,6 +4,7 @@ export default defineConfig({
entry: { entry: {
index: "./src/index.ts", index: "./src/index.ts",
provider: "./src/providers/index.ts", provider: "./src/providers/index.ts",
types: "./src/types/index.ts",
client: "./src/client/index.ts", client: "./src/client/index.ts",
cli: "./src/cli/index.ts", cli: "./src/cli/index.ts",
react: "./src/client/react.ts", react: "./src/client/react.ts",

28
pnpm-lock.yaml generated
View File

@@ -51,7 +51,7 @@ importers:
devDependencies: devDependencies:
'@types/bun': '@types/bun':
specifier: latest specifier: latest
version: 1.1.7 version: 1.1.8
vite: vite:
specifier: ^5.3.5 specifier: ^5.3.5
version: 5.3.5(@types/node@20.14.12) version: 5.3.5(@types/node@20.14.12)
@@ -478,8 +478,8 @@ importers:
specifier: ^1.9.2 specifier: ^1.9.2
version: 1.9.2 version: 1.9.2
better-call: better-call:
specifier: ^0.1.33 specifier: ^0.1.36
version: 0.1.33(typescript@5.5.4) version: 0.1.36(typescript@5.5.4)
chalk: chalk:
specifier: ^5.3.0 specifier: ^5.3.0
version: 5.3.0 version: 5.3.0
@@ -2659,8 +2659,8 @@ packages:
'@types/better-sqlite3@7.6.11': '@types/better-sqlite3@7.6.11':
resolution: {integrity: sha512-i8KcD3PgGtGBLl3+mMYA8PdKkButvPyARxA7IQAd6qeslht13qxb1zzO8dRCtE7U3IoJS782zDBAeoKiM695kg==} resolution: {integrity: sha512-i8KcD3PgGtGBLl3+mMYA8PdKkButvPyARxA7IQAd6qeslht13qxb1zzO8dRCtE7U3IoJS782zDBAeoKiM695kg==}
'@types/bun@1.1.7': '@types/bun@1.1.8':
resolution: {integrity: sha512-iIIn26SOX8qI5E8Juh+0rUgBmFHvll1akscwerhp9O/fHZGdQBWNLJkkRg/3z2Mh6a3ZgWUIkXViLZZYg47TXw==} resolution: {integrity: sha512-PIwVFQKPviksiibobyvcWtMvMFMTj91T8dQEh9l1P3Ypr3ZuVn9w7HSr+5mTNrPqD1xpdDLEErzZPU8gqHBu6g==}
'@types/cookie@0.6.0': '@types/cookie@0.6.0':
resolution: {integrity: sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==} resolution: {integrity: sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==}
@@ -3042,10 +3042,10 @@ packages:
peerDependencies: peerDependencies:
typescript: ^5.0.0 typescript: ^5.0.0
better-call@0.1.33: better-call@0.1.36:
resolution: {integrity: sha512-gzthE/AnimwMCNBNyy9LRqqAtjXkqO+dR4n1OjCiUhiBK4X+NZMKekQUKIzpfwDRC4k3hCshb1LsPhqwiSM7Bw==} resolution: {integrity: sha512-+FsoIB8tMVRciTTN9UUXXoJEzAqIaaNPrxa9kDYoTaGCTCimNKbHNuLwsnIAibamG0Lo7CFCiSL4yyV+I32O4A==}
peerDependencies: peerDependencies:
typescript: ^5.0.0 typescript: ^5.6.0-beta
better-sqlite3@11.1.2: better-sqlite3@11.1.2:
resolution: {integrity: sha512-gujtFwavWU4MSPT+h9B+4pkvZdyOUkH54zgLdIrMmmmd4ZqiBIrRNBzNzYVFO417xo882uP5HBu4GjOfaSrIQw==} resolution: {integrity: sha512-gujtFwavWU4MSPT+h9B+4pkvZdyOUkH54zgLdIrMmmmd4ZqiBIrRNBzNzYVFO417xo882uP5HBu4GjOfaSrIQw==}
@@ -3089,8 +3089,8 @@ packages:
bun-html-live-reload@0.1.3: bun-html-live-reload@0.1.3:
resolution: {integrity: sha512-PW1sp9ZmBAqiAa0aUhHpFc6sEQmC6FgRNKVAvcjSDUMqASzgq7xYpNkEt2Z6VjuiPXKtOx/49b6sLLmjyojrOw==} resolution: {integrity: sha512-PW1sp9ZmBAqiAa0aUhHpFc6sEQmC6FgRNKVAvcjSDUMqASzgq7xYpNkEt2Z6VjuiPXKtOx/49b6sLLmjyojrOw==}
bun-types@1.1.25: bun-types@1.1.26:
resolution: {integrity: sha512-WpRb8/N3S5IE8UYdIn39+0Is1XzxsC78+MCe5cIdaer0lfFs6+DREtQH9TM6KJNKTxBYDvbx81RwbvxS5+CkVQ==} resolution: {integrity: sha512-n7jDe62LsB2+WE8Q8/mT3azkPaatKlj/2MyP6hi3mKvPz9oPpB6JW/Ll6JHtNLudasFFuvfgklYSE+rreGvBjw==}
bundle-require@5.0.0: bundle-require@5.0.0:
resolution: {integrity: sha512-GuziW3fSSmopcx4KRymQEJVbZUfqlCqcq7dvs6TYwKRZiegK/2buMxQTPs6MGlNv50wms1699qYO54R8XfRX4w==} resolution: {integrity: sha512-GuziW3fSSmopcx4KRymQEJVbZUfqlCqcq7dvs6TYwKRZiegK/2buMxQTPs6MGlNv50wms1699qYO54R8XfRX4w==}
@@ -8820,9 +8820,9 @@ snapshots:
dependencies: dependencies:
'@types/node': 20.14.12 '@types/node': 20.14.12
'@types/bun@1.1.7': '@types/bun@1.1.8':
dependencies: dependencies:
bun-types: 1.1.25 bun-types: 1.1.26
'@types/cookie@0.6.0': {} '@types/cookie@0.6.0': {}
@@ -9275,7 +9275,7 @@ snapshots:
rou3: 0.5.1 rou3: 0.5.1
typescript: 5.5.4 typescript: 5.5.4
better-call@0.1.33(typescript@5.5.4): better-call@0.1.36(typescript@5.5.4):
dependencies: dependencies:
'@better-fetch/fetch': 1.1.4 '@better-fetch/fetch': 1.1.4
'@types/set-cookie-parser': 2.4.10 '@types/set-cookie-parser': 2.4.10
@@ -9343,7 +9343,7 @@ snapshots:
bun-html-live-reload@0.1.3: {} bun-html-live-reload@0.1.3: {}
bun-types@1.1.25: bun-types@1.1.26:
dependencies: dependencies:
'@types/node': 20.12.14 '@types/node': 20.12.14
'@types/ws': 8.5.11 '@types/ws': 8.5.11

View File

@@ -1,4 +1,4 @@
## TODO ## TODO
[ ] handle migration when the config removes existing schema [ ] handle migration when the config removes existing schema
[ ] refresh oauth tokens [ ] refresh oauth tokens
[ ] remember me functionality [x] remember me functionality

View File

@@ -5,6 +5,7 @@
"target": "es2022", "target": "es2022",
"allowJs": true, "allowJs": true,
"resolveJsonModule": true, "resolveJsonModule": true,
"disableReferencedProjectLoad": true,
"moduleDetection": "force", "moduleDetection": "force",
"isolatedModules": true, "isolatedModules": true,
"strict": true, "strict": true,