diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f8bcd2..b2f0af8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -312,6 +312,7 @@ set(RENDER_CADENCE_APP_SOURCES "${RENDER_CADENCE_APP_DIR}/control/ControlActionResult.h" "${RENDER_CADENCE_APP_DIR}/control/HttpControlServer.cpp" "${RENDER_CADENCE_APP_DIR}/control/HttpControlServer.h" + "${RENDER_CADENCE_APP_DIR}/control/HttpControlServerWebSocket.cpp" "${RENDER_CADENCE_APP_DIR}/control/RuntimeStateJson.h" "${RENDER_CADENCE_APP_DIR}/frames/SystemFrameExchange.cpp" "${RENDER_CADENCE_APP_DIR}/frames/SystemFrameExchange.h" @@ -936,6 +937,7 @@ add_test(NAME RenderCadenceCompositorRuntimeStateJsonTests COMMAND RenderCadence add_executable(RenderCadenceCompositorHttpControlServerTests "${RENDER_CADENCE_APP_DIR}/control/HttpControlServer.cpp" + "${RENDER_CADENCE_APP_DIR}/control/HttpControlServerWebSocket.cpp" "${RENDER_CADENCE_APP_DIR}/json/JsonWriter.cpp" "${RENDER_CADENCE_APP_DIR}/logging/Logger.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/tests/RenderCadenceCompositorHttpControlServerTests.cpp" diff --git a/apps/RenderCadenceCompositor/control/HttpControlServer.cpp b/apps/RenderCadenceCompositor/control/HttpControlServer.cpp index cc1eb44..4ffc962 100644 --- a/apps/RenderCadenceCompositor/control/HttpControlServer.cpp +++ b/apps/RenderCadenceCompositor/control/HttpControlServer.cpp @@ -6,9 +6,7 @@ #include #include -#include #include -#include #include #include #include @@ -45,116 +43,6 @@ bool IsKnownPostEndpoint(const std::string& path) || path == "/api/screenshot"; } -std::array Sha1(const std::string& input) -{ - auto leftRotate = [](uint32_t value, uint32_t bits) { - return (value << bits) | (value >> (32U - bits)); - }; - - std::vector data(input.begin(), input.end()); - const uint64_t bitLength = static_cast(data.size()) * 8ULL; - data.push_back(0x80); - while ((data.size() % 64) != 56) - data.push_back(0); - for (int shift = 56; shift >= 0; shift -= 8) - data.push_back(static_cast((bitLength >> shift) & 0xff)); - - uint32_t h0 = 0x67452301; - uint32_t h1 = 0xefcdab89; - uint32_t h2 = 0x98badcfe; - uint32_t h3 = 0x10325476; - uint32_t h4 = 0xc3d2e1f0; - - for (std::size_t offset = 0; offset < data.size(); offset += 64) - { - uint32_t words[80] = {}; - for (std::size_t i = 0; i < 16; ++i) - { - const std::size_t index = offset + i * 4; - words[i] = (static_cast(data[index]) << 24) - | (static_cast(data[index + 1]) << 16) - | (static_cast(data[index + 2]) << 8) - | static_cast(data[index + 3]); - } - for (std::size_t i = 16; i < 80; ++i) - words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1); - - uint32_t a = h0; - uint32_t b = h1; - uint32_t c = h2; - uint32_t d = h3; - uint32_t e = h4; - - for (std::size_t i = 0; i < 80; ++i) - { - uint32_t f = 0; - uint32_t k = 0; - if (i < 20) - { - f = (b & c) | ((~b) & d); - k = 0x5a827999; - } - else if (i < 40) - { - f = b ^ c ^ d; - k = 0x6ed9eba1; - } - else if (i < 60) - { - f = (b & c) | (b & d) | (c & d); - k = 0x8f1bbcdc; - } - else - { - f = b ^ c ^ d; - k = 0xca62c1d6; - } - - const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i]; - e = d; - d = c; - c = leftRotate(b, 30); - b = a; - a = temp; - } - - h0 += a; - h1 += b; - h2 += c; - h3 += d; - h4 += e; - } - - std::array digest = {}; - const uint32_t parts[] = { h0, h1, h2, h3, h4 }; - for (std::size_t i = 0; i < 5; ++i) - { - digest[i * 4] = static_cast((parts[i] >> 24) & 0xff); - digest[i * 4 + 1] = static_cast((parts[i] >> 16) & 0xff); - digest[i * 4 + 2] = static_cast((parts[i] >> 8) & 0xff); - digest[i * 4 + 3] = static_cast(parts[i] & 0xff); - } - return digest; -} - -std::string Base64Encode(const uint8_t* data, std::size_t size) -{ - static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string output; - output.reserve(((size + 2) / 3) * 4); - for (std::size_t i = 0; i < size; i += 3) - { - const uint32_t a = data[i]; - const uint32_t b = i + 1 < size ? data[i + 1] : 0; - const uint32_t c = i + 2 < size ? data[i + 2] : 0; - const uint32_t triple = (a << 16) | (b << 8) | c; - output.push_back(kAlphabet[(triple >> 18) & 0x3f]); - output.push_back(kAlphabet[(triple >> 12) & 0x3f]); - output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '='); - output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '='); - } - return output; -} } UniqueSocket::UniqueSocket(SOCKET socket) : @@ -347,75 +235,6 @@ bool HttpControlServer::HandleClient(UniqueSocket clientSocket) return SendResponse(clientSocket.get(), RouteRequest(request)); } -bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request) -{ - const auto keyIt = request.headers.find("sec-websocket-key"); - if (keyIt == request.headers.end() || keyIt->second.empty()) - return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key")); - - std::ostringstream stream; - stream << "HTTP/1.1 101 Switching Protocols\r\n" - << "Upgrade: websocket\r\n" - << "Connection: Upgrade\r\n" - << "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n"; - const std::string response = stream.str(); - if (send(clientSocket.get(), response.c_str(), static_cast(response.size()), 0) != static_cast(response.size())) - return false; - - u_long nonBlocking = 1; - ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking); - - std::thread thread([this, socket = std::move(clientSocket)]() mutable { - WebSocketClientMain(std::move(socket)); - }); - { - std::lock_guard lock(mClientThreadsMutex); - mClientThreads.push_back(std::move(thread)); - } - return true; -} - -void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket) -{ - std::string previousState; - while (mRunning.load(std::memory_order_acquire)) - { - const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; - if (state != previousState) - { - if (!SendWebSocketText(clientSocket.get(), state)) - break; - previousState = state; - } - std::this_thread::sleep_for(std::chrono::milliseconds(250)); - } - - std::lock_guard lock(mClientThreadsMutex); - const std::thread::id currentId = std::this_thread::get_id(); - for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it) - { - if (it->get_id() != currentId) - continue; - mFinishedClientThreads.push_back(std::move(*it)); - mClientThreads.erase(it); - break; - } -} - -void HttpControlServer::JoinFinishedClientThreads() -{ - std::vector finished; - { - std::lock_guard lock(mClientThreadsMutex); - finished.swap(mFinishedClientThreads); - } - for (std::thread& thread : finished) - { - if (thread.joinable()) - thread.join(); - } -} - bool HttpControlServer::SendResponse(SOCKET clientSocket, const HttpResponse& response) const { std::ostringstream stream; @@ -561,61 +380,6 @@ std::string HttpControlServer::ActionResponse(bool ok, const std::string& error) return writer.StringValue(); } -bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text) -{ - if (clientSocket == INVALID_SOCKET) - return false; - - std::vector frame; - frame.reserve(text.size() + 16); - frame.push_back(0x81); - if (text.size() <= 125) - { - frame.push_back(static_cast(text.size())); - } - else if (text.size() <= 0xffff) - { - frame.push_back(126); - frame.push_back(static_cast((text.size() >> 8) & 0xff)); - frame.push_back(static_cast(text.size() & 0xff)); - } - else - { - frame.push_back(127); - const uint64_t length = static_cast(text.size()); - for (int shift = 56; shift >= 0; shift -= 8) - frame.push_back(static_cast((length >> shift) & 0xff)); - } - frame.insert(frame.end(), text.begin(), text.end()); - - const char* data = reinterpret_cast(frame.data()); - int remaining = static_cast(frame.size()); - while (remaining > 0) - { - const int sent = send(clientSocket, data, remaining, 0); - if (sent <= 0) - { - const int error = WSAGetLastError(); - if (error == WSAEWOULDBLOCK) - { - std::this_thread::sleep_for(std::chrono::milliseconds(2)); - continue; - } - return false; - } - data += sent; - remaining -= sent; - } - return true; -} - -std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey) -{ - static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - const std::array digest = Sha1(clientKey + kWebSocketGuid); - return Base64Encode(digest.data(), digest.size()); -} - std::string HttpControlServer::GuessContentType(const std::filesystem::path& path) { const std::string extension = ToLower(path.extension().string()); diff --git a/apps/RenderCadenceCompositor/control/HttpControlServerWebSocket.cpp b/apps/RenderCadenceCompositor/control/HttpControlServerWebSocket.cpp new file mode 100644 index 0000000..aa94df7 --- /dev/null +++ b/apps/RenderCadenceCompositor/control/HttpControlServerWebSocket.cpp @@ -0,0 +1,248 @@ +#include "HttpControlServer.h" + +#include +#include +#include +#include +#include + +namespace RenderCadenceCompositor +{ +namespace +{ +std::array Sha1(const std::string& input) +{ + auto leftRotate = [](uint32_t value, uint32_t bits) { + return (value << bits) | (value >> (32U - bits)); + }; + + std::vector data(input.begin(), input.end()); + const uint64_t bitLength = static_cast(data.size()) * 8ULL; + data.push_back(0x80); + while ((data.size() % 64) != 56) + data.push_back(0); + for (int shift = 56; shift >= 0; shift -= 8) + data.push_back(static_cast((bitLength >> shift) & 0xff)); + + uint32_t h0 = 0x67452301; + uint32_t h1 = 0xefcdab89; + uint32_t h2 = 0x98badcfe; + uint32_t h3 = 0x10325476; + uint32_t h4 = 0xc3d2e1f0; + + for (std::size_t offset = 0; offset < data.size(); offset += 64) + { + uint32_t words[80] = {}; + for (std::size_t i = 0; i < 16; ++i) + { + const std::size_t index = offset + i * 4; + words[i] = (static_cast(data[index]) << 24) + | (static_cast(data[index + 1]) << 16) + | (static_cast(data[index + 2]) << 8) + | static_cast(data[index + 3]); + } + for (std::size_t i = 16; i < 80; ++i) + words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1); + + uint32_t a = h0; + uint32_t b = h1; + uint32_t c = h2; + uint32_t d = h3; + uint32_t e = h4; + + for (std::size_t i = 0; i < 80; ++i) + { + uint32_t f = 0; + uint32_t k = 0; + if (i < 20) + { + f = (b & c) | ((~b) & d); + k = 0x5a827999; + } + else if (i < 40) + { + f = b ^ c ^ d; + k = 0x6ed9eba1; + } + else if (i < 60) + { + f = (b & c) | (b & d) | (c & d); + k = 0x8f1bbcdc; + } + else + { + f = b ^ c ^ d; + k = 0xca62c1d6; + } + + const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i]; + e = d; + d = c; + c = leftRotate(b, 30); + b = a; + a = temp; + } + + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + } + + std::array digest = {}; + const uint32_t parts[] = { h0, h1, h2, h3, h4 }; + for (std::size_t i = 0; i < 5; ++i) + { + digest[i * 4] = static_cast((parts[i] >> 24) & 0xff); + digest[i * 4 + 1] = static_cast((parts[i] >> 16) & 0xff); + digest[i * 4 + 2] = static_cast((parts[i] >> 8) & 0xff); + digest[i * 4 + 3] = static_cast(parts[i] & 0xff); + } + return digest; +} + +std::string Base64Encode(const uint8_t* data, std::size_t size) +{ + static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string output; + output.reserve(((size + 2) / 3) * 4); + for (std::size_t i = 0; i < size; i += 3) + { + const uint32_t a = data[i]; + const uint32_t b = i + 1 < size ? data[i + 1] : 0; + const uint32_t c = i + 2 < size ? data[i + 2] : 0; + const uint32_t triple = (a << 16) | (b << 8) | c; + output.push_back(kAlphabet[(triple >> 18) & 0x3f]); + output.push_back(kAlphabet[(triple >> 12) & 0x3f]); + output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '='); + output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '='); + } + return output; +} +} + +bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request) +{ + const auto keyIt = request.headers.find("sec-websocket-key"); + if (keyIt == request.headers.end() || keyIt->second.empty()) + return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key")); + + std::ostringstream stream; + stream << "HTTP/1.1 101 Switching Protocols\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n"; + const std::string response = stream.str(); + if (send(clientSocket.get(), response.c_str(), static_cast(response.size()), 0) != static_cast(response.size())) + return false; + + u_long nonBlocking = 1; + ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking); + + std::thread thread([this, socket = std::move(clientSocket)]() mutable { + WebSocketClientMain(std::move(socket)); + }); + { + std::lock_guard lock(mClientThreadsMutex); + mClientThreads.push_back(std::move(thread)); + } + return true; +} + +void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket) +{ + std::string previousState; + while (mRunning.load(std::memory_order_acquire)) + { + const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; + if (state != previousState) + { + if (!SendWebSocketText(clientSocket.get(), state)) + break; + previousState = state; + } + std::this_thread::sleep_for(std::chrono::milliseconds(250)); + } + + std::lock_guard lock(mClientThreadsMutex); + const std::thread::id currentId = std::this_thread::get_id(); + for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it) + { + if (it->get_id() != currentId) + continue; + mFinishedClientThreads.push_back(std::move(*it)); + mClientThreads.erase(it); + break; + } +} + +void HttpControlServer::JoinFinishedClientThreads() +{ + std::vector finished; + { + std::lock_guard lock(mClientThreadsMutex); + finished.swap(mFinishedClientThreads); + } + for (std::thread& thread : finished) + { + if (thread.joinable()) + thread.join(); + } +} + +bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text) +{ + if (clientSocket == INVALID_SOCKET) + return false; + + std::vector frame; + frame.reserve(text.size() + 16); + frame.push_back(0x81); + if (text.size() <= 125) + { + frame.push_back(static_cast(text.size())); + } + else if (text.size() <= 0xffff) + { + frame.push_back(126); + frame.push_back(static_cast((text.size() >> 8) & 0xff)); + frame.push_back(static_cast(text.size() & 0xff)); + } + else + { + frame.push_back(127); + const uint64_t length = static_cast(text.size()); + for (int shift = 56; shift >= 0; shift -= 8) + frame.push_back(static_cast((length >> shift) & 0xff)); + } + frame.insert(frame.end(), text.begin(), text.end()); + + const char* data = reinterpret_cast(frame.data()); + int remaining = static_cast(frame.size()); + while (remaining > 0) + { + const int sent = send(clientSocket, data, remaining, 0); + if (sent <= 0) + { + const int error = WSAGetLastError(); + if (error == WSAEWOULDBLOCK) + { + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + continue; + } + return false; + } + data += sent; + remaining -= sent; + } + return true; +} + +std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey) +{ + static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + const std::array digest = Sha1(clientKey + kWebSocketGuid); + return Base64Encode(digest.data(), digest.size()); +} +}