Websocket split
All checks were successful
CI / React UI Build (push) Successful in 10s
CI / Native Windows Build And Tests (push) Successful in 3m0s
CI / Windows Release Package (push) Has been skipped

This commit is contained in:
Aiden
2026-05-12 15:36:02 +10:00
parent da7e1a93f6
commit 6a33bd02ab
3 changed files with 250 additions and 236 deletions

View File

@@ -312,6 +312,7 @@ set(RENDER_CADENCE_APP_SOURCES
"${RENDER_CADENCE_APP_DIR}/control/ControlActionResult.h"
"${RENDER_CADENCE_APP_DIR}/control/HttpControlServer.cpp"
"${RENDER_CADENCE_APP_DIR}/control/HttpControlServer.h"
"${RENDER_CADENCE_APP_DIR}/control/HttpControlServerWebSocket.cpp"
"${RENDER_CADENCE_APP_DIR}/control/RuntimeStateJson.h"
"${RENDER_CADENCE_APP_DIR}/frames/SystemFrameExchange.cpp"
"${RENDER_CADENCE_APP_DIR}/frames/SystemFrameExchange.h"
@@ -936,6 +937,7 @@ add_test(NAME RenderCadenceCompositorRuntimeStateJsonTests COMMAND RenderCadence
add_executable(RenderCadenceCompositorHttpControlServerTests
"${RENDER_CADENCE_APP_DIR}/control/HttpControlServer.cpp"
"${RENDER_CADENCE_APP_DIR}/control/HttpControlServerWebSocket.cpp"
"${RENDER_CADENCE_APP_DIR}/json/JsonWriter.cpp"
"${RENDER_CADENCE_APP_DIR}/logging/Logger.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tests/RenderCadenceCompositorHttpControlServerTests.cpp"

View File

@@ -6,9 +6,7 @@
#include <ws2tcpip.h>
#include <algorithm>
#include <array>
#include <cctype>
#include <cstdint>
#include <fstream>
#include <sstream>
#include <vector>
@@ -45,116 +43,6 @@ bool IsKnownPostEndpoint(const std::string& path)
|| path == "/api/screenshot";
}
std::array<uint8_t, 20> Sha1(const std::string& input)
{
auto leftRotate = [](uint32_t value, uint32_t bits) {
return (value << bits) | (value >> (32U - bits));
};
std::vector<uint8_t> data(input.begin(), input.end());
const uint64_t bitLength = static_cast<uint64_t>(data.size()) * 8ULL;
data.push_back(0x80);
while ((data.size() % 64) != 56)
data.push_back(0);
for (int shift = 56; shift >= 0; shift -= 8)
data.push_back(static_cast<uint8_t>((bitLength >> shift) & 0xff));
uint32_t h0 = 0x67452301;
uint32_t h1 = 0xefcdab89;
uint32_t h2 = 0x98badcfe;
uint32_t h3 = 0x10325476;
uint32_t h4 = 0xc3d2e1f0;
for (std::size_t offset = 0; offset < data.size(); offset += 64)
{
uint32_t words[80] = {};
for (std::size_t i = 0; i < 16; ++i)
{
const std::size_t index = offset + i * 4;
words[i] = (static_cast<uint32_t>(data[index]) << 24)
| (static_cast<uint32_t>(data[index + 1]) << 16)
| (static_cast<uint32_t>(data[index + 2]) << 8)
| static_cast<uint32_t>(data[index + 3]);
}
for (std::size_t i = 16; i < 80; ++i)
words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1);
uint32_t a = h0;
uint32_t b = h1;
uint32_t c = h2;
uint32_t d = h3;
uint32_t e = h4;
for (std::size_t i = 0; i < 80; ++i)
{
uint32_t f = 0;
uint32_t k = 0;
if (i < 20)
{
f = (b & c) | ((~b) & d);
k = 0x5a827999;
}
else if (i < 40)
{
f = b ^ c ^ d;
k = 0x6ed9eba1;
}
else if (i < 60)
{
f = (b & c) | (b & d) | (c & d);
k = 0x8f1bbcdc;
}
else
{
f = b ^ c ^ d;
k = 0xca62c1d6;
}
const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i];
e = d;
d = c;
c = leftRotate(b, 30);
b = a;
a = temp;
}
h0 += a;
h1 += b;
h2 += c;
h3 += d;
h4 += e;
}
std::array<uint8_t, 20> digest = {};
const uint32_t parts[] = { h0, h1, h2, h3, h4 };
for (std::size_t i = 0; i < 5; ++i)
{
digest[i * 4] = static_cast<uint8_t>((parts[i] >> 24) & 0xff);
digest[i * 4 + 1] = static_cast<uint8_t>((parts[i] >> 16) & 0xff);
digest[i * 4 + 2] = static_cast<uint8_t>((parts[i] >> 8) & 0xff);
digest[i * 4 + 3] = static_cast<uint8_t>(parts[i] & 0xff);
}
return digest;
}
std::string Base64Encode(const uint8_t* data, std::size_t size)
{
static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string output;
output.reserve(((size + 2) / 3) * 4);
for (std::size_t i = 0; i < size; i += 3)
{
const uint32_t a = data[i];
const uint32_t b = i + 1 < size ? data[i + 1] : 0;
const uint32_t c = i + 2 < size ? data[i + 2] : 0;
const uint32_t triple = (a << 16) | (b << 8) | c;
output.push_back(kAlphabet[(triple >> 18) & 0x3f]);
output.push_back(kAlphabet[(triple >> 12) & 0x3f]);
output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '=');
output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '=');
}
return output;
}
}
UniqueSocket::UniqueSocket(SOCKET socket) :
@@ -347,75 +235,6 @@ bool HttpControlServer::HandleClient(UniqueSocket clientSocket)
return SendResponse(clientSocket.get(), RouteRequest(request));
}
bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request)
{
const auto keyIt = request.headers.find("sec-websocket-key");
if (keyIt == request.headers.end() || keyIt->second.empty())
return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key"));
std::ostringstream stream;
stream << "HTTP/1.1 101 Switching Protocols\r\n"
<< "Upgrade: websocket\r\n"
<< "Connection: Upgrade\r\n"
<< "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n";
const std::string response = stream.str();
if (send(clientSocket.get(), response.c_str(), static_cast<int>(response.size()), 0) != static_cast<int>(response.size()))
return false;
u_long nonBlocking = 1;
ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking);
std::thread thread([this, socket = std::move(clientSocket)]() mutable {
WebSocketClientMain(std::move(socket));
});
{
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
mClientThreads.push_back(std::move(thread));
}
return true;
}
void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket)
{
std::string previousState;
while (mRunning.load(std::memory_order_acquire))
{
const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
if (state != previousState)
{
if (!SendWebSocketText(clientSocket.get(), state))
break;
previousState = state;
}
std::this_thread::sleep_for(std::chrono::milliseconds(250));
}
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
const std::thread::id currentId = std::this_thread::get_id();
for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it)
{
if (it->get_id() != currentId)
continue;
mFinishedClientThreads.push_back(std::move(*it));
mClientThreads.erase(it);
break;
}
}
void HttpControlServer::JoinFinishedClientThreads()
{
std::vector<std::thread> finished;
{
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
finished.swap(mFinishedClientThreads);
}
for (std::thread& thread : finished)
{
if (thread.joinable())
thread.join();
}
}
bool HttpControlServer::SendResponse(SOCKET clientSocket, const HttpResponse& response) const
{
std::ostringstream stream;
@@ -561,61 +380,6 @@ std::string HttpControlServer::ActionResponse(bool ok, const std::string& error)
return writer.StringValue();
}
bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text)
{
if (clientSocket == INVALID_SOCKET)
return false;
std::vector<unsigned char> frame;
frame.reserve(text.size() + 16);
frame.push_back(0x81);
if (text.size() <= 125)
{
frame.push_back(static_cast<unsigned char>(text.size()));
}
else if (text.size() <= 0xffff)
{
frame.push_back(126);
frame.push_back(static_cast<unsigned char>((text.size() >> 8) & 0xff));
frame.push_back(static_cast<unsigned char>(text.size() & 0xff));
}
else
{
frame.push_back(127);
const uint64_t length = static_cast<uint64_t>(text.size());
for (int shift = 56; shift >= 0; shift -= 8)
frame.push_back(static_cast<unsigned char>((length >> shift) & 0xff));
}
frame.insert(frame.end(), text.begin(), text.end());
const char* data = reinterpret_cast<const char*>(frame.data());
int remaining = static_cast<int>(frame.size());
while (remaining > 0)
{
const int sent = send(clientSocket, data, remaining, 0);
if (sent <= 0)
{
const int error = WSAGetLastError();
if (error == WSAEWOULDBLOCK)
{
std::this_thread::sleep_for(std::chrono::milliseconds(2));
continue;
}
return false;
}
data += sent;
remaining -= sent;
}
return true;
}
std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey)
{
static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const std::array<uint8_t, 20> digest = Sha1(clientKey + kWebSocketGuid);
return Base64Encode(digest.data(), digest.size());
}
std::string HttpControlServer::GuessContentType(const std::filesystem::path& path)
{
const std::string extension = ToLower(path.extension().string());

View File

@@ -0,0 +1,248 @@
#include "HttpControlServer.h"
#include <array>
#include <cstdint>
#include <sstream>
#include <utility>
#include <vector>
namespace RenderCadenceCompositor
{
namespace
{
std::array<uint8_t, 20> Sha1(const std::string& input)
{
auto leftRotate = [](uint32_t value, uint32_t bits) {
return (value << bits) | (value >> (32U - bits));
};
std::vector<uint8_t> data(input.begin(), input.end());
const uint64_t bitLength = static_cast<uint64_t>(data.size()) * 8ULL;
data.push_back(0x80);
while ((data.size() % 64) != 56)
data.push_back(0);
for (int shift = 56; shift >= 0; shift -= 8)
data.push_back(static_cast<uint8_t>((bitLength >> shift) & 0xff));
uint32_t h0 = 0x67452301;
uint32_t h1 = 0xefcdab89;
uint32_t h2 = 0x98badcfe;
uint32_t h3 = 0x10325476;
uint32_t h4 = 0xc3d2e1f0;
for (std::size_t offset = 0; offset < data.size(); offset += 64)
{
uint32_t words[80] = {};
for (std::size_t i = 0; i < 16; ++i)
{
const std::size_t index = offset + i * 4;
words[i] = (static_cast<uint32_t>(data[index]) << 24)
| (static_cast<uint32_t>(data[index + 1]) << 16)
| (static_cast<uint32_t>(data[index + 2]) << 8)
| static_cast<uint32_t>(data[index + 3]);
}
for (std::size_t i = 16; i < 80; ++i)
words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1);
uint32_t a = h0;
uint32_t b = h1;
uint32_t c = h2;
uint32_t d = h3;
uint32_t e = h4;
for (std::size_t i = 0; i < 80; ++i)
{
uint32_t f = 0;
uint32_t k = 0;
if (i < 20)
{
f = (b & c) | ((~b) & d);
k = 0x5a827999;
}
else if (i < 40)
{
f = b ^ c ^ d;
k = 0x6ed9eba1;
}
else if (i < 60)
{
f = (b & c) | (b & d) | (c & d);
k = 0x8f1bbcdc;
}
else
{
f = b ^ c ^ d;
k = 0xca62c1d6;
}
const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i];
e = d;
d = c;
c = leftRotate(b, 30);
b = a;
a = temp;
}
h0 += a;
h1 += b;
h2 += c;
h3 += d;
h4 += e;
}
std::array<uint8_t, 20> digest = {};
const uint32_t parts[] = { h0, h1, h2, h3, h4 };
for (std::size_t i = 0; i < 5; ++i)
{
digest[i * 4] = static_cast<uint8_t>((parts[i] >> 24) & 0xff);
digest[i * 4 + 1] = static_cast<uint8_t>((parts[i] >> 16) & 0xff);
digest[i * 4 + 2] = static_cast<uint8_t>((parts[i] >> 8) & 0xff);
digest[i * 4 + 3] = static_cast<uint8_t>(parts[i] & 0xff);
}
return digest;
}
std::string Base64Encode(const uint8_t* data, std::size_t size)
{
static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string output;
output.reserve(((size + 2) / 3) * 4);
for (std::size_t i = 0; i < size; i += 3)
{
const uint32_t a = data[i];
const uint32_t b = i + 1 < size ? data[i + 1] : 0;
const uint32_t c = i + 2 < size ? data[i + 2] : 0;
const uint32_t triple = (a << 16) | (b << 8) | c;
output.push_back(kAlphabet[(triple >> 18) & 0x3f]);
output.push_back(kAlphabet[(triple >> 12) & 0x3f]);
output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '=');
output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '=');
}
return output;
}
}
bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request)
{
const auto keyIt = request.headers.find("sec-websocket-key");
if (keyIt == request.headers.end() || keyIt->second.empty())
return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key"));
std::ostringstream stream;
stream << "HTTP/1.1 101 Switching Protocols\r\n"
<< "Upgrade: websocket\r\n"
<< "Connection: Upgrade\r\n"
<< "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n";
const std::string response = stream.str();
if (send(clientSocket.get(), response.c_str(), static_cast<int>(response.size()), 0) != static_cast<int>(response.size()))
return false;
u_long nonBlocking = 1;
ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking);
std::thread thread([this, socket = std::move(clientSocket)]() mutable {
WebSocketClientMain(std::move(socket));
});
{
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
mClientThreads.push_back(std::move(thread));
}
return true;
}
void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket)
{
std::string previousState;
while (mRunning.load(std::memory_order_acquire))
{
const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
if (state != previousState)
{
if (!SendWebSocketText(clientSocket.get(), state))
break;
previousState = state;
}
std::this_thread::sleep_for(std::chrono::milliseconds(250));
}
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
const std::thread::id currentId = std::this_thread::get_id();
for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it)
{
if (it->get_id() != currentId)
continue;
mFinishedClientThreads.push_back(std::move(*it));
mClientThreads.erase(it);
break;
}
}
void HttpControlServer::JoinFinishedClientThreads()
{
std::vector<std::thread> finished;
{
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
finished.swap(mFinishedClientThreads);
}
for (std::thread& thread : finished)
{
if (thread.joinable())
thread.join();
}
}
bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text)
{
if (clientSocket == INVALID_SOCKET)
return false;
std::vector<unsigned char> frame;
frame.reserve(text.size() + 16);
frame.push_back(0x81);
if (text.size() <= 125)
{
frame.push_back(static_cast<unsigned char>(text.size()));
}
else if (text.size() <= 0xffff)
{
frame.push_back(126);
frame.push_back(static_cast<unsigned char>((text.size() >> 8) & 0xff));
frame.push_back(static_cast<unsigned char>(text.size() & 0xff));
}
else
{
frame.push_back(127);
const uint64_t length = static_cast<uint64_t>(text.size());
for (int shift = 56; shift >= 0; shift -= 8)
frame.push_back(static_cast<unsigned char>((length >> shift) & 0xff));
}
frame.insert(frame.end(), text.begin(), text.end());
const char* data = reinterpret_cast<const char*>(frame.data());
int remaining = static_cast<int>(frame.size());
while (remaining > 0)
{
const int sent = send(clientSocket, data, remaining, 0);
if (sent <= 0)
{
const int error = WSAGetLastError();
if (error == WSAEWOULDBLOCK)
{
std::this_thread::sleep_for(std::chrono::milliseconds(2));
continue;
}
return false;
}
data += sent;
remaining -= sent;
}
return true;
}
std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey)
{
static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const std::array<uint8_t, 20> digest = Sha1(clientKey + kWebSocketGuid);
return Base64Encode(digest.data(), digest.size());
}
}