#include "stdafx.h" #include "ControlServer.h" #include "RuntimeJson.h" #include #include #include #include #include #pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Crypt32.lib") #pragma comment(lib, "Advapi32.lib") namespace { bool InitializeWinsock(std::string& error) { WSADATA wsaData = {}; int result = WSAStartup(MAKEWORD(2, 2), &wsaData); if (result != 0) { error = "WSAStartup failed."; return false; } return true; } std::string ToLower(std::string text) { std::transform(text.begin(), text.end(), text.begin(), [](unsigned char ch) { return static_cast(std::tolower(ch)); }); return text; } } ControlServer::ControlServer() : mListenSocket(INVALID_SOCKET), mPort(0), mRunning(false) { } ControlServer::~ControlServer() { Stop(); } bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short preferredPort, const Callbacks& callbacks, std::string& error) { mUiRoot = uiRoot; mCallbacks = callbacks; if (!InitializeWinsock(error)) return false; mListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (mListenSocket == INVALID_SOCKET) { error = "Could not create listening socket."; return false; } u_long nonBlocking = 1; ioctlsocket(mListenSocket, FIONBIO, &nonBlocking); sockaddr_in address = {}; address.sin_family = AF_INET; address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); bool bound = false; for (unsigned short offset = 0; offset < 20; ++offset) { address.sin_port = htons(static_cast(preferredPort + offset)); if (bind(mListenSocket, reinterpret_cast(&address), sizeof(address)) == 0) { mPort = preferredPort + offset; bound = true; break; } } if (!bound) { error = "Could not bind the local control server to any port in the preferred range."; closesocket(mListenSocket); mListenSocket = INVALID_SOCKET; return false; } if (listen(mListenSocket, SOMAXCONN) != 0) { error = "Could not start listening on the local control server socket."; closesocket(mListenSocket); mListenSocket = INVALID_SOCKET; return false; } mRunning = true; mThread = std::thread(&ControlServer::ServerLoop, this); return true; } void ControlServer::Stop() { const bool wasActive = mRunning || mListenSocket != INVALID_SOCKET || mThread.joinable(); mRunning = false; { std::lock_guard lock(mMutex); for (ClientConnection& client : mClients) { if (client.socket != INVALID_SOCKET) { closesocket(client.socket); client.socket = INVALID_SOCKET; } } mClients.clear(); } if (mListenSocket != INVALID_SOCKET) { closesocket(mListenSocket); mListenSocket = INVALID_SOCKET; } if (mThread.joinable()) mThread.join(); if (wasActive) WSACleanup(); } void ControlServer::BroadcastState() { std::lock_guard lock(mMutex); BroadcastStateLocked(); } void ControlServer::ServerLoop() { while (mRunning) { TryAcceptClient(); Sleep(25); } } bool ControlServer::HandleHttpClient(SOCKET clientSocket) { std::string request; char buffer[8192]; int received = recv(clientSocket, buffer, sizeof(buffer), 0); if (received <= 0) return false; request.assign(buffer, buffer + received); return HandleHttpRequest(clientSocket, request); } bool ControlServer::TryAcceptClient() { sockaddr_in clientAddress = {}; int addressSize = sizeof(clientAddress); SOCKET clientSocket = accept(mListenSocket, reinterpret_cast(&clientAddress), &addressSize); if (clientSocket == INVALID_SOCKET) return false; bool handled = HandleHttpClient(clientSocket); if (!handled) closesocket(clientSocket); return handled; } bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body) { std::ostringstream response; response << "HTTP/1.1 " << status << "\r\n"; response << "Content-Type: " << contentType << "\r\n"; response << "Content-Length: " << body.size() << "\r\n"; response << "Connection: close\r\n\r\n"; response << body; const std::string payload = response.str(); return send(clientSocket, payload.c_str(), static_cast(payload.size()), 0) == static_cast(payload.size()); } bool ControlServer::HandleHttpRequest(SOCKET clientSocket, const std::string& request) { const std::string method = GetRequestMethod(request); const std::string path = GetRequestPath(request); if (ToLower(GetHeaderValue(request, "Upgrade")) == "websocket") return HandleWebSocketUpgrade(clientSocket, request); if (method == "GET") { if (path == "/" || path == "/index.html") { std::string contentType; std::string body = LoadUiAsset("index.html", contentType); SendHttpResponse(clientSocket, "200 OK", contentType, body); closesocket(clientSocket); return true; } if (path == "/app.js" || path == "/styles.css") { std::string contentType; std::string body = LoadUiAsset(path.substr(1), contentType); SendHttpResponse(clientSocket, "200 OK", contentType, body); closesocket(clientSocket); return true; } if (path == "/api/state") { SendHttpResponse(clientSocket, "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"); closesocket(clientSocket); return true; } } else if (method == "POST") { std::string body = GetRequestBody(request); JsonValue root; std::string parseError; if (!ParseJson(body, root, parseError)) { SendHttpResponse(clientSocket, "400 Bad Request", "application/json", BuildJsonResponse(false, parseError)); closesocket(clientSocket); return true; } bool success = false; std::string actionError; if (path == "/api/select-shader") { const JsonValue* shaderId = root.find("shaderId"); success = shaderId && mCallbacks.selectShader && mCallbacks.selectShader(shaderId->asString(), actionError); } else if (path == "/api/update-parameter") { const JsonValue* shaderId = root.find("shaderId"); const JsonValue* parameterId = root.find("parameterId"); const JsonValue* value = root.find("value"); if (shaderId && parameterId && value && mCallbacks.updateParameter) success = mCallbacks.updateParameter(shaderId->asString(), parameterId->asString(), SerializeJson(*value, false), actionError); } else if (path == "/api/set-bypass") { const JsonValue* bypass = root.find("bypass"); if (bypass && mCallbacks.setBypass) success = mCallbacks.setBypass(bypass->asBoolean(), actionError); } else if (path == "/api/set-mix") { const JsonValue* mixAmount = root.find("mixAmount"); if (mixAmount && mCallbacks.setMixAmount) success = mCallbacks.setMixAmount(mixAmount->asNumber(), actionError); } else if (path == "/api/reload") { if (mCallbacks.reloadShader) success = mCallbacks.reloadShader(actionError); } SendHttpResponse(clientSocket, success ? "200 OK" : "400 Bad Request", "application/json", BuildJsonResponse(success, actionError)); closesocket(clientSocket); if (success) BroadcastState(); return true; } SendHttpResponse(clientSocket, "404 Not Found", "text/plain", "Not Found"); closesocket(clientSocket); return true; } bool ControlServer::HandleWebSocketUpgrade(SOCKET clientSocket, const std::string& request) { const std::string clientKey = GetHeaderValue(request, "Sec-WebSocket-Key"); if (clientKey.empty()) { SendHttpResponse(clientSocket, "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key"); closesocket(clientSocket); return true; } std::ostringstream response; response << "HTTP/1.1 101 Switching Protocols\r\n"; response << "Upgrade: websocket\r\n"; response << "Connection: Upgrade\r\n"; response << "Sec-WebSocket-Accept: " << ComputeWebSocketAcceptKey(clientKey) << "\r\n\r\n"; const std::string payload = response.str(); send(clientSocket, payload.c_str(), static_cast(payload.size()), 0); { std::lock_guard lock(mMutex); ClientConnection client; client.socket = clientSocket; client.websocket = true; mClients.push_back(client); BroadcastStateLocked(); } return true; } bool ControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& payload) { std::string frame; frame.push_back(static_cast(0x81)); if (payload.size() <= 125) { frame.push_back(static_cast(payload.size())); } else if (payload.size() <= 65535) { frame.push_back(126); frame.push_back(static_cast((payload.size() >> 8) & 0xFF)); frame.push_back(static_cast(payload.size() & 0xFF)); } else { frame.push_back(127); for (int shift = 56; shift >= 0; shift -= 8) frame.push_back(static_cast((payload.size() >> shift) & 0xFF)); } frame.append(payload); return send(clientSocket, frame.data(), static_cast(frame.size()), 0) == static_cast(frame.size()); } void ControlServer::BroadcastStateLocked() { const std::string stateMessage = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; for (auto it = mClients.begin(); it != mClients.end();) { if (!SendWebSocketText(it->socket, stateMessage)) { closesocket(it->socket); it = mClients.erase(it); } else { ++it; } } } std::string ControlServer::LoadUiAsset(const std::string& relativePath, std::string& contentType) const { const std::filesystem::path assetPath = mUiRoot / relativePath; std::ifstream input(assetPath, std::ios::binary); if (!input) return "Missing UI asset

