feat: support throwing responses in upgrade hook (#91)

Co-authored-by: Pooya Parsa <pooya@pi0.io>
This commit is contained in:
Luke Hagar
2025-01-21 16:33:57 -06:00
committed by GitHub
parent 676c577f41
commit c850a66ceb
8 changed files with 96 additions and 48 deletions

View File

@@ -31,17 +31,18 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
return {
...adapterUtils(peers),
async handleUpgrade(request, server) {
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) {
return res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
const upgradeOK = server.upgrade(request, {
data: {
server,
request,
} satisfies ContextData,
headers: res?.headers,
headers: upgradeHeaders,
});
if (!upgradeOK) {
return new Response("Upgrade failed", { status: 500 });
}

View File

@@ -31,10 +31,13 @@ export default defineWebSocketAdapter<
// placeholder
},
handleDurableUpgrade: async (obj, request) => {
const res = await hooks.callHook("upgrade", request as Request);
if (res instanceof Response) {
return res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(
request as Request,
);
if (endResponse) {
return endResponse;
}
const pair = new WebSocketPair();
const client = pair[0];
const server = pair[1];
@@ -46,11 +49,12 @@ export default defineWebSocketAdapter<
peers.add(peer);
(obj as DurableObjectPub).ctx.acceptWebSocket(server);
await hooks.callHook("open", peer);
// eslint-disable-next-line unicorn/no-null
return new Response(null, {
status: 101,
webSocket: client,
headers: res?.headers,
headers: upgradeHeaders,
});
},
handleDurableMessage: async (obj, ws, message) => {

View File

@@ -33,13 +33,13 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
return {
...adapterUtils(peers),
handleUpgrade: async (request, env, context) => {
const res = await hooks.callHook(
"upgrade",
const { upgradeHeaders, endResponse } = await hooks.upgrade(
request as unknown as Request,
);
if (res instanceof Response) {
return res;
if (endResponse) {
return endResponse as unknown as _cf.Response;
}
const pair = new WebSocketPair();
const client = pair[0];
const server = pair[1];
@@ -73,7 +73,7 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
return new Response(null, {
status: 101,
webSocket: client,
headers: res?.headers,
headers: upgradeHeaders,
});
},
};

View File

@@ -31,13 +31,14 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
return {
...adapterUtils(peers),
handleUpgrade: async (request, info) => {
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) {
return res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
const upgrade = Deno.upgradeWebSocket(request, {
// @ts-expect-error https://github.com/denoland/deno/pull/22242
headers: res?.headers,
headers: upgradeHeaders,
});
const peer = new DenoPeer({
ws: upgrade.socket,

View File

@@ -86,12 +86,14 @@ export default defineWebSocketAdapter<NodeAdapter, NodeOptions>(
...adapterUtils(peers),
handleUpgrade: async (nodeReq, socket, head) => {
const request = new NodeReqProxy(nodeReq);
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) {
return sendResponse(socket, res);
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return sendResponse(socket, endResponse);
}
(nodeReq as AugmentedReq)._request = request;
(nodeReq as AugmentedReq)._upgradeHeaders = res?.headers;
(nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders;
wss.handleUpgrade(nodeReq, socket, head, (ws) => {
wss.emit("connection", ws, nodeReq);
});

View File

@@ -27,9 +27,9 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
return {
...adapterUtils(peers),
fetch: async (request: Request) => {
const _res = await hooks.callHook("upgrade", request);
if (_res instanceof Response) {
return _res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
let peer: SSEPeer;
@@ -73,17 +73,19 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (opts.bidir) {
headers["x-crossws-id"] = peer.id;
}
if (_res?.headers) {
if (upgradeHeaders) {
headers = new Headers(headers);
for (const [key, value] of new Headers(_res.headers)) {
for (const [key, value] of new Headers(upgradeHeaders)) {
headers.set(key, value);
}
}
return new Response(peer._sseStream, { ..._res, headers });
return new Response(peer._sseStream, { headers });
},
};
});

View File

@@ -75,20 +75,18 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
res.onAborted(() => {
aborted = true;
});
const _res = await hooks.callHook("upgrade", new UWSReqProxy(req));
if (aborted) {
return;
}
if (_res instanceof Response) {
res.writeStatus(`${_res.status} ${_res.statusText}`);
for (const [key, value] of _res.headers) {
const { upgradeHeaders, endResponse } = await hooks.upgrade(
new UWSReqProxy(req),
);
if (endResponse) {
res.writeStatus(`${endResponse.status} ${endResponse.statusText}`);
for (const [key, value] of endResponse.headers) {
res.writeHeader(key, value);
}
if (_res.body) {
for await (const chunk of _res.body) {
if (aborted) {
break;
}
if (endResponse.body) {
for await (const chunk of endResponse.body) {
if (aborted) break;
res.write(chunk);
}
}
@@ -97,9 +95,16 @@ export default defineWebSocketAdapter<UWSAdapter, UWSOptions>(
}
return;
}
if (aborted) {
return;
}
res.writeStatus("101 Switching Protocols");
if (_res?.headers) {
for (const [key, value] of new Headers(_res.headers)) {
if (upgradeHeaders) {
// prettier-ignore
const headers = upgradeHeaders instanceof Headers ? upgradeHeaders : new Headers(upgradeHeaders);
for (const [key, value] of headers) {
res.writeHeader(key, value);
}
}

View File

@@ -39,6 +39,32 @@ export class AdapterHookable {
},
) as Promise<any>;
}
async upgrade(request: UpgradeRequest): Promise<{
upgradeHeaders?: HeadersInit;
endResponse?: Response;
}> {
try {
const res = await this.callHook("upgrade", request);
if (!res) {
return {};
}
if ((res as Response).ok === false) {
return { endResponse: res as Response };
}
if (res.headers) {
return {
upgradeHeaders: res.headers,
};
}
} catch (error) {
if (error instanceof Response) {
return { endResponse: error };
}
throw error;
}
return {};
}
}
// --- types ---
@@ -60,16 +86,23 @@ type HookFn<ArgsT extends any[] = any, RT = void> = (
...args: ArgsT
) => MaybePromise<RT>;
export interface Hooks {
/** Upgrading */
upgrade: (
request:
export type UpgradeRequest =
| Request
| {
url: string;
headers: Headers;
},
) => MaybePromise<Response | ResponseInit | void>;
};
export interface Hooks {
/** Upgrading */
/**
*
* @param request
* @throws {Response}
*/
upgrade: (
request: UpgradeRequest,
) => MaybePromise<Response | ResponseInit | undefined>;
/** A message is received */
message: (peer: Peer, message: Message) => MaybePromise<void>;