feat: support throwing responses in upgrade hook (#91)

Co-authored-by: Pooya Parsa <pooya@pi0.io>
This commit is contained in:
Luke Hagar
2025-01-21 16:33:57 -06:00
committed by GitHub
parent 676c577f41
commit c850a66ceb
8 changed files with 96 additions and 48 deletions

View File

@@ -31,17 +31,18 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
async handleUpgrade(request, server) { async handleUpgrade(request, server) {
const res = await hooks.callHook("upgrade", request); const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (res instanceof Response) { if (endResponse) {
return res; return endResponse;
} }
const upgradeOK = server.upgrade(request, { const upgradeOK = server.upgrade(request, {
data: { data: {
server, server,
request, request,
} satisfies ContextData, } satisfies ContextData,
headers: res?.headers, headers: upgradeHeaders,
}); });
if (!upgradeOK) { if (!upgradeOK) {
return new Response("Upgrade failed", { status: 500 }); return new Response("Upgrade failed", { status: 500 });
} }

View File

@@ -31,10 +31,13 @@ export default defineWebSocketAdapter<
// placeholder // placeholder
}, },
handleDurableUpgrade: async (obj, request) => { handleDurableUpgrade: async (obj, request) => {
const res = await hooks.callHook("upgrade", request as Request); const { upgradeHeaders, endResponse } = await hooks.upgrade(
if (res instanceof Response) { request as Request,
return res; );
if (endResponse) {
return endResponse;
} }
const pair = new WebSocketPair(); const pair = new WebSocketPair();
const client = pair[0]; const client = pair[0];
const server = pair[1]; const server = pair[1];
@@ -46,11 +49,12 @@ export default defineWebSocketAdapter<
peers.add(peer); peers.add(peer);
(obj as DurableObjectPub).ctx.acceptWebSocket(server); (obj as DurableObjectPub).ctx.acceptWebSocket(server);
await hooks.callHook("open", peer); await hooks.callHook("open", peer);
// eslint-disable-next-line unicorn/no-null // eslint-disable-next-line unicorn/no-null
return new Response(null, { return new Response(null, {
status: 101, status: 101,
webSocket: client, webSocket: client,
headers: res?.headers, headers: upgradeHeaders,
}); });
}, },
handleDurableMessage: async (obj, ws, message) => { handleDurableMessage: async (obj, ws, message) => {

View File

@@ -33,13 +33,13 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
handleUpgrade: async (request, env, context) => { handleUpgrade: async (request, env, context) => {
const res = await hooks.callHook( const { upgradeHeaders, endResponse } = await hooks.upgrade(
"upgrade",
request as unknown as Request, request as unknown as Request,
); );
if (res instanceof Response) { if (endResponse) {
return res; return endResponse as unknown as _cf.Response;
} }
const pair = new WebSocketPair(); const pair = new WebSocketPair();
const client = pair[0]; const client = pair[0];
const server = pair[1]; const server = pair[1];
@@ -73,7 +73,7 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
return new Response(null, { return new Response(null, {
status: 101, status: 101,
webSocket: client, webSocket: client,
headers: res?.headers, headers: upgradeHeaders,
}); });
}, },
}; };

View File

@@ -31,13 +31,14 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
handleUpgrade: async (request, info) => { handleUpgrade: async (request, info) => {
const res = await hooks.callHook("upgrade", request); const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (res instanceof Response) { if (endResponse) {
return res; return endResponse;
} }
const upgrade = Deno.upgradeWebSocket(request, { const upgrade = Deno.upgradeWebSocket(request, {
// @ts-expect-error https://github.com/denoland/deno/pull/22242 // @ts-expect-error https://github.com/denoland/deno/pull/22242
headers: res?.headers, headers: upgradeHeaders,
}); });
const peer = new DenoPeer({ const peer = new DenoPeer({
ws: upgrade.socket, ws: upgrade.socket,

View File

@@ -86,12 +86,14 @@ export default defineWebSocketAdapter<NodeAdapter, NodeOptions>(
...adapterUtils(peers), ...adapterUtils(peers),
handleUpgrade: async (nodeReq, socket, head) => { handleUpgrade: async (nodeReq, socket, head) => {
const request = new NodeReqProxy(nodeReq); const request = new NodeReqProxy(nodeReq);
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) { const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
return sendResponse(socket, res); if (endResponse) {
return sendResponse(socket, endResponse);
} }
(nodeReq as AugmentedReq)._request = request; (nodeReq as AugmentedReq)._request = request;
(nodeReq as AugmentedReq)._upgradeHeaders = res?.headers; (nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders;
wss.handleUpgrade(nodeReq, socket, head, (ws) => { wss.handleUpgrade(nodeReq, socket, head, (ws) => {
wss.emit("connection", ws, nodeReq); wss.emit("connection", ws, nodeReq);
}); });

View File

@@ -27,9 +27,9 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
return { return {
...adapterUtils(peers), ...adapterUtils(peers),
fetch: async (request: Request) => { fetch: async (request: Request) => {
const _res = await hooks.callHook("upgrade", request); const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (_res instanceof Response) { if (endResponse) {
return _res; return endResponse;
} }
let peer: SSEPeer; let peer: SSEPeer;
@@ -73,17 +73,19 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
Connection: "keep-alive", Connection: "keep-alive",
}; };
if (opts.bidir) { if (opts.bidir) {
headers["x-crossws-id"] = peer.id; headers["x-crossws-id"] = peer.id;
} }
if (_res?.headers) {
if (upgradeHeaders) {
headers = new Headers(headers); headers = new Headers(headers);
for (const [key, value] of new Headers(_res.headers)) { for (const [key, value] of new Headers(upgradeHeaders)) {
headers.set(key, value); headers.set(key, value);
} }
} }
return new Response(peer._sseStream, { ..._res, headers }); return new Response(peer._sseStream, { headers });
}, },
}; };
}); });

View File

@@ -75,20 +75,18 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
res.onAborted(() => { res.onAborted(() => {
aborted = true; aborted = true;
}); });
const _res = await hooks.callHook("upgrade", new UWSReqProxy(req));
if (aborted) { const { upgradeHeaders, endResponse } = await hooks.upgrade(
return; new UWSReqProxy(req),
} );
if (_res instanceof Response) { if (endResponse) {
res.writeStatus(`${_res.status} ${_res.statusText}`); res.writeStatus(`${endResponse.status} ${endResponse.statusText}`);
for (const [key, value] of _res.headers) { for (const [key, value] of endResponse.headers) {
res.writeHeader(key, value); res.writeHeader(key, value);
} }
if (_res.body) { if (endResponse.body) {
for await (const chunk of _res.body) { for await (const chunk of endResponse.body) {
if (aborted) { if (aborted) break;
break;
}
res.write(chunk); res.write(chunk);
} }
} }
@@ -97,9 +95,16 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
} }
return; return;
} }
if (aborted) {
return;
}
res.writeStatus("101 Switching Protocols"); res.writeStatus("101 Switching Protocols");
if (_res?.headers) { if (upgradeHeaders) {
for (const [key, value] of new Headers(_res.headers)) { // prettier-ignore
const headers = upgradeHeaders instanceof Headers ? upgradeHeaders : new Headers(upgradeHeaders);
for (const [key, value] of headers) {
res.writeHeader(key, value); res.writeHeader(key, value);
} }
} }

View File

@@ -39,6 +39,32 @@ export class AdapterHookable {
}, },
) as Promise<any>; ) as Promise<any>;
} }
async upgrade(request: UpgradeRequest): Promise<{
upgradeHeaders?: HeadersInit;
endResponse?: Response;
}> {
try {
const res = await this.callHook("upgrade", request);
if (!res) {
return {};
}
if ((res as Response).ok === false) {
return { endResponse: res as Response };
}
if (res.headers) {
return {
upgradeHeaders: res.headers,
};
}
} catch (error) {
if (error instanceof Response) {
return { endResponse: error };
}
throw error;
}
return {};
}
} }
// --- types --- // --- types ---
@@ -60,16 +86,23 @@ type HookFn<ArgsT extends any[] = any, RT = void> = (
...args: ArgsT ...args: ArgsT
) => MaybePromise<RT>; ) => MaybePromise<RT>;
export type UpgradeRequest =
| Request
| {
url: string;
headers: Headers;
};
export interface Hooks { export interface Hooks {
/** Upgrading */ /** Upgrading */
/**
*
* @param request
* @throws {Response}
*/
upgrade: ( upgrade: (
request: request: UpgradeRequest,
| Request ) => MaybePromise<Response | ResponseInit | undefined>;
| {
url: string;
headers: Headers;
},
) => MaybePromise<Response | ResponseInit | void>;
/** A message is received */ /** A message is received */
message: (peer: Peer, message: Message) => MaybePromise<void>; message: (peer: Peer, message: Message) => MaybePromise<void>;