feat: shared context between upgrade hook and peer (#111)

This commit is contained in:
Pooya Parsa
2025-01-22 01:56:32 +02:00
committed by GitHub
parent d17eee6ce0
commit 35842c6ecd
10 changed files with 67 additions and 22 deletions

View File

@@ -19,6 +19,7 @@ type ContextData = {
peer?: BunPeer; peer?: BunPeer;
request: Request; request: Request;
server?: Server; server?: Server;
context: Peer["context"];
}; };
// --- adapter --- // --- adapter ---
@@ -31,7 +32,8 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
async handleUpgrade(request, server) { async handleUpgrade(request, server) {
const { upgradeHeaders, endResponse } = await hooks.upgrade(request); const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) { if (endResponse) {
return endResponse; return endResponse;
} }
@@ -39,6 +41,7 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
data: { data: {
server, server,
request, request,
context,
} satisfies ContextData, } satisfies ContextData,
headers: upgradeHeaders, headers: upgradeHeaders,
}); });
@@ -93,6 +96,10 @@ class BunPeer extends Peer<{
return this._internal.ws.remoteAddress; return this._internal.ws.remoteAddress;
} }
get context() {
return this._internal.ws.data.context;
}
send(data: unknown, options?: { compress?: boolean }) { send(data: unknown, options?: { compress?: boolean }) {
return this._internal.ws.send(toBufferLike(data), options?.compress); return this._internal.ws.send(toBufferLike(data), options?.compress);
} }

View File

