mirror of
https://github.com/LukeHagar/crossws.git
synced 2025-12-06 04:19:26 +00:00
feat: shared context between upgrade hook and peer (#111)
This commit is contained in:
@@ -19,6 +19,7 @@ type ContextData = {
|
||||
peer?: BunPeer;
|
||||
request: Request;
|
||||
server?: Server;
|
||||
context: Peer["context"];
|
||||
};
|
||||
|
||||
// --- adapter ---
|
||||
@@ -31,7 +32,8 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
|
||||
return {
|
||||
...adapterUtils(peers),
|
||||
async handleUpgrade(request, server) {
|
||||
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
|
||||
const { upgradeHeaders, endResponse, context } =
|
||||
await hooks.upgrade(request);
|
||||
if (endResponse) {
|
||||
return endResponse;
|
||||
}
|
||||
@@ -39,6 +41,7 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
|
||||
data: {
|
||||
server,
|
||||
request,
|
||||
context,
|
||||
} satisfies ContextData,
|
||||
headers: upgradeHeaders,
|
||||
});
|
||||
@@ -93,6 +96,10 @@ class BunPeer extends Peer<{
|
||||
return this._internal.ws.remoteAddress;
|
||||
}
|
||||
|
||||
get context() {
|
||||
return this._internal.ws.data.context;
|
||||
}
|
||||
|
||||
send(data: unknown, options?: { compress?: boolean }) {
|
||||
return this._internal.ws.send(toBufferLike(data), options?.compress);
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
|
||||
const peers = new Set<CloudflarePeer>();
|
||||
return {
|
||||
...adapterUtils(peers),
|
||||
handleUpgrade: async (request, env, context) => {
|
||||
const { upgradeHeaders, endResponse } = await hooks.upgrade(
|
||||
handleUpgrade: async (request, env, cfCtx) => {
|
||||
const { upgradeHeaders, endResponse, context } = await hooks.upgrade(
|
||||
request as unknown as Request,
|
||||
);
|
||||
if (endResponse) {
|
||||
@@ -49,7 +49,8 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
|
||||
wsServer: server,
|
||||
request: request as unknown as Request,
|
||||
cfEnv: env,
|
||||
cfCtx: context,
|
||||
cfCtx: cfCtx,
|
||||
context,
|
||||
});
|
||||
peers.add(peer);
|
||||
server.accept();
|
||||
@@ -89,6 +90,7 @@ class CloudflarePeer extends Peer<{
|
||||
wsServer: _cf.WebSocket;
|
||||
cfEnv: unknown;
|
||||
cfCtx: _cf.ExecutionContext;
|
||||
context: Peer["context"];
|
||||
}> {
|
||||
send(data: unknown) {
|
||||
this._internal.wsServer.send(toBufferLike(data));
|
||||
|
||||
@@ -31,7 +31,8 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
|
||||
return {
|
||||
...adapterUtils(peers),
|
||||
handleUpgrade: async (request, info) => {
|
||||
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
|
||||
const { upgradeHeaders, endResponse, context } =
|
||||
await hooks.upgrade(request);
|
||||
if (endResponse) {
|
||||
return endResponse;
|
||||
}
|
||||
@@ -45,6 +46,7 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
|
||||
request,
|
||||
peers,
|
||||
denoInfo: info,
|
||||
context,
|
||||
});
|
||||
peers.add(peer);
|
||||
upgrade.socket.addEventListener("open", () => {
|
||||
@@ -74,6 +76,7 @@ class DenoPeer extends Peer<{
|
||||
request: Request;
|
||||
peers: Set<DenoPeer>;
|
||||
denoInfo: ServeHandlerInfo;
|
||||
context: Peer["context"];
|
||||
}> {
|
||||
get remoteAddress() {
|
||||
return this._internal.denoInfo.remoteAddr?.hostname;
|
||||
|
||||
@@ -7,7 +7,7 @@ import { Message } from "../message.ts";
|
||||
import { WSError } from "../error.ts";
|
||||
import { Peer } from "../peer.ts";
|
||||
|
||||
import type { ClientRequest, IncomingMessage } from "node:http";
|
||||
import type { IncomingMessage } from "node:http";
|
||||
import type { Duplex } from "node:stream";
|
||||
import { WebSocketServer as _WebSocketServer } from "ws";
|
||||
import type {
|
||||
@@ -21,6 +21,7 @@ import type {
|
||||
type AugmentedReq = IncomingMessage & {
|
||||
_request: NodeReqProxy;
|
||||
_upgradeHeaders?: HeadersInit;
|
||||
_context: Peer["context"];
|
||||
};
|
||||
|
||||
export interface NodeAdapter extends AdapterInstance {
|
||||
@@ -87,13 +88,15 @@ export default defineWebSocketAdapter<NodeAdapter, NodeOptions>(
|
||||
handleUpgrade: async (nodeReq, socket, head) => {
|
||||
const request = new NodeReqProxy(nodeReq);
|
||||
|
||||
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
|
||||
const { upgradeHeaders, endResponse, context } =
|
||||
await hooks.upgrade(request);
|
||||
if (endResponse) {
|
||||
return sendResponse(socket, endResponse);
|
||||
}
|
||||
|
||||
(nodeReq as AugmentedReq)._request = request;
|
||||
(nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders;
|
||||
(nodeReq as AugmentedReq)._context = context;
|
||||
wss.handleUpgrade(nodeReq, socket, head, (ws) => {
|
||||
wss.emit("connection", ws, nodeReq);
|
||||
});
|
||||
@@ -119,6 +122,10 @@ class NodePeer extends Peer<{
|
||||
return this._internal.nodeReq.socket?.remoteAddress;
|
||||
}
|
||||
|
||||
get context() {
|
||||
return (this._internal.nodeReq as AugmentedReq)._context;
|
||||
}
|
||||
|
||||
send(data: unknown, options?: { compress?: boolean }) {
|
||||
const dataBuff = toBufferLike(data);
|
||||
const isBinary = typeof data !== "string";
|
||||
|
||||
@@ -27,7 +27,8 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
|
||||
return {
|
||||
...adapterUtils(peers),
|
||||
fetch: async (request: Request) => {
|
||||
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
|
||||
const { upgradeHeaders, endResponse, context } =
|
||||
await hooks.upgrade(request);
|
||||
if (endResponse) {
|
||||
return endResponse;
|
||||
}
|
||||
@@ -60,6 +61,7 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
|
||||
request,
|
||||
hooks,
|
||||
ws,
|
||||
context,
|
||||
});
|
||||
peers.add(peer);
|
||||
if (opts.bidir) {
|
||||
@@ -98,6 +100,7 @@ class SSEPeer extends Peer<{
|
||||
request: Request;
|
||||
ws: SSEWebSocketStub;
|
||||
hooks: AdapterHookable;
|
||||
context: Peer["context"];
|
||||
}> {
|
||||
_sseStream: ReadableStream; // server -> client
|
||||
_sseStreamController?: ReadableStreamDefaultController;
|
||||
|
||||
@@ -15,6 +15,7 @@ type UserData = {
|
||||
res: uws.HttpResponse;
|
||||
protocol: string;
|
||||
extensions: string;
|
||||
context: Peer["context"];
|
||||
};
|
||||
|
||||
type WebSocketHandler = uws.WebSocketBehavior<UserData>;
|
||||
@@ -70,13 +71,13 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
|
||||
peers.add(peer);
|
||||
hooks.callHook("open", peer);
|
||||
},
|
||||
async upgrade(res, req, context) {
|
||||
async upgrade(res, req, uwsContext) {
|
||||
let aborted = false;
|
||||
res.onAborted(() => {
|
||||
aborted = true;
|
||||
});
|
||||
|
||||
const { upgradeHeaders, endResponse } = await hooks.upgrade(
|
||||
const { upgradeHeaders, endResponse, context } = await hooks.upgrade(
|
||||
new UWSReqProxy(req),
|
||||
);
|
||||
if (endResponse) {
|
||||
@@ -119,11 +120,12 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
|
||||
res,
|
||||
protocol,
|
||||
extensions,
|
||||
context,
|
||||
},
|
||||
key,
|
||||
protocol,
|
||||
extensions,
|
||||
context,
|
||||
uwsContext,
|
||||
);
|
||||
});
|
||||
},
|
||||
@@ -168,6 +170,10 @@ class UWSPeer extends Peer<{
|
||||
}
|
||||
}
|
||||
|
||||
get context() {
|
||||
return this._internal.uwsData.context;
|
||||
}
|
||||
|
||||
send(data: unknown, options?: { compress?: boolean }) {
|
||||
const dataBuff = toBufferLike(data);
|
||||
const isBinary = typeof data !== "string";
|
||||
|
||||
22
src/hooks.ts
22
src/hooks.ts
@@ -40,30 +40,38 @@ export class AdapterHookable {
|
||||
) as Promise<any>;
|
||||
}
|
||||
|
||||
async upgrade(request: UpgradeRequest): Promise<{
|
||||
async upgrade(
|
||||
request: UpgradeRequest & { context?: Peer["context"] },
|
||||
): Promise<{
|
||||
upgradeHeaders?: HeadersInit;
|
||||
endResponse?: Response;
|
||||
context: Peer["context"];
|
||||
}> {
|
||||
const context = (request.context ??= {});
|
||||
try {
|
||||
const res = await this.callHook("upgrade", request);
|
||||
const res = await this.callHook(
|
||||
"upgrade",
|
||||
request as UpgradeRequest & { context: Peer["context"] },
|
||||
);
|
||||
if (!res) {
|
||||
return {};
|
||||
return { context };
|
||||
}
|
||||
if ((res as Response).ok === false) {
|
||||
return { endResponse: res as Response };
|
||||
return { context, endResponse: res as Response };
|
||||
}
|
||||
if (res.headers) {
|
||||
return {
|
||||
context,
|
||||
upgradeHeaders: res.headers,
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Response) {
|
||||
return { endResponse: error };
|
||||
return { context, endResponse: error };
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
return {};
|
||||
return { context };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +104,7 @@ export interface Hooks {
|
||||
* @throws {Response}
|
||||
*/
|
||||
upgrade: (
|
||||
request: UpgradeRequest,
|
||||
request: UpgradeRequest & { context: Peer["context"] },
|
||||
) => MaybePromise<Response | ResponseInit | undefined>;
|
||||
|
||||
/** A message is received */
|
||||
|
||||
@@ -5,6 +5,7 @@ export interface AdapterInternal {
|
||||
ws: unknown;
|
||||
request?: Request | Partial<Request>;
|
||||
peers?: Set<Peer>;
|
||||
context?: Peer["context"];
|
||||
}
|
||||
|
||||
export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> {
|
||||
@@ -14,14 +15,15 @@ export abstract class Peer<Internal extends AdapterInternal = AdapterInternal> {
|
||||
|
||||
#ws?: Partial<web.WebSocket>;
|
||||
|
||||
readonly context: Record<string, unknown>;
|
||||
|
||||
constructor(internal: Internal) {
|
||||
this._topics = new Set();
|
||||
this.context = {};
|
||||
this._internal = internal;
|
||||
}
|
||||
|
||||
get context(): Record<string, unknown> {
|
||||
return (this._internal.context ??= {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Unique random [uuid v4](https://developer.mozilla.org/en-US/docs/Glossary/UUID) identifier for the peer.
|
||||
*/
|
||||
|
||||
@@ -28,6 +28,7 @@ export function createDemo<T extends Adapter<any, any>>(
|
||||
peer.send({
|
||||
id: peer.id,
|
||||
remoteAddress: peer.remoteAddress,
|
||||
context: peer.context,
|
||||
request: {
|
||||
url: peer.request?.url,
|
||||
headers: Object.fromEntries(peer.request?.headers || []),
|
||||
@@ -63,6 +64,7 @@ export function createDemo<T extends Adapter<any, any>>(
|
||||
headers: { "x-error": "unauthorized" },
|
||||
});
|
||||
}
|
||||
req.context.test = "1";
|
||||
return {
|
||||
headers: {
|
||||
"x-powered-by": "cross-ws",
|
||||
|
||||
@@ -69,7 +69,7 @@ export function wsTests(getURL: () => string, opts: WSTestOpts) {
|
||||
headers: { "x-test": "1" },
|
||||
});
|
||||
await ws.send("debug");
|
||||
const { request, remoteAddress } = await ws.next();
|
||||
const { request, remoteAddress, context } = await ws.next();
|
||||
|
||||
// Headers
|
||||
if (opts.adapter === "sse") {
|
||||
@@ -88,6 +88,11 @@ export function wsTests(getURL: () => string, opts: WSTestOpts) {
|
||||
if (!/sse|cloudflare/.test(opts.adapter)) {
|
||||
expect(remoteAddress).toMatch(/:{2}1|(?:0{4}:){7}0{3}1|127\.0\.\0\.1/);
|
||||
}
|
||||
|
||||
// Context
|
||||
if (opts.adapter !== "cloudflare-durable") {
|
||||
expect(context.test).toBe("1");
|
||||
}
|
||||
});
|
||||
|
||||
test("peer.websocket", async () => {
|
||||
|
||||
Reference in New Issue
Block a user