294 lines
7.6 KiB
C++
294 lines
7.6 KiB
C++
#include "HttpControlServer.h"
|
|
|
|
#include "../logging/Logger.h"
|
|
|
|
#include <ws2tcpip.h>
|
|
|
|
#include <algorithm>
|
|
#include <sstream>
|
|
|
|
namespace RenderCadenceCompositor
|
|
{
|
|
namespace
|
|
{
|
|
bool InitializeWinsock(std::string& error)
|
|
{
|
|
WSADATA wsaData = {};
|
|
const int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
|
if (result != 0)
|
|
{
|
|
error = "WSAStartup failed.";
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
}
|
|
|
|
UniqueSocket::UniqueSocket(SOCKET socket) :
|
|
mSocket(socket)
|
|
{
|
|
}
|
|
|
|
UniqueSocket::~UniqueSocket()
|
|
{
|
|
reset();
|
|
}
|
|
|
|
UniqueSocket::UniqueSocket(UniqueSocket&& other) noexcept :
|
|
mSocket(other.release())
|
|
{
|
|
}
|
|
|
|
UniqueSocket& UniqueSocket::operator=(UniqueSocket&& other) noexcept
|
|
{
|
|
if (this != &other)
|
|
reset(other.release());
|
|
return *this;
|
|
}
|
|
|
|
SOCKET UniqueSocket::release()
|
|
{
|
|
const SOCKET socket = mSocket;
|
|
mSocket = INVALID_SOCKET;
|
|
return socket;
|
|
}
|
|
|
|
void UniqueSocket::reset(SOCKET socket)
|
|
{
|
|
if (valid())
|
|
closesocket(mSocket);
|
|
mSocket = socket;
|
|
}
|
|
|
|
HttpControlServer::~HttpControlServer()
|
|
{
|
|
Stop();
|
|
}
|
|
|
|
bool HttpControlServer::Start(
|
|
const std::filesystem::path& uiRoot,
|
|
const std::filesystem::path& docsRoot,
|
|
HttpControlServerConfig config,
|
|
HttpControlServerCallbacks callbacks,
|
|
std::string& error)
|
|
{
|
|
Stop();
|
|
|
|
if (!InitializeWinsock(error))
|
|
return false;
|
|
mWinsockStarted = true;
|
|
|
|
mUiRoot = uiRoot;
|
|
mDocsRoot = docsRoot;
|
|
mConfig = config;
|
|
mCallbacks = std::move(callbacks);
|
|
|
|
mListenSocket.reset(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
|
if (!mListenSocket.valid())
|
|
{
|
|
error = "Could not create HTTP control server socket.";
|
|
Stop();
|
|
return false;
|
|
}
|
|
|
|
u_long nonBlocking = 1;
|
|
ioctlsocket(mListenSocket.get(), FIONBIO, &nonBlocking);
|
|
|
|
sockaddr_in address = {};
|
|
address.sin_family = AF_INET;
|
|
address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
|
|
|
|
bool bound = false;
|
|
for (unsigned short offset = 0; offset < mConfig.portSearchCount; ++offset)
|
|
{
|
|
address.sin_port = htons(static_cast<u_short>(mConfig.preferredPort + offset));
|
|
if (bind(mListenSocket.get(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == 0)
|
|
{
|
|
mPort = static_cast<unsigned short>(mConfig.preferredPort + offset);
|
|
bound = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!bound)
|
|
{
|
|
error = "Could not bind HTTP control server to loopback.";
|
|
Stop();
|
|
return false;
|
|
}
|
|
|
|
if (listen(mListenSocket.get(), SOMAXCONN) != 0)
|
|
{
|
|
error = "Could not listen on HTTP control server socket.";
|
|
Stop();
|
|
return false;
|
|
}
|
|
|
|
mRunning.store(true, std::memory_order_release);
|
|
mThread = std::thread([this]() { ThreadMain(); });
|
|
Log("http", "HTTP control server listening on http://127.0.0.1:" + std::to_string(mPort));
|
|
return true;
|
|
}
|
|
|
|
void HttpControlServer::Stop()
|
|
{
|
|
mRunning.store(false, std::memory_order_release);
|
|
mListenSocket.reset();
|
|
|
|
if (mThread.joinable())
|
|
mThread.join();
|
|
|
|
std::vector<std::thread> clientThreads;
|
|
{
|
|
std::lock_guard<std::mutex> lock(mClientThreadsMutex);
|
|
clientThreads.swap(mClientThreads);
|
|
for (std::thread& thread : mFinishedClientThreads)
|
|
clientThreads.push_back(std::move(thread));
|
|
mFinishedClientThreads.clear();
|
|
}
|
|
for (std::thread& thread : clientThreads)
|
|
{
|
|
if (thread.joinable())
|
|
thread.join();
|
|
}
|
|
|
|
if (mWinsockStarted)
|
|
{
|
|
WSACleanup();
|
|
mWinsockStarted = false;
|
|
}
|
|
mPort = 0;
|
|
}
|
|
|
|
HttpControlServer::HttpResponse HttpControlServer::RouteRequestForTest(const HttpRequest& request) const
|
|
{
|
|
return RouteRequest(request);
|
|
}
|
|
|
|
void HttpControlServer::SetCallbacksForTest(HttpControlServerCallbacks callbacks)
|
|
{
|
|
mCallbacks = std::move(callbacks);
|
|
}
|
|
|
|
void HttpControlServer::SetRootsForTest(const std::filesystem::path& uiRoot, const std::filesystem::path& docsRoot)
|
|
{
|
|
mUiRoot = uiRoot;
|
|
mDocsRoot = docsRoot;
|
|
}
|
|
|
|
void HttpControlServer::ThreadMain()
|
|
{
|
|
while (mRunning.load(std::memory_order_acquire))
|
|
{
|
|
JoinFinishedClientThreads();
|
|
TryAcceptClient();
|
|
std::this_thread::sleep_for(mConfig.idleSleep);
|
|
}
|
|
}
|
|
|
|
bool HttpControlServer::TryAcceptClient()
|
|
{
|
|
sockaddr_in clientAddress = {};
|
|
int addressSize = sizeof(clientAddress);
|
|
UniqueSocket clientSocket(accept(mListenSocket.get(), reinterpret_cast<sockaddr*>(&clientAddress), &addressSize));
|
|
if (!clientSocket.valid())
|
|
return false;
|
|
|
|
return HandleClient(std::move(clientSocket));
|
|
}
|
|
|
|
bool HttpControlServer::HandleClient(UniqueSocket clientSocket)
|
|
{
|
|
char buffer[16384];
|
|
const int received = recv(clientSocket.get(), buffer, sizeof(buffer), 0);
|
|
if (received <= 0)
|
|
return false;
|
|
|
|
HttpRequest request;
|
|
if (!ParseHttpRequest(std::string(buffer, buffer + received), request))
|
|
return SendResponse(clientSocket.get(), TextResponse("400 Bad Request", "Bad Request"));
|
|
|
|
if (request.path == "/ws")
|
|
return HandleWebSocketClient(std::move(clientSocket), request);
|
|
|
|
return SendResponse(clientSocket.get(), RouteRequest(request));
|
|
}
|
|
|
|
bool HttpControlServer::SendResponse(SOCKET clientSocket, const HttpResponse& response) const
|
|
{
|
|
std::ostringstream stream;
|
|
stream << "HTTP/1.1 " << response.status << "\r\n"
|
|
<< "Content-Type: " << response.contentType << "\r\n"
|
|
<< "Content-Length: " << response.body.size() << "\r\n"
|
|
<< "Access-Control-Allow-Origin: *\r\n"
|
|
<< "Connection: close\r\n\r\n"
|
|
<< response.body;
|
|
|
|
const std::string payload = stream.str();
|
|
return send(clientSocket, payload.c_str(), static_cast<int>(payload.size()), 0) == static_cast<int>(payload.size());
|
|
}
|
|
|
|
HttpControlServer::HttpResponse HttpControlServer::RouteRequest(const HttpRequest& request) const
|
|
{
|
|
if (request.method == "GET")
|
|
return ServeGet(request);
|
|
if (request.method == "POST")
|
|
return ServePost(request);
|
|
if (request.method == "OPTIONS")
|
|
return TextResponse("204 No Content", std::string());
|
|
return TextResponse("404 Not Found", "Not Found");
|
|
}
|
|
|
|
bool HttpControlServer::ParseHttpRequest(const std::string& rawRequest, HttpRequest& request)
|
|
{
|
|
const std::size_t requestLineEnd = rawRequest.find("\r\n");
|
|
if (requestLineEnd == std::string::npos)
|
|
return false;
|
|
|
|
const std::string requestLine = rawRequest.substr(0, requestLineEnd);
|
|
const std::size_t methodEnd = requestLine.find(' ');
|
|
if (methodEnd == std::string::npos)
|
|
return false;
|
|
|
|
const std::size_t pathEnd = requestLine.find(' ', methodEnd + 1);
|
|
if (pathEnd == std::string::npos)
|
|
return false;
|
|
|
|
request.method = requestLine.substr(0, methodEnd);
|
|
request.path = requestLine.substr(methodEnd + 1, pathEnd - methodEnd - 1);
|
|
request.headers.clear();
|
|
|
|
const std::size_t queryStart = request.path.find('?');
|
|
if (queryStart != std::string::npos)
|
|
request.path = request.path.substr(0, queryStart);
|
|
|
|
const std::size_t headersStart = requestLineEnd + 2;
|
|
const std::size_t bodySeparator = rawRequest.find("\r\n\r\n", headersStart);
|
|
const std::size_t headersEnd = bodySeparator == std::string::npos ? rawRequest.size() : bodySeparator;
|
|
|
|
for (std::size_t lineStart = headersStart; lineStart < headersEnd;)
|
|
{
|
|
const std::size_t lineEnd = rawRequest.find("\r\n", lineStart);
|
|
const std::size_t currentLineEnd = lineEnd == std::string::npos ? headersEnd : (std::min)(lineEnd, headersEnd);
|
|
const std::string line = rawRequest.substr(lineStart, currentLineEnd - lineStart);
|
|
const std::size_t separator = line.find(':');
|
|
if (separator != std::string::npos)
|
|
{
|
|
const std::string key = ToLower(line.substr(0, separator));
|
|
std::string value = line.substr(separator + 1);
|
|
const std::size_t first = value.find_first_not_of(" \t");
|
|
const std::size_t last = value.find_last_not_of(" \t");
|
|
request.headers[key] = first == std::string::npos ? std::string() : value.substr(first, last - first + 1);
|
|
}
|
|
|
|
if (lineEnd == std::string::npos || lineEnd >= headersEnd)
|
|
break;
|
|
lineStart = lineEnd + 2;
|
|
}
|
|
|
|
request.body = bodySeparator == std::string::npos ? std::string() : rawRequest.substr(bodySeparator + 4);
|
|
return !request.method.empty() && !request.path.empty();
|
|
}
|
|
}
|