feat: custom table names and fields for plugins (#570)

This commit is contained in:
Bereket Engida
2024-11-18 11:25:39 +03:00
committed by GitHub
parent a3867bdbba
commit afecf4ce41
32 changed files with 388 additions and 234 deletions

View File

@@ -470,6 +470,27 @@ export const auth = betterAuth({
Type inference in your code will still use the original field names (e.g., `user.name`, not `user.full_name`). Type inference in your code will still use the original field names (e.g., `user.name`, not `user.full_name`).
</Callout> </Callout>
To customize table names and column name for plugins, you can use the `schema` property in the plugin config:
```ts title="auth.ts"
import { betterAuth } from "better-auth";
export const auth = betterAuth({
plugins: {
twoFactor: {
schema: {
user: {
fields: {
twoFactorEnabled: "two_factor_enabled",
twoFactorSecret: "two_factor_secret"
}
}
}
}
}
})
```
### Extending Core Schema ### Extending Core Schema

View File

@@ -118,7 +118,7 @@ const myPlugin = ()=> {
type: "string" type: "string"
} }
}, },
tableName: "myTable" // optional if you want to use a different name than the key modelName: "myTable" // optional if you want to use a different name than the key
} }
} }
} satisfies BetterAuthPlugin } satisfies BetterAuthPlugin

View File

@@ -64,7 +64,7 @@ export const auth = betterAuth({
//...other options //...other options
rateLimit: { rateLimit: {
storage: "database", storage: "database",
tableName: "rateLimit", //optional by default "rateLimit" is used modelName: "rateLimit", //optional by default "rateLimit" is used
}, },
}) })
``` ```

View File

@@ -655,6 +655,25 @@ Table Name: `invitation`
]} ]}
/> />
### Customizing the Schema
To change the schema table name or fields, you can pass `schema` option to the organization plugin.
```ts title="auth.ts"
const auth = betterAuth({
plugins: [organization({
schema: {
organization: {
modelName: "organizations", //map the organization table to organizations
fields: {
name: "title" //map the name field to title
}
}
}
})]
})
```
## Options ## Options
**allowUserToCreateOrganization**: `boolean` | `((user: User) => Promise<boolean> | boolean)` - A function that determines whether a user can create an organization. By default, it's `true`. You can set it to `false` to restrict users from creating organizations. **allowUserToCreateOrganization**: `boolean` | `((user: User) => Promise<boolean> | boolean)` - A function that determines whether a user can create an organization. By default, it's `true`. You can set it to `false` to restrict users from creating organizations.

View File

@@ -174,8 +174,8 @@ exports[`init > should match config 1`] = `
"type": "string", "type": "string",
}, },
}, },
"modelName": "account",
"order": 3, "order": 3,
"tableName": "account",
}, },
"session": { "session": {
"fields": { "fields": {
@@ -205,8 +205,8 @@ exports[`init > should match config 1`] = `
"type": "string", "type": "string",
}, },
}, },
"modelName": "session",
"order": 2, "order": 2,
"tableName": "session",
}, },
"user": { "user": {
"fields": { "fields": {
@@ -245,8 +245,8 @@ exports[`init > should match config 1`] = `
"type": "date", "type": "date",
}, },
}, },
"modelName": "user",
"order": 1, "order": 1,
"tableName": "user",
}, },
"verification": { "verification": {
"fields": { "fields": {
@@ -272,8 +272,8 @@ exports[`init > should match config 1`] = `
"type": "string", "type": "string",
}, },
}, },
"modelName": "verification",
"order": 4, "order": 4,
"tableName": "verification",
}, },
}, },
"trustedOrigins": [ "trustedOrigins": [

View File

@@ -40,8 +40,8 @@ const createTransform = (
} }
const getModelName = (model: string) => { const getModelName = (model: string) => {
return schema[model].tableName !== model return schema[model].modelName !== model
? schema[model].tableName ? schema[model].modelName
: config.usePlural : config.usePlural
? `${model}s` ? `${model}s`
: model; : model;

View File

@@ -83,7 +83,7 @@ const createTransform = (
} }
function getModelName(model: string) { function getModelName(model: string) {
return schema[model].tableName; return schema[model].modelName;
} }
const shouldGenerateId = config?.generateId !== false; const shouldGenerateId = config?.generateId !== false;

View File

@@ -183,7 +183,7 @@ const createTransform = (options: BetterAuthOptions) => {
return clause; return clause;
}, },
getModelName: (model: string) => { getModelName: (model: string) => {
return schema[model].tableName; return schema[model].modelName;
}, },
getField, getField,
}; };

