feat: database transaction support (#4414)

Co-authored-by: Joel Solano <solano.joel@gmx.de>
This commit is contained in:
Alex Yang
2025-09-09 17:10:27 -07:00
committed by GitHub
parent 61b6a87435
commit 22c39b92be
16 changed files with 647 additions and 228 deletions

View File

@@ -20,7 +20,7 @@ const client = new MongoClient("mongodb://localhost:27017/database");
const db = client.db(); const db = client.db();
export const auth = betterAuth({ export const auth = betterAuth({
database: mongodbAdapter(db), database: mongodbAdapter(db, { client }),
}); });
``` ```

View File

@@ -452,6 +452,16 @@ const adapter = myAdapter({
}); });
``` ```
### `transaction`
Whether the adapter supports transactions. If `false`, operations run sequentially; otherwise provide a function that executes a callback with a `TransactionAdapter`.
<Callout type="warn">
If your database does not support transactions, the error handling and rollback
will not be as robust. We recommend using a database that supports transactions
for better data integrity.
</Callout>
### `debugLogs` ### `debugLogs`
Used to enable debug logs for the adapter. You can pass in a boolean, or an object with the following keys: `create`, `update`, `updateMany`, `findOne`, `findMany`, `delete`, `deleteMany`, `count`. Used to enable debug logs for the adapter. You can pass in a boolean, or an object with the following keys: `create`, `update`, `updateMany`, `findOne`, `findMany`, `delete`, `deleteMany`, `count`.

View File

