diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f2e743df..4ec98e56 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -77,4 +77,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build --cxxopt=-std=c++20 //:gemma --jobs=10 --show_progress_rate_limit=1 + - run: bazel build --cxxopt=-std=c++20 //:gemma_main --jobs=10 --show_progress_rate_limit=1 diff --git a/BUILD.bazel b/BUILD.bazel index 9eb60a4e..9eb3e4b9 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -55,7 +55,7 @@ cc_library( hdrs = ["util/args.h"], deps = [ ":basics", - "//io", # Path + "//gemma/io", # Path "@highway//:hwy", ], ) @@ -112,7 +112,7 @@ cc_library( ":threading", ":topology", ":zones", - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:profiler", @@ -192,8 +192,8 @@ cc_library( deps = [ ":basics", "//compression:types", - "//io", - "//io:fields", + "//gemma/io", + "//gemma/io:fields", "@highway//:hwy", # base.h ], ) @@ -205,7 +205,7 @@ cc_test( ":configs", "@googletest//:gtest_main", # buildcleaner: keep "//compression:types", - "//io:fields", + "//gemma/io:fields", ], ) @@ -230,7 +230,7 @@ cc_library( ":tensor_info", ":threading_context", "//compression:types", - "//io:fields", + "//gemma/io:fields", "@highway//:hwy", "@highway//:profiler", ], @@ -261,9 +261,9 @@ cc_library( ":threading_context", ":tokenizer", "//compression:types", - "//io", - "//io:blob_store", - "//io:fields", + "//gemma/io", + "//gemma/io:blob_store", + "//gemma/io:fields", "@highway//:hwy", "@highway//:profiler", ], @@ -282,7 +282,7 @@ cc_library( ":threading_context", ":zones", "//compression:compress", - "//io:blob_store", + "//gemma/io:blob_store", "@highway//:hwy", "@highway//:profiler", ], @@ -577,7 +577,7 @@ cc_library( ":mat", ":threading_context", "//compression:types", - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:profiler", ], @@ -641,7 +641,7 @@ cc_library( ":threading_context", ":zones", "//compression:compress", - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:profiler", "@highway//:stats", @@ -681,8 +681,8 @@ cc_library( ":zones", "//compression:compress", "//compression:types", - "//io", - "//io:blob_store", + "//gemma/io", + "//gemma/io:blob_store", "//paligemma:image", "@highway//:hwy", "@highway//:nanobenchmark", # timer @@ -795,7 +795,7 @@ cc_test( ":configs", ":gemma_lib", "@googletest//:gtest_main", # buildcleaner: keep - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:hwy_test_util", ], @@ -823,7 +823,7 @@ cc_test( ) cc_binary( - name = "gemma", + name = "gemma_main", srcs = ["gemma/run.cc"], deps = [ ":args", @@ -847,7 +847,7 @@ cc_binary( ":benchmark_helper", ":cross_entropy", ":gemma_lib", - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:nanobenchmark", "@nlohmann_json//:json", @@ -874,7 +874,7 @@ cc_binary( ":args", ":benchmark_helper", ":gemma_lib", - "//io", + "//gemma/io", "@highway//:hwy", "@nlohmann_json//:json", ], @@ -887,7 +887,7 @@ cc_binary( ":args", ":benchmark_helper", ":gemma_lib", - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:profiler", "@nlohmann_json//:json", diff --git a/CMakeLists.txt b/CMakeLists.txt index 47d7c4c2..f5d34380 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,13 +99,13 @@ set(SOURCES gemma/vit.h gemma/weights.cc gemma/weights.h - io/blob_store.cc - io/blob_store.h - io/fields.cc - io/fields.h - io/io_win.cc - io/io.cc - io/io.h + gemma/io/blob_store.cc + gemma/io/blob_store.h + gemma/io/fields.cc + gemma/io/fields.h + gemma/io/io_win.cc + gemma/io/io.cc + gemma/io/io.h ops/dot-inl.h ops/matmul_static_bf16.cc ops/matmul_static_f32.cc @@ -225,8 +225,8 @@ set(GEMMA_TEST_FILES gemma/gemma_args_test.cc gemma/flash_attention_test.cc gemma/tensor_info_test.cc - io/blob_store_test.cc - io/fields_test.cc + gemma/io/blob_store_test.cc + gemma/io/fields_test.cc ops/bench_matmul.cc ops/dot_test.cc ops/matmul_test.cc @@ -259,7 +259,7 @@ endif() # GEMMA_ENABLE_TESTS ## Tools -add_executable(migrate_weights io/migrate_weights.cc) +add_executable(migrate_weights gemma/io/migrate_weights.cc) target_link_libraries(migrate_weights libgemma hwy hwy_contrib) diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index e3b7e36f..72537203 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -23,8 +23,8 @@ cc_library( "//:threading_context", "//:tokenizer", "//compression:compress", - "//io", - "//io:blob_store", + "//gemma/io", + "//gemma/io:blob_store", "@highway//:hwy", ], ) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 3568ad34..57214995 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -29,19 +29,20 @@ #include "compression/compress.h" // ScaleWeights #include "gemma/configs.h" // ModelConfig +#include "gemma/io/blob_store.h" // BlobWriter +#include "gemma/io/io.h" // Path #include "gemma/model_store.h" // ModelStore #include "gemma/tensor_info.h" // TensorInfo #include "gemma/tokenizer.h" -#include "io/blob_store.h" // BlobWriter -#include "io/io.h" // Path #include "util/basics.h" #include "util/mat.h" #include "util/threading_context.h" + #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE \ "compression/python/compression_clif_aux.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 69798652..13b58188 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -23,11 +23,12 @@ #include "compression/types.h" // Type #include "gemma/configs.h" +#include "gemma/io/blob_store.h" #include "gemma/model_store.h" #include "gemma/tensor_info.h" -#include "io/blob_store.h" -#include "util/mat.h" #include "hwy/aligned_allocator.h" // Span +#include "util/mat.h" + namespace gcpp { diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 69cd644a..95b6b4ff 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -11,11 +11,12 @@ #include "evals/benchmark_helper.h" #include "evals/cross_entropy.h" #include "gemma/gemma.h" -#include "io/io.h" // Path -#include "util/args.h" +#include "gemma/io/io.h" // Path #include "hwy/base.h" #include "hwy/timer.h" #include "nlohmann/json.hpp" +#include "util/args.h" + namespace gcpp { @@ -85,8 +86,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, LogSpeedStats(time_start, pos + num_tokens); std::string text_slice = env.StringFromTokens(prompt_slice); total_input_len += text_slice.size(); - printf("Total cross entropy: %f [cumulative: %f]\n", - entropy, total_entropy); + printf("Total cross entropy: %f [cumulative: %f]\n", entropy, + total_entropy); printf("Cross entropy per byte: %f [cumulative: %f]\n", entropy / text_slice.size(), total_entropy / total_input_len); } diff --git a/evals/debug_prompt.cc b/evals/debug_prompt.cc index a6cf8c48..60be0080 100644 --- a/evals/debug_prompt.cc +++ b/evals/debug_prompt.cc @@ -20,10 +20,11 @@ #include "evals/benchmark_helper.h" #include "gemma/gemma.h" // LayersOutputFunc -#include "io/io.h" -#include "util/args.h" +#include "gemma/io/io.h" #include "hwy/base.h" #include "nlohmann/json.hpp" +#include "util/args.h" + using json = nlohmann::json; diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index c6ce9723..7476ce32 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -21,12 +21,13 @@ #include "evals/benchmark_helper.h" #include "gemma/gemma.h" // Gemma -#include "io/io.h" // Path -#include "util/args.h" +#include "gemma/io/io.h" // Path #include "hwy/base.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "nlohmann/json.hpp" +#include "util/args.h" + namespace gcpp { diff --git a/gemma/configs.cc b/gemma/configs.cc index cb508e84..c9263687 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -22,8 +22,8 @@ #include #include "compression/types.h" // Type -#include "io/fields.h" // IFields -#include "io/io.h" // Path +#include "gemma/io/fields.h" // IFields +#include "gemma/io/io.h" // Path #include "hwy/base.h" namespace gcpp { diff --git a/gemma/configs.h b/gemma/configs.h index f1bd0c52..c54b5cd4 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -26,8 +26,8 @@ #include #include "compression/types.h" // Type -#include "io/fields.h" // IFieldsVisitor -#include "io/io.h" // Path +#include "gemma/io/fields.h" // IFieldsVisitor +#include "gemma/io/io.h" // Path #include "util/basics.h" namespace gcpp { diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 0ca4a848..95db2e56 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -7,7 +7,7 @@ #include "gtest/gtest.h" #include "compression/types.h" // Type -#include "io/fields.h" // Type +#include "gemma/io/fields.h" // Type namespace gcpp { diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 5a48d009..40e28629 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -58,8 +58,8 @@ #include "gemma/configs.h" #include "gemma/model_store.h" #include "gemma/weights.h" -#include "io/blob_store.h" -#include "io/io.h" // Path +#include "gemma/io/blob_store.h" +#include "gemma/io/io.h" // Path #include "ops/matmul.h" #include "paligemma/image.h" #include "util/basics.h" diff --git a/gemma/gemma.h b/gemma/gemma.h index b630a8c1..5e9d9dd6 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -29,8 +29,8 @@ #include "gemma/model_store.h" #include "gemma/query.h" #include "gemma/weights.h" -#include "io/blob_store.h" -#include "io/io.h" // Path +#include "gemma/io/blob_store.h" +#include "gemma/io/io.h" // Path #include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" #include "util/basics.h" // TokenAndProb diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 6ccb5b38..fc854330 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -27,7 +27,7 @@ #include "compression/types.h" #include "gemma/configs.h" -#include "io/io.h" // Path +#include "gemma/io/io.h" // Path #include "util/args.h" // IWYU pragma: export #include "util/basics.h" // Tristate #include "util/mat.h" diff --git a/io/BUILD.bazel b/gemma/io/BUILD.bazel similarity index 100% rename from io/BUILD.bazel rename to gemma/io/BUILD.bazel diff --git a/io/blob_compare.cc b/gemma/io/blob_compare.cc similarity index 96% rename from io/blob_compare.cc rename to gemma/io/blob_compare.cc index 9bb860e5..85217e41 100644 --- a/io/blob_compare.cc +++ b/gemma/io/blob_compare.cc @@ -21,14 +21,15 @@ #include #include -#include "io/blob_store.h" -#include "io/io.h" // Path -#include "util/basics.h" // IndexRange -#include "util/threading.h" -#include "util/threading_context.h" +#include "gemma/io/blob_store.h" +#include "gemma/io/io.h" // Path #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" #include "hwy/timer.h" +#include "util/basics.h" // IndexRange +#include "util/threading.h" +#include "util/threading_context.h" + namespace gcpp { @@ -106,8 +107,8 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs, ThreadingContext& ctx, size_t cluster_idx) { HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size()); - ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx, - cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) { + ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx, cluster_idx, + Callers::kTest, [&](size_t i, size_t /*thread*/) { HWY_ASSERT(ranges[i].bytes == blobs[i].size()); reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data()); @@ -189,8 +190,8 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2, const double t0 = hwy::platform::Now(); std::atomic blobs_equal{}; std::atomic blobs_diff{}; - ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0, - Callers::kTest, [&](size_t i, size_t /*thread*/) { + ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0, Callers::kTest, + [&](size_t i, size_t /*thread*/) { const size_t mismatches = BlobDifferences(blobs1[i], blobs2[i], keys[i]); if (mismatches != 0) { diff --git a/io/blob_store.cc b/gemma/io/blob_store.cc similarity index 96% rename from io/blob_store.cc rename to gemma/io/blob_store.cc index 8346e4ba..02b3c3e5 100644 --- a/io/blob_store.cc +++ b/gemma/io/blob_store.cc @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "io/blob_store.h" +#include "gemma/io/blob_store.h" #include #include @@ -24,12 +24,13 @@ #include // std::move #include -#include "io/io.h" -#include "util/threading_context.h" +#include "gemma/io/io.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" #include "hwy/detect_compiler_arch.h" #include "hwy/profiler.h" +#include "util/threading_context.h" + namespace gcpp { @@ -490,25 +491,24 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) { const Parallelism parallelism = file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat; - ParallelFor( - parallelism, writes.size(), ctx_, - /*cluster_idx=*/0, Callers::kBlobWriter, - [this, &writes](uint64_t i, size_t /*thread*/) { - const BlobRange& range = writes[i].range; - if (!file_->Write(writes[i].data, range.bytes, range.offset)) { - const std::string& key = StringFromKey(keys_[range.key_idx]); - HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.", - key.c_str(), static_cast(range.offset), range.bytes, - writes[i].data); - } - }); + ParallelFor(parallelism, writes.size(), ctx_, + /*cluster_idx=*/0, Callers::kBlobWriter, + [this, &writes](uint64_t i, size_t /*thread*/) { + const BlobRange& range = writes[i].range; + if (!file_->Write(writes[i].data, range.bytes, range.offset)) { + const std::string& key = StringFromKey(keys_[range.key_idx]); + HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.", + key.c_str(), static_cast(range.offset), + range.bytes, writes[i].data); + } + }); curr_offset_ = writes.back().range.End(); } void BlobWriter::Finalize() { if (!file_->IsAppendOnly() && curr_offset_ != file_->FileSize()) { - HWY_WARN("Computed offset %zu does not match file size %zu.", - curr_offset_, file_->FileSize()); + HWY_WARN("Computed offset %zu does not match file size %zu.", curr_offset_, + file_->FileSize()); } const BlobStore bs = BlobStore(keys_, blob_sizes_); diff --git a/io/blob_store.h b/gemma/io/blob_store.h similarity index 97% rename from io/blob_store.h rename to gemma/io/blob_store.h index 82c2357b..e5461d33 100644 --- a/io/blob_store.h +++ b/gemma/io/blob_store.h @@ -26,11 +26,12 @@ #include #include -#include "io/io.h" // File, Path, MapPtr -#include "util/basics.h" // Tristate -#include "util/threading_context.h" +#include "gemma/io/io.h" // File, Path, MapPtr #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // HWY_ASSERT +#include "util/basics.h" // Tristate +#include "util/threading_context.h" + namespace gcpp { diff --git a/io/blob_store_test.cc b/gemma/io/blob_store_test.cc similarity index 96% rename from io/blob_store_test.cc rename to gemma/io/blob_store_test.cc index cf966849..e1f2995b 100644 --- a/io/blob_store_test.cc +++ b/gemma/io/blob_store_test.cc @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "io/blob_store.h" +#include "gemma/io/blob_store.h" #include @@ -22,11 +22,12 @@ #include #include -#include "io/io.h" -#include "util/basics.h" -#include "util/threading_context.h" +#include "gemma/io/io.h" #include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ +#include "util/basics.h" +#include "util/threading_context.h" + namespace gcpp { namespace { @@ -130,8 +131,8 @@ TEST(BlobStoreTest, TestNumBlobs) { HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); ParallelFor( - Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0, - Callers::kTest, [&](uint64_t i, size_t /*thread*/) { + Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0, Callers::kTest, + [&](uint64_t i, size_t /*thread*/) { HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str()); const BlobRange* range = reader.Find(keys[i]); diff --git a/io/fields.cc b/gemma/io/fields.cc similarity index 99% rename from io/fields.cc rename to gemma/io/fields.cc index 4516e891..dce67e66 100644 --- a/io/fields.cc +++ b/gemma/io/fields.cc @@ -13,8 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "io/fields.h" - +#include "gemma/io/fields.h" #include #include #include diff --git a/io/fields.h b/gemma/io/fields.h similarity index 100% rename from io/fields.h rename to gemma/io/fields.h diff --git a/io/fields_test.cc b/gemma/io/fields_test.cc similarity index 99% rename from io/fields_test.cc rename to gemma/io/fields_test.cc index f720c158..fe4645cf 100644 --- a/io/fields_test.cc +++ b/gemma/io/fields_test.cc @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "io/fields.h" +#include "gemma/io/fields.h" #include #include @@ -21,11 +21,12 @@ #include #include -#include #include +#include #include "hwy/tests/hwy_gtest.h" + namespace gcpp { namespace { diff --git a/io/io.cc b/gemma/io/io.cc similarity index 99% rename from io/io.cc rename to gemma/io/io.cc index 2f479b21..e39671ec 100644 --- a/io/io.cc +++ b/gemma/io/io.cc @@ -31,7 +31,7 @@ #include #include -#include "io/io.h" +#include "gemma/io/io.h" #include "hwy/base.h" // HWY_ASSERT #if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \ diff --git a/io/io.h b/gemma/io/io.h similarity index 100% rename from io/io.h rename to gemma/io/io.h diff --git a/io/io_win.cc b/gemma/io/io_win.cc similarity index 99% rename from io/io_win.cc rename to gemma/io/io_win.cc index 34773d33..7fa287ca 100644 --- a/io/io_win.cc +++ b/gemma/io/io_win.cc @@ -21,9 +21,10 @@ #include #include -#include "io/io.h" -#include "util/allocator.h" +#include "gemma/io/io.h" #include "hwy/base.h" // HWY_ASSERT +#include "util/allocator.h" + #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif diff --git a/io/migrate_weights.cc b/gemma/io/migrate_weights.cc similarity index 100% rename from io/migrate_weights.cc rename to gemma/io/migrate_weights.cc diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 76f0c754..ac6b701c 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -30,9 +30,9 @@ #include "gemma/configs.h" // ModelConfig, kMaxQKVDim #include "gemma/tensor_info.h" #include "gemma/tokenizer.h" -#include "io/blob_store.h" -#include "io/fields.h" -#include "io/io.h" // Path +#include "gemma/io/blob_store.h" +#include "gemma/io/fields.h" +#include "gemma/io/io.h" // Path #include "util/basics.h" #include "util/threading_context.h" #include "hwy/base.h" diff --git a/gemma/model_store.h b/gemma/model_store.h index 506fb77c..39f23c7a 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -28,11 +28,12 @@ // IWYU pragma: begin_exports #include "gemma/configs.h" // ModelConfig +#include "gemma/io/blob_store.h" #include "gemma/tokenizer.h" -#include "io/blob_store.h" -#include "io/io.h" // Path +#include "gemma/io/io.h" // Path #include "util/basics.h" // Tristate #include "util/mat.h" // MatPtr + // IWYU pragma: end_exports #include "util/allocator.h" diff --git a/gemma/tensor_stats.cc b/gemma/tensor_stats.cc index 62dcb98c..41cc3568 100644 --- a/gemma/tensor_stats.cc +++ b/gemma/tensor_stats.cc @@ -24,7 +24,7 @@ #include #include -#include "io/io.h" +#include "gemma/io/io.h" #include "util/mat.h" #include "util/threading_context.h" #include "util/zones.h" diff --git a/gemma/weights.cc b/gemma/weights.cc index 00c12c64..0c138ff0 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -29,7 +29,7 @@ #include "gemma/configs.h" #include "gemma/gemma_args.h" #include "gemma/model_store.h" -#include "io/blob_store.h" +#include "gemma/io/blob_store.h" #include "util/mat.h" #include "util/threading_context.h" #include "util/zones.h" diff --git a/gemma/weights.h b/gemma/weights.h index 4476e22d..b18e6c47 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -27,7 +27,7 @@ #include "gemma/gemma_args.h" // InferenceArgs #include "gemma/model_store.h" // ModelStore #include "gemma/tensor_info.h" // TensorInfoRegistry -#include "io/blob_store.h" // BlobWriter +#include "gemma/io/blob_store.h" // BlobWriter #include "util/mat.h" // MatPtr #include "util/threading_context.h" diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index b749e05d..390c79a3 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -16,7 +16,7 @@ cc_library( srcs = ["image.cc"], hdrs = ["image.h"], deps = [ - "//io", + "//gemma/io", "@highway//:hwy", "@highway//:profiler", ], @@ -59,12 +59,22 @@ cc_test( "no_tap", ], deps = [ + ":paligemma_helper", + "@googletest//:gtest_main", # buildcleaner: keep + "//:allocator", + "//:benchmark_helper", + "//:configs", + "//:gemma_lib", + + "//gemma/io", + "@highway//:hwy_test_util", + ], ) diff --git a/paligemma/image.cc b/paligemma/image.cc index d8b0cfc7..aa68e2a0 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -30,7 +30,7 @@ #include #include -#include "io/io.h" +#include "gemma/io/io.h" #include "hwy/aligned_allocator.h" // hwy::Span #include "hwy/base.h" #include "hwy/profiler.h" diff --git a/util/args.h b/util/args.h index 8c6423b3..d7eb7dc9 100644 --- a/util/args.h +++ b/util/args.h @@ -24,9 +24,10 @@ #include #include -#include "io/io.h" // Path +#include "gemma/io/io.h" // Path +#include "hwy/base.h" // HWY_ABORT #include "util/basics.h" // Tristate -#include "hwy/base.h" // HWY_ABORT + namespace gcpp { diff --git a/util/mat.h b/util/mat.h index 08300461..b68ef16c 100644 --- a/util/mat.h +++ b/util/mat.h @@ -24,10 +24,11 @@ // IWYU pragma: begin_exports #include "compression/types.h" // Type +#include "gemma/io/fields.h" #include "gemma/tensor_info.h" -#include "io/fields.h" #include "util/allocator.h" // AlignedPtr -#include "util/basics.h" // Extents2D +#include "util/basics.h" // Extents2D + // IWYU pragma: end_exports #include "hwy/base.h" @@ -457,7 +458,7 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func, // Like CallUpcasted, but only for kv_cache types: kBF16 and kF32. template decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func, - Args&&... args) { + Args&&... args) { if (base->GetType() == Type::kF32) { const MatPtrT mat(*base); return func(&mat, std::forward(args)...); diff --git a/util/threading_context.h b/util/threading_context.h index 7e595ba6..1f8452d4 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -23,14 +23,15 @@ #include // IWYU pragma: begin_exports -#include "io/io.h" // Path +#include "gemma/io/io.h" // Path +#include "hwy/profiler.h" #include "util/allocator.h" #include "util/args.h" #include "util/basics.h" // Tristate #include "util/threading.h" #include "util/topology.h" #include "util/zones.h" -#include "hwy/profiler.h" + // IWYU pragma: end_exports namespace gcpp {