diff --git a/biome.json b/biome.json index dfb9de93..e152c402 100644 --- a/biome.json +++ b/biome.json @@ -66,7 +66,8 @@ "!**/.cache", "!**/dev/cloudflare/drizzle", "!**/playwright-report", - "!**/.output" + "!**/.output", + "!**/.tmp" ] } } diff --git a/docker-compose.yml b/docker-compose.yml index 89740f24..0c9fb147 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,8 @@ services: - "27017:27017" volumes: - mongodb_data:/data/db - + + # drizzle postgres: image: postgres:latest container_name: postgres @@ -21,6 +22,31 @@ services: volumes: - postgres_data:/var/lib/postgresql/data + postgres-kysely: + image: postgres:latest + container_name: postgres-kysely + environment: + POSTGRES_USER: user + POSTGRES_PASSWORD: password + POSTGRES_DB: better_auth + ports: + - "5433:5432" + volumes: + - postgres-kysely_data:/var/lib/postgresql/data + + postgres-prisma: + image: postgres:latest + container_name: postgres-prisma + environment: + POSTGRES_USER: user + POSTGRES_PASSWORD: password + POSTGRES_DB: better_auth + ports: + - "5434:5432" + volumes: + - postgres-prisma_data:/var/lib/postgresql/data + + # Drizzle tests mysql: image: mysql:latest container_name: mysql @@ -34,6 +60,34 @@ services: volumes: - mysql_data:/var/lib/mysql + + mysql-kysely: + image: mysql:latest + container_name: mysql-kysely + environment: + MYSQL_ROOT_PASSWORD: root_password + MYSQL_DATABASE: better_auth + MYSQL_USER: user + MYSQL_PASSWORD: password + ports: + - "3307:3306" + volumes: + - mysql-kysely_data:/var/lib/mysql + + mysql-prisma: + image: mysql:latest + container_name: mysql-prisma + environment: + MYSQL_ROOT_PASSWORD: root_password + MYSQL_DATABASE: better_auth + MYSQL_USER: user + MYSQL_PASSWORD: password + ports: + - "3308:3306" + volumes: + - mysql-prisma_data:/var/lib/mysql + + mssql: image: mcr.microsoft.com/mssql/server:latest container_name: mssql @@ -48,5 +102,9 @@ services: volumes: mongodb_data: postgres_data: + postgres-kysely_data: + postgres-prisma_data: mysql_data: - mssql_data: \ No newline at end of file + mssql_data: + mysql-kysely_data: + mysql-prisma_data: \ No newline at end of file diff --git a/docs/content/docs/adapters/mssql.mdx b/docs/content/docs/adapters/mssql.mdx index e9696cf1..c096cb89 100644 --- a/docs/content/docs/adapters/mssql.mdx +++ b/docs/content/docs/adapters/mssql.mdx @@ -43,6 +43,10 @@ const dialect = new MssqlDialect({ server: 'localhost', }), }, + TYPES: { + ...Tedious.TYPES, + DateTime: Tedious.TYPES.DateTime2, + }, }) export const auth = betterAuth({ diff --git a/docs/content/docs/adapters/mysql.mdx b/docs/content/docs/adapters/mysql.mdx index 3bdb0a4b..b16851e8 100644 --- a/docs/content/docs/adapters/mysql.mdx +++ b/docs/content/docs/adapters/mysql.mdx @@ -21,6 +21,7 @@ export const auth = betterAuth({ user: "root", password: "password", database: "database", + timezone: "Z", // Important to ensure consistent timezone values }), }); ``` diff --git a/packages/better-auth/package.json b/packages/better-auth/package.json index 29fa42fa..efe7ae18 100644 --- a/packages/better-auth/package.json +++ b/packages/better-auth/package.json @@ -26,7 +26,7 @@ "build": "unbuild --clean", "stub": "unbuild --stub", "test": "vitest", - "prepare": "prisma generate --schema ./src/adapters/prisma-adapter/test/normal-tests/schema.prisma && prisma generate --schema ./src/adapters/prisma-adapter/test/number-id-tests/schema.prisma", + "prepare": "prisma generate --schema ./src/adapters/prisma-adapter/test/base.prisma", "typecheck": "tsc --project tsconfig.json" }, "main": "./dist/index.cjs", @@ -757,6 +757,7 @@ "better-sqlite3": "^12.2.0", "concurrently": "^9.2.1", "deepmerge": "^4.3.1", + "drizzle-kit": "^0.31.4", "drizzle-orm": "^0.38.2", "happy-dom": "^18.0.1", "hono": "^4.9.7", diff --git a/packages/better-auth/src/__snapshots__/init.test.ts.snap b/packages/better-auth/src/__snapshots__/init.test.ts.snap index af24f865..f0f733f5 100644 --- a/packages/better-auth/src/__snapshots__/init.test.ts.snap +++ b/packages/better-auth/src/__snapshots__/init.test.ts.snap @@ -16,6 +16,8 @@ exports[`init > should match config 1`] = ` "adapterId": "kysely", "adapterName": "Kysely Adapter", "debugLogs": false, + "disableTransformInput": false, + "disableTransformOutput": false, "supportsBooleans": false, "supportsDates": false, "supportsJSON": false, diff --git a/packages/better-auth/src/adapters/adapter-factory/index.ts b/packages/better-auth/src/adapters/adapter-factory/index.ts index 050b6831..2b291677 100644 --- a/packages/better-auth/src/adapters/adapter-factory/index.ts +++ b/packages/better-auth/src/adapters/adapter-factory/index.ts @@ -14,42 +14,13 @@ import type { AdapterTestDebugLogs, CleanedWhere, } from "./types"; +import { colors } from "../../utils/colors"; import type { DBFieldAttribute } from "@better-auth/core/db"; export * from "./types"; -let debugLogs: any[] = []; +let debugLogs: { instance: string; args: any[] }[] = []; let transactionId = -1; -const colors = { - reset: "\x1b[0m", - bright: "\x1b[1m", - dim: "\x1b[2m", - underscore: "\x1b[4m", - blink: "\x1b[5m", - reverse: "\x1b[7m", - hidden: "\x1b[8m", - fg: { - black: "\x1b[30m", - red: "\x1b[31m", - green: "\x1b[32m", - yellow: "\x1b[33m", - blue: "\x1b[34m", - magenta: "\x1b[35m", - cyan: "\x1b[36m", - white: "\x1b[37m", - }, - bg: { - black: "\x1b[40m", - red: "\x1b[41m", - green: "\x1b[42m", - yellow: "\x1b[43m", - blue: "\x1b[44m", - magenta: "\x1b[45m", - cyan: "\x1b[46m", - white: "\x1b[47m", - }, -}; - const createAsIsTransaction = (adapter: Adapter) => (fn: (trx: TransactionAdapter) => Promise) => @@ -63,6 +34,10 @@ export const createAdapterFactory = config: cfg, }: AdapterFactoryOptions): AdapterFactory => (options: BetterAuthOptions): Adapter => { + const uniqueAdapterFactoryInstanceId = Math.random() + .toString(36) + .substring(2, 15); + const config = { ...cfg, supportsBooleans: cfg.supportsBooleans ?? true, @@ -71,6 +46,8 @@ export const createAdapterFactory = adapterName: cfg.adapterName ?? cfg.adapterId, supportsNumericIds: cfg.supportsNumericIds ?? true, transaction: cfg.transaction ?? false, + disableTransformInput: cfg.disableTransformInput ?? false, + disableTransformOutput: cfg.disableTransformOutput ?? false, } satisfies AdapterFactoryConfig; if ( @@ -112,9 +89,13 @@ export const createAdapterFactory = let f = schema[model]?.fields[field]; if (!f) { - f = Object.values(schema[model]?.fields!).find( - (f) => f.fieldName === field, + const result = Object.entries(schema[model]!.fields!).find( + ([_, f]) => f.fieldName === field, ); + if (result) { + f = result[1]; + field = result[0]; + } } if (!f) { debugLog(`Field ${field} not found in model ${model}`); @@ -215,7 +196,7 @@ export const createAdapterFactory = ) { if (config.debugLogs.isRunningAdapterTests) { args.shift(); // Removes the {method: "..."} object from the args array. - debugLogs.push(args); + debugLogs.push({ instance: uniqueAdapterFactoryInstanceId, args }); } return; } @@ -316,7 +297,7 @@ export const createAdapterFactory = const defaultModelName = getDefaultModelName(model); const defaultFieldName = getDefaultFieldName({ field: field, - model: model, + model: defaultModelName, }); const fields = schema[defaultModelName]!.fields; @@ -324,32 +305,21 @@ export const createAdapterFactory = return fields[defaultFieldName]!; }; - const adapterInstance = customAdapter({ - options, - schema, - debugLog, - getFieldName, - getModelName, - getDefaultModelName, - getDefaultFieldName, - getFieldAttributes, - }); - const transformInput = async ( data: Record, - unsafe_model: string, + defaultModelName: string, action: "create" | "update", forceAllowId?: boolean, ) => { const transformedData: Record = {}; - const fields = schema[unsafe_model]!.fields; + const fields = schema[defaultModelName]!.fields; const newMappedKeys = config.mapKeysTransformInput ?? {}; if ( !config.disableIdGeneration && !options.advanced?.database?.useNumberId ) { fields.id = idField({ - customModelName: unsafe_model, + customModelName: defaultModelName, forceAllowId: forceAllowId && "id" in data, }); } @@ -411,7 +381,7 @@ export const createAdapterFactory = action, field: newFieldName, fieldAttributes: fieldAttributes!, - model: unsafe_model, + model: defaultModelName, schema, options, }); @@ -446,6 +416,7 @@ export const createAdapterFactory = const field = tableSchema[key]; if (field) { const originalKey = field.fieldName || key; + // If the field is mapped, we'll use the mapped key. Otherwise, we'll use the original key. let newValue = data[ @@ -573,14 +544,36 @@ export const createAdapterFactory = }) as any; }; + const adapterInstance = customAdapter({ + options, + schema, + debugLog, + getFieldName, + getModelName, + getDefaultModelName, + getDefaultFieldName, + getFieldAttributes, + transformInput, + transformOutput, + transformWhereClause, + }); + let lazyLoadTransaction: Adapter["transaction"] | null = null; const adapter: Adapter = { transaction: async (cb) => { if (!lazyLoadTransaction) { if (!config.transaction) { - logger.warn( - `[${config.adapterName}] - Transactions are not supported. Executing operations sequentially.`, - ); + if ( + typeof config.debugLogs === "object" && + "isRunningAdapterTests" in config.debugLogs && + config.debugLogs.isRunningAdapterTests + ) { + // hide warning in adapter tests + } else { + logger.warn( + `[${config.adapterName}] - Transactions are not supported. Executing operations sequentially.`, + ); + } lazyLoadTransaction = createAsIsTransaction(adapter); } else { logger.debug( @@ -626,12 +619,15 @@ export const createAdapterFactory = `${formatMethod("create")} ${formatAction("Unsafe Input")}:`, { model, data: unsafeData }, ); - const data = (await transformInput( - unsafeData, - unsafeModel, - "create", - forceAllowId, - )) as T; + let data = unsafeData; + if (!config.disableTransformInput) { + data = (await transformInput( + unsafeData, + unsafeModel, + "create", + forceAllowId, + )) as T; + } debugLog( { method: "create" }, `${formatTransactionId(thisTransactionId)} ${formatStep(2, 4)}`, @@ -645,7 +641,10 @@ export const createAdapterFactory = `${formatMethod("create")} ${formatAction("DB Result")}:`, { model, res }, ); - const transformed = await transformOutput(res, unsafeModel, select); + let transformed = res as any; + if (!config.disableTransformOutput) { + transformed = await transformOutput(res as any, unsafeModel, select); + } debugLog( { method: "create" }, `${formatTransactionId(thisTransactionId)} ${formatStep(4, 4)}`, @@ -676,11 +675,10 @@ export const createAdapterFactory = `${formatMethod("update")} ${formatAction("Unsafe Input")}:`, { model, data: unsafeData }, ); - const data = (await transformInput( - unsafeData, - unsafeModel, - "update", - )) as T; + let data = unsafeData as T; + if (!config.disableTransformInput) { + data = (await transformInput(unsafeData, unsafeModel, "update")) as T; + } debugLog( { method: "update" }, `${formatTransactionId(thisTransactionId)} ${formatStep(2, 4)}`, @@ -698,7 +696,10 @@ export const createAdapterFactory = `${formatMethod("update")} ${formatAction("DB Result")}:`, { model, data: res }, ); - const transformed = await transformOutput(res as any, unsafeModel); + let transformed = res as any; + if (!config.disableTransformOutput) { + transformed = await transformOutput(res as any, unsafeModel); + } debugLog( { method: "update" }, `${formatTransactionId(thisTransactionId)} ${formatStep(4, 4)}`, @@ -729,7 +730,10 @@ export const createAdapterFactory = `${formatMethod("updateMany")} ${formatAction("Unsafe Input")}:`, { model, data: unsafeData }, ); - const data = await transformInput(unsafeData, unsafeModel, "update"); + let data = unsafeData; + if (!config.disableTransformInput) { + data = await transformInput(unsafeData, unsafeModel, "update"); + } debugLog( { method: "updateMany" }, `${formatTransactionId(thisTransactionId)} ${formatStep(2, 4)}`, @@ -789,11 +793,10 @@ export const createAdapterFactory = `${formatMethod("findOne")} ${formatAction("DB Result")}:`, { model, data: res }, ); - const transformed = await transformOutput( - res as any, - unsafeModel, - select, - ); + let transformed = res as any; + if (!config.disableTransformOutput) { + transformed = await transformOutput(res as any, unsafeModel, select); + } debugLog( { method: "findOne" }, `${formatTransactionId(thisTransactionId)} ${formatStep(3, 3)}`, @@ -845,9 +848,12 @@ export const createAdapterFactory = `${formatMethod("findMany")} ${formatAction("DB Result")}:`, { model, data: res }, ); - const transformed = await Promise.all( - res.map(async (r) => await transformOutput(r as any, unsafeModel)), - ); + let transformed = res as any; + if (!config.disableTransformOutput) { + transformed = await Promise.all( + res.map(async (r) => await transformOutput(r as any, unsafeModel)), + ); + } debugLog( { method: "findMany" }, `${formatTransactionId(thisTransactionId)} ${formatStep(3, 3)}`, @@ -1021,17 +1027,25 @@ export const createAdapterFactory = ? { adapterTestDebugLogs: { resetDebugLogs() { - debugLogs = []; + debugLogs = debugLogs.filter( + (log) => log.instance !== uniqueAdapterFactoryInstanceId, + ); }, printDebugLogs() { const separator = `─`.repeat(80); + const logs = debugLogs.filter( + (log) => log.instance === uniqueAdapterFactoryInstanceId, + ); + if (logs.length === 0) { + return; + } //`${colors.fg.blue}|${colors.reset} `, - let log: any[] = debugLogs + let log: any[] = logs .reverse() .map((log) => { - log[0] = `\n${log[0]!}`; - return [...log, "\n"]; + log.args[0] = `\n${log.args[0]}`; + return [...log.args, "\n"]; }) .reduce( (prev, curr) => { diff --git a/packages/better-auth/src/adapters/adapter-factory/types.ts b/packages/better-auth/src/adapters/adapter-factory/types.ts index 48249fe4..0c399250 100644 --- a/packages/better-auth/src/adapters/adapter-factory/types.ts +++ b/packages/better-auth/src/adapters/adapter-factory/types.ts @@ -239,6 +239,18 @@ export interface AdapterFactoryConfig { * ``` */ customIdGenerator?: (props: { model: string }) => string; + /** + * Whether to disable the transform output. + * Do not use this option unless you know what you are doing. + * @default false + */ + disableTransformOutput?: boolean; + /** + * Whether to disable the transform input. + * Do not use this option unless you know what you are doing. + * @default false + */ + disableTransformInput?: boolean; } export type AdapterFactoryCustomizeAdapterCreator = (config: { @@ -304,6 +316,25 @@ export type AdapterFactoryCustomizeAdapterCreator = (config: { model: string; field: string; }) => DBFieldAttribute; + // The following functions are exposed primarily for the purpose of having wrapper adapters. + transformInput: ( + data: Record, + defaultModelName: string, + action: "create" | "update", + forceAllowId?: boolean, + ) => Promise>; + transformOutput: ( + data: Record, + defaultModelName: string, + select?: string[], + ) => Promise>; + transformWhereClause: ({ + model, + where, + }: { + where: W; + model: string; + }) => W extends undefined ? undefined : CleanedWhere[]; }) => CustomAdapter; export interface CustomAdapter { diff --git a/packages/better-auth/src/adapters/create-test-suite.ts b/packages/better-auth/src/adapters/create-test-suite.ts new file mode 100644 index 00000000..b238e6a7 --- /dev/null +++ b/packages/better-auth/src/adapters/create-test-suite.ts @@ -0,0 +1,534 @@ +import type { Adapter } from "../types"; +import type { User, Session, Verification, Account } from "../types"; +import type { BetterAuthOptions } from "../types"; +import { createAdapterFactory } from "./adapter-factory"; +import { test } from "vitest"; +import { generateId } from "../utils"; +import type { Logger } from "./test-adapter"; +import { colors } from "../utils/colors"; +import { betterAuth } from "../auth"; +import { deepmerge } from "./utils"; + +type GenerateFn = ( + Model: M, +) => Promise< + M extends "user" + ? User + : M extends "session" + ? Session + : M extends "verification" + ? Verification + : M extends "account" + ? Account + : undefined +>; + +type Success = { + data: T; + error: null; +}; + +type Failure = { + data: null; + error: E; +}; + +type Result = Success | Failure; + +async function tryCatch( + promise: Promise, +): Promise> { + try { + const data = await promise; + return { data, error: null }; + } catch (error) { + return { data: null, error: error as E }; + } +} + +export type InsertRandomFn = < + M extends "user" | "session" | "verification" | "account", + Count extends number = 1, +>( + model: M, + count?: Count, +) => Promise< + Count extends 1 + ? M extends "user" + ? [User] + : M extends "session" + ? [User, Session] + : M extends "verification" + ? [Verification] + : M extends "account" + ? [User, Account] + : [undefined] + : Array< + M extends "user" + ? [User] + : M extends "session" + ? [User, Session] + : M extends "verification" + ? [Verification] + : M extends "account" + ? [User, Account] + : [undefined] + > +>; + +export const createTestSuite = < + Tests extends Record< + string, + (context: { + /** + * Mark tests as skipped. All execution after this call will be skipped. + * This function throws an error, so make sure you are not catching it accidentally. + * @see {@link https://vitest.dev/guide/test-context#skip} + */ + readonly skip: { + (note?: string): never; + (condition: boolean, note?: string): void; + }; + }) => Promise + >, + AdditionalOptions extends Record = {}, +>( + suiteName: string, + config: { + defaultBetterAuthOptions?: BetterAuthOptions; + /** + * Helpful if the default better auth options require migrations to be run. + */ + alwaysMigrate?: boolean; + prefixTests?: string; + }, + tests: ( + helpers: { + adapter: Adapter; + log: Logger; + generate: GenerateFn; + insertRandom: InsertRandomFn; + /** + * A light cleanup function that will only delete rows it knows about. + */ + cleanup: () => Promise; + /** + * A hard cleanup function that will delete all rows from the database. + */ + hardCleanup: () => Promise; + modifyBetterAuthOptions: ( + options: BetterAuthOptions, + shouldRunMigrations: boolean, + ) => Promise; + getBetterAuthOptions: () => BetterAuthOptions; + sortModels: ( + models: Array< + Record & { + id: string; + } + >, + by?: "id" | "createdAt", + ) => (Record & { + id: string; + })[]; + getAuth: () => Promise>; + tryCatch(promise: Promise): Promise>; + customIdGenerator?: () => string | Promise; + }, + additionalOptions?: AdditionalOptions, + ) => Tests, +) => { + return ( + options?: { + disableTests?: Partial & { ALL?: boolean }>; + } & AdditionalOptions, + ) => { + return async (helpers: { + adapter: () => Promise; + log: Logger; + adapterDisplayName: string; + getBetterAuthOptions: () => BetterAuthOptions; + modifyBetterAuthOptions: ( + options: BetterAuthOptions, + ) => Promise; + cleanup: () => Promise; + runMigrations: () => Promise; + prefixTests?: string; + onTestFinish: () => Promise; + customIdGenerator?: () => string | Promise; + defaultRetryCount?: number; + }) => { + const createdRows: Record = {}; + + let adapter = await helpers.adapter(); + const wrapperAdapter = (overrideOptions?: BetterAuthOptions) => { + const options = deepmerge( + deepmerge( + helpers.getBetterAuthOptions(), + config?.defaultBetterAuthOptions || {}, + ), + overrideOptions || {}, + ); + const adapterConfig = { + adapterId: helpers.adapterDisplayName, + ...(adapter.options?.adapterConfig || {}), + adapterName: `Wrapped ${adapter.options?.adapterConfig.adapterName}`, + disableTransformOutput: true, + disableTransformInput: true, + }; + const adapterCreator = (options: BetterAuthOptions): Adapter => + createAdapterFactory({ + config: { + ...adapterConfig, + transaction: adapter.transaction, + }, + adapter: ({ getDefaultModelName }) => { + //@ts-expect-error + adapter.transaction = undefined; + return { + count: adapter.count, + deleteMany: adapter.deleteMany, + delete: adapter.delete, + findOne: adapter.findOne, + findMany: adapter.findMany, + update: adapter.update as any, + updateMany: adapter.updateMany, + + createSchema: adapter.createSchema as any, + async create({ data, model, select }) { + const defaultModelName = getDefaultModelName(model); + adapter = await helpers.adapter(); + const res = await adapter.create({ + data: data, + model: defaultModelName, + select, + forceAllowId: true, + }); + createdRows[model] = [...(createdRows[model] || []), res]; + return res as any; + }, + options: adapter.options, + }; + }, + })(options); + + return adapterCreator(options); + }; + + const resetDebugLogs = () => { + //@ts-expect-error + wrapperAdapter()?.adapterTestDebugLogs?.resetDebugLogs(); + }; + + const printDebugLogs = () => { + //@ts-expect-error + wrapperAdapter()?.adapterTestDebugLogs?.printDebugLogs(); + }; + + const cleanupCreatedRows = async () => { + adapter = await helpers.adapter(); + for (const model of Object.keys(createdRows)) { + for (const row of createdRows[model]!) { + try { + await adapter.delete({ + model, + where: [{ field: "id", value: row.id }], + }); + } catch (error) { + // We ignore any failed attempts to delete the created rows. + } + if (createdRows[model]!.length === 1) { + delete createdRows[model]; + } + } + } + }; + + let didMigrateOnOptionsModify = false; + + const resetBetterAuthOptions = async () => { + adapter = await helpers.adapter(); + await helpers.modifyBetterAuthOptions( + config.defaultBetterAuthOptions || {}, + ); + if (didMigrateOnOptionsModify) { + didMigrateOnOptionsModify = false; + await helpers.runMigrations(); + adapter = await helpers.adapter(); + } + }; + + const generateModel: GenerateFn = async (model: string) => { + const id = (await helpers.customIdGenerator?.()) || generateId(); + const randomDate = new Date( + Date.now() - Math.random() * 1000 * 60 * 60 * 24 * 365, + ); + if (model === "user") { + const user: User = { + id, + createdAt: randomDate, + updatedAt: new Date(), + email: `user-${id}@email.com`.toLowerCase(), + emailVerified: true, + name: `user-${id}`, + image: null, + }; + return user as any; + } + if (model === "session") { + const session: Session = { + id, + createdAt: randomDate, + updatedAt: new Date(), + expiresAt: new Date(), + token: generateId(32), + userId: generateId(), + ipAddress: "127.0.0.1", + userAgent: "Some User Agent", + }; + return session as any; + } + if (model === "verification") { + const verification: Verification = { + id, + createdAt: randomDate, + updatedAt: new Date(), + expiresAt: new Date(), + identifier: `test:${generateId()}`, + value: generateId(), + }; + return verification as any; + } + if (model === "account") { + const account: Account = { + id, + createdAt: randomDate, + updatedAt: new Date(), + accountId: generateId(), + providerId: "test", + userId: generateId(), + accessToken: generateId(), + refreshToken: generateId(), + idToken: generateId(), + accessTokenExpiresAt: new Date(), + refreshTokenExpiresAt: new Date(), + scope: "test", + }; + return account as any; + } + // This should never happen given the type constraints, but TypeScript needs an exhaustive check + throw new Error(`Unknown model type: ${model}`); + }; + + const insertRandom: InsertRandomFn = async < + M extends "user" | "session" | "verification" | "account", + Count extends number = 1, + >( + model: M, + count: Count = 1 as Count, + ) => { + let res: any[] = []; + const a = wrapperAdapter(); + + for (let i = 0; i < count; i++) { + const modelResults = []; + + if (model === "user") { + const user = await generateModel("user"); + modelResults.push( + await a.create({ + data: user, + model: "user", + forceAllowId: true, + }), + ); + } + if (model === "session") { + const user = await generateModel("user"); + const userRes = await a.create({ + data: user, + model: "user", + forceAllowId: true, + }); + const session = await generateModel("session"); + session.userId = userRes.id; + const sessionRes = await a.create({ + data: session, + model: "session", + forceAllowId: true, + }); + modelResults.push(userRes, sessionRes); + } + if (model === "verification") { + const verification = await generateModel("verification"); + modelResults.push( + await a.create({ + data: verification, + model: "verification", + forceAllowId: true, + }), + ); + } + if (model === "account") { + const user = await generateModel("user"); + const account = await generateModel("account"); + const userRes = await a.create({ + data: user, + model: "user", + forceAllowId: true, + }); + account.userId = userRes.id; + const accRes = await a.create({ + data: account, + model: "account", + forceAllowId: true, + }); + modelResults.push(userRes, accRes); + } + res.push(modelResults); + } + return res.length === 1 ? res[0] : (res as any); + }; + + const sortModels = ( + models: Array & { id: string }>, + by: "id" | "createdAt" = "id", + ) => { + return models.sort((a, b) => { + if (by === "createdAt") { + return ( + new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime() + ); + } + return a.id.localeCompare(b.id); + }); + }; + + const modifyBetterAuthOptions = async ( + opts: BetterAuthOptions, + shouldRunMigrations: boolean, + ) => { + const res = helpers.modifyBetterAuthOptions( + deepmerge(config?.defaultBetterAuthOptions || {}, opts), + ); + if (config.alwaysMigrate || shouldRunMigrations) { + didMigrateOnOptionsModify = true; + await helpers.runMigrations(); + } + return res; + }; + + const additionalOptions = { ...options }; + additionalOptions.disableTests = undefined; + + const fullTests = tests( + { + adapter: new Proxy({} as any, { + get(target, prop) { + const adapter = wrapperAdapter(); + if (prop === "transaction") { + return adapter.transaction; + } + const value = adapter[prop as keyof typeof adapter]; + if (typeof value === "function") { + return value.bind(adapter); + } + return value; + }, + }), + getAuth: async () => { + adapter = await helpers.adapter(); + const auth = betterAuth({ + ...helpers.getBetterAuthOptions(), + ...(config?.defaultBetterAuthOptions || {}), + database: (options: BetterAuthOptions) => { + const adapter = wrapperAdapter(options); + return adapter; + }, + }); + return auth; + }, + log: helpers.log, + generate: generateModel, + cleanup: cleanupCreatedRows, + hardCleanup: helpers.cleanup, + insertRandom, + modifyBetterAuthOptions, + getBetterAuthOptions: helpers.getBetterAuthOptions, + sortModels, + tryCatch, + customIdGenerator: helpers.customIdGenerator, + }, + additionalOptions as AdditionalOptions, + ); + + const dash = `─`; + const allDisabled: boolean = options?.disableTests?.ALL ?? false; + + // Here to display a label in the tests showing the suite name + test(`\n${colors.fg.white}${" ".repeat(3)}${dash.repeat(35)} [${colors.fg.magenta}${suiteName}${colors.fg.white}] ${dash.repeat(35)}`, async () => { + try { + await helpers.cleanup(); + } catch {} + if (config.defaultBetterAuthOptions && !allDisabled) { + await helpers.modifyBetterAuthOptions( + config.defaultBetterAuthOptions, + ); + if (config.alwaysMigrate) { + await helpers.runMigrations(); + } + } + }); + + const onFinish = async (testName: string) => { + await cleanupCreatedRows(); + await resetBetterAuthOptions(); + // Check if this is the last test by comparing current test index with total tests + const testEntries = Object.entries(fullTests); + const currentTestIndex = testEntries.findIndex( + ([name]) => + name === testName.replace(/\[.*?\] /, "").replace(/ ─ /g, " - "), + ); + const isLastTest = currentTestIndex === testEntries.length - 1; + + if (isLastTest) { + await helpers.onTestFinish(); + } + }; + + if (allDisabled) { + await resetBetterAuthOptions(); + } + + for (let [testName, testFn] of Object.entries(fullTests)) { + let shouldSkip = + (allDisabled && options?.disableTests?.[testName] !== false) || + (options?.disableTests?.[testName] ?? false); + testName = testName.replace( + " - ", + ` ${colors.dim}${dash}${colors.undim} `, + ); + if (config.prefixTests) { + testName = `${config.prefixTests} ${colors.dim}>${colors.undim} ${testName}`; + } + if (helpers.prefixTests) { + testName = `[${colors.dim}${helpers.prefixTests}${colors.undim}] ${testName}`; + } + + test.skipIf(shouldSkip)( + testName, + { retry: helpers?.defaultRetryCount ?? 10, timeout: 10000 }, + async ({ onTestFailed, skip }) => { + resetDebugLogs(); + onTestFailed(async () => { + printDebugLogs(); + await onFinish(testName); + }); + await testFn({ skip }); + await onFinish(testName); + }, + ); + } + }; + }; +}; diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/.gitignore b/packages/better-auth/src/adapters/drizzle-adapter/test/.gitignore new file mode 100644 index 00000000..f0de3173 --- /dev/null +++ b/packages/better-auth/src/adapters/drizzle-adapter/test/.gitignore @@ -0,0 +1,2 @@ +.tmp +drizzle \ No newline at end of file diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.mysql.test.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.mysql.test.ts index 60eb76d2..0f7304c8 100644 --- a/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.mysql.test.ts +++ b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.mysql.test.ts @@ -1,196 +1,83 @@ -import { afterAll, beforeAll, describe, expect, it } from "vitest"; -import * as schema from "./schema.mysql"; +import { drizzleAdapter } from "../drizzle-adapter"; +import { testAdapter } from "../../test-adapter"; import { - recoverProcessTZ, - runAdapterTest, - runNumberIdAdapterTest, -} from "../../test"; -import { drizzleAdapter } from ".."; -import { getMigrations } from "../../../db/get-migration"; + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; import { drizzle } from "drizzle-orm/mysql2"; -import type { BetterAuthOptions } from "../../../types"; -import { createPool, type Pool } from "mysql2/promise"; -import { Kysely, MysqlDialect } from "kysely"; -import { betterAuth } from "../../../auth"; -import merge from "deepmerge"; +import { generateDrizzleSchema, resetGenerationCount } from "./generate-schema"; +import { createPool } from "mysql2/promise"; +import { assert } from "vitest"; +import { execSync } from "child_process"; -const TEST_DB_MYSQL_URL = "mysql://user:password@localhost:3306/better_auth"; - -const createTestPool = () => createPool(TEST_DB_MYSQL_URL); - -const createKyselyInstance = (pool: any) => - new Kysely({ - dialect: new MysqlDialect({ pool }), - }); - -const cleanupDatabase = async (mysql: Pool, shouldDestroy = true) => { - try { - await mysql.query("DROP DATABASE IF EXISTS better_auth"); - await mysql.query("CREATE DATABASE better_auth"); - await mysql.query("USE better_auth"); - } catch (error) { - console.log(error); - } - if (shouldDestroy) { - await mysql.end(); - } else { - await new Promise((resolve) => setTimeout(resolve, 1000)); - } -}; - -const createTestOptions = (pool: any, useNumberId = false) => - ({ - database: pool, - user: { - fields: { email: "email_address" }, - additionalFields: { - test: { - type: "string", - defaultValue: "test", - }, - }, - }, - session: { - modelName: "sessions", - }, - advanced: { - database: { - useNumberId, - }, - }, - }) satisfies BetterAuthOptions; - -describe("Drizzle Adapter Tests (MySQL)", async () => { - let pool: any; - let mysql: Kysely; - - pool = createTestPool(); - mysql = createKyselyInstance(pool); - let opts = createTestOptions(pool); - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - - const db = drizzle({ - client: pool, - }); - const adapter = drizzleAdapter(db, { - provider: "mysql", - schema, - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - const db = opts.database; - opts.database = undefined; - const merged = merge(opts, customOptions); - merged.database = db; - return adapter(merged); - }, - }); +const mysqlDB = createPool({ + uri: "mysql://user:password@localhost:3306", + timezone: "Z", }); -describe("Drizzle Adapter Authentication Flow Tests (MySQL)", async () => { - const pool = createTestPool(); - const opts = createTestOptions(pool); - const testUser = { - email: "test-email@email.com", - password: "password", - name: "Test Name", - }; - - beforeAll(async () => { - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - }); - - const auth = betterAuth({ - ...opts, - database: drizzleAdapter(drizzle({ client: pool }), { - provider: "mysql", +const { execute } = await testAdapter({ + adapter: async (options) => { + const { schema } = await generateDrizzleSchema(mysqlDB, options, "mysql"); + return drizzleAdapter(drizzle(mysqlDB), { + debugLogs: { isRunningAdapterTests: true }, schema, - }), - emailAndPassword: { - enabled: true, - }, - }); - - it("should successfully sign up a new user", async () => { - const user = await auth.api.signUpEmail({ body: testUser }); - expect(user).toBeDefined(); - expect(user.user.id).toBeDefined(); - }); - - it("should successfully sign in an existing user", async () => { - const user = await auth.api.signInEmail({ body: testUser }); - expect(user.user).toBeDefined(); - expect(user.user.id).toBeDefined(); - }); - - it("stores and retrieves timestamps correctly across timezones", async () => { - using _ = recoverProcessTZ(); - - const sampleUser = { - name: "sample", - email: "sampler@test.com", - password: "samplerrrrr", - }; - - process.env.TZ = "Europe/London"; - const userSignUp = await auth.api.signUpEmail({ - body: { - name: sampleUser.name, - email: sampleUser.email, - password: sampleUser.password, - }, - }); - process.env.TZ = "America/Los_Angeles"; - const userSignIn = await auth.api.signInEmail({ - body: { email: sampleUser.email, password: sampleUser.password }, + provider: "mysql", }); + }, + async runMigrations(betterAuthOptions) { + await mysqlDB.query("DROP DATABASE IF EXISTS better_auth"); + await mysqlDB.query("CREATE DATABASE better_auth"); + await mysqlDB.query("USE better_auth"); - expect(userSignUp.user.createdAt).toStrictEqual(userSignIn.user.createdAt); - }); + const { fileName } = await generateDrizzleSchema( + mysqlDB, + betterAuthOptions, + "mysql", + ); + + const command = `npx drizzle-kit push --dialect=mysql --schema=${fileName}.ts --url=mysql://user:password@localhost:3306/better_auth`; + console.log(`Running: ${command}`); + console.log(`Options:`, betterAuthOptions); + try { + // wait for the above console.log to be printed + await new Promise((resolve) => setTimeout(resolve, 10)); + execSync(command, { + cwd: import.meta.dirname, + stdio: "inherit", + }); + } catch (error) { + console.error("Failed to push drizzle schema (mysql):", error); + throw error; + } + + // ensure migrations were run successfully + const [tables_result] = (await mysqlDB.query("SHOW TABLES")) as unknown as [ + { Tables_in_better_auth: string }[], + ]; + const tables = tables_result.map((table) => table.Tables_in_better_auth); + assert(tables.length > 0, "No tables found"); + assert( + !["user", "session", "account", "verification"].find( + (x) => !tables.includes(x), + ), + "No tables found", + ); + }, + prefixTests: "mysql", + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect: "mysql" }), + ], + async onFinish() { + await mysqlDB.end(); + resetGenerationCount(); + }, }); -describe("Drizzle Adapter Number Id Test (MySQL)", async () => { - let pool: any; - let mysql: Kysely; - - pool = createTestPool(); - mysql = createKyselyInstance(pool); - let opts = createTestOptions(pool, true); - - beforeAll(async () => { - await cleanupDatabase(pool, false); - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - }); - - afterAll(async () => { - await cleanupDatabase(pool); - }); - - const db = drizzle({ - client: pool, - }); - const adapter = drizzleAdapter(db, { - provider: "mysql", - schema, - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - runNumberIdAdapterTest({ - getAdapter: async (customOptions = {}) => { - const db = opts.database; - opts.database = undefined; - const merged = merge(opts, customOptions); - merged.database = db; - return adapter(merged); - }, - }); -}); +execute(); diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.pg.test.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.pg.test.ts new file mode 100644 index 00000000..f8b34eac --- /dev/null +++ b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.pg.test.ts @@ -0,0 +1,73 @@ +import { drizzleAdapter } from "../drizzle-adapter"; +import { testAdapter } from "../../test-adapter"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { drizzle } from "drizzle-orm/node-postgres"; +import { generateDrizzleSchema, resetGenerationCount } from "./generate-schema"; +import { Pool } from "pg"; +import { execSync } from "child_process"; + +const pgDB = new Pool({ + connectionString: "postgres://user:password@localhost:5432/better_auth", +}); + +const cleanupDatabase = async (shouldDestroy = false) => { + await pgDB.query(`DROP SCHEMA public CASCADE; CREATE SCHEMA public;`); + if (shouldDestroy) { + await pgDB.end(); + } +}; + +const { execute } = await testAdapter({ + adapter: async (options) => { + const { schema } = await generateDrizzleSchema(pgDB, options, "pg"); + return drizzleAdapter(drizzle(pgDB), { + debugLogs: { isRunningAdapterTests: true }, + schema, + provider: "pg", + transaction: true, + }); + }, + async runMigrations(betterAuthOptions) { + await cleanupDatabase(); + const { fileName } = await generateDrizzleSchema( + pgDB, + betterAuthOptions, + "pg", + ); + + const command = `npx drizzle-kit push --dialect=postgresql --schema=${fileName}.ts --url=postgres://user:password@localhost:5432/better_auth`; + console.log(`Running: ${command}`); + console.log(`Options:`, betterAuthOptions); + try { + // wait for the above console.log to be printed + await new Promise((resolve) => setTimeout(resolve, 10)); + execSync(command, { + cwd: import.meta.dirname, + stdio: "inherit", + }); + } catch (error) { + console.error("Failed to push drizzle schema (pg):", error); + throw error; + } + }, + prefixTests: "pg", + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect: "pg" }), + ], + async onFinish() { + await cleanupDatabase(true); + resetGenerationCount(); + }, +}); + +execute(); diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.sqlite.test.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.sqlite.test.ts new file mode 100644 index 00000000..2c8f9fca --- /dev/null +++ b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.sqlite.test.ts @@ -0,0 +1,77 @@ +import Database from "better-sqlite3"; +import { drizzleAdapter } from "../drizzle-adapter"; +import { testAdapter } from "../../test-adapter"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { drizzle } from "drizzle-orm/better-sqlite3"; +import path from "path"; +import { + clearSchemaCache, + generateDrizzleSchema, + resetGenerationCount, +} from "./generate-schema"; +import fs from "fs/promises"; +import { execSync } from "child_process"; + +const dbFilePath = path.join(import.meta.dirname, "test.db"); +let sqliteDB = new Database(dbFilePath); + +const { execute } = await testAdapter({ + adapter: async (options) => { + const { schema } = await generateDrizzleSchema(sqliteDB, options, "sqlite"); + return drizzleAdapter(drizzle(sqliteDB), { + debugLogs: { isRunningAdapterTests: true }, + schema, + provider: "sqlite", + }); + }, + async runMigrations(betterAuthOptions) { + sqliteDB.close(); + try { + await fs.unlink(dbFilePath); + } catch { + console.log("db file not found"); + } + sqliteDB = new Database(dbFilePath); + + const { fileName } = await generateDrizzleSchema( + sqliteDB, + betterAuthOptions, + "sqlite", + ); + + const command = `npx drizzle-kit push --dialect=sqlite --schema=${fileName}.ts --url=./test.db`; + console.log(`Running: ${command}`); + console.log(`Options:`, betterAuthOptions); + try { + // wait for the above console.log to be printed + await new Promise((resolve) => setTimeout(resolve, 10)); + execSync(command, { + cwd: import.meta.dirname, + stdio: "inherit", + }); + } catch (error) { + console.error("Failed to push drizzle schema (sqlite):", error); + throw error; + } + }, + prefixTests: "sqlite", + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect: "sqlite" }), + ], + async onFinish() { + clearSchemaCache(); + resetGenerationCount(); + }, +}); + +execute(); diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.test.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.test.ts deleted file mode 100644 index 0b49c98a..00000000 --- a/packages/better-auth/src/adapters/drizzle-adapter/test/adapter.drizzle.test.ts +++ /dev/null @@ -1,183 +0,0 @@ -import merge from "deepmerge"; -import { afterAll, beforeAll, describe, expect, it } from "vitest"; -import * as schema from "./schema"; -import { - recoverProcessTZ, - runAdapterTest, - runNumberIdAdapterTest, -} from "../../test"; -import { drizzleAdapter } from ".."; -import { getMigrations } from "../../../db/get-migration"; -import { drizzle } from "drizzle-orm/node-postgres"; -import type { BetterAuthOptions } from "../../../types"; -import { Pool } from "pg"; -import { Kysely, PostgresDialect, sql } from "kysely"; -import { betterAuth } from "../../../auth"; - -const TEST_DB_URL = "postgres://user:password@localhost:5432/better_auth"; - -const createTestPool = () => new Pool({ connectionString: TEST_DB_URL }); - -const createKyselyInstance = (pool: Pool) => - new Kysely({ - dialect: new PostgresDialect({ pool }), - }); - -const cleanupDatabase = async (postgres: Kysely, shouldDestroy = true) => { - await sql`DROP SCHEMA public CASCADE; CREATE SCHEMA public;`.execute( - postgres, - ); - if (shouldDestroy) { - await postgres.destroy(); - } -}; - -const createTestOptions = (pg: Pool, useNumberId = false) => - ({ - database: pg, - user: { - fields: { email: "email_address" }, - additionalFields: { - test: { - type: "string", - defaultValue: "test", - }, - }, - }, - session: { - modelName: "sessions", - }, - advanced: { - database: { - useNumberId, - }, - }, - }) satisfies BetterAuthOptions; - -describe("Drizzle Adapter Tests", async () => { - let pg: Pool; - let postgres: Kysely; - pg = createTestPool(); - postgres = createKyselyInstance(pg); - const opts = createTestOptions(pg); - await cleanupDatabase(postgres, false); - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - - afterAll(async () => { - await cleanupDatabase(postgres); - }); - const db = drizzle(pg); - const adapter = drizzleAdapter(db, { provider: "pg", schema }); - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - const db = opts.database; - //@ts-expect-error - opts.database = undefined; - const merged = merge(opts, customOptions); - merged.database = db; - return adapter(merged); - }, - }); -}); - -describe("Drizzle Adapter Authentication Flow Tests", async () => { - const pg = createTestPool(); - let postgres: Kysely; - const opts = createTestOptions(pg); - const testUser = { - email: "test-email@email.com", - password: "password", - name: "Test Name", - }; - beforeAll(async () => { - postgres = createKyselyInstance(pg); - - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - }); - - const auth = betterAuth({ - ...opts, - database: drizzleAdapter(drizzle(pg), { provider: "pg", schema }), - emailAndPassword: { - enabled: true, - }, - }); - - afterAll(async () => { - await cleanupDatabase(postgres); - }); - - it("should successfully sign up a new user", async () => { - const user = await auth.api.signUpEmail({ body: testUser }); - expect(user).toBeDefined(); - }); - - it("should successfully sign in an existing user", async () => { - const user = await auth.api.signInEmail({ body: testUser }); - expect(user.user).toBeDefined(); - }); - - it("stores and retrieves timestamps correctly across timezones", async () => { - using _ = recoverProcessTZ(); - - const sampleUser = { - name: "sample", - email: "sampler@test.com", - password: "samplerrrrr", - }; - - process.env.TZ = "Europe/London"; - const userSignUp = await auth.api.signUpEmail({ - body: { - name: sampleUser.name, - email: sampleUser.email, - password: sampleUser.password, - }, - }); - process.env.TZ = "America/Los_Angeles"; - const userSignIn = await auth.api.signInEmail({ - body: { email: sampleUser.email, password: sampleUser.password }, - }); - - expect(userSignUp.user.createdAt).toStrictEqual(userSignIn.user.createdAt); - }); -}); - -describe("Drizzle Adapter Number Id Test", async () => { - let pg: Pool; - let postgres: Kysely; - pg = createTestPool(); - postgres = createKyselyInstance(pg); - const opts = createTestOptions(pg, true); - beforeAll(async () => { - await cleanupDatabase(postgres, false); - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - }); - - afterAll(async () => { - await cleanupDatabase(postgres); - }); - const db = drizzle(pg); - const adapter = drizzleAdapter(db, { - provider: "pg", - schema, - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - runNumberIdAdapterTest({ - getAdapter: async (customOptions = {}) => { - const db = opts.database; - //@ts-expect-error - opts.database = undefined; - const merged = merge(opts, customOptions); - merged.database = db; - return adapter(merged); - }, - }); -}); diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/generate-schema.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/generate-schema.ts new file mode 100644 index 00000000..5ca2a887 --- /dev/null +++ b/packages/better-auth/src/adapters/drizzle-adapter/test/generate-schema.ts @@ -0,0 +1,98 @@ +import type { Adapter, BetterAuthOptions } from "../../../types"; +import { drizzleAdapter } from "../drizzle-adapter"; +import fs from "fs/promises"; +import { join } from "path"; + +let generationCount = 0; + +const schemaCache = new Map(); + +/** + * generates a drizzle schema based on BetterAuthOptions & a given dialect. + * + * Useful for testing the Drizzle adapter. + */ +export const generateDrizzleSchema = async ( + db: any, + options: BetterAuthOptions, + dialect: "sqlite" | "mysql" | "pg", +) => { + const cacheKey = `${dialect}-${JSON.stringify(options)}`; + if (schemaCache.has(cacheKey)) { + const { count, schema } = schemaCache.get(cacheKey)!; + return { + schema, + fileName: `./.tmp/generated-${dialect}-schema-${count}`, + }; + } + generationCount++; + let thisCount = generationCount; + const i = async (x: string) => { + // Clear the Node.js module cache for the generated schema file to ensure fresh import + try { + const resolvedPath = + require?.resolve?.(x) || + (import.meta && new URL(x, import.meta.url).pathname); + if (resolvedPath && typeof resolvedPath === "string" && require?.cache) { + delete require.cache[resolvedPath]; + } + } catch (error) {} + return await import(x); + }; + + const { generateSchema } = (await i( + "./../../../../../cli/src/generators/index", + )) as { + generateSchema: (opts: { + adapter: Adapter; + file?: string; + options: BetterAuthOptions; + }) => Promise<{ + code: string | undefined; + fileName: string; + overwrite: boolean | undefined; + }>; + }; + + const exists = await fs + .access(join(import.meta.dirname, `/.tmp`)) + .then(() => true) + .catch(() => false); + if (!exists) { + await fs.mkdir(join(import.meta.dirname, `/.tmp`), { recursive: true }); + } + + let adapter = drizzleAdapter(db, { provider: dialect })(options); + + let { code } = await generateSchema({ + adapter, + options, + }); + + await fs.writeFile( + join( + import.meta.dirname, + `/.tmp/generated-${dialect}-schema-${thisCount}.ts`, + ), + code || "", + "utf-8", + ); + + const res = await i(`./.tmp/generated-${dialect}-schema-${thisCount}`); + schemaCache.set(cacheKey, { + count: thisCount, + schema: res, + }); + return { + schema: res, + fileName: `./.tmp/generated-${dialect}-schema-${thisCount}`, + }; +}; + +export const clearSchemaCache = () => { + schemaCache.clear(); +}; + +export const resetGenerationCount = () => { + generationCount = 0; +}; diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/schema.mysql.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/schema.mysql.ts deleted file mode 100644 index e9de9845..00000000 --- a/packages/better-auth/src/adapters/drizzle-adapter/test/schema.mysql.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { boolean, text, varchar, datetime } from "drizzle-orm/mysql-core"; -import { mysqlTable } from "drizzle-orm/mysql-core"; - -export const user = mysqlTable("user", { - id: varchar("id", { length: 255 }).primaryKey(), - name: varchar("name", { length: 255 }).notNull(), - email_address: varchar("email_address", { length: 255 }).notNull().unique(), - emailVerified: boolean("emailVerified").notNull(), - test: text("test").notNull(), - image: text("image"), - createdAt: datetime("createdAt", { mode: "date" }).notNull(), // Use `date` mode - updatedAt: datetime("updatedAt", { mode: "date" }).notNull(), // Use `date` mode -}); - -export const sessions = mysqlTable("sessions", { - id: varchar("id", { length: 255 }).primaryKey(), - expiresAt: datetime("expiresAt", { mode: "date" }).notNull(), // Use `date` mode - ipAddress: varchar("ipAddress", { length: 255 }), - userAgent: varchar("userAgent", { length: 255 }), - token: varchar("token", { length: 255 }).notNull(), - createdAt: datetime("createdAt", { mode: "date" }).notNull(), // Use `date` mode - updatedAt: datetime("updatedAt", { mode: "date" }).notNull(), // Use `date` mode - userId: varchar("userId", { length: 255 }) - .notNull() - .references(() => user.id), -}); - -export const account = mysqlTable("account", { - id: varchar("id", { length: 255 }).primaryKey(), - accountId: varchar("accountId", { length: 255 }).notNull(), - providerId: varchar("providerId", { length: 255 }).notNull(), - userId: varchar("userId", { length: 255 }) - .notNull() - .references(() => user.id), - accessToken: text("accessToken"), - createdAt: datetime("createdAt", { mode: "date" }).notNull(), // Use `date` mode - updatedAt: datetime("updatedAt", { mode: "date" }).notNull(), // Use `date` mode - refreshToken: text("refreshToken"), - idToken: text("idToken"), - accessTokenExpiresAt: datetime("accessTokenExpiresAt", { mode: "date" }), - refreshTokenExpiresAt: datetime("refreshTokenExpiresAt", { mode: "date" }), - scope: text("scope"), - password: text("password"), -}); - -export const verification = mysqlTable("verification", { - id: varchar("id", { length: 255 }).primaryKey(), - identifier: varchar("identifier", { length: 255 }).notNull(), - value: varchar("value", { length: 255 }).notNull(), - expiresAt: datetime("expiresAt", { mode: "date" }).notNull(), // Use `date` mode - createdAt: datetime("createdAt", { mode: "date" }).notNull(), // Use `date` mode - updatedAt: datetime("updatedAt", { mode: "date" }).notNull(), // Use `date` mode -}); diff --git a/packages/better-auth/src/adapters/drizzle-adapter/test/schema.ts b/packages/better-auth/src/adapters/drizzle-adapter/test/schema.ts deleted file mode 100644 index 5f9c640f..00000000 --- a/packages/better-auth/src/adapters/drizzle-adapter/test/schema.ts +++ /dev/null @@ -1,64 +0,0 @@ -/* - -This file is used explicitly for testing purposes. - -It's not used in the production code. - -For information on how to use the drizzle-adapter, please refer to the documentation. - -https://www.better-auth.com/docs/adapters/drizzle - -*/ -import { boolean, text, timestamp } from "drizzle-orm/pg-core"; -import { pgTable } from "drizzle-orm/pg-core"; - -export const user = pgTable("user", { - id: text("id").primaryKey(), - name: text("name").notNull(), - email_address: text("email_address").notNull().unique(), - emailVerified: boolean("emailVerified").notNull(), - test: text("test").notNull(), - image: text("image"), - createdAt: timestamp("createdAt").notNull(), - updatedAt: timestamp("updatedAt").notNull(), -}); - -export const sessions = pgTable("sessions", { - id: text("id").primaryKey(), - expiresAt: timestamp("expiresAt").notNull(), - ipAddress: text("ipAddress"), - userAgent: text("userAgent"), - token: text("token").notNull(), - createdAt: timestamp("createdAt").notNull(), - updatedAt: timestamp("updatedAt").notNull(), - userId: text("userId") - .notNull() - .references(() => user.id), -}); - -export const account = pgTable("account", { - id: text("id").primaryKey(), - accountId: text("accountId").notNull(), - providerId: text("providerId").notNull(), - userId: text("userId") - .notNull() - .references(() => user.id), - accessToken: text("accessToken"), - createdAt: timestamp("createdAt").notNull(), - updatedAt: timestamp("updatedAt").notNull(), - refreshToken: text("refreshToken"), - idToken: text("idToken"), - accessTokenExpiresAt: timestamp("accessTokenExpiresAt"), - refreshTokenExpiresAt: timestamp("refreshTokenExpiresAt"), - scope: text("scope"), - password: text("password"), -}); - -export const verification = pgTable("verification", { - id: text("id").primaryKey(), - identifier: text("identifier").notNull(), - value: text("value").notNull(), - expiresAt: timestamp("expiresAt").notNull(), - createdAt: timestamp("createdAt").notNull(), - updatedAt: timestamp("updatedAt").notNull(), -}); diff --git a/packages/better-auth/src/adapters/kysely-adapter/kysely-adapter.ts b/packages/better-auth/src/adapters/kysely-adapter/kysely-adapter.ts index bb778ae7..99cfb4dc 100644 --- a/packages/better-auth/src/adapters/kysely-adapter/kysely-adapter.ts +++ b/packages/better-auth/src/adapters/kysely-adapter/kysely-adapter.ts @@ -6,8 +6,11 @@ import { } from "../adapter-factory"; import type { Adapter, BetterAuthOptions, Where } from "../../types"; import type { KyselyDatabaseType } from "./types"; -import type { InsertQueryBuilder, Kysely, UpdateQueryBuilder } from "kysely"; -import { ensureUTC } from "../../utils/ensure-utc"; +import { + type InsertQueryBuilder, + type Kysely, + type UpdateQueryBuilder, +} from "kysely"; interface KyselyAdapterConfig { /** @@ -44,7 +47,7 @@ export const kyselyAdapter = ( const createCustomAdapter = ( db: Kysely, ): AdapterFactoryCustomizeAdapterCreator => { - return ({ getFieldName, schema }) => { + return ({ getFieldName, schema, getDefaultModelName }) => { const withReturning = async ( values: Record, builder: @@ -110,7 +113,8 @@ export const kyselyAdapter = ( return value ? 1 : 0; } if (f!.type === "date" && value && value instanceof Date) { - return type === "sqlite" ? value.toISOString() : value; + if (type === "sqlite") return value.toISOString(); + return value; } return value; } @@ -123,7 +127,7 @@ export const kyselyAdapter = ( const field = obj[key]; if (field instanceof Date && config?.type === "mysql") { - obj[key] = ensureUTC(field); + // obj[key] = ensureUTC(field); } else if (typeof field === "object" && field !== null) { transformObject(field); } @@ -232,10 +236,8 @@ export const kyselyAdapter = ( return { async create({ data, model }) { const builder = db.insertInto(model).values(data); - - return transformValueFromDB( - await withReturning(data, builder, model, []), - ) as any; + const returned = await withReturning(data, builder, model, []); + return transformValueFromDB(returned) as any; }, async findOne({ model, where, select }) { const { and, or } = convertWhereClause(model, where); @@ -326,7 +328,13 @@ export const kyselyAdapter = ( query = query.where((eb) => eb.or(or.map((expr) => expr(eb)))); } const res = await query.execute(); - return res[0]!.count as number; + if (typeof res[0]!.count === "number") { + return res[0]!.count; + } + if (typeof res[0]!.count === "bigint") { + return Number(res[0]!.count); + } + return parseInt(res[0]!.count); }, async delete({ model, where }) { const { and, or } = convertWhereClause(model, where); @@ -363,7 +371,10 @@ export const kyselyAdapter = ( usePlural: config?.usePlural, debugLogs: config?.debugLogs, supportsBooleans: - config?.type === "sqlite" || config?.type === "mssql" || !config?.type + config?.type === "sqlite" || + config?.type === "mssql" || + config?.type === "mysql" || + !config?.type ? false : true, supportsDates: diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.mssql.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.mssql.test.ts new file mode 100644 index 00000000..c32d3f5b --- /dev/null +++ b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.mssql.test.ts @@ -0,0 +1,317 @@ +import { Kysely, MssqlDialect } from "kysely"; +import { testAdapter } from "../../test-adapter"; +import { kyselyAdapter } from "../kysely-adapter"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { getMigrations } from "../../../db"; +import * as Tedious from "tedious"; +import * as Tarn from "tarn"; +import type { BetterAuthOptions } from "../../../types"; + +// We are not allowed to handle the mssql connection +// we must let kysely handle it. This is because if kysely is already +// handling it, and we were to connect it ourselves, it will create bugs. +const dialect = new MssqlDialect({ + tarn: { + ...Tarn, + options: { + min: 0, + max: 10, + }, + }, + tedious: { + ...Tedious, + connectionFactory: () => + new Tedious.Connection({ + authentication: { + options: { + password: "Password123!", + userName: "sa", + }, + type: "default", + }, + options: { + database: "master", // Start with master database, will create better_auth if needed + port: 1433, + trustServerCertificate: true, + encrypt: false, + }, + server: "localhost", + }), + TYPES: { + ...Tedious.TYPES, + DateTime: Tedious.TYPES.DateTime2, + }, + }, +}); + +const kyselyDB = new Kysely({ + dialect: dialect, +}); + +// Create better_auth database if it doesn't exist +const ensureDatabaseExists = async () => { + try { + console.log("Ensuring better_auth database exists..."); + await kyselyDB.getExecutor().executeQuery({ + sql: ` + IF NOT EXISTS (SELECT name FROM sys.databases WHERE name = 'better_auth') + BEGIN + CREATE DATABASE better_auth; + PRINT 'Database better_auth created successfully'; + END + ELSE + BEGIN + PRINT 'Database better_auth already exists'; + END + `, + parameters: [], + query: { kind: "SelectQueryNode" }, + queryId: { queryId: "ensure-db" }, + }); + console.log("Database check/creation completed"); + } catch (error) { + console.error("Failed to ensure database exists:", error); + throw error; + } +}; + +// Warm up connection for CI environments +const warmupConnection = async () => { + const isCI = + process.env.CI === "true" || process.env.GITHUB_ACTIONS === "true"; + if (isCI) { + console.log("Warming up MSSQL connection for CI environment..."); + console.log( + `Environment: CI=${process.env.CI}, GITHUB_ACTIONS=${process.env.GITHUB_ACTIONS}`, + ); + + try { + await ensureDatabaseExists(); + + // Try a simple query to establish the connection + await kyselyDB.getExecutor().executeQuery({ + sql: "SELECT 1 as warmup, @@VERSION as version", + parameters: [], + query: { kind: "SelectQueryNode" }, + queryId: { queryId: "warmup" }, + }); + console.log("Connection warmup successful"); + } catch (error) { + console.warn( + "Connection warmup failed, will retry during validation:", + error, + ); + // Log additional debugging info for CI + if (isCI) { + console.log("CI Debug Info:"); + console.log("- MSSQL server may not be ready yet"); + console.log("- Network connectivity issues possible"); + console.log("- Database may not exist yet"); + } + } + } else { + // For local development, also ensure database exists + await ensureDatabaseExists(); + } +}; + +// Add connection validation helper with CI-specific handling +const validateConnection = async (retries: number = 10): Promise => { + const isCI = + process.env.CI === "true" || process.env.GITHUB_ACTIONS === "true"; + const maxRetries = isCI ? 15 : retries; // More retries in CI + const baseDelay = isCI ? 2000 : 1000; // Longer delays in CI + + console.log( + `Validating connection (CI: ${isCI}, max retries: ${maxRetries})`, + ); + + for (let i = 0; i < maxRetries; i++) { + try { + await query("SELECT 1 as test", isCI ? 10000 : 5000); + console.log("Connection validated successfully"); + return true; + } catch (error) { + console.warn( + `Connection validation attempt ${i + 1}/${maxRetries} failed:`, + error, + ); + if (i === maxRetries - 1) { + console.error("All connection validation attempts failed"); + return false; + } + // Exponential backoff with longer delays in CI + const delay = baseDelay * Math.pow(1.5, i); + console.log(`Waiting ${delay}ms before retry...`); + await new Promise((resolve) => setTimeout(resolve, delay)); + } + } + return false; +}; + +const query = async (sql: string, timeoutMs: number = 30000) => { + const isCI = + process.env.CI === "true" || process.env.GITHUB_ACTIONS === "true"; + const actualTimeout = isCI ? Math.max(timeoutMs, 60000) : timeoutMs; // Minimum 60s timeout in CI + + try { + console.log( + `Executing SQL: ${sql.substring(0, 100)}... (timeout: ${actualTimeout}ms, CI: ${isCI})`, + ); + + // Ensure we're using the better_auth database for queries + const sqlWithContext = sql.includes("USE ") + ? sql + : `USE better_auth; ${sql}`; + + const result = (await Promise.race([ + kyselyDB.getExecutor().executeQuery({ + sql: sqlWithContext, + parameters: [], + query: { kind: "SelectQueryNode" }, + queryId: { queryId: "" }, + }), + new Promise((_, reject) => + setTimeout( + () => reject(new Error(`Query timeout after ${actualTimeout}ms`)), + actualTimeout, + ), + ), + ])) as any; + console.log(`Query completed successfully`); + return { rows: result.rows, rowCount: result.rows.length }; + } catch (error) { + console.error(`Query failed: ${error}`); + throw error; + } +}; + +const showDB = async () => { + const DB = { + users: await query("SELECT * FROM [user]"), + sessions: await query("SELECT * FROM [session]"), + accounts: await query("SELECT * FROM [account]"), + verifications: await query("SELECT * FROM [verification]"), + }; + console.log(`DB`, DB); +}; + +const resetDB = async (retryCount: number = 0) => { + const isCI = + process.env.CI === "true" || process.env.GITHUB_ACTIONS === "true"; + const maxRetries = isCI ? 3 : 1; // Allow retries in CI + + try { + console.log( + `Starting database reset... (attempt ${retryCount + 1}/${maxRetries + 1})`, + ); + + // Warm up connection first (especially important for CI) + await warmupConnection(); + + const isConnected = await validateConnection(); + if (!isConnected) { + throw new Error("Database connection validation failed"); + } + + // First, try to disable foreign key checks and drop constraints + await query( + ` + -- Disable all foreign key constraints + EXEC sp_MSforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT all"; + `, + 15000, + ); + + // Drop foreign key constraints + await query( + ` + DECLARE @sql NVARCHAR(MAX) = ''; + SELECT @sql = @sql + 'ALTER TABLE [' + TABLE_SCHEMA + '].[' + TABLE_NAME + '] DROP CONSTRAINT [' + CONSTRAINT_NAME + '];' + CHAR(13) + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS + WHERE CONSTRAINT_TYPE = 'FOREIGN KEY' + AND TABLE_CATALOG = DB_NAME(); + IF LEN(@sql) > 0 + EXEC sp_executesql @sql; + `, + 15000, + ); + + // Then drop all tables + await query( + ` + DECLARE @sql NVARCHAR(MAX) = ''; + SELECT @sql = @sql + 'DROP TABLE [' + TABLE_NAME + '];' + CHAR(13) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND TABLE_CATALOG = DB_NAME() + AND TABLE_SCHEMA = 'dbo'; + IF LEN(@sql) > 0 + EXEC sp_executesql @sql; + `, + 15000, + ); + + console.log("Database reset completed successfully"); + } catch (error) { + console.error("Database reset failed:", error); + + // Retry logic for CI environments + if (retryCount < maxRetries) { + const delay = 5000 * (retryCount + 1); // Increasing delay + console.log( + `Retrying in ${delay}ms... (attempt ${retryCount + 2}/${maxRetries + 1})`, + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + return resetDB(retryCount + 1); + } + + // Final fallback - try to recreate the database + try { + console.log("Attempting database recreation..."); + // This would require a separate connection to master database + // For now, just throw the error with better context + throw new Error(`Database reset failed completely: ${error}`); + } catch (finalError) { + console.error("Final fallback also failed:", finalError); + throw new Error( + `Database reset failed: ${error}. All fallback attempts failed: ${finalError}`, + ); + } + } +}; + +const { execute } = await testAdapter({ + adapter: () => { + return kyselyAdapter(kyselyDB, { + type: "mssql", + debugLogs: { isRunningAdapterTests: true }, + }); + }, + async runMigrations(betterAuthOptions) { + await resetDB(); + const opts = Object.assign(betterAuthOptions, { + database: { db: kyselyDB, type: "mssql" }, + } satisfies BetterAuthOptions); + const { runMigrations } = await getMigrations(opts); + await runMigrations(); + }, + prefixTests: "mssql", + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite({ showDB }), + numberIdTestSuite(), + performanceTestSuite({ dialect: "mssql" }), + ], + async onFinish() { + kyselyDB.destroy(); + }, +}); +execute(); diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.mysql.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.mysql.test.ts new file mode 100644 index 00000000..abc16a5f --- /dev/null +++ b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.mysql.test.ts @@ -0,0 +1,63 @@ +import { Kysely, MysqlDialect } from "kysely"; +import { testAdapter } from "../../test-adapter"; +import { kyselyAdapter } from "../kysely-adapter"; +import { createPool } from "mysql2/promise"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { getMigrations } from "../../../db"; +import { assert } from "vitest"; + +const mysqlDB = createPool({ + uri: "mysql://user:password@localhost:3307/better_auth", + timezone: "Z", +}); + +let kyselyDB = new Kysely({ + dialect: new MysqlDialect(mysqlDB), +}); + +const { execute } = await testAdapter({ + adapter: () => + kyselyAdapter(kyselyDB, { + type: "mysql", + debugLogs: { isRunningAdapterTests: true }, + }), + async runMigrations(betterAuthOptions) { + await mysqlDB.query("DROP DATABASE IF EXISTS better_auth"); + await mysqlDB.query("CREATE DATABASE better_auth"); + await mysqlDB.query("USE better_auth"); + const opts = Object.assign(betterAuthOptions, { database: mysqlDB }); + const { runMigrations } = await getMigrations(opts); + await runMigrations(); + + // ensure migrations were run successfully + const [tables_result] = (await mysqlDB.query("SHOW TABLES")) as unknown as [ + { Tables_in_better_auth: string }[], + ]; + const tables = tables_result.map((table) => table.Tables_in_better_auth); + assert(tables.length > 0, "No tables found"); + assert( + !["user", "session", "account", "verification"].find( + (x) => !tables.includes(x), + ), + "No tables found", + ); + }, + prefixTests: "mysql", + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect: "mysql" }), + ], + async onFinish() { + await mysqlDB.end(); + }, +}); +execute(); diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.pg.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.pg.test.ts new file mode 100644 index 00000000..47f39ca2 --- /dev/null +++ b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.pg.test.ts @@ -0,0 +1,53 @@ +import { Kysely, PostgresDialect } from "kysely"; +import { testAdapter } from "../../test-adapter"; +import { kyselyAdapter } from "../kysely-adapter"; +import { Pool } from "pg"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { getMigrations } from "../../../db"; +import type { BetterAuthOptions } from "../../../types"; + +const pgDB = new Pool({ + connectionString: "postgres://user:password@localhost:5433/better_auth", +}); + +let kyselyDB = new Kysely({ + dialect: new PostgresDialect({ pool: pgDB }), +}); + +const cleanupDatabase = async () => { + await pgDB.query(`DROP SCHEMA public CASCADE; CREATE SCHEMA public;`); +}; + +const { execute } = await testAdapter({ + adapter: () => + kyselyAdapter(kyselyDB, { + type: "postgres", + debugLogs: { isRunningAdapterTests: true }, + }), + prefixTests: "pg", + async runMigrations(betterAuthOptions) { + await cleanupDatabase(); + const opts = Object.assign(betterAuthOptions, { + database: pgDB, + } satisfies BetterAuthOptions); + const { runMigrations } = await getMigrations(opts); + await runMigrations(); + }, + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect: "pg" }), + ], + async onFinish() { + await pgDB.end(); + }, +}); +execute(); diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.sqlite.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.sqlite.test.ts new file mode 100644 index 00000000..3b5dcc62 --- /dev/null +++ b/packages/better-auth/src/adapters/kysely-adapter/test/adapter.kysely.sqlite.test.ts @@ -0,0 +1,55 @@ +import { Kysely, SqliteDialect } from "kysely"; +import { testAdapter } from "../../test-adapter"; +import { kyselyAdapter } from "../kysely-adapter"; +import Database from "better-sqlite3"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import path from "path"; +import { getMigrations } from "../../../db"; +import fs from "fs/promises"; + +const dbPath = path.join(__dirname, "test.db"); +let database = new Database(dbPath); + +let kyselyDB = new Kysely({ + dialect: new SqliteDialect({ database }), +}); + +const { execute } = await testAdapter({ + adapter: () => { + return kyselyAdapter(kyselyDB, { + type: "sqlite", + debugLogs: { isRunningAdapterTests: true }, + }); + }, + prefixTests: "sqlite", + async runMigrations(betterAuthOptions) { + database.close(); + try { + await fs.unlink(dbPath); + } catch { + console.log("db doesnt exist"); + } + database = new Database(dbPath); + kyselyDB = new Kysely({ dialect: new SqliteDialect({ database }) }); + const opts = Object.assign(betterAuthOptions, { database }); + const { runMigrations } = await getMigrations(opts); + await runMigrations(); + }, + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect: "sqlite" }), + ], + async onFinish() { + database.close(); + }, +}); +execute(); diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/normal/node-sqlite-dialect.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/node-sqlite-dialect.test.ts similarity index 95% rename from packages/better-auth/src/adapters/kysely-adapter/test/normal/node-sqlite-dialect.test.ts rename to packages/better-auth/src/adapters/kysely-adapter/test/node-sqlite-dialect.test.ts index a95b24a7..6c5edea7 100644 --- a/packages/better-auth/src/adapters/kysely-adapter/test/normal/node-sqlite-dialect.test.ts +++ b/packages/better-auth/src/adapters/kysely-adapter/test/node-sqlite-dialect.test.ts @@ -1,10 +1,10 @@ import { describe, it, expect, beforeAll, afterAll } from "vitest"; import { Kysely, sql } from "kysely"; -import { NodeSqliteDialect } from "../../node-sqlite-dialect"; -import { kyselyAdapter } from "../../kysely-adapter"; -import { runAdapterTest } from "../../../test"; -import { getMigrations } from "../../../../db/get-migration"; -import type { BetterAuthOptions } from "../../../../types"; +import { NodeSqliteDialect } from "../node-sqlite-dialect"; +import { kyselyAdapter } from "../kysely-adapter"; +import { runAdapterTest } from "../../test"; +import { getMigrations } from "../../../db/get-migration"; +import type { BetterAuthOptions } from "../../../types"; import merge from "deepmerge"; import type { DatabaseSync } from "node:sqlite"; const nodeVersion = process.version; diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/normal/adapter.kysely.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/normal/adapter.kysely.test.ts deleted file mode 100644 index 34c5a3eb..00000000 --- a/packages/better-auth/src/adapters/kysely-adapter/test/normal/adapter.kysely.test.ts +++ /dev/null @@ -1,558 +0,0 @@ -import merge from "deepmerge"; -import fsPromises from "fs/promises"; -import { afterAll, beforeAll, describe, expect, it } from "vitest"; -import { recoverProcessTZ, runAdapterTest } from "../../../test"; -import { getMigrations } from "../../../../db/get-migration"; -import path from "path"; -import Database from "better-sqlite3"; -import { kyselyAdapter } from "../.."; -import { - Kysely, - MysqlDialect, - PostgresDialect, - sql, - SqliteDialect, -} from "kysely"; -import type { BetterAuthOptions } from "../../../../types"; -import { createPool } from "mysql2/promise"; - -import * as tedious from "tedious"; -import * as tarn from "tarn"; -import { MssqlDialect } from "kysely"; -import { getTestInstance } from "../../../../test-utils/test-instance"; -import { setState } from "../state"; -import { Pool } from "pg"; - -describe("adapter test", async () => { - const sqlite = new Database(path.join(__dirname, "test.db")); - const mysql = createPool("mysql://user:password@localhost:3306/better_auth"); - const sqliteKy = new Kysely({ - dialect: new SqliteDialect({ - database: sqlite, - }), - }); - const mysqlKy = new Kysely({ - dialect: new MysqlDialect(mysql), - }); - - const opts = ({ - database, - isNumberIdTest, - }: { - database: BetterAuthOptions["database"]; - isNumberIdTest: boolean; - }) => - ({ - database: database, - user: { - fields: { - email: "email_address", - }, - additionalFields: { - test: { - type: "string", - defaultValue: "test", - }, - }, - }, - session: { - modelName: "sessions", - }, - advanced: { - database: { - useNumberId: isNumberIdTest, - }, - }, - }) satisfies BetterAuthOptions; - - const mysqlOptions = opts({ - database: { - db: mysqlKy, - type: "mysql", - }, - isNumberIdTest: false, - }); - - const sqliteOptions = opts({ - database: { - db: sqliteKy, - type: "sqlite", - }, - isNumberIdTest: false, - }); - beforeAll(async () => { - setState("RUNNING"); - console.log(`Now running Number ID Kysely adapter test...`); - await (await getMigrations(mysqlOptions)).runMigrations(); - await (await getMigrations(sqliteOptions)).runMigrations(); - }); - - afterAll(async () => { - await mysql.query("DROP DATABASE IF EXISTS better_auth"); - await mysql.query("CREATE DATABASE better_auth"); - await mysql.end(); - await fsPromises.unlink(path.join(__dirname, "test.db")); - }); - - const mysqlAdapter = kyselyAdapter(mysqlKy, { - type: "mysql", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - return mysqlAdapter(merge(customOptions, mysqlOptions)); - }, - testPrefix: "mysql", - }); - - const sqliteAdapter = kyselyAdapter(sqliteKy, { - type: "sqlite", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - return sqliteAdapter(merge(customOptions, sqliteOptions)); - }, - testPrefix: "sqlite", - }); -}); - -describe("mssql", async () => { - const dialect = new MssqlDialect({ - tarn: { - ...tarn, - options: { - min: 0, - max: 10, - }, - }, - tedious: { - ...tedious, - connectionFactory: () => - new tedious.Connection({ - authentication: { - options: { - password: "Password123!", - userName: "sa", - }, - type: "default", - }, - options: { - port: 1433, - trustServerCertificate: true, - }, - server: "localhost", - }), - }, - }); - const opts = { - database: dialect, - user: { - modelName: "users", - }, - } satisfies BetterAuthOptions; - beforeAll(async () => { - const { runMigrations, toBeAdded, toBeCreated } = await getMigrations(opts); - await runMigrations(); - return async () => { - await resetDB(); - console.log( - `Normal Kysely adapter test finished. Now allowing number ID Kysely tests to run.`, - ); - setState("IDLE"); - }; - }); - const mssql = new Kysely({ - dialect: dialect, - }); - const getAdapter = kyselyAdapter(mssql, { - type: "mssql", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - async function resetDB() { - await sql`DROP TABLE dbo.session;`.execute(mssql); - await sql`DROP TABLE dbo.verification;`.execute(mssql); - await sql`DROP TABLE dbo.account;`.execute(mssql); - await sql`DROP TABLE dbo.users;`.execute(mssql); - } - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - // const merged = merge( customOptions,opts); - // merged.database = opts.database; - return getAdapter(opts); - }, - disableTests: { - SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true, - }, - }); - - describe("simple flow", async () => { - const { auth } = await getTestInstance( - { - database: dialect, - user: { - modelName: "users", - }, - }, - { - disableTestUser: true, - }, - ); - it("should sign-up", async () => { - const res = await auth.api.signUpEmail({ - body: { - name: "test", - password: "password", - email: "test-2@email.com", - }, - }); - expect(res.user.name).toBe("test"); - expect(res.token?.length).toBeTruthy(); - }); - - let token = ""; - it("should sign in", async () => { - const signInRes = await auth.api.signInEmail({ - body: { - password: "password", - email: "test-2@email.com", - }, - }); - - expect(signInRes.token?.length).toBeTruthy(); - expect(signInRes.user.name).toBe("test"); - token = signInRes.token; - }); - - it("should return session", async () => { - const session = await auth.api.getSession({ - headers: new Headers({ - Authorization: `Bearer ${token}`, - }), - }); - expect(session).toMatchObject({ - session: { - token, - userId: expect.any(String), - }, - user: { - name: "test", - email: "test-2@email.com", - }, - }); - }); - - it("stores and retrieves timestamps correctly across timezones", async () => { - using _ = recoverProcessTZ(); - - const sampleUser = { - name: "sample", - email: "sampler@test.com", - password: "samplerrrrr", - }; - - process.env.TZ = "Europe/London"; - const userSignUp = await auth.api.signUpEmail({ - body: { - name: sampleUser.name, - email: sampleUser.email, - password: sampleUser.password, - }, - }); - process.env.TZ = "America/Los_Angeles"; - const userSignIn = await auth.api.signInEmail({ - body: { email: sampleUser.email, password: sampleUser.password }, - }); - - expect(userSignUp.user.createdAt).toStrictEqual( - userSignIn.user.createdAt, - ); - }); - }); -}); - -describe("postgres", async () => { - const pool = new Pool({ - connectionString: "postgres://user:password@localhost:5432/better_auth", - }); - - const dialect = new PostgresDialect({ - pool, - }); - - const opts = { - database: dialect, - user: { - modelName: "users", - }, - } satisfies BetterAuthOptions; - - beforeAll(async () => { - const { runMigrations, toBeAdded, toBeCreated } = await getMigrations(opts); - await runMigrations(); - return async () => { - await resetDB(); - await pool.end(); - setState("IDLE"); - }; - }); - - const pg = new Kysely({ dialect }); - - const getAdapter = kyselyAdapter(pg, { - type: "postgres", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - async function resetDB() { - await sql`DROP TABLE session;`.execute(pg); - await sql`DROP TABLE verification;`.execute(pg); - await sql`DROP TABLE account;`.execute(pg); - await sql`DROP TABLE users;`.execute(pg); - } - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - // const merged = merge( customOptions,opts); - // merged.database = opts.database; - return getAdapter(opts); - }, - disableTests: { - SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true, - }, - }); - - describe("simple flow", async () => { - const { auth } = await getTestInstance( - { - database: dialect, - user: { - modelName: "users", - }, - }, - { - disableTestUser: true, - }, - ); - it("should sign-up", async () => { - const res = await auth.api.signUpEmail({ - body: { - name: "test", - password: "password", - email: "test-2@email.com", - }, - }); - expect(res.user.name).toBe("test"); - expect(res.token?.length).toBeTruthy(); - }); - - let token = ""; - it("should sign in", async () => { - const signInRes = await auth.api.signInEmail({ - body: { - password: "password", - email: "test-2@email.com", - }, - }); - - expect(signInRes.token?.length).toBeTruthy(); - expect(signInRes.user.name).toBe("test"); - token = signInRes.token; - }); - - it("should return session", async () => { - const session = await auth.api.getSession({ - headers: new Headers({ - Authorization: `Bearer ${token}`, - }), - }); - expect(session).toMatchObject({ - session: { - token, - userId: expect.any(String), - }, - user: { - name: "test", - email: "test-2@email.com", - }, - }); - }); - - it("stores and retrieves timestamps correctly across timezones", async () => { - using _ = recoverProcessTZ(); - - const sampleUser = { - name: "sample", - email: "sampler@test.com", - password: "samplerrrrr", - }; - - process.env.TZ = "Europe/London"; - const userSignUp = await auth.api.signUpEmail({ - body: { - name: sampleUser.name, - email: sampleUser.email, - password: sampleUser.password, - }, - }); - process.env.TZ = "America/Los_Angeles"; - const userSignIn = await auth.api.signInEmail({ - body: { email: sampleUser.email, password: sampleUser.password }, - }); - - expect(userSignUp.user.createdAt).toStrictEqual( - userSignIn.user.createdAt, - ); - }); - }); -}); - -describe("mysql", async () => { - const pool = createPool("mysql://user:password@localhost:3306/better_auth"); - const dialect = new MysqlDialect(pool); - - const opts = { - database: dialect, - user: { - modelName: "users", - }, - } satisfies BetterAuthOptions; - - beforeAll(async () => { - await pool.query("DROP DATABASE IF EXISTS better_auth"); - await pool.query("CREATE DATABASE better_auth"); - await pool.query("USE better_auth"); - - const { runMigrations } = await getMigrations(opts); - await runMigrations(); - - return async () => { - await resetDB(); - await pool.end(); - setState("IDLE"); - }; - }); - - const mysql = new Kysely({ dialect }); - - const getAdapter = kyselyAdapter(mysql, { - type: "mysql", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - async function resetDB() { - await sql`DROP TABLE session;`.execute(mysql); - await sql`DROP TABLE verification;`.execute(mysql); - await sql`DROP TABLE account;`.execute(mysql); - await sql`DROP TABLE users;`.execute(mysql); - } - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - // const merged = merge( customOptions,opts); - // merged.database = opts.database; - return getAdapter(opts); - }, - disableTests: { - SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true, - }, - }); - - describe("simple flow", async () => { - const { auth } = await getTestInstance( - { - database: pool, - user: { - modelName: "users", - }, - }, - { - disableTestUser: true, - }, - ); - it("should sign-up", async () => { - const res = await auth.api.signUpEmail({ - body: { - name: "test", - password: "password", - email: "test-2@email.com", - }, - }); - expect(res.user.name).toBe("test"); - expect(res.token?.length).toBeTruthy(); - }); - - let token = ""; - it("should sign in", async () => { - const signInRes = await auth.api.signInEmail({ - body: { - password: "password", - email: "test-2@email.com", - }, - }); - - expect(signInRes.token?.length).toBeTruthy(); - expect(signInRes.user.name).toBe("test"); - token = signInRes.token; - }); - - it("should return session", async () => { - const session = await auth.api.getSession({ - headers: new Headers({ - Authorization: `Bearer ${token}`, - }), - }); - expect(session).toMatchObject({ - session: { - token, - userId: expect.any(String), - }, - user: { - name: "test", - email: "test-2@email.com", - }, - }); - }); - - it("stores and retrieves timestamps correctly across timezones", async () => { - using _ = recoverProcessTZ(); - - const sampleUser = { - name: "sample", - email: "sampler@test.com", - password: "samplerrrrr", - }; - - process.env.TZ = "Africa/Addis_Ababa"; - const userSignUp = await auth.api.signUpEmail({ - body: { - name: sampleUser.name, - email: sampleUser.email, - password: sampleUser.password, - }, - }); - - process.env.TZ = "America/Los_Angeles"; - const userSignIn = await auth.api.signInEmail({ - body: { email: sampleUser.email, password: sampleUser.password }, - }); - - expect(userSignUp.user.createdAt).toStrictEqual( - userSignIn.user.createdAt, - ); - }); - }); -}); diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/number-id/adapter.kysely.number-id.test.ts b/packages/better-auth/src/adapters/kysely-adapter/test/number-id/adapter.kysely.number-id.test.ts deleted file mode 100644 index 8b8da383..00000000 --- a/packages/better-auth/src/adapters/kysely-adapter/test/number-id/adapter.kysely.number-id.test.ts +++ /dev/null @@ -1,125 +0,0 @@ -import merge from "deepmerge"; -import fs from "fs"; -import fsPromises from "fs/promises"; -import { afterAll, beforeAll, describe } from "vitest"; -import { runNumberIdAdapterTest } from "../../../test"; -import path from "path"; -import Database from "better-sqlite3"; -import { kyselyAdapter } from "../.."; -import { Kysely, MysqlDialect, SqliteDialect } from "kysely"; -import { createPool } from "mysql2/promise"; -import { getState, stateFilePath } from "../state"; -import { getMigrations } from "../../../../db/get-migration"; -import type { BetterAuthOptions } from "../../../../types"; - -export const opts = ({ - database, - isNumberIdTest, -}: { - database: BetterAuthOptions["database"]; - isNumberIdTest: boolean; -}): BetterAuthOptions => ({ - database: database, - user: { - fields: { - email: "email_address", - }, - additionalFields: { - test: { - type: "string", - defaultValue: "test", - }, - }, - }, - session: { - modelName: "sessions", - }, - advanced: { - database: { - useNumberId: isNumberIdTest, - }, - }, -}); - -const sqlite = new Database(path.join(__dirname, "test.db")); -const mysql = createPool("mysql://user:password@localhost:3306/better_auth"); -const sqliteKy = new Kysely({ - dialect: new SqliteDialect({ - database: sqlite, - }), -}); -const mysqlKy = new Kysely({ - dialect: new MysqlDialect(mysql), -}); - -describe("Number ID Adapter tests", async () => { - const mysqlOptions = opts({ - database: { - db: mysqlKy, - type: "mysql", - }, - isNumberIdTest: true, - }); - const sqliteOptions = opts({ - database: { - db: sqliteKy, - type: "sqlite", - }, - isNumberIdTest: true, - }); - - beforeAll(async () => { - await new Promise(async (resolve) => { - await new Promise((r) => setTimeout(r, 800)); - if (getState() === "IDLE") { - resolve(true); - return; - } - console.log(`Waiting for state to be IDLE...`); - fs.watch(stateFilePath, () => { - if (getState() === "IDLE") { - resolve(true); - return; - } - }); - }); - console.log(`Now running Number ID Kysely adapter test...`); - await (await getMigrations(mysqlOptions)).runMigrations(); - await (await getMigrations(sqliteOptions)).runMigrations(); - }); - - afterAll(async () => { - await mysql.query("DROP DATABASE IF EXISTS better_auth"); - await mysql.query("CREATE DATABASE better_auth"); - await mysql.end(); - await fsPromises.unlink(path.join(__dirname, "test.db")); - }); - - const mysqlAdapter = kyselyAdapter(mysqlKy, { - type: "mysql", - debugLogs: { - isRunningAdapterTests: false, - }, - }); - runNumberIdAdapterTest({ - getAdapter: async (customOptions = {}) => { - const merged = merge(customOptions, mysqlOptions); - return mysqlAdapter(merged); - }, - testPrefix: "mysql", - }); - - const sqliteAdapter = kyselyAdapter(sqliteKy, { - type: "sqlite", - debugLogs: { - isRunningAdapterTests: false, - }, - }); - - runNumberIdAdapterTest({ - getAdapter: async (customOptions = {}) => { - return sqliteAdapter(merge(customOptions, sqliteOptions)); - }, - testPrefix: "sqlite", - }); -}); diff --git a/packages/better-auth/src/adapters/kysely-adapter/test/state.ts b/packages/better-auth/src/adapters/kysely-adapter/test/state.ts deleted file mode 100644 index 7412174b..00000000 --- a/packages/better-auth/src/adapters/kysely-adapter/test/state.ts +++ /dev/null @@ -1,3 +0,0 @@ -import { makeTestState } from "../../../test-utils/state"; - -export const { stateFilePath, getState, setState } = makeTestState(__dirname); diff --git a/packages/better-auth/src/adapters/memory-adapter/adapter.memory.test.ts b/packages/better-auth/src/adapters/memory-adapter/adapter.memory.test.ts index 300dc088..a87dc289 100644 --- a/packages/better-auth/src/adapters/memory-adapter/adapter.memory.test.ts +++ b/packages/better-auth/src/adapters/memory-adapter/adapter.memory.test.ts @@ -1,56 +1,34 @@ -import { describe } from "vitest"; +import { getAuthTables } from "../../db"; +import { testAdapter } from "../test-adapter"; import { memoryAdapter } from "./memory-adapter"; -import { runAdapterTest, runNumberIdAdapterTest } from "../test"; +import { + performanceTestSuite, + normalTestSuite, + transactionsTestSuite, + authFlowTestSuite, + numberIdTestSuite, +} from "../tests"; +let db: Record = {}; -describe("adapter test", async () => { - const db = { - user: [], - session: [], - account: [], - }; - const adapter = memoryAdapter(db, { - debugLogs: { - isRunningAdapterTests: true, - }, - }); - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - return adapter({ - user: { - fields: { - email: "email_address", - }, - }, - ...customOptions, - }); - }, - disableTests: { - SHOULD_ROLLBACK_FAILING_TRANSACTION: true, - SHOULD_RETURN_TRANSACTION_RESULT: true, - }, - }); +const { execute } = await testAdapter({ + adapter: () => { + return memoryAdapter(db); + }, + runMigrations: (options) => { + db = {}; + const allModels = Object.keys(getAuthTables(options)); + for (const model of allModels) { + db[model] = []; + } + }, + tests: [ + normalTestSuite(), + transactionsTestSuite({ disableTests: { ALL: true } }), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite(), + ], + async onFinish() {}, }); -describe("Number Id Adapter Test", async () => { - const db = { - user: [], - session: [], - account: [], - }; - const adapter = memoryAdapter(db, { - debugLogs: { - isRunningAdapterTests: true, - }, - }); - runNumberIdAdapterTest({ - getAdapter: async (customOptions = {}) => { - return adapter({ - ...customOptions, - }); - }, - disableTests: { - SHOULD_ROLLBACK_FAILING_TRANSACTION: true, - SHOULD_RETURN_TRANSACTION_RESULT: true, - }, - }); -}); +execute(); diff --git a/packages/better-auth/src/adapters/memory-adapter/memory-adapter.ts b/packages/better-auth/src/adapters/memory-adapter/memory-adapter.ts index 79a1c317..dff6a1f1 100644 --- a/packages/better-auth/src/adapters/memory-adapter/memory-adapter.ts +++ b/packages/better-auth/src/adapters/memory-adapter/memory-adapter.ts @@ -35,13 +35,14 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => { transaction: async (cb) => { let clone = structuredClone(db); try { - return cb(adapterCreator(lazyOptions!)); - } catch { + const r = await cb(adapterCreator(lazyOptions!)); + return r; + } catch (error) { // Rollback changes Object.keys(db).forEach((key) => { db[key] = clone[key]!; }); - throw new Error("Transaction failed, rolling back changes"); + throw error; } }, }, @@ -77,6 +78,16 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => { return record[field].startsWith(value); case "ends_with": return record[field].endsWith(value); + case "ne": + return record[field] !== value; + case "gt": + return value != null && Boolean(record[field] > value); + case "gte": + return value != null && Boolean(record[field] >= value); + case "lt": + return value != null && Boolean(record[field] < value); + case "lte": + return value != null && Boolean(record[field] <= value); default: return record[field] === value; } @@ -126,11 +137,50 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => { if (sortBy) { table = table!.sort((a, b) => { const field = getFieldName({ model, field: sortBy.field }); - if (sortBy.direction === "asc") { - return a[field] > b[field] ? 1 : -1; - } else { - return a[field] < b[field] ? 1 : -1; + const aValue = a[field]; + const bValue = b[field]; + + let comparison = 0; + + // Handle null/undefined values + if (aValue == null && bValue == null) { + comparison = 0; + } else if (aValue == null) { + comparison = -1; + } else if (bValue == null) { + comparison = 1; } + // Handle string comparison + else if ( + typeof aValue === "string" && + typeof bValue === "string" + ) { + comparison = aValue.localeCompare(bValue); + } + // Handle date comparison + else if (aValue instanceof Date && bValue instanceof Date) { + comparison = aValue.getTime() - bValue.getTime(); + } + // Handle numeric comparison + else if ( + typeof aValue === "number" && + typeof bValue === "number" + ) { + comparison = aValue - bValue; + } + // Handle boolean comparison + else if ( + typeof aValue === "boolean" && + typeof bValue === "boolean" + ) { + comparison = aValue === bValue ? 0 : aValue ? 1 : -1; + } + // Fallback to string comparison + else { + comparison = String(aValue).localeCompare(String(bValue)); + } + + return sortBy.direction === "asc" ? comparison : -comparison; }); } if (offset !== undefined) { @@ -141,7 +191,11 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => { } return table || []; }, - count: async ({ model }) => { + count: async ({ model, where }) => { + if (where) { + const filteredRecords = convertWhereClause(where, model); + return filteredRecords.length; + } return db[model]!.length; }, update: async ({ model, where, update }) => { diff --git a/packages/better-auth/src/adapters/mongodb-adapter/adapter.mongo-db.test.ts b/packages/better-auth/src/adapters/mongodb-adapter/adapter.mongo-db.test.ts index b5aa695c..e4c78132 100644 --- a/packages/better-auth/src/adapters/mongodb-adapter/adapter.mongo-db.test.ts +++ b/packages/better-auth/src/adapters/mongodb-adapter/adapter.mongo-db.test.ts @@ -1,105 +1,39 @@ -import { describe, beforeAll, it, expect } from "vitest"; +import { MongoClient, ObjectId } from "mongodb"; +import { testAdapter } from "../test-adapter"; +import { mongodbAdapter } from "./mongodb-adapter"; +import { + normalTestSuite, + performanceTestSuite, + authFlowTestSuite, + transactionsTestSuite, +} from "../tests"; -import { MongoClient } from "mongodb"; -import { runAdapterTest } from "../test"; -import { mongodbAdapter } from "."; -import { getTestInstance } from "../../test-utils/test-instance"; -describe("adapter test", async () => { - const dbClient = async (connectionString: string, dbName: string) => { - const client = new MongoClient(connectionString); - await client.connect(); - const db = client.db(dbName); - return { db, client }; - }; +const dbClient = async (connectionString: string, dbName: string) => { + const client = new MongoClient(connectionString); + await client.connect(); + const db = client.db(dbName); + return { db, client }; +}; - const user = "user"; - const { db, client } = await dbClient( - "mongodb://127.0.0.1:27017", - "better-auth", - ); - async function clearDb() { - await db.collection(user).deleteMany({}); - await db.collection("session").deleteMany({}); - } +const { db, client } = await dbClient( + "mongodb://127.0.0.1:27017", + "better-auth", +); - beforeAll(async () => { - await clearDb(); - }); - - const adapter = mongodbAdapter(db, { - // MongoDB transactions require a replica set or a sharded cluster - // client, - }); - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - return adapter({ - user: { - fields: { - email: "email_address", - }, - additionalFields: { - test: { - type: "string", - defaultValue: "test", - }, - }, - }, - session: { - modelName: "sessions", - }, - ...customOptions, - }); - }, - disableTests: { - SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true, - SHOULD_RETURN_TRANSACTION_RESULT: true, - SHOULD_ROLLBACK_FAILING_TRANSACTION: true, - }, - }); +const { execute } = await testAdapter({ + adapter: (options) => { + return mongodbAdapter(db, { transaction: false }); + }, + runMigrations: async (betterAuthOptions) => {}, + tests: [ + normalTestSuite(), + authFlowTestSuite(), + transactionsTestSuite(), + // numberIdTestSuite(), // Mongo doesn't support number ids + performanceTestSuite(), + ], + customIdGenerator: () => new ObjectId().toString(), + defaultRetryCount: 20, }); -describe("simple-flow", async () => { - const { auth, client, sessionSetter, db } = await getTestInstance( - {}, - { - disableTestUser: true, - testWith: "mongodb", - }, - ); - const testUser = { - email: "test-eamil@email.com", - password: "password", - name: "Test Name", - }; - - it("should sign up", async () => { - const user = await auth.api.signUpEmail({ - body: testUser, - }); - expect(user).toBeDefined(); - }); - - it("should sign in", async () => { - const user = await auth.api.signInEmail({ - body: testUser, - }); - expect(user).toBeDefined(); - }); - - it("should get session", async () => { - const headers = new Headers(); - await client.signIn.email( - { - email: testUser.email, - password: testUser.password, - }, - { - onSuccess: sessionSetter(headers), - }, - ); - const { data: session } = await client.getSession({ - fetchOptions: { headers }, - }); - expect(session?.user).toBeDefined(); - }); -}); +execute(); diff --git a/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts b/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts index febc2f42..8b084012 100644 --- a/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts +++ b/packages/better-auth/src/adapters/mongodb-adapter/mongodb-adapter.ts @@ -37,20 +37,10 @@ export interface MongoDBAdapterConfig { export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { let lazyOptions: BetterAuthOptions | null; - const getCustomIdGenerator = (options: BetterAuthOptions) => { - const generator = - options.advanced?.database?.generateId || options.advanced?.generateId; - if (typeof generator === "function") { - return generator; - } - return undefined; - }; const createCustomAdapter = (db: Db, session?: ClientSession): AdapterFactoryCustomizeAdapterCreator => ({ options, getFieldName, schema, getDefaultModelName }) => { - const customIdGen = getCustomIdGenerator(options); - function serializeID({ field, value, @@ -60,9 +50,6 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { value: any; model: string; }) { - if (customIdGen) { - return value; - } model = getDefaultModelName(model); if ( field === "id" || @@ -203,7 +190,14 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { }, async findOne({ model, where, select }) { const clause = convertWhereClause({ where, model }); - const res = await db.collection(model).findOne(clause, { session }); + const projection = select + ? Object.fromEntries( + select.map((field) => [getFieldName({ field, model }), 1]), + ) + : undefined; + const res = await db + .collection(model) + .findOne(clause, { session, projection }); if (!res) return null; return res as any; }, @@ -220,10 +214,11 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { const res = await cursor.toArray(); return res as any; }, - async count({ model }) { + async count({ model, where }) { + const clause = where ? convertWhereClause({ where, model }) : {}; const res = await db .collection(model) - .countDocuments(undefined, { session }); + .countDocuments(clause, { session }); return res; }, async update({ model, where, update: values }) { @@ -319,17 +314,28 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { model, options, }) { - const customIdGen = getCustomIdGenerator(options); - if (field === "_id" || fieldAttributes.references?.field === "id") { - if (customIdGen) { - return data; - } if (action === "update") { + if (typeof data === "string") { + try { + return new ObjectId(data); + } catch (error) { + return data; + } + } return data; } if (Array.isArray(data)) { - return data.map((v) => new ObjectId()); + return data.map((v) => { + if (typeof v === "string") { + try { + return new ObjectId(v); + } catch (error) { + return v; + } + } + return v; + }); } if (typeof data === "string") { try { @@ -359,6 +365,9 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { } return data; }, + customIdGenerator(props) { + return new ObjectId().toString(); + }, }, adapter: createCustomAdapter(db), }; diff --git a/packages/better-auth/src/adapters/prisma-adapter/prisma-adapter.ts b/packages/better-auth/src/adapters/prisma-adapter/prisma-adapter.ts index cf2fbdbb..3e9f4ae9 100644 --- a/packages/better-auth/src/adapters/prisma-adapter/prisma-adapter.ts +++ b/packages/better-auth/src/adapters/prisma-adapter/prisma-adapter.ts @@ -55,6 +55,7 @@ type PrismaClientInternal = { findFirst: (data: any) => Promise; findMany: (data: any) => Promise; update: (data: any) => Promise; + updateMany: (data: any) => Promise; delete: (data: any) => Promise; [key: string]: any; }; @@ -91,7 +92,7 @@ export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) => { } } const convertWhereClause = (model: string, where?: Where[]) => { - if (!where) return {}; + if (!where || !where.length) return {}; if (where.length === 1) { const w = where[0]!; if (!w) { diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/.gitignore b/packages/better-auth/src/adapters/prisma-adapter/test/.gitignore new file mode 100644 index 00000000..c9e9469d --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/.gitignore @@ -0,0 +1,5 @@ +auth.ts +schema-mysql.prisma +schema-sqlite.prisma +schema-postgresql.prisma +.tmp \ No newline at end of file diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/base.prisma b/packages/better-auth/src/adapters/prisma-adapter/test/base.prisma new file mode 100644 index 00000000..a43f7411 --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/base.prisma @@ -0,0 +1,69 @@ + +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "sqlite" + url = "file:./dev.db" +} + +model User { + id String @id + name String + email String + emailVerified Boolean @default(false) + image String? + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + sessions Session[] + accounts Account[] + + @@unique([email]) + @@map("user") +} + +model Session { + id String @id + expiresAt DateTime + token String + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + ipAddress String? + userAgent String? + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([token]) + @@map("session") +} + +model Account { + id String @id + accountId String + providerId String + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + accessToken String? + refreshToken String? + idToken String? + accessTokenExpiresAt DateTime? + refreshTokenExpiresAt DateTime? + scope String? + password String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@map("account") +} + +model Verification { + id String @id + identifier String + value String + expiresAt DateTime + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + + @@map("verification") +} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/generate-auth-config.ts b/packages/better-auth/src/adapters/prisma-adapter/test/generate-auth-config.ts new file mode 100644 index 00000000..14f9e332 --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/generate-auth-config.ts @@ -0,0 +1,22 @@ +import fs from "fs/promises"; +import type { BetterAuthOptions } from "../../../types"; +import path from "path"; + +export const generateAuthConfigFile = async (_options: BetterAuthOptions) => { + const options = { ..._options }; + // biome-ignore lint/performance/noDelete: perf doesn't matter here. + delete options.database; + let code = `import { betterAuth } from "../../../auth"; +import { prismaAdapter } from "../prisma-adapter"; +import { PrismaClient } from "@prisma/client"; +const db = new PrismaClient(); + +export const auth = betterAuth({ + database: prismaAdapter(db, { + provider: 'sqlite' + }), + ${JSON.stringify(options, null, 2).slice(1, -1)} +})`; + + await fs.writeFile(path.join(import.meta.dirname, "auth.ts"), code); +}; diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/generate-prisma-schema.ts b/packages/better-auth/src/adapters/prisma-adapter/test/generate-prisma-schema.ts new file mode 100644 index 00000000..f9e38820 --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/generate-prisma-schema.ts @@ -0,0 +1,57 @@ +import type { PrismaClient } from "@prisma/client"; +import type { Adapter, BetterAuthOptions } from "../../../types"; +import { prismaAdapter } from "../prisma-adapter"; +import { join } from "path"; +import fs from "fs/promises"; + +export async function generatePrismaSchema( + betterAuthOptions: BetterAuthOptions, + db: PrismaClient, + iteration: number, + dialect: "sqlite" | "postgresql" | "mysql", +) { + const i = async (x: string) => await import(x); + const { generateSchema } = (await i( + "./../../../../../cli/src/generators/index", + )) as { + generateSchema: (opts: { + adapter: Adapter; + file?: string; + options: BetterAuthOptions; + }) => Promise<{ + code: string | undefined; + fileName: string; + overwrite: boolean | undefined; + }>; + }; + + const prismaDB = prismaAdapter(db, { provider: dialect }); + let { fileName, code } = await generateSchema({ + file: join(import.meta.dirname, `schema-${dialect}.prisma`), + adapter: prismaDB({}), + options: { ...betterAuthOptions, database: prismaDB }, + }); + if (dialect === "postgresql") { + code = code?.replace( + `env("DATABASE_URL")`, + '"postgres://user:password@localhost:5434/better_auth"', + ); + } else if (dialect === "mysql") { + code = code?.replace( + `env("DATABASE_URL")`, + '"mysql://user:password@localhost:3308/better_auth"', + ); + } + code = code + ?.split("\n") + .map((line, index) => { + if (index === 2) { + return ( + line + `\n output = "./.tmp/prisma-client-${dialect}-${iteration}"` + ); + } + return line; + }) + .join("\n"); + await fs.writeFile(fileName, code || "", "utf-8"); +} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/get-prisma-client.ts b/packages/better-auth/src/adapters/prisma-adapter/test/get-prisma-client.ts new file mode 100644 index 00000000..e3cde1e5 --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/get-prisma-client.ts @@ -0,0 +1,39 @@ +import type { PrismaClient } from "@prisma/client"; +type PC = InstanceType; + +let migrationCount = 0; +const clientMap = new Map(); +export const getPrismaClient = async ( + dialect: "sqlite" | "postgresql" | "mysql", +) => { + if (clientMap.has(`${dialect}-${migrationCount}`)) { + return clientMap.get(`${dialect}-${migrationCount}`) as PC; + } + const { PrismaClient } = await import( + migrationCount === 0 + ? "@prisma/client" + : `./.tmp/prisma-client-${dialect}-${migrationCount}` + ); + const db = new PrismaClient(); + clientMap.set(`${dialect}-${migrationCount}`, db); + return db as PC; +}; + +export const incrementMigrationCount = () => { + migrationCount++; + return migrationCount; +}; + +export const destroyPrismaClient = ({ + migrationCount, + dialect, +}: { + migrationCount: number; + dialect: "sqlite" | "postgresql" | "mysql"; +}) => { + const db = clientMap.get(`${dialect}-${migrationCount}`); + if (db) { + db.$disconnect(); + } + clientMap.delete(`${dialect}-${migrationCount}`); +}; diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/adapter.prisma.test.ts b/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/adapter.prisma.test.ts deleted file mode 100644 index 4963c4ab..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/adapter.prisma.test.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { beforeAll, describe } from "vitest"; -import { pushPrismaSchema } from "../push-schema"; -import { createTestOptions } from "../test-options"; -import { runAdapterTest } from "../../../test"; -import { setState } from "../state"; - -describe("Adapter tests", async () => { - beforeAll(async () => { - setState("RUNNING"); - pushPrismaSchema("normal"); - console.log("Successfully pushed normal Prisma Schema using pnpm..."); - const { getAdapter } = await import("./get-adapter"); - const { clearDb } = getAdapter(); - await clearDb(); - return () => { - console.log( - `Normal Prisma adapter test finished. Now allowing number ID prisma tests to run.`, - ); - setState("IDLE"); - }; - }); - - runAdapterTest({ - getAdapter: async (customOptions = {}) => { - const { getAdapter } = await import("./get-adapter"); - const { adapter } = getAdapter(); - const { advanced, database, session, user } = createTestOptions(adapter); - return adapter({ - ...customOptions, - user: { - ...user, - ...customOptions.user, - }, - session: { - ...session, - ...customOptions.session, - }, - advanced: { - ...advanced, - ...customOptions.advanced, - }, - database, - }); - }, - }); -}); diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/get-adapter.ts b/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/get-adapter.ts deleted file mode 100644 index d2a308ee..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/get-adapter.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { PrismaClient } from "@prisma/client"; -import { prismaAdapter } from "../.."; - -export function getAdapter() { - const db = new PrismaClient(); - - async function clearDb() { - await db.sessions.deleteMany(); - await db.user.deleteMany(); - } - - const adapter = prismaAdapter(db, { - provider: "sqlite", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - return { adapter, clearDb }; -} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/schema.prisma b/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/schema.prisma deleted file mode 100644 index 89732b75..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/normal-tests/schema.prisma +++ /dev/null @@ -1,28 +0,0 @@ -generator client { - provider = "prisma-client-js" - previewFeatures = ["strictUndefinedChecks"] -} - -datasource db { - provider = "sqlite" - url = "file:.db/dev.db" -} - -model User { - id String @id @default(cuid()) - email_address String @unique - test String - emailVerified Boolean @default(false) - name String - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} - -model Sessions { - id String @id @default(cuid()) - userId String - token String @unique - expiresAt DateTime - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/adapter.prisma.number-id.test.ts b/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/adapter.prisma.number-id.test.ts deleted file mode 100644 index 3a3b4f93..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/adapter.prisma.number-id.test.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { beforeAll, describe } from "vitest"; -import { runNumberIdAdapterTest } from "../../../test"; -import { pushPrismaSchema } from "../push-schema"; -import { createTestOptions } from "../test-options"; -import * as fs from "fs"; -import { getState, stateFilePath } from "../state"; - -describe("Number Id Adapter Test", async () => { - beforeAll(async () => { - await new Promise(async (resolve) => { - await new Promise((r) => setTimeout(r, 500)); - if (getState() === "IDLE") { - resolve(true); - return; - } - console.log(`Waiting for state to be IDLE...`); - fs.watch(stateFilePath, () => { - if (getState() === "IDLE") { - resolve(true); - return; - } - }); - }); - console.log(`Now running Number ID Prisma adapter test...`); - pushPrismaSchema("number-id"); - console.log(`Successfully pushed number id Prisma Schema using pnpm...`); - const { getAdapter } = await import("./get-adapter"); - const { clearDb } = getAdapter(); - await clearDb(); - }, Number.POSITIVE_INFINITY); - - runNumberIdAdapterTest({ - getAdapter: async (customOptions = {}) => { - const { getAdapter } = await import("./get-adapter"); - const { adapter } = getAdapter(); - const { advanced, database, session, user } = createTestOptions(adapter); - return adapter({ - ...customOptions, - user: { - ...user, - ...customOptions.user, - }, - session: { - ...session, - ...customOptions.session, - }, - advanced: { - ...advanced, - ...customOptions.advanced, - }, - database, - }); - }, - }); -}); diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/get-adapter.ts b/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/get-adapter.ts deleted file mode 100644 index 7ac18da1..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/get-adapter.ts +++ /dev/null @@ -1,27 +0,0 @@ -import { PrismaClient } from "@prisma/client"; -import { prismaAdapter } from "../.."; - -export function getAdapter() { - const db = new PrismaClient(); - - async function clearDb() { - await db.sessions.deleteMany(); - await db.user.deleteMany(); - try { - await db.$executeRaw`DELETE FROM sqlite_sequence WHERE name = 'User'`; - } catch {} - try { - // it's `sessions` not `session` because our `createTestOptions` uses `modelName: "sessions"` - await db.$executeRaw`DELETE FROM sqlite_sequence WHERE name = 'Sessions'`; - } catch {} - } - - const adapter = prismaAdapter(db, { - provider: "sqlite", - debugLogs: { - isRunningAdapterTests: true, - }, - }); - - return { adapter, clearDb }; -} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/schema.prisma b/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/schema.prisma deleted file mode 100644 index b7a013ca..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/number-id-tests/schema.prisma +++ /dev/null @@ -1,28 +0,0 @@ -generator client { - provider = "prisma-client-js" - previewFeatures = ["strictUndefinedChecks"] -} - -datasource db { - provider = "sqlite" - url = "file:.db/dev.db" -} - -model User { - id Int @id @default(autoincrement()) - email_address String @unique - test String - emailVerified Boolean @default(false) - name String - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} - -model Sessions { - id Int @id @default(autoincrement()) - userId Int - token String @unique - expiresAt DateTime - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/prisma.mysql.test.ts b/packages/better-auth/src/adapters/prisma-adapter/test/prisma.mysql.test.ts new file mode 100644 index 00000000..5c66f7f7 --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/prisma.mysql.test.ts @@ -0,0 +1,56 @@ +import { testAdapter } from "../../test-adapter"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { prismaAdapter } from "../prisma-adapter"; +import { generateAuthConfigFile } from "./generate-auth-config"; +import { generatePrismaSchema } from "./generate-prisma-schema"; +import { pushPrismaSchema } from "./push-prisma-schema"; +import type { BetterAuthOptions } from "../../../types"; +import { + destroyPrismaClient, + getPrismaClient, + incrementMigrationCount, +} from "./get-prisma-client"; +import { createPool } from "mysql2/promise"; + +const dialect = "mysql"; +const { execute } = await testAdapter({ + adapter: async () => { + const db = await getPrismaClient(dialect); + return prismaAdapter(db, { + provider: dialect, + debugLogs: { isRunningAdapterTests: true }, + }); + }, + runMigrations: async (options: BetterAuthOptions) => { + const mysqlDB = createPool({ + uri: "mysql://user:password@localhost:3308/better_auth", + timezone: "Z", + }); + await mysqlDB.query("DROP DATABASE IF EXISTS better_auth"); + await mysqlDB.query("CREATE DATABASE better_auth"); + await mysqlDB.end(); + const db = await getPrismaClient(dialect); + const migrationCount = incrementMigrationCount(); + await generateAuthConfigFile(options); + await generatePrismaSchema(options, db, migrationCount, dialect); + await pushPrismaSchema(dialect); + destroyPrismaClient({ migrationCount: migrationCount - 1, dialect }); + }, + tests: [ + normalTestSuite(), + transactionsTestSuite(), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect }), + ], + onFinish: async () => {}, + prefixTests: dialect, +}); + +execute(); diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/prisma.pg.test.ts b/packages/better-auth/src/adapters/prisma-adapter/test/prisma.pg.test.ts new file mode 100644 index 00000000..3e8170bf --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/prisma.pg.test.ts @@ -0,0 +1,54 @@ +import { testAdapter } from "../../test-adapter"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { prismaAdapter } from "../prisma-adapter"; +import { generateAuthConfigFile } from "./generate-auth-config"; +import { generatePrismaSchema } from "./generate-prisma-schema"; +import { pushPrismaSchema } from "./push-prisma-schema"; +import type { BetterAuthOptions } from "../../../types"; +import { + destroyPrismaClient, + getPrismaClient, + incrementMigrationCount, +} from "./get-prisma-client"; +import { Pool } from "pg"; + +const dialect = "postgresql"; +const { execute } = await testAdapter({ + adapter: async () => { + const db = await getPrismaClient(dialect); + return prismaAdapter(db, { + provider: dialect, + debugLogs: { isRunningAdapterTests: true }, + }); + }, + runMigrations: async (options: BetterAuthOptions) => { + const db = await getPrismaClient(dialect); + const pgDB = new Pool({ + connectionString: "postgres://user:password@localhost:5434/better_auth", + }); + await pgDB.query(`DROP SCHEMA public CASCADE; CREATE SCHEMA public;`); + await pgDB.end(); + const migrationCount = incrementMigrationCount(); + await generateAuthConfigFile(options); + await generatePrismaSchema(options, db, migrationCount, dialect); + await pushPrismaSchema(dialect); + destroyPrismaClient({ migrationCount: migrationCount - 1, dialect }); + }, + tests: [ + normalTestSuite(), + transactionsTestSuite(), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect }), + ], + onFinish: async () => {}, + prefixTests: "pg", +}); + +execute(); diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/prisma.sqlite.test.ts b/packages/better-auth/src/adapters/prisma-adapter/test/prisma.sqlite.test.ts new file mode 100644 index 00000000..c1b3b8ef --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/prisma.sqlite.test.ts @@ -0,0 +1,57 @@ +import { testAdapter } from "../../test-adapter"; +import { + authFlowTestSuite, + normalTestSuite, + numberIdTestSuite, + performanceTestSuite, + transactionsTestSuite, +} from "../../tests"; +import { prismaAdapter } from "../prisma-adapter"; +import { generateAuthConfigFile } from "./generate-auth-config"; +import { generatePrismaSchema } from "./generate-prisma-schema"; +import { pushPrismaSchema } from "./push-prisma-schema"; +import type { BetterAuthOptions } from "../../../types"; +import { join } from "path"; +import fs from "node:fs/promises"; +import { + destroyPrismaClient, + getPrismaClient, + incrementMigrationCount, +} from "./get-prisma-client"; + +const dialect = "sqlite"; +const { execute } = await testAdapter({ + adapter: async () => { + const db = await getPrismaClient(dialect); + return prismaAdapter(db, { + provider: dialect, + debugLogs: { isRunningAdapterTests: true }, + }); + }, + runMigrations: async (options: BetterAuthOptions) => { + const dbPath = join(import.meta.dirname, "dev.db"); + try { + await fs.unlink(dbPath); + } catch { + console.log("db path not found"); + } + const db = await getPrismaClient(dialect); + const migrationCount = incrementMigrationCount(); + await generateAuthConfigFile(options); + await generatePrismaSchema(options, db, migrationCount, dialect); + await pushPrismaSchema(dialect); + await db.$disconnect(); + destroyPrismaClient({ migrationCount: migrationCount - 1, dialect }); + }, + tests: [ + normalTestSuite(), + transactionsTestSuite(), + authFlowTestSuite(), + numberIdTestSuite(), + performanceTestSuite({ dialect }), + ], + onFinish: async () => {}, + prefixTests: dialect, +}); + +execute(); diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/push-prisma-schema.ts b/packages/better-auth/src/adapters/prisma-adapter/test/push-prisma-schema.ts new file mode 100644 index 00000000..07f1a41f --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/push-prisma-schema.ts @@ -0,0 +1,14 @@ +import { execSync } from "node:child_process"; +import { createRequire } from "node:module"; +import { join } from "node:path"; + +export async function pushPrismaSchema( + dialect: "sqlite" | "postgresql" | "mysql", +) { + const node = process.execPath; + const cli = createRequire(import.meta.url).resolve("prisma"); + execSync(`${node} ${cli} db push --schema ./schema-${dialect}.prisma`, { + stdio: "ignore", // use `inherit` if you want to see the output + cwd: join(import.meta.dirname), + }); +} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/push-schema.ts b/packages/better-auth/src/adapters/prisma-adapter/test/push-schema.ts deleted file mode 100644 index 9a67e4d7..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/push-schema.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { execSync } from "child_process"; -import { join } from "node:path"; -import { createRequire } from "node:module"; - -export function pushPrismaSchema(schema: "normal" | "number-id") { - const node = process.execPath; - const cli = createRequire(import.meta.url).resolve("prisma"); - if (schema === "normal") { - execSync(`${node} ${cli} db push --schema ./schema.prisma`, { - stdio: "inherit", - cwd: join(import.meta.dirname, "normal-tests"), - }); - } else { - execSync(`${node} ${cli} db push --schema ./schema.prisma`, { - stdio: "inherit", - cwd: join(import.meta.dirname, "number-id-tests"), - }); - } -} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/schema-mysql.prisma b/packages/better-auth/src/adapters/prisma-adapter/test/schema-mysql.prisma new file mode 100644 index 00000000..1dff4c5f --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/schema-mysql.prisma @@ -0,0 +1,70 @@ + +generator client { + provider = "prisma-client-js" + output = "./.tmp/prisma-client-mysql-6" +} + +datasource db { + provider = "mysql" + url = "mysql://user:password@localhost:3308/better_auth" +} + +model User { + id Int @id @default(autoincrement()) + name String @db.Text + email String + emailVerified Boolean @default(false) + image String? @db.Text + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + sessions Session[] + accounts Account[] + + @@unique([email]) + @@map("user") +} + +model Session { + id Int @id @default(autoincrement()) + expiresAt DateTime + token String + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + ipAddress String? @db.Text + userAgent String? @db.Text + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([token]) + @@map("session") +} + +model Account { + id Int @id @default(autoincrement()) + accountId String @db.Text + providerId String @db.Text + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + accessToken String? @db.Text + refreshToken String? @db.Text + idToken String? @db.Text + accessTokenExpiresAt DateTime? + refreshTokenExpiresAt DateTime? + scope String? @db.Text + password String? @db.Text + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@map("account") +} + +model Verification { + id Int @id @default(autoincrement()) + identifier String @db.Text + value String @db.Text + expiresAt DateTime + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + + @@map("verification") +} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/schema-postgresql.prisma b/packages/better-auth/src/adapters/prisma-adapter/test/schema-postgresql.prisma new file mode 100644 index 00000000..889b2dac --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/schema-postgresql.prisma @@ -0,0 +1,70 @@ + +generator client { + provider = "prisma-client-js" + output = "./.tmp/prisma-client-postgresql-6" +} + +datasource db { + provider = "postgresql" + url = "postgres://user:password@localhost:5434/better_auth" +} + +model User { + id Int @id @default(autoincrement()) + name String + email String + emailVerified Boolean @default(false) + image String? + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + sessions Session[] + accounts Account[] + + @@unique([email]) + @@map("user") +} + +model Session { + id Int @id @default(autoincrement()) + expiresAt DateTime + token String + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + ipAddress String? + userAgent String? + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([token]) + @@map("session") +} + +model Account { + id Int @id @default(autoincrement()) + accountId String + providerId String + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + accessToken String? + refreshToken String? + idToken String? + accessTokenExpiresAt DateTime? + refreshTokenExpiresAt DateTime? + scope String? + password String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@map("account") +} + +model Verification { + id Int @id @default(autoincrement()) + identifier String + value String + expiresAt DateTime + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + + @@map("verification") +} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/schema-sqlite.prisma b/packages/better-auth/src/adapters/prisma-adapter/test/schema-sqlite.prisma new file mode 100644 index 00000000..07182188 --- /dev/null +++ b/packages/better-auth/src/adapters/prisma-adapter/test/schema-sqlite.prisma @@ -0,0 +1,70 @@ + +generator client { + provider = "prisma-client-js" + output = "./.tmp/prisma-client-sqlite-6" +} + +datasource db { + provider = "sqlite" + url = "file:./dev.db" +} + +model User { + id Int @id @default(autoincrement()) + name String + email String + emailVerified Boolean @default(false) + image String? + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + sessions Session[] + accounts Account[] + + @@unique([email]) + @@map("user") +} + +model Session { + id Int @id @default(autoincrement()) + expiresAt DateTime + token String + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + ipAddress String? + userAgent String? + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([token]) + @@map("session") +} + +model Account { + id Int @id @default(autoincrement()) + accountId String + providerId String + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + accessToken String? + refreshToken String? + idToken String? + accessTokenExpiresAt DateTime? + refreshTokenExpiresAt DateTime? + scope String? + password String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@map("account") +} + +model Verification { + id Int @id @default(autoincrement()) + identifier String + value String + expiresAt DateTime + createdAt DateTime @default(now()) + updatedAt DateTime @default(now()) @updatedAt + + @@map("verification") +} diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/state.ts b/packages/better-auth/src/adapters/prisma-adapter/test/state.ts deleted file mode 100644 index 7412174b..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/state.ts +++ /dev/null @@ -1,3 +0,0 @@ -import { makeTestState } from "../../../test-utils/state"; - -export const { stateFilePath, getState, setState } = makeTestState(__dirname); diff --git a/packages/better-auth/src/adapters/prisma-adapter/test/test-options.ts b/packages/better-auth/src/adapters/prisma-adapter/test/test-options.ts deleted file mode 100644 index d732666e..00000000 --- a/packages/better-auth/src/adapters/prisma-adapter/test/test-options.ts +++ /dev/null @@ -1,27 +0,0 @@ -import type { Adapter, BetterAuthAdvancedOptions } from "../../../types"; -import type { BetterAuthOptions } from "../../../types"; - -export const createTestOptions = ( - adapter: (options: BetterAuthOptions) => Adapter, - databaseAdvancedOptions: Required["database"] = { - useNumberId: false, - }, -) => - ({ - database: adapter, - user: { - fields: { email: "email_address" }, - additionalFields: { - test: { - type: "string", - defaultValue: "test", - }, - }, - }, - session: { - modelName: "sessions", - }, - advanced: { - database: databaseAdvancedOptions, - }, - }) satisfies BetterAuthOptions; diff --git a/packages/better-auth/src/adapters/test-adapter.ts b/packages/better-auth/src/adapters/test-adapter.ts new file mode 100644 index 00000000..a047ac94 --- /dev/null +++ b/packages/better-auth/src/adapters/test-adapter.ts @@ -0,0 +1,235 @@ +import { afterAll, beforeAll, describe } from "vitest"; +import type { Adapter, BetterAuthOptions } from "../types"; +import { getAuthTables } from "../db"; +import type { createTestSuite } from "./create-test-suite"; +import { colors } from "../utils/colors"; +import { deepmerge } from "./utils"; + +export type Logger = { + info: (...args: any[]) => void; + success: (...args: any[]) => void; + warn: (...args: any[]) => void; + error: (...args: any[]) => void; + debug: (...args: any[]) => void; +}; + +export const testAdapter = async ({ + adapter: getAdapter, + runMigrations, + overrideBetterAuthOptions, + additionalCleanups, + tests, + prefixTests, + onFinish, + customIdGenerator, + defaultRetryCount, +}: { + /** + * A function that will return the adapter instance to test with. + * + * @example + * ```ts + * testAdapter({ + * adapter: (options) => drizzleAdapter(drizzle(db), { + * schema: generateSchema(options), + * }), + * }) + */ + adapter: ( + options: BetterAuthOptions, + ) => + | Promise<(options: BetterAuthOptions) => Adapter> + | ((options: BetterAuthOptions) => Adapter); + /** + * A function that will run the database migrations. + */ + runMigrations: (betterAuthOptions: BetterAuthOptions) => Promise | void; + /** + * Any potential better-auth options overrides. + */ + overrideBetterAuthOptions?: < + Passed extends BetterAuthOptions, + Returned extends BetterAuthOptions, + >( + betterAuthOptions: Passed, + ) => Returned; + /** + * By default we will cleanup all tables automatically, + * but if you have additional cleanup logic, you can pass it here. + * + * Such as deleting a DB file that could had been created. + */ + additionalCleanups?: () => Promise | void; + /** + * A test suite to run. + */ + tests: ReturnType>[]; + /** + * A prefix to add to the test suite name. + */ + prefixTests?: string; + /** + * Upon finish of the tests, this function will be called. + */ + onFinish?: () => Promise | void; + /** + * Custom ID generator function to be used by the helper functions. (such as `insertRandom`) + */ + customIdGenerator?: () => string | Promise; + /** + * Default retry count for the tests. + */ + defaultRetryCount?: number; +}) => { + const defaultBAOptions = {} satisfies BetterAuthOptions; + let betterAuthOptions = (() => { + return { + ...defaultBAOptions, + ...(overrideBetterAuthOptions?.(defaultBAOptions) || {}), + } satisfies BetterAuthOptions; + })(); + + let adapter: Adapter = (await getAdapter(betterAuthOptions))( + betterAuthOptions, + ); + + const adapterName = adapter.options?.adapterConfig.adapterName; + const adapterId = adapter.options?.adapterConfig.adapterId || adapter.id; + const adapterDisplayName = adapterName || adapterId; + + const refreshAdapter = async (betterAuthOptions: BetterAuthOptions) => { + adapter = (await getAdapter(betterAuthOptions))(betterAuthOptions); + }; + + /** + * A helper function to log to the console. + */ + const log: Logger = (() => { + return { + info: (...args: any[]) => + console.log( + `${colors.fg.blue}INFO ${colors.reset} [${adapterDisplayName}]`, + ...args, + ), + success: (...args: any[]) => + console.log( + `${colors.fg.green}SUCCESS${colors.reset} [${adapterDisplayName}]`, + ...args, + ), + warn: (...args: any[]) => + console.log( + `${colors.fg.yellow}WARN ${colors.reset} [${adapterDisplayName}]`, + ...args, + ), + error: (...args: any[]) => + console.log( + `${colors.fg.red}ERROR ${colors.reset} [${adapterDisplayName}]`, + ...args, + ), + debug: (...args: any[]) => + console.log( + `${colors.fg.magenta}DEBUG ${colors.reset} [${adapterDisplayName}]`, + ...args, + ), + }; + })(); + + /** + * Cleanup function to remove all rows from the database. + */ + const cleanup = async () => { + const start = performance.now(); + await refreshAdapter(betterAuthOptions); + const getAllModels = getAuthTables(betterAuthOptions); + + // Clean up all rows from all models + for (const model of Object.keys(getAllModels)) { + try { + await adapter.deleteMany({ model: model, where: [] }); + } catch (error) { + const msg = `Error while cleaning up all rows from ${model}`; + log.error(msg, error); + throw new Error(msg, { + cause: error, + }); + } + } + + // Run additional cleanups + try { + await additionalCleanups?.(); + } catch (error) { + const msg = `Error while running additional cleanups`; + log.error(msg, error); + throw new Error(msg, { + cause: error, + }); + } + await refreshAdapter(betterAuthOptions); + log.success( + `${colors.bright}CLEAN-UP${colors.reset} completed successfully (${(performance.now() - start).toFixed(3)}ms)`, + ); + }; + + /** + * A function that will run the database migrations. + */ + const migrate = async () => { + const start = performance.now(); + + try { + await runMigrations(betterAuthOptions); + } catch (error) { + const msg = `Error while running migrations`; + log.error(msg, error); + throw new Error(msg, { + cause: error, + }); + } + log.success( + `${colors.bright}MIGRATIONS${colors.reset} completed successfully (${(performance.now() - start).toFixed(3)}ms)`, + ); + }; + + return { + execute: () => { + describe(adapterDisplayName, async () => { + beforeAll(async () => { + await migrate(); + }, 20000); + + afterAll(async () => { + await cleanup(); + await onFinish?.(); + }, 20000); + + for (const testSuite of tests) { + await testSuite({ + adapter: async () => { + await refreshAdapter(betterAuthOptions); + return adapter; + }, + adapterDisplayName, + log, + getBetterAuthOptions: () => betterAuthOptions, + modifyBetterAuthOptions: async (options) => { + const newOptions = deepmerge(defaultBAOptions, options); + betterAuthOptions = deepmerge( + newOptions, + overrideBetterAuthOptions?.(newOptions) || {}, + ); + await refreshAdapter(betterAuthOptions); + return betterAuthOptions; + }, + cleanup, + prefixTests, + runMigrations: migrate, + onTestFinish: async () => {}, + customIdGenerator, + defaultRetryCount: defaultRetryCount, + }); + } + }); + }, + }; +}; diff --git a/packages/better-auth/src/adapters/test.ts b/packages/better-auth/src/adapters/test.ts index a8c48373..92a60db2 100644 --- a/packages/better-auth/src/adapters/test.ts +++ b/packages/better-auth/src/adapters/test.ts @@ -5,7 +5,7 @@ import { generateId } from "../utils"; interface AdapterTestOptions { getAdapter: ( customOptions?: Omit, - ) => Promise; + ) => Promise | Adapter; disableTests?: Partial>; testPrefix?: string; } diff --git a/packages/better-auth/src/adapters/tests/auth-flow.ts b/packages/better-auth/src/adapters/tests/auth-flow.ts new file mode 100644 index 00000000..6b2070ef --- /dev/null +++ b/packages/better-auth/src/adapters/tests/auth-flow.ts @@ -0,0 +1,173 @@ +import { expect } from "vitest"; +import { createTestSuite } from "../create-test-suite"; + +/** + * This test suite tests basic authentication flow using the adapter. + */ +export const authFlowTestSuite = createTestSuite( + "auth-flow", + {}, + ( + { generate, getAuth, modifyBetterAuthOptions, tryCatch }, + debug?: { showDB?: () => Promise }, + ) => ({ + "should successfully sign up": async () => { + await modifyBetterAuthOptions( + { + emailAndPassword: { + enabled: true, + password: { hash: async (password) => password }, + }, + }, + false, + ); + const auth = await getAuth(); + const user = await generate("user"); + const start = Date.now(); + const result = await auth.api.signUpEmail({ + body: { + email: user.email, + password: crypto.randomUUID(), + name: user.name, + image: user.image || "", + }, + }); + const end = Date.now(); + console.log(`signUpEmail took ${end - start}ms (without hashing)`); + expect(result.user).toBeDefined(); + expect(result.user.email).toBe(user.email); + expect(result.user.name).toBe(user.name); + expect(result.user.image).toBe(user.image || ""); + expect(result.user.emailVerified).toBe(false); + expect(result.user.createdAt).toBeDefined(); + expect(result.user.updatedAt).toBeDefined(); + }, + "should successfully sign in": async () => { + await modifyBetterAuthOptions( + { + emailAndPassword: { + enabled: true, + password: { + hash: async (password) => password, + async verify(data) { + return data.hash === data.password; + }, + }, + }, + }, + false, + ); + const auth = await getAuth(); + const user = await generate("user"); + const password = crypto.randomUUID(); + const signUpResult = await auth.api.signUpEmail({ + body: { + email: user.email, + password: password, + name: user.name, + image: user.image || "", + }, + }); + const start = Date.now(); + const result = await auth.api.signInEmail({ + body: { email: user.email, password: password }, + }); + const end = Date.now(); + console.log(`signInEmail took ${end - start}ms (without hashing)`); + expect(result.user).toBeDefined(); + expect(result.user.id).toBe(signUpResult.user.id); + }, + "should successfully get session": async () => { + await modifyBetterAuthOptions( + { + emailAndPassword: { + enabled: true, + password: { hash: async (password) => password }, + }, + }, + false, + ); + const auth = await getAuth(); + const user = await generate("user"); + const password = crypto.randomUUID(); + + const { headers, response: signUpResult } = await auth.api.signUpEmail({ + body: { + email: user.email, + password: password, + name: user.name, + image: user.image || "", + }, + returnHeaders: true, + }); + + // Convert set-cookie header to cookie header for getSession call + const modifiedHeaders = new Headers(headers); + if (headers.has("set-cookie")) { + modifiedHeaders.set("cookie", headers.getSetCookie().join("; ")); + modifiedHeaders.delete("set-cookie"); + } + + const start = Date.now(); + const result = await auth.api.getSession({ + headers: modifiedHeaders, + }); + const end = Date.now(); + console.log(`getSession took ${end - start}ms`); + expect(result?.user).toBeDefined(); + expect(result?.user).toStrictEqual(signUpResult.user); + expect(result?.session).toBeDefined(); + }, + "should not sign in with invalid email": async () => { + await modifyBetterAuthOptions( + { emailAndPassword: { enabled: true } }, + false, + ); + const auth = await getAuth(); + const user = await generate("user"); + const { data, error } = await tryCatch( + auth.api.signInEmail({ + body: { email: user.email, password: crypto.randomUUID() }, + }), + ); + expect(data).toBeNull(); + expect(error).toBeDefined(); + }, + "should store and retrieve timestamps correctly across timezones": + async () => { + using _ = recoverProcessTZ(); + await modifyBetterAuthOptions( + { emailAndPassword: { enabled: true } }, + false, + ); + const auth = await getAuth(); + const user = await generate("user"); + const password = crypto.randomUUID(); + const userSignUp = await auth.api.signUpEmail({ + body: { + email: user.email, + password: password, + name: user.name, + image: user.image || "", + }, + }); + process.env.TZ = "Europe/London"; + const userSignIn = await auth.api.signInEmail({ + body: { email: user.email, password: password }, + }); + process.env.TZ = "America/Los_Angeles"; + expect(userSignUp.user.createdAt.toISOString()).toStrictEqual( + userSignIn.user.createdAt.toISOString(), + ); + }, + }), +); + +function recoverProcessTZ() { + const originalTZ = process.env.TZ; + return { + [Symbol.dispose]: () => { + process.env.TZ = originalTZ; + }, + }; +} diff --git a/packages/better-auth/src/adapters/tests/index.ts b/packages/better-auth/src/adapters/tests/index.ts new file mode 100644 index 00000000..63f846a7 --- /dev/null +++ b/packages/better-auth/src/adapters/tests/index.ts @@ -0,0 +1,5 @@ +export * from "./normal"; +export * from "./performance"; +export * from "./transactions"; +export * from "./auth-flow"; +export * from "./number-id"; diff --git a/packages/better-auth/src/adapters/tests/normal.ts b/packages/better-auth/src/adapters/tests/normal.ts new file mode 100644 index 00000000..9964ecb2 --- /dev/null +++ b/packages/better-auth/src/adapters/tests/normal.ts @@ -0,0 +1,561 @@ +import { expect } from "vitest"; +import { createTestSuite } from "../create-test-suite"; +import type { User } from "../../types"; + +/** + * This test suite tests the basic CRUD operations of the adapter. + */ +export const normalTestSuite = createTestSuite("normal", {}, (helpers) => { + const tests = getNormalTestSuiteTests(helpers); + return { + "init - tests": async () => { + const opts = helpers.getBetterAuthOptions(); + expect(opts.advanced?.database?.useNumberId).toBe(undefined); + }, + ...tests, + }; +}); + +export const getNormalTestSuiteTests = ({ + adapter, + generate, + insertRandom, + modifyBetterAuthOptions, + sortModels, + customIdGenerator, + getBetterAuthOptions, +}: Parameters[2]>[0]) => { + /** + * Some databases (such as SQLite) sort rows orders using raw byte values + * Meaning that capitalization, numbers and others goes before the rest of the alphabet + * Because of the inconsistency, as a bare minimum for testing sorting functionality, we should + * remove all capitalizations and numbers from the `name` field + */ + const createBinarySortFriendlyUsers = async (count: number) => { + let users: User[] = []; + for (let i = 0; i < count; i++) { + const user = await generate("user"); + const userResult = await adapter.create({ + model: "user", + data: { + ...user, + name: user.name.replace(/[0-9]/g, "").toLowerCase(), + }, + forceAllowId: true, + }); + users.push(userResult); + } + return users; + }; + + return { + "create - should create a model": async () => { + const user = await generate("user"); + const result = await adapter.create({ + model: "user", + data: user, + forceAllowId: true, + }); + const options = getBetterAuthOptions(); + if (options.advanced?.database?.useNumberId) { + expect(typeof result.id).toEqual("string"); + user.id = result.id; + } else { + expect(typeof result.id).toEqual("string"); + } + expect(result).toEqual(user); + }, + "create - should always return an id": async () => { + const { id: _, ...user } = await generate("user"); + const res = await adapter.create({ + model: "user", + data: user, + }); + expect(res).toHaveProperty("id"); + expect(typeof res.id).toEqual("string"); + }, + "create - should use generateId if provided": async () => { + const ID = (await customIdGenerator?.()) || "MOCK-ID"; + await modifyBetterAuthOptions( + { + advanced: { + database: { + generateId: () => ID, + }, + }, + }, + false, + ); + const { id: _, ...user } = await generate("user"); + const res = await adapter.create({ + model: "user", + data: user, + }); + expect(res.id).toEqual(ID); + const findResult = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: res.id }], + }); + expect(findResult).toEqual(res); + }, + "findOne - should find a model": async () => { + const [user] = await insertRandom("user"); + const result = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: user.id }], + }); + expect(result).toEqual(user); + }, + "findOne - should find a model using a reference field": async () => { + const [user, session] = await insertRandom("session"); + const result = await adapter.findOne({ + model: "session", + where: [{ field: "userId", value: user.id }], + }); + expect(result).toEqual(session); + }, + "findOne - should not throw on record not found": async () => { + const result = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: "100000" }], + }); + expect(result).toBeNull(); + }, + "findOne - should find a model without id": async () => { + const [user] = await insertRandom("user"); + const result = await adapter.findOne({ + model: "user", + where: [{ field: "email", value: user.email }], + }); + expect(result).toEqual(user); + }, + "findOne - should find a model with modified field name": async () => { + await modifyBetterAuthOptions( + { + user: { + fields: { + email: "email_address", + }, + }, + }, + true, + ); + const [user] = await insertRandom("user"); + const result = await adapter.findOne({ + model: "user", + where: [{ field: "email", value: user.email }], + }); + expect(result).toEqual(user); + expect(result?.email).toEqual(user.email); + expect(true).toEqual(true); + }, + "findOne - should select fields": async () => { + const [user] = await insertRandom("user"); + const result = await adapter.findOne>({ + model: "user", + where: [{ field: "id", value: user.id }], + select: ["email", "name"], + }); + expect(result).toEqual({ email: user.email, name: user.name }); + }, + "findMany - should find many models": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + }); + expect(sortModels(result)).toEqual(sortModels(users)); + }, + "findMany - should return an empty array when no models are found": + async () => { + const result = await adapter.findMany({ + model: "user", + where: [{ field: "id", value: "100000" }], + }); + expect(result).toEqual([]); + }, + "findMany - should find many models with starts_with operator": + async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [{ field: "name", value: "user", operator: "starts_with" }], + }); + expect(sortModels(result)).toEqual(sortModels(users)); + }, + "findMany - should find many models with ends_with operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [ + { + field: "name", + value: users[0]!.name.slice(-1), + operator: "ends_with", + }, + ], + }); + const expectedResult = sortModels( + users.filter((user) => user.name.endsWith(users[0]!.name.slice(-1))), + ); + expect(sortModels(result)).toEqual(sortModels(expectedResult)); + }, + "findMany - should find many models with contains operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [{ field: "email", value: "@", operator: "contains" }], + }); + expect(sortModels(result)).toEqual(sortModels(users)); + }, + "findMany - should find many models with eq operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [{ field: "email", value: users[0]!.email, operator: "eq" }], + }); + expect(sortModels(result)).toEqual(sortModels([users[0]!])); + }, + "findMany - should find many models with ne operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [{ field: "email", value: users[0]!.email, operator: "ne" }], + }); + expect(sortModels(result)).toEqual(sortModels(users.slice(1))); + }, + "findMany - should find many models with gt operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const oldestUser = users.sort( + (a, b) => a.createdAt.getTime() - b.createdAt.getTime(), + )[0]!; + const result = await adapter.findMany({ + model: "user", + where: [ + { + field: "createdAt", + value: oldestUser.createdAt, + operator: "gt", + }, + ], + }); + const expectedResult = sortModels( + users.filter((user) => user.createdAt > oldestUser.createdAt), + ); + expect(result.length).not.toBe(0); + expect(sortModels(result)).toEqual(expectedResult); + }, + "findMany - should find many models with gte operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const oldestUser = users.sort( + (a, b) => b.createdAt.getTime() - a.createdAt.getTime(), + )[0]!; + const result = await adapter.findMany({ + model: "user", + where: [ + { + field: "createdAt", + value: oldestUser.createdAt, + operator: "gte", + }, + ], + }); + const expectedResult = users.filter( + (user) => user.createdAt >= oldestUser.createdAt, + ); + expect(result.length).not.toBe(0); + expect(sortModels(result)).toEqual(sortModels(expectedResult)); + }, + "findMany - should find many models with lte operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [ + { field: "createdAt", value: users[0]!.createdAt, operator: "lte" }, + ], + }); + const expectedResult = users.filter( + (user) => user.createdAt <= users[0]!.createdAt, + ); + expect(sortModels(result)).toEqual(sortModels(expectedResult)); + }, + "findMany - should find many models with lt operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [ + { field: "createdAt", value: users[0]!.createdAt, operator: "lt" }, + ], + }); + const expectedResult = users.filter( + (user) => user.createdAt < users[0]!.createdAt, + ); + expect(sortModels(result)).toEqual(sortModels(expectedResult)); + }, + "findMany - should find many models with in operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [ + { + field: "id", + value: [users[0]!.id, users[1]!.id], + operator: "in", + }, + ], + }); + const expectedResult = users.filter( + (user) => user.id === users[0]!.id || user.id === users[1]!.id, + ); + expect(sortModels(result)).toEqual(sortModels(expectedResult)); + }, + "findMany - should find many models with not_in operator": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + where: [ + { + field: "id", + value: [users[0]!.id, users[1]!.id], + operator: "not_in", + }, + ], + }); + expect(sortModels(result)).toEqual([users[2]]); + }, + "findMany - should find many models with sortBy": async () => { + const users = await createBinarySortFriendlyUsers(5); + const result = await adapter.findMany({ + model: "user", + sortBy: { field: "name", direction: "asc" }, + }); + expect(result.map((x) => x.name)).toEqual( + users.map((x) => x.name).sort((a, b) => a.localeCompare(b)), + ); + }, + "findMany - should find many models with limit": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + const result = await adapter.findMany({ + model: "user", + limit: 1, + }); + expect(result.length).toEqual(1); + expect(users.find((x) => x.id === result[0]!.id)).not.toBeNull(); + }, + "findMany - should find many models with offset": async () => { + // Note: The returned rows are ordered in no particular order + // This is because databases return rows in whatever order is fastest for the query. + const count = 10; + await insertRandom("user", count); + const result = await adapter.findMany({ + model: "user", + offset: 2, + }); + expect(result.length).toEqual(count - 2); + }, + "findMany - should find many models with limit and offset": async () => { + // Note: The returned rows are ordered in no particular order + // This is because databases return rows in whatever order is fastest for the query. + const count = 5; + await insertRandom("user", count); + const result = await adapter.findMany({ + model: "user", + limit: 2, + offset: 2, + }); + expect(result.length).toEqual(2); + expect(result).toBeInstanceOf(Array); + result.forEach((user) => { + expect(user).toHaveProperty("id"); + expect(user).toHaveProperty("name"); + expect(user).toHaveProperty("email"); + }); + }, + "findMany - should find many models with sortBy and offset": async () => { + const users = await createBinarySortFriendlyUsers(5); + const result = await adapter.findMany({ + model: "user", + sortBy: { field: "name", direction: "asc" }, + offset: 2, + }); + expect(result).toHaveLength(3); + expect(result).toEqual( + users.sort((a, b) => a["name"].localeCompare(b["name"])).slice(2), + ); + }, + "findMany - should find many models with sortBy and limit": async () => { + const users = await createBinarySortFriendlyUsers(5); + const result = await adapter.findMany({ + model: "user", + sortBy: { field: "name", direction: "asc" }, + limit: 2, + }); + expect(result).toEqual( + users.sort((a, b) => a["name"].localeCompare(b["name"])).slice(0, 2), + ); + }, + "findMany - should find many models with sortBy and limit and offset": + async () => { + const users = await createBinarySortFriendlyUsers(5); + const result = await adapter.findMany({ + model: "user", + sortBy: { field: "name", direction: "asc" }, + limit: 2, + offset: 2, + }); + expect(result).toEqual( + users.sort((a, b) => a["name"].localeCompare(b["name"])).slice(2, 4), + ); + }, + "findMany - should find many models with sortBy and limit and offset and where": + async () => { + const users = await createBinarySortFriendlyUsers(5); + const result = await adapter.findMany({ + model: "user", + sortBy: { field: "name", direction: "asc" }, + limit: 2, + offset: 2, + where: [{ field: "name", value: "user", operator: "starts_with" }], + }); + expect(result).toEqual( + users.sort((a, b) => a["name"].localeCompare(b["name"])).slice(2, 4), + ); + }, + "update - should update a model": async () => { + const [user] = await insertRandom("user"); + const result = await adapter.update({ + model: "user", + where: [{ field: "id", value: user.id }], + update: { name: "test-name" }, + }); + const expectedResult = { + ...user, + name: "test-name", + }; + // because of `onUpdate` hook, the updatedAt field will be different + result!.updatedAt = user.updatedAt; + expect(result).toEqual(expectedResult); + const findResult = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: user.id }], + }); + // because of `onUpdate` hook, the updatedAt field will be different + findResult!.updatedAt = user.updatedAt; + expect(findResult).toEqual(expectedResult); + }, + "updateMany - should update all models when where is empty": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + await adapter.updateMany({ + model: "user", + where: [], + update: { name: "test-name" }, + }); + const result = await adapter.findMany({ + model: "user", + }); + expect(sortModels(result)).toEqual( + sortModels(users).map((user, i) => ({ + ...user, + name: "test-name", + updatedAt: sortModels(result)[i]!.updatedAt, + })), + ); + }, + "updateMany - should update many models with a specific where": + async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + await adapter.updateMany({ + model: "user", + where: [{ field: "id", value: users[0]!.id }], + update: { name: "test-name" }, + }); + const result = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: users[0]!.id }], + }); + expect(result).toEqual({ + ...users[0], + name: "test-name", + updatedAt: result!.updatedAt, + }); + }, + "updateMany - should update many models with a multiple where": + async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + await adapter.updateMany({ + model: "user", + where: [ + { field: "id", value: users[0]!.id, connector: "OR" }, + { field: "id", value: users[1]!.id, connector: "OR" }, + ], + update: { name: "test-name" }, + }); + const result = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: users[0]!.id }], + }); + expect(result).toEqual({ + ...users[0], + name: "test-name", + updatedAt: result!.updatedAt, + }); + }, + "delete - should delete a model": async () => { + const [user] = await insertRandom("user"); + await adapter.delete({ + model: "user", + where: [{ field: "id", value: user.id }], + }); + const result = await adapter.findOne({ + model: "user", + where: [{ field: "id", value: user.id }], + }); + expect(result).toBeNull(); + }, + "delete - should not throw on record not found": async () => { + await expect( + adapter.delete({ + model: "user", + where: [{ field: "id", value: "100000" }], + }), + ).resolves.not.toThrow(); + }, + "deleteMany - should delete many models": async () => { + const users = (await insertRandom("user", 3)).map((x) => x[0]); + await adapter.deleteMany({ + model: "user", + where: [ + { field: "id", value: users[0]!.id, connector: "OR" }, + { field: "id", value: users[1]!.id, connector: "OR" }, + ], + }); + const result = await adapter.findMany({ + model: "user", + }); + expect(sortModels(result)).toEqual(sortModels(users.slice(2))); + }, + "count - should count many models": async () => { + const users = await insertRandom("user", 15); + const result = await adapter.count({ + model: "user", + }); + expect(result).toEqual(users.length); + }, + "count - should return 0 with no rows to count": async () => { + const result = await adapter.count({ + model: "user", + }); + expect(result).toEqual(0); + }, + "count - should count with where clause": async () => { + const users = (await insertRandom("user", 15)).map((x) => x[0]); + const result = await adapter.count({ + model: "user", + where: [ + { field: "id", value: users[2]!.id, connector: "OR" }, + { field: "id", value: users[3]!.id, connector: "OR" }, + ], + }); + expect(result).toEqual(2); + }, + }; +}; diff --git a/packages/better-auth/src/adapters/tests/number-id.ts b/packages/better-auth/src/adapters/tests/number-id.ts new file mode 100644 index 00000000..cee830fa --- /dev/null +++ b/packages/better-auth/src/adapters/tests/number-id.ts @@ -0,0 +1,42 @@ +import { expect } from "vitest"; +import { createTestSuite } from "../create-test-suite"; +import type { User } from "better-auth/types"; +import { getNormalTestSuiteTests } from "./normal"; + +export const numberIdTestSuite = createTestSuite( + "number-id", + { + defaultBetterAuthOptions: { + advanced: { + database: { + useNumberId: true, + }, + }, + }, + alwaysMigrate: true, + prefixTests: "number-id", + }, + (helpers) => { + const { "create - should use generateId if provided": _, ...normalTests } = + getNormalTestSuiteTests({ ...helpers }); + + return { + "init - tests": async () => { + const opts = helpers.getBetterAuthOptions(); + expect(opts.advanced?.database?.useNumberId).toBe(true); + }, + "create - should return a number id": async () => { + const user = await helpers.generate("user"); + const res = await helpers.adapter.create({ + model: "user", + data: user, + forceAllowId: true, + }); + expect(res).toHaveProperty("id"); + expect(typeof res.id).toBe("string"); + expect(parseInt(res.id)).toBeGreaterThan(0); + }, + ...normalTests, + }; + }, +); diff --git a/packages/better-auth/src/adapters/tests/performance.ts b/packages/better-auth/src/adapters/tests/performance.ts new file mode 100644 index 00000000..08e0e5e2 --- /dev/null +++ b/packages/better-auth/src/adapters/tests/performance.ts @@ -0,0 +1,155 @@ +import { assert, expect } from "vitest"; +import { createTestSuite } from "../create-test-suite"; + +/** + * This test suite tests the performance of the adapter and logs the results. + */ +export const performanceTestSuite = createTestSuite( + "performance", + {}, + ( + { adapter, generate, cleanup }, + config?: { iterations?: number; userSeedCount?: number; dialect?: string }, + ) => { + const tests = { + create: [] as number[], + update: [] as number[], + delete: [] as number[], + count: [] as number[], + findOne: [] as number[], + findMany: [] as number[], + }; + + const iterations = config?.iterations ?? 10; + const userSeedCount = config?.userSeedCount ?? 15; + + assert( + userSeedCount >= iterations, + "userSeedCount must be greater than iterations", + ); + + const seedUser = async () => { + const user = await generate("user"); + return await adapter.create({ + model: "user", + data: user, + forceAllowId: true, + }); + }; + const seedManyUsers = async () => { + const users = []; + for (let i = 0; i < userSeedCount; i++) { + users.push(await seedUser()); + } + return users; + }; + + const performanceTests = { + create: async () => { + for (let i = 0; i < iterations; i++) { + const start = performance.now(); + await seedUser(); + const end = performance.now(); + tests.create.push(end - start); + } + }, + update: async () => { + const users = await seedManyUsers(); + for (let i = 0; i < iterations; i++) { + const start = performance.now(); + await adapter.update({ + model: "user", + where: [{ field: "id", value: users[i]!.id }], + update: { + name: `user-${i}`, + }, + }); + const end = performance.now(); + tests.update.push(end - start); + } + }, + delete: async () => { + const users = await seedManyUsers(); + for (let i = 0; i < iterations; i++) { + const start = performance.now(); + await adapter.delete({ + model: "user", + where: [{ field: "id", value: users[i]!.id }], + }); + const end = performance.now(); + tests.delete.push(end - start); + } + }, + count: async () => { + const users = await seedManyUsers(); + for (let i = 0; i < iterations; i++) { + const start = performance.now(); + const c = await adapter.count({ + model: "user", + }); + const end = performance.now(); + tests.count.push(end - start); + expect(c).toEqual(users.length); + } + }, + findOne: async () => { + const users = await seedManyUsers(); + for (let i = 0; i < iterations; i++) { + const start = performance.now(); + await adapter.findOne({ + model: "user", + where: [{ field: "id", value: users[i]!.id }], + }); + const end = performance.now(); + tests.findOne.push(end - start); + } + }, + findMany: async () => { + const users = await seedManyUsers(); + for (let i = 0; i < iterations; i++) { + const start = performance.now(); + const result = await adapter.findMany({ + model: "user", + where: [{ field: "name", value: "user", operator: "starts_with" }], + limit: users.length, + }); + const end = performance.now(); + tests.findMany.push(end - start); + expect(result.length).toBe(users.length); + } + }, + }; + + return { + "run performance test": async () => { + for (const test of Object.keys(performanceTests)) { + await performanceTests[test as keyof typeof performanceTests](); + await cleanup(); + } + + // Calculate averages for each test + const averages = Object.entries(tests).reduce( + (acc, [key, values]) => { + const average = + values.length > 0 + ? values.reduce((sum, val) => sum + val, 0) / values.length + : 0; + acc[key] = `${average.toFixed(3)}ms`; + return acc; + }, + {} as Record, + ); + + console.log(`Performance tests results, counting averages:`); + console.table(averages); + console.log({ + iterations, + userSeedCount, + adapter: adapter.options?.adapterConfig.adapterId, + ...(config?.dialect ? { dialect: config.dialect } : {}), + }); + expect(1).toBe(1); + }, + }; + }, +); diff --git a/packages/better-auth/src/adapters/tests/transactions.ts b/packages/better-auth/src/adapters/tests/transactions.ts new file mode 100644 index 00000000..d322fde8 --- /dev/null +++ b/packages/better-auth/src/adapters/tests/transactions.ts @@ -0,0 +1,40 @@ +import { expect } from "vitest"; +import { createTestSuite } from "../create-test-suite"; +import type { User } from "../../types"; + +/** + * This test suite tests the transaction functionality of the adapter. + */ +export const transactionsTestSuite = createTestSuite( + "transactions", + {}, + ({ adapter, generate, hardCleanup }) => ({ + "transaction - should rollback failing transaction": async ({ skip }) => { + const isEnabled = adapter.options?.adapterConfig.transaction; + if (!isEnabled) { + skip( + `Skipping test: ${adapter.options?.adapterConfig.adapterName} does not support transactions`, + ); + return; + } + + const user1 = await generate("user"); + const user2 = await generate("user"); + await expect( + adapter.transaction(async (tx) => { + await tx.create({ model: "user", data: user1, forceAllowId: true }); + const users = await tx.findMany({ model: "user" }); + expect(users).toHaveLength(1); + throw new Error("Simulated failure"); + await tx.create({ model: "user", data: user2, forceAllowId: true }); + }), + ).rejects.toThrow("Simulated failure"); + const result = await adapter.findMany({ + model: "user", + }); + //Transactions made rows are unable to be automatically cleaned up, so we need to clean them up manually + await hardCleanup(); + expect(result.length).toBe(0); + }, + }), +); diff --git a/packages/better-auth/src/adapters/utils.ts b/packages/better-auth/src/adapters/utils.ts index 8be8cc7d..bad0eb99 100644 --- a/packages/better-auth/src/adapters/utils.ts +++ b/packages/better-auth/src/adapters/utils.ts @@ -25,3 +25,34 @@ export function withApplyDefault( } return value; } + +function isObject(item: unknown): item is Record { + return item !== null && typeof item === "object" && !Array.isArray(item); +} + +export function deepmerge(target: T, source: Partial): T { + if (Array.isArray(target) && Array.isArray(source)) { + // merge arrays by concatenation + return [...target, ...source] as T; + } else if (isObject(target) && isObject(source)) { + const result: Record = { ...target }; + + for (const [key, value] of Object.entries(source)) { + if (value === undefined) continue; // skip undefineds + + if (key in target) { + result[key] = deepmerge( + (target as Record)[key], + value as unknown as Partial, + ); + } else { + result[key] = value; + } + } + + return result as T; + } + + // primitives and fallback: source overrides target + return source as T; +} diff --git a/packages/better-auth/src/db/get-migration.ts b/packages/better-auth/src/db/get-migration.ts index d7f68c9c..267f97c2 100644 --- a/packages/better-auth/src/db/get-migration.ts +++ b/packages/better-auth/src/db/get-migration.ts @@ -53,7 +53,7 @@ const mssqlMap = { string: ["varchar", "nvarchar"], number: ["int", "bigint", "smallint", "decimal", "float", "double"], boolean: ["bit", "smallint"], - date: ["datetime", "date"], + date: ["datetime2", "date", "datetime"], json: ["varchar", "nvarchar"], }; @@ -207,8 +207,8 @@ export async function getMigrations(config: BetterAuthOptions) { date: { sqlite: "date", postgres: "timestamptz", - mysql: "timestamp", - mssql: "datetime", + mysql: "timestamp(3)", + mssql: sql`datetime2(3)`, }, json: { sqlite: "text", @@ -266,7 +266,11 @@ export async function getMigrations(config: BetterAuthOptions) { dbType === "mysql" || dbType === "mssql") ) { - col = col.defaultTo(sql`CURRENT_TIMESTAMP`); + if (dbType === "mysql") { + col = col.defaultTo(sql`CURRENT_TIMESTAMP(3)`); + } else { + col = col.defaultTo(sql`CURRENT_TIMESTAMP`); + } } return col; }); @@ -291,6 +295,8 @@ export async function getMigrations(config: BetterAuthOptions) { if (config.advanced?.database?.useNumberId) { if (dbType === "postgres" || dbType === "sqlite") { return col.primaryKey().notNull(); + } else if (dbType === "mssql") { + return col.identity().primaryKey().notNull(); } return col.autoIncrement().primaryKey().notNull(); } @@ -316,7 +322,11 @@ export async function getMigrations(config: BetterAuthOptions) { typeof field.defaultValue === "function" && (dbType === "postgres" || dbType === "mysql" || dbType === "mssql") ) { - col = col.defaultTo(sql`CURRENT_TIMESTAMP`); + if (dbType === "mysql") { + col = col.defaultTo(sql`CURRENT_TIMESTAMP(3)`); + } else { + col = col.defaultTo(sql`CURRENT_TIMESTAMP`); + } } return col; }); diff --git a/packages/better-auth/src/utils/colors.ts b/packages/better-auth/src/utils/colors.ts new file mode 100644 index 00000000..f075e65e --- /dev/null +++ b/packages/better-auth/src/utils/colors.ts @@ -0,0 +1,30 @@ +export const colors = { + reset: "\x1b[0m", + bright: "\x1b[1m", + dim: "\x1b[2m", + undim: "\x1b[22m", + underscore: "\x1b[4m", + blink: "\x1b[5m", + reverse: "\x1b[7m", + hidden: "\x1b[8m", + fg: { + black: "\x1b[30m", + red: "\x1b[31m", + green: "\x1b[32m", + yellow: "\x1b[33m", + blue: "\x1b[34m", + magenta: "\x1b[35m", + cyan: "\x1b[36m", + white: "\x1b[37m", + }, + bg: { + black: "\x1b[40m", + red: "\x1b[41m", + green: "\x1b[42m", + yellow: "\x1b[43m", + blue: "\x1b[44m", + magenta: "\x1b[45m", + cyan: "\x1b[46m", + white: "\x1b[47m", + }, +}; diff --git a/packages/better-auth/src/utils/logger.ts b/packages/better-auth/src/utils/logger.ts index ce056a56..7944d897 100644 --- a/packages/better-auth/src/utils/logger.ts +++ b/packages/better-auth/src/utils/logger.ts @@ -1,4 +1,5 @@ import { getColorDepth } from "./color-depth"; +import { colors } from "./colors"; export type LogLevel = "info" | "success" | "warn" | "error" | "debug"; @@ -29,36 +30,6 @@ export type LogHandlerParams = Parameters> extends [ ? Rest : never; -const colors = { - reset: "\x1b[0m", - bright: "\x1b[1m", - dim: "\x1b[2m", - underscore: "\x1b[4m", - blink: "\x1b[5m", - reverse: "\x1b[7m", - hidden: "\x1b[8m", - fg: { - black: "\x1b[30m", - red: "\x1b[31m", - green: "\x1b[32m", - yellow: "\x1b[33m", - blue: "\x1b[34m", - magenta: "\x1b[35m", - cyan: "\x1b[36m", - white: "\x1b[37m", - }, - bg: { - black: "\x1b[40m", - red: "\x1b[41m", - green: "\x1b[42m", - yellow: "\x1b[43m", - blue: "\x1b[44m", - magenta: "\x1b[45m", - cyan: "\x1b[46m", - white: "\x1b[47m", - }, -}; - const levelColors: Record = { info: colors.fg.blue, success: colors.fg.green, diff --git a/packages/cli/src/generators/drizzle.ts b/packages/cli/src/generators/drizzle.ts index 1c965d81..4d5e4962 100644 --- a/packages/cli/src/generators/drizzle.ts +++ b/packages/cli/src/generators/drizzle.ts @@ -106,7 +106,7 @@ export const generateDrizzleSchema: SchemaGenerator = async ({ date: { sqlite: `integer('${name}', { mode: 'timestamp_ms' })`, pg: `timestamp('${name}')`, - mysql: `timestamp('${name}')`, + mysql: `timestamp('${name}', { fsp: 3 })`, }, "number[]": { sqlite: `integer('${name}').array()`, @@ -137,7 +137,7 @@ export const generateDrizzleSchema: SchemaGenerator = async ({ if (databaseType === "pg") { id = `serial("id").primaryKey()`; } else if (databaseType === "sqlite") { - id = `int("id").primaryKey()`; + id = `integer("id", { mode: "number" }).primaryKey({ autoIncrement: true })`; } else { id = `int("id").autoincrement().primaryKey()`; } @@ -159,7 +159,8 @@ export const generateDrizzleSchema: SchemaGenerator = async ({ ${Object.keys(fields) .map((field) => { const attr = fields[field]!; - let type = getType(field, attr); + const fieldName = attr.fieldName || field; + let type = getType(fieldName, attr); if ( attr.defaultValue !== null && typeof attr.defaultValue !== "undefined" @@ -190,7 +191,7 @@ export const generateDrizzleSchema: SchemaGenerator = async ({ type += `.$onUpdate(${attr.onUpdate})`; } } - return `${field}: ${type}${attr.required ? ".notNull()" : ""}${ + return `${fieldName}: ${type}${attr.required ? ".notNull()" : ""}${ attr.unique ? ".unique()" : "" }${ attr.references @@ -198,7 +199,7 @@ export const generateDrizzleSchema: SchemaGenerator = async ({ tables[attr.references.model]?.modelName || attr.references.model, adapter.options, - )}.${attr.references.field}, { onDelete: '${ + )}.${fields[attr.references.field]?.fieldName || attr.references.field}, { onDelete: '${ attr.references.onDelete || "cascade" }' })` : "" diff --git a/packages/cli/src/generators/prisma.ts b/packages/cli/src/generators/prisma.ts index 8de663e0..e5304fa6 100644 --- a/packages/cli/src/generators/prisma.ts +++ b/packages/cli/src/generators/prisma.ts @@ -109,13 +109,11 @@ export const generatePrismaSchema: SchemaGenerator = async ({ .attribute(`map("_id")`); } else { if (options.advanced?.database?.useNumberId) { - const col = builder + builder .model(modelName) .field("id", "Int") - .attribute("id"); - if (provider !== "sqlite") { - col.attribute("default(autoincrement())"); - } + .attribute("id") + .attribute("default(autoincrement())"); } else { builder.model(modelName).field("id", "String").attribute("id"); } @@ -160,8 +158,6 @@ export const generatePrismaSchema: SchemaGenerator = async ({ if (provider === "mongodb") { fieldBuilder.attribute(`map("_id")`); } - } else if (fieldName !== field) { - fieldBuilder.attribute(`map("${field}")`); } if (attr.unique) { diff --git a/packages/cli/test/__snapshots__/auth-schema-mysql-number-id.txt b/packages/cli/test/__snapshots__/auth-schema-mysql-number-id.txt index e9255a4f..b93a360a 100644 --- a/packages/cli/test/__snapshots__/auth-schema-mysql-number-id.txt +++ b/packages/cli/test/__snapshots__/auth-schema-mysql-number-id.txt @@ -13,8 +13,8 @@ export const custom_user = mysqlTable("custom_user", { email: varchar("email", { length: 255 }).notNull().unique(), emailVerified: boolean("email_verified").default(false).notNull(), image: text("image"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), @@ -25,10 +25,10 @@ export const custom_user = mysqlTable("custom_user", { export const custom_session = mysqlTable("custom_session", { id: int("id").autoincrement().primaryKey(), - expiresAt: timestamp("expires_at").notNull(), + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), token: varchar("token", { length: 255 }).notNull().unique(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), ipAddress: text("ip_address"), @@ -48,12 +48,12 @@ export const custom_account = mysqlTable("custom_account", { accessToken: text("access_token"), refreshToken: text("refresh_token"), idToken: text("id_token"), - accessTokenExpiresAt: timestamp("access_token_expires_at"), - refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), + accessTokenExpiresAt: timestamp("access_token_expires_at", { fsp: 3 }), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at", { fsp: 3 }), scope: text("scope"), password: text("password"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), }); @@ -62,9 +62,9 @@ export const custom_verification = mysqlTable("custom_verification", { id: int("id").autoincrement().primaryKey(), identifier: text("identifier").notNull(), value: text("value").notNull(), - expiresAt: timestamp("expires_at").notNull(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), diff --git a/packages/cli/test/__snapshots__/auth-schema-mysql-passkey-number-id.txt b/packages/cli/test/__snapshots__/auth-schema-mysql-passkey-number-id.txt index ccb9b41a..683362c7 100644 --- a/packages/cli/test/__snapshots__/auth-schema-mysql-passkey-number-id.txt +++ b/packages/cli/test/__snapshots__/auth-schema-mysql-passkey-number-id.txt @@ -13,8 +13,8 @@ export const custom_user = mysqlTable("custom_user", { email: varchar("email", { length: 255 }).notNull().unique(), emailVerified: boolean("email_verified").default(false).notNull(), image: text("image"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), @@ -22,10 +22,10 @@ export const custom_user = mysqlTable("custom_user", { export const custom_session = mysqlTable("custom_session", { id: int("id").autoincrement().primaryKey(), - expiresAt: timestamp("expires_at").notNull(), + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), token: varchar("token", { length: 255 }).notNull().unique(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), ipAddress: text("ip_address"), @@ -45,12 +45,12 @@ export const custom_account = mysqlTable("custom_account", { accessToken: text("access_token"), refreshToken: text("refresh_token"), idToken: text("id_token"), - accessTokenExpiresAt: timestamp("access_token_expires_at"), - refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), + accessTokenExpiresAt: timestamp("access_token_expires_at", { fsp: 3 }), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at", { fsp: 3 }), scope: text("scope"), password: text("password"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), }); @@ -59,9 +59,9 @@ export const custom_verification = mysqlTable("custom_verification", { id: int("id").autoincrement().primaryKey(), identifier: text("identifier").notNull(), value: text("value").notNull(), - expiresAt: timestamp("expires_at").notNull(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), @@ -79,6 +79,6 @@ export const passkey = mysqlTable("passkey", { deviceType: text("device_type").notNull(), backedUp: boolean("backed_up").notNull(), transports: text("transports"), - createdAt: timestamp("created_at"), + createdAt: timestamp("created_at", { fsp: 3 }), aaguid: text("aaguid"), }); diff --git a/packages/cli/test/__snapshots__/auth-schema-mysql-passkey.txt b/packages/cli/test/__snapshots__/auth-schema-mysql-passkey.txt index 44ef95ac..19d5b129 100644 --- a/packages/cli/test/__snapshots__/auth-schema-mysql-passkey.txt +++ b/packages/cli/test/__snapshots__/auth-schema-mysql-passkey.txt @@ -13,8 +13,8 @@ export const custom_user = mysqlTable("custom_user", { email: varchar("email", { length: 255 }).notNull().unique(), emailVerified: boolean("email_verified").default(false).notNull(), image: text("image"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), @@ -22,10 +22,10 @@ export const custom_user = mysqlTable("custom_user", { export const custom_session = mysqlTable("custom_session", { id: varchar("id", { length: 36 }).primaryKey(), - expiresAt: timestamp("expires_at").notNull(), + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), token: varchar("token", { length: 255 }).notNull().unique(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), ipAddress: text("ip_address"), @@ -45,12 +45,12 @@ export const custom_account = mysqlTable("custom_account", { accessToken: text("access_token"), refreshToken: text("refresh_token"), idToken: text("id_token"), - accessTokenExpiresAt: timestamp("access_token_expires_at"), - refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), + accessTokenExpiresAt: timestamp("access_token_expires_at", { fsp: 3 }), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at", { fsp: 3 }), scope: text("scope"), password: text("password"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), }); @@ -59,9 +59,9 @@ export const custom_verification = mysqlTable("custom_verification", { id: varchar("id", { length: 36 }).primaryKey(), identifier: text("identifier").notNull(), value: text("value").notNull(), - expiresAt: timestamp("expires_at").notNull(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), @@ -79,6 +79,6 @@ export const passkey = mysqlTable("passkey", { deviceType: text("device_type").notNull(), backedUp: boolean("backed_up").notNull(), transports: text("transports"), - createdAt: timestamp("created_at"), + createdAt: timestamp("created_at", { fsp: 3 }), aaguid: text("aaguid"), }); diff --git a/packages/cli/test/__snapshots__/auth-schema-mysql.txt b/packages/cli/test/__snapshots__/auth-schema-mysql.txt index 3bc03762..569bbe49 100644 --- a/packages/cli/test/__snapshots__/auth-schema-mysql.txt +++ b/packages/cli/test/__snapshots__/auth-schema-mysql.txt @@ -12,8 +12,8 @@ export const custom_user = mysqlTable("custom_user", { email: varchar("email", { length: 255 }).notNull().unique(), emailVerified: boolean("email_verified").default(false).notNull(), image: text("image"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), @@ -24,10 +24,10 @@ export const custom_user = mysqlTable("custom_user", { export const custom_session = mysqlTable("custom_session", { id: varchar("id", { length: 36 }).primaryKey(), - expiresAt: timestamp("expires_at").notNull(), + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), token: varchar("token", { length: 255 }).notNull().unique(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), ipAddress: text("ip_address"), @@ -47,12 +47,12 @@ export const custom_account = mysqlTable("custom_account", { accessToken: text("access_token"), refreshToken: text("refresh_token"), idToken: text("id_token"), - accessTokenExpiresAt: timestamp("access_token_expires_at"), - refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), + accessTokenExpiresAt: timestamp("access_token_expires_at", { fsp: 3 }), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at", { fsp: 3 }), scope: text("scope"), password: text("password"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), }); @@ -61,9 +61,9 @@ export const custom_verification = mysqlTable("custom_verification", { id: varchar("id", { length: 36 }).primaryKey(), identifier: text("identifier").notNull(), value: text("value").notNull(), - expiresAt: timestamp("expires_at").notNull(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) .defaultNow() .$onUpdate(() => /* @__PURE__ */ new Date()) .notNull(), diff --git a/packages/cli/test/__snapshots__/auth-schema-sqlite-number-id.txt b/packages/cli/test/__snapshots__/auth-schema-sqlite-number-id.txt index 81570a04..4c33a8cb 100644 --- a/packages/cli/test/__snapshots__/auth-schema-sqlite-number-id.txt +++ b/packages/cli/test/__snapshots__/auth-schema-sqlite-number-id.txt @@ -2,7 +2,7 @@ import { sql } from "drizzle-orm"; import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core"; export const custom_user = sqliteTable("custom_user", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), name: text("name").notNull(), email: text("email").notNull().unique(), emailVerified: integer("email_verified", { mode: "boolean" }) @@ -24,7 +24,7 @@ export const custom_user = sqliteTable("custom_user", { }); export const custom_session = sqliteTable("custom_session", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), token: text("token").notNull().unique(), createdAt: integer("created_at", { mode: "timestamp_ms" }) @@ -41,7 +41,7 @@ export const custom_session = sqliteTable("custom_session", { }); export const custom_account = sqliteTable("custom_account", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), accountId: text("account_id").notNull(), providerId: text("provider_id").notNull(), userId: integer("user_id") @@ -67,7 +67,7 @@ export const custom_account = sqliteTable("custom_account", { }); export const custom_verification = sqliteTable("custom_verification", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), identifier: text("identifier").notNull(), value: text("value").notNull(), expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), @@ -81,7 +81,7 @@ export const custom_verification = sqliteTable("custom_verification", { }); export const twoFactor = sqliteTable("two_factor", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), secret: text("secret").notNull(), backupCodes: text("backup_codes").notNull(), userId: integer("user_id") diff --git a/packages/cli/test/__snapshots__/auth-schema-sqlite-passkey-number-id.txt b/packages/cli/test/__snapshots__/auth-schema-sqlite-passkey-number-id.txt index 5d5d20be..ddedf438 100644 --- a/packages/cli/test/__snapshots__/auth-schema-sqlite-passkey-number-id.txt +++ b/packages/cli/test/__snapshots__/auth-schema-sqlite-passkey-number-id.txt @@ -2,7 +2,7 @@ import { sql } from "drizzle-orm"; import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core"; export const custom_user = sqliteTable("custom_user", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), name: text("name").notNull(), email: text("email").notNull().unique(), emailVerified: integer("email_verified", { mode: "boolean" }) @@ -19,7 +19,7 @@ export const custom_user = sqliteTable("custom_user", { }); export const custom_session = sqliteTable("custom_session", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), token: text("token").notNull().unique(), createdAt: integer("created_at", { mode: "timestamp_ms" }) @@ -36,7 +36,7 @@ export const custom_session = sqliteTable("custom_session", { }); export const custom_account = sqliteTable("custom_account", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), accountId: text("account_id").notNull(), providerId: text("provider_id").notNull(), userId: integer("user_id") @@ -62,7 +62,7 @@ export const custom_account = sqliteTable("custom_account", { }); export const custom_verification = sqliteTable("custom_verification", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), identifier: text("identifier").notNull(), value: text("value").notNull(), expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), @@ -76,7 +76,7 @@ export const custom_verification = sqliteTable("custom_verification", { }); export const passkey = sqliteTable("passkey", { - id: int("id").primaryKey(), + id: integer("id", { mode: "number" }).primaryKey({ autoIncrement: true }), name: text("name"), publicKey: text("public_key").notNull(), userId: integer("user_id") diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 02e06e04..25bac82d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -895,6 +895,9 @@ importers: deepmerge: specifier: ^4.3.1 version: 4.3.1 + drizzle-kit: + specifier: ^0.31.4 + version: 0.31.4 drizzle-orm: specifier: ^0.38.2 version: 0.38.4(@cloudflare/workers-types@4.20250903.0)(@libsql/client@0.15.14)(@prisma/client@5.22.0(prisma@5.22.0))(@types/better-sqlite3@7.6.13)(@types/pg@8.15.5)(@types/react@18.3.23)(better-sqlite3@12.2.0)(bun-types@1.2.21(@types/react@18.3.23))(kysely@0.28.5)(mysql2@3.14.4)(pg@8.16.3)(postgres@3.4.7)(prisma@5.22.0)(react@19.1.1) @@ -14925,7 +14928,7 @@ snapshots: postcss: 8.4.49 resolve-from: 5.0.0 optionalDependencies: - expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.81.4(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@react-native/metro-config@0.81.0(@babel/core@7.28.4))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) + expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.80.2(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) transitivePeerDependencies: - bufferutil - supports-color @@ -15010,7 +15013,7 @@ snapshots: '@expo/json-file': 10.0.7 '@react-native/normalize-colors': 0.81.4 debug: 4.4.1 - expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.81.4(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@react-native/metro-config@0.81.0(@babel/core@7.28.4))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) + expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.80.2(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) resolve-from: 5.0.0 semver: 7.7.2 xml2js: 0.6.0 @@ -17164,7 +17167,9 @@ snapshots: metro-runtime: 0.83.1 transitivePeerDependencies: - '@babel/core' + - bufferutil - supports-color + - utf-8-validate optional: true '@react-native/normalize-colors@0.74.89': {} @@ -19162,7 +19167,7 @@ snapshots: resolve-from: 5.0.0 optionalDependencies: '@babel/runtime': 7.28.4 - expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.81.4(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@react-native/metro-config@0.81.0(@babel/core@7.28.4))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) + expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.80.2(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) transitivePeerDependencies: - '@babel/core' - supports-color @@ -20721,7 +20726,7 @@ snapshots: expo-keep-awake@15.0.7(expo@54.0.10)(react@19.1.1): dependencies: - expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.81.4(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@react-native/metro-config@0.81.0(@babel/core@7.28.4))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) + expo: 54.0.10(@babel/core@7.28.4)(@expo/metro-runtime@6.1.2)(expo-router@6.0.8)(graphql@16.11.0)(react-native@0.80.2(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1) react: 19.1.1 expo-linking@7.1.7(expo@54.0.10)(react-native@0.80.2(@babel/core@7.28.4)(@react-native-community/cli@20.0.1(typescript@5.9.2))(@types/react@19.1.12)(react@19.1.1))(react@19.1.1):