Websocket split
This commit is contained in:
@@ -6,9 +6,7 @@
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cctype>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
@@ -45,116 +43,6 @@ bool IsKnownPostEndpoint(const std::string& path)
|
||||
|| path == "/api/screenshot";
|
||||
}
|
||||
|
||||
std::array<uint8_t, 20> Sha1(const std::string& input)
|
||||
{
|
||||
auto leftRotate = [](uint32_t value, uint32_t bits) {
|
||||
return (value << bits) | (value >> (32U - bits));
|
||||
};
|
||||
|
||||
std::vector<uint8_t> data(input.begin(), input.end());
|
||||
const uint64_t bitLength = static_cast<uint64_t>(data.size()) * 8ULL;
|
||||
data.push_back(0x80);
|
||||
while ((data.size() % 64) != 56)
|
||||
data.push_back(0);
|
||||
for (int shift = 56; shift >= 0; shift -= 8)
|
||||
data.push_back(static_cast<uint8_t>((bitLength >> shift) & 0xff));
|
||||
|
||||
uint32_t h0 = 0x67452301;
|
||||
uint32_t h1 = 0xefcdab89;
|
||||
uint32_t h2 = 0x98badcfe;
|
||||
uint32_t h3 = 0x10325476;
|
||||
uint32_t h4 = 0xc3d2e1f0;
|
||||
|
||||
for (std::size_t offset = 0; offset < data.size(); offset += 64)
|
||||
{
|
||||
uint32_t words[80] = {};
|
||||
for (std::size_t i = 0; i < 16; ++i)
|
||||
{
|
||||
const std::size_t index = offset + i * 4;
|
||||
words[i] = (static_cast<uint32_t>(data[index]) << 24)
|
||||
| (static_cast<uint32_t>(data[index + 1]) << 16)
|
||||
| (static_cast<uint32_t>(data[index + 2]) << 8)
|
||||
| static_cast<uint32_t>(data[index + 3]);
|
||||
}
|
||||
for (std::size_t i = 16; i < 80; ++i)
|
||||
words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1);
|
||||
|
||||
uint32_t a = h0;
|
||||
uint32_t b = h1;
|
||||
uint32_t c = h2;
|
||||
uint32_t d = h3;
|
||||
uint32_t e = h4;
|
||||
|
||||
for (std::size_t i = 0; i < 80; ++i)
|
||||
{
|
||||
uint32_t f = 0;
|
||||
uint32_t k = 0;
|
||||
if (i < 20)
|
||||
{
|
||||
f = (b & c) | ((~b) & d);
|
||||
k = 0x5a827999;
|
||||
}
|
||||
else if (i < 40)
|
||||
{
|
||||
f = b ^ c ^ d;
|
||||
k = 0x6ed9eba1;
|
||||
}
|
||||
else if (i < 60)
|
||||
{
|
||||
f = (b & c) | (b & d) | (c & d);
|
||||
k = 0x8f1bbcdc;
|
||||
}
|
||||
else
|
||||
{
|
||||
f = b ^ c ^ d;
|
||||
k = 0xca62c1d6;
|
||||
}
|
||||
|
||||
const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i];
|
||||
e = d;
|
||||
d = c;
|
||||
c = leftRotate(b, 30);
|
||||
b = a;
|
||||
a = temp;
|
||||
}
|
||||
|
||||
h0 += a;
|
||||
h1 += b;
|
||||
h2 += c;
|
||||
h3 += d;
|
||||
h4 += e;
|
||||
}
|
||||
|
||||
std::array<uint8_t, 20> digest = {};
|
||||
const uint32_t parts[] = { h0, h1, h2, h3, h4 };
|
||||
for (std::size_t i = 0; i < 5; ++i)
|
||||
{
|
||||
digest[i * 4] = static_cast<uint8_t>((parts[i] >> 24) & 0xff);
|
||||
digest[i * 4 + 1] = static_cast<uint8_t>((parts[i] >> 16) & 0xff);
|
||||
digest[i * 4 + 2] = static_cast<uint8_t>((parts[i] >> 8) & 0xff);
|
||||
digest[i * 4 + 3] = static_cast<uint8_t>(parts[i] & 0xff);
|
||||
}
|
||||
return digest;
|
||||
}
|
||||
|
||||
std::string Base64Encode(const uint8_t* data, std::size_t size)
|
||||
{
|
||||
static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
std::string output;
|
||||
output.reserve(((size + 2) / 3) * 4);
|
||||
for (std::size_t i = 0; i < size; i += 3)
|
||||
{
|
||||
const uint32_t a = data[i];
|
||||
const uint32_t b = i + 1 < size ? data[i + 1] : 0;
|
||||
const uint32_t c = i + 2 < size ? data[i + 2] : 0;
|
||||
const uint32_t triple = (a << 16) | (b << 8) | c;
|
||||
output.push_back(kAlphabet[(triple >> 18) & 0x3f]);
|
||||
output.push_back(kAlphabet[(triple >> 12) & 0x3f]);
|
||||
output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '=');
|
||||
output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '=');
|
||||
}
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
UniqueSocket::UniqueSocket(SOCKET socket) :
|
||||
@@ -347,75 +235,6 @@ bool HttpControlServer::HandleClient(UniqueSocket clientSocket)
|
||||
return SendResponse(clientSocket.get(), RouteRequest(request));
|
||||
}
|
||||
|
||||
bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request)
|
||||
{
|
||||
const auto keyIt = request.headers.find("sec-websocket-key");
|
||||
if (keyIt == request.headers.end() || keyIt->second.empty())
|
||||
return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key"));
|
||||
|
||||
std::ostringstream stream;
|
||||
stream << "HTTP/1.1 101 Switching Protocols\r\n"
|
||||
<< "Upgrade: websocket\r\n"
|
||||
<< "Connection: Upgrade\r\n"
|
||||
<< "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n";
|
||||
const std::string response = stream.str();
|
||||
if (send(clientSocket.get(), response.c_str(), static_cast<int>(response.size()), 0) != static_cast<int>(response.size()))
|
||||
return false;
|
||||
|
||||
u_long nonBlocking = 1;
|
||||
ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking);
|
||||
|
||||
std::thread thread([this, socket = std::move(clientSocket)]() mutable {
|
||||
WebSocketClientMain(std::move(socket));
|
||||
});
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
||||
mClientThreads.push_back(std::move(thread));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket)
|
||||
{
|
||||
std::string previousState;
|
||||
while (mRunning.load(std::memory_order_acquire))
|
||||
{
|
||||
const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
|
||||
if (state != previousState)
|
||||
{
|
||||
if (!SendWebSocketText(clientSocket.get(), state))
|
||||
break;
|
||||
previousState = state;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(250));
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
||||
const std::thread::id currentId = std::this_thread::get_id();
|
||||
for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it)
|
||||
{
|
||||
if (it->get_id() != currentId)
|
||||
continue;
|
||||
mFinishedClientThreads.push_back(std::move(*it));
|
||||
mClientThreads.erase(it);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void HttpControlServer::JoinFinishedClientThreads()
|
||||
{
|
||||
std::vector<std::thread> finished;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
||||
finished.swap(mFinishedClientThreads);
|
||||
}
|
||||
for (std::thread& thread : finished)
|
||||
{
|
||||
if (thread.joinable())
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
bool HttpControlServer::SendResponse(SOCKET clientSocket, const HttpResponse& response) const
|
||||
{
|
||||
std::ostringstream stream;
|
||||
@@ -561,61 +380,6 @@ std::string HttpControlServer::ActionResponse(bool ok, const std::string& error)
|
||||
return writer.StringValue();
|
||||
}
|
||||
|
||||
bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text)
|
||||
{
|
||||
if (clientSocket == INVALID_SOCKET)
|
||||
return false;
|
||||
|
||||
std::vector<unsigned char> frame;
|
||||
frame.reserve(text.size() + 16);
|
||||
frame.push_back(0x81);
|
||||
if (text.size() <= 125)
|
||||
{
|
||||
frame.push_back(static_cast<unsigned char>(text.size()));
|
||||
}
|
||||
else if (text.size() <= 0xffff)
|
||||
{
|
||||
frame.push_back(126);
|
||||
frame.push_back(static_cast<unsigned char>((text.size() >> 8) & 0xff));
|
||||
frame.push_back(static_cast<unsigned char>(text.size() & 0xff));
|
||||
}
|
||||
else
|
||||
{
|
||||
frame.push_back(127);
|
||||
const uint64_t length = static_cast<uint64_t>(text.size());
|
||||
for (int shift = 56; shift >= 0; shift -= 8)
|
||||
frame.push_back(static_cast<unsigned char>((length >> shift) & 0xff));
|
||||
}
|
||||
frame.insert(frame.end(), text.begin(), text.end());
|
||||
|
||||
const char* data = reinterpret_cast<const char*>(frame.data());
|
||||
int remaining = static_cast<int>(frame.size());
|
||||
while (remaining > 0)
|
||||
{
|
||||
const int sent = send(clientSocket, data, remaining, 0);
|
||||
if (sent <= 0)
|
||||
{
|
||||
const int error = WSAGetLastError();
|
||||
if (error == WSAEWOULDBLOCK)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2));
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
data += sent;
|
||||
remaining -= sent;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey)
|
||||
{
|
||||
static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
const std::array<uint8_t, 20> digest = Sha1(clientKey + kWebSocketGuid);
|
||||
return Base64Encode(digest.data(), digest.size());
|
||||
}
|
||||
|
||||
std::string HttpControlServer::GuessContentType(const std::filesystem::path& path)
|
||||
{
|
||||
const std::string extension = ToLower(path.extension().string());
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
#include "HttpControlServer.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace RenderCadenceCompositor
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::array<uint8_t, 20> Sha1(const std::string& input)
|
||||
{
|
||||
auto leftRotate = [](uint32_t value, uint32_t bits) {
|
||||
return (value << bits) | (value >> (32U - bits));
|
||||
};
|
||||
|
||||
std::vector<uint8_t> data(input.begin(), input.end());
|
||||
const uint64_t bitLength = static_cast<uint64_t>(data.size()) * 8ULL;
|
||||
data.push_back(0x80);
|
||||
while ((data.size() % 64) != 56)
|
||||
data.push_back(0);
|
||||
for (int shift = 56; shift >= 0; shift -= 8)
|
||||
data.push_back(static_cast<uint8_t>((bitLength >> shift) & 0xff));
|
||||
|
||||
uint32_t h0 = 0x67452301;
|
||||
uint32_t h1 = 0xefcdab89;
|
||||
uint32_t h2 = 0x98badcfe;
|
||||
uint32_t h3 = 0x10325476;
|
||||
uint32_t h4 = 0xc3d2e1f0;
|
||||
|
||||
for (std::size_t offset = 0; offset < data.size(); offset += 64)
|
||||
{
|
||||
uint32_t words[80] = {};
|
||||
for (std::size_t i = 0; i < 16; ++i)
|
||||
{
|
||||
const std::size_t index = offset + i * 4;
|
||||
words[i] = (static_cast<uint32_t>(data[index]) << 24)
|
||||
| (static_cast<uint32_t>(data[index + 1]) << 16)
|
||||
| (static_cast<uint32_t>(data[index + 2]) << 8)
|
||||
| static_cast<uint32_t>(data[index + 3]);
|
||||
}
|
||||
for (std::size_t i = 16; i < 80; ++i)
|
||||
words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1);
|
||||
|
||||
uint32_t a = h0;
|
||||
uint32_t b = h1;
|
||||
uint32_t c = h2;
|
||||
uint32_t d = h3;
|
||||
uint32_t e = h4;
|
||||
|
||||
for (std::size_t i = 0; i < 80; ++i)
|
||||
{
|
||||
uint32_t f = 0;
|
||||
uint32_t k = 0;
|
||||
if (i < 20)
|
||||
{
|
||||
f = (b & c) | ((~b) & d);
|
||||
k = 0x5a827999;
|
||||
}
|
||||
else if (i < 40)
|
||||
{
|
||||
f = b ^ c ^ d;
|
||||
k = 0x6ed9eba1;
|
||||
}
|
||||
else if (i < 60)
|
||||
{
|
||||
f = (b & c) | (b & d) | (c & d);
|
||||
k = 0x8f1bbcdc;
|
||||
}
|
||||
else
|
||||
{
|
||||
f = b ^ c ^ d;
|
||||
k = 0xca62c1d6;
|
||||
}
|
||||
|
||||
const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i];
|
||||
e = d;
|
||||
d = c;
|
||||
c = leftRotate(b, 30);
|
||||
b = a;
|
||||
a = temp;
|
||||
}
|
||||
|
||||
h0 += a;
|
||||
h1 += b;
|
||||
h2 += c;
|
||||
h3 += d;
|
||||
h4 += e;
|
||||
}
|
||||
|
||||
std::array<uint8_t, 20> digest = {};
|
||||
const uint32_t parts[] = { h0, h1, h2, h3, h4 };
|
||||
for (std::size_t i = 0; i < 5; ++i)
|
||||
{
|
||||
digest[i * 4] = static_cast<uint8_t>((parts[i] >> 24) & 0xff);
|
||||
digest[i * 4 + 1] = static_cast<uint8_t>((parts[i] >> 16) & 0xff);
|
||||
digest[i * 4 + 2] = static_cast<uint8_t>((parts[i] >> 8) & 0xff);
|
||||
digest[i * 4 + 3] = static_cast<uint8_t>(parts[i] & 0xff);
|
||||
}
|
||||
return digest;
|
||||
}
|
||||
|
||||
std::string Base64Encode(const uint8_t* data, std::size_t size)
|
||||
{
|
||||
static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
std::string output;
|
||||
output.reserve(((size + 2) / 3) * 4);
|
||||
for (std::size_t i = 0; i < size; i += 3)
|
||||
{
|
||||
const uint32_t a = data[i];
|
||||
const uint32_t b = i + 1 < size ? data[i + 1] : 0;
|
||||
const uint32_t c = i + 2 < size ? data[i + 2] : 0;
|
||||
const uint32_t triple = (a << 16) | (b << 8) | c;
|
||||
output.push_back(kAlphabet[(triple >> 18) & 0x3f]);
|
||||
output.push_back(kAlphabet[(triple >> 12) & 0x3f]);
|
||||
output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '=');
|
||||
output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '=');
|
||||
}
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request)
|
||||
{
|
||||
const auto keyIt = request.headers.find("sec-websocket-key");
|
||||
if (keyIt == request.headers.end() || keyIt->second.empty())
|
||||
return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key"));
|
||||
|
||||
std::ostringstream stream;
|
||||
stream << "HTTP/1.1 101 Switching Protocols\r\n"
|
||||
<< "Upgrade: websocket\r\n"
|
||||
<< "Connection: Upgrade\r\n"
|
||||
<< "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n";
|
||||
const std::string response = stream.str();
|
||||
if (send(clientSocket.get(), response.c_str(), static_cast<int>(response.size()), 0) != static_cast<int>(response.size()))
|
||||
return false;
|
||||
|
||||
u_long nonBlocking = 1;
|
||||
ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking);
|
||||
|
||||
std::thread thread([this, socket = std::move(clientSocket)]() mutable {
|
||||
WebSocketClientMain(std::move(socket));
|
||||
});
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
||||
mClientThreads.push_back(std::move(thread));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket)
|
||||
{
|
||||
std::string previousState;
|
||||
while (mRunning.load(std::memory_order_acquire))
|
||||
{
|
||||
const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
|
||||
if (state != previousState)
|
||||
{
|
||||
if (!SendWebSocketText(clientSocket.get(), state))
|
||||
break;
|
||||
previousState = state;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(250));
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
||||
const std::thread::id currentId = std::this_thread::get_id();
|
||||
for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it)
|
||||
{
|
||||
if (it->get_id() != currentId)
|
||||
continue;
|
||||
mFinishedClientThreads.push_back(std::move(*it));
|
||||
mClientThreads.erase(it);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void HttpControlServer::JoinFinishedClientThreads()
|
||||
{
|
||||
std::vector<std::thread> finished;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
||||
finished.swap(mFinishedClientThreads);
|
||||
}
|
||||
for (std::thread& thread : finished)
|
||||
{
|
||||
if (thread.joinable())
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text)
|
||||
{
|
||||
if (clientSocket == INVALID_SOCKET)
|
||||
return false;
|
||||
|
||||
std::vector<unsigned char> frame;
|
||||
frame.reserve(text.size() + 16);
|
||||
frame.push_back(0x81);
|
||||
if (text.size() <= 125)
|
||||
{
|
||||
frame.push_back(static_cast<unsigned char>(text.size()));
|
||||
}
|
||||
else if (text.size() <= 0xffff)
|
||||
{
|
||||
frame.push_back(126);
|
||||
frame.push_back(static_cast<unsigned char>((text.size() >> 8) & 0xff));
|
||||
frame.push_back(static_cast<unsigned char>(text.size() & 0xff));
|
||||
}
|
||||
else
|
||||
{
|
||||
frame.push_back(127);
|
||||
const uint64_t length = static_cast<uint64_t>(text.size());
|
||||
for (int shift = 56; shift >= 0; shift -= 8)
|
||||
frame.push_back(static_cast<unsigned char>((length >> shift) & 0xff));
|
||||
}
|
||||
frame.insert(frame.end(), text.begin(), text.end());
|
||||
|
||||
const char* data = reinterpret_cast<const char*>(frame.data());
|
||||
int remaining = static_cast<int>(frame.size());
|
||||
while (remaining > 0)
|
||||
{
|
||||
const int sent = send(clientSocket, data, remaining, 0);
|
||||
if (sent <= 0)
|
||||
{
|
||||
const int error = WSAGetLastError();
|
||||
if (error == WSAEWOULDBLOCK)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2));
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
data += sent;
|
||||
remaining -= sent;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey)
|
||||
{
|
||||
static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
const std::array<uint8_t, 20> digest = Sha1(clientKey + kWebSocketGuid);
|
||||
return Base64Encode(digest.data(), digest.size());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user