#include "stdafx.h" #include "ControlServer.h" #include "RuntimeJson.h" #include #include #include #include #include #pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Crypt32.lib") #pragma comment(lib, "Advapi32.lib") namespace { constexpr DWORD kStateBroadcastIntervalMs = 250; constexpr DWORD kStateBroadcastThrottleMs = 50; bool InitializeWinsock(std::string& error) { WSADATA wsaData = {}; int result = WSAStartup(MAKEWORD(2, 2), &wsaData); if (result != 0) { error = "WSAStartup failed."; return false; } return true; } std::string ToLower(std::string text) { std::transform(text.begin(), text.end(), text.begin(), [](unsigned char ch) { return static_cast(std::tolower(ch)); }); return text; } bool IsSafeUiPath(const std::filesystem::path& relativePath) { for (const std::filesystem::path& part : relativePath) { if (part == "..") return false; } return !relativePath.empty(); } std::string GuessContentType(const std::filesystem::path& assetPath) { const std::string extension = ToLower(assetPath.extension().string()); if (extension == ".js" || extension == ".mjs") return "text/javascript"; if (extension == ".css") return "text/css"; if (extension == ".json") return "application/json"; if (extension == ".yaml" || extension == ".yml") return "application/yaml"; if (extension == ".svg") return "image/svg+xml"; if (extension == ".png") return "image/png"; if (extension == ".jpg" || extension == ".jpeg") return "image/jpeg"; if (extension == ".ico") return "image/x-icon"; if (extension == ".map") return "application/json"; if (extension == ".md") return "text/markdown"; return "text/html"; } } ControlServer::ControlServer() : mPort(0), mRunning(false), mBroadcastPending(false) { } ControlServer::~ControlServer() { Stop(); } bool ControlServer::Start(const std::filesystem::path& uiRoot, const std::filesystem::path& docsRoot, unsigned short preferredPort, const Callbacks& callbacks, std::string& error) { mUiRoot = uiRoot; mDocsRoot = docsRoot; mCallbacks = callbacks; if (!InitializeWinsock(error)) return false; mListenSocket.reset(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); if (!mListenSocket.valid()) { error = "Could not create listening socket."; return false; } u_long nonBlocking = 1; ioctlsocket(mListenSocket.get(), FIONBIO, &nonBlocking); sockaddr_in address = {}; address.sin_family = AF_INET; address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); bool bound = false; for (unsigned short offset = 0; offset < 20; ++offset) { address.sin_port = htons(static_cast(preferredPort + offset)); if (bind(mListenSocket.get(), reinterpret_cast(&address), sizeof(address)) == 0) { mPort = preferredPort + offset; bound = true; break; } } if (!bound) { error = "Could not bind the local control server to any port in the preferred range."; mListenSocket.reset(); return false; } if (listen(mListenSocket.get(), SOMAXCONN) != 0) { error = "Could not start listening on the local control server socket."; mListenSocket.reset(); return false; } mRunning = true; mThread = std::thread(&ControlServer::ServerLoop, this); return true; } void ControlServer::Stop() { const bool wasActive = mRunning || mListenSocket.valid() || mThread.joinable(); mRunning = false; { std::lock_guard lock(mMutex); for (ClientConnection& client : mClients) client.socket.reset(); mClients.clear(); } mListenSocket.reset(); if (mThread.joinable()) mThread.join(); if (wasActive) WSACleanup(); } void ControlServer::BroadcastState() { mBroadcastPending = false; std::lock_guard lock(mMutex); BroadcastStateLocked(); } void ControlServer::RequestBroadcastState() { mBroadcastPending = true; } void ControlServer::ServerLoop() { DWORD lastStateBroadcastMs = GetTickCount(); while (mRunning) { TryAcceptClient(); const DWORD nowMs = GetTickCount(); if (mBroadcastPending && nowMs - lastStateBroadcastMs >= kStateBroadcastThrottleMs) { BroadcastState(); lastStateBroadcastMs = nowMs; } else if (nowMs - lastStateBroadcastMs >= kStateBroadcastIntervalMs) { BroadcastState(); lastStateBroadcastMs = nowMs; } Sleep(25); } } bool ControlServer::HandleHttpClient(UniqueSocket clientSocket) { std::string request; char buffer[8192]; int received = recv(clientSocket.get(), buffer, sizeof(buffer), 0); if (received <= 0) return false; request.assign(buffer, buffer + received); return HandleHttpRequest(std::move(clientSocket), request); } bool ControlServer::TryAcceptClient() { sockaddr_in clientAddress = {}; int addressSize = sizeof(clientAddress); UniqueSocket clientSocket(accept(mListenSocket.get(), reinterpret_cast(&clientAddress), &addressSize)); if (!clientSocket.valid()) return false; return HandleHttpClient(std::move(clientSocket)); } bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body) { std::ostringstream response; response << "HTTP/1.1 " << status << "\r\n"; response << "Content-Type: " << contentType << "\r\n"; response << "Content-Length: " << body.size() << "\r\n"; response << "Connection: close\r\n\r\n"; response << body; const std::string payload = response.str(); return send(clientSocket, payload.c_str(), static_cast(payload.size()), 0) == static_cast(payload.size()); } bool ControlServer::SendHttpResponse(SOCKET clientSocket, const HttpResponse& response) { return SendHttpResponse(clientSocket, response.status, response.contentType, response.body); } bool ControlServer::HandleHttpRequest(UniqueSocket clientSocket, const std::string& request) { HttpRequest httpRequest; if (!ParseHttpRequest(request, httpRequest)) { SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Bad Request"); return true; } if (ToLower(GetHeaderValue(httpRequest, "Upgrade")) == "websocket") return HandleWebSocketUpgrade(std::move(clientSocket), httpRequest); const HttpResponse response = RouteHttpRequest(httpRequest); SendHttpResponse(clientSocket.get(), response); if (response.broadcastState) BroadcastState(); return true; } ControlServer::HttpResponse ControlServer::RouteHttpRequest(const HttpRequest& request) { if (request.method == "GET") return ServeGetRequest(request); if (request.method == "POST") return HandleApiPost(request); return { "404 Not Found", "text/plain", "Not Found" }; } ControlServer::HttpResponse ControlServer::ServeGetRequest(const HttpRequest& request) const { if (request.path == "/" || request.path == "/index.html") return ServeUiAsset("index.html"); if (request.path == "/api/state") return { "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}" }; if (request.path == "/openapi.yaml" || request.path == "/docs/openapi.yaml") return ServeOpenApiSpec(); if (request.path == "/docs" || request.path == "/docs/") return ServeSwaggerDocs(); const std::string docsPrefix = "/docs/"; if (request.path.rfind(docsPrefix, 0) == 0) return ServeDocsAsset(request.path.substr(docsPrefix.size())); if (request.path.size() > 1) { const HttpResponse assetResponse = ServeUiAsset(request.path.substr(1)); if (!assetResponse.body.empty()) return assetResponse; } return { "404 Not Found", "text/plain", "Not Found" }; } ControlServer::HttpResponse ControlServer::ServeUiAsset(const std::string& relativePath) const { std::string contentType; const std::string body = LoadUiAsset(relativePath, contentType); return body.empty() ? HttpResponse{ "404 Not Found", "text/plain", "Not Found" } : HttpResponse{ "200 OK", contentType, body }; } ControlServer::HttpResponse ControlServer::ServeDocsAsset(const std::string& relativePath) const { const std::filesystem::path sanitizedPath = std::filesystem::path(relativePath).lexically_normal(); if (!IsSafeUiPath(sanitizedPath)) return { "404 Not Found", "text/plain", "Not Found" }; const std::filesystem::path docsPath = mDocsRoot / sanitizedPath; const std::string body = LoadTextFile(docsPath); return body.empty() ? HttpResponse{ "404 Not Found", "text/plain", "Not Found" } : HttpResponse{ "200 OK", GuessContentType(docsPath), body }; } ControlServer::HttpResponse ControlServer::ServeOpenApiSpec() const { const std::filesystem::path specPath = mDocsRoot / "openapi.yaml"; const std::string body = LoadTextFile(specPath); return body.empty() ? HttpResponse{ "404 Not Found", "text/plain", "OpenAPI spec not found" } : HttpResponse{ "200 OK", GuessContentType(specPath), body }; } ControlServer::HttpResponse ControlServer::ServeSwaggerDocs() const { std::ostringstream html; html << "\n" << "\n" << "\n" << " \n" << " \n" << " Video Shader Toys API Docs\n" << " \n" << "\n" << "\n" << "
\n" << " \n" << " \n" << "\n" << "\n"; return { "200 OK", "text/html", html.str() }; } ControlServer::HttpResponse ControlServer::HandleApiPost(const HttpRequest& request) { JsonValue root; std::string parseError; if (!ParseJson(request.body, root, parseError)) return { "400 Bad Request", "application/json", BuildJsonResponse(false, parseError) }; std::string actionError; const bool success = InvokePostRoute(request.path, root, actionError); return { success ? "200 OK" : "400 Bad Request", "application/json", BuildJsonResponse(success, actionError), success }; } bool ControlServer::InvokePostRoute(const std::string& path, const JsonValue& root, std::string& actionError) { using PostHandler = std::function; const std::map postRoutes = { { "/api/layers/add", [this](const JsonValue& json, std::string& error) { const JsonValue* shaderId = json.find("shaderId"); return shaderId && mCallbacks.addLayer && mCallbacks.addLayer(shaderId->asString(), error); } }, { "/api/layers/remove", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); return layerId && mCallbacks.removeLayer && mCallbacks.removeLayer(layerId->asString(), error); } }, { "/api/layers/move", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); const JsonValue* direction = json.find("direction"); return layerId && direction && mCallbacks.moveLayer && mCallbacks.moveLayer(layerId->asString(), static_cast(direction->asNumber()), error); } }, { "/api/layers/reorder", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); const JsonValue* targetIndex = json.find("targetIndex"); return layerId && targetIndex && mCallbacks.moveLayerToIndex && mCallbacks.moveLayerToIndex(layerId->asString(), static_cast(targetIndex->asNumber()), error); } }, { "/api/layers/set-bypass", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); const JsonValue* bypass = json.find("bypass"); return layerId && bypass && mCallbacks.setLayerBypass && mCallbacks.setLayerBypass(layerId->asString(), bypass->asBoolean(), error); } }, { "/api/layers/set-shader", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); const JsonValue* shaderId = json.find("shaderId"); return layerId && shaderId && mCallbacks.setLayerShader && mCallbacks.setLayerShader(layerId->asString(), shaderId->asString(), error); } }, { "/api/layers/update-parameter", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); const JsonValue* parameterId = json.find("parameterId"); const JsonValue* value = json.find("value"); return layerId && parameterId && value && mCallbacks.updateLayerParameter && mCallbacks.updateLayerParameter(layerId->asString(), parameterId->asString(), SerializeJson(*value, false), error); } }, { "/api/layers/reset-parameters", [this](const JsonValue& json, std::string& error) { const JsonValue* layerId = json.find("layerId"); return layerId && mCallbacks.resetLayerParameters && mCallbacks.resetLayerParameters(layerId->asString(), error); } }, { "/api/stack-presets/save", [this](const JsonValue& json, std::string& error) { const JsonValue* presetName = json.find("presetName"); return presetName && mCallbacks.saveStackPreset && mCallbacks.saveStackPreset(presetName->asString(), error); } }, { "/api/stack-presets/load", [this](const JsonValue& json, std::string& error) { const JsonValue* presetName = json.find("presetName"); return presetName && mCallbacks.loadStackPreset && mCallbacks.loadStackPreset(presetName->asString(), error); } }, { "/api/reload", [this](const JsonValue&, std::string& error) { return mCallbacks.reloadShader && mCallbacks.reloadShader(error); } }, { "/api/screenshot", [this](const JsonValue&, std::string& error) { return mCallbacks.requestScreenshot && mCallbacks.requestScreenshot(error); } } }; const auto route = postRoutes.find(path); return route != postRoutes.end() && route->second(root, actionError); } bool ControlServer::HandleWebSocketUpgrade(UniqueSocket clientSocket, const HttpRequest& request) { const std::string clientKey = GetHeaderValue(request, "Sec-WebSocket-Key"); if (clientKey.empty()) { SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key"); return true; } std::ostringstream response; response << "HTTP/1.1 101 Switching Protocols\r\n"; response << "Upgrade: websocket\r\n"; response << "Connection: Upgrade\r\n"; response << "Sec-WebSocket-Accept: " << ComputeWebSocketAcceptKey(clientKey) << "\r\n\r\n"; const std::string payload = response.str(); send(clientSocket.get(), payload.c_str(), static_cast(payload.size()), 0); { std::lock_guard lock(mMutex); ClientConnection client; client.socket.reset(clientSocket.release()); client.websocket = true; mClients.push_back(std::move(client)); mBroadcastPending = false; BroadcastStateLocked(); } return true; } bool ControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& payload) { std::string frame; frame.push_back(static_cast(0x81)); if (payload.size() <= 125) { frame.push_back(static_cast(payload.size())); } else if (payload.size() <= 65535) { frame.push_back(126); frame.push_back(static_cast((payload.size() >> 8) & 0xFF)); frame.push_back(static_cast(payload.size() & 0xFF)); } else { frame.push_back(127); for (int shift = 56; shift >= 0; shift -= 8) frame.push_back(static_cast((payload.size() >> shift) & 0xFF)); } frame.append(payload); return send(clientSocket, frame.data(), static_cast(frame.size()), 0) == static_cast(frame.size()); } void ControlServer::BroadcastStateLocked() { if (mClients.empty()) return; const std::string stateMessage = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; for (auto it = mClients.begin(); it != mClients.end();) { if (!SendWebSocketText(it->socket.get(), stateMessage)) { it = mClients.erase(it); } else { ++it; } } } std::string ControlServer::LoadUiAsset(const std::string& relativePath, std::string& contentType) const { const std::filesystem::path sanitizedPath = std::filesystem::path(relativePath).lexically_normal(); if (!IsSafeUiPath(sanitizedPath)) return std::string(); const std::filesystem::path assetPath = mUiRoot / sanitizedPath; contentType = GuessContentType(assetPath); return LoadTextFile(assetPath); } std::string ControlServer::LoadTextFile(const std::filesystem::path& path) const { std::ifstream input(path, std::ios::binary); if (!input) return std::string(); std::ostringstream buffer; buffer << input.rdbuf(); return buffer.str(); } std::string ControlServer::BuildJsonResponse(bool success, const std::string& error) const { JsonValue response = JsonValue::MakeObject(); response.set("ok", JsonValue(success)); if (!error.empty()) response.set("error", JsonValue(error)); return SerializeJson(response, false); } std::string ControlServer::Base64Encode(const unsigned char* data, DWORD dataLength) { DWORD outputLength = 0; CryptBinaryToStringA(data, dataLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, NULL, &outputLength); std::string encoded(outputLength, '\0'); CryptBinaryToStringA(data, dataLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, &encoded[0], &outputLength); if (!encoded.empty() && encoded.back() == '\0') encoded.pop_back(); return encoded; } std::string ControlServer::ComputeWebSocketAcceptKey(const std::string& clientKey) { const std::string combined = clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; HCRYPTPROV provider = 0; HCRYPTHASH hash = 0; BYTE digest[20] = {}; DWORD digestLength = sizeof(digest); CryptAcquireContext(&provider, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT); CryptCreateHash(provider, CALG_SHA1, 0, 0, &hash); CryptHashData(hash, reinterpret_cast(combined.data()), static_cast(combined.size()), 0); CryptGetHashParam(hash, HP_HASHVAL, digest, &digestLength, 0); if (hash) CryptDestroyHash(hash); if (provider) CryptReleaseContext(provider, 0); return Base64Encode(digest, digestLength); } std::string ControlServer::GetHeaderValue(const HttpRequest& request, const std::string& headerName) { const auto header = request.headers.find(ToLower(headerName)); return header == request.headers.end() ? std::string() : header->second; } bool ControlServer::ParseHttpRequest(const std::string& rawRequest, HttpRequest& request) { const std::size_t requestLineEnd = rawRequest.find("\r\n"); if (requestLineEnd == std::string::npos) return false; const std::string requestLine = rawRequest.substr(0, requestLineEnd); const std::size_t methodEnd = requestLine.find(' '); if (methodEnd == std::string::npos) return false; const std::size_t pathEnd = requestLine.find(' ', methodEnd + 1); if (pathEnd == std::string::npos) return false; request.method = requestLine.substr(0, methodEnd); request.path = requestLine.substr(methodEnd + 1, pathEnd - methodEnd - 1); request.headers.clear(); const std::size_t headersStart = requestLineEnd + 2; const std::size_t bodySeparator = rawRequest.find("\r\n\r\n", headersStart); const std::size_t headersEnd = bodySeparator == std::string::npos ? rawRequest.size() : bodySeparator; for (std::size_t lineStart = headersStart; lineStart < headersEnd;) { const std::size_t lineEnd = rawRequest.find("\r\n", lineStart); const std::size_t currentLineEnd = lineEnd == std::string::npos ? headersEnd : std::min(lineEnd, headersEnd); const std::string line = rawRequest.substr(lineStart, currentLineEnd - lineStart); const std::size_t separator = line.find(':'); if (separator != std::string::npos) { const std::string key = ToLower(line.substr(0, separator)); std::string value = line.substr(separator + 1); const std::size_t first = value.find_first_not_of(" \t"); const std::size_t last = value.find_last_not_of(" \t"); request.headers[key] = first == std::string::npos ? std::string() : value.substr(first, last - first + 1); } if (lineEnd == std::string::npos || lineEnd >= headersEnd) break; lineStart = lineEnd + 2; } request.body = bodySeparator == std::string::npos ? std::string() : rawRequest.substr(bodySeparator + 4); return !request.method.empty() && !request.path.empty(); }