diff --git a/SerialPrograms/Source/CommonTools/OCR/OCR_NumberReader.cpp b/SerialPrograms/Source/CommonTools/OCR/OCR_NumberReader.cpp index 0f4ba24e1..f54909159 100644 --- a/SerialPrograms/Source/CommonTools/OCR/OCR_NumberReader.cpp +++ b/SerialPrograms/Source/CommonTools/OCR/OCR_NumberReader.cpp @@ -14,10 +14,12 @@ #include "CommonFramework/ImageTypes/ImageRGB32.h" #include "CommonFramework/ImageTools/ImageBoxes.h" #include "CommonFramework/Tools/GlobalThreadPools.h" +#include "CommonFramework/GlobalSettingsPanel.h" #include "CommonTools/Images/ImageManip.h" #include "CommonTools/Images/ImageFilter.h" #include "CommonTools/Images/BinaryImage_FilterRgb32.h" #include "OCR_RawOCR.h" +#include "OCR_RawPaddleOCR.h" #include "OCR_NumberReader.h" #include @@ -81,7 +83,14 @@ std::string run_number_normalization(const std::string& input){ int read_number(Logger& logger, const ImageViewRGB32& image, Language language){ - std::string ocr_text = OCR::ocr_read(language, image, OCR::PageSegMode::SINGLE_LINE); + bool use_paddle_ocr = false; // GlobalSettings::instance().USE_PADDLE_OCR; + std::string ocr_text; + if (use_paddle_ocr){ + ocr_text = OCR::paddle_ocr_read(language, image); + }else{ + ocr_text = OCR::ocr_read(language, image, OCR::PageSegMode::SINGLE_LINE); + } + std::string normalized = run_number_normalization(ocr_text); std::string str; @@ -167,8 +176,13 @@ std::string read_number_waterfill_no_normalization( } ImageRGB32 padded = pad_image(cropped, 1 * cropped.width(), 0xffffffff); - std::string ocr = OCR::ocr_read(Language::English, padded, OCR::PageSegMode::SINGLE_CHAR); - + bool use_paddle_ocr = false; // GlobalSettings::instance().USE_PADDLE_OCR; + std::string ocr; + if (use_paddle_ocr){ + ocr = OCR::paddle_ocr_read(Language::English, padded); + }else{ + ocr = OCR::ocr_read(Language::English, padded, OCR::PageSegMode::SINGLE_CHAR); + } // padded.save("zztest-cropped" + std::to_string(c) + "-" + std::to_string(i++) + ".png"); // std::cout << ocr[0] << std::endl; if (!ocr.empty()){ diff --git a/SerialPrograms/Source/CommonTools/OCR/OCR_RawPaddleOCR.cpp b/SerialPrograms/Source/CommonTools/OCR/OCR_RawPaddleOCR.cpp new file mode 100644 index 000000000..c59e92d00 --- /dev/null +++ b/SerialPrograms/Source/CommonTools/OCR/OCR_RawPaddleOCR.cpp @@ -0,0 +1,137 @@ +/* Threadpools for PaddleOCR + * + * From: https://github.com/PokemonAutomation/ + * + */ + +#include +#include +#include +#include +#include "ML/Inference/ML_PaddleOCRPipeline.h" +#include "Common/Cpp/Exceptions.h" +#include "Common/Cpp/Concurrency/SpinLock.h" +#include "CommonFramework/Globals.h" +#include "CommonFramework/Logging/Logger.h" +#include "CommonFramework/ImageTypes/ImageViewRGB32.h" +#include "OCR_RawOCR.h" + +#include +using std::cout; +using std::endl; + +namespace PokemonAutomation{ +namespace OCR{ + + + +enum class LanguageGroup { + None, + English, + ChineseJapanese, + Latin, + Korean, +}; + +LanguageGroup language_to_languagegroup(Language language){ + switch(language){ + case Language::None: + throw InternalProgramError(nullptr, PA_CURRENT_FUNCTION, "Attempted to call OCR without a language."); + case Language::English: + return LanguageGroup::English; + case Language::Japanese: + return LanguageGroup::ChineseJapanese; + case Language::Spanish: + return LanguageGroup::Latin; + case Language::French: + return LanguageGroup::Latin; + case Language::German: + return LanguageGroup::Latin; + case Language::Italian: + return LanguageGroup::Latin; + case Language::Korean: + return LanguageGroup::Korean; + case Language::ChineseSimplified: + return LanguageGroup::ChineseJapanese; + case Language::ChineseTraditional: + return LanguageGroup::ChineseJapanese; + default: + throw InternalProgramError(nullptr, PA_CURRENT_FUNCTION, "Attempted to call OCR on an unknown language."); + } +} + + +// Global singleton managing the single PaddleOCR instance for each language. +// ocr_pool_lock protects the map +struct PaddleOcrGlobals{ + SpinLock ocr_pool_lock; // Protects ocr_pool map. + std::map ocr_pool; // One instance per language. + + static PaddleOcrGlobals& instance(){ + static PaddleOcrGlobals globals; + return globals; + } +}; + +ML::PaddleOCRPipeline& ensure_paddle_ocr_instance(Language language){ + if (language == Language::None){ + throw InternalProgramError(nullptr, PA_CURRENT_FUNCTION, "Attempted to call OCR without a language."); + } + + LanguageGroup language_group = language_to_languagegroup(language); + + PaddleOcrGlobals& globals = PaddleOcrGlobals::instance(); + std::map& ocr_pool = globals.ocr_pool; + + // Get or create the Paddle instance for this language. + std::map::iterator iter; + { + WriteSpinLock lg(globals.ocr_pool_lock, "ensure_paddle_ocr_instances()"); + // std::lock_guard lg(globals.ocr_pool_lock); + iter = ocr_pool.find(language_group); + if (iter == ocr_pool.end()){ + // This is creating a Paddle instance while under a lock; it isn't ideal if we need to run OCR on different languages at the same time. + // In practice, however, this doesn't really happen in our code base. + iter = ocr_pool.try_emplace(language_group, language).first; + } + } + + return iter->second; +} + + +std::string paddle_ocr_read(Language language, const ImageViewRGB32& image){ +// static size_t c = 0; +// image.save("ocr-" + std::to_string(c++) + ".png"); + + ML::PaddleOCRPipeline& paddle_instance = ensure_paddle_ocr_instance(language); + + // Run inference with the paddle model. + // PaddleOCR with Onnx is threadsafe, so a single instance can be called by multiple threads. + std::string ret = paddle_instance.recognize(image); + +// global_logger_tagged().log(ret); + + return ret; +} + + + + +void clear_paddle_ocr_cache(){ + PaddleOcrGlobals& globals = PaddleOcrGlobals::instance(); + std::map& ocr_pool = globals.ocr_pool; + WriteSpinLock lg(globals.ocr_pool_lock, "clear_paddle_ocr_cache()"); + // std::lock_guard lg(globals.ocr_pool_lock); + ocr_pool.clear(); // Destroys all pools and their instances. +} + + + + +} +} + + + + diff --git a/SerialPrograms/Source/CommonTools/OCR/OCR_RawPaddleOCR.h b/SerialPrograms/Source/CommonTools/OCR/OCR_RawPaddleOCR.h new file mode 100644 index 000000000..90cf5524c --- /dev/null +++ b/SerialPrograms/Source/CommonTools/OCR/OCR_RawPaddleOCR.h @@ -0,0 +1,49 @@ +/* Threadpools for PaddleOCR + * + * From: https://github.com/PokemonAutomation/ + * + */ + +#ifndef PokemonAutomation_CommonTools_OCR_RawPaddleOCR_H +#define PokemonAutomation_CommonTools_OCR_RawPaddleOCR_H + +#include +#include "CommonFramework/Language.h" + +namespace PokemonAutomation{ + class ImageViewRGB32; + namespace ML { + class PaddleOCRPipeline; + } +namespace OCR{ + + +// Pre-warm the PaddleOCR instance pool for a language. Ensure one instance exists. +// Avoids lazy initialization delays during runtime. Thread-safe. +// returns a pointer to a Paddle instance, for the given language. +ML::PaddleOCRPipeline& ensure_paddle_ocr_instance(Language language); + +// OCR the image in the specified language. +// Main OCR entry point. Performs OCR on the image using the specified language. +// Thread-safe: internally uses a pool of PaddleOCR instances, able to accept +// multiple concurrent calls without delay or queueing. +// It creates one PaddleOCR instance for each language. You can +// call `ensure_instances()` to pre-warm to pool with a given number of instances. +// +std::string paddle_ocr_read( + Language language, + const ImageViewRGB32& image +); + + + +// Clear all PaddleOCR instances for all languages. Used for cleanup or +// forcing re-initialization. +// This is not safe to call while any OCR is still running! +void clear_paddle_ocr_cache(); + + + +} +} +#endif diff --git a/SerialPrograms/Source/CommonTools/OCR/OCR_Routines.cpp b/SerialPrograms/Source/CommonTools/OCR/OCR_Routines.cpp index 02d61bad7..16998b096 100644 --- a/SerialPrograms/Source/CommonTools/OCR/OCR_Routines.cpp +++ b/SerialPrograms/Source/CommonTools/OCR/OCR_Routines.cpp @@ -8,7 +8,7 @@ #include "CommonFramework/Tools/GlobalThreadPools.h" #include "CommonFramework/GlobalSettingsPanel.h" #include "CommonTools/Images/ImageFilter.h" -#include "ML/Inference/ML_PaddleOCRPipeline.h" +#include "OCR_RawPaddleOCR.h" #include "OCR_RawOCR.h" #include "OCR_DictionaryMatcher.h" #include "OCR_Routines.h" @@ -45,11 +45,6 @@ StringMatchResult multifiltered_OCR( double pixels_inv = 1. / (image.width() * image.height()); bool use_paddle_ocr = GlobalSettings::instance().USE_PADDLE_OCR; - std::unique_ptr paddle_ocr; - if (use_paddle_ocr) { - // Initialize only if the setting is enabled - paddle_ocr = std::make_unique(language); - } // Run all the filters. SpinLock lock; @@ -60,7 +55,7 @@ StringMatchResult multifiltered_OCR( std::string text; if (use_paddle_ocr) { - text = paddle_ocr->recognize(filtered.first); + text = paddle_ocr_read(language, filtered.first); }else{ text = ocr_read(language, filtered.first, psm); } @@ -117,8 +112,7 @@ StringMatchResult dictionary_OCR( // Run all the filters. std::string text; if (GlobalSettings::instance().USE_PADDLE_OCR){ - ML::PaddleOCRPipeline paddle_ocr(language); - text = paddle_ocr.recognize(image); + text = paddle_ocr_read(language, image); }else{ text = ocr_read(language, image, psm); } diff --git a/SerialPrograms/Source/NintendoSwitch/DevPrograms/TestProgramSwitch.cpp b/SerialPrograms/Source/NintendoSwitch/DevPrograms/TestProgramSwitch.cpp index 09580cdcc..9b084b899 100644 --- a/SerialPrograms/Source/NintendoSwitch/DevPrograms/TestProgramSwitch.cpp +++ b/SerialPrograms/Source/NintendoSwitch/DevPrograms/TestProgramSwitch.cpp @@ -172,6 +172,7 @@ #include "Common/PABotBase2/PABotbase2_ReliableStreamConnection.h" #include "Common/Cpp/StreamConnections/MockDevice.h" #include "ML/Inference/ML_PaddleOCRPipeline.h" +#include "CommonTools/OCR/OCR_RawPaddleOCR.h" @@ -771,12 +772,13 @@ void TestProgram::program(MultiSwitchProgramEnvironment& env, CancellableScope& // ImageRGB32 image1(IMAGE_PATH); auto image1 = feed.snapshot(); ImageViewRGB32 cropped = extract_box_reference(image1, ImageFloatBox{BOX.x(), BOX.y(), BOX.width(), BOX.height()}); - ML::PaddleOCRPipeline paddle_ocr(LANGUAGE); // auto snapshot = feed.snapshot(); - std::string text = paddle_ocr.recognize(cropped); + std::string text = OCR::paddle_ocr_read(LANGUAGE, cropped); cout << text << endl; + + #endif #if 0 diff --git a/SerialPrograms/cmake/SourceFiles.cmake b/SerialPrograms/cmake/SourceFiles.cmake index a579e65cf..d626d6e6b 100644 --- a/SerialPrograms/cmake/SourceFiles.cmake +++ b/SerialPrograms/cmake/SourceFiles.cmake @@ -615,6 +615,8 @@ file(GLOB LIBRARY_SOURCES Source/CommonTools/OCR/OCR_LargeDictionaryMatcher.h Source/CommonTools/OCR/OCR_NumberReader.cpp Source/CommonTools/OCR/OCR_NumberReader.h + Source/CommonTools/OCR/OCR_RawPaddleOCR.cpp + Source/CommonTools/OCR/OCR_RawPaddleOCR.h Source/CommonTools/OCR/OCR_RawOCR.cpp Source/CommonTools/OCR/OCR_RawOCR.h Source/CommonTools/OCR/OCR_Routines.cpp