diff --git a/apps/RenderCadenceCompositor/README.md b/apps/RenderCadenceCompositor/README.md index e87370a..4b72049 100644 --- a/apps/RenderCadenceCompositor/README.md +++ b/apps/RenderCadenceCompositor/README.md @@ -150,6 +150,7 @@ Current endpoints: - `GET /` and UI asset paths: serve the bundled control UI from `ui/dist` - `GET /api/state`: returns OpenAPI-shaped display data with cadence telemetry, supported shaders, output status, and a read-only current runtime layer +- `GET /ws`: upgrades to a WebSocket and streams state snapshots when they change - `GET /docs/openapi.yaml` and `GET /openapi.yaml`: serves the OpenAPI document - `GET /docs`: serves Swagger UI - `POST /api/layers/add` and `POST /api/layers/remove` mutate the app-owned display layer model only diff --git a/apps/RenderCadenceCompositor/control/HttpControlServer.cpp b/apps/RenderCadenceCompositor/control/HttpControlServer.cpp index 6f80f89..cc1eb44 100644 --- a/apps/RenderCadenceCompositor/control/HttpControlServer.cpp +++ b/apps/RenderCadenceCompositor/control/HttpControlServer.cpp @@ -6,9 +6,12 @@ #include #include +#include #include +#include #include #include +#include namespace RenderCadenceCompositor { @@ -41,6 +44,117 @@ bool IsKnownPostEndpoint(const std::string& path) || path == "/api/reload" || path == "/api/screenshot"; } + +std::array Sha1(const std::string& input) +{ + auto leftRotate = [](uint32_t value, uint32_t bits) { + return (value << bits) | (value >> (32U - bits)); + }; + + std::vector data(input.begin(), input.end()); + const uint64_t bitLength = static_cast(data.size()) * 8ULL; + data.push_back(0x80); + while ((data.size() % 64) != 56) + data.push_back(0); + for (int shift = 56; shift >= 0; shift -= 8) + data.push_back(static_cast((bitLength >> shift) & 0xff)); + + uint32_t h0 = 0x67452301; + uint32_t h1 = 0xefcdab89; + uint32_t h2 = 0x98badcfe; + uint32_t h3 = 0x10325476; + uint32_t h4 = 0xc3d2e1f0; + + for (std::size_t offset = 0; offset < data.size(); offset += 64) + { + uint32_t words[80] = {}; + for (std::size_t i = 0; i < 16; ++i) + { + const std::size_t index = offset + i * 4; + words[i] = (static_cast(data[index]) << 24) + | (static_cast(data[index + 1]) << 16) + | (static_cast(data[index + 2]) << 8) + | static_cast(data[index + 3]); + } + for (std::size_t i = 16; i < 80; ++i) + words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1); + + uint32_t a = h0; + uint32_t b = h1; + uint32_t c = h2; + uint32_t d = h3; + uint32_t e = h4; + + for (std::size_t i = 0; i < 80; ++i) + { + uint32_t f = 0; + uint32_t k = 0; + if (i < 20) + { + f = (b & c) | ((~b) & d); + k = 0x5a827999; + } + else if (i < 40) + { + f = b ^ c ^ d; + k = 0x6ed9eba1; + } + else if (i < 60) + { + f = (b & c) | (b & d) | (c & d); + k = 0x8f1bbcdc; + } + else + { + f = b ^ c ^ d; + k = 0xca62c1d6; + } + + const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i]; + e = d; + d = c; + c = leftRotate(b, 30); + b = a; + a = temp; + } + + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + } + + std::array digest = {}; + const uint32_t parts[] = { h0, h1, h2, h3, h4 }; + for (std::size_t i = 0; i < 5; ++i) + { + digest[i * 4] = static_cast((parts[i] >> 24) & 0xff); + digest[i * 4 + 1] = static_cast((parts[i] >> 16) & 0xff); + digest[i * 4 + 2] = static_cast((parts[i] >> 8) & 0xff); + digest[i * 4 + 3] = static_cast(parts[i] & 0xff); + } + return digest; +} + +std::string Base64Encode(const uint8_t* data, std::size_t size) +{ + static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string output; + output.reserve(((size + 2) / 3) * 4); + for (std::size_t i = 0; i < size; i += 3) + { + const uint32_t a = data[i]; + const uint32_t b = i + 1 < size ? data[i + 1] : 0; + const uint32_t c = i + 2 < size ? data[i + 2] : 0; + const uint32_t triple = (a << 16) | (b << 8) | c; + output.push_back(kAlphabet[(triple >> 18) & 0x3f]); + output.push_back(kAlphabet[(triple >> 12) & 0x3f]); + output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '='); + output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '='); + } + return output; +} } UniqueSocket::UniqueSocket(SOCKET socket) : @@ -157,6 +271,20 @@ void HttpControlServer::Stop() if (mThread.joinable()) mThread.join(); + std::vector clientThreads; + { + std::lock_guard lock(mClientThreadsMutex); + clientThreads.swap(mClientThreads); + for (std::thread& thread : mFinishedClientThreads) + clientThreads.push_back(std::move(thread)); + mFinishedClientThreads.clear(); + } + for (std::thread& thread : clientThreads) + { + if (thread.joinable()) + thread.join(); + } + if (mWinsockStarted) { WSACleanup(); @@ -185,6 +313,7 @@ void HttpControlServer::ThreadMain() { while (mRunning.load(std::memory_order_acquire)) { + JoinFinishedClientThreads(); TryAcceptClient(); std::this_thread::sleep_for(mConfig.idleSleep); } @@ -212,9 +341,81 @@ bool HttpControlServer::HandleClient(UniqueSocket clientSocket) if (!ParseHttpRequest(std::string(buffer, buffer + received), request)) return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Bad Request")); + if (request.path == "/ws") + return HandleWebSocketClient(std::move(clientSocket), request); + return SendResponse(clientSocket.get(), RouteRequest(request)); } +bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request) +{ + const auto keyIt = request.headers.find("sec-websocket-key"); + if (keyIt == request.headers.end() || keyIt->second.empty()) + return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key")); + + std::ostringstream stream; + stream << "HTTP/1.1 101 Switching Protocols\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n"; + const std::string response = stream.str(); + if (send(clientSocket.get(), response.c_str(), static_cast(response.size()), 0) != static_cast(response.size())) + return false; + + u_long nonBlocking = 1; + ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking); + + std::thread thread([this, socket = std::move(clientSocket)]() mutable { + WebSocketClientMain(std::move(socket)); + }); + { + std::lock_guard lock(mClientThreadsMutex); + mClientThreads.push_back(std::move(thread)); + } + return true; +} + +void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket) +{ + std::string previousState; + while (mRunning.load(std::memory_order_acquire)) + { + const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; + if (state != previousState) + { + if (!SendWebSocketText(clientSocket.get(), state)) + break; + previousState = state; + } + std::this_thread::sleep_for(std::chrono::milliseconds(250)); + } + + std::lock_guard lock(mClientThreadsMutex); + const std::thread::id currentId = std::this_thread::get_id(); + for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it) + { + if (it->get_id() != currentId) + continue; + mFinishedClientThreads.push_back(std::move(*it)); + mClientThreads.erase(it); + break; + } +} + +void HttpControlServer::JoinFinishedClientThreads() +{ + std::vector finished; + { + std::lock_guard lock(mClientThreadsMutex); + finished.swap(mFinishedClientThreads); + } + for (std::thread& thread : finished) + { + if (thread.joinable()) + thread.join(); + } +} + bool HttpControlServer::SendResponse(SOCKET clientSocket, const HttpResponse& response) const { std::ostringstream stream; @@ -360,6 +561,61 @@ std::string HttpControlServer::ActionResponse(bool ok, const std::string& error) return writer.StringValue(); } +bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text) +{ + if (clientSocket == INVALID_SOCKET) + return false; + + std::vector frame; + frame.reserve(text.size() + 16); + frame.push_back(0x81); + if (text.size() <= 125) + { + frame.push_back(static_cast(text.size())); + } + else if (text.size() <= 0xffff) + { + frame.push_back(126); + frame.push_back(static_cast((text.size() >> 8) & 0xff)); + frame.push_back(static_cast(text.size() & 0xff)); + } + else + { + frame.push_back(127); + const uint64_t length = static_cast(text.size()); + for (int shift = 56; shift >= 0; shift -= 8) + frame.push_back(static_cast((length >> shift) & 0xff)); + } + frame.insert(frame.end(), text.begin(), text.end()); + + const char* data = reinterpret_cast(frame.data()); + int remaining = static_cast(frame.size()); + while (remaining > 0) + { + const int sent = send(clientSocket, data, remaining, 0); + if (sent <= 0) + { + const int error = WSAGetLastError(); + if (error == WSAEWOULDBLOCK) + { + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + continue; + } + return false; + } + data += sent; + remaining -= sent; + } + return true; +} + +std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey) +{ + static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + const std::array digest = Sha1(clientKey + kWebSocketGuid); + return Base64Encode(digest.data(), digest.size()); +} + std::string HttpControlServer::GuessContentType(const std::filesystem::path& path) { const std::string extension = ToLower(path.extension().string()); diff --git a/apps/RenderCadenceCompositor/control/HttpControlServer.h b/apps/RenderCadenceCompositor/control/HttpControlServer.h index 31f15b8..8b811eb 100644 --- a/apps/RenderCadenceCompositor/control/HttpControlServer.h +++ b/apps/RenderCadenceCompositor/control/HttpControlServer.h @@ -9,8 +9,10 @@ #include #include #include +#include #include #include +#include namespace RenderCadenceCompositor { @@ -89,11 +91,15 @@ public: HttpResponse RouteRequestForTest(const HttpRequest& request) const; static bool ParseHttpRequest(const std::string& rawRequest, HttpRequest& request); + static std::string WebSocketAcceptKey(const std::string& clientKey); private: void ThreadMain(); bool TryAcceptClient(); bool HandleClient(UniqueSocket clientSocket); + bool HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request); + void WebSocketClientMain(UniqueSocket clientSocket); + void JoinFinishedClientThreads(); bool SendResponse(SOCKET clientSocket, const HttpResponse& response) const; HttpResponse RouteRequest(const HttpRequest& request) const; HttpResponse ServeGet(const HttpRequest& request) const; @@ -107,6 +113,7 @@ private: static HttpResponse TextResponse(const std::string& status, const std::string& body); static HttpResponse HtmlResponse(const std::string& status, const std::string& body); static std::string ActionResponse(bool ok, const std::string& error = std::string()); + static bool SendWebSocketText(SOCKET clientSocket, const std::string& text); static std::string GuessContentType(const std::filesystem::path& path); static bool IsSafeRelativePath(const std::filesystem::path& path); static std::string ToLower(std::string text); @@ -117,6 +124,9 @@ private: HttpControlServerCallbacks mCallbacks; UniqueSocket mListenSocket; std::thread mThread; + std::mutex mClientThreadsMutex; + std::vector mClientThreads; + std::vector mFinishedClientThreads; std::atomic mRunning{ false }; unsigned short mPort = 0; bool mWinsockStarted = false; diff --git a/docs/openapi.yaml b/docs/openapi.yaml index 7045158..6278885 100644 --- a/docs/openapi.yaml +++ b/docs/openapi.yaml @@ -8,8 +8,8 @@ info: The API is intended for local control tools and the bundled React UI. All mutating endpoints return a small action result object. - WebSocket state streaming is planned for the control UI but is not currently served - by RenderCadenceCompositor. Clients should poll `/api/state` until `/ws` is implemented. + RenderCadenceCompositor serves `/api/state` for snapshots and `/ws` for local + WebSocket state updates consumed by the bundled control UI. servers: - url: http://127.0.0.1:8080 description: Default local control server @@ -179,6 +179,24 @@ paths: application/json: schema: $ref: "#/components/schemas/RuntimeState" + /ws: + get: + tags: [State] + summary: Stream runtime state over WebSocket + description: | + Upgrades to a WebSocket connection. The server sends JSON runtime-state + snapshots using the same shape as `GET /api/state` whenever the serialized + state changes. + operationId: streamRuntimeState + responses: + "101": + description: WebSocket protocol upgrade accepted. + "400": + description: The request was not a valid WebSocket upgrade. + content: + text/plain: + schema: + type: string /api/layers/add: post: tags: [Layers] diff --git a/tests/RenderCadenceCompositorHttpControlServerTests.cpp b/tests/RenderCadenceCompositorHttpControlServerTests.cpp index 6865e15..44e936d 100644 --- a/tests/RenderCadenceCompositorHttpControlServerTests.cpp +++ b/tests/RenderCadenceCompositorHttpControlServerTests.cpp @@ -63,6 +63,14 @@ void TestStateEndpointUsesCallback() ExpectEquals(response.body, "{\"ok\":true}", "state endpoint returns callback JSON"); } +void TestWebSocketAcceptKey() +{ + using namespace RenderCadenceCompositor; + + const std::string acceptKey = HttpControlServer::WebSocketAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + ExpectEquals(acceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", "WebSocket accept key matches RFC example"); +} + void TestRootServesUiIndex() { using namespace RenderCadenceCompositor; @@ -157,6 +165,7 @@ int main() { TestParsesHttpRequest(); TestStateEndpointUsesCallback(); + TestWebSocketAcceptKey(); TestRootServesUiIndex(); TestKnownPostEndpointReturnsActionError(); TestLayerPostEndpointsUseCallbacks();