Files
video-shader-toys/apps/LoopThroughWithOpenGLCompositing/control/ControlServer.cpp
Aiden 05d0bcbedd
All checks were successful
CI / React UI Build (push) Successful in 11s
CI / Native Windows Build And Tests (push) Successful in 1m35s
CI / Windows Release Package (push) Successful in 2m17s
PNG writer
2026-05-08 15:33:40 +10:00

633 lines
19 KiB
C++

#include "stdafx.h"
#include "ControlServer.h"
#include "RuntimeJson.h"
#include <Wincrypt.h>
#include <ws2tcpip.h>
#include <algorithm>
#include <fstream>
#include <sstream>
#pragma comment(lib, "Ws2_32.lib")
#pragma comment(lib, "Crypt32.lib")
#pragma comment(lib, "Advapi32.lib")
namespace
{
constexpr DWORD kStateBroadcastIntervalMs = 250;
bool InitializeWinsock(std::string& error)
{
WSADATA wsaData = {};
int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (result != 0)
{
error = "WSAStartup failed.";
return false;
}
return true;
}
std::string ToLower(std::string text)
{
std::transform(text.begin(), text.end(), text.begin(),
[](unsigned char ch) { return static_cast<char>(std::tolower(ch)); });
return text;
}
bool IsSafeUiPath(const std::filesystem::path& relativePath)
{
for (const std::filesystem::path& part : relativePath)
{
if (part == "..")
return false;
}
return !relativePath.empty();
}
std::string GuessContentType(const std::filesystem::path& assetPath)
{
const std::string extension = ToLower(assetPath.extension().string());
if (extension == ".js" || extension == ".mjs")
return "text/javascript";
if (extension == ".css")
return "text/css";
if (extension == ".json")
return "application/json";
if (extension == ".yaml" || extension == ".yml")
return "application/yaml";
if (extension == ".svg")
return "image/svg+xml";
if (extension == ".png")
return "image/png";
if (extension == ".jpg" || extension == ".jpeg")
return "image/jpeg";
if (extension == ".ico")
return "image/x-icon";
if (extension == ".map")
return "application/json";
if (extension == ".md")
return "text/markdown";
return "text/html";
}
}
ControlServer::ControlServer()
: mPort(0), mRunning(false)
{
}
ControlServer::~ControlServer()
{
Stop();
}
bool ControlServer::Start(const std::filesystem::path& uiRoot, const std::filesystem::path& docsRoot, unsigned short preferredPort, const Callbacks& callbacks, std::string& error)
{
mUiRoot = uiRoot;
mDocsRoot = docsRoot;
mCallbacks = callbacks;
if (!InitializeWinsock(error))
return false;
mListenSocket.reset(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
if (!mListenSocket.valid())
{
error = "Could not create listening socket.";
return false;
}
u_long nonBlocking = 1;
ioctlsocket(mListenSocket.get(), FIONBIO, &nonBlocking);
sockaddr_in address = {};
address.sin_family = AF_INET;
address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
bool bound = false;
for (unsigned short offset = 0; offset < 20; ++offset)
{
address.sin_port = htons(static_cast<u_short>(preferredPort + offset));
if (bind(mListenSocket.get(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == 0)
{
mPort = preferredPort + offset;
bound = true;
break;
}
}
if (!bound)
{
error = "Could not bind the local control server to any port in the preferred range.";
mListenSocket.reset();
return false;
}
if (listen(mListenSocket.get(), SOMAXCONN) != 0)
{
error = "Could not start listening on the local control server socket.";
mListenSocket.reset();
return false;
}
mRunning = true;
mThread = std::thread(&ControlServer::ServerLoop, this);
return true;
}
void ControlServer::Stop()
{
const bool wasActive = mRunning || mListenSocket.valid() || mThread.joinable();
mRunning = false;
{
std::lock_guard<std::mutex> lock(mMutex);
for (ClientConnection& client : mClients)
client.socket.reset();
mClients.clear();
}
mListenSocket.reset();
if (mThread.joinable())
mThread.join();
if (wasActive)
WSACleanup();
}
void ControlServer::BroadcastState()
{
std::lock_guard<std::mutex> lock(mMutex);
BroadcastStateLocked();
}
void ControlServer::ServerLoop()
{
DWORD lastStateBroadcastMs = GetTickCount();
while (mRunning)
{
TryAcceptClient();
const DWORD nowMs = GetTickCount();
if (nowMs - lastStateBroadcastMs >= kStateBroadcastIntervalMs)
{
BroadcastState();
lastStateBroadcastMs = nowMs;
}
Sleep(25);
}
}
bool ControlServer::HandleHttpClient(UniqueSocket clientSocket)
{
std::string request;
char buffer[8192];
int received = recv(clientSocket.get(), buffer, sizeof(buffer), 0);
if (received <= 0)
return false;
request.assign(buffer, buffer + received);
return HandleHttpRequest(std::move(clientSocket), request);
}
bool ControlServer::TryAcceptClient()
{
sockaddr_in clientAddress = {};
int addressSize = sizeof(clientAddress);
UniqueSocket clientSocket(accept(mListenSocket.get(), reinterpret_cast<sockaddr*>(&clientAddress), &addressSize));
if (!clientSocket.valid())
return false;
return HandleHttpClient(std::move(clientSocket));
}
bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body)
{
std::ostringstream response;
response << "HTTP/1.1 " << status << "\r\n";
response << "Content-Type: " << contentType << "\r\n";
response << "Content-Length: " << body.size() << "\r\n";
response << "Connection: close\r\n\r\n";
response << body;
const std::string payload = response.str();
return send(clientSocket, payload.c_str(), static_cast<int>(payload.size()), 0) == static_cast<int>(payload.size());
}
bool ControlServer::SendHttpResponse(SOCKET clientSocket, const HttpResponse& response)
{
return SendHttpResponse(clientSocket, response.status, response.contentType, response.body);
}
bool ControlServer::HandleHttpRequest(UniqueSocket clientSocket, const std::string& request)
{
HttpRequest httpRequest;
if (!ParseHttpRequest(request, httpRequest))
{
SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Bad Request");
return true;
}
if (ToLower(GetHeaderValue(httpRequest, "Upgrade")) == "websocket")
return HandleWebSocketUpgrade(std::move(clientSocket), httpRequest);
const HttpResponse response = RouteHttpRequest(httpRequest);
SendHttpResponse(clientSocket.get(), response);
if (response.broadcastState)
BroadcastState();
return true;
}
ControlServer::HttpResponse ControlServer::RouteHttpRequest(const HttpRequest& request)
{
if (request.method == "GET")
return ServeGetRequest(request);
if (request.method == "POST")
return HandleApiPost(request);
return { "404 Not Found", "text/plain", "Not Found" };
}
ControlServer::HttpResponse ControlServer::ServeGetRequest(const HttpRequest& request) const
{
if (request.path == "/" || request.path == "/index.html")
return ServeUiAsset("index.html");
if (request.path == "/api/state")
return { "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}" };
if (request.path == "/openapi.yaml" || request.path == "/docs/openapi.yaml")
return ServeOpenApiSpec();
if (request.path == "/docs" || request.path == "/docs/")
return ServeSwaggerDocs();
const std::string docsPrefix = "/docs/";
if (request.path.rfind(docsPrefix, 0) == 0)
return ServeDocsAsset(request.path.substr(docsPrefix.size()));
if (request.path.size() > 1)
{
const HttpResponse assetResponse = ServeUiAsset(request.path.substr(1));
if (!assetResponse.body.empty())
return assetResponse;
}
return { "404 Not Found", "text/plain", "Not Found" };
}
ControlServer::HttpResponse ControlServer::ServeUiAsset(const std::string& relativePath) const
{
std::string contentType;
const std::string body = LoadUiAsset(relativePath, contentType);
return body.empty()
? HttpResponse{ "404 Not Found", "text/plain", "Not Found" }
: HttpResponse{ "200 OK", contentType, body };
}
ControlServer::HttpResponse ControlServer::ServeDocsAsset(const std::string& relativePath) const
{
const std::filesystem::path sanitizedPath = std::filesystem::path(relativePath).lexically_normal();
if (!IsSafeUiPath(sanitizedPath))
return { "404 Not Found", "text/plain", "Not Found" };
const std::filesystem::path docsPath = mDocsRoot / sanitizedPath;
const std::string body = LoadTextFile(docsPath);
return body.empty()
? HttpResponse{ "404 Not Found", "text/plain", "Not Found" }
: HttpResponse{ "200 OK", GuessContentType(docsPath), body };
}
ControlServer::HttpResponse ControlServer::ServeOpenApiSpec() const
{
const std::filesystem::path specPath = mDocsRoot / "openapi.yaml";
const std::string body = LoadTextFile(specPath);
return body.empty()
? HttpResponse{ "404 Not Found", "text/plain", "OpenAPI spec not found" }
: HttpResponse{ "200 OK", GuessContentType(specPath), body };
}
ControlServer::HttpResponse ControlServer::ServeSwaggerDocs() const
{
std::ostringstream html;
html << "<!doctype html>\n"
<< "<html lang=\"en\">\n"
<< "<head>\n"
<< " <meta charset=\"utf-8\">\n"
<< " <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\">\n"
<< " <title>Video Shader Toys API Docs</title>\n"
<< " <link rel=\"stylesheet\" href=\"https://unpkg.com/swagger-ui-dist@5/swagger-ui.css\">\n"
<< "</head>\n"
<< "<body>\n"
<< " <div id=\"swagger-ui\"></div>\n"
<< " <script src=\"https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js\"></script>\n"
<< " <script>SwaggerUIBundle({url:'/docs/openapi.yaml',dom_id:'#swagger-ui'});</script>\n"
<< "</body>\n"
<< "</html>\n";
return { "200 OK", "text/html", html.str() };
}
ControlServer::HttpResponse ControlServer::HandleApiPost(const HttpRequest& request)
{
JsonValue root;
std::string parseError;
if (!ParseJson(request.body, root, parseError))
return { "400 Bad Request", "application/json", BuildJsonResponse(false, parseError) };
std::string actionError;
const bool success = InvokePostRoute(request.path, root, actionError);
return {
success ? "200 OK" : "400 Bad Request",
"application/json",
BuildJsonResponse(success, actionError),
success
};
}
bool ControlServer::InvokePostRoute(const std::string& path, const JsonValue& root, std::string& actionError)
{
using PostHandler = std::function<bool(const JsonValue&, std::string&)>;
const std::map<std::string, PostHandler> postRoutes =
{
{ "/api/layers/add", [this](const JsonValue& json, std::string& error)
{
const JsonValue* shaderId = json.find("shaderId");
return shaderId && mCallbacks.addLayer && mCallbacks.addLayer(shaderId->asString(), error);
}
},
{ "/api/layers/remove", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
return layerId && mCallbacks.removeLayer && mCallbacks.removeLayer(layerId->asString(), error);
}
},
{ "/api/layers/move", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* direction = json.find("direction");
return layerId && direction && mCallbacks.moveLayer &&
mCallbacks.moveLayer(layerId->asString(), static_cast<int>(direction->asNumber()), error);
}
},
{ "/api/layers/reorder", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* targetIndex = json.find("targetIndex");
return layerId && targetIndex && mCallbacks.moveLayerToIndex &&
mCallbacks.moveLayerToIndex(layerId->asString(), static_cast<std::size_t>(targetIndex->asNumber()), error);
}
},
{ "/api/layers/set-bypass", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* bypass = json.find("bypass");
return layerId && bypass && mCallbacks.setLayerBypass &&
mCallbacks.setLayerBypass(layerId->asString(), bypass->asBoolean(), error);
}
},
{ "/api/layers/set-shader", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* shaderId = json.find("shaderId");
return layerId && shaderId && mCallbacks.setLayerShader &&
mCallbacks.setLayerShader(layerId->asString(), shaderId->asString(), error);
}
},
{ "/api/layers/update-parameter", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* parameterId = json.find("parameterId");
const JsonValue* value = json.find("value");
return layerId && parameterId && value && mCallbacks.updateLayerParameter &&
mCallbacks.updateLayerParameter(layerId->asString(), parameterId->asString(), SerializeJson(*value, false), error);
}
},
{ "/api/layers/reset-parameters", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
return layerId && mCallbacks.resetLayerParameters &&
mCallbacks.resetLayerParameters(layerId->asString(), error);
}
},
{ "/api/stack-presets/save", [this](const JsonValue& json, std::string& error)
{
const JsonValue* presetName = json.find("presetName");
return presetName && mCallbacks.saveStackPreset &&
mCallbacks.saveStackPreset(presetName->asString(), error);
}
},
{ "/api/stack-presets/load", [this](const JsonValue& json, std::string& error)
{
const JsonValue* presetName = json.find("presetName");
return presetName && mCallbacks.loadStackPreset &&
mCallbacks.loadStackPreset(presetName->asString(), error);
}
},
{ "/api/reload", [this](const JsonValue&, std::string& error)
{
return mCallbacks.reloadShader && mCallbacks.reloadShader(error);
}
},
{ "/api/screenshot", [this](const JsonValue&, std::string& error)
{
return mCallbacks.requestScreenshot && mCallbacks.requestScreenshot(error);
}
}
};
const auto route = postRoutes.find(path);
return route != postRoutes.end() && route->second(root, actionError);
}
bool ControlServer::HandleWebSocketUpgrade(UniqueSocket clientSocket, const HttpRequest& request)
{
const std::string clientKey = GetHeaderValue(request, "Sec-WebSocket-Key");
if (clientKey.empty())
{
SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key");
return true;
}
std::ostringstream response;
response << "HTTP/1.1 101 Switching Protocols\r\n";
response << "Upgrade: websocket\r\n";
response << "Connection: Upgrade\r\n";
response << "Sec-WebSocket-Accept: " << ComputeWebSocketAcceptKey(clientKey) << "\r\n\r\n";
const std::string payload = response.str();
send(clientSocket.get(), payload.c_str(), static_cast<int>(payload.size()), 0);
{
std::lock_guard<std::mutex> lock(mMutex);
ClientConnection client;
client.socket.reset(clientSocket.release());
client.websocket = true;
mClients.push_back(std::move(client));
BroadcastStateLocked();
}
return true;
}
bool ControlServer::SendWebSocketText(SOCKET clientSocket, const std::string& payload)
{
std::string frame;
frame.push_back(static_cast<char>(0x81));
if (payload.size() <= 125)
{
frame.push_back(static_cast<char>(payload.size()));
}
else if (payload.size() <= 65535)
{
frame.push_back(126);
frame.push_back(static_cast<char>((payload.size() >> 8) & 0xFF));
frame.push_back(static_cast<char>(payload.size() & 0xFF));
}
else
{
frame.push_back(127);
for (int shift = 56; shift >= 0; shift -= 8)
frame.push_back(static_cast<char>((payload.size() >> shift) & 0xFF));
}
frame.append(payload);
return send(clientSocket, frame.data(), static_cast<int>(frame.size()), 0) == static_cast<int>(frame.size());
}
void ControlServer::BroadcastStateLocked()
{
const std::string stateMessage = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
for (auto it = mClients.begin(); it != mClients.end();)
{
if (!SendWebSocketText(it->socket.get(), stateMessage))
{
it = mClients.erase(it);
}
else
{
++it;
}
}
}
std::string ControlServer::LoadUiAsset(const std::string& relativePath, std::string& contentType) const
{
const std::filesystem::path sanitizedPath = std::filesystem::path(relativePath).lexically_normal();
if (!IsSafeUiPath(sanitizedPath))
return std::string();
const std::filesystem::path assetPath = mUiRoot / sanitizedPath;
contentType = GuessContentType(assetPath);
return LoadTextFile(assetPath);
}
std::string ControlServer::LoadTextFile(const std::filesystem::path& path) const
{
std::ifstream input(path, std::ios::binary);
if (!input)
return std::string();
std::ostringstream buffer;
buffer << input.rdbuf();
return buffer.str();
}
std::string ControlServer::BuildJsonResponse(bool success, const std::string& error) const
{
JsonValue response = JsonValue::MakeObject();
response.set("ok", JsonValue(success));
if (!error.empty())
response.set("error", JsonValue(error));
return SerializeJson(response, false);
}
std::string ControlServer::Base64Encode(const unsigned char* data, DWORD dataLength)
{
DWORD outputLength = 0;
CryptBinaryToStringA(data, dataLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, NULL, &outputLength);
std::string encoded(outputLength, '\0');
CryptBinaryToStringA(data, dataLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, &encoded[0], &outputLength);
if (!encoded.empty() && encoded.back() == '\0')
encoded.pop_back();
return encoded;
}
std::string ControlServer::ComputeWebSocketAcceptKey(const std::string& clientKey)
{
const std::string combined = clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
HCRYPTPROV provider = 0;
HCRYPTHASH hash = 0;
BYTE digest[20] = {};
DWORD digestLength = sizeof(digest);
CryptAcquireContext(&provider, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT);
CryptCreateHash(provider, CALG_SHA1, 0, 0, &hash);
CryptHashData(hash, reinterpret_cast<const BYTE*>(combined.data()), static_cast<DWORD>(combined.size()), 0);
CryptGetHashParam(hash, HP_HASHVAL, digest, &digestLength, 0);
if (hash)
CryptDestroyHash(hash);
if (provider)
CryptReleaseContext(provider, 0);
return Base64Encode(digest, digestLength);
}
std::string ControlServer::GetHeaderValue(const HttpRequest& request, const std::string& headerName)
{
const auto header = request.headers.find(ToLower(headerName));
return header == request.headers.end() ? std::string() : header->second;
}
bool ControlServer::ParseHttpRequest(const std::string& rawRequest, HttpRequest& request)
{
const std::size_t requestLineEnd = rawRequest.find("\r\n");
if (requestLineEnd == std::string::npos)
return false;
const std::string requestLine = rawRequest.substr(0, requestLineEnd);
const std::size_t methodEnd = requestLine.find(' ');
if (methodEnd == std::string::npos)
return false;
const std::size_t pathEnd = requestLine.find(' ', methodEnd + 1);
if (pathEnd == std::string::npos)
return false;
request.method = requestLine.substr(0, methodEnd);
request.path = requestLine.substr(methodEnd + 1, pathEnd - methodEnd - 1);
request.headers.clear();
const std::size_t headersStart = requestLineEnd + 2;
const std::size_t bodySeparator = rawRequest.find("\r\n\r\n", headersStart);
const std::size_t headersEnd = bodySeparator == std::string::npos ? rawRequest.size() : bodySeparator;
for (std::size_t lineStart = headersStart; lineStart < headersEnd;)
{
const std::size_t lineEnd = rawRequest.find("\r\n", lineStart);
const std::size_t currentLineEnd = lineEnd == std::string::npos ? headersEnd : std::min(lineEnd, headersEnd);
const std::string line = rawRequest.substr(lineStart, currentLineEnd - lineStart);
const std::size_t separator = line.find(':');
if (separator != std::string::npos)
{
const std::string key = ToLower(line.substr(0, separator));
std::string value = line.substr(separator + 1);
const std::size_t first = value.find_first_not_of(" \t");
const std::size_t last = value.find_last_not_of(" \t");
request.headers[key] = first == std::string::npos ? std::string() : value.substr(first, last - first + 1);
}
if (lineEnd == std::string::npos || lineEnd >= headersEnd)
break;
lineStart = lineEnd + 2;
}
request.body = bodySeparator == std::string::npos ? std::string() : rawRequest.substr(bodySeparator + 4);
return !request.method.empty() && !request.path.empty();
}