#include "HttpControlServer.h" #include "../logging/Logger.h" #include #include #include namespace RenderCadenceCompositor { namespace { bool InitializeWinsock(std::string& error) { WSADATA wsaData = {}; const int result = WSAStartup(MAKEWORD(2, 2), &wsaData); if (result != 0) { error = "WSAStartup failed."; return false; } return true; } } UniqueSocket::UniqueSocket(SOCKET socket) : mSocket(socket) { } UniqueSocket::~UniqueSocket() { reset(); } UniqueSocket::UniqueSocket(UniqueSocket&& other) noexcept : mSocket(other.release()) { } UniqueSocket& UniqueSocket::operator=(UniqueSocket&& other) noexcept { if (this != &other) reset(other.release()); return *this; } SOCKET UniqueSocket::release() { const SOCKET socket = mSocket; mSocket = INVALID_SOCKET; return socket; } void UniqueSocket::reset(SOCKET socket) { if (valid()) closesocket(mSocket); mSocket = socket; } HttpControlServer::~HttpControlServer() { Stop(); } bool HttpControlServer::Start( const std::filesystem::path& uiRoot, const std::filesystem::path& docsRoot, HttpControlServerConfig config, HttpControlServerCallbacks callbacks, std::string& error) { Stop(); if (!InitializeWinsock(error)) return false; mWinsockStarted = true; mUiRoot = uiRoot; mDocsRoot = docsRoot; mConfig = config; mCallbacks = std::move(callbacks); mListenSocket.reset(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); if (!mListenSocket.valid()) { error = "Could not create HTTP control server socket."; Stop(); return false; } u_long nonBlocking = 1; ioctlsocket(mListenSocket.get(), FIONBIO, &nonBlocking); sockaddr_in address = {}; address.sin_family = AF_INET; address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); bool bound = false; for (unsigned short offset = 0; offset < mConfig.portSearchCount; ++offset) { address.sin_port = htons(static_cast(mConfig.preferredPort + offset)); if (bind(mListenSocket.get(), reinterpret_cast(&address), sizeof(address)) == 0) { mPort = static_cast(mConfig.preferredPort + offset); bound = true; break; } } if (!bound) { error = "Could not bind HTTP control server to loopback."; Stop(); return false; } if (listen(mListenSocket.get(), SOMAXCONN) != 0) { error = "Could not listen on HTTP control server socket."; Stop(); return false; } mRunning.store(true, std::memory_order_release); mThread = std::thread([this]() { ThreadMain(); }); Log("http", "HTTP control server listening on http://127.0.0.1:" + std::to_string(mPort)); return true; } void HttpControlServer::Stop() { mRunning.store(false, std::memory_order_release); mListenSocket.reset(); if (mThread.joinable()) mThread.join(); std::vector clientThreads; { std::lock_guard lock(mClientThreadsMutex); clientThreads.swap(mClientThreads); for (std::thread& thread : mFinishedClientThreads) clientThreads.push_back(std::move(thread)); mFinishedClientThreads.clear(); } for (std::thread& thread : clientThreads) { if (thread.joinable()) thread.join(); } if (mWinsockStarted) { WSACleanup(); mWinsockStarted = false; } mPort = 0; } HttpControlServer::HttpResponse HttpControlServer::RouteRequestForTest(const HttpRequest& request) const { return RouteRequest(request); } void HttpControlServer::SetCallbacksForTest(HttpControlServerCallbacks callbacks) { mCallbacks = std::move(callbacks); } void HttpControlServer::SetRootsForTest(const std::filesystem::path& uiRoot, const std::filesystem::path& docsRoot) { mUiRoot = uiRoot; mDocsRoot = docsRoot; } void HttpControlServer::ThreadMain() { while (mRunning.load(std::memory_order_acquire)) { JoinFinishedClientThreads(); TryAcceptClient(); std::this_thread::sleep_for(mConfig.idleSleep); } } bool HttpControlServer::TryAcceptClient() { sockaddr_in clientAddress = {}; int addressSize = sizeof(clientAddress); UniqueSocket clientSocket(accept(mListenSocket.get(), reinterpret_cast(&clientAddress), &addressSize)); if (!clientSocket.valid()) return false; return HandleClient(std::move(clientSocket)); } bool HttpControlServer::HandleClient(UniqueSocket clientSocket) { char buffer[16384]; const int received = recv(clientSocket.get(), buffer, sizeof(buffer), 0); if (received <= 0) return false; HttpRequest request; if (!ParseHttpRequest(std::string(buffer, buffer + received), request)) return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Bad Request")); if (request.path == "/ws") return HandleWebSocketClient(std::move(clientSocket), request); return SendResponse(clientSocket.get(), RouteRequest(request)); } bool HttpControlServer::SendResponse(SOCKET clientSocket, const HttpResponse& response) const { std::ostringstream stream; stream << "HTTP/1.1 " << response.status << "\r\n" << "Content-Type: " << response.contentType << "\r\n" << "Content-Length: " << response.body.size() << "\r\n" << "Access-Control-Allow-Origin: *\r\n" << "Connection: close\r\n\r\n" << response.body; const std::string payload = stream.str(); return send(clientSocket, payload.c_str(), static_cast(payload.size()), 0) == static_cast(payload.size()); } HttpControlServer::HttpResponse HttpControlServer::RouteRequest(const HttpRequest& request) const { if (request.method == "GET") return ServeGet(request); if (request.method == "POST") return ServePost(request); if (request.method == "OPTIONS") return TextResponse("204 No Content", std::string()); return TextResponse("404 Not Found", "Not Found"); } bool HttpControlServer::ParseHttpRequest(const std::string& rawRequest, HttpRequest& request) { const std::size_t requestLineEnd = rawRequest.find("\r\n"); if (requestLineEnd == std::string::npos) return false; const std::string requestLine = rawRequest.substr(0, requestLineEnd); const std::size_t methodEnd = requestLine.find(' '); if (methodEnd == std::string::npos) return false; const std::size_t pathEnd = requestLine.find(' ', methodEnd + 1); if (pathEnd == std::string::npos) return false; request.method = requestLine.substr(0, methodEnd); request.path = requestLine.substr(methodEnd + 1, pathEnd - methodEnd - 1); request.headers.clear(); const std::size_t queryStart = request.path.find('?'); if (queryStart != std::string::npos) request.path = request.path.substr(0, queryStart); const std::size_t headersStart = requestLineEnd + 2; const std::size_t bodySeparator = rawRequest.find("\r\n\r\n", headersStart); const std::size_t headersEnd = bodySeparator == std::string::npos ? rawRequest.size() : bodySeparator; for (std::size_t lineStart = headersStart; lineStart < headersEnd;) { const std::size_t lineEnd = rawRequest.find("\r\n", lineStart); const std::size_t currentLineEnd = lineEnd == std::string::npos ? headersEnd : (std::min)(lineEnd, headersEnd); const std::string line = rawRequest.substr(lineStart, currentLineEnd - lineStart); const std::size_t separator = line.find(':'); if (separator != std::string::npos) { const std::string key = ToLower(line.substr(0, separator)); std::string value = line.substr(separator + 1); const std::size_t first = value.find_first_not_of(" \t"); const std::size_t last = value.find_last_not_of(" \t"); request.headers[key] = first == std::string::npos ? std::string() : value.substr(first, last - first + 1); } if (lineEnd == std::string::npos || lineEnd >= headersEnd) break; lineStart = lineEnd + 2; } request.body = bodySeparator == std::string::npos ? std::string() : rawRequest.substr(bodySeparator + 4); return !request.method.empty() && !request.path.empty(); } }