diff --git a/BUILD.bazel b/BUILD.bazel index 130f18ff..9eb60a4e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -523,6 +523,7 @@ cc_library( ":configs", ":gemma_args", ":mat", + "//compression:types", "@highway//:hwy", ], ) @@ -575,6 +576,7 @@ cc_library( ":configs", ":mat", ":threading_context", + "//compression:types", "//io", "@highway//:hwy", "@highway//:profiler", diff --git a/gemma/activations.h b/gemma/activations.h index 11e2b1c3..c1b943ee 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -35,6 +35,7 @@ namespace gcpp { typedef std::vector> AlignedFloatVector; +typedef std::vector> AlignedBF16Vector; // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index ba72db67..6ccb5b38 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -22,8 +22,10 @@ #include #include +#include #include +#include "compression/types.h" #include "gemma/configs.h" #include "io/io.h" // Path #include "util/args.h" // IWYU pragma: export diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 2fe6885f..49276f83 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -16,8 +16,12 @@ #include "gemma/kv_cache.h" #include + +#include +#include #include +#include "compression/types.h" #include "gemma/configs.h" #include "gemma/gemma_args.h" #include "util/mat.h" // ZeroInit diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index bad66fa0..fe6a1ff9 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -19,12 +19,14 @@ #include #include +#include #include #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // InferenceArgs #include "util/basics.h" // BF16 #include "util/mat.h" +#include "hwy/base.h" namespace gcpp { diff --git a/gemma/kv_cache_test.cc b/gemma/kv_cache_test.cc index 157b3d95..7b7bed20 100644 --- a/gemma/kv_cache_test.cc +++ b/gemma/kv_cache_test.cc @@ -35,8 +35,13 @@ TEST(KVCacheTest, KVCacheToPtrs) { std::vector ptrs = ToKVCachePtrs({caches.data(), caches.size()}); ASSERT_EQ(ptrs.size(), 2); - EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0)); - EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0)); + if (caches[0].IsTiled()) { + EXPECT_EQ(ptrs[0].cache, &caches[0]); + EXPECT_EQ(ptrs[1].cache, &caches[1]); + } else { + EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0)); + EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0)); + } } } // namespace diff --git a/util/mat.h b/util/mat.h index 83d03b15..08300461 100644 --- a/util/mat.h +++ b/util/mat.h @@ -469,6 +469,38 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func, } } +// Calls 'func' with a span of MatPtrT for all elements in `base`. +// T is dynamic type, read from base. It is assumed that all elements in `base` +// have the same type. +template +decltype(auto) CallUpcastedKVs(hwy::Span base, const Func& func, + Args&&... args) { + Type type = base[0].GetType(); + for ([[maybe_unused]] auto&& mat : base) { + HWY_DASSERT(mat.GetType() == type); + } + auto convert_to_matptr_t = [&base]() { + std::vector> matptrs; + matptrs.reserve(base.size()); + for (auto&& mat : base) { + matptrs.emplace_back(mat); + } + return matptrs; + }; + if (type == Type::kF32) { + auto matptrs = convert_to_matptr_t.template operator()(); + hwy::Span> matptrs_span(matptrs.data(), + matptrs.size()); + return func(matptrs_span, std::forward(args)...); + } else if (type == Type::kBF16) { + auto matptrs = convert_to_matptr_t.template operator()(); + hwy::Span> matptrs_span(matptrs.data(), matptrs.size()); + return func(matptrs_span, std::forward(args)...); + } else { + HWY_ABORT("Unhandled type %s.", TypeName(type)); + } +} + void CopyMat(const MatPtr& from, MatPtr& to); void ZeroInit(MatPtr& mat);