feat: shared context between upgrade hook and peer (#111)

This commit is contained in:
Pooya Parsa
2025-01-22 01:56:32 +02:00
committed by GitHub
parent d17eee6ce0
commit 35842c6ecd
10 changed files with 67 additions and 22 deletions

View File

@@ -19,6 +19,7 @@ type ContextData = {
peer?: BunPeer;
request: Request;
server?: Server;
context: Peer["context"];
};
// --- adapter ---
@@ -31,7 +32,8 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
return {
...adapterUtils(peers),
async handleUpgrade(request, server) {
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
@@ -39,6 +41,7 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
data: {
server,
request,
context,
} satisfies ContextData,
headers: upgradeHeaders,
});
@@ -93,6 +96,10 @@ class BunPeer extends Peer<{
return this._internal.ws.remoteAddress;
}
get context() {
return this._internal.ws.data.context;
}
send(data: unknown, options?: { compress?: boolean }) {
return this._internal.ws.send(toBufferLike(data), options?.compress);
}

View File

@@ -32,8 +32,8 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
const peers = new Set<CloudflarePeer>();
return {
...adapterUtils(peers),
handleUpgrade: async (request, env, context) => {
const { upgradeHeaders, endResponse } = await hooks.upgrade(
handleUpgrade: async (request, env, cfCtx) => {
const { upgradeHeaders, endResponse, context } = await hooks.upgrade(
request as unknown as Request,
);
if (endResponse) {
@@ -49,7 +49,8 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
wsServer: server,
request: request as unknown as Request,
cfEnv: env,
cfCtx: context,
cfCtx: cfCtx,
context,
});
peers.add(peer);
server.accept();
@@ -89,6 +90,7 @@ class CloudflarePeer extends Peer<{
wsServer: _cf.WebSocket;
cfEnv: unknown;
cfCtx: _cf.ExecutionContext;
context: Peer["context"];
}> {
send(data: unknown) {
this._internal.wsServer.send(toBufferLike(data));

View File

@@ -31,7 +31,8 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
return {
...adapterUtils(peers),
handleUpgrade: async (request, info) => {
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
@@ -45,6 +46,7 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
request,
peers,
denoInfo: info,
context,
});
peers.add(peer);
upgrade.socket.addEventListener("open", () => {
@@ -74,6 +76,7 @@ class DenoPeer extends Peer<{
request: Request;
peers: Set<DenoPeer>;
denoInfo: ServeHandlerInfo;
context: Peer["context"];
}> {
get remoteAddress() {
return this._internal.denoInfo.remoteAddr?.hostname;

View File

@@ -7,7 +7,7 @@ import { Message } from "../message.ts";
import { WSError } from "../error.ts";
import { Peer } from "../peer.ts";
import type { ClientRequest, IncomingMessage } from "node:http";
import type { IncomingMessage } from "node:http";
import type { Duplex } from "node:stream";
import { WebSocketServer as _WebSocketServer } from "ws";
import type {
@@ -21,6 +21,7 @@ import type {
type AugmentedReq = IncomingMessage & {
_request: NodeReqProxy;
_upgradeHeaders?: HeadersInit;
_context: Peer["context"];
};
export interface NodeAdapter extends AdapterInstance {
@@ -87,13 +88,15 @@ export default defineWebSocketAdapter<NodeAdapter, NodeOptions>(
handleUpgrade: async (nodeReq, socket, head) => {
const request = new NodeReqProxy(nodeReq);
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) {
return sendResponse(socket, endResponse);
}
(nodeReq as AugmentedReq)._request = request;
(nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders;
(nodeReq as AugmentedReq)._context = context;
wss.handleUpgrade(nodeReq, socket, head, (ws) => {
wss.emit("connection", ws, nodeReq);
});
@@ -119,6 +122,10 @@ class NodePeer extends Peer<{
return this._internal.nodeReq.socket?.remoteAddress;
}
get context() {
return (this._internal.nodeReq as AugmentedReq)._context;
}
send(data: unknown, options?: { compress?: boolean }) {
const dataBuff = toBufferLike(data);
const isBinary = typeof data !== "string";

View File

@@ -27,7 +27,8 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
return {
...adapterUtils(peers),
fetch: async (request: Request) => {
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
@@ -60,6 +61,7 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
request,
hooks,
ws,
context,
});
peers.add(peer);
if (opts.bidir) {
@@ -98,6 +100,7 @@ class SSEPeer extends Peer<{
request: Request;
ws: SSEWebSocketStub;
hooks: AdapterHookable;
context: Peer["context"];
}> {
_sseStream: ReadableStream; // server -> client
_sseStreamController?: ReadableStreamDefaultController;

View File

@@ -15,6 +15,7 @@ type UserData = {
res: uws.HttpResponse;
protocol: string;
extensions: string;
context: Peer["context"];
};
type WebSocketHandler = uws.WebSocketBehavior<UserData>;
@@ -70,13 +71,13 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
peers.add(peer);
hooks.callHook("open", peer);
},
async upgrade(res, req, context) {
async upgrade(res, req, uwsContext) {
let aborted = false;
res.onAborted(() => {
aborted = true;
});
const { upgradeHeaders, endResponse } = await hooks.upgrade(
const { upgradeHeaders, endResponse, context } = await hooks.upgrade(
new UWSReqProxy(req),
);
if (endResponse) {
@@ -119,11 +120,12 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
res,
protocol,
extensions,
context,
},
key,
protocol,
extensions,
context,
uwsContext,
);
});
},
@@ -168,6 +170,10 @@ class UWSPeer extends Peer<{
}
}
get context() {
return this._internal.uwsData.context;
}
send(data: unknown, options?: { compress?: boolean }) {
const dataBuff = toBufferLike(data);
const isBinary = typeof data !== "string";

View File

@@ -40,30 +40,38 @@ export class AdapterHookable {
) as Promise<any>;
}
async upgrade(request: UpgradeRequest): Promise<{
async upgrade(
request: UpgradeRequest & { context?: Peer["context"] },
): Promise<{
upgradeHeaders?: HeadersInit;
endResponse?: Response;
context: Peer["context"];
}> {
const context = (request.context ??= {});
try {
const res = await this.callHook("upgrade", request);
const res = await this.callHook(
"upgrade",
request as UpgradeRequest & { context: Peer["context"] },
);
if (!res) {
return {};
return { context };
}
if ((res as Response).ok === false) {
return { endResponse: res as Response };
return { context, endResponse: res as Response };
}
if (res.headers) {
return {
context,
upgradeHeaders: res.headers,
};
}
} catch (error) {
if (error instanceof Response) {
return { endResponse: error };
return { context, endResponse: error };
}
throw error;
}
return {};
return { context };
}
}
@@ -96,7 +104,7 @@ export interface Hooks {
* @throws {Response}
*/
upgrade: (
request: UpgradeRequest,
request: UpgradeRequest & { context: Peer["context"] },
) => MaybePromise<Response | ResponseInit | undefined>;
/** A message is received */

View File

@@ -5,6 +5,7 @@ export interface AdapterInternal {
ws: unknown;
request?: Request | Partial<Request>;
peers?: Set<Peer>;
context?: Peer["context"];
}
export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> {
@@ -14,14 +15,15 @@ export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> {
#ws?: Partial<web.WebSocket>;
readonly context: Record<string, unknown>;
constructor(internal: Internal) {
this._topics = new Set();
this.context = {};
this._internal = internal;
}
get context(): Record<string, unknown> {
return (this._internal.context ??= {});
}
/**
* Unique random [uuid v4](https://developer.mozilla.org/en-US/docs/Glossary/UUID) identifier for the peer.
*/

View File

@@ -28,6 +28,7 @@ export function createDemo<T extends Adapter<any, any>>(
peer.send({
id: peer.id,
remoteAddress: peer.remoteAddress,
context: peer.context,
request: {
url: peer.request?.url,
headers: Object.fromEntries(peer.request?.headers || []),
@@ -63,6 +64,7 @@ export function createDemo<T extends Adapter<any, any>>(
headers: { "x-error": "unauthorized" },
});
}
req.context.test = "1";
return {
headers: {
"x-powered-by": "cross-ws",

View File

@@ -69,7 +69,7 @@ export function wsTests(getURL: () => string, opts: WSTestOpts) {
headers: { "x-test": "1" },
});
await ws.send("debug");
const { request, remoteAddress } = await ws.next();
const { request, remoteAddress, context } = await ws.next();
// Headers
if (opts.adapter === "sse") {
@@ -88,6 +88,11 @@ export function wsTests(getURL: () => string, opts: WSTestOpts) {
if (!/sse|cloudflare/.test(opts.adapter)) {
expect(remoteAddress).toMatch(/:{2}1|(?:0{4}:){7}0{3}1|127\.0\.\0\.1/);
}
// Context
if (opts.adapter !== "cloudflare-durable") {
expect(context.test).toBe("1");
}
});
test("peer.websocket", async () => {