diff --git a/CMakeLists.txt b/CMakeLists.txt index aa72f65..384ff36 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,18 +14,17 @@ include_directories( ) # === Source files === -set(SOURCE_FILES - src/main.cpp - +set(LIB_SOURCES include/agents/QLearningAgent.cpp include/core/ExecutionSimulator.cpp include/env/MarketEnvironment.cpp include/utils/Logger.cpp - include/utils/MathUtils.cpp -) + include/utils/MathUtils.cpp) + +add_library(ExecutionRLLib ${LIB_SOURCES}) -# === Executable === -add_executable(ExecutionRL ${SOURCE_FILES}) +add_executable(ExecutionRL src/main.cpp) +target_link_libraries(ExecutionRL PRIVATE ExecutionRLLib) # === Tests === enable_testing() @@ -34,9 +33,9 @@ add_executable(test_agent tests/test_agent.cpp) add_executable(test_env tests/test_env.cpp) add_executable(test_simulation tests/test_simulation.cpp) -target_include_directories(test_agent PRIVATE include) -target_include_directories(test_env PRIVATE include) -target_include_directories(test_simulation PRIVATE include) +target_link_libraries(test_agent PRIVATE ExecutionRLLib) +target_link_libraries(test_env PRIVATE ExecutionRLLib) +target_link_libraries(test_simulation PRIVATE ExecutionRLLib) add_test(NAME AgentTest COMMAND test_agent) add_test(NAME EnvTest COMMAND test_env) diff --git a/include/core/ExecutionSimulator.cpp b/include/core/ExecutionSimulator.cpp index 0fd5277..0816a21 100644 --- a/include/core/ExecutionSimulator.cpp +++ b/include/core/ExecutionSimulator.cpp @@ -3,13 +3,15 @@ ExecutionSimulator::ExecutionSimulator(MarketEnvironment& env, QLearningAgent& agent) : env(env), agent(agent) {} -void ExecutionSimulator::run() { - env.reset(); - while (!env.isDone()) { - auto state = env.getState(); - int action = agent.chooseAction(state); - double reward = env.step(action); - auto nextState = env.getState(); - agent.update(state, action, reward, nextState); +void ExecutionSimulator::run(int episodes) { + for (int episode = 0; episode < episodes; ++episode) { + env.reset(); + while (!env.isDone()) { + auto state = env.getState(); + int action = agent.chooseAction(state); + double reward = env.step(action); + auto nextState = env.getState(); + agent.update(state, action, reward, nextState); + } } } diff --git a/include/core/ExecutionSimulator.hpp b/include/core/ExecutionSimulator.hpp index 0ae0fe3..ccdbe16 100644 --- a/include/core/ExecutionSimulator.hpp +++ b/include/core/ExecutionSimulator.hpp @@ -8,7 +8,7 @@ class ExecutionSimulator { public: ExecutionSimulator(MarketEnvironment& env, QLearningAgent& agent); - void run(); + void run(int episodes = 1); private: MarketEnvironment& env; diff --git a/include/env/MarketEnvironment.cpp b/include/env/MarketEnvironment.cpp index c1bc781..f312d81 100644 --- a/include/env/MarketEnvironment.cpp +++ b/include/env/MarketEnvironment.cpp @@ -40,15 +40,24 @@ bool MarketEnvironment::isDone() const { } void MarketEnvironment::loadMarketData() { - std::ifstream file("prices.csv"); + std::ifstream file("data/price_series.csv"); + if (!file.is_open()) { + file.open("../data/price_series.csv"); + } std::string line; prices.clear(); while (std::getline(file, line)) { std::stringstream ss(line); - std::string cell; - if (std::getline(ss, cell, ',')) { - prices.push_back(std::stod(cell)); + std::string first, second; + if (!std::getline(ss, first, ',')) + continue; + if (!std::getline(ss, second, ',')) + continue; + try { + prices.push_back(std::stod(second)); + } catch (const std::invalid_argument&) { + // likely a header line, skip } } diff --git a/src/main.cpp b/src/main.cpp index f93774d..a08fd40 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,6 +1,6 @@ -#include "include/core/ExecutionSimulator.hpp" -#include "include/env/MarketEnvironment.hpp" -#include "include/agents/QLearningAgent.hpp" +#include "core/ExecutionSimulator.hpp" +#include "env/MarketEnvironment.hpp" +#include "agents/QLearningAgent.hpp" int main() { MarketEnvironment env;