This commit is contained in:
2026-05-03 11:16:56 +10:00
parent f6db9ee3e6
commit ee929374a8
14 changed files with 873 additions and 350 deletions

4
.gitignore vendored
View File

@@ -41,6 +41,8 @@ build.ninja
*.log *.log
*.dmp *.dmp
*.tmp *.tmp
/runtime/ /runtime/*
!/runtime/templates/
!/runtime/templates/**
/ui/node_modules/ /ui/node_modules/
/ui/dist/ /ui/dist/

View File

@@ -26,6 +26,8 @@ set(APP_SOURCES
"${APP_DIR}/LoopThroughWithOpenGLCompositing.cpp" "${APP_DIR}/LoopThroughWithOpenGLCompositing.cpp"
"${APP_DIR}/LoopThroughWithOpenGLCompositing.h" "${APP_DIR}/LoopThroughWithOpenGLCompositing.h"
"${APP_DIR}/LoopThroughWithOpenGLCompositing.rc" "${APP_DIR}/LoopThroughWithOpenGLCompositing.rc"
"${APP_DIR}/NativeHandles.h"
"${APP_DIR}/NativeSockets.h"
"${APP_DIR}/OpenGLComposite.cpp" "${APP_DIR}/OpenGLComposite.cpp"
"${APP_DIR}/OpenGLComposite.h" "${APP_DIR}/OpenGLComposite.h"
"${APP_DIR}/resource.h" "${APP_DIR}/resource.h"
@@ -72,6 +74,22 @@ if(MSVC)
target_compile_options(LoopThroughWithOpenGLCompositing PRIVATE /W3) target_compile_options(LoopThroughWithOpenGLCompositing PRIVATE /W3)
endif() endif()
add_executable(RuntimeJsonTests
"${APP_DIR}/RuntimeJson.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tests/RuntimeJsonTests.cpp"
)
target_include_directories(RuntimeJsonTests PRIVATE
"${APP_DIR}"
)
if(MSVC)
target_compile_options(RuntimeJsonTests PRIVATE /W3)
endif()
enable_testing()
add_test(NAME RuntimeJsonTests COMMAND RuntimeJsonTests)
add_custom_command(TARGET LoopThroughWithOpenGLCompositing POST_BUILD add_custom_command(TARGET LoopThroughWithOpenGLCompositing POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${GPUDIRECT_DIR}/bin/x64/dvp.dll" "${GPUDIRECT_DIR}/bin/x64/dvp.dll"

View File

@@ -69,7 +69,7 @@ std::string GuessContentType(const std::filesystem::path& assetPath)
} }
ControlServer::ControlServer() ControlServer::ControlServer()
: mListenSocket(INVALID_SOCKET), mPort(0), mRunning(false) : mPort(0), mRunning(false)
{ {
} }
@@ -86,15 +86,15 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr
if (!InitializeWinsock(error)) if (!InitializeWinsock(error))
return false; return false;
mListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); mListenSocket.reset(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
if (mListenSocket == INVALID_SOCKET) if (!mListenSocket.valid())
{ {
error = "Could not create listening socket."; error = "Could not create listening socket.";
return false; return false;
} }
u_long nonBlocking = 1; u_long nonBlocking = 1;
ioctlsocket(mListenSocket, FIONBIO, &nonBlocking); ioctlsocket(mListenSocket.get(), FIONBIO, &nonBlocking);
sockaddr_in address = {}; sockaddr_in address = {};
address.sin_family = AF_INET; address.sin_family = AF_INET;
@@ -104,7 +104,7 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr
for (unsigned short offset = 0; offset < 20; ++offset) for (unsigned short offset = 0; offset < 20; ++offset)
{ {
address.sin_port = htons(static_cast<u_short>(preferredPort + offset)); address.sin_port = htons(static_cast<u_short>(preferredPort + offset));
if (bind(mListenSocket, reinterpret_cast<sockaddr*>(&address), sizeof(address)) == 0) if (bind(mListenSocket.get(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == 0)
{ {
mPort = preferredPort + offset; mPort = preferredPort + offset;
bound = true; bound = true;
@@ -115,16 +115,14 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr
if (!bound) if (!bound)
{ {
error = "Could not bind the local control server to any port in the preferred range."; error = "Could not bind the local control server to any port in the preferred range.";
closesocket(mListenSocket); mListenSocket.reset();
mListenSocket = INVALID_SOCKET;
return false; return false;
} }
if (listen(mListenSocket, SOMAXCONN) != 0) if (listen(mListenSocket.get(), SOMAXCONN) != 0)
{ {
error = "Could not start listening on the local control server socket."; error = "Could not start listening on the local control server socket.";
closesocket(mListenSocket); mListenSocket.reset();
mListenSocket = INVALID_SOCKET;
return false; return false;
} }
@@ -135,27 +133,17 @@ bool ControlServer::Start(const std::filesystem::path& uiRoot, unsigned short pr
void ControlServer::Stop() void ControlServer::Stop()
{ {
const bool wasActive = mRunning || mListenSocket != INVALID_SOCKET || mThread.joinable(); const bool wasActive = mRunning || mListenSocket.valid() || mThread.joinable();
mRunning = false; mRunning = false;
{ {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
for (ClientConnection& client : mClients) for (ClientConnection& client : mClients)
{ client.socket.reset();
if (client.socket != INVALID_SOCKET)
{
closesocket(client.socket);
client.socket = INVALID_SOCKET;
}
}
mClients.clear(); mClients.clear();
} }
if (mListenSocket != INVALID_SOCKET) mListenSocket.reset();
{
closesocket(mListenSocket);
mListenSocket = INVALID_SOCKET;
}
if (mThread.joinable()) if (mThread.joinable())
mThread.join(); mThread.join();
@@ -179,30 +167,27 @@ void ControlServer::ServerLoop()
} }
} }
bool ControlServer::HandleHttpClient(SOCKET clientSocket) bool ControlServer::HandleHttpClient(UniqueSocket clientSocket)
{ {
std::string request; std::string request;
char buffer[8192]; char buffer[8192];
int received = recv(clientSocket, buffer, sizeof(buffer), 0); int received = recv(clientSocket.get(), buffer, sizeof(buffer), 0);
if (received <= 0) if (received <= 0)
return false; return false;
request.assign(buffer, buffer + received); request.assign(buffer, buffer + received);
return HandleHttpRequest(clientSocket, request); return HandleHttpRequest(std::move(clientSocket), request);
} }
bool ControlServer::TryAcceptClient() bool ControlServer::TryAcceptClient()
{ {
sockaddr_in clientAddress = {}; sockaddr_in clientAddress = {};
int addressSize = sizeof(clientAddress); int addressSize = sizeof(clientAddress);
SOCKET clientSocket = accept(mListenSocket, reinterpret_cast<sockaddr*>(&clientAddress), &addressSize); UniqueSocket clientSocket(accept(mListenSocket.get(), reinterpret_cast<sockaddr*>(&clientAddress), &addressSize));
if (clientSocket == INVALID_SOCKET) if (!clientSocket.valid())
return false; return false;
bool handled = HandleHttpClient(clientSocket); return HandleHttpClient(std::move(clientSocket));
if (!handled)
closesocket(clientSocket);
return handled;
} }
bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body) bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body)
@@ -218,144 +203,181 @@ bool ControlServer::SendHttpResponse(SOCKET clientSocket, const std::string& sta
return send(clientSocket, payload.c_str(), static_cast<int>(payload.size()), 0) == static_cast<int>(payload.size()); return send(clientSocket, payload.c_str(), static_cast<int>(payload.size()), 0) == static_cast<int>(payload.size());
} }
bool ControlServer::HandleHttpRequest(SOCKET clientSocket, const std::string& request) bool ControlServer::SendHttpResponse(SOCKET clientSocket, const HttpResponse& response)
{ {
const std::string method = GetRequestMethod(request); return SendHttpResponse(clientSocket, response.status, response.contentType, response.body);
const std::string path = GetRequestPath(request);
if (ToLower(GetHeaderValue(request, "Upgrade")) == "websocket")
return HandleWebSocketUpgrade(clientSocket, request);
if (method == "GET")
{
if (path == "/" || path == "/index.html")
{
std::string contentType;
std::string body = LoadUiAsset("index.html", contentType);
SendHttpResponse(clientSocket, "200 OK", contentType, body);
closesocket(clientSocket);
return true;
} }
if (path == "/api/state")
bool ControlServer::HandleHttpRequest(UniqueSocket clientSocket, const std::string& request)
{ {
SendHttpResponse(clientSocket, "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"); HttpRequest httpRequest;
closesocket(clientSocket); if (!ParseHttpRequest(request, httpRequest))
{
SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Bad Request");
return true; return true;
} }
std::string contentType; if (ToLower(GetHeaderValue(httpRequest, "Upgrade")) == "websocket")
std::string body = LoadUiAsset(path.substr(1), contentType); return HandleWebSocketUpgrade(std::move(clientSocket), httpRequest);
if (!body.empty())
{
SendHttpResponse(clientSocket, "200 OK", contentType, body);
closesocket(clientSocket);
return true;
}
}
else if (method == "POST")
{
std::string body = GetRequestBody(request);
JsonValue root;
std::string parseError;
if (!ParseJson(body, root, parseError))
{
SendHttpResponse(clientSocket, "400 Bad Request", "application/json", BuildJsonResponse(false, parseError));
closesocket(clientSocket);
return true;
}
bool success = false; const HttpResponse response = RouteHttpRequest(httpRequest);
std::string actionError; SendHttpResponse(clientSocket.get(), response);
if (response.broadcastState)
if (path == "/api/layers/add")
{
const JsonValue* shaderId = root.find("shaderId");
success = shaderId && mCallbacks.addLayer && mCallbacks.addLayer(shaderId->asString(), actionError);
}
else if (path == "/api/layers/remove")
{
const JsonValue* layerId = root.find("layerId");
success = layerId && mCallbacks.removeLayer && mCallbacks.removeLayer(layerId->asString(), actionError);
}
else if (path == "/api/layers/move")
{
const JsonValue* layerId = root.find("layerId");
const JsonValue* direction = root.find("direction");
if (layerId && direction && mCallbacks.moveLayer)
success = mCallbacks.moveLayer(layerId->asString(), static_cast<int>(direction->asNumber()), actionError);
}
else if (path == "/api/layers/reorder")
{
const JsonValue* layerId = root.find("layerId");
const JsonValue* targetIndex = root.find("targetIndex");
if (layerId && targetIndex && mCallbacks.moveLayerToIndex)
success = mCallbacks.moveLayerToIndex(layerId->asString(), static_cast<std::size_t>(targetIndex->asNumber()), actionError);
}
else if (path == "/api/layers/set-bypass")
{
const JsonValue* layerId = root.find("layerId");
const JsonValue* bypass = root.find("bypass");
if (layerId && bypass && mCallbacks.setLayerBypass)
success = mCallbacks.setLayerBypass(layerId->asString(), bypass->asBoolean(), actionError);
}
else if (path == "/api/layers/set-shader")
{
const JsonValue* layerId = root.find("layerId");
const JsonValue* shaderId = root.find("shaderId");
if (layerId && shaderId && mCallbacks.setLayerShader)
success = mCallbacks.setLayerShader(layerId->asString(), shaderId->asString(), actionError);
}
else if (path == "/api/layers/update-parameter")
{
const JsonValue* layerId = root.find("layerId");
const JsonValue* parameterId = root.find("parameterId");
const JsonValue* value = root.find("value");
if (layerId && parameterId && value && mCallbacks.updateLayerParameter)
success = mCallbacks.updateLayerParameter(layerId->asString(), parameterId->asString(), SerializeJson(*value, false), actionError);
}
else if (path == "/api/layers/reset-parameters")
{
const JsonValue* layerId = root.find("layerId");
if (layerId && mCallbacks.resetLayerParameters)
success = mCallbacks.resetLayerParameters(layerId->asString(), actionError);
}
else if (path == "/api/stack-presets/save")
{
const JsonValue* presetName = root.find("presetName");
if (presetName && mCallbacks.saveStackPreset)
success = mCallbacks.saveStackPreset(presetName->asString(), actionError);
}
else if (path == "/api/stack-presets/load")
{
const JsonValue* presetName = root.find("presetName");
if (presetName && mCallbacks.loadStackPreset)
success = mCallbacks.loadStackPreset(presetName->asString(), actionError);
}
else if (path == "/api/reload")
{
if (mCallbacks.reloadShader)
success = mCallbacks.reloadShader(actionError);
}
SendHttpResponse(clientSocket, success ? "200 OK" : "400 Bad Request", "application/json", BuildJsonResponse(success, actionError));
closesocket(clientSocket);
if (success)
BroadcastState(); BroadcastState();
return true; return true;
} }
SendHttpResponse(clientSocket, "404 Not Found", "text/plain", "Not Found"); ControlServer::HttpResponse ControlServer::RouteHttpRequest(const HttpRequest& request)
closesocket(clientSocket); {
return true; if (request.method == "GET")
return ServeGetRequest(request);
if (request.method == "POST")
return HandleApiPost(request);
return { "404 Not Found", "text/plain", "Not Found" };
} }
bool ControlServer::HandleWebSocketUpgrade(SOCKET clientSocket, const std::string& request) ControlServer::HttpResponse ControlServer::ServeGetRequest(const HttpRequest& request) const
{
if (request.path == "/" || request.path == "/index.html")
return ServeUiAsset("index.html");
if (request.path == "/api/state")
return { "200 OK", "application/json", mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}" };
if (request.path.size() > 1)
{
const HttpResponse assetResponse = ServeUiAsset(request.path.substr(1));
if (!assetResponse.body.empty())
return assetResponse;
}
return { "404 Not Found", "text/plain", "Not Found" };
}
ControlServer::HttpResponse ControlServer::ServeUiAsset(const std::string& relativePath) const
{
std::string contentType;
const std::string body = LoadUiAsset(relativePath, contentType);
return body.empty()
? HttpResponse{ "404 Not Found", "text/plain", "Not Found" }
: HttpResponse{ "200 OK", contentType, body };
}
ControlServer::HttpResponse ControlServer::HandleApiPost(const HttpRequest& request)
{
JsonValue root;
std::string parseError;
if (!ParseJson(request.body, root, parseError))
return { "400 Bad Request", "application/json", BuildJsonResponse(false, parseError) };
std::string actionError;
const bool success = InvokePostRoute(request.path, root, actionError);
return {
success ? "200 OK" : "400 Bad Request",
"application/json",
BuildJsonResponse(success, actionError),
success
};
}
bool ControlServer::InvokePostRoute(const std::string& path, const JsonValue& root, std::string& actionError)
{
using PostHandler = std::function<bool(const JsonValue&, std::string&)>;
const std::map<std::string, PostHandler> postRoutes =
{
{ "/api/layers/add", [this](const JsonValue& json, std::string& error)
{
const JsonValue* shaderId = json.find("shaderId");
return shaderId && mCallbacks.addLayer && mCallbacks.addLayer(shaderId->asString(), error);
}
},
{ "/api/layers/remove", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
return layerId && mCallbacks.removeLayer && mCallbacks.removeLayer(layerId->asString(), error);
}
},
{ "/api/layers/move", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* direction = json.find("direction");
return layerId && direction && mCallbacks.moveLayer &&
mCallbacks.moveLayer(layerId->asString(), static_cast<int>(direction->asNumber()), error);
}
},
{ "/api/layers/reorder", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* targetIndex = json.find("targetIndex");
return layerId && targetIndex && mCallbacks.moveLayerToIndex &&
mCallbacks.moveLayerToIndex(layerId->asString(), static_cast<std::size_t>(targetIndex->asNumber()), error);
}
},
{ "/api/layers/set-bypass", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* bypass = json.find("bypass");
return layerId && bypass && mCallbacks.setLayerBypass &&
mCallbacks.setLayerBypass(layerId->asString(), bypass->asBoolean(), error);
}
},
{ "/api/layers/set-shader", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* shaderId = json.find("shaderId");
return layerId && shaderId && mCallbacks.setLayerShader &&
mCallbacks.setLayerShader(layerId->asString(), shaderId->asString(), error);
}
},
{ "/api/layers/update-parameter", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
const JsonValue* parameterId = json.find("parameterId");
const JsonValue* value = json.find("value");
return layerId && parameterId && value && mCallbacks.updateLayerParameter &&
mCallbacks.updateLayerParameter(layerId->asString(), parameterId->asString(), SerializeJson(*value, false), error);
}
},
{ "/api/layers/reset-parameters", [this](const JsonValue& json, std::string& error)
{
const JsonValue* layerId = json.find("layerId");
return layerId && mCallbacks.resetLayerParameters &&
mCallbacks.resetLayerParameters(layerId->asString(), error);
}
},
{ "/api/stack-presets/save", [this](const JsonValue& json, std::string& error)
{
const JsonValue* presetName = json.find("presetName");
return presetName && mCallbacks.saveStackPreset &&
mCallbacks.saveStackPreset(presetName->asString(), error);
}
},
{ "/api/stack-presets/load", [this](const JsonValue& json, std::string& error)
{
const JsonValue* presetName = json.find("presetName");
return presetName && mCallbacks.loadStackPreset &&
mCallbacks.loadStackPreset(presetName->asString(), error);
}
},
{ "/api/reload", [this](const JsonValue&, std::string& error)
{
return mCallbacks.reloadShader && mCallbacks.reloadShader(error);
}
}
};
const auto route = postRoutes.find(path);
return route != postRoutes.end() && route->second(root, actionError);
}
bool ControlServer::HandleWebSocketUpgrade(UniqueSocket clientSocket, const HttpRequest& request)
{ {
const std::string clientKey = GetHeaderValue(request, "Sec-WebSocket-Key"); const std::string clientKey = GetHeaderValue(request, "Sec-WebSocket-Key");
if (clientKey.empty()) if (clientKey.empty())
{ {
SendHttpResponse(clientSocket, "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key"); SendHttpResponse(clientSocket.get(), "400 Bad Request", "text/plain", "Missing Sec-WebSocket-Key");
closesocket(clientSocket);
return true; return true;
} }
@@ -366,14 +388,14 @@ bool ControlServer::HandleWebSocketUpgrade(SOCKET clientSocket, const std::strin
response << "Sec-WebSocket-Accept: " << ComputeWebSocketAcceptKey(clientKey) << "\r\n\r\n"; response << "Sec-WebSocket-Accept: " << ComputeWebSocketAcceptKey(clientKey) << "\r\n\r\n";
const std::string payload = response.str(); const std::string payload = response.str();
send(clientSocket, payload.c_str(), static_cast<int>(payload.size()), 0); send(clientSocket.get(), payload.c_str(), static_cast<int>(payload.size()), 0);
{ {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
ClientConnection client; ClientConnection client;
client.socket = clientSocket; client.socket.reset(clientSocket.release());
client.websocket = true; client.websocket = true;
mClients.push_back(client); mClients.push_back(std::move(client));
BroadcastStateLocked(); BroadcastStateLocked();
} }
return true; return true;
@@ -409,9 +431,8 @@ void ControlServer::BroadcastStateLocked()
const std::string stateMessage = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}"; const std::string stateMessage = mCallbacks.getStateJson ? mCallbacks.getStateJson() : "{}";
for (auto it = mClients.begin(); it != mClients.end();) for (auto it = mClients.begin(); it != mClients.end();)
{ {
if (!SendWebSocketText(it->socket, stateMessage)) if (!SendWebSocketText(it->socket.get(), stateMessage))
{ {
closesocket(it->socket);
it = mClients.erase(it); it = mClients.erase(it);
} }
else else
@@ -480,46 +501,55 @@ std::string ControlServer::ComputeWebSocketAcceptKey(const std::string& clientKe
return Base64Encode(digest, digestLength); return Base64Encode(digest, digestLength);
} }
std::string ControlServer::GetHeaderValue(const std::string& request, const std::string& headerName) std::string ControlServer::GetHeaderValue(const HttpRequest& request, const std::string& headerName)
{ {
const std::string lowerRequest = ToLower(request); const auto header = request.headers.find(ToLower(headerName));
const std::string lowerHeaderName = ToLower(headerName) + ":"; return header == request.headers.end() ? std::string() : header->second;
const std::size_t start = lowerRequest.find(lowerHeaderName); }
if (start == std::string::npos)
return std::string();
const std::size_t valueStart = start + lowerHeaderName.size(); bool ControlServer::ParseHttpRequest(const std::string& rawRequest, HttpRequest& request)
const std::size_t lineEnd = request.find("\r\n", valueStart); {
if (lineEnd == std::string::npos) const std::size_t requestLineEnd = rawRequest.find("\r\n");
return std::string(); if (requestLineEnd == std::string::npos)
return false;
std::string value = request.substr(valueStart, lineEnd - valueStart); const std::string requestLine = rawRequest.substr(0, requestLineEnd);
const std::size_t methodEnd = requestLine.find(' ');
if (methodEnd == std::string::npos)
return false;
const std::size_t pathEnd = requestLine.find(' ', methodEnd + 1);
if (pathEnd == std::string::npos)
return false;
request.method = requestLine.substr(0, methodEnd);
request.path = requestLine.substr(methodEnd + 1, pathEnd - methodEnd - 1);
request.headers.clear();
const std::size_t headersStart = requestLineEnd + 2;
const std::size_t bodySeparator = rawRequest.find("\r\n\r\n", headersStart);
const std::size_t headersEnd = bodySeparator == std::string::npos ? rawRequest.size() : bodySeparator;
for (std::size_t lineStart = headersStart; lineStart < headersEnd;)
{
const std::size_t lineEnd = rawRequest.find("\r\n", lineStart);
const std::size_t currentLineEnd = lineEnd == std::string::npos ? headersEnd : std::min(lineEnd, headersEnd);
const std::string line = rawRequest.substr(lineStart, currentLineEnd - lineStart);
const std::size_t separator = line.find(':');
if (separator != std::string::npos)
{
const std::string key = ToLower(line.substr(0, separator));
std::string value = line.substr(separator + 1);
const std::size_t first = value.find_first_not_of(" \t"); const std::size_t first = value.find_first_not_of(" \t");
const std::size_t last = value.find_last_not_of(" \t"); const std::size_t last = value.find_last_not_of(" \t");
return first == std::string::npos ? std::string() : value.substr(first, last - first + 1); request.headers[key] = first == std::string::npos ? std::string() : value.substr(first, last - first + 1);
} }
std::string ControlServer::GetRequestPath(const std::string& request) if (lineEnd == std::string::npos || lineEnd >= headersEnd)
{ break;
const std::size_t methodEnd = request.find(' '); lineStart = lineEnd + 2;
if (methodEnd == std::string::npos)
return "/";
const std::size_t pathEnd = request.find(' ', methodEnd + 1);
if (pathEnd == std::string::npos)
return "/";
return request.substr(methodEnd + 1, pathEnd - methodEnd - 1);
} }
std::string ControlServer::GetRequestMethod(const std::string& request) request.body = bodySeparator == std::string::npos ? std::string() : rawRequest.substr(bodySeparator + 4);
{ return !request.method.empty() && !request.path.empty();
const std::size_t methodEnd = request.find(' ');
return methodEnd == std::string::npos ? std::string() : request.substr(0, methodEnd);
}
std::string ControlServer::GetRequestBody(const std::string& request)
{
const std::size_t separator = request.find("\r\n\r\n");
if (separator == std::string::npos)
return std::string();
return request.substr(separator + 4);
} }

View File

@@ -1,15 +1,20 @@
#pragma once #pragma once
#include "NativeSockets.h"
#include <winsock2.h> #include <winsock2.h>
#include <atomic> #include <atomic>
#include <filesystem> #include <filesystem>
#include <functional> #include <functional>
#include <map>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
class JsonValue;
class ControlServer class ControlServer
{ {
public: public:
@@ -41,31 +46,51 @@ public:
private: private:
struct ClientConnection struct ClientConnection
{ {
SOCKET socket = INVALID_SOCKET; UniqueSocket socket;
bool websocket = false; bool websocket = false;
}; };
struct HttpRequest
{
std::string method;
std::string path;
std::map<std::string, std::string> headers;
std::string body;
};
struct HttpResponse
{
std::string status;
std::string contentType;
std::string body;
bool broadcastState = false;
};
void ServerLoop(); void ServerLoop();
bool HandleHttpClient(SOCKET clientSocket); bool HandleHttpClient(UniqueSocket clientSocket);
bool TryAcceptClient(); bool TryAcceptClient();
bool SendHttpResponse(SOCKET clientSocket, const HttpResponse& response);
bool SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body); bool SendHttpResponse(SOCKET clientSocket, const std::string& status, const std::string& contentType, const std::string& body);
bool HandleHttpRequest(SOCKET clientSocket, const std::string& request); bool HandleHttpRequest(UniqueSocket clientSocket, const std::string& request);
bool HandleWebSocketUpgrade(SOCKET clientSocket, const std::string& request); bool HandleWebSocketUpgrade(UniqueSocket clientSocket, const HttpRequest& request);
HttpResponse RouteHttpRequest(const HttpRequest& request);
HttpResponse ServeGetRequest(const HttpRequest& request) const;
HttpResponse ServeUiAsset(const std::string& relativePath) const;
HttpResponse HandleApiPost(const HttpRequest& request);
bool InvokePostRoute(const std::string& path, const JsonValue& root, std::string& actionError);
bool SendWebSocketText(SOCKET clientSocket, const std::string& payload); bool SendWebSocketText(SOCKET clientSocket, const std::string& payload);
void BroadcastStateLocked(); void BroadcastStateLocked();
std::string LoadUiAsset(const std::string& relativePath, std::string& contentType) const; std::string LoadUiAsset(const std::string& relativePath, std::string& contentType) const;
std::string BuildJsonResponse(bool success, const std::string& error = std::string()) const; std::string BuildJsonResponse(bool success, const std::string& error = std::string()) const;
static std::string Base64Encode(const unsigned char* data, DWORD dataLength); static std::string Base64Encode(const unsigned char* data, DWORD dataLength);
static std::string ComputeWebSocketAcceptKey(const std::string& clientKey); static std::string ComputeWebSocketAcceptKey(const std::string& clientKey);
static std::string GetHeaderValue(const std::string& request, const std::string& headerName); static std::string GetHeaderValue(const HttpRequest& request, const std::string& headerName);
static std::string GetRequestPath(const std::string& request); static bool ParseHttpRequest(const std::string& rawRequest, HttpRequest& request);
static std::string GetRequestMethod(const std::string& request);
static std::string GetRequestBody(const std::string& request);
private: private:
std::filesystem::path mUiRoot; std::filesystem::path mUiRoot;
Callbacks mCallbacks; Callbacks mCallbacks;
SOCKET mListenSocket; UniqueSocket mListenSocket;
unsigned short mPort; unsigned short mPort;
std::thread mThread; std::thread mThread;
std::atomic<bool> mRunning; std::atomic<bool> mRunning;

View File

@@ -0,0 +1,42 @@
#pragma once
#include <windows.h>
class UniqueHandle
{
public:
explicit UniqueHandle(HANDLE handle = NULL) : mHandle(handle) {}
~UniqueHandle() { reset(); }
UniqueHandle(const UniqueHandle&) = delete;
UniqueHandle& operator=(const UniqueHandle&) = delete;
UniqueHandle(UniqueHandle&& other) noexcept : mHandle(other.release()) {}
UniqueHandle& operator=(UniqueHandle&& other) noexcept
{
if (this != &other)
reset(other.release());
return *this;
}
HANDLE get() const { return mHandle; }
bool valid() const { return mHandle != NULL && mHandle != INVALID_HANDLE_VALUE; }
HANDLE release()
{
HANDLE handle = mHandle;
mHandle = NULL;
return handle;
}
void reset(HANDLE handle = NULL)
{
if (valid())
CloseHandle(mHandle);
mHandle = handle;
}
private:
HANDLE mHandle;
};

View File

@@ -0,0 +1,42 @@
#pragma once
#include <winsock2.h>
class UniqueSocket
{
public:
explicit UniqueSocket(SOCKET socket = INVALID_SOCKET) : mSocket(socket) {}
~UniqueSocket() { reset(); }
UniqueSocket(const UniqueSocket&) = delete;
UniqueSocket& operator=(const UniqueSocket&) = delete;
UniqueSocket(UniqueSocket&& other) noexcept : mSocket(other.release()) {}
UniqueSocket& operator=(UniqueSocket&& other) noexcept
{
if (this != &other)
reset(other.release());
return *this;
}
SOCKET get() const { return mSocket; }
bool valid() const { return mSocket != INVALID_SOCKET; }
SOCKET release()
{
SOCKET socket = mSocket;
mSocket = INVALID_SOCKET;
return socket;
}
void reset(SOCKET socket = INVALID_SOCKET)
{
if (valid())
closesocket(mSocket);
mSocket = socket;
}
private:
SOCKET mSocket;
};

View File

@@ -248,6 +248,47 @@ bool NumberListFromJsonValue(const JsonValue& value, std::vector<double>& number
return false; return false;
} }
bool ValidateShaderIdentifier(const std::string& identifier, const std::string& fieldName, const std::filesystem::path& manifestPath, std::string& error)
{
if (identifier.empty() || !(std::isalpha(static_cast<unsigned char>(identifier.front())) || identifier.front() == '_'))
{
error = "Shader manifest field '" + fieldName + "' must be a valid shader identifier in: " + ManifestPathMessage(manifestPath);
return false;
}
for (char ch : identifier)
{
const unsigned char unsignedCh = static_cast<unsigned char>(ch);
if (!(std::isalnum(unsignedCh) || ch == '_'))
{
error = "Shader manifest field '" + fieldName + "' must be a valid shader identifier in: " + ManifestPathMessage(manifestPath);
return false;
}
}
return true;
}
bool ParseShaderMetadata(const JsonValue& manifestJson, ShaderPackage& shaderPackage, const std::filesystem::path& manifestPath, std::string& error)
{
if (!RequireStringField(manifestJson, "id", shaderPackage.id, manifestPath, error) ||
!RequireStringField(manifestJson, "name", shaderPackage.displayName, manifestPath, error) ||
!OptionalStringField(manifestJson, "description", shaderPackage.description, "", manifestPath, error) ||
!OptionalStringField(manifestJson, "category", shaderPackage.category, "", manifestPath, error) ||
!OptionalStringField(manifestJson, "entryPoint", shaderPackage.entryPoint, "shadeVideo", manifestPath, error))
{
return false;
}
if (!ValidateShaderIdentifier(shaderPackage.entryPoint, "entryPoint", manifestPath, error))
return false;
shaderPackage.directoryPath = manifestPath.parent_path();
shaderPackage.shaderPath = shaderPackage.directoryPath / "shader.slang";
shaderPackage.manifestPath = manifestPath;
return true;
}
bool ParseTextureAssets(const JsonValue& manifestJson, ShaderPackage& shaderPackage, const std::filesystem::path& manifestPath, std::string& error) bool ParseTextureAssets(const JsonValue& manifestJson, ShaderPackage& shaderPackage, const std::filesystem::path& manifestPath, std::string& error)
{ {
const JsonValue* texturesValue = nullptr; const JsonValue* texturesValue = nullptr;
@@ -272,6 +313,8 @@ bool ParseTextureAssets(const JsonValue& manifestJson, ShaderPackage& shaderPack
error = "Shader texture is missing required 'id' or 'path' in: " + ManifestPathMessage(manifestPath); error = "Shader texture is missing required 'id' or 'path' in: " + ManifestPathMessage(manifestPath);
return false; return false;
} }
if (!ValidateShaderIdentifier(textureId, "textures[].id", manifestPath, error))
return false;
ShaderTextureAsset textureAsset; ShaderTextureAsset textureAsset;
textureAsset.id = textureId; textureAsset.id = textureId;
@@ -374,7 +417,7 @@ bool ParseParameterDefault(const JsonValue& parameterJson, ShaderParameterDefini
return NumberListFromJsonValue(*defaultValue, definition.defaultNumbers, "default", manifestPath, error); return NumberListFromJsonValue(*defaultValue, definition.defaultNumbers, "default", manifestPath, error);
} }
bool ParseEnumOptions(const JsonValue& parameterJson, ShaderParameterDefinition& definition, const std::filesystem::path& manifestPath, std::string& error) bool ParseParameterOptions(const JsonValue& parameterJson, ShaderParameterDefinition& definition, const std::filesystem::path& manifestPath, std::string& error)
{ {
const JsonValue* optionsValue = nullptr; const JsonValue* optionsValue = nullptr;
if (!OptionalArrayField(parameterJson, "options", optionsValue, manifestPath, error) || !optionsValue) if (!OptionalArrayField(parameterJson, "options", optionsValue, manifestPath, error) || !optionsValue)
@@ -442,6 +485,8 @@ bool ParseParameterDefinition(const JsonValue& parameterJson, ShaderParameterDef
error = "Unsupported parameter type '" + typeName + "' in: " + ManifestPathMessage(manifestPath); error = "Unsupported parameter type '" + typeName + "' in: " + ManifestPathMessage(manifestPath);
return false; return false;
} }
if (!ValidateShaderIdentifier(definition.id, "parameters[].id", manifestPath, error))
return false;
if (!ParseParameterDefault(parameterJson, definition, manifestPath, error) || if (!ParseParameterDefault(parameterJson, definition, manifestPath, error) ||
!ParseParameterNumberField(parameterJson, "min", definition.minNumbers, manifestPath, error) || !ParseParameterNumberField(parameterJson, "min", definition.minNumbers, manifestPath, error) ||
@@ -452,7 +497,7 @@ bool ParseParameterDefinition(const JsonValue& parameterJson, ShaderParameterDef
} }
if (definition.type == ShaderParameterType::Enum) if (definition.type == ShaderParameterType::Enum)
return ParseEnumOptions(parameterJson, definition, manifestPath, error); return ParseParameterOptions(parameterJson, definition, manifestPath, error);
return true; return true;
} }
@@ -1265,18 +1310,8 @@ bool RuntimeHost::ParseShaderManifest(const std::filesystem::path& manifestPath,
return false; return false;
} }
if (!RequireStringField(manifestJson, "id", shaderPackage.id, manifestPath, error) || if (!ParseShaderMetadata(manifestJson, shaderPackage, manifestPath, error))
!RequireStringField(manifestJson, "name", shaderPackage.displayName, manifestPath, error) ||
!OptionalStringField(manifestJson, "description", shaderPackage.description, "", manifestPath, error) ||
!OptionalStringField(manifestJson, "category", shaderPackage.category, "", manifestPath, error) ||
!OptionalStringField(manifestJson, "entryPoint", shaderPackage.entryPoint, "shadeVideo", manifestPath, error))
{
return false; return false;
}
shaderPackage.directoryPath = manifestPath.parent_path();
shaderPackage.shaderPath = shaderPackage.directoryPath / "shader.slang";
shaderPackage.manifestPath = manifestPath;
if (!std::filesystem::exists(shaderPackage.shaderPath)) if (!std::filesystem::exists(shaderPackage.shaderPath))
{ {

View File

@@ -1,6 +1,7 @@
#include "stdafx.h" #include "stdafx.h"
#include "RuntimeJson.h" #include "RuntimeJson.h"
#include <cerrno>
#include <cctype> #include <cctype>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@@ -10,6 +11,59 @@
namespace namespace
{ {
int HexDigitValue(char ch)
{
if (ch >= '0' && ch <= '9')
return ch - '0';
if (ch >= 'a' && ch <= 'f')
return ch - 'a' + 10;
if (ch >= 'A' && ch <= 'F')
return ch - 'A' + 10;
return -1;
}
bool IsHighSurrogate(unsigned int codePoint)
{
return codePoint >= 0xD800 && codePoint <= 0xDBFF;
}
bool IsLowSurrogate(unsigned int codePoint)
{
return codePoint >= 0xDC00 && codePoint <= 0xDFFF;
}
void AppendUtf8(unsigned int codePoint, std::ostringstream& output)
{
if (codePoint <= 0x7F)
{
output << static_cast<char>(codePoint);
}
else if (codePoint <= 0x7FF)
{
output << static_cast<char>(0xC0 | ((codePoint >> 6) & 0x1F));
output << static_cast<char>(0x80 | (codePoint & 0x3F));
}
else if (codePoint <= 0xFFFF)
{
output << static_cast<char>(0xE0 | ((codePoint >> 12) & 0x0F));
output << static_cast<char>(0x80 | ((codePoint >> 6) & 0x3F));
output << static_cast<char>(0x80 | (codePoint & 0x3F));
}
else
{
output << static_cast<char>(0xF0 | ((codePoint >> 18) & 0x07));
output << static_cast<char>(0x80 | ((codePoint >> 12) & 0x3F));
output << static_cast<char>(0x80 | ((codePoint >> 6) & 0x3F));
output << static_cast<char>(0x80 | (codePoint & 0x3F));
}
}
void AppendControlEscape(unsigned char ch, std::ostringstream& output)
{
const char* digits = "0123456789ABCDEF";
output << "\\u00" << digits[(ch >> 4) & 0x0F] << digits[ch & 0x0F];
}
class JsonParser class JsonParser
{ {
public: public:
@@ -181,8 +235,9 @@ private:
case 'r': result << '\r'; break; case 'r': result << '\r'; break;
case 't': result << '\t'; break; case 't': result << '\t'; break;
case 'u': case 'u':
setError("Unicode escape sequences are not supported in this JSON parser."); if (!parseUnicodeEscape(result))
return false; return false;
break;
default: default:
setError("Invalid escape sequence in JSON string."); setError("Invalid escape sequence in JSON string.");
return false; return false;
@@ -190,6 +245,11 @@ private:
} }
else else
{ {
if (static_cast<unsigned char>(ch) < 0x20)
{
setError("Unescaped control character in JSON string.");
return false;
}
result << ch; result << ch;
} }
} }
@@ -198,6 +258,66 @@ private:
return false; return false;
} }
bool parseHexCodePoint(unsigned int& codePoint)
{
if (mPosition + 4 > mText.size())
{
setError("Unexpected end of Unicode escape sequence.");
return false;
}
codePoint = 0;
for (int i = 0; i < 4; ++i)
{
const int digit = HexDigitValue(mText[mPosition + i]);
if (digit < 0)
{
setError("Invalid Unicode escape sequence in JSON string.");
return false;
}
codePoint = (codePoint << 4) | static_cast<unsigned int>(digit);
}
mPosition += 4;
return true;
}
bool parseUnicodeEscape(std::ostringstream& result)
{
unsigned int codePoint = 0;
if (!parseHexCodePoint(codePoint))
return false;
if (IsHighSurrogate(codePoint))
{
if (mPosition + 2 > mText.size() || mText[mPosition] != '\\' || mText[mPosition + 1] != 'u')
{
setError("High surrogate Unicode escape must be followed by a low surrogate.");
return false;
}
mPosition += 2;
unsigned int lowSurrogate = 0;
if (!parseHexCodePoint(lowSurrogate))
return false;
if (!IsLowSurrogate(lowSurrogate))
{
setError("High surrogate Unicode escape must be followed by a low surrogate.");
return false;
}
codePoint = 0x10000 + (((codePoint - 0xD800) << 10) | (lowSurrogate - 0xDC00));
}
else if (IsLowSurrogate(codePoint))
{
setError("Low surrogate Unicode escape without preceding high surrogate.");
return false;
}
AppendUtf8(codePoint, result);
return true;
}
bool parseNumber(JsonValue& value) bool parseNumber(JsonValue& value)
{ {
std::size_t start = mPosition; std::size_t start = mPosition;
@@ -205,12 +325,40 @@ private:
if (mText[mPosition] == '-') if (mText[mPosition] == '-')
++mPosition; ++mPosition;
if (mPosition >= mText.size())
{
setError("Invalid JSON number.");
return false;
}
if (mText[mPosition] == '0')
{
++mPosition;
if (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition])))
{
setError("JSON numbers must not contain leading zeroes.");
return false;
}
}
else if (mText[mPosition] >= '1' && mText[mPosition] <= '9')
{
while (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition]))) while (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition])))
++mPosition; ++mPosition;
}
else
{
setError("Invalid JSON number.");
return false;
}
if (mPosition < mText.size() && mText[mPosition] == '.') if (mPosition < mText.size() && mText[mPosition] == '.')
{ {
++mPosition; ++mPosition;
if (mPosition >= mText.size() || !std::isdigit(static_cast<unsigned char>(mText[mPosition])))
{
setError("JSON number fraction must contain at least one digit.");
return false;
}
while (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition]))) while (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition])))
++mPosition; ++mPosition;
} }
@@ -220,14 +368,20 @@ private:
++mPosition; ++mPosition;
if (mPosition < mText.size() && (mText[mPosition] == '+' || mText[mPosition] == '-')) if (mPosition < mText.size() && (mText[mPosition] == '+' || mText[mPosition] == '-'))
++mPosition; ++mPosition;
if (mPosition >= mText.size() || !std::isdigit(static_cast<unsigned char>(mText[mPosition])))
{
setError("JSON number exponent must contain at least one digit.");
return false;
}
while (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition]))) while (mPosition < mText.size() && std::isdigit(static_cast<unsigned char>(mText[mPosition])))
++mPosition; ++mPosition;
} }
std::string token = mText.substr(start, mPosition - start); std::string token = mText.substr(start, mPosition - start);
char* endPtr = nullptr; char* endPtr = nullptr;
errno = 0;
double parsed = strtod(token.c_str(), &endPtr); double parsed = strtod(token.c_str(), &endPtr);
if (endPtr == token.c_str() || *endPtr != '\0') if (endPtr == token.c_str() || *endPtr != '\0' || errno == ERANGE || !std::isfinite(parsed))
{ {
setError("Invalid JSON number."); setError("Invalid JSON number.");
return false; return false;
@@ -322,7 +476,12 @@ void SerializeJsonImpl(const JsonValue& value, std::ostringstream& output, bool
case '\n': output << "\\n"; break; case '\n': output << "\\n"; break;
case '\r': output << "\\r"; break; case '\r': output << "\\r"; break;
case '\t': output << "\\t"; break; case '\t': output << "\\t"; break;
default: output << ch; break; default:
if (static_cast<unsigned char>(ch) < 0x20)
AppendControlEscape(static_cast<unsigned char>(ch), output);
else
output << ch;
break;
} }
} }
output << '"'; output << '"';
@@ -407,14 +566,14 @@ JsonValue::JsonValue(const std::string& value)
JsonValue JsonValue::MakeArray() JsonValue JsonValue::MakeArray()
{ {
JsonValue value; JsonValue value;
value.mType = Type::Array; value.reset(Type::Array);
return value; return value;
} }
JsonValue JsonValue::MakeObject() JsonValue JsonValue::MakeObject()
{ {
JsonValue value; JsonValue value;
value.mType = Type::Object; value.reset(Type::Object);
return value; return value;
} }
@@ -449,20 +608,14 @@ const std::map<std::string, JsonValue>& JsonValue::asObject() const
std::vector<JsonValue>& JsonValue::array() std::vector<JsonValue>& JsonValue::array()
{ {
if (mType != Type::Array) if (mType != Type::Array)
{ reset(Type::Array);
mType = Type::Array;
mArrayValue.clear();
}
return mArrayValue; return mArrayValue;
} }
std::map<std::string, JsonValue>& JsonValue::object() std::map<std::string, JsonValue>& JsonValue::object()
{ {
if (mType != Type::Object) if (mType != Type::Object)
{ reset(Type::Object);
mType = Type::Object;
mObjectValue.clear();
}
return mObjectValue; return mObjectValue;
} }
@@ -485,6 +638,16 @@ const JsonValue* JsonValue::find(const std::string& key) const
return iterator != mObjectValue.end() ? &iterator->second : nullptr; return iterator != mObjectValue.end() ? &iterator->second : nullptr;
} }
void JsonValue::reset(Type type)
{
mType = type;
mBooleanValue = false;
mNumberValue = 0.0;
mStringValue.clear();
mArrayValue.clear();
mObjectValue.clear();
}
bool ParseJson(const std::string& text, JsonValue& value, std::string& error) bool ParseJson(const std::string& text, JsonValue& value, std::string& error)
{ {
error.clear(); error.clear();

View File

@@ -50,6 +50,8 @@ public:
const JsonValue* find(const std::string& key) const; const JsonValue* find(const std::string& key) const;
private: private:
void reset(Type type);
Type mType; Type mType;
bool mBooleanValue; bool mBooleanValue;
double mNumberValue; double mNumberValue;

View File

@@ -1,7 +1,8 @@
#include "stdafx.h" #include "stdafx.h"
#include "ShaderCompiler.h" #include "ShaderCompiler.h"
#include <cstring> #include "NativeHandles.h"
#include <fstream> #include <fstream>
#include <regex> #include <regex>
#include <sstream> #include <sstream>
@@ -20,17 +21,51 @@ std::string ReplaceAll(std::string text, const std::string& from, const std::str
return text; return text;
} }
std::string SlangTypeForParameter(ShaderParameterType type) std::string SlangCBufferTypeForParameter(ShaderParameterType type)
{ {
switch (type) switch (type)
{ {
case ShaderParameterType::Float: return "uniform float"; case ShaderParameterType::Float: return "float";
case ShaderParameterType::Vec2: return "uniform float2"; case ShaderParameterType::Vec2: return "float2";
case ShaderParameterType::Color: return "uniform float4"; case ShaderParameterType::Color: return "float4";
case ShaderParameterType::Boolean: return "uniform bool"; case ShaderParameterType::Boolean: return "bool";
case ShaderParameterType::Enum: return "uniform int"; case ShaderParameterType::Enum: return "int";
} }
return "uniform float"; return "float";
}
std::string BuildParameterUniforms(const std::vector<ShaderParameterDefinition>& parameters)
{
std::ostringstream source;
for (const ShaderParameterDefinition& definition : parameters)
source << "\t" << SlangCBufferTypeForParameter(definition.type) << " " << definition.id << ";\n";
return source.str();
}
std::string BuildHistorySamplerDeclarations(const std::string& samplerPrefix, unsigned historyLength)
{
std::ostringstream source;
for (unsigned index = 0; index < historyLength; ++index)
source << "Sampler2D<float4> " << samplerPrefix << index << ";\n";
return source.str();
}
std::string BuildTextureSamplerDeclarations(const std::vector<ShaderTextureAsset>& textureAssets)
{
std::ostringstream source;
for (const ShaderTextureAsset& textureAsset : textureAssets)
source << "Sampler2D<float4> " << textureAsset.id << ";\n";
if (!textureAssets.empty())
source << "\n";
return source.str();
}
std::string BuildHistorySwitchCases(const std::string& samplerPrefix, unsigned historyLength)
{
std::ostringstream source;
for (unsigned index = 0; index < historyLength; ++index)
source << "\tcase " << index << ": return " << samplerPrefix << index << ".Sample(tc);\n";
return source.str();
} }
} }
@@ -50,7 +85,9 @@ ShaderCompiler::ShaderCompiler(
bool ShaderCompiler::BuildLayerFragmentShaderSource(const ShaderPackage& shaderPackage, std::string& fragmentShaderSource, std::string& error) const bool ShaderCompiler::BuildLayerFragmentShaderSource(const ShaderPackage& shaderPackage, std::string& fragmentShaderSource, std::string& error) const
{ {
const std::string wrapperSource = BuildWrapperSlangSource(shaderPackage); std::string wrapperSource;
if (!BuildWrapperSlangSource(shaderPackage, wrapperSource, error))
return false;
if (!WriteTextFile(mWrapperPath, wrapperSource, error)) if (!WriteTextFile(mWrapperPath, wrapperSource, error))
return false; return false;
@@ -70,104 +107,22 @@ bool ShaderCompiler::BuildLayerFragmentShaderSource(const ShaderPackage& shaderP
return true; return true;
} }
std::string ShaderCompiler::BuildWrapperSlangSource(const ShaderPackage& shaderPackage) const bool ShaderCompiler::BuildWrapperSlangSource(const ShaderPackage& shaderPackage, std::string& wrapperSource, std::string& error) const
{ {
std::ostringstream source; const std::filesystem::path templatePath = mRepoRoot / "runtime" / "templates" / "shader_wrapper.slang.in";
source << "struct FragmentInput\n"; wrapperSource = ReadTextFile(templatePath, error);
source << "{\n"; if (wrapperSource.empty())
source << "\tfloat4 position : SV_Position;\n"; return false;
source << "\tfloat2 texCoord : TEXCOORD0;\n";
source << "};\n\n"; wrapperSource = ReplaceAll(wrapperSource, "{{PARAMETER_UNIFORMS}}", BuildParameterUniforms(shaderPackage.parameters));
source << "struct ShaderContext\n"; wrapperSource = ReplaceAll(wrapperSource, "{{SOURCE_HISTORY_SAMPLERS}}", BuildHistorySamplerDeclarations("gSourceHistory", mMaxTemporalHistoryFrames));
source << "{\n"; wrapperSource = ReplaceAll(wrapperSource, "{{TEMPORAL_HISTORY_SAMPLERS}}", BuildHistorySamplerDeclarations("gTemporalHistory", mMaxTemporalHistoryFrames));
source << "\tfloat2 uv;\n"; wrapperSource = ReplaceAll(wrapperSource, "{{TEXTURE_SAMPLERS}}", BuildTextureSamplerDeclarations(shaderPackage.textureAssets));
source << "\tfloat4 sourceColor;\n"; wrapperSource = ReplaceAll(wrapperSource, "{{SOURCE_HISTORY_SWITCH_CASES}}", BuildHistorySwitchCases("gSourceHistory", mMaxTemporalHistoryFrames));
source << "\tfloat2 inputResolution;\n"; wrapperSource = ReplaceAll(wrapperSource, "{{TEMPORAL_HISTORY_SWITCH_CASES}}", BuildHistorySwitchCases("gTemporalHistory", mMaxTemporalHistoryFrames));
source << "\tfloat2 outputResolution;\n"; wrapperSource = ReplaceAll(wrapperSource, "{{USER_SHADER_INCLUDE}}", shaderPackage.shaderPath.generic_string());
source << "\tfloat time;\n"; wrapperSource = ReplaceAll(wrapperSource, "{{ENTRY_POINT_CALL}}", shaderPackage.entryPoint + "(context)");
source << "\tfloat frameCount;\n"; return true;
source << "\tfloat mixAmount;\n";
source << "\tfloat bypass;\n";
source << "\tint sourceHistoryLength;\n";
source << "\tint temporalHistoryLength;\n";
source << "};\n\n";
source << "cbuffer GlobalParams\n";
source << "{\n";
source << "\tfloat gTime;\n";
source << "\tfloat2 gInputResolution;\n";
source << "\tfloat2 gOutputResolution;\n";
source << "\tfloat gFrameCount;\n";
source << "\tfloat gMixAmount;\n";
source << "\tfloat gBypass;\n";
source << "\tint gSourceHistoryLength;\n";
source << "\tint gTemporalHistoryLength;\n";
for (const ShaderParameterDefinition& definition : shaderPackage.parameters)
source << "\t" << SlangTypeForParameter(definition.type).substr(strlen("uniform ")) << " " << definition.id << ";\n";
source << "};\n\n";
source << "Sampler2D<float4> gVideoInput;\n";
for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index)
source << "Sampler2D<float4> gSourceHistory" << index << ";\n";
for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index)
source << "Sampler2D<float4> gTemporalHistory" << index << ";\n";
for (const ShaderTextureAsset& textureAsset : shaderPackage.textureAssets)
source << "Sampler2D<float4> " << textureAsset.id << ";\n";
source << "\n";
source << "float4 sampleVideo(float2 tc)\n";
source << "{\n";
source << "\treturn gVideoInput.Sample(tc);\n";
source << "}\n\n";
source << "float4 sampleSourceHistory(int framesAgo, float2 tc)\n";
source << "{\n";
source << "\tif (gSourceHistoryLength <= 0)\n";
source << "\t\treturn sampleVideo(tc);\n";
source << "\tint clampedIndex = framesAgo;\n";
source << "\tif (clampedIndex < 0)\n";
source << "\t\tclampedIndex = 0;\n";
source << "\tif (clampedIndex >= gSourceHistoryLength)\n";
source << "\t\tclampedIndex = gSourceHistoryLength - 1;\n";
source << "\tswitch (clampedIndex)\n";
source << "\t{\n";
for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index)
source << "\tcase " << index << ": return gSourceHistory" << index << ".Sample(tc);\n";
source << "\tdefault: return sampleVideo(tc);\n";
source << "\t}\n";
source << "}\n\n";
source << "float4 sampleTemporalHistory(int framesAgo, float2 tc)\n";
source << "{\n";
source << "\tif (gTemporalHistoryLength <= 0)\n";
source << "\t\treturn sampleVideo(tc);\n";
source << "\tint clampedIndex = framesAgo;\n";
source << "\tif (clampedIndex < 0)\n";
source << "\t\tclampedIndex = 0;\n";
source << "\tif (clampedIndex >= gTemporalHistoryLength)\n";
source << "\t\tclampedIndex = gTemporalHistoryLength - 1;\n";
source << "\tswitch (clampedIndex)\n";
source << "\t{\n";
for (unsigned index = 0; index < mMaxTemporalHistoryFrames; ++index)
source << "\tcase " << index << ": return gTemporalHistory" << index << ".Sample(tc);\n";
source << "\tdefault: return sampleVideo(tc);\n";
source << "\t}\n";
source << "}\n\n";
source << "#include \"" << shaderPackage.shaderPath.generic_string() << "\"\n\n";
source << "[shader(\"fragment\")]\n";
source << "float4 fragmentMain(FragmentInput input) : SV_Target\n";
source << "{\n";
source << "\tShaderContext context;\n";
source << "\tcontext.uv = input.texCoord;\n";
source << "\tcontext.sourceColor = sampleVideo(context.uv);\n";
source << "\tcontext.inputResolution = gInputResolution;\n";
source << "\tcontext.outputResolution = gOutputResolution;\n";
source << "\tcontext.time = gTime;\n";
source << "\tcontext.frameCount = gFrameCount;\n";
source << "\tcontext.mixAmount = gMixAmount;\n";
source << "\tcontext.bypass = gBypass;\n";
source << "\tcontext.sourceHistoryLength = gSourceHistoryLength;\n";
source << "\tcontext.temporalHistoryLength = gTemporalHistoryLength;\n";
source << "\tfloat4 effectedColor = " << shaderPackage.entryPoint << "(context);\n";
source << "\tfloat mixValue = clamp(gBypass > 0.5 ? 0.0 : gMixAmount, 0.0, 1.0);\n";
source << "\treturn lerp(context.sourceColor, effectedColor, mixValue);\n";
source << "}\n";
return source.str();
} }
bool ShaderCompiler::FindSlangCompiler(std::filesystem::path& compilerPath, std::string& error) const bool ShaderCompiler::FindSlangCompiler(std::filesystem::path& compilerPath, std::string& error) const
@@ -216,12 +171,13 @@ bool ShaderCompiler::RunSlangCompiler(const std::filesystem::path& wrapperPath,
return false; return false;
} }
WaitForSingleObject(processInfo.hProcess, INFINITE); UniqueHandle processHandle(processInfo.hProcess);
UniqueHandle threadHandle(processInfo.hThread);
WaitForSingleObject(processHandle.get(), INFINITE);
DWORD exitCode = 0; DWORD exitCode = 0;
GetExitCodeProcess(processInfo.hProcess, &exitCode); GetExitCodeProcess(processHandle.get(), &exitCode);
CloseHandle(processInfo.hThread);
CloseHandle(processInfo.hProcess);
if (exitCode != 0) if (exitCode != 0)
{ {

View File

@@ -18,7 +18,7 @@ public:
bool BuildLayerFragmentShaderSource(const ShaderPackage& shaderPackage, std::string& fragmentShaderSource, std::string& error) const; bool BuildLayerFragmentShaderSource(const ShaderPackage& shaderPackage, std::string& fragmentShaderSource, std::string& error) const;
private: private:
std::string BuildWrapperSlangSource(const ShaderPackage& shaderPackage) const; bool BuildWrapperSlangSource(const ShaderPackage& shaderPackage, std::string& wrapperSource, std::string& error) const;
bool FindSlangCompiler(std::filesystem::path& compilerPath, std::string& error) const; bool FindSlangCompiler(std::filesystem::path& compilerPath, std::string& error) const;
bool RunSlangCompiler(const std::filesystem::path& wrapperPath, const std::filesystem::path& outputPath, std::string& error) const; bool RunSlangCompiler(const std::filesystem::path& wrapperPath, const std::filesystem::path& outputPath, std::string& error) const;
bool PatchGeneratedGlsl(std::string& shaderText, std::string& error) const; bool PatchGeneratedGlsl(std::string& shaderText, std::string& error) const;

View File

@@ -39,6 +39,7 @@
*/ */
#include "VideoFrameTransfer.h" #include "VideoFrameTransfer.h"
#include "NativeHandles.h"
#define DVP_CHECK(cmd) { \ #define DVP_CHECK(cmd) { \
@@ -140,20 +141,19 @@ bool VideoFrameTransfer::initializeMemoryLocking(unsigned memSize)
{ {
// Increase the process working set size to allow pinning of memory. // Increase the process working set size to allow pinning of memory.
static SIZE_T dwMin = 0, dwMax = 0; static SIZE_T dwMin = 0, dwMax = 0;
HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_SET_QUOTA, FALSE, GetCurrentProcessId()); UniqueHandle processHandle(OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_SET_QUOTA, FALSE, GetCurrentProcessId()));
if (!hProcess) if (!processHandle.valid())
return false; return false;
// Retrieve the working set size of the process. // Retrieve the working set size of the process.
if (!dwMin && !GetProcessWorkingSetSize(hProcess, &dwMin, &dwMax)) if (!dwMin && !GetProcessWorkingSetSize(processHandle.get(), &dwMin, &dwMax))
return false; return false;
// Allow for 80 frames to be locked // Allow for 80 frames to be locked
BOOL res = SetProcessWorkingSetSize(hProcess, memSize * 80 + dwMin, memSize * 80 + (dwMax-dwMin)); BOOL res = SetProcessWorkingSetSize(processHandle.get(), memSize * 80 + dwMin, memSize * 80 + (dwMax-dwMin));
if (!res) if (!res)
return false; return false;
CloseHandle(hProcess);
return true; return true;
} }