View File

@@ -59,7 +59,7 @@ const createTransform = (config: PrismaConfig, options: BetterAuthOptions) => {
} }
function getModelName(model: string) { function getModelName(model: string) {
return schema[model].tableName; return schema[model].modelName;
} }
const shouldGenerateId = config?.generateId !== false; const shouldGenerateId = config?.generateId !== false;
return { return {

View File

@@ -33,8 +33,8 @@ function getRetryAfter(lastRequest: number, window: number) {
return Math.ceil((lastRequest + windowInMs - now) / 1000); return Math.ceil((lastRequest + windowInMs - now) / 1000);
} }
function createDBStorage(ctx: AuthContext, tableName?: string) { function createDBStorage(ctx: AuthContext, modelName?: string) {
const model = tableName ?? "rateLimit"; const model = "rateLimit";
const db = ctx.adapter; const db = ctx.adapter;
return { return {
get: async (key: string) => { get: async (key: string) => {
@@ -48,7 +48,7 @@ function createDBStorage(ctx: AuthContext, tableName?: string) {
try { try {
if (_update) { if (_update) {
await db.update({ await db.update({
model: tableName ?? "rateLimit", model: modelName ?? "rateLimit",
where: [{ field: "key", value: key }], where: [{ field: "key", value: key }],
update: { update: {
count: value.count, count: value.count,
@@ -57,7 +57,7 @@ function createDBStorage(ctx: AuthContext, tableName?: string) {
}); });
} else { } else {
await db.create({ await db.create({
model: tableName ?? "rateLimit", model: modelName ?? "rateLimit",
data: { data: {
key, key,
count: value.count, count: value.count,
@@ -96,7 +96,7 @@ export function getRateLimitStorage(ctx: AuthContext) {
}, },
}; };
} }
return createDBStorage(ctx, ctx.rateLimit.tableName); return createDBStorage(ctx, ctx.rateLimit.modelName);
} }
export async function onRequestRateLimit(req: Request, ctx: AuthContext) { export async function onRequestRateLimit(req: Request, ctx: AuthContext) {

View File

@@ -61,6 +61,7 @@ export const useAuthQuery = <T>(
}); });
await opts?.onError?.(context); await opts?.onError?.(context);
}, },
async onRequest(context) { async onRequest(context) {
const currentValue = value.get(); const currentValue = value.get();
value.set({ value.set({

View File

@@ -8,6 +8,8 @@ export type FieldType =
| "date" | "date"
| `${"string" | "number"}[]`; | `${"string" | "number"}[]`;
type Primitive = string | number | boolean | Date | null | undefined;
export type FieldAttributeConfig<T extends FieldType = FieldType> = { export type FieldAttributeConfig<T extends FieldType = FieldType> = {
/** /**
* If the field should be required on a new record. * If the field should be required on a new record.
@@ -24,22 +26,22 @@ export type FieldAttributeConfig<T extends FieldType = FieldType> = {
* @default true * @default true
*/ */
input?: boolean; input?: boolean;
/**
* If the value should be hashed when it's stored.
* @default false
*/
hashValue?: boolean;
/** /**
* Default value for the field * Default value for the field
* *
* Note: This will not create a default value on the database level. It will only * Note: This will not create a default value on the database level. It will only
* be used when creating a new record. * be used when creating a new record.
*/ */
defaultValue?: any; defaultValue?: Primitive | (() => Primitive);
/** /**
* transform the value before storing it. * transform the value before storing it.
*/ */
transform?: (value: InferValueType<T>) => InferValueType<T>; transform?: {
input?: (value: InferValueType<T>) => Primitive | Promise<Primitive>;
output?: (
value: Primitive,
) => InferValueType<T> | Promise<InferValueType<T>>;
};
/** /**
* Reference to another model. * Reference to another model.
*/ */
@@ -67,14 +69,12 @@ export type FieldAttributeConfig<T extends FieldType = FieldType> = {
/** /**
* A zod schema to validate the value. * A zod schema to validate the value.
*/ */
validator?: ZodSchema; validator?: {
input?: ZodSchema;
output?: ZodSchema;
};
/** /**
* The name of the field on the database. * The name of the field on the database.
*
* @default
* ```txt
* the key in the fields object.
* ```
*/ */
fieldName?: string; fieldName?: string;
}; };

View File

@@ -17,14 +17,14 @@ export function getSchema(config: BetterAuthOptions) {
Object.entries(fields).forEach(([key, field]) => { Object.entries(fields).forEach(([key, field]) => {
actualFields[field.fieldName || key] = field; actualFields[field.fieldName || key] = field;
}); });
if (schema[table.tableName]) { if (schema[table.modelName]) {
schema[table.tableName].fields = { schema[table.modelName].fields = {
...schema[table.tableName].fields, ...schema[table.modelName].fields,
...actualFields, ...actualFields,
}; };
continue; continue;
} }
schema[table.tableName] = { schema[table.modelName] = {
fields: actualFields, fields: actualFields,
order: table.order || Infinity, order: table.order || Infinity,
}; };

View File

@@ -7,7 +7,7 @@ export type BetterAuthDbSchema = Record<
/** /**
* The name of the table in the database * The name of the table in the database
*/ */
tableName: string; modelName: string;
/** /**
* The fields of the table * The fields of the table
*/ */
@@ -37,21 +37,21 @@ export const getAuthTables = (
...acc[key]?.fields, ...acc[key]?.fields,
...value.fields, ...value.fields,
}, },
tableName: value.tableName || key, modelName: value.modelName || key,
}; };
} }
return acc; return acc;
}, },
{} as Record< {} as Record<
string, string,
{ fields: Record<string, FieldAttribute>; tableName: string } { fields: Record<string, FieldAttribute>; modelName: string }
>, >,
); );
const shouldAddRateLimitTable = options.rateLimit?.storage === "database"; const shouldAddRateLimitTable = options.rateLimit?.storage === "database";
const rateLimitTable = { const rateLimitTable = {
rateLimit: { rateLimit: {
tableName: options.rateLimit?.tableName || "rateLimit", modelName: options.rateLimit?.modelName || "rateLimit",
fields: { fields: {
key: { key: {
type: "string", type: "string",
@@ -72,7 +72,7 @@ export const getAuthTables = (
const { user, session, account, ...pluginTables } = pluginSchema || {}; const { user, session, account, ...pluginTables } = pluginSchema || {};
return { return {
user: { user: {
tableName: options.user?.modelName || "user", modelName: options.user?.modelName || "user",
fields: { fields: {
name: { name: {
type: "string", type: "string",
@@ -114,7 +114,7 @@ export const getAuthTables = (
order: 1, order: 1,
}, },
session: { session: {
tableName: options.session?.modelName || "session", modelName: options.session?.modelName || "session",
fields: { fields: {
expiresAt: { expiresAt: {
type: "date", type: "date",
@@ -147,7 +147,7 @@ export const getAuthTables = (
order: 2, order: 2,
}, },
account: { account: {
tableName: options.account?.modelName || "account", modelName: options.account?.modelName || "account",
fields: { fields: {
accountId: { accountId: {
type: "string", type: "string",
@@ -199,7 +199,7 @@ export const getAuthTables = (
order: 3, order: 3,
}, },
verification: { verification: {
tableName: options.verification?.modelName || "verification", modelName: options.verification?.modelName || "verification",
fields: { fields: {
identifier: { identifier: {
type: "string", type: "string",

View File

@@ -1,6 +1,6 @@
import { z } from "zod"; import { z } from "zod";
import type { FieldAttribute } from "."; import type { FieldAttribute } from ".";
import type { BetterAuthOptions } from "../types"; import type { BetterAuthOptions, PluginSchema } from "../types";
export const accountSchema = z.object({ export const accountSchema = z.object({
id: z.string(), id: z.string(),
@@ -172,3 +172,33 @@ export function parseSessionInput(
const schema = getAllFields(options, "session"); const schema = getAllFields(options, "session");
return parseInputData(session, { fields: schema }); return parseInputData(session, { fields: schema });
} }
export function mergeSchema<S extends PluginSchema>(
schema: S,
newSchema?: {
[K in keyof S]?: {
modelName?: string;
fields?: {
[P: string]: string;
};
};
},
) {
if (!newSchema) {
return schema;
}
for (const table in newSchema) {
const newModelName = newSchema[table]?.modelName;
if (newModelName) {
schema[table].modelName = newModelName;
}
for (const field in schema[table].fields) {
const newField = newSchema[table]?.fields?.[field];
if (!newField) {
continue;
}
schema[table].fields[field].fieldName = newField;
}
}
return schema;
}

View File

@@ -6,7 +6,17 @@ import { adminClient } from "./client";
describe("Admin plugin", async () => { describe("Admin plugin", async () => {
const { client, signInWithTestUser } = await getTestInstance( const { client, signInWithTestUser } = await getTestInstance(
{ {
plugins: [admin()], plugins: [
admin({
schema: {
user: {
fields: {
role: "_role",
},
},
},
}),
],
logger: { logger: {
level: "error", level: "error",
}, },

View File

@@ -5,9 +5,17 @@ import {
createAuthMiddleware, createAuthMiddleware,
getSessionFromCtx, getSessionFromCtx,
} from "../../api"; } from "../../api";
import type { BetterAuthPlugin, Session, User, Where } from "../../types"; import {
type BetterAuthPlugin,
type InferOptionSchema,
type PluginSchema,
type Session,
type User,
type Where,
} from "../../types";
import { setSessionCookie } from "../../cookies"; import { setSessionCookie } from "../../cookies";
import { getDate } from "../../utils/date"; import { getDate } from "../../utils/date";
import { mergeSchema } from "../../db/schema";
export interface UserWithRole extends User { export interface UserWithRole extends User {
role?: string | null; role?: string | null;
@@ -53,9 +61,13 @@ interface AdminOptions {
* By default, the impersonation session lasts 1 hour * By default, the impersonation session lasts 1 hour
*/ */
impersonationSessionDuration?: number; impersonationSessionDuration?: number;
/**
* Custom schema for the admin plugin
*/
schema?: InferOptionSchema<typeof schema>;
} }
export const admin = (options?: AdminOptions) => { export const admin = <O extends AdminOptions>(options?: O) => {
const opts = { const opts = {
defaultRole: "user", defaultRole: "user",
adminRole: "admin", adminRole: "admin",
@@ -478,7 +490,11 @@ export const admin = (options?: AdminOptions) => {
}, },
), ),
}, },
schema: { schema: mergeSchema(schema, opts.schema),
} satisfies BetterAuthPlugin;
};
const schema = {
user: { user: {
fields: { fields: {
role: { role: {
@@ -512,6 +528,4 @@ export const admin = (options?: AdminOptions) => {
}, },
}, },
}, },
}, } satisfies PluginSchema;
} satisfies BetterAuthPlugin;
};

View File

@@ -12,6 +12,13 @@ describe("anonymous", async () => {
async onLinkAccount(data) { async onLinkAccount(data) {
linkAccountFn(data); linkAccountFn(data);
}, },
schema: {
user: {
fields: {
isAnonymous: "is_anon",
},
},
},
}), }),
], ],
}); });

View File

@@ -1,14 +1,16 @@
import { import { APIError, createAuthEndpoint, getSessionFromCtx } from "../../api";
APIError, import type {
createAuthEndpoint, BetterAuthPlugin,
getSessionFromCtx, InferOptionSchema,
sessionMiddleware, PluginSchema,
} from "../../api"; Session,
import type { BetterAuthPlugin, Session, User } from "../../types"; User,
} from "../../types";
import { parseSetCookieHeader, setSessionCookie } from "../../cookies"; import { parseSetCookieHeader, setSessionCookie } from "../../cookies";
import { z } from "zod"; import { z } from "zod";
import { generateId } from "../../utils/id"; import { generateId } from "../../utils/id";
import { getOrigin } from "../../utils/url"; import { getOrigin } from "../../utils/url";
import { mergeSchema } from "../../db/schema";
export interface UserWithAnonymous extends User { export interface UserWithAnonymous extends User {
isAnonymous: boolean; isAnonymous: boolean;
@@ -38,8 +40,23 @@ export interface AnonymousOptions {
* Disable deleting the anonymous user after linking * Disable deleting the anonymous user after linking
*/ */
disableDeleteAnonymousUser?: boolean; disableDeleteAnonymousUser?: boolean;
/**
* Custom schema for the admin plugin
*/
schema?: InferOptionSchema<typeof schema>;
} }
const schema = {
user: {
fields: {
isAnonymous: {
type: "boolean",
required: false,
},
},
},
} satisfies PluginSchema;
export const anonymous = (options?: AnonymousOptions) => { export const anonymous = (options?: AnonymousOptions) => {
return { return {
id: "anonymous", id: "anonymous",
@@ -154,15 +171,6 @@ export const anonymous = (options?: AnonymousOptions) => {
}, },
], ],
}, },
schema: { schema: mergeSchema(schema, options?.schema),
user: {
fields: {
isAnonymous: {
type: "boolean",
required: false,
},
},
},
},
} satisfies BetterAuthPlugin; } satisfies BetterAuthPlugin;
}; };

View File

@@ -1,9 +1,10 @@
import type { BetterAuthPlugin, User } from "../../types"; import type { BetterAuthPlugin, InferOptionSchema, User } from "../../types";
import { type Jwk, schema } from "./schema"; import { type Jwk, schema } from "./schema";
import { getJwksAdapter } from "./adapter"; import { getJwksAdapter } from "./adapter";
import { exportJWK, generateKeyPair, importJWK, SignJWT } from "jose"; import { exportJWK, generateKeyPair, importJWK, SignJWT } from "jose";
import { createAuthEndpoint, sessionMiddleware } from "../../api"; import { createAuthEndpoint, sessionMiddleware } from "../../api";
import { symmetricDecrypt, symmetricEncrypt } from "../../crypto"; import { symmetricDecrypt, symmetricEncrypt } from "../../crypto";
import { mergeSchema } from "../../db/schema";
type JWKOptions = type JWKOptions =
| { | {
@@ -83,6 +84,10 @@ export interface JwtOptions {
user: User, user: User,
) => Promise<Record<string, any>> | Record<string, any>; ) => Promise<Record<string, any>> | Record<string, any>;
}; };
/**
* Custom schema for the admin plugin
*/
schema?: InferOptionSchema<typeof schema>;
} }
export const jwt = (options?: JwtOptions) => { export const jwt = (options?: JwtOptions) => {
@@ -189,6 +194,6 @@ export const jwt = (options?: JwtOptions) => {
}, },
), ),
}, },
schema, schema: mergeSchema(schema, options?.schema),
} satisfies BetterAuthPlugin; } satisfies BetterAuthPlugin;
}; };

View File

@@ -1,7 +1,7 @@
import type { PluginSchema } from "../../types"; import type { PluginSchema } from "../../types";
import { z } from "zod"; import { z } from "zod";
export const schema: PluginSchema = { export const schema = {
jwks: { jwks: {
fields: { fields: {
publicKey: { publicKey: {
@@ -18,7 +18,7 @@ export const schema: PluginSchema = {
}, },
}, },
}, },
}; } satisfies PluginSchema;
export const jwk = z.object({ export const jwk = z.object({
id: z.string(), id: z.string(),

View File

@@ -71,7 +71,7 @@ export const getOrgAdapter = (
organizationId: string; organizationId: string;
}) => { }) => {
const user = await adapter.findOne<User>({ const user = await adapter.findOne<User>({
model: context.tables.user.tableName, model: context.tables.user.modelName,
where: [ where: [
{ {
field: "email", field: "email",
@@ -127,7 +127,7 @@ export const getOrgAdapter = (
], ],
}), }),
await adapter.findOne<User>({ await adapter.findOne<User>({
model: context.tables.user.tableName, model: context.tables.user.modelName,
where: [ where: [
{ {
field: "id", field: "id",
@@ -163,7 +163,7 @@ export const getOrgAdapter = (
return null; return null;
} }
const user = await adapter.findOne<User>({ const user = await adapter.findOne<User>({
model: context.tables.user.tableName, model: context.tables.user.modelName,
where: [ where: [
{ {
field: "id", field: "id",
@@ -269,7 +269,7 @@ export const getOrgAdapter = (
organizationId: string | null, organizationId: string | null,
) => { ) => {
const session = await adapter.update<Session>({ const session = await adapter.update<Session>({
model: context.tables.session.tableName, model: context.tables.session.modelName,
where: [ where: [
{ {
field: "id", field: "id",
@@ -317,7 +317,7 @@ export const getOrgAdapter = (
const userIds = members.map((member) => member.userId); const userIds = members.map((member) => member.userId);
const users = await adapter.findMany<User>({ const users = await adapter.findMany<User>({
model: context.tables.user.tableName, model: context.tables.user.modelName,
where: [{ field: "id", value: userIds, operator: "in" }], where: [{ field: "id", value: userIds, operator: "in" }],
}); });

View File

@@ -16,11 +16,16 @@ import { z } from "zod";
import { createAuthEndpoint } from "../../api/call"; import { createAuthEndpoint } from "../../api/call";
import { sessionMiddleware } from "../../api"; import { sessionMiddleware } from "../../api";
import { getSessionFromCtx } from "../../api/routes"; import { getSessionFromCtx } from "../../api/routes";
import type { BetterAuthPlugin } from "../../types/plugins"; import type {
BetterAuthPlugin,
InferOptionSchema,
PluginSchema,
} from "../../types/plugins";
import { setSessionCookie } from "../../cookies"; import { setSessionCookie } from "../../cookies";
import { BetterAuthError } from "../../error"; import { BetterAuthError } from "../../error";
import { generateId } from "../../utils/id"; import { generateId } from "../../utils/id";
import { env } from "../../utils/env"; import { env } from "../../utils/env";
import { mergeSchema } from "../../db/schema";
interface WebAuthnChallengeValue { interface WebAuthnChallengeValue {
expectedChallenge: string; expectedChallenge: string;
@@ -58,6 +63,10 @@ export interface PasskeyOptions {
advanced?: { advanced?: {
webAuthnChallengeCookie?: string; webAuthnChallengeCookie?: string;
}; };
/**
* Schema for the passkey model
*/
schema?: InferOptionSchema<typeof schema>;
} }
export type Passkey = { export type Passkey = {
@@ -498,7 +507,11 @@ export const passkey = (options?: PasskeyOptions) => {
}, },
), ),
}, },
schema: { schema: mergeSchema(schema, options?.schema),
} satisfies BetterAuthPlugin;
};
const schema = {
passkey: { passkey: {
fields: { fields: {
name: { name: {
@@ -544,8 +557,6 @@ export const passkey = (options?: PasskeyOptions) => {
}, },
}, },
}, },
}, } satisfies PluginSchema;
} satisfies BetterAuthPlugin;
};
export * from "./client"; export * from "./client";

View File

@@ -1,8 +1,12 @@
import { z } from "zod"; import { z } from "zod";
import { createAuthEndpoint } from "../../api/call"; import { createAuthEndpoint } from "../../api/call";
import type { BetterAuthPlugin } from "../../types/plugins"; import type {
BetterAuthPlugin,
InferOptionSchema,
PluginSchema,
} from "../../types/plugins";
import { APIError } from "better-call"; import { APIError } from "better-call";
import type { User } from "../../db/schema"; import { mergeSchema, type User } from "../../db/schema";
import { alphabet, generateRandomString } from "../../crypto/random"; import { alphabet, generateRandomString } from "../../crypto/random";
import { getSessionFromCtx } from "../../api"; import { getSessionFromCtx } from "../../api";
import { getDate } from "../../utils/date"; import { getDate } from "../../utils/date";
@@ -83,6 +87,10 @@ export const phoneNumber = (options?: {
*/ */
getTempName?: (phoneNumber: string) => string; getTempName?: (phoneNumber: string) => string;
}; };
/**
* Custom schema for the admin plugin
*/
schema?: InferOptionSchema<typeof schema>;
}) => { }) => {
const opts = { const opts = {
phoneNumber: "phoneNumber", phoneNumber: "phoneNumber",
@@ -204,7 +212,7 @@ export const phoneNumber = (options?: {
} }
let user = await ctx.context.adapter.findOne<UserWithPhoneNumber>({ let user = await ctx.context.adapter.findOne<UserWithPhoneNumber>({
model: ctx.context.tables.user.tableName, model: ctx.context.tables.user.modelName,
where: [ where: [
{ {
value: ctx.body.phoneNumber, value: ctx.body.phoneNumber,
@@ -280,7 +288,11 @@ export const phoneNumber = (options?: {
}, },
), ),
}, },
schema: { schema: mergeSchema(schema, options?.schema),
} satisfies BetterAuthPlugin;
};
const schema = {
user: { user: {
fields: { fields: {
phoneNumber: { phoneNumber: {
@@ -297,6 +309,4 @@ export const phoneNumber = (options?: {
}, },
}, },
}, },
}, } satisfies PluginSchema;
} satisfies BetterAuthPlugin;
};

View File

@@ -8,17 +8,18 @@ import { backupCode2fa, generateBackupCodes } from "./backup-codes";
import { otp2fa } from "./otp"; import { otp2fa } from "./otp";
import { totp2fa } from "./totp"; import { totp2fa } from "./totp";
import type { TwoFactorOptions, UserWithTwoFactor } from "./types"; import type { TwoFactorOptions, UserWithTwoFactor } from "./types";
import type { Session } from "../../db/schema"; import { mergeSchema, type Session } from "../../db/schema";
import { TWO_FACTOR_COOKIE_NAME, TRUST_DEVICE_COOKIE_NAME } from "./constant"; import { TWO_FACTOR_COOKIE_NAME, TRUST_DEVICE_COOKIE_NAME } from "./constant";
import { validatePassword } from "../../utils/password"; import { validatePassword } from "../../utils/password";
import { APIError } from "better-call"; import { APIError } from "better-call";
import { createTOTPKeyURI } from "oslo/otp"; import { createTOTPKeyURI } from "oslo/otp";
import { TimeSpan } from "oslo"; import { TimeSpan } from "oslo";
import { deleteSessionCookie, setSessionCookie } from "../../cookies"; import { deleteSessionCookie, setSessionCookie } from "../../cookies";
import { schema } from "./schema";
export const twoFactor = (options?: TwoFactorOptions) => { export const twoFactor = (options?: TwoFactorOptions) => {
const opts = { const opts = {
twoFactorTable: options?.twoFactorTable || ("twoFactor" as const), twoFactorTable: "twoFactor",
}; };
const totp = totp2fa( const totp = totp2fa(
{ {
@@ -267,42 +268,7 @@ export const twoFactor = (options?: TwoFactorOptions) => {
}, },
], ],
}, },
schema: { schema: mergeSchema(schema, options?.schema),
user: {
fields: {
twoFactorEnabled: {
type: "boolean",
required: false,
defaultValue: false,
input: false,
},
},
},
twoFactor: {
tableName: opts.twoFactorTable,
fields: {
secret: {
type: "string",
required: true,
returned: false,
},
backupCodes: {
type: "string",
required: true,
returned: false,
},
userId: {
type: "string",
required: true,
returned: false,
references: {
model: "user",
field: "id",
},
},
},
},
},
rateLimit: [ rateLimit: [
{ {
pathMatcher(path) { pathMatcher(path) {

View File

@@ -0,0 +1,37 @@
import type { PluginSchema } from "../../types";
export const schema = {
user: {
fields: {
twoFactorEnabled: {
type: "boolean",
required: false,
defaultValue: false,
input: false,
},
},
},
twoFactor: {
fields: {
secret: {
type: "string",
required: true,
returned: false,
},
backupCodes: {
type: "string",
required: true,
returned: false,
},
userId: {
type: "string",
required: true,
returned: false,
references: {
model: "user",
field: "id",
},
},
},
},
} satisfies PluginSchema;

View File

@@ -4,6 +4,8 @@ import type { LiteralString } from "../../types/helper";
import type { BackupCodeOptions } from "./backup-codes"; import type { BackupCodeOptions } from "./backup-codes";
import type { OTPOptions } from "./otp"; import type { OTPOptions } from "./otp";
import type { TOTPOptions } from "./totp"; import type { TOTPOptions } from "./totp";
import type { InferOptionSchema } from "../../types";
import type { schema } from "./schema";
export interface TwoFactorOptions { export interface TwoFactorOptions {
/** /**
@@ -22,16 +24,15 @@ export interface TwoFactorOptions {
* Backup code options * Backup code options
*/ */
backupCodeOptions?: BackupCodeOptions; backupCodeOptions?: BackupCodeOptions;
/**
* Table name for two factor authentication.
* @default "userTwoFactor"
*/
twoFactorTable?: string;
/** /**
* Skip verification on enabling two factor authentication. * Skip verification on enabling two factor authentication.
* @default false * @default false
*/ */
skipVerificationOnEnable?: boolean; skipVerificationOnEnable?: boolean;
/**
* Custom schema for the two factor plugin
*/
schema?: InferOptionSchema<typeof schema>;
} }
export interface UserWithTwoFactor extends User { export interface UserWithTwoFactor extends User {

View File

@@ -20,7 +20,7 @@ export const username = () => {
}, },
async (ctx) => { async (ctx) => {
const user = await ctx.context.adapter.findOne<User>({ const user = await ctx.context.adapter.findOne<User>({
model: ctx.context.tables.user.tableName, model: ctx.context.tables.user.modelName,
where: [ where: [
{ {
field: "username", field: "username",
@@ -36,7 +36,7 @@ export const username = () => {
}); });
} }
const account = await ctx.context.adapter.findOne<Account>({ const account = await ctx.context.adapter.findOne<Account>({
model: ctx.context.tables.account.tableName, model: "account",
where: [ where: [
{ {
field: field:

View File

@@ -412,7 +412,7 @@ export interface BetterAuthOptions {
* *
* @default "rateLimit" * @default "rateLimit"
*/ */
tableName?: string; modelName?: string;
/** /**
* Custom field names for the rate limit table * Custom field names for the rate limit table
*/ */

View File

@@ -12,7 +12,7 @@ export type PluginSchema = {
[field in string]: FieldAttribute; [field in string]: FieldAttribute;
}; };
disableMigration?: boolean; disableMigration?: boolean;
tableName?: string; modelName?: string;
}; };
}; };
@@ -125,3 +125,17 @@ export type BetterAuthPlugin = {
pathMatcher: (path: string) => boolean; pathMatcher: (path: string) => boolean;
}[]; }[];
}; };
export type InferOptionSchema<S extends PluginSchema> = S extends Record<
string,
{ fields: infer Fields }
>
? {
[K in keyof S]?: {
modelName?: string;
fields: {
[P in keyof Fields]?: string;
};
};
}
: never;

View File

@@ -19,7 +19,7 @@ export const generateDrizzleSchema: SchemaGenerator = async ({
const fileExist = existsSync(filePath); const fileExist = existsSync(filePath);
for (const table in tables) { for (const table in tables) {
const tableName = tables[table].tableName; const modelName = tables[table].modelName;
const fields = tables[table].fields; const fields = tables[table].fields;
function getType(name: string, type: FieldType) { function getType(name: string, type: FieldType) {
if (type === "string") { if (type === "string") {
@@ -45,7 +45,7 @@ export const generateDrizzleSchema: SchemaGenerator = async ({
return `timestamp('${name}')`; return `timestamp('${name}')`;
} }
} }
const schema = `export const ${table} = ${databaseType}Table("${tableName}", { const schema = `export const ${table} = ${databaseType}Table("${modelName}", {
id: text("id").primaryKey(), id: text("id").primaryKey(),
${Object.keys(fields) ${Object.keys(fields)
.map((field) => { .map((field) => {

View File

@@ -28,8 +28,8 @@ export const generatePrismaSchema: SchemaGenerator = async ({
const schema = produceSchema(schemaPrisma, (builder) => { const schema = produceSchema(schemaPrisma, (builder) => {
for (const table in tables) { for (const table in tables) {
const fields = tables[table]?.fields; const fields = tables[table]?.fields;
const originalTable = tables[table]?.tableName; const originalTable = tables[table]?.modelName;
const tableName = capitalizeFirstLetter(originalTable || ""); const modelName = capitalizeFirstLetter(originalTable || "");
function getType(type: FieldType, isOptional: boolean) { function getType(type: FieldType, isOptional: boolean) {
if (type === "string") { if (type === "string") {
return isOptional ? "String?" : "String"; return isOptional ? "String?" : "String";
@@ -51,17 +51,17 @@ export const generatePrismaSchema: SchemaGenerator = async ({
} }
} }
const prismaModel = builder.findByType("model", { const prismaModel = builder.findByType("model", {
name: tableName, name: modelName,
}); });
if (!prismaModel) { if (!prismaModel) {
if (provider === "mongodb") { if (provider === "mongodb") {
builder builder
.model(tableName) .model(modelName)
.field("id", "String") .field("id", "String")
.attribute("id") .attribute("id")
.attribute(`map("_id")`); .attribute(`map("_id")`);
} else { } else {
builder.model(tableName).field("id", "String").attribute("id"); builder.model(modelName).field("id", "String").attribute("id");
} }
} }
@@ -79,14 +79,14 @@ export const generatePrismaSchema: SchemaGenerator = async ({
} }
builder builder
.model(tableName) .model(modelName)
.field(field, getType(attr.type, !attr?.required)); .field(field, getType(attr.type, !attr?.required));
if (attr.unique) { if (attr.unique) {
builder.model(tableName).blockAttribute(`unique([${field}])`); builder.model(modelName).blockAttribute(`unique([${field}])`);
} }
if (attr.references) { if (attr.references) {
builder builder
.model(tableName) .model(modelName)
.field( .field(
`${attr.references.model.toLowerCase()}`, `${attr.references.model.toLowerCase()}`,
capitalizeFirstLetter(attr.references.model), capitalizeFirstLetter(attr.references.model),
@@ -100,8 +100,8 @@ export const generatePrismaSchema: SchemaGenerator = async ({
name: "map", name: "map",
within: prismaModel?.properties, within: prismaModel?.properties,
}); });
if (originalTable !== tableName && !hasAttribute) { if (originalTable !== modelName && !hasAttribute) {
builder.model(tableName).blockAttribute("map", originalTable); builder.model(modelName).blockAttribute("map", originalTable);
} }
} }
}); });