@@ -20,11 +20,13 @@ exports[`init > should match config 1`] = `
"supportsDates": false, "supportsDates": false,
"supportsJSON": false, "supportsJSON": false,
"supportsNumericIds": true, "supportsNumericIds": true,
"transaction": [Function],
"usePlural": undefined, "usePlural": undefined,
}, },
"debugLogs": false, "debugLogs": false,
"type": "sqlite", "type": "sqlite",
}, },
"transaction": [Function],
"update": [Function], "update": [Function],
"updateMany": [Function], "updateMany": [Function],
}, },

View File

@@ -1,13 +1,17 @@
import { safeJSONParse } from "../../utils/json"; import { safeJSONParse } from "../../utils/json";
import { withApplyDefault } from "../../adapters/utils"; import { withApplyDefault } from "../../adapters/utils";
import { getAuthTables } from "../../db/get-tables"; import { getAuthTables } from "../../db/get-tables";
import type { Adapter, BetterAuthOptions, Where } from "../../types"; import type {
Adapter,
BetterAuthOptions,
TransactionAdapter,
Where,
} from "../../types";
import { generateId as defaultGenerateId, logger } from "../../utils"; import { generateId as defaultGenerateId, logger } from "../../utils";
import type { import type {
AdapterConfig, CreateAdapterOptions,
AdapterTestDebugLogs, AdapterTestDebugLogs,
CleanedWhere, CleanedWhere,
CreateCustomAdapter,
} from "./types"; } from "./types";
import type { FieldAttribute } from "../../db"; import type { FieldAttribute } from "../../db";
export * from "./types"; export * from "./types";
@@ -45,14 +49,13 @@ const colors = {
}, },
}; };
const createAsIsTransaction =
(adapter: Adapter) =>
<R>(fn: (trx: TransactionAdapter) => Promise<R>) =>
fn(adapter);
export const createAdapter = export const createAdapter =
({ ({ adapter: customAdapter, config: cfg }: CreateAdapterOptions) =>
adapter,
config: cfg,
}: {
config: AdapterConfig;
adapter: CreateCustomAdapter;
}) =>
(options: BetterAuthOptions): Adapter => { (options: BetterAuthOptions): Adapter => {
const config = { const config = {
...cfg, ...cfg,
@@ -315,7 +318,7 @@ export const createAdapter =
return fields[defaultFieldName]; return fields[defaultFieldName];
}; };
const adapterInstance = adapter({ const adapterInstance = customAdapter({
options, options,
schema, schema,
debugLog, debugLog,
@@ -560,7 +563,24 @@ export const createAdapter =
}) as any; }) as any;
}; };
return { let lazyLoadTransaction: Adapter["transaction"] | null = null;
const adapter: Adapter = {
transaction: async (cb) => {
if (!lazyLoadTransaction) {
if (!config.transaction) {
logger.warn(
`[${config.adapterName}] - Transactions are not supported. Executing operations sequentially.`,
);
lazyLoadTransaction = createAsIsTransaction(adapter);
} else {
logger.debug(
`[${config.adapterName}] - Using provided transaction implementation.`,
);
lazyLoadTransaction = config.transaction;
}
}
return lazyLoadTransaction(cb);
},
create: async <T extends Record<string, any>, R = T>({ create: async <T extends Record<string, any>, R = T>({
data: unsafeData, data: unsafeData,
model: unsafeModel, model: unsafeModel,
@@ -1016,6 +1036,7 @@ export const createAdapter =
} }
: {}), : {}),
}; };
return adapter;
}; };
function formatTransactionId(transactionId: number) { function formatTransactionId(transactionId: number) {

View File

@@ -3,6 +3,7 @@ import type { BetterAuthDbSchema } from "../../db/get-tables";
import type { import type {
AdapterSchemaCreation, AdapterSchemaCreation,
BetterAuthOptions, BetterAuthOptions,
TransactionAdapter,
Where, Where,
} from "../../types"; } from "../../types";
import type { Prettify } from "../../types/helper"; import type { Prettify } from "../../types/helper";
@@ -32,6 +33,11 @@ export type AdapterDebugLogs =
isRunningAdapterTests: boolean; isRunningAdapterTests: boolean;
}; };
export type CreateAdapterOptions = {
config: AdapterConfig;
adapter: CreateCustomAdapter;
};
export interface AdapterConfig { export interface AdapterConfig {
/** /**
* Use plural table names. * Use plural table names.
@@ -89,6 +95,15 @@ export interface AdapterConfig {
* @default true * @default true
*/ */
supportsBooleans?: boolean; supportsBooleans?: boolean;
/**
* Execute multiple operations in a transaction.
*
* If the database doesn't support transactions, set this to `false` and operations will be executed sequentially.
*
*/
transaction?:
| false
| (<R>(callback: (trx: TransactionAdapter) => Promise<R>) => Promise<R>);
/** /**
* Disable id generation for the `create` method. * Disable id generation for the `create` method.
* *

View File

@@ -17,8 +17,13 @@ import {
SQL, SQL,
} from "drizzle-orm"; } from "drizzle-orm";
import { BetterAuthError } from "../../error"; import { BetterAuthError } from "../../error";
import type { Where } from "../../types"; import type { Adapter, BetterAuthOptions, Where } from "../../types";
import { createAdapter, type AdapterDebugLogs } from "../create-adapter"; import {
createAdapter,
type AdapterDebugLogs,
type CreateAdapterOptions,
type CreateCustomAdapter,
} from "../create-adapter";
export interface DB { export interface DB {
[key: string]: any; [key: string]: any;
@@ -52,17 +57,21 @@ export interface DrizzleAdapterConfig {
* @default false * @default false
*/ */
camelCase?: boolean; camelCase?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
} }
export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) => export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) => {
createAdapter({ let lazyOptions: BetterAuthOptions | null = null;
config: { const createCustomAdapter =
adapterId: "drizzle", (db: DB): CreateCustomAdapter =>
adapterName: "Drizzle Adapter", ({ getFieldName, debugLog }) => {
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
},
adapter: ({ getFieldName, debugLog }) => {
function getSchema(model: string) { function getSchema(model: string) {
const schema = config.schema || db._.fullSchema; const schema = config.schema || db._.fullSchema;
if (!schema) { if (!schema) {
@@ -343,5 +352,31 @@ export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) =>
}, },
options: config, options: config,
}; };
};
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "drizzle",
adapterName: "Drizzle Adapter",
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
transaction:
(config.transaction ?? true)
? (cb) =>
db.transaction((tx: DB) => {
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(tx),
})(lazyOptions!);
return cb(adapter);
})
: false,
}, },
}); adapter: createCustomAdapter(db),
};
const adapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return adapter(options);
};
};

View File

@@ -1,5 +1,10 @@
import { createAdapter, type AdapterDebugLogs } from "../create-adapter"; import {
import type { Where } from "../../types"; createAdapter,
type AdapterDebugLogs,
type CreateCustomAdapter,
type CreateAdapterOptions,
} from "../create-adapter";
import type { Adapter, BetterAuthOptions, Where } from "../../types";
import type { KyselyDatabaseType } from "./types"; import type { KyselyDatabaseType } from "./types";
import type { InsertQueryBuilder, Kysely, UpdateQueryBuilder } from "kysely"; import type { InsertQueryBuilder, Kysely, UpdateQueryBuilder } from "kysely";
@@ -20,26 +25,23 @@ interface KyselyAdapterConfig {
* @default false * @default false
*/ */
usePlural?: boolean; usePlural?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
} }
export const kyselyAdapter = (db: Kysely<any>, config?: KyselyAdapterConfig) => export const kyselyAdapter = (
createAdapter({ db: Kysely<any>,
config: { config?: KyselyAdapterConfig,
adapterId: "kysely", ) => {
adapterName: "Kysely Adapter", let lazyOptions: BetterAuthOptions | null = null;
usePlural: config?.usePlural, const createCustomAdapter = (db: Kysely<any>): CreateCustomAdapter => {
debugLogs: config?.debugLogs, return ({ getFieldName, schema }) => {
supportsBooleans:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsDates:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsJSON: false,
},
adapter: ({ getFieldName, schema }) => {
const withReturning = async ( const withReturning = async (
values: Record<string, any>, values: Record<string, any>,
builder: builder:
@@ -315,5 +317,43 @@ export const kyselyAdapter = (db: Kysely<any>, config?: KyselyAdapterConfig) =>
}, },
options: config, options: config,
}; };
};
};
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "kysely",
adapterName: "Kysely Adapter",
usePlural: config?.usePlural,
debugLogs: config?.debugLogs,
supportsBooleans:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsDates:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsJSON: false,
transaction:
(config?.transaction ?? true)
? (cb) =>
db.transaction().execute((trx) => {
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(trx),
})(lazyOptions!);
return cb(adapter);
})
: false,
}, },
}); adapter: createCustomAdapter(db),
};
const adapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return adapter(options);
};
};

View File

@@ -24,6 +24,10 @@ describe("adapter test", async () => {
...customOptions, ...customOptions,
}); });
}, },
disableTests: {
SHOULD_ROLLBACK_FAILING_TRANSACTION: true,
SHOULD_RETURN_TRANSACTION_RESULT: true,
},
}); });
}); });

View File

@@ -4,6 +4,7 @@ import {
type AdapterDebugLogs, type AdapterDebugLogs,
type CleanedWhere, type CleanedWhere,
} from "../create-adapter"; } from "../create-adapter";
import type { BetterAuthOptions } from "../../types";
export interface MemoryDB { export interface MemoryDB {
[key: string]: any[]; [key: string]: any[];
@@ -13,8 +14,9 @@ export interface MemoryAdapterConfig {
debugLogs?: AdapterDebugLogs; debugLogs?: AdapterDebugLogs;
} }
export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => {
createAdapter({ let lazyOptions: BetterAuthOptions | null = null;
let adapterCreator = createAdapter({
config: { config: {
adapterId: "memory", adapterId: "memory",
adapterName: "Memory Adapter", adapterName: "Memory Adapter",
@@ -30,6 +32,18 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) =>
} }
return props.data; return props.data;
}, },
transaction: async (cb) => {
let clone = structuredClone(db);
try {
return cb(adapterCreator(lazyOptions!));
} catch {
// Rollback changes
Object.keys(db).forEach((key) => {
db[key] = clone[key];
});
throw new Error("Transaction failed, rolling back changes");
}
},
}, },
adapter: ({ getFieldName, options, debugLog }) => { adapter: ({ getFieldName, options, debugLog }) => {
function convertWhereClause(where: CleanedWhere[], model: string) { function convertWhereClause(where: CleanedWhere[], model: string) {
@@ -147,3 +161,8 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) =>
}; };
}, },
}); });
return (options: BetterAuthOptions) => {
lazyOptions = options;
return adapterCreator(options);
};
};

View File

@@ -9,11 +9,14 @@ describe("adapter test", async () => {
const client = new MongoClient(connectionString); const client = new MongoClient(connectionString);
await client.connect(); await client.connect();
const db = client.db(dbName); const db = client.db(dbName);
return db; return { db, client };
}; };
const user = "user"; const user = "user";
const db = await dbClient("mongodb://127.0.0.1:27017", "better-auth"); const { db, client } = await dbClient(
"mongodb://127.0.0.1:27017",
"better-auth",
);
async function clearDb() { async function clearDb() {
await db.collection(user).deleteMany({}); await db.collection(user).deleteMany({});
await db.collection("session").deleteMany({}); await db.collection("session").deleteMany({});
@@ -23,7 +26,10 @@ describe("adapter test", async () => {
await clearDb(); await clearDb();
}); });
const adapter = mongodbAdapter(db); const adapter = mongodbAdapter(db, {
// MongoDB transactions require a replica set or a sharded cluster
// client,
});
await runAdapterTest({ await runAdapterTest({
getAdapter: async (customOptions = {}) => { getAdapter: async (customOptions = {}) => {
return adapter({ return adapter({
@@ -46,6 +52,8 @@ describe("adapter test", async () => {
}, },
disableTests: { disableTests: {
SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true, SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true,
SHOULD_RETURN_TRANSACTION_RESULT: true,
SHOULD_ROLLBACK_FAILING_TRANSACTION: true,
}, },
}); });
}); });

View File

@@ -1,8 +1,18 @@
import { ObjectId, type Db } from "mongodb"; import { ClientSession, ObjectId, type Db, type MongoClient } from "mongodb";
import type { BetterAuthOptions, Where } from "../../types"; import type { Adapter, BetterAuthOptions, Where } from "../../types";
import { createAdapter, type AdapterDebugLogs } from "../create-adapter"; import {
createAdapter,
type AdapterDebugLogs,
type CreateAdapterOptions,
type CreateCustomAdapter,
} from "../create-adapter";
export interface MongoDBAdapterConfig { export interface MongoDBAdapterConfig {
/**
* MongoDB Client used for transactions.
* If not provided, operations will be executed without a session.
*/
client?: MongoClient;
/** /**
* Enable debug logs for the adapter * Enable debug logs for the adapter
* *
@@ -15,9 +25,18 @@ export interface MongoDBAdapterConfig {
* @default false * @default false
*/ */
usePlural?: boolean; usePlural?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
} }
export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => { export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
let lazyOptions: BetterAuthOptions | null;
const getCustomIdGenerator = (options: BetterAuthOptions) => { const getCustomIdGenerator = (options: BetterAuthOptions) => {
const generator = const generator =
options.advanced?.database?.generateId || options.advanced?.generateId; options.advanced?.database?.generateId || options.advanced?.generateId;
@@ -26,70 +45,10 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
} }
return undefined; return undefined;
}; };
return createAdapter({
config: {
adapterId: "mongodb-adapter",
adapterName: "MongoDB Adapter",
usePlural: config?.usePlural ?? false,
debugLogs: config?.debugLogs ?? false,
mapKeysTransformInput: {
id: "_id",
},
mapKeysTransformOutput: {
_id: "id",
},
supportsNumericIds: false,
customTransformInput({
action,
data,
field,
fieldAttributes,
schema,
model,
options,
}) {
const customIdGen = getCustomIdGenerator(options);
if (field === "_id" || fieldAttributes.references?.field === "id") { const createCustomAdapter =
if (customIdGen) { (db: Db, session?: ClientSession): CreateCustomAdapter =>
return data; ({ options, getFieldName, schema, getDefaultModelName }) => {
}
if (action === "update") {
return data;
}
if (Array.isArray(data)) {
return data.map((v) => new ObjectId());
}
if (typeof data === "string") {
try {
return new ObjectId(data);
} catch (error) {
return new ObjectId();
}
}
return new ObjectId();
}
return data;
},
customTransformOutput({ data, field, fieldAttributes }) {
if (field === "id" || fieldAttributes.references?.field === "id") {
if (data instanceof ObjectId) {
return data.toHexString();
}
if (Array.isArray(data)) {
return data.map((v) => {
if (v instanceof ObjectId) {
return v.toHexString();
}
return v;
});
}
return data;
}
return data;
},
},
adapter: ({ options, getFieldName, schema, getDefaultModelName }) => {
const customIdGen = getCustomIdGenerator(options); const customIdGen = getCustomIdGenerator(options);
function serializeID({ function serializeID({
@@ -238,19 +197,19 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
return { return {
async create({ model, data: values }) { async create({ model, data: values }) {
const res = await db.collection(model).insertOne(values); const res = await db.collection(model).insertOne(values, { session });
const insertedData = { _id: res.insertedId.toString(), ...values }; const insertedData = { _id: res.insertedId.toString(), ...values };
return insertedData as any; return insertedData as any;
}, },
async findOne({ model, where, select }) { async findOne({ model, where, select }) {
const clause = convertWhereClause({ where, model }); const clause = convertWhereClause({ where, model });
const res = await db.collection(model).findOne(clause); const res = await db.collection(model).findOne(clause, { session });
if (!res) return null; if (!res) return null;
return res as any; return res as any;
}, },
async findMany({ model, where, limit, offset, sortBy }) { async findMany({ model, where, limit, offset, sortBy }) {
const clause = where ? convertWhereClause({ where, model }) : {}; const clause = where ? convertWhereClause({ where, model }) : {};
const cursor = db.collection(model).find(clause); const cursor = db.collection(model).find(clause, { session });
if (limit) cursor.limit(limit); if (limit) cursor.limit(limit);
if (offset) cursor.skip(offset); if (offset) cursor.skip(offset);
if (sortBy) if (sortBy)
@@ -262,7 +221,9 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
return res as any; return res as any;
}, },
async count({ model }) { async count({ model }) {
const res = await db.collection(model).countDocuments(); const res = await db
.collection(model)
.countDocuments(undefined, { session });
return res; return res;
}, },
async update({ model, where, update: values }) { async update({ model, where, update: values }) {
@@ -272,6 +233,7 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
clause, clause,
{ $set: values as any }, { $set: values as any },
{ {
session,
returnDocument: "after", returnDocument: "after",
}, },
); );
@@ -281,21 +243,129 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
async updateMany({ model, where, update: values }) { async updateMany({ model, where, update: values }) {
const clause = convertWhereClause({ where, model }); const clause = convertWhereClause({ where, model });
const res = await db.collection(model).updateMany(clause, { const res = await db.collection(model).updateMany(
clause,
{
$set: values as any, $set: values as any,
}); },
{ session },
);
return res.modifiedCount; return res.modifiedCount;
}, },
async delete({ model, where }) { async delete({ model, where }) {
const clause = convertWhereClause({ where, model }); const clause = convertWhereClause({ where, model });
await db.collection(model).deleteOne(clause); await db.collection(model).deleteOne(clause, { session });
}, },
async deleteMany({ model, where }) { async deleteMany({ model, where }) {
const clause = convertWhereClause({ where, model }); const clause = convertWhereClause({ where, model });
const res = await db.collection(model).deleteMany(clause); const res = await db
.collection(model)
.deleteMany(clause, { session });
return res.deletedCount; return res.deletedCount;
}, },
}; };
}, };
});
let lazyAdapter: ((options: BetterAuthOptions) => Adapter) | null = null;
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "mongodb-adapter",
adapterName: "MongoDB Adapter",
usePlural: config?.usePlural ?? false,
debugLogs: config?.debugLogs ?? false,
mapKeysTransformInput: {
id: "_id",
},
mapKeysTransformOutput: {
_id: "id",
},
supportsNumericIds: false,
transaction:
config?.client && (config?.transaction ?? true)
? async (cb) => {
if (!config.client) {
return cb(lazyAdapter!(lazyOptions!));
}
const session = config.client.startSession();
try {
session.startTransaction();
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(db, session),
})(lazyOptions!);
const result = await cb(adapter);
await session.commitTransaction();
return result;
} catch (err) {
await session.abortTransaction();
throw err;
} finally {
await session.endSession();
}
}
: false,
customTransformInput({
action,
data,
field,
fieldAttributes,
schema,
model,
options,
}) {
const customIdGen = getCustomIdGenerator(options);
if (field === "_id" || fieldAttributes.references?.field === "id") {
if (customIdGen) {
return data;
}
if (action === "update") {
return data;
}
if (Array.isArray(data)) {
return data.map((v) => new ObjectId());
}
if (typeof data === "string") {
try {
return new ObjectId(data);
} catch (error) {
return new ObjectId();
}
}
return new ObjectId();
}
return data;
},
customTransformOutput({ data, field, fieldAttributes }) {
if (field === "id" || fieldAttributes.references?.field === "id") {
if (data instanceof ObjectId) {
return data.toHexString();
}
if (Array.isArray(data)) {
return data.map((v) => {
if (v instanceof ObjectId) {
return v.toHexString();
}
return v;
});
}
return data;
}
return data;
},
},
adapter: createCustomAdapter(db),
};
lazyAdapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return lazyAdapter(options);
};
}; };

View File

@@ -1,6 +1,11 @@
import { BetterAuthError } from "../../error"; import { BetterAuthError } from "../../error";
import type { Where } from "../../types"; import type { Adapter, BetterAuthOptions, Where } from "../../types";
import { createAdapter, type AdapterDebugLogs } from "../create-adapter"; import {
createAdapter,
type AdapterDebugLogs,
type CreateAdapterOptions,
type CreateCustomAdapter,
} from "../create-adapter";
export interface PrismaConfig { export interface PrismaConfig {
/** /**
@@ -27,11 +32,24 @@ export interface PrismaConfig {
* @default false * @default false
*/ */
usePlural?: boolean; usePlural?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
} }
interface PrismaClient {} interface PrismaClient {}
interface PrismaClientInternal { type PrismaClientInternal = {
$transaction: (
callback: (db: PrismaClient) => Promise<any> | any,
) => Promise<any>;
} & {
[model: string]: { [model: string]: {
create: (data: any) => Promise<any>; create: (data: any) => Promise<any>;
findFirst: (data: any) => Promise<any>; findFirst: (data: any) => Promise<any>;
@@ -40,17 +58,13 @@ interface PrismaClientInternal {
delete: (data: any) => Promise<any>; delete: (data: any) => Promise<any>;
[key: string]: any; [key: string]: any;
}; };
} };
export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) => export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) => {
createAdapter({ let lazyOptions: BetterAuthOptions | null = null;
config: { const createCustomAdapter =
adapterId: "prisma", (prisma: PrismaClient): CreateCustomAdapter =>
adapterName: "Prisma Adapter", ({ getFieldName }) => {
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
},
adapter: ({ getFieldName }) => {
const db = prisma as PrismaClientInternal; const db = prisma as PrismaClientInternal;
const convertSelect = (select?: string[], model?: string) => { const convertSelect = (select?: string[], model?: string) => {
@@ -217,5 +231,33 @@ export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) =>
}, },
options: config, options: config,
}; };
};
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "prisma",
adapterName: "Prisma Adapter",
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
transaction:
(config.transaction ?? true)
? (cb) =>
(prisma as PrismaClientInternal).$transaction((tx) => {
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(tx),
})(lazyOptions!);
return cb(adapter);
})
: false,
}, },
}); adapter: createCustomAdapter(prisma),
};
const adapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return adapter(options);
};
};

View File

@@ -47,6 +47,8 @@ const adapterTests = {
SHOULD_SEARCH_USERS_WITH_STARTS_WITH: "should search users with startsWith", SHOULD_SEARCH_USERS_WITH_STARTS_WITH: "should search users with startsWith",
SHOULD_SEARCH_USERS_WITH_ENDS_WITH: "should search users with endsWith", SHOULD_SEARCH_USERS_WITH_ENDS_WITH: "should search users with endsWith",
SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: "should prefer generateId if provided", SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: "should prefer generateId if provided",
SHOULD_ROLLBACK_FAILING_TRANSACTION: "should rollback failing transaction",
SHOULD_RETURN_TRANSACTION_RESULT: "should return transaction result",
} as const; } as const;
const { ...numberIdAdapterTestsCopy } = adapterTests; const { ...numberIdAdapterTestsCopy } = adapterTests;
@@ -820,6 +822,83 @@ async function adapterTest(
expect(res.id).toBe("mocked-id"); expect(res.id).toBe("mocked-id");
}, },
); );
test.skipIf(disabledTests?.SHOULD_ROLLBACK_FAILING_TRANSACTION)(
`${testPrefix ? `${testPrefix} - ` : ""}${adapterTests.SHOULD_ROLLBACK_FAILING_TRANSACTION}`,
async ({ onTestFailed }) => {
await resetDebugLogs();
onTestFailed(async () => {
await printDebugLogs();
});
const customAdapter = await adapter();
const user5 = {
name: "user5",
email: "user5@email.com",
emailVerified: true,
createdAt: new Date(),
updatedAt: new Date(),
};
const user6 = {
name: "user6",
email: "user6@email.com",
emailVerified: true,
createdAt: new Date(),
updatedAt: new Date(),
};
await expect(
customAdapter.transaction(async (tx) => {
await tx.create({ model: "user", data: user5 });
throw new Error("Simulated failure");
await tx.create({ model: "user", data: user6 });
}),
).rejects.toThrow("Simulated failure");
await expect(
customAdapter.findMany({
model: "user",
where: [
{
field: "email",
value: user5.email,
connector: "OR",
},
{
field: "email",
value: user6.email,
connector: "OR",
},
],
}),
).resolves.toEqual([]);
},
);
test.skipIf(disabledTests?.SHOULD_RETURN_TRANSACTION_RESULT)(
`${testPrefix ? `${testPrefix} - ` : ""}${adapterTests.SHOULD_RETURN_TRANSACTION_RESULT}`,
async ({ onTestFailed }) => {
await resetDebugLogs();
onTestFailed(async () => {
await printDebugLogs();
});
const customAdapter = await adapter();
const result = await customAdapter.transaction(async (tx) => {
const createdUser = await tx.create<User>({
model: "user",
data: {
name: "user6",
email: "user6@email.com",
emailVerified: true,
createdAt: new Date(),
updatedAt: new Date(),
},
});
return createdUser.email;
});
expect(result).toEqual("user6@email.com");
},
);
} }
export async function runAdapterTest(opts: AdapterTestOptions) { export async function runAdapterTest(opts: AdapterTestOptions) {

View File

@@ -5,6 +5,7 @@ import type {
AuthContext, AuthContext,
BetterAuthOptions, BetterAuthOptions,
GenericEndpointContext, GenericEndpointContext,
TransactionAdapter,
Where, Where,
} from "../types"; } from "../types";
import { import {
@@ -41,6 +42,7 @@ export const createInternalAdapter = (
Partial<Account>, Partial<Account>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
) => { ) => {
return adapter.transaction(async (trxAdapter) => {
const createdUser = await createWithHooks( const createdUser = await createWithHooks(
{ {
// todo: we should remove auto setting createdAt and updatedAt in the next major release, since the db generators already handle that // todo: we should remove auto setting createdAt and updatedAt in the next major release, since the db generators already handle that
@@ -51,6 +53,7 @@ export const createInternalAdapter = (
"user", "user",
undefined, undefined,
context, context,
trxAdapter,
); );
const createdAccount = await createWithHooks( const createdAccount = await createWithHooks(
{ {
@@ -63,17 +66,20 @@ export const createInternalAdapter = (
"account", "account",
undefined, undefined,
context, context,
trxAdapter,
); );
return { return {
user: createdUser, user: createdUser,
account: createdAccount, account: createdAccount,
}; };
});
}, },
createUser: async <T>( createUser: async <T>(
user: Omit<User, "id" | "createdAt" | "updatedAt" | "emailVerified"> & user: Omit<User, "id" | "createdAt" | "updatedAt" | "emailVerified"> &
Partial<User> & Partial<User> &
Record<string, any>, Record<string, any>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const createdUser = await createWithHooks( const createdUser = await createWithHooks(
{ {
@@ -86,6 +92,7 @@ export const createInternalAdapter = (
"user", "user",
undefined, undefined,
context, context,
trxAdapter,
); );
return createdUser as T & User; return createdUser as T & User;
}, },
@@ -94,6 +101,7 @@ export const createInternalAdapter = (
Partial<Account> & Partial<Account> &
T, T,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const createdAccount = await createWithHooks( const createdAccount = await createWithHooks(
{ {
@@ -105,10 +113,11 @@ export const createInternalAdapter = (
"account", "account",
undefined, undefined,
context, context,
trxAdapter,
); );
return createdAccount as T & Account; return createdAccount as T & Account;
}, },
listSessions: async (userId: string) => { listSessions: async (userId: string, trxAdapter?: TransactionAdapter) => {
if (secondaryStorage) { if (secondaryStorage) {
const currentList = await secondaryStorage.get( const currentList = await secondaryStorage.get(
`active-sessions-${userId}`, `active-sessions-${userId}`,
@@ -140,7 +149,7 @@ export const createInternalAdapter = (
return sessions; return sessions;
} }
const sessions = await adapter.findMany<Session>({ const sessions = await (trxAdapter || adapter).findMany<Session>({
model: "session", model: "session",
where: [ where: [
{ {
@@ -159,8 +168,9 @@ export const createInternalAdapter = (
direction: "asc" | "desc"; direction: "asc" | "desc";
}, },
where?: Where[], where?: Where[],
trxAdapter?: TransactionAdapter,
) => { ) => {
const users = await adapter.findMany<User>({ const users = await (trxAdapter || adapter).findMany<User>({
model: "user", model: "user",
limit, limit,
offset, offset,
@@ -169,8 +179,11 @@ export const createInternalAdapter = (
}); });
return users; return users;
}, },
countTotalUsers: async (where?: Where[]) => { countTotalUsers: async (
const total = await adapter.count({ where?: Where[],
trxAdapter?: TransactionAdapter,
) => {
const total = await (trxAdapter || adapter).count({
model: "user", model: "user",
where, where,
}); });
@@ -179,13 +192,13 @@ export const createInternalAdapter = (
} }
return total; return total;
}, },
deleteUser: async (userId: string) => { deleteUser: async (userId: string, trxAdapter?: TransactionAdapter) => {
if (secondaryStorage) { if (secondaryStorage) {
await secondaryStorage.delete(`active-sessions-${userId}`); await secondaryStorage.delete(`active-sessions-${userId}`);
} }
if (!secondaryStorage || options.session?.storeSessionInDatabase) { if (!secondaryStorage || options.session?.storeSessionInDatabase) {
await adapter.deleteMany({ await (trxAdapter || adapter).deleteMany({
model: "session", model: "session",
where: [ where: [
{ {
@@ -196,7 +209,7 @@ export const createInternalAdapter = (
}); });
} }
await adapter.deleteMany({ await (trxAdapter || adapter).deleteMany({
model: "account", model: "account",
where: [ where: [
{ {
@@ -205,7 +218,7 @@ export const createInternalAdapter = (
}, },
], ],
}); });
await adapter.delete({ await (trxAdapter || adapter).delete({
model: "user", model: "user",
where: [ where: [
{ {
@@ -221,6 +234,7 @@ export const createInternalAdapter = (
dontRememberMe?: boolean, dontRememberMe?: boolean,
override?: Partial<Session> & Record<string, any>, override?: Partial<Session> & Record<string, any>,
overrideAll?: boolean, overrideAll?: boolean,
trxAdapter?: TransactionAdapter,
) => { ) => {
const headers = ctx.headers || ctx.request?.headers; const headers = ctx.headers || ctx.request?.headers;
const { id: _, ...rest } = override || {}; const { id: _, ...rest } = override || {};
@@ -285,11 +299,13 @@ export const createInternalAdapter = (
} }
: undefined, : undefined,
ctx, ctx,
trxAdapter,
); );
return res as Session; return res as Session;
}, },
findSession: async ( findSession: async (
token: string, token: string,
trxAdapter?: TransactionAdapter,
): Promise<{ ): Promise<{
session: Session & Record<string, any>; session: Session & Record<string, any>;
user: User & Record<string, any>; user: User & Record<string, any>;
@@ -323,7 +339,7 @@ export const createInternalAdapter = (
} }
} }
const session = await adapter.findOne<Session>({ const session = await (trxAdapter || adapter).findOne<Session>({
model: "session", model: "session",
where: [ where: [
{ {
@@ -337,7 +353,7 @@ export const createInternalAdapter = (
return null; return null;
} }
const user = await adapter.findOne<User>({ const user = await (trxAdapter || adapter).findOne<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -357,7 +373,10 @@ export const createInternalAdapter = (
user: parsedUser, user: parsedUser,
}; };
}, },
findSessions: async (sessionTokens: string[]) => { findSessions: async (
sessionTokens: string[],
trxAdapter?: TransactionAdapter,
) => {
if (secondaryStorage) { if (secondaryStorage) {
const sessions: { const sessions: {
session: Session; session: Session;
@@ -391,7 +410,7 @@ export const createInternalAdapter = (
return sessions; return sessions;
} }
const sessions = await adapter.findMany<Session>({ const sessions = await (trxAdapter || adapter).findMany<Session>({
model: "session", model: "session",
where: [ where: [
{ {
@@ -405,7 +424,7 @@ export const createInternalAdapter = (
return session.userId; return session.userId;
}); });
if (!userIds.length) return []; if (!userIds.length) return [];
const users = await adapter.findMany<User>({ const users = await (trxAdapter || adapter).findMany<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -431,6 +450,7 @@ export const createInternalAdapter = (
sessionToken: string, sessionToken: string,
session: Partial<Session> & Record<string, any>, session: Partial<Session> & Record<string, any>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const updatedSession = await updateWithHooks<Session>( const updatedSession = await updateWithHooks<Session>(
session, session,
@@ -460,10 +480,11 @@ export const createInternalAdapter = (
} }
: undefined, : undefined,
context, context,
trxAdapter,
); );
return updatedSession; return updatedSession;
}, },
deleteSession: async (token: string) => { deleteSession: async (token: string, trxAdapter?: TransactionAdapter) => {
if (secondaryStorage) { if (secondaryStorage) {
// remove the session from the active sessions list // remove the session from the active sessions list
const data = await secondaryStorage.get(token); const data = await secondaryStorage.get(token);
@@ -510,7 +531,7 @@ export const createInternalAdapter = (
return; return;
} }
} }
await adapter.delete<Session>({ await (trxAdapter || adapter).delete<Session>({
model: "session", model: "session",
where: [ where: [
{ {
@@ -520,8 +541,8 @@ export const createInternalAdapter = (
], ],
}); });
}, },
deleteAccounts: async (userId: string) => { deleteAccounts: async (userId: string, trxAdapter?: TransactionAdapter) => {
await adapter.deleteMany({ await (trxAdapter || adapter).deleteMany({
model: "account", model: "account",
where: [ where: [
{ {
@@ -531,8 +552,11 @@ export const createInternalAdapter = (
], ],
}); });
}, },
deleteAccount: async (accountId: string) => { deleteAccount: async (
await adapter.delete({ accountId: string,
trxAdapter?: TransactionAdapter,
) => {
await (trxAdapter || adapter).delete({
model: "account", model: "account",
where: [ where: [
{ {
@@ -542,7 +566,10 @@ export const createInternalAdapter = (
], ],
}); });
}, },
deleteSessions: async (userIdOrSessionTokens: string | string[]) => { deleteSessions: async (
userIdOrSessionTokens: string | string[],
trxAdapter?: TransactionAdapter,
) => {
if (secondaryStorage) { if (secondaryStorage) {
if (typeof userIdOrSessionTokens === "string") { if (typeof userIdOrSessionTokens === "string") {
const activeSession = await secondaryStorage.get( const activeSession = await secondaryStorage.get(
@@ -571,7 +598,7 @@ export const createInternalAdapter = (
return; return;
} }
} }
await adapter.deleteMany({ await (trxAdapter || adapter).deleteMany({
model: "session", model: "session",
where: [ where: [
{ {
@@ -586,9 +613,10 @@ export const createInternalAdapter = (
email: string, email: string,
accountId: string, accountId: string,
providerId: string, providerId: string,
trxAdapter?: TransactionAdapter,
) => { ) => {
// we need to find account first to avoid missing user if the email changed with the provider for the same account // we need to find account first to avoid missing user if the email changed with the provider for the same account
const account = await adapter const account = await (trxAdapter || adapter)
.findMany<Account>({ .findMany<Account>({
model: "account", model: "account",
where: [ where: [
@@ -602,7 +630,7 @@ export const createInternalAdapter = (
return accounts.find((a) => a.providerId === providerId); return accounts.find((a) => a.providerId === providerId);
}); });
if (account) { if (account) {
const user = await adapter.findOne<User>({ const user = await (trxAdapter || adapter).findOne<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -617,7 +645,7 @@ export const createInternalAdapter = (
accounts: [account], accounts: [account],
}; };
} else { } else {
const user = await adapter.findOne<User>({ const user = await (trxAdapter || adapter).findOne<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -635,7 +663,7 @@ export const createInternalAdapter = (
return null; return null;
} }
} else { } else {
const user = await adapter.findOne<User>({ const user = await (trxAdapter || adapter).findOne<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -645,7 +673,7 @@ export const createInternalAdapter = (
], ],
}); });
if (user) { if (user) {
const accounts = await adapter.findMany<Account>({ const accounts = await (trxAdapter || adapter).findMany<Account>({
model: "account", model: "account",
where: [ where: [
{ {
@@ -666,8 +694,9 @@ export const createInternalAdapter = (
findUserByEmail: async ( findUserByEmail: async (
email: string, email: string,
options?: { includeAccounts: boolean }, options?: { includeAccounts: boolean },
trxAdapter?: TransactionAdapter,
) => { ) => {
const user = await adapter.findOne<User>({ const user = await (trxAdapter || adapter).findOne<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -678,7 +707,7 @@ export const createInternalAdapter = (
}); });
if (!user) return null; if (!user) return null;
if (options?.includeAccounts) { if (options?.includeAccounts) {
const accounts = await adapter.findMany<Account>({ const accounts = await (trxAdapter || adapter).findMany<Account>({
model: "account", model: "account",
where: [ where: [
{ {
@@ -697,8 +726,8 @@ export const createInternalAdapter = (
accounts: [], accounts: [],
}; };
}, },
findUserById: async (userId: string) => { findUserById: async (userId: string, trxAdapter?: TransactionAdapter) => {
const user = await adapter.findOne<User>({ const user = await (trxAdapter || adapter).findOne<User>({
model: "user", model: "user",
where: [ where: [
{ {
@@ -713,6 +742,7 @@ export const createInternalAdapter = (
account: Omit<Account, "id" | "createdAt" | "updatedAt"> & account: Omit<Account, "id" | "createdAt" | "updatedAt"> &
Partial<Account>, Partial<Account>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const _account = await createWithHooks( const _account = await createWithHooks(
{ {
@@ -724,6 +754,7 @@ export const createInternalAdapter = (
"account", "account",
undefined, undefined,
context, context,
trxAdapter,
); );
return _account; return _account;
}, },
@@ -731,6 +762,7 @@ export const createInternalAdapter = (
userId: string, userId: string,
data: Partial<User> & Record<string, any>, data: Partial<User> & Record<string, any>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const user = await updateWithHooks<User>( const user = await updateWithHooks<User>(
data, data,
@@ -743,6 +775,7 @@ export const createInternalAdapter = (
"user", "user",
undefined, undefined,
context, context,
trxAdapter,
); );
if (secondaryStorage && user) { if (secondaryStorage && user) {
const listRaw = await secondaryStorage.get(`active-sessions-${userId}`); const listRaw = await secondaryStorage.get(`active-sessions-${userId}`);
@@ -785,6 +818,7 @@ export const createInternalAdapter = (
email: string, email: string,
data: Partial<User & Record<string, any>>, data: Partial<User & Record<string, any>>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const user = await updateWithHooks<User>( const user = await updateWithHooks<User>(
data, data,
@@ -797,6 +831,7 @@ export const createInternalAdapter = (
"user", "user",
undefined, undefined,
context, context,
trxAdapter,
); );
return user; return user;
}, },
@@ -804,6 +839,7 @@ export const createInternalAdapter = (
userId: string, userId: string,
password: string, password: string,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
await updateManyWithHooks( await updateManyWithHooks(
{ {
@@ -822,10 +858,11 @@ export const createInternalAdapter = (
"account", "account",
undefined, undefined,
context, context,
trxAdapter,
); );
}, },
findAccounts: async (userId: string) => { findAccounts: async (userId: string, trxAdapter?: TransactionAdapter) => {
const accounts = await adapter.findMany<Account>({ const accounts = await (trxAdapter || adapter).findMany<Account>({
model: "account", model: "account",
where: [ where: [
{ {
@@ -836,8 +873,8 @@ export const createInternalAdapter = (
}); });
return accounts; return accounts;
}, },
findAccount: async (accountId: string) => { findAccount: async (accountId: string, trxAdapter?: TransactionAdapter) => {
const account = await adapter.findOne<Account>({ const account = await (trxAdapter || adapter).findOne<Account>({
model: "account", model: "account",
where: [ where: [
{ {
@@ -848,8 +885,12 @@ export const createInternalAdapter = (
}); });
return account; return account;
}, },
findAccountByProviderId: async (accountId: string, providerId: string) => { findAccountByProviderId: async (
const account = await adapter.findOne<Account>({ accountId: string,
providerId: string,
trxAdapter?: TransactionAdapter,
) => {
const account = await (trxAdapter || adapter).findOne<Account>({
model: "account", model: "account",
where: [ where: [
{ {
@@ -864,8 +905,11 @@ export const createInternalAdapter = (
}); });
return account; return account;
}, },
findAccountByUserId: async (userId: string) => { findAccountByUserId: async (
const account = await adapter.findMany<Account>({ userId: string,
trxAdapter?: TransactionAdapter,
) => {
const account = await (trxAdapter || adapter).findMany<Account>({
model: "account", model: "account",
where: [ where: [
{ {
@@ -880,6 +924,7 @@ export const createInternalAdapter = (
id: string, id: string,
data: Partial<Account>, data: Partial<Account>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const account = await updateWithHooks<Account>( const account = await updateWithHooks<Account>(
data, data,
@@ -887,6 +932,7 @@ export const createInternalAdapter = (
"account", "account",
undefined, undefined,
context, context,
trxAdapter,
); );
return account; return account;
}, },
@@ -894,6 +940,7 @@ export const createInternalAdapter = (
data: Omit<Verification, "createdAt" | "id" | "updatedAt"> & data: Omit<Verification, "createdAt" | "id" | "updatedAt"> &
Partial<Verification>, Partial<Verification>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const verification = await createWithHooks( const verification = await createWithHooks(
{ {
@@ -905,11 +952,16 @@ export const createInternalAdapter = (
"verification", "verification",
undefined, undefined,
context, context,
trxAdapter,
); );
return verification as Verification; return verification as Verification;
}, },
findVerificationValue: async (identifier: string) => { findVerificationValue: async (
const verification = await adapter.findMany<Verification>({ identifier: string,
trxAdapter?: TransactionAdapter,
) => {
const verification = await (trxAdapter || adapter).findMany<Verification>(
{
model: "verification", model: "verification",
where: [ where: [
{ {
@@ -922,9 +974,10 @@ export const createInternalAdapter = (
direction: "desc", direction: "desc",
}, },
limit: 1, limit: 1,
}); },
);
if (!options.verification?.disableCleanup) { if (!options.verification?.disableCleanup) {
await adapter.deleteMany({ await (trxAdapter || adapter).deleteMany({
model: "verification", model: "verification",
where: [ where: [
{ {
@@ -938,8 +991,11 @@ export const createInternalAdapter = (
const lastVerification = verification[0]; const lastVerification = verification[0];
return lastVerification as Verification | null; return lastVerification as Verification | null;
}, },
deleteVerificationValue: async (id: string) => { deleteVerificationValue: async (
await adapter.delete<Verification>({ id: string,
trxAdapter?: TransactionAdapter,
) => {
await (trxAdapter || adapter).delete<Verification>({
model: "verification", model: "verification",
where: [ where: [
{ {
@@ -949,8 +1005,11 @@ export const createInternalAdapter = (
], ],
}); });
}, },
deleteVerificationByIdentifier: async (identifier: string) => { deleteVerificationByIdentifier: async (
await adapter.delete<Verification>({ identifier: string,
trxAdapter?: TransactionAdapter,
) => {
await (trxAdapter || adapter).delete<Verification>({
model: "verification", model: "verification",
where: [ where: [
{ {
@@ -964,6 +1023,7 @@ export const createInternalAdapter = (
id: string, id: string,
data: Partial<Verification>, data: Partial<Verification>,
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => { ) => {
const verification = await updateWithHooks<Verification>( const verification = await updateWithHooks<Verification>(
data, data,
@@ -971,6 +1031,7 @@ export const createInternalAdapter = (
"verification", "verification",
undefined, undefined,
context, context,
trxAdapter,
); );
return verification; return verification;
}, },

View File

@@ -3,6 +3,7 @@ import type {
BetterAuthOptions, BetterAuthOptions,
GenericEndpointContext, GenericEndpointContext,
Models, Models,
TransactionAdapter,
Where, Where,
} from "../types"; } from "../types";
@@ -26,6 +27,7 @@ export function getWithHooks(
executeMainFn?: boolean; executeMainFn?: boolean;
}, },
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) { ) {
let actualData = data; let actualData = data;
for (const hook of hooks || []) { for (const hook of hooks || []) {
@@ -50,7 +52,7 @@ export function getWithHooks(
: null; : null;
const created = const created =
!customCreateFn || customCreateFn.executeMainFn !customCreateFn || customCreateFn.executeMainFn
? await adapter.create<T>({ ? await (trxAdapter || adapter).create<T>({
model, model,
data: actualData as any, data: actualData as any,
forceAllowId: true, forceAllowId: true,
@@ -76,6 +78,7 @@ export function getWithHooks(
executeMainFn?: boolean; executeMainFn?: boolean;
}, },
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) { ) {
let actualData = data; let actualData = data;
@@ -97,7 +100,7 @@ export function getWithHooks(
const updated = const updated =
!customUpdateFn || customUpdateFn.executeMainFn !customUpdateFn || customUpdateFn.executeMainFn
? await adapter.update<T>({ ? await (trxAdapter || adapter).update<T>({
model, model,
update: actualData, update: actualData,
where, where,
@@ -122,6 +125,7 @@ export function getWithHooks(
executeMainFn?: boolean; executeMainFn?: boolean;
}, },
context?: GenericEndpointContext, context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) { ) {
let actualData = data; let actualData = data;
@@ -143,7 +147,7 @@ export function getWithHooks(
const updated = const updated =
!customUpdateFn || customUpdateFn.executeMainFn !customUpdateFn || customUpdateFn.executeMainFn
? await adapter.updateMany({ ? await (trxAdapter || adapter).updateMany({
model, model,
update: actualData, update: actualData,
where, where,

View File

@@ -70,6 +70,13 @@ export type Adapter = {
}) => Promise<number>; }) => Promise<number>;
delete: <T>(data: { model: string; where: Where[] }) => Promise<void>; delete: <T>(data: { model: string; where: Where[] }) => Promise<void>;
deleteMany: (data: { model: string; where: Where[] }) => Promise<number>; deleteMany: (data: { model: string; where: Where[] }) => Promise<number>;
/**
* Execute multiple operations in a transaction.
* If the adapter doesn't support transactions, operations will be executed sequentially.
*/
transaction: <R>(
callback: (tx: Omit<Adapter, "transaction">) => Promise<R>,
) => Promise<R>;
/** /**
* *
* @param options * @param options
@@ -84,6 +91,8 @@ export type Adapter = {
} & CustomAdapter["options"]; } & CustomAdapter["options"];
}; };
export type TransactionAdapter = Omit<Adapter, "transaction">;
export type AdapterSchemaCreation = { export type AdapterSchemaCreation = {
/** /**
* Code to be inserted into the file * Code to be inserted into the file