diff --git a/CMakeLists.txt b/CMakeLists.txt index aa72f65..3f4f55b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,9 +14,7 @@ include_directories( ) # === Source files === -set(SOURCE_FILES - src/main.cpp - +set(LIB_SOURCES include/agents/QLearningAgent.cpp include/core/ExecutionSimulator.cpp include/env/MarketEnvironment.cpp @@ -24,8 +22,12 @@ set(SOURCE_FILES include/utils/MathUtils.cpp ) +add_library(executionrl_lib ${LIB_SOURCES}) +target_compile_definitions(executionrl_lib PRIVATE DATA_PATH="${CMAKE_SOURCE_DIR}/data/price_series.csv") + # === Executable === -add_executable(ExecutionRL ${SOURCE_FILES}) +add_executable(ExecutionRL src/main.cpp) +target_link_libraries(ExecutionRL PRIVATE executionrl_lib) # === Tests === enable_testing() @@ -34,6 +36,10 @@ add_executable(test_agent tests/test_agent.cpp) add_executable(test_env tests/test_env.cpp) add_executable(test_simulation tests/test_simulation.cpp) +target_link_libraries(test_agent PRIVATE executionrl_lib) +target_link_libraries(test_env PRIVATE executionrl_lib) +target_link_libraries(test_simulation PRIVATE executionrl_lib) + target_include_directories(test_agent PRIVATE include) target_include_directories(test_env PRIVATE include) target_include_directories(test_simulation PRIVATE include) @@ -41,3 +47,6 @@ target_include_directories(test_simulation PRIVATE include) add_test(NAME AgentTest COMMAND test_agent) add_test(NAME EnvTest COMMAND test_env) add_test(NAME SimulationTest COMMAND test_simulation) + +set_tests_properties(AgentTest EnvTest SimulationTest + PROPERTIES WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}) diff --git a/include/core/ExecutionSimulator.cpp b/include/core/ExecutionSimulator.cpp index 0fd5277..5871c37 100644 --- a/include/core/ExecutionSimulator.cpp +++ b/include/core/ExecutionSimulator.cpp @@ -3,13 +3,15 @@ ExecutionSimulator::ExecutionSimulator(MarketEnvironment& env, QLearningAgent& agent) : env(env), agent(agent) {} -void ExecutionSimulator::run() { - env.reset(); - while (!env.isDone()) { - auto state = env.getState(); - int action = agent.chooseAction(state); - double reward = env.step(action); - auto nextState = env.getState(); - agent.update(state, action, reward, nextState); +void ExecutionSimulator::run(int episodes) { + for (int i = 0; i < episodes; ++i) { + env.reset(); + while (!env.isDone()) { + auto state = env.getState(); + int action = agent.chooseAction(state); + double reward = env.step(action); + auto nextState = env.getState(); + agent.update(state, action, reward, nextState); + } } } diff --git a/include/core/ExecutionSimulator.hpp b/include/core/ExecutionSimulator.hpp index 0ae0fe3..ccdbe16 100644 --- a/include/core/ExecutionSimulator.hpp +++ b/include/core/ExecutionSimulator.hpp @@ -8,7 +8,7 @@ class ExecutionSimulator { public: ExecutionSimulator(MarketEnvironment& env, QLearningAgent& agent); - void run(); + void run(int episodes = 1); private: MarketEnvironment& env; diff --git a/include/env/MarketEnvironment.cpp b/include/env/MarketEnvironment.cpp index c1bc781..8d7c66b 100644 --- a/include/env/MarketEnvironment.cpp +++ b/include/env/MarketEnvironment.cpp @@ -40,15 +40,25 @@ bool MarketEnvironment::isDone() const { } void MarketEnvironment::loadMarketData() { - std::ifstream file("prices.csv"); + std::ifstream file(DATA_PATH); std::string line; prices.clear(); + bool firstLine = true; while (std::getline(file, line)) { + if (firstLine) { // skip header if present + firstLine = false; + if (line.find_first_not_of("0123456789-.") != std::string::npos) + continue; + } std::stringstream ss(line); std::string cell; if (std::getline(ss, cell, ',')) { - prices.push_back(std::stod(cell)); + try { + prices.push_back(std::stod(cell)); + } catch (const std::invalid_argument&) { + // ignore malformed lines + } } } diff --git a/src/main.cpp b/src/main.cpp index f93774d..a08fd40 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,6 +1,6 @@ -#include "include/core/ExecutionSimulator.hpp" -#include "include/env/MarketEnvironment.hpp" -#include "include/agents/QLearningAgent.hpp" +#include "core/ExecutionSimulator.hpp" +#include "env/MarketEnvironment.hpp" +#include "agents/QLearningAgent.hpp" int main() { MarketEnvironment env; diff --git a/tests/test_agent.cpp b/tests/test_agent.cpp index 9ee59bd..e95ea36 100644 --- a/tests/test_agent.cpp +++ b/tests/test_agent.cpp @@ -1,4 +1,4 @@ -#include "../include/agents/QLearningAgent.hpp" +#include "agents/QLearningAgent.hpp" #include #include diff --git a/tests/test_env.cpp b/tests/test_env.cpp index fbf9148..3746010 100644 --- a/tests/test_env.cpp +++ b/tests/test_env.cpp @@ -1,4 +1,4 @@ -#include "../include/env/MarketEnvironment.hpp" +#include "env/MarketEnvironment.hpp" #include #include diff --git a/tests/test_simulation.cpp b/tests/test_simulation.cpp index 7eac0bd..4d2e7f6 100644 --- a/tests/test_simulation.cpp +++ b/tests/test_simulation.cpp @@ -1,7 +1,7 @@ -#include "../include/core/ExecutionSimulator.hpp" -#include "../include/env/MarketEnvironment.hpp" -#include "../include/agents/QLearningAgent.hpp" -#include "../include/config.hpp" +#include "core/ExecutionSimulator.hpp" +#include "env/MarketEnvironment.hpp" +#include "agents/QLearningAgent.hpp" +#include "config.hpp" #include int main() {