feat: database transaction support (#4414)

Co-authored-by: Joel Solano <solano.joel@gmx.de>
This commit is contained in:
Alex Yang
2025-09-09 17:10:27 -07:00
committed by GitHub
parent 61b6a87435
commit 22c39b92be
16 changed files with 647 additions and 228 deletions

View File

@@ -20,7 +20,7 @@ const client = new MongoClient("mongodb://localhost:27017/database");
const db = client.db();
export const auth = betterAuth({
database: mongodbAdapter(db),
database: mongodbAdapter(db, { client }),
});
```

View File

@@ -452,6 +452,16 @@ const adapter = myAdapter({
});
```
### `transaction`
Whether the adapter supports transactions. If `false`, operations run sequentially; otherwise provide a function that executes a callback with a `TransactionAdapter`.
<Callout type="warn">
If your database does not support transactions, the error handling and rollback
will not be as robust. We recommend using a database that supports transactions
for better data integrity.
</Callout>
### `debugLogs`
Used to enable debug logs for the adapter. You can pass in a boolean, or an object with the following keys: `create`, `update`, `updateMany`, `findOne`, `findMany`, `delete`, `deleteMany`, `count`.

View File

@@ -20,11 +20,13 @@ exports[`init > should match config 1`] = `
"supportsDates": false,
"supportsJSON": false,
"supportsNumericIds": true,
"transaction": [Function],
"usePlural": undefined,
},
"debugLogs": false,
"type": "sqlite",
},
"transaction": [Function],
"update": [Function],
"updateMany": [Function],
},

View File

@@ -1,13 +1,17 @@
import { safeJSONParse } from "../../utils/json";
import { withApplyDefault } from "../../adapters/utils";
import { getAuthTables } from "../../db/get-tables";
import type { Adapter, BetterAuthOptions, Where } from "../../types";
import type {
Adapter,
BetterAuthOptions,
TransactionAdapter,
Where,
} from "../../types";
import { generateId as defaultGenerateId, logger } from "../../utils";
import type {
AdapterConfig,
CreateAdapterOptions,
AdapterTestDebugLogs,
CleanedWhere,
CreateCustomAdapter,
} from "./types";
import type { FieldAttribute } from "../../db";
export * from "./types";
@@ -45,14 +49,13 @@ const colors = {
},
};
const createAsIsTransaction =
(adapter: Adapter) =>
<R>(fn: (trx: TransactionAdapter) => Promise<R>) =>
fn(adapter);
export const createAdapter =
({
adapter,
config: cfg,
}: {
config: AdapterConfig;
adapter: CreateCustomAdapter;
}) =>
({ adapter: customAdapter, config: cfg }: CreateAdapterOptions) =>
(options: BetterAuthOptions): Adapter => {
const config = {
...cfg,
@@ -315,7 +318,7 @@ export const createAdapter =
return fields[defaultFieldName];
};
const adapterInstance = adapter({
const adapterInstance = customAdapter({
options,
schema,
debugLog,
@@ -560,7 +563,24 @@ export const createAdapter =
}) as any;
};
return {
let lazyLoadTransaction: Adapter["transaction"] | null = null;
const adapter: Adapter = {
transaction: async (cb) => {
if (!lazyLoadTransaction) {
if (!config.transaction) {
logger.warn(
`[${config.adapterName}] - Transactions are not supported. Executing operations sequentially.`,
);
lazyLoadTransaction = createAsIsTransaction(adapter);
} else {
logger.debug(
`[${config.adapterName}] - Using provided transaction implementation.`,
);
lazyLoadTransaction = config.transaction;
}
}
return lazyLoadTransaction(cb);
},
create: async <T extends Record<string, any>, R = T>({
data: unsafeData,
model: unsafeModel,
@@ -1016,6 +1036,7 @@ export const createAdapter =
}
: {}),
};
return adapter;
};
function formatTransactionId(transactionId: number) {

View File

@@ -3,6 +3,7 @@ import type { BetterAuthDbSchema } from "../../db/get-tables";
import type {
AdapterSchemaCreation,
BetterAuthOptions,
TransactionAdapter,
Where,
} from "../../types";
import type { Prettify } from "../../types/helper";
@@ -32,6 +33,11 @@ export type AdapterDebugLogs =
isRunningAdapterTests: boolean;
};
export type CreateAdapterOptions = {
config: AdapterConfig;
adapter: CreateCustomAdapter;
};
export interface AdapterConfig {
/**
* Use plural table names.
@@ -89,6 +95,15 @@ export interface AdapterConfig {
* @default true
*/
supportsBooleans?: boolean;
/**
* Execute multiple operations in a transaction.
*
* If the database doesn't support transactions, set this to `false` and operations will be executed sequentially.
*
*/
transaction?:
| false
| (<R>(callback: (trx: TransactionAdapter) => Promise<R>) => Promise<R>);
/**
* Disable id generation for the `create` method.
*

View File

@@ -17,8 +17,13 @@ import {
SQL,
} from "drizzle-orm";
import { BetterAuthError } from "../../error";
import type { Where } from "../../types";
import { createAdapter, type AdapterDebugLogs } from "../create-adapter";
import type { Adapter, BetterAuthOptions, Where } from "../../types";
import {
createAdapter,
type AdapterDebugLogs,
type CreateAdapterOptions,
type CreateCustomAdapter,
} from "../create-adapter";
export interface DB {
[key: string]: any;
@@ -52,17 +57,21 @@ export interface DrizzleAdapterConfig {
* @default false
*/
camelCase?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
}
export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) =>
createAdapter({
config: {
adapterId: "drizzle",
adapterName: "Drizzle Adapter",
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
},
adapter: ({ getFieldName, debugLog }) => {
export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) => {
let lazyOptions: BetterAuthOptions | null = null;
const createCustomAdapter =
(db: DB): CreateCustomAdapter =>
({ getFieldName, debugLog }) => {
function getSchema(model: string) {
const schema = config.schema || db._.fullSchema;
if (!schema) {
@@ -343,5 +352,31 @@ export const drizzleAdapter = (db: DB, config: DrizzleAdapterConfig) =>
},
options: config,
};
};
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "drizzle",
adapterName: "Drizzle Adapter",
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
transaction:
(config.transaction ?? true)
? (cb) =>
db.transaction((tx: DB) => {
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(tx),
})(lazyOptions!);
return cb(adapter);
})
: false,
},
});
adapter: createCustomAdapter(db),
};
const adapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return adapter(options);
};
};

View File

@@ -1,5 +1,10 @@
import { createAdapter, type AdapterDebugLogs } from "../create-adapter";
import type { Where } from "../../types";
import {
createAdapter,
type AdapterDebugLogs,
type CreateCustomAdapter,
type CreateAdapterOptions,
} from "../create-adapter";
import type { Adapter, BetterAuthOptions, Where } from "../../types";
import type { KyselyDatabaseType } from "./types";
import type { InsertQueryBuilder, Kysely, UpdateQueryBuilder } from "kysely";
@@ -20,26 +25,23 @@ interface KyselyAdapterConfig {
* @default false
*/
usePlural?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
}
export const kyselyAdapter = (db: Kysely<any>, config?: KyselyAdapterConfig) =>
createAdapter({
config: {
adapterId: "kysely",
adapterName: "Kysely Adapter",
usePlural: config?.usePlural,
debugLogs: config?.debugLogs,
supportsBooleans:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsDates:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsJSON: false,
},
adapter: ({ getFieldName, schema }) => {
export const kyselyAdapter = (
db: Kysely<any>,
config?: KyselyAdapterConfig,
) => {
let lazyOptions: BetterAuthOptions | null = null;
const createCustomAdapter = (db: Kysely<any>): CreateCustomAdapter => {
return ({ getFieldName, schema }) => {
const withReturning = async (
values: Record<string, any>,
builder:
@@ -315,5 +317,43 @@ export const kyselyAdapter = (db: Kysely<any>, config?: KyselyAdapterConfig) =>
},
options: config,
};
};
};
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "kysely",
adapterName: "Kysely Adapter",
usePlural: config?.usePlural,
debugLogs: config?.debugLogs,
supportsBooleans:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsDates:
config?.type === "sqlite" || config?.type === "mssql" || !config?.type
? false
: true,
supportsJSON: false,
transaction:
(config?.transaction ?? true)
? (cb) =>
db.transaction().execute((trx) => {
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(trx),
})(lazyOptions!);
return cb(adapter);
})
: false,
},
});
adapter: createCustomAdapter(db),
};
const adapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return adapter(options);
};
};

View File

@@ -24,6 +24,10 @@ describe("adapter test", async () => {
...customOptions,
});
},
disableTests: {
SHOULD_ROLLBACK_FAILING_TRANSACTION: true,
SHOULD_RETURN_TRANSACTION_RESULT: true,
},
});
});

View File

@@ -4,6 +4,7 @@ import {
type AdapterDebugLogs,
type CleanedWhere,
} from "../create-adapter";
import type { BetterAuthOptions } from "../../types";
export interface MemoryDB {
[key: string]: any[];
@@ -13,8 +14,9 @@ export interface MemoryAdapterConfig {
debugLogs?: AdapterDebugLogs;
}
export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) =>
createAdapter({
export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) => {
let lazyOptions: BetterAuthOptions | null = null;
let adapterCreator = createAdapter({
config: {
adapterId: "memory",
adapterName: "Memory Adapter",
@@ -30,6 +32,18 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) =>
}
return props.data;
},
transaction: async (cb) => {
let clone = structuredClone(db);
try {
return cb(adapterCreator(lazyOptions!));
} catch {
// Rollback changes
Object.keys(db).forEach((key) => {
db[key] = clone[key];
});
throw new Error("Transaction failed, rolling back changes");
}
},
},
adapter: ({ getFieldName, options, debugLog }) => {
function convertWhereClause(where: CleanedWhere[], model: string) {
@@ -147,3 +161,8 @@ export const memoryAdapter = (db: MemoryDB, config?: MemoryAdapterConfig) =>
};
},
});
return (options: BetterAuthOptions) => {
lazyOptions = options;
return adapterCreator(options);
};
};

View File

@@ -9,11 +9,14 @@ describe("adapter test", async () => {
const client = new MongoClient(connectionString);
await client.connect();
const db = client.db(dbName);
return db;
return { db, client };
};
const user = "user";
const db = await dbClient("mongodb://127.0.0.1:27017", "better-auth");
const { db, client } = await dbClient(
"mongodb://127.0.0.1:27017",
"better-auth",
);
async function clearDb() {
await db.collection(user).deleteMany({});
await db.collection("session").deleteMany({});
@@ -23,7 +26,10 @@ describe("adapter test", async () => {
await clearDb();
});
const adapter = mongodbAdapter(db);
const adapter = mongodbAdapter(db, {
// MongoDB transactions require a replica set or a sharded cluster
// client,
});
await runAdapterTest({
getAdapter: async (customOptions = {}) => {
return adapter({
@@ -46,6 +52,8 @@ describe("adapter test", async () => {
},
disableTests: {
SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: true,
SHOULD_RETURN_TRANSACTION_RESULT: true,
SHOULD_ROLLBACK_FAILING_TRANSACTION: true,
},
});
});

View File

@@ -1,8 +1,18 @@
import { ObjectId, type Db } from "mongodb";
import type { BetterAuthOptions, Where } from "../../types";
import { createAdapter, type AdapterDebugLogs } from "../create-adapter";
import { ClientSession, ObjectId, type Db, type MongoClient } from "mongodb";
import type { Adapter, BetterAuthOptions, Where } from "../../types";
import {
createAdapter,
type AdapterDebugLogs,
type CreateAdapterOptions,
type CreateCustomAdapter,
} from "../create-adapter";
export interface MongoDBAdapterConfig {
/**
* MongoDB Client used for transactions.
* If not provided, operations will be executed without a session.
*/
client?: MongoClient;
/**
* Enable debug logs for the adapter
*
@@ -15,9 +25,18 @@ export interface MongoDBAdapterConfig {
* @default false
*/
usePlural?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
}
export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
let lazyOptions: BetterAuthOptions | null;
const getCustomIdGenerator = (options: BetterAuthOptions) => {
const generator =
options.advanced?.database?.generateId || options.advanced?.generateId;
@@ -26,70 +45,10 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
}
return undefined;
};
return createAdapter({
config: {
adapterId: "mongodb-adapter",
adapterName: "MongoDB Adapter",
usePlural: config?.usePlural ?? false,
debugLogs: config?.debugLogs ?? false,
mapKeysTransformInput: {
id: "_id",
},
mapKeysTransformOutput: {
_id: "id",
},
supportsNumericIds: false,
customTransformInput({
action,
data,
field,
fieldAttributes,
schema,
model,
options,
}) {
const customIdGen = getCustomIdGenerator(options);
if (field === "_id" || fieldAttributes.references?.field === "id") {
if (customIdGen) {
return data;
}
if (action === "update") {
return data;
}
if (Array.isArray(data)) {
return data.map((v) => new ObjectId());
}
if (typeof data === "string") {
try {
return new ObjectId(data);
} catch (error) {
return new ObjectId();
}
}
return new ObjectId();
}
return data;
},
customTransformOutput({ data, field, fieldAttributes }) {
if (field === "id" || fieldAttributes.references?.field === "id") {
if (data instanceof ObjectId) {
return data.toHexString();
}
if (Array.isArray(data)) {
return data.map((v) => {
if (v instanceof ObjectId) {
return v.toHexString();
}
return v;
});
}
return data;
}
return data;
},
},
adapter: ({ options, getFieldName, schema, getDefaultModelName }) => {
const createCustomAdapter =
(db: Db, session?: ClientSession): CreateCustomAdapter =>
({ options, getFieldName, schema, getDefaultModelName }) => {
const customIdGen = getCustomIdGenerator(options);
function serializeID({
@@ -238,19 +197,19 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
return {
async create({ model, data: values }) {
const res = await db.collection(model).insertOne(values);
const res = await db.collection(model).insertOne(values, { session });
const insertedData = { _id: res.insertedId.toString(), ...values };
return insertedData as any;
},
async findOne({ model, where, select }) {
const clause = convertWhereClause({ where, model });
const res = await db.collection(model).findOne(clause);
const res = await db.collection(model).findOne(clause, { session });
if (!res) return null;
return res as any;
},
async findMany({ model, where, limit, offset, sortBy }) {
const clause = where ? convertWhereClause({ where, model }) : {};
const cursor = db.collection(model).find(clause);
const cursor = db.collection(model).find(clause, { session });
if (limit) cursor.limit(limit);
if (offset) cursor.skip(offset);
if (sortBy)
@@ -262,7 +221,9 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
return res as any;
},
async count({ model }) {
const res = await db.collection(model).countDocuments();
const res = await db
.collection(model)
.countDocuments(undefined, { session });
return res;
},
async update({ model, where, update: values }) {
@@ -272,6 +233,7 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
clause,
{ $set: values as any },
{
session,
returnDocument: "after",
},
);
@@ -281,21 +243,129 @@ export const mongodbAdapter = (db: Db, config?: MongoDBAdapterConfig) => {
async updateMany({ model, where, update: values }) {
const clause = convertWhereClause({ where, model });
const res = await db.collection(model).updateMany(clause, {
$set: values as any,
});
const res = await db.collection(model).updateMany(
clause,
{
$set: values as any,
},
{ session },
);
return res.modifiedCount;
},
async delete({ model, where }) {
const clause = convertWhereClause({ where, model });
await db.collection(model).deleteOne(clause);
await db.collection(model).deleteOne(clause, { session });
},
async deleteMany({ model, where }) {
const clause = convertWhereClause({ where, model });
const res = await db.collection(model).deleteMany(clause);
const res = await db
.collection(model)
.deleteMany(clause, { session });
return res.deletedCount;
},
};
};
let lazyAdapter: ((options: BetterAuthOptions) => Adapter) | null = null;
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "mongodb-adapter",
adapterName: "MongoDB Adapter",
usePlural: config?.usePlural ?? false,
debugLogs: config?.debugLogs ?? false,
mapKeysTransformInput: {
id: "_id",
},
mapKeysTransformOutput: {
_id: "id",
},
supportsNumericIds: false,
transaction:
config?.client && (config?.transaction ?? true)
? async (cb) => {
if (!config.client) {
return cb(lazyAdapter!(lazyOptions!));
}
const session = config.client.startSession();
try {
session.startTransaction();
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(db, session),
})(lazyOptions!);
const result = await cb(adapter);
await session.commitTransaction();
return result;
} catch (err) {
await session.abortTransaction();
throw err;
} finally {
await session.endSession();
}
}
: false,
customTransformInput({
action,
data,
field,
fieldAttributes,
schema,
model,
options,
}) {
const customIdGen = getCustomIdGenerator(options);
if (field === "_id" || fieldAttributes.references?.field === "id") {
if (customIdGen) {
return data;
}
if (action === "update") {
return data;
}
if (Array.isArray(data)) {
return data.map((v) => new ObjectId());
}
if (typeof data === "string") {
try {
return new ObjectId(data);
} catch (error) {
return new ObjectId();
}
}
return new ObjectId();
}
return data;
},
customTransformOutput({ data, field, fieldAttributes }) {
if (field === "id" || fieldAttributes.references?.field === "id") {
if (data instanceof ObjectId) {
return data.toHexString();
}
if (Array.isArray(data)) {
return data.map((v) => {
if (v instanceof ObjectId) {
return v.toHexString();
}
return v;
});
}
return data;
}
return data;
},
},
});
adapter: createCustomAdapter(db),
};
lazyAdapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return lazyAdapter(options);
};
};

View File

@@ -1,6 +1,11 @@
import { BetterAuthError } from "../../error";
import type { Where } from "../../types";
import { createAdapter, type AdapterDebugLogs } from "../create-adapter";
import type { Adapter, BetterAuthOptions, Where } from "../../types";
import {
createAdapter,
type AdapterDebugLogs,
type CreateAdapterOptions,
type CreateCustomAdapter,
} from "../create-adapter";
export interface PrismaConfig {
/**
@@ -27,11 +32,24 @@ export interface PrismaConfig {
* @default false
*/
usePlural?: boolean;
/**
* Whether to execute multiple operations in a transaction.
*
* If the database doesn't support transactions,
* set this to `false` and operations will be executed sequentially.
* @default true
*/
transaction?: boolean;
}
interface PrismaClient {}
interface PrismaClientInternal {
type PrismaClientInternal = {
$transaction: (
callback: (db: PrismaClient) => Promise<any> | any,
) => Promise<any>;
} & {
[model: string]: {
create: (data: any) => Promise<any>;
findFirst: (data: any) => Promise<any>;
@@ -40,17 +58,13 @@ interface PrismaClientInternal {
delete: (data: any) => Promise<any>;
[key: string]: any;
};
}
};
export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) =>
createAdapter({
config: {
adapterId: "prisma",
adapterName: "Prisma Adapter",
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
},
adapter: ({ getFieldName }) => {
export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) => {
let lazyOptions: BetterAuthOptions | null = null;
const createCustomAdapter =
(prisma: PrismaClient): CreateCustomAdapter =>
({ getFieldName }) => {
const db = prisma as PrismaClientInternal;
const convertSelect = (select?: string[], model?: string) => {
@@ -217,5 +231,33 @@ export const prismaAdapter = (prisma: PrismaClient, config: PrismaConfig) =>
},
options: config,
};
};
let adapterOptions: CreateAdapterOptions | null = null;
adapterOptions = {
config: {
adapterId: "prisma",
adapterName: "Prisma Adapter",
usePlural: config.usePlural ?? false,
debugLogs: config.debugLogs ?? false,
transaction:
(config.transaction ?? true)
? (cb) =>
(prisma as PrismaClientInternal).$transaction((tx) => {
const adapter = createAdapter({
config: adapterOptions!.config,
adapter: createCustomAdapter(tx),
})(lazyOptions!);
return cb(adapter);
})
: false,
},
});
adapter: createCustomAdapter(prisma),
};
const adapter = createAdapter(adapterOptions);
return (options: BetterAuthOptions): Adapter => {
lazyOptions = options;
return adapter(options);
};
};

View File

@@ -47,6 +47,8 @@ const adapterTests = {
SHOULD_SEARCH_USERS_WITH_STARTS_WITH: "should search users with startsWith",
SHOULD_SEARCH_USERS_WITH_ENDS_WITH: "should search users with endsWith",
SHOULD_PREFER_GENERATE_ID_IF_PROVIDED: "should prefer generateId if provided",
SHOULD_ROLLBACK_FAILING_TRANSACTION: "should rollback failing transaction",
SHOULD_RETURN_TRANSACTION_RESULT: "should return transaction result",
} as const;
const { ...numberIdAdapterTestsCopy } = adapterTests;
@@ -820,6 +822,83 @@ async function adapterTest(
expect(res.id).toBe("mocked-id");
},
);
test.skipIf(disabledTests?.SHOULD_ROLLBACK_FAILING_TRANSACTION)(
`${testPrefix ? `${testPrefix} - ` : ""}${adapterTests.SHOULD_ROLLBACK_FAILING_TRANSACTION}`,
async ({ onTestFailed }) => {
await resetDebugLogs();
onTestFailed(async () => {
await printDebugLogs();
});
const customAdapter = await adapter();
const user5 = {
name: "user5",
email: "user5@email.com",
emailVerified: true,
createdAt: new Date(),
updatedAt: new Date(),
};
const user6 = {
name: "user6",
email: "user6@email.com",
emailVerified: true,
createdAt: new Date(),
updatedAt: new Date(),
};
await expect(
customAdapter.transaction(async (tx) => {
await tx.create({ model: "user", data: user5 });
throw new Error("Simulated failure");
await tx.create({ model: "user", data: user6 });
}),
).rejects.toThrow("Simulated failure");
await expect(
customAdapter.findMany({
model: "user",
where: [
{
field: "email",
value: user5.email,
connector: "OR",
},
{
field: "email",
value: user6.email,
connector: "OR",
},
],
}),
).resolves.toEqual([]);
},
);
test.skipIf(disabledTests?.SHOULD_RETURN_TRANSACTION_RESULT)(
`${testPrefix ? `${testPrefix} - ` : ""}${adapterTests.SHOULD_RETURN_TRANSACTION_RESULT}`,
async ({ onTestFailed }) => {
await resetDebugLogs();
onTestFailed(async () => {
await printDebugLogs();
});
const customAdapter = await adapter();
const result = await customAdapter.transaction(async (tx) => {
const createdUser = await tx.create<User>({
model: "user",
data: {
name: "user6",
email: "user6@email.com",
emailVerified: true,
createdAt: new Date(),
updatedAt: new Date(),
},
});
return createdUser.email;
});
expect(result).toEqual("user6@email.com");
},
);
}
export async function runAdapterTest(opts: AdapterTestOptions) {

View File

@@ -5,6 +5,7 @@ import type {
AuthContext,
BetterAuthOptions,
GenericEndpointContext,
TransactionAdapter,
Where,
} from "../types";
import {
@@ -41,39 +42,44 @@ export const createInternalAdapter = (
Partial<Account>,
context?: GenericEndpointContext,
) => {
const createdUser = await createWithHooks(
{
// todo: we should remove auto setting createdAt and updatedAt in the next major release, since the db generators already handle that
createdAt: new Date(),
updatedAt: new Date(),
...user,
},
"user",
undefined,
context,
);
const createdAccount = await createWithHooks(
{
...account,
userId: createdUser!.id,
// todo: we should remove auto setting createdAt and updatedAt in the next major release, since the db generators already handle that
createdAt: new Date(),
updatedAt: new Date(),
},
"account",
undefined,
context,
);
return {
user: createdUser,
account: createdAccount,
};
return adapter.transaction(async (trxAdapter) => {
const createdUser = await createWithHooks(
{
// todo: we should remove auto setting createdAt and updatedAt in the next major release, since the db generators already handle that
createdAt: new Date(),
updatedAt: new Date(),
...user,
},
"user",
undefined,
context,
trxAdapter,
);
const createdAccount = await createWithHooks(
{
...account,
userId: createdUser!.id,
// todo: we should remove auto setting createdAt and updatedAt in the next major release, since the db generators already handle that
createdAt: new Date(),
updatedAt: new Date(),
},
"account",
undefined,
context,
trxAdapter,
);
return {
user: createdUser,
account: createdAccount,
};
});
},
createUser: async <T>(
user: Omit<User, "id" | "createdAt" | "updatedAt" | "emailVerified"> &
Partial<User> &
Record<string, any>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const createdUser = await createWithHooks(
{
@@ -86,6 +92,7 @@ export const createInternalAdapter = (
"user",
undefined,
context,
trxAdapter,
);
return createdUser as T & User;
},
@@ -94,6 +101,7 @@ export const createInternalAdapter = (
Partial<Account> &
T,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const createdAccount = await createWithHooks(
{
@@ -105,10 +113,11 @@ export const createInternalAdapter = (
"account",
undefined,
context,
trxAdapter,
);
return createdAccount as T & Account;
},
listSessions: async (userId: string) => {
listSessions: async (userId: string, trxAdapter?: TransactionAdapter) => {
if (secondaryStorage) {
const currentList = await secondaryStorage.get(
`active-sessions-${userId}`,
@@ -140,7 +149,7 @@ export const createInternalAdapter = (
return sessions;
}
const sessions = await adapter.findMany<Session>({
const sessions = await (trxAdapter || adapter).findMany<Session>({
model: "session",
where: [
{
@@ -159,8 +168,9 @@ export const createInternalAdapter = (
direction: "asc" | "desc";
},
where?: Where[],
trxAdapter?: TransactionAdapter,
) => {
const users = await adapter.findMany<User>({
const users = await (trxAdapter || adapter).findMany<User>({
model: "user",
limit,
offset,
@@ -169,8 +179,11 @@ export const createInternalAdapter = (
});
return users;
},
countTotalUsers: async (where?: Where[]) => {
const total = await adapter.count({
countTotalUsers: async (
where?: Where[],
trxAdapter?: TransactionAdapter,
) => {
const total = await (trxAdapter || adapter).count({
model: "user",
where,
});
@@ -179,13 +192,13 @@ export const createInternalAdapter = (
}
return total;
},
deleteUser: async (userId: string) => {
deleteUser: async (userId: string, trxAdapter?: TransactionAdapter) => {
if (secondaryStorage) {
await secondaryStorage.delete(`active-sessions-${userId}`);
}
if (!secondaryStorage || options.session?.storeSessionInDatabase) {
await adapter.deleteMany({
await (trxAdapter || adapter).deleteMany({
model: "session",
where: [
{
@@ -196,7 +209,7 @@ export const createInternalAdapter = (
});
}
await adapter.deleteMany({
await (trxAdapter || adapter).deleteMany({
model: "account",
where: [
{
@@ -205,7 +218,7 @@ export const createInternalAdapter = (
},
],
});
await adapter.delete({
await (trxAdapter || adapter).delete({
model: "user",
where: [
{
@@ -221,6 +234,7 @@ export const createInternalAdapter = (
dontRememberMe?: boolean,
override?: Partial<Session> & Record<string, any>,
overrideAll?: boolean,
trxAdapter?: TransactionAdapter,
) => {
const headers = ctx.headers || ctx.request?.headers;
const { id: _, ...rest } = override || {};
@@ -285,11 +299,13 @@ export const createInternalAdapter = (
}
: undefined,
ctx,
trxAdapter,
);
return res as Session;
},
findSession: async (
token: string,
trxAdapter?: TransactionAdapter,
): Promise<{
session: Session & Record<string, any>;
user: User & Record<string, any>;
@@ -323,7 +339,7 @@ export const createInternalAdapter = (
}
}
const session = await adapter.findOne<Session>({
const session = await (trxAdapter || adapter).findOne<Session>({
model: "session",
where: [
{
@@ -337,7 +353,7 @@ export const createInternalAdapter = (
return null;
}
const user = await adapter.findOne<User>({
const user = await (trxAdapter || adapter).findOne<User>({
model: "user",
where: [
{
@@ -357,7 +373,10 @@ export const createInternalAdapter = (
user: parsedUser,
};
},
findSessions: async (sessionTokens: string[]) => {
findSessions: async (
sessionTokens: string[],
trxAdapter?: TransactionAdapter,
) => {
if (secondaryStorage) {
const sessions: {
session: Session;
@@ -391,7 +410,7 @@ export const createInternalAdapter = (
return sessions;
}
const sessions = await adapter.findMany<Session>({
const sessions = await (trxAdapter || adapter).findMany<Session>({
model: "session",
where: [
{
@@ -405,7 +424,7 @@ export const createInternalAdapter = (
return session.userId;
});
if (!userIds.length) return [];
const users = await adapter.findMany<User>({
const users = await (trxAdapter || adapter).findMany<User>({
model: "user",
where: [
{
@@ -431,6 +450,7 @@ export const createInternalAdapter = (
sessionToken: string,
session: Partial<Session> & Record<string, any>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const updatedSession = await updateWithHooks<Session>(
session,
@@ -460,10 +480,11 @@ export const createInternalAdapter = (
}
: undefined,
context,
trxAdapter,
);
return updatedSession;
},
deleteSession: async (token: string) => {
deleteSession: async (token: string, trxAdapter?: TransactionAdapter) => {
if (secondaryStorage) {
// remove the session from the active sessions list
const data = await secondaryStorage.get(token);
@@ -510,7 +531,7 @@ export const createInternalAdapter = (
return;
}
}
await adapter.delete<Session>({
await (trxAdapter || adapter).delete<Session>({
model: "session",
where: [
{
@@ -520,8 +541,8 @@ export const createInternalAdapter = (
],
});
},
deleteAccounts: async (userId: string) => {
await adapter.deleteMany({
deleteAccounts: async (userId: string, trxAdapter?: TransactionAdapter) => {
await (trxAdapter || adapter).deleteMany({
model: "account",
where: [
{
@@ -531,8 +552,11 @@ export const createInternalAdapter = (
],
});
},
deleteAccount: async (accountId: string) => {
await adapter.delete({
deleteAccount: async (
accountId: string,
trxAdapter?: TransactionAdapter,
) => {
await (trxAdapter || adapter).delete({
model: "account",
where: [
{
@@ -542,7 +566,10 @@ export const createInternalAdapter = (
],
});
},
deleteSessions: async (userIdOrSessionTokens: string | string[]) => {
deleteSessions: async (
userIdOrSessionTokens: string | string[],
trxAdapter?: TransactionAdapter,
) => {
if (secondaryStorage) {
if (typeof userIdOrSessionTokens === "string") {
const activeSession = await secondaryStorage.get(
@@ -571,7 +598,7 @@ export const createInternalAdapter = (
return;
}
}
await adapter.deleteMany({
await (trxAdapter || adapter).deleteMany({
model: "session",
where: [
{
@@ -586,9 +613,10 @@ export const createInternalAdapter = (
email: string,
accountId: string,
providerId: string,
trxAdapter?: TransactionAdapter,
) => {
// we need to find account first to avoid missing user if the email changed with the provider for the same account
const account = await adapter
const account = await (trxAdapter || adapter)
.findMany<Account>({
model: "account",
where: [
@@ -602,7 +630,7 @@ export const createInternalAdapter = (
return accounts.find((a) => a.providerId === providerId);
});
if (account) {
const user = await adapter.findOne<User>({
const user = await (trxAdapter || adapter).findOne<User>({
model: "user",
where: [
{
@@ -617,7 +645,7 @@ export const createInternalAdapter = (
accounts: [account],
};
} else {
const user = await adapter.findOne<User>({
const user = await (trxAdapter || adapter).findOne<User>({
model: "user",
where: [
{
@@ -635,7 +663,7 @@ export const createInternalAdapter = (
return null;
}
} else {
const user = await adapter.findOne<User>({
const user = await (trxAdapter || adapter).findOne<User>({
model: "user",
where: [
{
@@ -645,7 +673,7 @@ export const createInternalAdapter = (
],
});
if (user) {
const accounts = await adapter.findMany<Account>({
const accounts = await (trxAdapter || adapter).findMany<Account>({
model: "account",
where: [
{
@@ -666,8 +694,9 @@ export const createInternalAdapter = (
findUserByEmail: async (
email: string,
options?: { includeAccounts: boolean },
trxAdapter?: TransactionAdapter,
) => {
const user = await adapter.findOne<User>({
const user = await (trxAdapter || adapter).findOne<User>({
model: "user",
where: [
{
@@ -678,7 +707,7 @@ export const createInternalAdapter = (
});
if (!user) return null;
if (options?.includeAccounts) {
const accounts = await adapter.findMany<Account>({
const accounts = await (trxAdapter || adapter).findMany<Account>({
model: "account",
where: [
{
@@ -697,8 +726,8 @@ export const createInternalAdapter = (
accounts: [],
};
},
findUserById: async (userId: string) => {
const user = await adapter.findOne<User>({
findUserById: async (userId: string, trxAdapter?: TransactionAdapter) => {
const user = await (trxAdapter || adapter).findOne<User>({
model: "user",
where: [
{
@@ -713,6 +742,7 @@ export const createInternalAdapter = (
account: Omit<Account, "id" | "createdAt" | "updatedAt"> &
Partial<Account>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const _account = await createWithHooks(
{
@@ -724,6 +754,7 @@ export const createInternalAdapter = (
"account",
undefined,
context,
trxAdapter,
);
return _account;
},
@@ -731,6 +762,7 @@ export const createInternalAdapter = (
userId: string,
data: Partial<User> & Record<string, any>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const user = await updateWithHooks<User>(
data,
@@ -743,6 +775,7 @@ export const createInternalAdapter = (
"user",
undefined,
context,
trxAdapter,
);
if (secondaryStorage && user) {
const listRaw = await secondaryStorage.get(`active-sessions-${userId}`);
@@ -785,6 +818,7 @@ export const createInternalAdapter = (
email: string,
data: Partial<User & Record<string, any>>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const user = await updateWithHooks<User>(
data,
@@ -797,6 +831,7 @@ export const createInternalAdapter = (
"user",
undefined,
context,
trxAdapter,
);
return user;
},
@@ -804,6 +839,7 @@ export const createInternalAdapter = (
userId: string,
password: string,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
await updateManyWithHooks(
{
@@ -822,10 +858,11 @@ export const createInternalAdapter = (
"account",
undefined,
context,
trxAdapter,
);
},
findAccounts: async (userId: string) => {
const accounts = await adapter.findMany<Account>({
findAccounts: async (userId: string, trxAdapter?: TransactionAdapter) => {
const accounts = await (trxAdapter || adapter).findMany<Account>({
model: "account",
where: [
{
@@ -836,8 +873,8 @@ export const createInternalAdapter = (
});
return accounts;
},
findAccount: async (accountId: string) => {
const account = await adapter.findOne<Account>({
findAccount: async (accountId: string, trxAdapter?: TransactionAdapter) => {
const account = await (trxAdapter || adapter).findOne<Account>({
model: "account",
where: [
{
@@ -848,8 +885,12 @@ export const createInternalAdapter = (
});
return account;
},
findAccountByProviderId: async (accountId: string, providerId: string) => {
const account = await adapter.findOne<Account>({
findAccountByProviderId: async (
accountId: string,
providerId: string,
trxAdapter?: TransactionAdapter,
) => {
const account = await (trxAdapter || adapter).findOne<Account>({
model: "account",
where: [
{
@@ -864,8 +905,11 @@ export const createInternalAdapter = (
});
return account;
},
findAccountByUserId: async (userId: string) => {
const account = await adapter.findMany<Account>({
findAccountByUserId: async (
userId: string,
trxAdapter?: TransactionAdapter,
) => {
const account = await (trxAdapter || adapter).findMany<Account>({
model: "account",
where: [
{
@@ -880,6 +924,7 @@ export const createInternalAdapter = (
id: string,
data: Partial<Account>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const account = await updateWithHooks<Account>(
data,
@@ -887,6 +932,7 @@ export const createInternalAdapter = (
"account",
undefined,
context,
trxAdapter,
);
return account;
},
@@ -894,6 +940,7 @@ export const createInternalAdapter = (
data: Omit<Verification, "createdAt" | "id" | "updatedAt"> &
Partial<Verification>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const verification = await createWithHooks(
{
@@ -905,26 +952,32 @@ export const createInternalAdapter = (
"verification",
undefined,
context,
trxAdapter,
);
return verification as Verification;
},
findVerificationValue: async (identifier: string) => {
const verification = await adapter.findMany<Verification>({
model: "verification",
where: [
{
field: "identifier",
value: identifier,
findVerificationValue: async (
identifier: string,
trxAdapter?: TransactionAdapter,
) => {
const verification = await (trxAdapter || adapter).findMany<Verification>(
{
model: "verification",
where: [
{
field: "identifier",
value: identifier,
},
],
sortBy: {
field: "createdAt",
direction: "desc",
},
],
sortBy: {
field: "createdAt",
direction: "desc",
limit: 1,
},
limit: 1,
});
);
if (!options.verification?.disableCleanup) {
await adapter.deleteMany({
await (trxAdapter || adapter).deleteMany({
model: "verification",
where: [
{
@@ -938,8 +991,11 @@ export const createInternalAdapter = (
const lastVerification = verification[0];
return lastVerification as Verification | null;
},
deleteVerificationValue: async (id: string) => {
await adapter.delete<Verification>({
deleteVerificationValue: async (
id: string,
trxAdapter?: TransactionAdapter,
) => {
await (trxAdapter || adapter).delete<Verification>({
model: "verification",
where: [
{
@@ -949,8 +1005,11 @@ export const createInternalAdapter = (
],
});
},
deleteVerificationByIdentifier: async (identifier: string) => {
await adapter.delete<Verification>({
deleteVerificationByIdentifier: async (
identifier: string,
trxAdapter?: TransactionAdapter,
) => {
await (trxAdapter || adapter).delete<Verification>({
model: "verification",
where: [
{
@@ -964,6 +1023,7 @@ export const createInternalAdapter = (
id: string,
data: Partial<Verification>,
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) => {
const verification = await updateWithHooks<Verification>(
data,
@@ -971,6 +1031,7 @@ export const createInternalAdapter = (
"verification",
undefined,
context,
trxAdapter,
);
return verification;
},

View File

@@ -3,6 +3,7 @@ import type {
BetterAuthOptions,
GenericEndpointContext,
Models,
TransactionAdapter,
Where,
} from "../types";
@@ -26,6 +27,7 @@ export function getWithHooks(
executeMainFn?: boolean;
},
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) {
let actualData = data;
for (const hook of hooks || []) {
@@ -50,7 +52,7 @@ export function getWithHooks(
: null;
const created =
!customCreateFn || customCreateFn.executeMainFn
? await adapter.create<T>({
? await (trxAdapter || adapter).create<T>({
model,
data: actualData as any,
forceAllowId: true,
@@ -76,6 +78,7 @@ export function getWithHooks(
executeMainFn?: boolean;
},
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) {
let actualData = data;
@@ -97,7 +100,7 @@ export function getWithHooks(
const updated =
!customUpdateFn || customUpdateFn.executeMainFn
? await adapter.update<T>({
? await (trxAdapter || adapter).update<T>({
model,
update: actualData,
where,
@@ -122,6 +125,7 @@ export function getWithHooks(
executeMainFn?: boolean;
},
context?: GenericEndpointContext,
trxAdapter?: TransactionAdapter,
) {
let actualData = data;
@@ -143,7 +147,7 @@ export function getWithHooks(
const updated =
!customUpdateFn || customUpdateFn.executeMainFn
? await adapter.updateMany({
? await (trxAdapter || adapter).updateMany({
model,
update: actualData,
where,

View File

@@ -70,6 +70,13 @@ export type Adapter = {
}) => Promise<number>;
delete: <T>(data: { model: string; where: Where[] }) => Promise<void>;
deleteMany: (data: { model: string; where: Where[] }) => Promise<number>;
/**
* Execute multiple operations in a transaction.
* If the adapter doesn't support transactions, operations will be executed sequentially.
*/
transaction: <R>(
callback: (tx: Omit<Adapter, "transaction">) => Promise<R>,
) => Promise<R>;
/**
*
* @param options
@@ -84,6 +91,8 @@ export type Adapter = {
} & CustomAdapter["options"];
};
export type TransactionAdapter = Omit<Adapter, "transaction">;
export type AdapterSchemaCreation = {
/**
* Code to be inserted into the file