View File

@@ -0,0 +1,89 @@
struct FragmentInput
{
float4 position : SV_Position;
float2 texCoord : TEXCOORD0;
};
struct ShaderContext
{
float2 uv;
float4 sourceColor;
float2 inputResolution;
float2 outputResolution;
float time;
float frameCount;
float mixAmount;
float bypass;
int sourceHistoryLength;
int temporalHistoryLength;
};
cbuffer GlobalParams
{
float gTime;
float2 gInputResolution;
float2 gOutputResolution;
float gFrameCount;
float gMixAmount;
float gBypass;
int gSourceHistoryLength;
int gTemporalHistoryLength;
{{PARAMETER_UNIFORMS}}};
Sampler2D<float4> gVideoInput;
{{SOURCE_HISTORY_SAMPLERS}}{{TEMPORAL_HISTORY_SAMPLERS}}{{TEXTURE_SAMPLERS}}
float4 sampleVideo(float2 tc)
{
return gVideoInput.Sample(tc);
}
float4 sampleSourceHistory(int framesAgo, float2 tc)
{
if (gSourceHistoryLength <= 0)
return sampleVideo(tc);
int clampedIndex = framesAgo;
if (clampedIndex < 0)
clampedIndex = 0;
if (clampedIndex >= gSourceHistoryLength)
clampedIndex = gSourceHistoryLength - 1;
switch (clampedIndex)
{
{{SOURCE_HISTORY_SWITCH_CASES}} default: return sampleVideo(tc);
}
}
float4 sampleTemporalHistory(int framesAgo, float2 tc)
{
if (gTemporalHistoryLength <= 0)
return sampleVideo(tc);
int clampedIndex = framesAgo;
if (clampedIndex < 0)
clampedIndex = 0;
if (clampedIndex >= gTemporalHistoryLength)
clampedIndex = gTemporalHistoryLength - 1;
switch (clampedIndex)
{
{{TEMPORAL_HISTORY_SWITCH_CASES}} default: return sampleVideo(tc);
}
}
#include "{{USER_SHADER_INCLUDE}}"
[shader("fragment")]
float4 fragmentMain(FragmentInput input) : SV_Target
{
ShaderContext context;
context.uv = input.texCoord;
context.sourceColor = sampleVideo(context.uv);
context.inputResolution = gInputResolution;
context.outputResolution = gOutputResolution;
context.time = gTime;
context.frameCount = gFrameCount;
context.mixAmount = gMixAmount;
context.bypass = gBypass;
context.sourceHistoryLength = gSourceHistoryLength;
context.temporalHistoryLength = gTemporalHistoryLength;
float4 effectedColor = {{ENTRY_POINT_CALL}};
float mixValue = clamp(gBypass > 0.5 ? 0.0 : gMixAmount, 0.0, 1.0);
return lerp(context.sourceColor, effectedColor, mixValue);
}