@@ -32,8 +32,8 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
const peers = new Set<CloudflarePeer>(); const peers = new Set<CloudflarePeer>();
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
handleUpgrade: async (request, env, context) => { handleUpgrade: async (request, env, cfCtx) => {
const { upgradeHeaders, endResponse } = await hooks.upgrade( const { upgradeHeaders, endResponse, context } = await hooks.upgrade(
request as unknown as Request, request as unknown as Request,
); );
if (endResponse) { if (endResponse) {
@@ -49,7 +49,8 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
wsServer: server, wsServer: server,
request: request as unknown as Request, request: request as unknown as Request,
cfEnv: env, cfEnv: env,
cfCtx: context, cfCtx: cfCtx,
context,
}); });
peers.add(peer); peers.add(peer);
server.accept(); server.accept();
@@ -89,6 +90,7 @@ class CloudflarePeer extends Peer<{
wsServer: _cf.WebSocket; wsServer: _cf.WebSocket;
cfEnv: unknown; cfEnv: unknown;
cfCtx: _cf.ExecutionContext; cfCtx: _cf.ExecutionContext;
context: Peer["context"];
}> { }> {
send(data: unknown) { send(data: unknown) {
this._internal.wsServer.send(toBufferLike(data)); this._internal.wsServer.send(toBufferLike(data));

View File

@@ -31,7 +31,8 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
handleUpgrade: async (request, info) => { handleUpgrade: async (request, info) => {
const { upgradeHeaders, endResponse } = await hooks.upgrade(request); const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) { if (endResponse) {
return endResponse; return endResponse;
} }
@@ -45,6 +46,7 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
request, request,
peers, peers,
denoInfo: info, denoInfo: info,
context,
}); });
peers.add(peer); peers.add(peer);
upgrade.socket.addEventListener("open", () => { upgrade.socket.addEventListener("open", () => {
@@ -74,6 +76,7 @@ class DenoPeer extends Peer<{
request: Request; request: Request;
peers: Set<DenoPeer>; peers: Set<DenoPeer>;
denoInfo: ServeHandlerInfo; denoInfo: ServeHandlerInfo;
context: Peer["context"];
}> { }> {
get remoteAddress() { get remoteAddress() {
return this._internal.denoInfo.remoteAddr?.hostname; return this._internal.denoInfo.remoteAddr?.hostname;

View File

@@ -7,7 +7,7 @@ import { Message } from "../message.ts";
import { WSError } from "../error.ts"; import { WSError } from "../error.ts";
import { Peer } from "../peer.ts"; import { Peer } from "../peer.ts";
import type { ClientRequest, IncomingMessage } from "node:http"; import type { IncomingMessage } from "node:http";
import type { Duplex } from "node:stream"; import type { Duplex } from "node:stream";
import { WebSocketServer as _WebSocketServer } from "ws"; import { WebSocketServer as _WebSocketServer } from "ws";
import type { import type {
@@ -21,6 +21,7 @@ import type {
type AugmentedReq = IncomingMessage & { type AugmentedReq = IncomingMessage & {
_request: NodeReqProxy; _request: NodeReqProxy;
_upgradeHeaders?: HeadersInit; _upgradeHeaders?: HeadersInit;
_context: Peer["context"];
}; };
export interface NodeAdapter extends AdapterInstance { export interface NodeAdapter extends AdapterInstance {
@@ -87,13 +88,15 @@ export default defineWebSocketAdapter<NodeAdapter, NodeOptions>(
handleUpgrade: async (nodeReq, socket, head) => { handleUpgrade: async (nodeReq, socket, head) => {
const request = new NodeReqProxy(nodeReq); const request = new NodeReqProxy(nodeReq);
const { upgradeHeaders, endResponse } = await hooks.upgrade(request); const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) { if (endResponse) {
return sendResponse(socket, endResponse); return sendResponse(socket, endResponse);
} }
(nodeReq as AugmentedReq)._request = request; (nodeReq as AugmentedReq)._request = request;
(nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders; (nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders;
(nodeReq as AugmentedReq)._context = context;
wss.handleUpgrade(nodeReq, socket, head, (ws) => { wss.handleUpgrade(nodeReq, socket, head, (ws) => {
wss.emit("connection", ws, nodeReq); wss.emit("connection", ws, nodeReq);
}); });
@@ -119,6 +122,10 @@ class NodePeer extends Peer<{
return this._internal.nodeReq.socket?.remoteAddress; return this._internal.nodeReq.socket?.remoteAddress;
} }
get context() {
return (this._internal.nodeReq as AugmentedReq)._context;
}
send(data: unknown, options?: { compress?: boolean }) { send(data: unknown, options?: { compress?: boolean }) {
const dataBuff = toBufferLike(data); const dataBuff = toBufferLike(data);
const isBinary = typeof data !== "string"; const isBinary = typeof data !== "string";

View File

@@ -27,7 +27,8 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
fetch: async (request: Request) => { fetch: async (request: Request) => {
const { upgradeHeaders, endResponse } = await hooks.upgrade(request); const { upgradeHeaders, endResponse, context } =
await hooks.upgrade(request);
if (endResponse) { if (endResponse) {
return endResponse; return endResponse;
} }
@@ -60,6 +61,7 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
request, request,
hooks, hooks,
ws, ws,
context,
}); });
peers.add(peer); peers.add(peer);
if (opts.bidir) { if (opts.bidir) {
@@ -98,6 +100,7 @@ class SSEPeer extends Peer<{
request: Request; request: Request;
ws: SSEWebSocketStub; ws: SSEWebSocketStub;
hooks: AdapterHookable; hooks: AdapterHookable;
context: Peer["context"];
}> { }> {
_sseStream: ReadableStream; // server -> client _sseStream: ReadableStream; // server -> client
_sseStreamController?: ReadableStreamDefaultController; _sseStreamController?: ReadableStreamDefaultController;

View File

@@ -15,6 +15,7 @@ type UserData = {
res: uws.HttpResponse; res: uws.HttpResponse;
protocol: string; protocol: string;
extensions: string; extensions: string;
context: Peer["context"];
}; };
type WebSocketHandler = uws.WebSocketBehavior<UserData>; type WebSocketHandler = uws.WebSocketBehavior<UserData>;
@@ -70,13 +71,13 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
peers.add(peer); peers.add(peer);
hooks.callHook("open", peer); hooks.callHook("open", peer);
}, },
async upgrade(res, req, context) { async upgrade(res, req, uwsContext) {
let aborted = false; let aborted = false;
res.onAborted(() => { res.onAborted(() => {
aborted = true; aborted = true;
}); });
const { upgradeHeaders, endResponse } = await hooks.upgrade( const { upgradeHeaders, endResponse, context } = await hooks.upgrade(
new UWSReqProxy(req), new UWSReqProxy(req),
); );
if (endResponse) { if (endResponse) {
@@ -119,11 +120,12 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
res, res,
protocol, protocol,
extensions, extensions,
context,
}, },
key, key,
protocol, protocol,
extensions, extensions,
context, uwsContext,
); );
}); });
}, },
@@ -168,6 +170,10 @@ class UWSPeer extends Peer<{
} }
} }
get context() {
return this._internal.uwsData.context;
}
send(data: unknown, options?: { compress?: boolean }) { send(data: unknown, options?: { compress?: boolean }) {
const dataBuff = toBufferLike(data); const dataBuff = toBufferLike(data);
const isBinary = typeof data !== "string"; const isBinary = typeof data !== "string";

View File

@@ -40,30 +40,38 @@ export class AdapterHookable {
) as Promise<any>; ) as Promise<any>;
} }
async upgrade(request: UpgradeRequest): Promise<{ async upgrade(
request: UpgradeRequest & { context?: Peer["context"] },
): Promise<{
upgradeHeaders?: HeadersInit; upgradeHeaders?: HeadersInit;
endResponse?: Response; endResponse?: Response;
context: Peer["context"];
}> { }> {
const context = (request.context ??= {});
try { try {
const res = await this.callHook("upgrade", request); const res = await this.callHook(
"upgrade",
request as UpgradeRequest & { context: Peer["context"] },
);
if (!res) { if (!res) {
return {}; return { context };
} }
if ((res as Response).ok === false) { if ((res as Response).ok === false) {
return { endResponse: res as Response }; return { context, endResponse: res as Response };
} }
if (res.headers) { if (res.headers) {
return { return {
context,
upgradeHeaders: res.headers, upgradeHeaders: res.headers,
}; };
} }
} catch (error) { } catch (error) {
if (error instanceof Response) { if (error instanceof Response) {
return { endResponse: error }; return { context, endResponse: error };
} }
throw error; throw error;
} }
return {}; return { context };
} }
} }
@@ -96,7 +104,7 @@ export interface Hooks {
* @throws {Response} * @throws {Response}
*/ */
upgrade: ( upgrade: (
request: UpgradeRequest, request: UpgradeRequest & { context: Peer["context"] },
) => MaybePromise<Response | ResponseInit | undefined>; ) => MaybePromise<Response | ResponseInit | undefined>;
/** A message is received */ /** A message is received */

View File

@@ -5,6 +5,7 @@ export interface AdapterInternal {
ws: unknown; ws: unknown;
request?: Request | Partial<Request>; request?: Request | Partial<Request>;
peers?: Set<Peer>; peers?: Set<Peer>;
context?: Peer["context"];
} }
export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> { export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> {
@@ -14,14 +15,15 @@ export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> {
#ws?: Partial<web.WebSocket>; #ws?: Partial<web.WebSocket>;
readonly context: Record<string, unknown>;
constructor(internal: Internal) { constructor(internal: Internal) {
this._topics = new Set(); this._topics = new Set();
this.context = {};
this._internal = internal; this._internal = internal;
} }
get context(): Record<string, unknown> {
return (this._internal.context ??= {});
}
/** /**
* Unique random [uuid v4](https://developer.mozilla.org/en-US/docs/Glossary/UUID) identifier for the peer. * Unique random [uuid v4](https://developer.mozilla.org/en-US/docs/Glossary/UUID) identifier for the peer.
*/ */

View File

@@ -28,6 +28,7 @@ export function createDemo<T extends Adapter<any, any>>(
peer.send({ peer.send({
id: peer.id, id: peer.id,
remoteAddress: peer.remoteAddress, remoteAddress: peer.remoteAddress,
context: peer.context,
request: { request: {
url: peer.request?.url, url: peer.request?.url,
headers: Object.fromEntries(peer.request?.headers || []), headers: Object.fromEntries(peer.request?.headers || []),
@@ -63,6 +64,7 @@ export function createDemo<T extends Adapter<any, any>>(
headers: { "x-error": "unauthorized" }, headers: { "x-error": "unauthorized" },
}); });
} }
req.context.test = "1";
return { return {
headers: { headers: {
"x-powered-by": "cross-ws", "x-powered-by": "cross-ws",

View File

@@ -69,7 +69,7 @@ export function wsTests(getURL: () => string, opts: WSTestOpts) {
headers: { "x-test": "1" }, headers: { "x-test": "1" },
}); });
await ws.send("debug"); await ws.send("debug");
const { request, remoteAddress } = await ws.next(); const { request, remoteAddress, context } = await ws.next();
// Headers // Headers
if (opts.adapter === "sse") { if (opts.adapter === "sse") {
@@ -88,6 +88,11 @@ export function wsTests(getURL: () => string, opts: WSTestOpts) {
if (!/sse|cloudflare/.test(opts.adapter)) { if (!/sse|cloudflare/.test(opts.adapter)) {
expect(remoteAddress).toMatch(/:{2}1|(?:0{4}:){7}0{3}1|127\.0\.\0\.1/); expect(remoteAddress).toMatch(/:{2}1|(?:0{4}:){7}0{3}1|127\.0\.\0\.1/);
} }
// Context
if (opts.adapter !== "cloudflare-durable") {
expect(context.test).toBe("1");
}
}); });
test("peer.websocket", async () => { test("peer.websocket", async () => {