diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d8fbb5..7652dfa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,13 +21,10 @@ set(LIB_SOURCES include/utils/Logger.cpp include/utils/MathUtils.cpp) -add_library(executionrl_lib ${LIB_SOURCES}) -target_include_directories(executionrl_lib PUBLIC include) -target_compile_definitions(executionrl_lib PRIVATE DATA_PATH="${CMAKE_SOURCE_DIR}/data/price_series.csv") +add_library(ExecutionRLLib ${LIB_SOURCES}) add_executable(ExecutionRL src/main.cpp) -target_link_libraries(ExecutionRL PRIVATE executionrl_lib) -target_include_directories(ExecutionRL PRIVATE include) +target_link_libraries(ExecutionRL PRIVATE ExecutionRLLib) # === Tests === enable_testing() @@ -36,12 +33,9 @@ add_executable(test_agent tests/test_agent.cpp) add_executable(test_env tests/test_env.cpp) add_executable(test_simulation tests/test_simulation.cpp) -target_link_libraries(test_agent PRIVATE executionrl_lib) -target_link_libraries(test_env PRIVATE executionrl_lib) -target_link_libraries(test_simulation PRIVATE executionrl_lib) -target_include_directories(test_agent PRIVATE include) -target_include_directories(test_env PRIVATE include) -target_include_directories(test_simulation PRIVATE include) +target_link_libraries(test_agent PRIVATE ExecutionRLLib) +target_link_libraries(test_env PRIVATE ExecutionRLLib) +target_link_libraries(test_simulation PRIVATE ExecutionRLLib) add_test(NAME AgentTest COMMAND test_agent) add_test(NAME EnvTest COMMAND test_env) diff --git a/include/core/ExecutionSimulator.cpp b/include/core/ExecutionSimulator.cpp index 5871c37..0816a21 100644 --- a/include/core/ExecutionSimulator.cpp +++ b/include/core/ExecutionSimulator.cpp @@ -4,7 +4,7 @@ ExecutionSimulator::ExecutionSimulator(MarketEnvironment& env, QLearningAgent& a : env(env), agent(agent) {} void ExecutionSimulator::run(int episodes) { - for (int i = 0; i < episodes; ++i) { + for (int episode = 0; episode < episodes; ++episode) { env.reset(); while (!env.isDone()) { auto state = env.getState(); diff --git a/include/env/MarketEnvironment.cpp b/include/env/MarketEnvironment.cpp index 5e59240..f312d81 100644 --- a/include/env/MarketEnvironment.cpp +++ b/include/env/MarketEnvironment.cpp @@ -40,34 +40,24 @@ bool MarketEnvironment::isDone() const { } void MarketEnvironment::loadMarketData() { - std::ifstream file(DATA_PATH); + std::ifstream file("data/price_series.csv"); if (!file.is_open()) { file.open("../data/price_series.csv"); } std::string line; prices.clear(); - - // Skip header line if present - if (std::getline(file, line)) { - if (line.find("price") == std::string::npos) { - std::stringstream ss(line); - std::string timestamp, priceStr; - if (std::getline(ss, timestamp, ',') && std::getline(ss, priceStr, ',')) { - prices.push_back(std::stod(priceStr)); - } - } - } - while (std::getline(file, line)) { std::stringstream ss(line); - std::string timestamp, priceStr; - if (std::getline(ss, timestamp, ',') && std::getline(ss, priceStr, ',')) { - try { - prices.push_back(std::stod(priceStr)); - } catch (const std::exception& e) { - std::cerr << "Invalid price value: " << priceStr << "\n"; - } + std::string first, second; + if (!std::getline(ss, first, ',')) + continue; + if (!std::getline(ss, second, ',')) + continue; + try { + prices.push_back(std::stod(second)); + } catch (const std::invalid_argument&) { + // likely a header line, skip } } diff --git a/src/main.cpp b/src/main.cpp index 5e842e0..a08fd40 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,7 +7,7 @@ int main() { QLearningAgent agent(0.1, 0.9, 0.1); // alpha, gamma, epsilon ExecutionSimulator simulator(env, agent); - simulator.run(); + simulator.run(1000); // Run for 1000 episodes return 0; }