119
tests/RuntimeJsonTests.cpp Normal file
View File

@@ -0,0 +1,119 @@
#include "RuntimeJson.h"
#include <cmath>
#include <iostream>
#include <string>
namespace
{
int gFailures = 0;
void Expect(bool condition, const char* message)
{
if (condition)
return;
std::cerr << "FAIL: " << message << "\n";
++gFailures;
}
JsonValue ParseOrFail(const std::string& text, const char* message)
{
JsonValue value;
std::string error;
if (!ParseJson(text, value, error))
{
std::cerr << "FAIL: " << message << ": " << error << "\n";
++gFailures;
}
return value;
}
void ExpectParseFails(const std::string& text, const char* message)
{
JsonValue value;
std::string error;
Expect(!ParseJson(text, value, error), message);
Expect(!error.empty(), "failed parse should include an error message");
}
void TestStringsAndEscaping()
{
const JsonValue value = ParseOrFail(R"({"text":"line\nquote:\" slash:\\ tab:\t"})", "escaped string parses");
const JsonValue* text = value.find("text");
Expect(text && text->isString(), "escaped string field is present");
Expect(text && text->asString().find("line\nquote:\" slash:\\ tab:\t") != std::string::npos, "escaped string unescapes expected characters");
JsonValue controlString(std::string("a\001b", 3));
Expect(SerializeJson(controlString, false) == R"("a\u0001b")", "serializer escapes raw control characters");
ExpectParseFails(std::string("{\"bad\":\"a") + static_cast<char>(1) + "b\"}", "raw control characters are rejected");
}
void TestUnicodeEscapes()
{
const JsonValue copyright = ParseOrFail(R"({"s":"\u00A9"})", "basic unicode escape parses");
Expect(copyright.find("s") && copyright.find("s")->asString() == "\xC2\xA9", "basic unicode escape becomes UTF-8");
const JsonValue music = ParseOrFail(R"({"s":"\uD834\uDD1E"})", "surrogate pair parses");
Expect(music.find("s") && music.find("s")->asString() == "\xF0\x9D\x84\x9E", "surrogate pair becomes UTF-8");
ExpectParseFails(R"({"s":"\uD834x"})", "unpaired high surrogate is rejected");
ExpectParseFails(R"({"s":"\uDD1E"})", "unpaired low surrogate is rejected");
ExpectParseFails(R"({"s":"\u00XZ"})", "invalid unicode hex is rejected");
}
void TestNumbers()
{
const JsonValue value = ParseOrFail(R"({"a":0,"b":-12.5e+2,"c":3.25})", "valid numbers parse");
Expect(value.find("a") && value.find("a")->asNumber(-1.0) == 0.0, "zero parses");
Expect(value.find("b") && std::fabs(value.find("b")->asNumber() + 1250.0) < 0.0001, "exponent parses");
Expect(value.find("c") && std::fabs(value.find("c")->asNumber() - 3.25) < 0.0001, "fraction parses");
ExpectParseFails("01", "leading zero is rejected");
ExpectParseFails("1.", "fraction without digits is rejected");
ExpectParseFails("1e", "exponent without digits is rejected");
ExpectParseFails("1e9999", "non-finite number is rejected");
}
void TestRoundtripAndMutation()
{
JsonValue root = JsonValue::MakeObject();
root.set("version", JsonValue(1.0));
root.set("name", JsonValue("Preset One"));
JsonValue layers = JsonValue::MakeArray();
JsonValue layer = JsonValue::MakeObject();
layer.set("id", JsonValue("layer-a"));
layer.set("bypass", JsonValue(false));
layers.pushBack(layer);
root.set("layers", layers);
const std::string serialized = SerializeJson(root, true);
const JsonValue parsed = ParseOrFail(serialized, "serialized preset-style object parses");
Expect(parsed.find("layers") && parsed.find("layers")->isArray(), "roundtripped layers array exists");
JsonValue mutableValue = JsonValue::MakeObject();
mutableValue.set("stale", JsonValue("value"));
mutableValue.pushBack(JsonValue(7.0));
Expect(mutableValue.isArray(), "pushBack changes object into array");
Expect(mutableValue.asObject().empty(), "object storage is cleared when changing to array");
Expect(mutableValue.asArray().size() == 1, "new array receives pushed value");
}
}
int main()
{
TestStringsAndEscaping();
TestUnicodeEscapes();
TestNumbers();
TestRoundtripAndMutation();
if (gFailures != 0)
{
std::cerr << gFailures << " RuntimeJson test failure(s).\n";
return 1;
}
std::cout << "RuntimeJson tests passed.\n";
return 0;
}