diff --git a/.gitignore b/.gitignore index 7ef9afa..76df258 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,8 @@ build.ninja *.log *.dmp *.tmp -/runtime/ +/runtime/* +!/runtime/templates/ +!/runtime/templates/** /ui/node_modules/ /ui/dist/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 72dee02..cff965b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ set(APP_SOURCES "${APP_DIR}/LoopThroughWithOpenGLCompositing.cpp" "${APP_DIR}/LoopThroughWithOpenGLCompositing.h" "${APP_DIR}/LoopThroughWithOpenGLCompositing.rc" + "${APP_DIR}/NativeHandles.h" + "${APP_DIR}/NativeSockets.h" "${APP_DIR}/OpenGLComposite.cpp" "${APP_DIR}/OpenGLComposite.h" "${APP_DIR}/resource.h" @@ -72,6 +74,22 @@ if(MSVC) target_compile_options(LoopThroughWithOpenGLCompositing PRIVATE /W3) endif() +add_executable(RuntimeJsonTests + "${APP_DIR}/RuntimeJson.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tests/RuntimeJsonTests.cpp" +) + +target_include_directories(RuntimeJsonTests PRIVATE + "${APP_DIR}" +) + +if(MSVC) + target_compile_options(RuntimeJsonTests PRIVATE /W3) +endif() + +enable_testing() +add_test(NAME RuntimeJsonTests COMMAND RuntimeJsonTests) + add_custom_command(TARGET LoopThroughWithOpenGLCompositing POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different "${GPUDIRECT_DIR}/bin/x64/dvp.dll" diff --git a/apps/LoopThroughWithOpenGLCompositing/ControlServer.cpp b/apps/LoopThroughWithOpenGLCompositing/ControlServer.cpp index c746e71..3b5dfe8 100644 --- a/apps/LoopThroughWithOpenGLCompositing/ControlServer.cpp +++ b/apps/LoopThroughWithOpenGLCompositing/ControlServer.cpp @@ -69,7 +69,7 @@ std::string GuessContentType(const std::filesystem::path& assetPath) } ControlServer::ControlServer() - : mListenSocket(INVALID_SOCKET), mPort(0), mRunning(false) + : mPort(0), mRunning(false) { } @@ -86,15 +86,15 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr if (!InitializeWinsock(error)) return false; - mListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (mListenSocket == INVALID_SOCKET) + mListenSocket.reset(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + if (!mListenSocket.valid()) { error = "Could not create listening socket."; return false; } u_long nonBlocking = 1; - ioctlsocket(mListenSocket, FIONBIO, &nonBlocking); + ioctlsocket(mListenSocket.get(), FIONBIO, &nonBlocking); sockaddr_in address = {}; address.sin_family = AF_INET; @@ -104,7 +104,7 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr for (unsigned short offset = 0; offset < 20; ++offset) { address.sin_port = htons(static_cast(preferredPort + offset)); - if (bind(mListenSocket, reinterpret_cast(&address), sizeof(address)) == 0) + if (bind(mListenSocket.get(), reinterpret_cast(&address), sizeof(address)) == 0) { mPort = preferredPort + offset; bound = true; @@ -115,16 +115,14 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr if (!bound) { error = "Could not bind the local control server to any port in the preferred range."; - closesocket(mListenSocket); - mListenSocket = INVALID_SOCKET; + mListenSocket.reset(); return false; } - if (listen(mListenSocket, SOMAXCONN) != 0) + if (listen(mListenSocket.get(), SOMAXCONN) != 0) { error = "Could not start listening on the local control server socket."; - closesocket(mListenSocket); - mListenSocket = INVALID_SOCKET; + mListenSocket.reset(); return false; } @@ -135,27 +133,17 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr void ControlServer::Stop() { - const bool wasActive = mRunning || mListenSocket != INVALID_SOCKET || mThread.joinable(); + const bool wasActive = mRunning || mListenSocket.valid() || mThread.joinable(); mRunning = false; { std::lock_guard lock(mMutex); for (ClientConnection& client : mClients) - { - if (client.socket != INVALID_SOCKET) - { - closesocket(client.socket); - client.socket = INVALID_SOCKET; - } - } + client.socket.reset(); mClients.clear(); } - if (mListenSocket != INVALID_SOCKET) - { - closesocket(mListenSocket); - mListenSocket = INVALID_SOCKET; - } + mListenSocket.reset(); if (mThread.joinable()) mThread.join(); @@ -179,30 +167,27 @@ void ControlServer::ServerLoop() } } -bool ControlServer::HandleHttpClient(SOCKET clientSocket) +bool ControlServer::HandleHttpClient(UniqueSocket clientSocket) { std::string request; char buffer[8192]; - int received = recv(clientSocket, buffer, sizeof(buffer), 0); + int received = recv(clientSocket.get(), buffer, sizeof(buffer), 0); if (received <= 0) return false; request.assign(buffer, buffer + received); - return HandleHttpRequest(clientSocket, request); + return HandleHttpRequest(std::move(clientSocket), request); } bool ControlServer::TryAcceptClient() { sockaddr_in clientAddress = {}; int addressSize = sizeof(clientAddress); - SOCKET clientSocket = accept(mListenSocket, reinterpret_cast(&clientAddress), &addressSize); - if (clientSocket == INVALID_SOCKET) + UniqueSocket clientSocket(accept(mListenSocket.get(), reinterpret_cast(&clientAddress), &addressSize)); + if (!clientSocket.valid()) return false; - bool handled = HandleHttpClient(clientSocket); - if (!handled) - closesocket(clientSocket); - return handled; + return HandleHttpClient(std::move(clientSocket)); } bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body) @@ -218,144 +203,181 @@ bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& sta return send(clientSocket, payload.c_str(), static_cast(payload.size()), 0) == static_cast(payload.size()); } -bool ControlServer::HandleHttpRequest(SOCKET clientSocket, const std::string& request) +bool ControlServer::SendHttpResponse(SOCKET clientSocket, const HttpResponse& response) { - const std::string method = GetRequestMethod(request); - const std::string path = GetRequestPath(request); + return SendHttpResponse(clientSocket, response.status, response.contentType, response.body); +} - if (ToLower(GetHeaderValue(request, "Upgrade")) == "websocket") - return HandleWebSocketUpgrade(clientSocket, request); - - if (method == "GET") +bool ControlServer::HandleHttpRequest(UniqueSocket clientSocket, const std::string& request) +{ + HttpRequest httpRequest; + if (!ParseHttpRequest(request, httpRequest)) { - if (path == "/" || path == "/index.html") - { - std::string contentType; - std::string body = LoadUiAsset("index.html", contentType); - SendHttpResponse(clientSocket, "200 OK", contentType, body); - closesocket(clientSocket); - return true; - } - if (path == "/api/state") - { - SendHttpResponse(clientSocket, "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"); - closesocket(clientSocket); - return true; - } - - std::string contentType; - std::string body = LoadUiAsset(path.substr(1), contentType); - if (!body.empty()) - { - SendHttpResponse(clientSocket, "200 OK", contentType, body); - closesocket(clientSocket); - return true; - } - } - else if (method == "POST") - { - std::string body = GetRequestBody(request); - JsonValue root; - std::string parseError; - if (!ParseJson(body, root, parseError)) - { - SendHttpResponse(clientSocket, "400 Bad Request", "application/json", BuildJsonResponse(false, parseError)); - closesocket(clientSocket); - return true; - } - - bool success = false; - std::string actionError; - - if (path == "/api/layers/add") - { - const JsonValue* shaderId = root.find("shaderId"); - success = shaderId && mCallbacks.addLayer && mCallbacks.addLayer(shaderId->asString(), actionError); - } - else if (path == "/api/layers/remove") - { - const JsonValue* layerId = root.find("layerId"); - success = layerId && mCallbacks.removeLayer && mCallbacks.removeLayer(layerId->asString(), actionError); - } - else if (path == "/api/layers/move") - { - const JsonValue* layerId = root.find("layerId"); - const JsonValue* direction = root.find("direction"); - if (layerId && direction && mCallbacks.moveLayer) - success = mCallbacks.moveLayer(layerId->asString(), static_cast(direction->asNumber()), actionError); - } - else if (path == "/api/layers/reorder") - { - const JsonValue* layerId = root.find("layerId"); - const JsonValue* targetIndex = root.find("targetIndex"); - if (layerId && targetIndex && mCallbacks.moveLayerToIndex) - success = mCallbacks.moveLayerToIndex(layerId->asString(), static_cast(targetIndex->asNumber()), actionError); - } - else if (path == "/api/layers/set-bypass") - { - const JsonValue* layerId = root.find("layerId"); - const JsonValue* bypass = root.find("bypass"); - if (layerId && bypass && mCallbacks.setLayerBypass) - success = mCallbacks.setLayerBypass(layerId->asString(), bypass->asBoolean(), actionError); - } - else if (path == "/api/layers/set-shader") - { - const JsonValue* layerId = root.find("layerId"); - const JsonValue* shaderId = root.find("shaderId"); - if (layerId && shaderId && mCallbacks.setLayerShader) - success = mCallbacks.setLayerShader(layerId->asString(), shaderId->asString(), actionError); - } - else if (path == "/api/layers/update-parameter") - { - const JsonValue* layerId = root.find("layerId"); - const JsonValue* parameterId = root.find("parameterId"); - const JsonValue* value = root.find("value"); - if (layerId && parameterId && value && mCallbacks.updateLayerParameter) - success = mCallbacks.updateLayerParameter(layerId->asString(), parameterId->asString(), SerializeJson(*value, false), actionError); - } - else if (path == "/api/layers/reset-parameters") - { - const JsonValue* layerId = root.find("layerId"); - if (layerId && mCallbacks.resetLayerParameters) - success = mCallbacks.resetLayerParameters(layerId->asString(), actionError); - } - else if (path == "/api/stack-presets/save") - { - const JsonValue* presetName = root.find("presetName"); - if (presetName && mCallbacks.saveStackPreset) - success = mCallbacks.saveStackPreset(presetName->asString(), actionError); - } - else if (path == "/api/stack-presets/load") - { - const JsonValue* presetName = root.find("presetName"); - if (presetName && mCallbacks.loadStackPreset) - success = mCallbacks.loadStackPreset(presetName->asString(), actionError); - } - else if (path == "/api/reload") - { - if (mCallbacks.reloadShader) - success = mCallbacks.reloadShader(actionError); - } - - SendHttpResponse(clientSocket, success ? "200 OK" : "400 Bad Request", "application/json", BuildJsonResponse(success, actionError)); - closesocket(clientSocket); - if (success) - BroadcastState(); + SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Bad Request"); return true; } - SendHttpResponse(clientSocket, "404 Not Found", "text/plain", "Not Found"); - closesocket(clientSocket); + if (ToLower(GetHeaderValue(httpRequest, "Upgrade")) == "websocket") + return HandleWebSocketUpgrade(std::move(clientSocket), httpRequest); + + const HttpResponse response = RouteHttpRequest(httpRequest); + SendHttpResponse(clientSocket.get(), response); + if (response.broadcastState) + BroadcastState(); return true; } -bool ControlServer::HandleWebSocketUpgrade(SOCKET clientSocket, const std::string& request) +ControlServer::HttpResponse ControlServer::RouteHttpRequest(const HttpRequest& request) +{ + if (request.method == "GET") + return ServeGetRequest(request); + + if (request.method == "POST") + return HandleApiPost(request); + + return { "404 Not Found", "text/plain", "Not Found" }; +} + +ControlServer::HttpResponse ControlServer::ServeGetRequest(const HttpRequest& request) const +{ + if (request.path == "/" || request.path == "/index.html") + return ServeUiAsset("index.html"); + + if (request.path == "/api/state") + return { "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}" }; + + if (request.path.size() > 1) + { + const HttpResponse assetResponse = ServeUiAsset(request.path.substr(1)); + if (!assetResponse.body.empty()) + return assetResponse; + } + + return { "404 Not Found", "text/plain", "Not Found" }; +} + +ControlServer::HttpResponse ControlServer::ServeUiAsset(const std::string& relativePath) const +{ + std::string contentType; + const std::string body = LoadUiAsset(relativePath, contentType); + return body.empty() + ? HttpResponse{ "404 Not Found", "text/plain", "Not Found" } + : HttpResponse{ "200 OK", contentType, body }; +} + +ControlServer::HttpResponse ControlServer::HandleApiPost(const HttpRequest& request) +{ + JsonValue root; + std::string parseError; + if (!ParseJson(request.body, root, parseError)) + return { "400 Bad Request", "application/json", BuildJsonResponse(false, parseError) }; + + std::string actionError; + const bool success = InvokePostRoute(request.path, root, actionError); + return { + success ? "200 OK" : "400 Bad Request", + "application/json", + BuildJsonResponse(success, actionError), + success + }; +} + +bool ControlServer::InvokePostRoute(const std::string& path, const JsonValue& root, std::string& actionError) +{ + using PostHandler = std::function; + const std::map postRoutes = + { + { "/api/layers/add", [this](const JsonValue& json, std::string& error) + { + const JsonValue* shaderId = json.find("shaderId"); + return shaderId && mCallbacks.addLayer && mCallbacks.addLayer(shaderId->asString(), error); + } + }, + { "/api/layers/remove", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + return layerId && mCallbacks.removeLayer && mCallbacks.removeLayer(layerId->asString(), error); + } + }, + { "/api/layers/move", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + const JsonValue* direction = json.find("direction"); + return layerId && direction && mCallbacks.moveLayer && + mCallbacks.moveLayer(layerId->asString(), static_cast(direction->asNumber()), error); + } + }, + { "/api/layers/reorder", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + const JsonValue* targetIndex = json.find("targetIndex"); + return layerId && targetIndex && mCallbacks.moveLayerToIndex && + mCallbacks.moveLayerToIndex(layerId->asString(), static_cast(targetIndex->asNumber()), error); + } + }, + { "/api/layers/set-bypass", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + const JsonValue* bypass = json.find("bypass"); + return layerId && bypass && mCallbacks.setLayerBypass && + mCallbacks.setLayerBypass(layerId->asString(), bypass->asBoolean(), error); + } + }, + { "/api/layers/set-shader", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + const JsonValue* shaderId = json.find("shaderId"); + return layerId && shaderId && mCallbacks.setLayerShader && + mCallbacks.setLayerShader(layerId->asString(), shaderId->asString(), error); + } + }, + { "/api/layers/update-parameter", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + const JsonValue* parameterId = json.find("parameterId"); + const JsonValue* value = json.find("value"); + return layerId && parameterId && value && mCallbacks.updateLayerParameter && + mCallbacks.updateLayerParameter(layerId->asString(), parameterId->asString(), SerializeJson(*value, false), error); + } + }, + { "/api/layers/reset-parameters", [this](const JsonValue& json, std::string& error) + { + const JsonValue* layerId = json.find("layerId"); + return layerId && mCallbacks.resetLayerParameters && + mCallbacks.resetLayerParameters(layerId->asString(), error); + } + }, + { "/api/stack-presets/save", [this](const JsonValue& json, std::string& error) + { + const JsonValue* presetName = json.find("presetName"); + return presetName && mCallbacks.saveStackPreset && + mCallbacks.saveStackPreset(presetName->asString(), error); + } + }, + { "/api/stack-presets/load", [this](const JsonValue& json, std::string& error) + { + const JsonValue* presetName = json.find("presetName"); + return presetName && mCallbacks.loadStackPreset && + mCallbacks.loadStackPreset(presetName->asString(), error); + } + }, + { "/api/reload", [this](const JsonValue&, std::string& error) + { + return mCallbacks.reloadShader && mCallbacks.reloadShader(error); + } + } + }; + + const auto route = postRoutes.find(path); + return route != postRoutes.end() && route->second(root, actionError); +} + +bool ControlServer::HandleWebSocketUpgrade(UniqueSocket clientSocket, const HttpRequest& request) { const std::string clientKey = GetHeaderValue(request, "Sec-WebSocket-Key"); if (clientKey.empty()) { - SendHttpResponse(clientSocket, "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key"); - closesocket(clientSocket); + SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key"); return true; } @@ -366,14 +388,14 @@ bool ControlServer::HandleWebSocketUpgrade(SOCKET clientSocket, const std::strin response << "Sec-WebSocket-Accept: " << ComputeWebSocketAcceptKey(clientKey) << "\r\n\r\n"; const std::string payload = response.str(); - send(clientSocket, payload.c_str(), static_cast(payload.size()), 0); + send(clientSocket.get(), payload.c_str(), static_cast(payload.size()), 0); { std::lock_guard lock(mMutex); ClientConnection client; - client.socket = clientSocket; + client.socket.reset(clientSocket.release()); client.websocket = true; - mClients.push_back(client); + mClients.push_back(std::move(client)); BroadcastStateLocked(); } return true; @@ -409,9 +431,8 @@ void ControlServer::BroadcastStateLocked() const std::string stateMessage = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; for (auto it = mClients.begin(); it != mClients.end();) { - if (!SendWebSocketText(it->socket, stateMessage)) + if (!SendWebSocketText(it->socket.get(), stateMessage)) { - closesocket(it->socket); it = mClients.erase(it); } else @@ -480,46 +501,55 @@ std::string ControlServer::ComputeWebSocketAcceptKey(const std::string& clientKe return Base64Encode(digest, digestLength); } -std::string ControlServer::GetHeaderValue(const std::string& request, const std::string& headerName) +std::string ControlServer::GetHeaderValue(const HttpRequest& request, const std::string& headerName) { - const std::string lowerRequest = ToLower(request); - const std::string lowerHeaderName = ToLower(headerName) + ":"; - const std::size_t start = lowerRequest.find(lowerHeaderName); - if (start == std::string::npos) - return std::string(); - - const std::size_t valueStart = start + lowerHeaderName.size(); - const std::size_t lineEnd = request.find("\r\n", valueStart); - if (lineEnd == std::string::npos) - return std::string(); - - std::string value = request.substr(valueStart, lineEnd - valueStart); - const std::size_t first = value.find_first_not_of(" \t"); - const std::size_t last = value.find_last_not_of(" \t"); - return first == std::string::npos ? std::string() : value.substr(first, last - first + 1); + const auto header = request.headers.find(ToLower(headerName)); + return header == request.headers.end() ? std::string() : header->second; } -std::string ControlServer::GetRequestPath(const std::string& request) +bool ControlServer::ParseHttpRequest(const std::string& rawRequest, HttpRequest& request) { - const std::size_t methodEnd = request.find(' '); + const std::size_t requestLineEnd = rawRequest.find("\r\n"); + if (requestLineEnd == std::string::npos) + return false; + + const std::string requestLine = rawRequest.substr(0, requestLineEnd); + const std::size_t methodEnd = requestLine.find(' '); if (methodEnd == std::string::npos) - return "/"; - const std::size_t pathEnd = request.find(' ', methodEnd + 1); + return false; + + const std::size_t pathEnd = requestLine.find(' ', methodEnd + 1); if (pathEnd == std::string::npos) - return "/"; - return request.substr(methodEnd + 1, pathEnd - methodEnd - 1); -} + return false; -std::string ControlServer::GetRequestMethod(const std::string& request) -{ - const std::size_t methodEnd = request.find(' '); - return methodEnd == std::string::npos ? std::string() : request.substr(0, methodEnd); -} + request.method = requestLine.substr(0, methodEnd); + request.path = requestLine.substr(methodEnd + 1, pathEnd - methodEnd - 1); + request.headers.clear(); -std::string ControlServer::GetRequestBody(const std::string& request) -{ - const std::size_t separator = request.find("\r\n\r\n"); - if (separator == std::string::npos) - return std::string(); - return request.substr(separator + 4); + const std::size_t headersStart = requestLineEnd + 2; + const std::size_t bodySeparator = rawRequest.find("\r\n\r\n", headersStart); + const std::size_t headersEnd = bodySeparator == std::string::npos ? rawRequest.size() : bodySeparator; + + for (std::size_t lineStart = headersStart; lineStart < headersEnd;) + { + const std::size_t lineEnd = rawRequest.find("\r\n", lineStart); + const std::size_t currentLineEnd = lineEnd == std::string::npos ? headersEnd : std::min(lineEnd, headersEnd); + const std::string line = rawRequest.substr(lineStart, currentLineEnd - lineStart); + const std::size_t separator = line.find(':'); + if (separator != std::string::npos) + { + const std::string key = ToLower(line.substr(0, separator)); + std::string value = line.substr(separator + 1); + const std::size_t first = value.find_first_not_of(" \t"); + const std::size_t last = value.find_last_not_of(" \t"); + request.headers[key] = first == std::string::npos ? std::string() : value.substr(first, last - first + 1); + } + + if (lineEnd == std::string::npos || lineEnd >= headersEnd) + break; + lineStart = lineEnd + 2; + } + + request.body = bodySeparator == std::string::npos ? std::string() : rawRequest.substr(bodySeparator + 4); + return !request.method.empty() && !request.path.empty(); } diff --git a/apps/LoopThroughWithOpenGLCompositing/ControlServer.h b/apps/LoopThroughWithOpenGLCompositing/ControlServer.h index 755bc33..0c821bc 100644 --- a/apps/LoopThroughWithOpenGLCompositing/ControlServer.h +++ b/apps/LoopThroughWithOpenGLCompositing/ControlServer.h @@ -1,15 +1,20 @@ #pragma once +#include "NativeSockets.h" + #include #include #include #include +#include #include #include #include #include +class JsonValue; + class ControlServer { public: @@ -41,31 +46,51 @@ public: private: struct ClientConnection { - SOCKET socket = INVALID_SOCKET; + UniqueSocket socket; bool websocket = false; }; + struct HttpRequest + { + std::string method; + std::string path; + std::map headers; + std::string body; + }; + + struct HttpResponse + { + std::string status; + std::string contentType; + std::string body; + bool broadcastState = false; + }; + void ServerLoop(); - bool HandleHttpClient(SOCKET clientSocket); + bool HandleHttpClient(UniqueSocket clientSocket); bool TryAcceptClient(); + bool SendHttpResponse(SOCKET clientSocket, const HttpResponse& response); bool SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body); - bool HandleHttpRequest(SOCKET clientSocket, const std::string& request); - bool HandleWebSocketUpgrade(SOCKET clientSocket, const std::string& request); + bool HandleHttpRequest(UniqueSocket clientSocket, const std::string& request); + bool HandleWebSocketUpgrade(UniqueSocket clientSocket, const HttpRequest& request); + HttpResponse RouteHttpRequest(const HttpRequest& request); + HttpResponse ServeGetRequest(const HttpRequest& request) const; + HttpResponse ServeUiAsset(const std::string& relativePath) const; + HttpResponse HandleApiPost(const HttpRequest& request); + bool InvokePostRoute(const std::string& path, const JsonValue& root, std::string& actionError); bool SendWebSocketText(SOCKET clientSocket, const std::string& payload); void BroadcastStateLocked(); std::string LoadUiAsset(const std::string& relativePath, std::string& contentType) const; std::string BuildJsonResponse(bool success, const std::string& error = std::string()) const; static std::string Base64Encode(const unsigned char* data, DWORD dataLength); static std::string ComputeWebSocketAcceptKey(const std::string& clientKey); - static std::string GetHeaderValue(const std::string& request, const std::string& headerName); - static std::string GetRequestPath(const std::string& request); - static std::string GetRequestMethod(const std::string& request); - static std::string GetRequestBody(const std::string& request); + static std::string GetHeaderValue(const HttpRequest& request, const std::string& headerName); + static bool ParseHttpRequest(const std::string& rawRequest, HttpRequest& request); private: std::filesystem::path mUiRoot; Callbacks mCallbacks; - SOCKET mListenSocket; + UniqueSocket mListenSocket; unsigned short mPort; std::thread mThread; std::atomic mRunning; diff --git a/apps/LoopThroughWithOpenGLCompositing/NativeHandles.h b/apps/LoopThroughWithOpenGLCompositing/NativeHandles.h new file mode 100644 index 0000000..7a31d1d --- /dev/null +++ b/apps/LoopThroughWithOpenGLCompositing/NativeHandles.h @@ -0,0 +1,42 @@ +#pragma once + +#include + +class UniqueHandle +{ +public: + explicit UniqueHandle(HANDLE handle = NULL) : mHandle(handle) {} + ~UniqueHandle() { reset(); } + + UniqueHandle(const UniqueHandle&) = delete; + UniqueHandle& operator=(const UniqueHandle&) = delete; + + UniqueHandle(UniqueHandle&& other) noexcept : mHandle(other.release()) {} + + UniqueHandle& operator=(UniqueHandle&& other) noexcept + { + if (this != &other) + reset(other.release()); + return *this; + } + + HANDLE get() const { return mHandle; } + bool valid() const { return mHandle != NULL && mHandle != INVALID_HANDLE_VALUE; } + + HANDLE release() + { + HANDLE handle = mHandle; + mHandle = NULL; + return handle; + } + + void reset(HANDLE handle = NULL) + { + if (valid()) + CloseHandle(mHandle); + mHandle = handle; + } + +private: + HANDLE mHandle; +}; diff --git a/apps/LoopThroughWithOpenGLCompositing/NativeSockets.h b/apps/LoopThroughWithOpenGLCompositing/NativeSockets.h new file mode 100644 index 0000000..a9469a2 --- /dev/null +++ b/apps/LoopThroughWithOpenGLCompositing/NativeSockets.h @@ -0,0 +1,42 @@ +#pragma once + +#include + +class UniqueSocket +{ +public: + explicit UniqueSocket(SOCKET socket = INVALID_SOCKET) : mSocket(socket) {} + ~UniqueSocket() { reset(); } + + UniqueSocket(const UniqueSocket&) = delete; + UniqueSocket& operator=(const UniqueSocket&) = delete; + + UniqueSocket(UniqueSocket&& other) noexcept : mSocket(other.release()) {} + + UniqueSocket& operator=(UniqueSocket&& other) noexcept + { + if (this != &other) + reset(other.release()); + return *this; + } + + SOCKET get() const { return mSocket; } + bool valid() const { return mSocket != INVALID_SOCKET; } + + SOCKET release() + { + SOCKET socket = mSocket; + mSocket = INVALID_SOCKET; + return socket; + } + + void reset(SOCKET socket = INVALID_SOCKET) + { + if (valid()) + closesocket(mSocket); + mSocket = socket; + } + +private: + SOCKET mSocket; +}; diff --git a/apps/LoopThroughWithOpenGLCompositing/RuntimeHost.cpp b/apps/LoopThroughWithOpenGLCompositing/RuntimeHost.cpp index 5cdf72e..71fef1c 100644 --- a/apps/LoopThroughWithOpenGLCompositing/RuntimeHost.cpp +++ b/apps/LoopThroughWithOpenGLCompositing/RuntimeHost.cpp @@ -248,6 +248,47 @@ bool NumberListFromJsonValue(const JsonValue& value, std::vector& number return false; } +bool ValidateShaderIdentifier(const std::string& identifier, const std::string& fieldName, const std::filesystem::path& manifestPath, std::string& error) +{ + if (identifier.empty() || !(std::isalpha(static_cast(identifier.front())) || identifier.front() == '_')) + { + error = "Shader manifest field '" + fieldName + "' must be a valid shader identifier in: " + ManifestPathMessage(manifestPath); + return false; + } + + for (char ch : identifier) + { + const unsigned char unsignedCh = static_cast(ch); + if (!(std::isalnum(unsignedCh) || ch == '_')) + { + error = "Shader manifest field '" + fieldName + "' must be a valid shader identifier in: " + ManifestPathMessage(manifestPath); + return false; + } + } + + return true; +} + +bool ParseShaderMetadata(const JsonValue& manifestJson, ShaderPackage& shaderPackage, const std::filesystem::path& manifestPath, std::string& error) +{ + if (!RequireStringField(manifestJson, "id", shaderPackage.id, manifestPath, error) || + !RequireStringField(manifestJson, "name", shaderPackage.displayName, manifestPath, error) || + !OptionalStringField(manifestJson, "description", shaderPackage.description, "", manifestPath, error) || + !OptionalStringField(manifestJson, "category", shaderPackage.category, "", manifestPath, error) || + !OptionalStringField(manifestJson, "entryPoint", shaderPackage.entryPoint, "shadeVideo", manifestPath, error)) + { + return false; + } + + if (!ValidateShaderIdentifier(shaderPackage.entryPoint, "entryPoint", manifestPath, error)) + return false; + + shaderPackage.directoryPath = manifestPath.parent_path(); + shaderPackage.shaderPath = shaderPackage.directoryPath / "shader.slang"; + shaderPackage.manifestPath = manifestPath; + return true; +} + bool ParseTextureAssets(const JsonValue& manifestJson, ShaderPackage& shaderPackage, const std::filesystem::path& manifestPath, std::string& error) { const JsonValue* texturesValue = nullptr; @@ -272,6 +313,8 @@ bool ParseTextureAssets(const JsonValue& manifestJson, ShaderPackage& shaderPack error = "Shader texture is missing required 'id' or 'path' in: " + ManifestPathMessage(manifestPath); return false; } + if (!ValidateShaderIdentifier(textureId, "textures[].id", manifestPath, error)) + return false; ShaderTextureAsset textureAsset; textureAsset.id = textureId; @@ -374,7 +417,7 @@ bool ParseParameterDefault(const JsonValue& parameterJson, ShaderParameterDefini return NumberListFromJsonValue(*defaultValue, definition.defaultNumbers, "default", manifestPath, error); } -bool ParseEnumOptions(const JsonValue& parameterJson, ShaderParameterDefinition& definition, const std::filesystem::path& manifestPath, std::string& error) +bool ParseParameterOptions(const JsonValue& parameterJson, ShaderParameterDefinition& definition, const std::filesystem::path& manifestPath, std::string& error) { const JsonValue* optionsValue = nullptr; if (!OptionalArrayField(parameterJson, "options", optionsValue, manifestPath, error) || !optionsValue) @@ -442,6 +485,8 @@ bool ParseParameterDefinition(const JsonValue& parameterJson, ShaderParameterDef error = "Unsupported parameter type '" + typeName + "' in: " + ManifestPathMessage(manifestPath); return false; } + if (!ValidateShaderIdentifier(definition.id, "parameters[].id", manifestPath, error)) + return false; if (!ParseParameterDefault(parameterJson, definition, manifestPath, error) || !ParseParameterNumberField(parameterJson, "min", definition.minNumbers, manifestPath, error) || @@ -452,7 +497,7 @@ bool ParseParameterDefinition(const JsonValue& parameterJson, ShaderParameterDef } if (definition.type == ShaderParameterType::Enum) - return ParseEnumOptions(parameterJson, definition, manifestPath, error); + return ParseParameterOptions(parameterJson, definition, manifestPath, error); return true; } @@ -1265,18 +1310,8 @@ bool RuntimeHost::ParseShaderManifest(const std::filesystem::path& manifestPath, return false; } - if (!RequireStringField(manifestJson, "id", shaderPackage.id, manifestPath, error) || - !RequireStringField(manifestJson, "name", shaderPackage.displayName, manifestPath, error) || - !OptionalStringField(manifestJson, "description", shaderPackage.description, "", manifestPath, error) || - !OptionalStringField(manifestJson, "category", shaderPackage.category, "", manifestPath, error) || - !OptionalStringField(manifestJson, "entryPoint", shaderPackage.entryPoint, "shadeVideo", manifestPath, error)) - { + if (!ParseShaderMetadata(manifestJson, shaderPackage, manifestPath, error)) return false; - } - - shaderPackage.directoryPath = manifestPath.parent_path(); - shaderPackage.shaderPath = shaderPackage.directoryPath / "shader.slang"; - shaderPackage.manifestPath = manifestPath; if (!std::filesystem::exists(shaderPackage.shaderPath)) { diff --git a/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.cpp b/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.cpp index 5a74cde..785ad42 100644 --- a/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.cpp +++ b/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.cpp @@ -1,6 +1,7 @@ #include "stdafx.h" #include "RuntimeJson.h" +#include #include #include #include @@ -10,6 +11,59 @@ namespace { +int HexDigitValue(char ch) +{ + if (ch >= '0' && ch <= '9') + return ch - '0'; + if (ch >= 'a' && ch <= 'f') + return ch - 'a' + 10; + if (ch >= 'A' && ch <= 'F') + return ch - 'A' + 10; + return -1; +} + +bool IsHighSurrogate(unsigned int codePoint) +{ + return codePoint >= 0xD800 && codePoint <= 0xDBFF; +} + +bool IsLowSurrogate(unsigned int codePoint) +{ + return codePoint >= 0xDC00 && codePoint <= 0xDFFF; +} + +void AppendUtf8(unsigned int codePoint, std::ostringstream& output) +{ + if (codePoint <= 0x7F) + { + output << static_cast(codePoint); + } + else if (codePoint <= 0x7FF) + { + output << static_cast(0xC0 | ((codePoint >> 6) & 0x1F)); + output << static_cast(0x80 | (codePoint & 0x3F)); + } + else if (codePoint <= 0xFFFF) + { + output << static_cast(0xE0 | ((codePoint >> 12) & 0x0F)); + output << static_cast(0x80 | ((codePoint >> 6) & 0x3F)); + output << static_cast(0x80 | (codePoint & 0x3F)); + } + else + { + output << static_cast(0xF0 | ((codePoint >> 18) & 0x07)); + output << static_cast(0x80 | ((codePoint >> 12) & 0x3F)); + output << static_cast(0x80 | ((codePoint >> 6) & 0x3F)); + output << static_cast(0x80 | (codePoint & 0x3F)); + } +} + +void AppendControlEscape(unsigned char ch, std::ostringstream& output) +{ + const char* digits = "0123456789ABCDEF"; + output << "\\u00" << digits[(ch >> 4) & 0x0F] << digits[ch & 0x0F]; +} + class JsonParser { public: @@ -181,8 +235,9 @@ private: case 'r': result << '\r'; break; case 't': result << '\t'; break; case 'u': - setError("Unicode escape sequences are not supported in this JSON parser."); - return false; + if (!parseUnicodeEscape(result)) + return false; + break; default: setError("Invalid escape sequence in JSON string."); return false; @@ -190,6 +245,11 @@ private: } else { + if (static_cast(ch) < 0x20) + { + setError("Unescaped control character in JSON string."); + return false; + } result << ch; } } @@ -198,6 +258,66 @@ private: return false; } + bool parseHexCodePoint(unsigned int& codePoint) + { + if (mPosition + 4 > mText.size()) + { + setError("Unexpected end of Unicode escape sequence."); + return false; + } + + codePoint = 0; + for (int i = 0; i < 4; ++i) + { + const int digit = HexDigitValue(mText[mPosition + i]); + if (digit < 0) + { + setError("Invalid Unicode escape sequence in JSON string."); + return false; + } + codePoint = (codePoint << 4) | static_cast(digit); + } + + mPosition += 4; + return true; + } + + bool parseUnicodeEscape(std::ostringstream& result) + { + unsigned int codePoint = 0; + if (!parseHexCodePoint(codePoint)) + return false; + + if (IsHighSurrogate(codePoint)) + { + if (mPosition + 2 > mText.size() || mText[mPosition] != '\\' || mText[mPosition + 1] != 'u') + { + setError("High surrogate Unicode escape must be followed by a low surrogate."); + return false; + } + + mPosition += 2; + unsigned int lowSurrogate = 0; + if (!parseHexCodePoint(lowSurrogate)) + return false; + if (!IsLowSurrogate(lowSurrogate)) + { + setError("High surrogate Unicode escape must be followed by a low surrogate."); + return false; + } + + codePoint = 0x10000 + (((codePoint - 0xD800) << 10) | (lowSurrogate - 0xDC00)); + } + else if (IsLowSurrogate(codePoint)) + { + setError("Low surrogate Unicode escape without preceding high surrogate."); + return false; + } + + AppendUtf8(codePoint, result); + return true; + } + bool parseNumber(JsonValue& value) { std::size_t start = mPosition; @@ -205,12 +325,40 @@ private: if (mText[mPosition] == '-') ++mPosition; - while (mPosition < mText.size() && std::isdigit(static_cast(mText[mPosition]))) + if (mPosition >= mText.size()) + { + setError("Invalid JSON number."); + return false; + } + + if (mText[mPosition] == '0') + { ++mPosition; + if (mPosition < mText.size() && std::isdigit(static_cast(mText[mPosition]))) + { + setError("JSON numbers must not contain leading zeroes."); + return false; + } + } + else if (mText[mPosition] >= '1' && mText[mPosition] <= '9') + { + while (mPosition < mText.size() && std::isdigit(static_cast(mText[mPosition]))) + ++mPosition; + } + else + { + setError("Invalid JSON number."); + return false; + } if (mPosition < mText.size() && mText[mPosition] == '.') { ++mPosition; + if (mPosition >= mText.size() || !std::isdigit(static_cast(mText[mPosition]))) + { + setError("JSON number fraction must contain at least one digit."); + return false; + } while (mPosition < mText.size() && std::isdigit(static_cast(mText[mPosition]))) ++mPosition; } @@ -220,14 +368,20 @@ private: ++mPosition; if (mPosition < mText.size() && (mText[mPosition] == '+' || mText[mPosition] == '-')) ++mPosition; + if (mPosition >= mText.size() || !std::isdigit(static_cast(mText[mPosition]))) + { + setError("JSON number exponent must contain at least one digit."); + return false; + } while (mPosition < mText.size() && std::isdigit(static_cast(mText[mPosition]))) ++mPosition; } std::string token = mText.substr(start, mPosition - start); char* endPtr = nullptr; + errno = 0; double parsed = strtod(token.c_str(), &endPtr); - if (endPtr == token.c_str() || *endPtr != '\0') + if (endPtr == token.c_str() || *endPtr != '\0' || errno == ERANGE || !std::isfinite(parsed)) { setError("Invalid JSON number."); return false; @@ -322,7 +476,12 @@ void SerializeJsonImpl(const JsonValue& value, std::ostringstream& output, bool case '\n': output << "\\n"; break; case '\r': output << "\\r"; break; case '\t': output << "\\t"; break; - default: output << ch; break; + default: + if (static_cast(ch) < 0x20) + AppendControlEscape(static_cast(ch), output); + else + output << ch; + break; } } output << '"'; @@ -407,14 +566,14 @@ JsonValue::JsonValue(const std::string& value) JsonValue JsonValue::MakeArray() { JsonValue value; - value.mType = Type::Array; + value.reset(Type::Array); return value; } JsonValue JsonValue::MakeObject() { JsonValue value; - value.mType = Type::Object; + value.reset(Type::Object); return value; } @@ -449,20 +608,14 @@ const std::map& JsonValue::asObject() const std::vector& JsonValue::array() { if (mType != Type::Array) - { - mType = Type::Array; - mArrayValue.clear(); - } + reset(Type::Array); return mArrayValue; } std::map& JsonValue::object() { if (mType != Type::Object) - { - mType = Type::Object; - mObjectValue.clear(); - } + reset(Type::Object); return mObjectValue; } @@ -485,6 +638,16 @@ const JsonValue* JsonValue::find(const std::string& key) const return iterator != mObjectValue.end() ? &iterator->second : nullptr; } +void JsonValue::reset(Type type) +{ + mType = type; + mBooleanValue = false; + mNumberValue = 0.0; + mStringValue.clear(); + mArrayValue.clear(); + mObjectValue.clear(); +} + bool ParseJson(const std::string& text, JsonValue& value, std::string& error) { error.clear(); diff --git a/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.h b/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.h index 4adfb97..47e601e 100644 --- a/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.h +++ b/apps/LoopThroughWithOpenGLCompositing/RuntimeJson.h @@ -50,6 +50,8 @@ public: const JsonValue* find(const std::string& key) const; private: + void reset(Type type); + Type mType; bool mBooleanValue; double mNumberValue; diff --git a/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.cpp b/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.cpp index dc6f503..dfa41d9 100644 --- a/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.cpp +++ b/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.cpp @@ -1,7 +1,8 @@ #include "stdafx.h" #include "ShaderCompiler.h" -#include +#include "NativeHandles.h" + #include #include #include @@ -20,17 +21,51 @@ std::string ReplaceAll(std::string text, const std::string& from, const std::str return text; } -std::string SlangTypeForParameter(ShaderParameterType type) +std::string SlangCBufferTypeForParameter(ShaderParameterType type) { switch (type) { - case ShaderParameterType::Float: return "uniform float"; - case ShaderParameterType::Vec2: return "uniform float2"; - case ShaderParameterType::Color: return "uniform float4"; - case ShaderParameterType::Boolean: return "uniform bool"; - case ShaderParameterType::Enum: return "uniform int"; + case ShaderParameterType::Float: return "float"; + case ShaderParameterType::Vec2: return "float2"; + case ShaderParameterType::Color: return "float4"; + case ShaderParameterType::Boolean: return "bool"; + case ShaderParameterType::Enum: return "int"; } - return "uniform float"; + return "float"; +} + +std::string BuildParameterUniforms(const std::vector& parameters) +{ + std::ostringstream source; + for (const ShaderParameterDefinition& definition : parameters) + source << "\t" << SlangCBufferTypeForParameter(definition.type) << " " << definition.id << ";\n"; + return source.str(); +} + +std::string BuildHistorySamplerDeclarations(const std::string& samplerPrefix, unsigned historyLength) +{ + std::ostringstream source; + for (unsigned index = 0; index < historyLength; ++index) + source << "Sampler2D " << samplerPrefix << index << ";\n"; + return source.str(); +} + +std::string BuildTextureSamplerDeclarations(const std::vector& textureAssets) +{ + std::ostringstream source; + for (const ShaderTextureAsset& textureAsset : textureAssets) + source << "Sampler2D " << textureAsset.id << ";\n"; + if (!textureAssets.empty()) + source << "\n"; + return source.str(); +} + +std::string BuildHistorySwitchCases(const std::string& samplerPrefix, unsigned historyLength) +{ + std::ostringstream source; + for (unsigned index = 0; index < historyLength; ++index) + source << "\tcase " << index << ": return " << samplerPrefix << index << ".Sample(tc);\n"; + return source.str(); } } @@ -50,7 +85,9 @@ ShaderCompiler::ShaderCompiler( bool ShaderCompiler::BuildLayerFragmentShaderSource(const ShaderPackage& shaderPackage, std::string& fragmentShaderSource, std::string& error) const { - const std::string wrapperSource = BuildWrapperSlangSource(shaderPackage); + std::string wrapperSource; + if (!BuildWrapperSlangSource(shaderPackage, wrapperSource, error)) + return false; if (!WriteTextFile(mWrapperPath, wrapperSource, error)) return false; @@ -70,104 +107,22 @@ bool ShaderCompiler::BuildLayerFragmentShaderSource(const ShaderPackage& shaderP return true; } -std::string ShaderCompiler::BuildWrapperSlangSource(const ShaderPackage& shaderPackage) const +bool ShaderCompiler::BuildWrapperSlangSource(const ShaderPackage& shaderPackage, std::string& wrapperSource, std::string& error) const { - std::ostringstream source; - source << "struct FragmentInput\n"; - source << "{\n"; - source << "\tfloat4 position : SV_Position;\n"; - source << "\tfloat2 texCoord : TEXCOORD0;\n"; - source << "};\n\n"; - source << "struct ShaderContext\n"; - source << "{\n"; - source << "\tfloat2 uv;\n"; - source << "\tfloat4 sourceColor;\n"; - source << "\tfloat2 inputResolution;\n"; - source << "\tfloat2 outputResolution;\n"; - source << "\tfloat time;\n"; - source << "\tfloat frameCount;\n"; - source << "\tfloat mixAmount;\n"; - source << "\tfloat bypass;\n"; - source << "\tint sourceHistoryLength;\n"; - source << "\tint temporalHistoryLength;\n"; - source << "};\n\n"; - source << "cbuffer GlobalParams\n"; - source << "{\n"; - source << "\tfloat gTime;\n"; - source << "\tfloat2 gInputResolution;\n"; - source << "\tfloat2 gOutputResolution;\n"; - source << "\tfloat gFrameCount;\n"; - source << "\tfloat gMixAmount;\n"; - source << "\tfloat gBypass;\n"; - source << "\tint gSourceHistoryLength;\n"; - source << "\tint gTemporalHistoryLength;\n"; - for (const ShaderParameterDefinition& definition : shaderPackage.parameters) - source << "\t" << SlangTypeForParameter(definition.type).substr(strlen("uniform ")) << " " << definition.id << ";\n"; - source << "};\n\n"; - source << "Sampler2D gVideoInput;\n"; - for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index) - source << "Sampler2D gSourceHistory" << index << ";\n"; - for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index) - source << "Sampler2D gTemporalHistory" << index << ";\n"; - for (const ShaderTextureAsset& textureAsset : shaderPackage.textureAssets) - source << "Sampler2D " << textureAsset.id << ";\n"; - source << "\n"; - source << "float4 sampleVideo(float2 tc)\n"; - source << "{\n"; - source << "\treturn gVideoInput.Sample(tc);\n"; - source << "}\n\n"; - source << "float4 sampleSourceHistory(int framesAgo, float2 tc)\n"; - source << "{\n"; - source << "\tif (gSourceHistoryLength <= 0)\n"; - source << "\t\treturn sampleVideo(tc);\n"; - source << "\tint clampedIndex = framesAgo;\n"; - source << "\tif (clampedIndex < 0)\n"; - source << "\t\tclampedIndex = 0;\n"; - source << "\tif (clampedIndex >= gSourceHistoryLength)\n"; - source << "\t\tclampedIndex = gSourceHistoryLength - 1;\n"; - source << "\tswitch (clampedIndex)\n"; - source << "\t{\n"; - for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index) - source << "\tcase " << index << ": return gSourceHistory" << index << ".Sample(tc);\n"; - source << "\tdefault: return sampleVideo(tc);\n"; - source << "\t}\n"; - source << "}\n\n"; - source << "float4 sampleTemporalHistory(int framesAgo, float2 tc)\n"; - source << "{\n"; - source << "\tif (gTemporalHistoryLength <= 0)\n"; - source << "\t\treturn sampleVideo(tc);\n"; - source << "\tint clampedIndex = framesAgo;\n"; - source << "\tif (clampedIndex < 0)\n"; - source << "\t\tclampedIndex = 0;\n"; - source << "\tif (clampedIndex >= gTemporalHistoryLength)\n"; - source << "\t\tclampedIndex = gTemporalHistoryLength - 1;\n"; - source << "\tswitch (clampedIndex)\n"; - source << "\t{\n"; - for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index) - source << "\tcase " << index << ": return gTemporalHistory" << index << ".Sample(tc);\n"; - source << "\tdefault: return sampleVideo(tc);\n"; - source << "\t}\n"; - source << "}\n\n"; - source << "#include \"" << shaderPackage.shaderPath.generic_string() << "\"\n\n"; - source << "[shader(\"fragment\")]\n"; - source << "float4 fragmentMain(FragmentInput input) : SV_Target\n"; - source << "{\n"; - source << "\tShaderContext context;\n"; - source << "\tcontext.uv = input.texCoord;\n"; - source << "\tcontext.sourceColor = sampleVideo(context.uv);\n"; - source << "\tcontext.inputResolution = gInputResolution;\n"; - source << "\tcontext.outputResolution = gOutputResolution;\n"; - source << "\tcontext.time = gTime;\n"; - source << "\tcontext.frameCount = gFrameCount;\n"; - source << "\tcontext.mixAmount = gMixAmount;\n"; - source << "\tcontext.bypass = gBypass;\n"; - source << "\tcontext.sourceHistoryLength = gSourceHistoryLength;\n"; - source << "\tcontext.temporalHistoryLength = gTemporalHistoryLength;\n"; - source << "\tfloat4 effectedColor = " << shaderPackage.entryPoint << "(context);\n"; - source << "\tfloat mixValue = clamp(gBypass > 0.5 ? 0.0 : gMixAmount, 0.0, 1.0);\n"; - source << "\treturn lerp(context.sourceColor, effectedColor, mixValue);\n"; - source << "}\n"; - return source.str(); + const std::filesystem::path templatePath = mRepoRoot / "runtime" / "templates" / "shader_wrapper.slang.in"; + wrapperSource = ReadTextFile(templatePath, error); + if (wrapperSource.empty()) + return false; + + wrapperSource = ReplaceAll(wrapperSource, "{{PARAMETER_UNIFORMS}}", BuildParameterUniforms(shaderPackage.parameters)); + wrapperSource = ReplaceAll(wrapperSource, "{{SOURCE_HISTORY_SAMPLERS}}", BuildHistorySamplerDeclarations("gSourceHistory", mMaxTemporalHistoryFrames)); + wrapperSource = ReplaceAll(wrapperSource, "{{TEMPORAL_HISTORY_SAMPLERS}}", BuildHistorySamplerDeclarations("gTemporalHistory", mMaxTemporalHistoryFrames)); + wrapperSource = ReplaceAll(wrapperSource, "{{TEXTURE_SAMPLERS}}", BuildTextureSamplerDeclarations(shaderPackage.textureAssets)); + wrapperSource = ReplaceAll(wrapperSource, "{{SOURCE_HISTORY_SWITCH_CASES}}", BuildHistorySwitchCases("gSourceHistory", mMaxTemporalHistoryFrames)); + wrapperSource = ReplaceAll(wrapperSource, "{{TEMPORAL_HISTORY_SWITCH_CASES}}", BuildHistorySwitchCases("gTemporalHistory", mMaxTemporalHistoryFrames)); + wrapperSource = ReplaceAll(wrapperSource, "{{USER_SHADER_INCLUDE}}", shaderPackage.shaderPath.generic_string()); + wrapperSource = ReplaceAll(wrapperSource, "{{ENTRY_POINT_CALL}}", shaderPackage.entryPoint + "(context)"); + return true; } bool ShaderCompiler::FindSlangCompiler(std::filesystem::path& compilerPath, std::string& error) const @@ -216,12 +171,13 @@ bool ShaderCompiler::RunSlangCompiler(const std::filesystem::path& wrapperPath, return false; } - WaitForSingleObject(processInfo.hProcess, INFINITE); + UniqueHandle processHandle(processInfo.hProcess); + UniqueHandle threadHandle(processInfo.hThread); + + WaitForSingleObject(processHandle.get(), INFINITE); DWORD exitCode = 0; - GetExitCodeProcess(processInfo.hProcess, &exitCode); - CloseHandle(processInfo.hThread); - CloseHandle(processInfo.hProcess); + GetExitCodeProcess(processHandle.get(), &exitCode); if (exitCode != 0) { diff --git a/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.h b/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.h index b93214f..2e4d8eb 100644 --- a/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.h +++ b/apps/LoopThroughWithOpenGLCompositing/ShaderCompiler.h @@ -18,7 +18,7 @@ public: bool BuildLayerFragmentShaderSource(const ShaderPackage& shaderPackage, std::string& fragmentShaderSource, std::string& error) const; private: - std::string BuildWrapperSlangSource(const ShaderPackage& shaderPackage) const; + bool BuildWrapperSlangSource(const ShaderPackage& shaderPackage, std::string& wrapperSource, std::string& error) const; bool FindSlangCompiler(std::filesystem::path& compilerPath, std::string& error) const; bool RunSlangCompiler(const std::filesystem::path& wrapperPath, const std::filesystem::path& outputPath, std::string& error) const; bool PatchGeneratedGlsl(std::string& shaderText, std::string& error) const; diff --git a/apps/LoopThroughWithOpenGLCompositing/VideoFrameTransfer.cpp b/apps/LoopThroughWithOpenGLCompositing/VideoFrameTransfer.cpp index 94f79c1..697099f 100644 --- a/apps/LoopThroughWithOpenGLCompositing/VideoFrameTransfer.cpp +++ b/apps/LoopThroughWithOpenGLCompositing/VideoFrameTransfer.cpp @@ -39,6 +39,7 @@ */ #include "VideoFrameTransfer.h" +#include "NativeHandles.h" #define DVP_CHECK(cmd) { \ @@ -140,20 +141,19 @@ bool VideoFrameTransfer::initializeMemoryLocking(unsigned memSize) { // Increase the process working set size to allow pinning of memory. static SIZE_T dwMin = 0, dwMax = 0; - HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_SET_QUOTA, FALSE, GetCurrentProcessId()); - if (!hProcess) + UniqueHandle processHandle(OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_SET_QUOTA, FALSE, GetCurrentProcessId())); + if (!processHandle.valid()) return false; // Retrieve the working set size of the process. - if (!dwMin && !GetProcessWorkingSetSize(hProcess, &dwMin, &dwMax)) + if (!dwMin && !GetProcessWorkingSetSize(processHandle.get(), &dwMin, &dwMax)) return false; // Allow for 80 frames to be locked - BOOL res = SetProcessWorkingSetSize(hProcess, memSize * 80 + dwMin, memSize * 80 + (dwMax-dwMin)); + BOOL res = SetProcessWorkingSetSize(processHandle.get(), memSize * 80 + dwMin, memSize * 80 + (dwMax-dwMin)); if (!res) return false; - CloseHandle(hProcess); return true; } diff --git a/runtime/templates/shader_wrapper.slang.in b/runtime/templates/shader_wrapper.slang.in new file mode 100644 index 0000000..c6a6557 --- /dev/null +++ b/runtime/templates/shader_wrapper.slang.in @@ -0,0 +1,89 @@ +struct FragmentInput +{ + float4 position : SV_Position; + float2 texCoord : TEXCOORD0; +}; + +struct ShaderContext +{ + float2 uv; + float4 sourceColor; + float2 inputResolution; + float2 outputResolution; + float time; + float frameCount; + float mixAmount; + float bypass; + int sourceHistoryLength; + int temporalHistoryLength; +}; + +cbuffer GlobalParams +{ + float gTime; + float2 gInputResolution; + float2 gOutputResolution; + float gFrameCount; + float gMixAmount; + float gBypass; + int gSourceHistoryLength; + int gTemporalHistoryLength; +{{PARAMETER_UNIFORMS}}}; + +Sampler2D gVideoInput; +{{SOURCE_HISTORY_SAMPLERS}}{{TEMPORAL_HISTORY_SAMPLERS}}{{TEXTURE_SAMPLERS}} +float4 sampleVideo(float2 tc) +{ + return gVideoInput.Sample(tc); +} + +float4 sampleSourceHistory(int framesAgo, float2 tc) +{ + if (gSourceHistoryLength <= 0) + return sampleVideo(tc); + int clampedIndex = framesAgo; + if (clampedIndex < 0) + clampedIndex = 0; + if (clampedIndex >= gSourceHistoryLength) + clampedIndex = gSourceHistoryLength - 1; + switch (clampedIndex) + { +{{SOURCE_HISTORY_SWITCH_CASES}} default: return sampleVideo(tc); + } +} + +float4 sampleTemporalHistory(int framesAgo, float2 tc) +{ + if (gTemporalHistoryLength <= 0) + return sampleVideo(tc); + int clampedIndex = framesAgo; + if (clampedIndex < 0) + clampedIndex = 0; + if (clampedIndex >= gTemporalHistoryLength) + clampedIndex = gTemporalHistoryLength - 1; + switch (clampedIndex) + { +{{TEMPORAL_HISTORY_SWITCH_CASES}} default: return sampleVideo(tc); + } +} + +#include "{{USER_SHADER_INCLUDE}}" + +[shader("fragment")] +float4 fragmentMain(FragmentInput input) : SV_Target +{ + ShaderContext context; + context.uv = input.texCoord; + context.sourceColor = sampleVideo(context.uv); + context.inputResolution = gInputResolution; + context.outputResolution = gOutputResolution; + context.time = gTime; + context.frameCount = gFrameCount; + context.mixAmount = gMixAmount; + context.bypass = gBypass; + context.sourceHistoryLength = gSourceHistoryLength; + context.temporalHistoryLength = gTemporalHistoryLength; + float4 effectedColor = {{ENTRY_POINT_CALL}}; + float mixValue = clamp(gBypass > 0.5 ? 0.0 : gMixAmount, 0.0, 1.0); + return lerp(context.sourceColor, effectedColor, mixValue); +} diff --git a/tests/RuntimeJsonTests.cpp b/tests/RuntimeJsonTests.cpp new file mode 100644 index 0000000..b87bfbe --- /dev/null +++ b/tests/RuntimeJsonTests.cpp @@ -0,0 +1,119 @@ +#include "RuntimeJson.h" + +#include +#include +#include + +namespace +{ +int gFailures = 0; + +void Expect(bool condition, const char* message) +{ + if (condition) + return; + + std::cerr << "FAIL: " << message << "\n"; + ++gFailures; +} + +JsonValue ParseOrFail(const std::string& text, const char* message) +{ + JsonValue value; + std::string error; + if (!ParseJson(text, value, error)) + { + std::cerr << "FAIL: " << message << ": " << error << "\n"; + ++gFailures; + } + return value; +} + +void ExpectParseFails(const std::string& text, const char* message) +{ + JsonValue value; + std::string error; + Expect(!ParseJson(text, value, error), message); + Expect(!error.empty(), "failed parse should include an error message"); +} + +void TestStringsAndEscaping() +{ + const JsonValue value = ParseOrFail(R"({"text":"line\nquote:\" slash:\\ tab:\t"})", "escaped string parses"); + const JsonValue* text = value.find("text"); + Expect(text && text->isString(), "escaped string field is present"); + Expect(text && text->asString().find("line\nquote:\" slash:\\ tab:\t") != std::string::npos, "escaped string unescapes expected characters"); + + JsonValue controlString(std::string("a\001b", 3)); + Expect(SerializeJson(controlString, false) == R"("a\u0001b")", "serializer escapes raw control characters"); + ExpectParseFails(std::string("{\"bad\":\"a") + static_cast(1) + "b\"}", "raw control characters are rejected"); +} + +void TestUnicodeEscapes() +{ + const JsonValue copyright = ParseOrFail(R"({"s":"\u00A9"})", "basic unicode escape parses"); + Expect(copyright.find("s") && copyright.find("s")->asString() == "\xC2\xA9", "basic unicode escape becomes UTF-8"); + + const JsonValue music = ParseOrFail(R"({"s":"\uD834\uDD1E"})", "surrogate pair parses"); + Expect(music.find("s") && music.find("s")->asString() == "\xF0\x9D\x84\x9E", "surrogate pair becomes UTF-8"); + + ExpectParseFails(R"({"s":"\uD834x"})", "unpaired high surrogate is rejected"); + ExpectParseFails(R"({"s":"\uDD1E"})", "unpaired low surrogate is rejected"); + ExpectParseFails(R"({"s":"\u00XZ"})", "invalid unicode hex is rejected"); +} + +void TestNumbers() +{ + const JsonValue value = ParseOrFail(R"({"a":0,"b":-12.5e+2,"c":3.25})", "valid numbers parse"); + Expect(value.find("a") && value.find("a")->asNumber(-1.0) == 0.0, "zero parses"); + Expect(value.find("b") && std::fabs(value.find("b")->asNumber() + 1250.0) < 0.0001, "exponent parses"); + Expect(value.find("c") && std::fabs(value.find("c")->asNumber() - 3.25) < 0.0001, "fraction parses"); + + ExpectParseFails("01", "leading zero is rejected"); + ExpectParseFails("1.", "fraction without digits is rejected"); + ExpectParseFails("1e", "exponent without digits is rejected"); + ExpectParseFails("1e9999", "non-finite number is rejected"); +} + +void TestRoundtripAndMutation() +{ + JsonValue root = JsonValue::MakeObject(); + root.set("version", JsonValue(1.0)); + root.set("name", JsonValue("Preset One")); + + JsonValue layers = JsonValue::MakeArray(); + JsonValue layer = JsonValue::MakeObject(); + layer.set("id", JsonValue("layer-a")); + layer.set("bypass", JsonValue(false)); + layers.pushBack(layer); + root.set("layers", layers); + + const std::string serialized = SerializeJson(root, true); + const JsonValue parsed = ParseOrFail(serialized, "serialized preset-style object parses"); + Expect(parsed.find("layers") && parsed.find("layers")->isArray(), "roundtripped layers array exists"); + + JsonValue mutableValue = JsonValue::MakeObject(); + mutableValue.set("stale", JsonValue("value")); + mutableValue.pushBack(JsonValue(7.0)); + Expect(mutableValue.isArray(), "pushBack changes object into array"); + Expect(mutableValue.asObject().empty(), "object storage is cleared when changing to array"); + Expect(mutableValue.asArray().size() == 1, "new array receives pushed value"); +} +} + +int main() +{ + TestStringsAndEscaping(); + TestUnicodeEscapes(); + TestNumbers(); + TestRoundtripAndMutation(); + + if (gFailures != 0) + { + std::cerr << gFailures << " RuntimeJson test failure(s).\n"; + return 1; + } + + std::cout << "RuntimeJson tests passed.\n"; + return 0; +}