UI asset missing.

"; if (assetPath.extension() == ".js") contentType = "text/javascript"; else if (assetPath.extension() == ".css") contentType = "text/css"; else contentType = "text/html"; std::ostringstream buffer; buffer << input.rdbuf(); return buffer.str(); } std::string ControlServer::BuildJsonResponse(bool success, const std::string& error) const { JsonValue response = JsonValue::MakeObject(); response.set("ok", JsonValue(success)); if (!error.empty()) response.set("error", JsonValue(error)); return SerializeJson(response, false); } std::string ControlServer::Base64Encode(const unsigned char* data, DWORD dataLength) { DWORD outputLength = 0; CryptBinaryToStringA(data, dataLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, NULL, &outputLength); std::string encoded(outputLength, '\0'); CryptBinaryToStringA(data, dataLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, &encoded[0], &outputLength); if (!encoded.empty() && encoded.back() == '\0') encoded.pop_back(); return encoded; } std::string ControlServer::ComputeWebSocketAcceptKey(const std::string& clientKey) { const std::string combined = clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; HCRYPTPROV provider = 0; HCRYPTHASH hash = 0; BYTE digest[20] = {}; DWORD digestLength = sizeof(digest); CryptAcquireContext(&provider, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT); CryptCreateHash(provider, CALG_SHA1, 0, 0, &hash); CryptHashData(hash, reinterpret_cast(combined.data()), static_cast(combined.size()), 0); CryptGetHashParam(hash, HP_HASHVAL, digest, &digestLength, 0); if (hash) CryptDestroyHash(hash); if (provider) CryptReleaseContext(provider, 0); return Base64Encode(digest, digestLength); } std::string ControlServer::GetHeaderValue(const std::string& request, const std::string& headerName) { const std::string lowerRequest = ToLower(request); const std::string lowerHeaderName = ToLower(headerName) + ":"; const std::size_t start = lowerRequest.find(lowerHeaderName); if (start == std::string::npos) return std::string(); const std::size_t valueStart = start + lowerHeaderName.size(); const std::size_t lineEnd = request.find("\r\n", valueStart); if (lineEnd == std::string::npos) return std::string(); std::string value = request.substr(valueStart, lineEnd - valueStart); const std::size_t first = value.find_first_not_of(" \t"); const std::size_t last = value.find_last_not_of(" \t"); return first == std::string::npos ? std::string() : value.substr(first, last - first + 1); } std::string ControlServer::GetRequestPath(const std::string& request) { const std::size_t methodEnd = request.find(' '); if (methodEnd == std::string::npos) return "/"; const std::size_t pathEnd = request.find(' ', methodEnd + 1); if (pathEnd == std::string::npos) return "/"; return request.substr(methodEnd + 1, pathEnd - methodEnd - 1); } std::string ControlServer::GetRequestMethod(const std::string& request) { const std::size_t methodEnd = request.find(' '); return methodEnd == std::string::npos ? std::string() : request.substr(0, methodEnd); } std::string ControlServer::GetRequestBody(const std::string& request) { const std::size_t separator = request.find("\r\n\r\n"); if (separator == std::string::npos) return std::string(); return request.substr(separator + 4); }