249 lines
6.7 KiB
C++
249 lines
6.7 KiB
C++
#include "HttpControlServer.h"
|
|
|
|
#include <array>
|
|
#include <cstdint>
|
|
#include <sstream>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace RenderCadenceCompositor
|
|
{
|
|
namespace
|
|
{
|
|
std::array<uint8_t, 20> Sha1(const std::string& input)
|
|
{
|
|
auto leftRotate = [](uint32_t value, uint32_t bits) {
|
|
return (value << bits) | (value >> (32U - bits));
|
|
};
|
|
|
|
std::vector<uint8_t> data(input.begin(), input.end());
|
|
const uint64_t bitLength = static_cast<uint64_t>(data.size()) * 8ULL;
|
|
data.push_back(0x80);
|
|
while ((data.size() % 64) != 56)
|
|
data.push_back(0);
|
|
for (int shift = 56; shift >= 0; shift -= 8)
|
|
data.push_back(static_cast<uint8_t>((bitLength >> shift) & 0xff));
|
|
|
|
uint32_t h0 = 0x67452301;
|
|
uint32_t h1 = 0xefcdab89;
|
|
uint32_t h2 = 0x98badcfe;
|
|
uint32_t h3 = 0x10325476;
|
|
uint32_t h4 = 0xc3d2e1f0;
|
|
|
|
for (std::size_t offset = 0; offset < data.size(); offset += 64)
|
|
{
|
|
uint32_t words[80] = {};
|
|
for (std::size_t i = 0; i < 16; ++i)
|
|
{
|
|
const std::size_t index = offset + i * 4;
|
|
words[i] = (static_cast<uint32_t>(data[index]) << 24)
|
|
| (static_cast<uint32_t>(data[index + 1]) << 16)
|
|
| (static_cast<uint32_t>(data[index + 2]) << 8)
|
|
| static_cast<uint32_t>(data[index + 3]);
|
|
}
|
|
for (std::size_t i = 16; i < 80; ++i)
|
|
words[i] = leftRotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1);
|
|
|
|
uint32_t a = h0;
|
|
uint32_t b = h1;
|
|
uint32_t c = h2;
|
|
uint32_t d = h3;
|
|
uint32_t e = h4;
|
|
|
|
for (std::size_t i = 0; i < 80; ++i)
|
|
{
|
|
uint32_t f = 0;
|
|
uint32_t k = 0;
|
|
if (i < 20)
|
|
{
|
|
f = (b & c) | ((~b) & d);
|
|
k = 0x5a827999;
|
|
}
|
|
else if (i < 40)
|
|
{
|
|
f = b ^ c ^ d;
|
|
k = 0x6ed9eba1;
|
|
}
|
|
else if (i < 60)
|
|
{
|
|
f = (b & c) | (b & d) | (c & d);
|
|
k = 0x8f1bbcdc;
|
|
}
|
|
else
|
|
{
|
|
f = b ^ c ^ d;
|
|
k = 0xca62c1d6;
|
|
}
|
|
|
|
const uint32_t temp = leftRotate(a, 5) + f + e + k + words[i];
|
|
e = d;
|
|
d = c;
|
|
c = leftRotate(b, 30);
|
|
b = a;
|
|
a = temp;
|
|
}
|
|
|
|
h0 += a;
|
|
h1 += b;
|
|
h2 += c;
|
|
h3 += d;
|
|
h4 += e;
|
|
}
|
|
|
|
std::array<uint8_t, 20> digest = {};
|
|
const uint32_t parts[] = { h0, h1, h2, h3, h4 };
|
|
for (std::size_t i = 0; i < 5; ++i)
|
|
{
|
|
digest[i * 4] = static_cast<uint8_t>((parts[i] >> 24) & 0xff);
|
|
digest[i * 4 + 1] = static_cast<uint8_t>((parts[i] >> 16) & 0xff);
|
|
digest[i * 4 + 2] = static_cast<uint8_t>((parts[i] >> 8) & 0xff);
|
|
digest[i * 4 + 3] = static_cast<uint8_t>(parts[i] & 0xff);
|
|
}
|
|
return digest;
|
|
}
|
|
|
|
std::string Base64Encode(const uint8_t* data, std::size_t size)
|
|
{
|
|
static constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
std::string output;
|
|
output.reserve(((size + 2) / 3) * 4);
|
|
for (std::size_t i = 0; i < size; i += 3)
|
|
{
|
|
const uint32_t a = data[i];
|
|
const uint32_t b = i + 1 < size ? data[i + 1] : 0;
|
|
const uint32_t c = i + 2 < size ? data[i + 2] : 0;
|
|
const uint32_t triple = (a << 16) | (b << 8) | c;
|
|
output.push_back(kAlphabet[(triple >> 18) & 0x3f]);
|
|
output.push_back(kAlphabet[(triple >> 12) & 0x3f]);
|
|
output.push_back(i + 1 < size ? kAlphabet[(triple >> 6) & 0x3f] : '=');
|
|
output.push_back(i + 2 < size ? kAlphabet[triple & 0x3f] : '=');
|
|
}
|
|
return output;
|
|
}
|
|
}
|
|
|
|
bool HttpControlServer::HandleWebSocketClient(UniqueSocket clientSocket, const HttpRequest& request)
|
|
{
|
|
const auto keyIt = request.headers.find("sec-websocket-key");
|
|
if (keyIt == request.headers.end() || keyIt->second.empty())
|
|
return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Missing WebSocket key"));
|
|
|
|
std::ostringstream stream;
|
|
stream << "HTTP/1.1 101 Switching Protocols\r\n"
|
|
<< "Upgrade: websocket\r\n"
|
|
<< "Connection: Upgrade\r\n"
|
|
<< "Sec-WebSocket-Accept: " << WebSocketAcceptKey(keyIt->second) << "\r\n\r\n";
|
|
const std::string response = stream.str();
|
|
if (send(clientSocket.get(), response.c_str(), static_cast<int>(response.size()), 0) != static_cast<int>(response.size()))
|
|
return false;
|
|
|
|
u_long nonBlocking = 1;
|
|
ioctlsocket(clientSocket.get(), FIONBIO, &nonBlocking);
|
|
|
|
std::thread thread([this, socket = std::move(clientSocket)]() mutable {
|
|
WebSocketClientMain(std::move(socket));
|
|
});
|
|
{
|
|
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
|
mClientThreads.push_back(std::move(thread));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void HttpControlServer::WebSocketClientMain(UniqueSocket clientSocket)
|
|
{
|
|
std::string previousState;
|
|
while (mRunning.load(std::memory_order_acquire))
|
|
{
|
|
const std::string state = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
|
|
if (state != previousState)
|
|
{
|
|
if (!SendWebSocketText(clientSocket.get(), state))
|
|
break;
|
|
previousState = state;
|
|
}
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(250));
|
|
}
|
|
|
|
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
|
const std::thread::id currentId = std::this_thread::get_id();
|
|
for (auto it = mClientThreads.begin(); it != mClientThreads.end(); ++it)
|
|
{
|
|
if (it->get_id() != currentId)
|
|
continue;
|
|
mFinishedClientThreads.push_back(std::move(*it));
|
|
mClientThreads.erase(it);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void HttpControlServer::JoinFinishedClientThreads()
|
|
{
|
|
std::vector<std::thread> finished;
|
|
{
|
|
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
|
finished.swap(mFinishedClientThreads);
|
|
}
|
|
for (std::thread& thread : finished)
|
|
{
|
|
if (thread.joinable())
|
|
thread.join();
|
|
}
|
|
}
|
|
|
|
bool HttpControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& text)
|
|
{
|
|
if (clientSocket == INVALID_SOCKET)
|
|
return false;
|
|
|
|
std::vector<unsigned char> frame;
|
|
frame.reserve(text.size() + 16);
|
|
frame.push_back(0x81);
|
|
if (text.size() <= 125)
|
|
{
|
|
frame.push_back(static_cast<unsigned char>(text.size()));
|
|
}
|
|
else if (text.size() <= 0xffff)
|
|
{
|
|
frame.push_back(126);
|
|
frame.push_back(static_cast<unsigned char>((text.size() >> 8) & 0xff));
|
|
frame.push_back(static_cast<unsigned char>(text.size() & 0xff));
|
|
}
|
|
else
|
|
{
|
|
frame.push_back(127);
|
|
const uint64_t length = static_cast<uint64_t>(text.size());
|
|
for (int shift = 56; shift >= 0; shift -= 8)
|
|
frame.push_back(static_cast<unsigned char>((length >> shift) & 0xff));
|
|
}
|
|
frame.insert(frame.end(), text.begin(), text.end());
|
|
|
|
const char* data = reinterpret_cast<const char*>(frame.data());
|
|
int remaining = static_cast<int>(frame.size());
|
|
while (remaining > 0)
|
|
{
|
|
const int sent = send(clientSocket, data, remaining, 0);
|
|
if (sent <= 0)
|
|
{
|
|
const int error = WSAGetLastError();
|
|
if (error == WSAEWOULDBLOCK)
|
|
{
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(2));
|
|
continue;
|
|
}
|
|
return false;
|
|
}
|
|
data += sent;
|
|
remaining -= sent;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::string HttpControlServer::WebSocketAcceptKey(const std::string& clientKey)
|
|
{
|
|
static constexpr const char* kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
const std::array<uint8_t, 20> digest = Sha1(clientKey + kWebSocketGuid);
|
|
return Base64Encode(digest.data(), digest.size());
|
|
}
|
|
}
|