Allow variable coverage

This commit is contained in:
2026-04-12 20:51:19 -07:00
parent 59b5eade7e
commit 3fb10b78e3
4 changed files with 320 additions and 172 deletions

View File

@@ -153,7 +153,7 @@ fn bench_encode(c: &mut Criterion) {
group.bench_function("uniform", |b| { group.bench_function("uniform", |b| {
b.iter(|| { b.iter(|| {
for blk in &uniform { for blk in &uniform {
black_box(encode(black_box(blk))); black_box(encode(black_box(blk), 0.0));
} }
}); });
}); });
@@ -161,7 +161,7 @@ fn bench_encode(c: &mut Criterion) {
group.bench_function("rle_optimal", |b| { group.bench_function("rle_optimal", |b| {
b.iter(|| { b.iter(|| {
for blk in &rle_opt { for blk in &rle_opt {
black_box(encode(black_box(blk))); black_box(encode(black_box(blk), 0.0));
} }
}); });
}); });
@@ -186,8 +186,8 @@ fn bench_dequantize(c: &mut Criterion) {
let q4k_uniform = uniform_blocks(1).into_iter().next().unwrap(); let q4k_uniform = uniform_blocks(1).into_iter().next().unwrap();
let q4k_rle_opt = rle_optimal_blocks(1).into_iter().next().unwrap(); let q4k_rle_opt = rle_optimal_blocks(1).into_iter().next().unwrap();
let rle_raw = encode(&q4k_uniform); // IS_RLE = 0 let rle_raw = encode(&q4k_uniform, 0.0); // IS_RLE = 0
let rle_rle = encode(&q4k_rle_opt); // IS_RLE = 1 let rle_rle = encode(&q4k_rle_opt, 0.0); // IS_RLE = 1
// Confirm the fixtures ended up in the right encoding modes. // Confirm the fixtures ended up in the right encoding modes.
assert!(!rle_raw.is_rle(), "uniform block should encode to raw mode"); assert!(!rle_raw.is_rle(), "uniform block should encode to raw mode");
@@ -270,10 +270,10 @@ fn bench_matmul(c: &mut Criterion) {
// Build all four A variants and the shared B matrix for this config. // Build all four A variants and the shared B matrix for this config.
let a_q4k_u: Vec<BlockQ4K> = uniform_blocks(m * bpr); let a_q4k_u: Vec<BlockQ4K> = uniform_blocks(m * bpr);
let a_rle_u: Vec<BlockQ4KRle> = a_q4k_u.iter().map(encode).collect(); let a_rle_u: Vec<BlockQ4KRle> = a_q4k_u.iter().map(|b| encode(b, 0.0)).collect();
let a_q4k_r: Vec<BlockQ4K> = rle_optimal_blocks(m * bpr); let a_q4k_r: Vec<BlockQ4K> = rle_optimal_blocks(m * bpr);
let a_rle_r: Vec<BlockQ4KRle> = a_q4k_r.iter().map(encode).collect(); let a_rle_r: Vec<BlockQ4KRle> = a_q4k_r.iter().map(|b| encode(b, 0.0)).collect();
let b = fp16_ones(k, n); let b = fp16_ones(k, n);

View File

@@ -145,7 +145,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// ── RLE encode (best of `trials`) ──────────────────────────────────────── // ── RLE encode (best of `trials`) ────────────────────────────────────────
let (rle_blocks, t_enc) = bench(trials, || -> Vec<BlockQ4KRle> { let (rle_blocks, t_enc) = bench(trials, || -> Vec<BlockQ4KRle> {
blocks.iter().map(encode).collect() blocks.iter().map(|b| encode(b, 0.0)).collect()
}); });
let n_rle = rle_blocks.iter().filter(|b| b.is_rle()).count(); let n_rle = rle_blocks.iter().filter(|b| b.is_rle()).count();

View File

@@ -101,11 +101,32 @@ fn fixed(s: &str, width: usize) -> String {
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let args: Vec<String> = env::args().collect(); let args: Vec<String> = env::args().collect();
if args.len() < 2 { if args.len() < 2 {
eprintln!("usage: {} <model.gguf>", args[0]); eprintln!("usage: {} <model.gguf> [--threshold <0.0..1.0>]", args[0]);
eprintln!();
eprintln!(" --threshold Minimum fraction of qs bytes that must be in runs of");
eprintln!(" length ≥ 2 for a block to use RLE mode. Default: 0.0");
eprintln!(" (use RLE whenever the pair count fits in 64 pairs).");
std::process::exit(1); std::process::exit(1);
} }
let path = &args[1]; let path = &args[1];
// Parse optional --threshold flag from the remaining arguments.
let mut threshold = 0.0f32;
let mut idx = 2usize;
while idx < args.len() {
if args[idx] == "--threshold" {
idx += 1;
threshold = args.get(idx)
.and_then(|s| s.parse::<f32>().ok())
.filter(|&v| (0.0..=1.0).contains(&v))
.unwrap_or_else(|| {
eprintln!("error: --threshold requires a value in [0.0, 1.0]");
std::process::exit(1);
});
}
idx += 1;
}
// ── Parse header ───────────────────────────────────────────────────────── // ── Parse header ─────────────────────────────────────────────────────────
eprintln!("Parsing {path}"); eprintln!("Parsing {path}");
let (tensors, data_start) = parse_header(path)?; let (tensors, data_start) = parse_header(path)?;
@@ -122,6 +143,8 @@ fn main() -> Result<(), Box<dyn Error>> {
q4k_tensors.len(), q4k_tensors.len(),
other_count, other_count,
); );
eprintln!(" RLE threshold: {threshold:.2} (blocks need ≥ {:.0}% of bytes in runs)",
threshold * 100.0);
eprintln!(); eprintln!();
// ── Header row ─────────────────────────────────────────────────────────── // ── Header row ───────────────────────────────────────────────────────────
@@ -145,7 +168,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let mut stats = TensorStats::new(); let mut stats = TensorStats::new();
for_each_block(&mut file, data_start, tensor, |block| { for_each_block(&mut file, data_start, tensor, |block| {
let rle_block = encode(block); let rle_block = encode(block, threshold);
stats.observe(rle_block.is_rle(), rle_block.rle_len()); stats.observe(rle_block.is_rle(), rle_block.rle_len());
})?; })?;
@@ -187,10 +210,15 @@ fn main() -> Result<(), Box<dyn Error>> {
if !any_rle { if !any_rle {
println!(); println!();
println!("No blocks compressed with RLE — all weights are effectively random at"); println!("No blocks used RLE at threshold {threshold:.2}.");
println!("the byte level, which is typical for trained Q4_K quantised weights."); if threshold < 0.01 {
println!("RLE compression only helps for structured weight matrices (binary,"); println!("All weights are effectively random at the byte level — typical for");
println!("ternary, heavily pruned, or synthetic)."); println!("trained Q4_K weights. RLE only helps for structured weight matrices");
println!("(binary, ternary, heavily pruned, or synthetic).");
} else {
println!("Try a lower --threshold (e.g. --threshold 0.0) to see whether any");
println!("blocks have enough run structure to qualify at a looser threshold.");
}
} }
Ok(()) Ok(())

View File

@@ -12,18 +12,16 @@
//! //!
//! ## RLE format (when `IS_RLE` = 1) //! ## RLE format (when `IS_RLE` = 1)
//! //!
//! - `flags >> 1` gives the number of `(value, count)` pairs stored in `qs`. //! - `n_pairs` gives the number of `(value, count)` pairs stored in `qs`.
//! - For each pair `i`: //! - For each pair `i`:
//! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing //! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing
//! as the raw format). //! as the raw format).
//! - `qs[2*i + 1]` — the run length in bytes (1..=255). //! - `qs[2*i + 1]` — the run length in bytes (1..=255).
//! - The run lengths must sum to exactly 128 (the uncompressed `qs` size). //! - The run lengths must sum to exactly 128 (the uncompressed `qs` size).
//! //!
//! RLE encoding is chosen only when the compressed representation is //! The 256-byte `qs` field can hold up to 128 `(value, count)` pairs — enough
//! **strictly shorter** than the 128-byte raw payload, i.e. when //! to represent even fully-random blocks where every byte differs from its
//! `pairs * 2 < 128`. That caps the useful range at ≤ 63 pairs. The 7-bit //! neighbour.
//! `flags >> 1` sub-field can hold up to 127, so this ceiling is never a
//! concern in practice.
//! //!
//! ## Constructing blocks //! ## Constructing blocks
//! //!
@@ -49,42 +47,45 @@ pub const IS_RLE: u8 = 0x01;
/// A Q4_K super-block with optional byte-level RLE compression on the weights. /// A Q4_K super-block with optional byte-level RLE compression on the weights.
/// ///
/// Identical to [`crate::BlockQ4K`] except for the additional [`flags`](Self::flags) /// Unlike [`crate::BlockQ4K`], this format is **not** binary-compatible with
/// byte inserted between `scales` and `qs`. /// the GGUF on-disk layout. It uses a 256-byte `qs` field (vs the 128-byte
/// field in `BlockQ4K`) so the RLE stream can store up to 128 `(value, count)`
/// pairs — enough to represent even fully-random blocks where every byte
/// differs from its neighbour.
/// ///
/// Memory layout (repr C): /// Memory layout (`repr C`):
/// ///
/// | Offset | Field | Size | Notes | /// | Offset | Field | Size | Notes |
/// |--------|------------|-------|--------------------------------| /// |--------|-----------|--------|------------------------------------|
/// | 0 | `d` | 2 B | fp16 super-block scale | /// | 0 | `d` | 2 B | fp16 super-block scale |
/// | 2 | `dmin` | fp16 super-block min scale | 2 B | /// | 2 | `dmin` | 2 B | fp16 super-block min-scale |
/// | 4 | `scales` | 12 B | packed 6-bit sub-block params | /// | 4 | `scales` | 12 B | packed 6-bit sub-block params |
/// | 16 | `flags` | 1 B | encoding flags (see below) | /// | 16 | `flags` | 1 B | bit 0 = `IS_RLE`; bits 17 unused |
/// | 17 | `qs` | 128 B | raw nibbles or RLE stream | /// | 17 | `n_pairs` | 1 B | RLE pair count (0 when raw) |
/// | 145 | (padding) | 1 B | implicit trailing alignment pad| /// | 18 | `qs` | 256 B | raw nibbles (first 128 B) or RLE |
/// ///
/// **sizeof = 146 bytes** (padded to 2-byte alignment imposed by `u16` fields). /// **sizeof = 274 bytes.**
/// ///
/// ## `flags` bit layout /// ## `qs` interpretation
/// ///
/// | Bits | Meaning | /// | `IS_RLE` | Meaning |
/// |------|---------------------------------------------------------------| /// |----------|--------------------------------------------------------------|
/// | 0 | [`IS_RLE`] — 1 = `qs` is RLE-encoded, 0 = raw packed nibbles | /// | 0 | `qs[0..128]` holds raw packed nibbles (same as `BlockQ4K`) |
/// | 17 | When `IS_RLE`=1: number of `(value, count)` pairs in `qs` | /// | 1 | `qs[0..n_pairs*2]` holds `(value, count)` byte-pairs |
#[repr(C)] #[repr(C)]
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct BlockQ4KRle { pub struct BlockQ4KRle {
/// Super-block scale for quantised sub-block scales (fp16 bits).
pub d: u16, pub d: u16,
/// Super-block scale for quantised sub-block mins (fp16 bits).
pub dmin: u16, pub dmin: u16,
/// Packed 6-bit sub-block scales and mins (same layout as [`crate::BlockQ4K`]).
pub scales: [u8; K_SCALE_SIZE], pub scales: [u8; K_SCALE_SIZE],
/// Encoding flags. Bit 0 = [`IS_RLE`]. Bits 1-7 = RLE pair count when /// Encoding flags. Only bit 0 (`IS_RLE`) is used; bits 1-7 are reserved.
/// `IS_RLE` is set.
pub flags: u8, pub flags: u8,
/// Raw packed-nibble weights (`IS_RLE` = 0) or RLE byte stream (`IS_RLE` = 1). /// When `IS_RLE` is set: number of `(value, count)` byte-pairs in `qs`.
pub qs: [u8; QK_K / 2], /// Zero when in raw mode.
pub n_pairs: u8,
/// Raw packed-nibble weights (IS_RLE = 0, first 128 bytes) or RLE stream
/// (IS_RLE = 1, first `n_pairs * 2` bytes).
pub qs: [u8; QK_K], // 256 bytes
} }
impl BlockQ4KRle { impl BlockQ4KRle {
@@ -94,12 +95,11 @@ impl BlockQ4KRle {
self.flags & IS_RLE != 0 self.flags & IS_RLE != 0
} }
/// Number of `(value, count)` byte-pairs stored at the start of `qs`. /// Number of `(value, count)` byte-pairs in `qs`.
/// /// Only meaningful when `is_rle()` is true.
/// Only meaningful when [`is_rle`](Self::is_rle) returns `true`.
#[inline] #[inline]
pub fn rle_len(&self) -> usize { pub fn rle_len(&self) -> usize {
(self.flags >> 1) as usize self.n_pairs as usize
} }
} }
@@ -109,24 +109,38 @@ impl BlockQ4KRle {
/// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block. /// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block.
/// ///
/// The 128-byte `qs` payload is scanned for runs of identical bytes. If the /// The `qs` payload is scanned for runs of equal consecutive bytes. RLE mode
/// RLE representation fits in the same 128-byte field **and is strictly /// is chosen when **both** conditions hold:
/// shorter** than the raw payload, it is stored with `IS_RLE` set. Otherwise
/// the raw bytes are copied unchanged and `IS_RLE` is cleared.
/// ///
/// The `d`, `dmin`, and `scales` fields are always copied verbatim. /// 1. **Coverage**: at least `min_coverage` fraction of the 128 `qs` bytes
pub fn encode(block: &BlockQ4K) -> BlockQ4KRle { /// belong to runs of length ≥ 2. These are the bytes whose weights can be
let mut raw = Vec::from(&block.qs); /// batched in `accumulate_rle_block`, replacing `2 * run_len` multiplies
/// with just 2 per group-segment.
///
/// 2. **Capacity**: the pair count does not exceed 128 (the physical limit of
/// the 256-byte `qs` field at 2 bytes per pair).
///
/// | `min_coverage` | Effect |
/// |----------------|------------------------------------------------------|
/// | `0.0` | RLE whenever pairs fit (≤ 128), regardless of runs |
/// | `0.5` | RLE only if ≥ 50 % of bytes are in repeated runs |
/// | `1.0` | RLE only when every byte is part of a run |
pub fn encode(block: &BlockQ4K, min_coverage: f32) -> BlockQ4KRle {
debug_assert!(
(0.0..=1.0).contains(&min_coverage),
"min_coverage must be in [0.0, 1.0], got {min_coverage}"
);
// Sort the raw numbers let raw = &block.qs; // [u8; 128]
raw.sort();
// Scan the 128-byte raw payload for runs of equal bytes. // Scan for runs of equal consecutive bytes.
let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(64); // Track long_run_bytes: bytes in runs of length ≥ 2 (the bytes that
// benefit from RLE in the matmul).
let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(QK_K / 2);
let mut long_run_bytes = 0usize;
let mut i = 0usize; let mut i = 0usize;
while i < raw.len() { while i < raw.len() {
let val = raw[i]; let val = raw[i];
// Count consecutive equal bytes; saturate at u8::MAX to stay in-range.
let mut run = 1u8; let mut run = 1u8;
while i + (run as usize) < raw.len() while i + (run as usize) < raw.len()
&& raw[i + (run as usize)] == val && raw[i + (run as usize)] == val
@@ -135,38 +149,40 @@ pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
run += 1; run += 1;
} }
pairs.push((val, run)); pairs.push((val, run));
if run >= 2 {
long_run_bytes += run as usize;
}
i += run as usize; i += run as usize;
} }
// Only switch to RLE when the encoded form is strictly smaller than the // Coverage: fraction of qs bytes that are in non-singleton runs.
// raw payload. Because each pair costs 2 bytes and the raw payload is let coverage = long_run_bytes as f32 / raw.len() as f32;
// 128 bytes, the condition pairs.len() * 2 < 128 also guarantees that
// pairs.len() ≤ 63, which fits in bits 1-7 of the flags byte.
if pairs.len() * 2 < raw.len() {
let n = pairs.len();
debug_assert!(n <= 63, "RLE pair count {n} unexpectedly exceeds 63");
let mut qs = [0u8; QK_K / 2]; if pairs.len() <= QK_K / 2 && coverage >= min_coverage {
let n = pairs.len();
let mut qs = [0u8; QK_K];
for (k, &(val, count)) in pairs.iter().enumerate() { for (k, &(val, count)) in pairs.iter().enumerate() {
qs[2 * k] = val; qs[2 * k] = val;
qs[2 * k + 1] = count; qs[2 * k + 1] = count;
} }
BlockQ4KRle { BlockQ4KRle {
d: block.d, d: block.d,
dmin: block.dmin, dmin: block.dmin,
scales: block.scales, scales: block.scales,
flags: IS_RLE | ((n as u8) << 1), flags: IS_RLE,
n_pairs: n as u8,
qs, qs,
} }
} else { } else {
// No space savings — copy raw bytes and leave IS_RLE clear. let mut qs = [0u8; QK_K];
qs[..QK_K / 2].copy_from_slice(&block.qs);
BlockQ4KRle { BlockQ4KRle {
d: block.d, d: block.d,
dmin: block.dmin, dmin: block.dmin,
scales: block.scales, scales: block.scales,
flags: 0, flags: 0,
qs: block.qs, n_pairs: 0,
qs,
} }
} }
} }
@@ -183,20 +199,18 @@ pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
/// Panics if the decoded RLE stream does not sum to exactly 128 bytes. /// Panics if the decoded RLE stream does not sum to exactly 128 bytes.
fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] { fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] {
if !block.is_rle() { if !block.is_rle() {
return block.qs; // First QK_K/2 bytes of qs hold the raw packed nibbles.
} block.qs[..QK_K / 2].try_into().unwrap()
} else {
let n = block.rle_len(); let n = block.rle_len();
let mut raw = [0u8; QK_K / 2]; let mut raw = [0u8; QK_K / 2];
let mut pos = 0usize; let mut pos = 0usize;
for i in 0..n { for i in 0..n {
let val = block.qs[2 * i]; let val = block.qs[2 * i];
let count = block.qs[2 * i + 1] as usize; let count = block.qs[2 * i + 1] as usize;
raw[pos..pos + count].fill(val); raw[pos..pos + count].fill(val);
pos += count; pos += count;
} }
debug_assert_eq!( debug_assert_eq!(
pos, pos,
QK_K / 2, QK_K / 2,
@@ -204,6 +218,7 @@ fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] {
QK_K / 2 QK_K / 2
); );
raw raw
}
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -570,17 +585,18 @@ mod tests {
// ========================================================================= // =========================================================================
#[test] #[test]
fn block_q4k_rle_size_is_146_bytes() { fn block_q4k_rle_size_is_274_bytes() {
// d(2) + dmin(2) + scales(12) + flags(1) + qs(128) = 145 raw bytes, // d(2) + dmin(2) + scales(12) + flags(1) + n_pairs(1) + qs(256) = 274 bytes.
// rounded up to 146 by the 2-byte alignment imposed by the u16 fields. // No padding needed: struct is already 2-byte aligned and 274 is even.
assert_eq!(core::mem::size_of::<BlockQ4KRle>(), 146); assert_eq!(core::mem::size_of::<BlockQ4KRle>(), 274);
} }
#[test] #[test]
fn block_q4k_rle_is_two_bytes_larger_than_block_q4k() { fn block_q4k_rle_is_130_bytes_larger_than_block_q4k() {
// BlockQ4K = 144 bytes, BlockQ4KRle = 274 bytes, delta = 130.
assert_eq!( assert_eq!(
core::mem::size_of::<BlockQ4KRle>(), core::mem::size_of::<BlockQ4KRle>(),
core::mem::size_of::<BlockQ4K>() + 2, core::mem::size_of::<BlockQ4K>() + 130,
); );
} }
@@ -591,7 +607,7 @@ mod tests {
#[test] #[test]
fn is_rle_false_when_flag_clear() { fn is_rle_false_when_flag_clear() {
let b = BlockQ4KRle { let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2], d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs: [0; QK_K],
}; };
assert!(!b.is_rle()); assert!(!b.is_rle());
} }
@@ -599,7 +615,7 @@ mod tests {
#[test] #[test]
fn rle_len_zero_when_flag_clear() { fn rle_len_zero_when_flag_clear() {
let b = BlockQ4KRle { let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2], d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs: [0; QK_K],
}; };
assert_eq!(b.rle_len(), 0); assert_eq!(b.rle_len(), 0);
} }
@@ -608,19 +624,21 @@ mod tests {
fn is_rle_true_when_flag_set() { fn is_rle_true_when_flag_set() {
let b = BlockQ4KRle { let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | (5u8 << 1), flags: IS_RLE,
qs: [0; QK_K / 2], n_pairs: 5,
qs: [0; QK_K],
}; };
assert!(b.is_rle()); assert!(b.is_rle());
} }
#[test] #[test]
fn rle_len_reports_pair_count_from_flags() { fn rle_len_reports_pair_count_from_n_pairs() {
for n in [0usize, 1, 7, 31, 63] { for n in [0usize, 1, 7, 31, 63, 128] {
let b = BlockQ4KRle { let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | ((n as u8) << 1), flags: if n > 0 { IS_RLE } else { 0 },
qs: [0; QK_K / 2], n_pairs: n as u8,
qs: [0; QK_K],
}; };
assert_eq!(b.rle_len(), n, "expected rle_len {n}"); assert_eq!(b.rle_len(), n, "expected rle_len {n}");
} }
@@ -632,61 +650,64 @@ mod tests {
#[test] #[test]
fn encode_uniform_qs_uses_rle() { fn encode_uniform_qs_uses_rle() {
// 128 identical bytes → 1 pair → 2 bytes < 128 raw. // 128 identical bytes → 1 pair → 2 bytes stored in qs.
let src = make_block(1.0, 0.0, 1, 0, 0x77); let src = make_block(1.0, 0.0, 1, 0, 0x77);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(rle.is_rle(), "uniform qs should trigger RLE mode"); assert!(rle.is_rle(), "uniform qs should trigger RLE mode");
} }
#[test] #[test]
fn encode_uniform_qs_rle_len_is_one() { fn encode_uniform_qs_rle_len_is_one() {
let src = make_block(1.0, 0.0, 1, 0, 0x55); let src = make_block(1.0, 0.0, 1, 0, 0x55);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert_eq!(rle.rle_len(), 1); assert_eq!(rle.rle_len(), 1);
} }
#[test] #[test]
fn encode_uniform_qs_rle_entry_is_correct() { fn encode_uniform_qs_rle_entry_is_correct() {
let src = make_block(1.0, 0.0, 1, 0, 0xAB); let src = make_block(1.0, 0.0, 1, 0, 0xAB);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte"); assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte");
assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes"); assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes");
} }
#[test] #[test]
fn encode_alternating_bytes_stays_raw() { fn encode_alternating_bytes_stays_raw() {
// 128 single-byte runs → 128 pairs → 256 bytes ≥ 128 raw → raw mode. // Alternating 0xAA / 0x55 → 128 singleton pairs, coverage = 0%.
// At threshold 0.01 the 0% coverage fails → raw mode.
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() { for (i, b) in qs.iter_mut().enumerate() {
*b = if i % 2 == 0 { 0xAA } else { 0x55 }; *b = if i % 2 == 0 { 0xAA } else { 0x55 };
} }
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src); let rle = encode(&src, 0.01);
assert!(!rle.is_rle(), "alternating bytes cannot be compressed → raw mode"); assert!(!rle.is_rle(), "0% coverage fails any threshold > 0 → raw mode");
} }
#[test] #[test]
fn encode_raw_mode_copies_qs_verbatim() { fn encode_raw_mode_copies_qs_verbatim() {
// Three-byte cycle of distinct values → 128 runs of 1 byte each,
// coverage = 0%. At threshold 0.01 the 0% coverage fails → raw mode.
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() { for (i, b) in qs.iter_mut().enumerate() {
// Three-byte cycle of distinct values → 128 runs of 1 byte each.
*b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 }; *b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 };
} }
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src); let rle = encode(&src, 0.01);
assert!(!rle.is_rle()); assert!(!rle.is_rle());
assert_eq!(rle.qs, qs, "raw mode must preserve qs bytes unchanged"); // Raw mode copies the 128-byte qs into the first half of the 256-byte field.
assert_eq!(&rle.qs[..QK_K / 2], &qs[..], "raw mode must preserve qs bytes in first 128 bytes");
} }
#[test] #[test]
fn encode_two_runs_uses_rle_and_stores_correct_pairs() { fn encode_two_runs_uses_rle_and_stores_correct_pairs() {
// Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22. // Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22.
// → 2 pairs = 4 bytes < 128 bytes raw. // → 2 pairs = 4 bytes.
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K / 2];
qs[..64].fill(0x11); qs[..64].fill(0x11);
qs[64..].fill(0x22); qs[64..].fill(0x22);
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(rle.is_rle()); assert!(rle.is_rle());
assert_eq!(rle.rle_len(), 2); assert_eq!(rle.rle_len(), 2);
assert_eq!(rle.qs[0], 0x11, "first pair: value"); assert_eq!(rle.qs[0], 0x11, "first pair: value");
@@ -698,7 +719,7 @@ mod tests {
#[test] #[test]
fn encode_63_pairs_uses_rle() { fn encode_63_pairs_uses_rle() {
// Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes. // Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes.
// 63 pairs × 2 = 126 bytes < 128 → RLE should be chosen. // 63 pairs × 2 = 126 bytes; 63 ≤ 128 → RLE should be chosen.
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K / 2];
let mut pos = 0usize; let mut pos = 0usize;
for run in 0..62usize { for run in 0..62usize {
@@ -712,15 +733,15 @@ mod tests {
qs[pos..].fill(0xFE); qs[pos..].fill(0xFE);
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(rle.is_rle(), "63 pairs should use RLE"); assert!(rle.is_rle(), "63 pairs should use RLE");
assert_eq!(rle.rle_len(), 63); assert_eq!(rle.rle_len(), 63);
} }
#[test] #[test]
fn encode_64_pairs_stays_raw() { fn encode_64_pairs_uses_rle_at_zero_threshold() {
// 64 runs of 2 bytes each = 128 bytes total. // 64 runs of 2 bytes each = 128 bytes total, coverage = 100%.
// 64 pairs × 2 = 128 bytes, which is NOT strictly less than 128 → raw. // pairs (64) ≤ 128 AND 100% ≥ 0.0 → RLE mode.
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K / 2];
let mut pos = 0usize; let mut pos = 0usize;
for run in 0..64usize { for run in 0..64usize {
@@ -730,14 +751,93 @@ mod tests {
pos += 2; pos += 2;
} }
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(!rle.is_rle(), "64 pairs offers no saving → raw mode"); assert!(rle.is_rle(), "64 pairs, 100% coverage, threshold 0.0 → RLE");
assert_eq!(rle.rle_len(), 64);
}
#[test]
fn encode_128_pairs_uses_rle_at_zero_threshold() {
// 128 distinct consecutive bytes = 128 singleton runs = 128 pairs.
// With old cap (64 pairs), this was always raw.
// With new cap (128 pairs), threshold 0.0 accepts it.
// Coverage = 0 % (all singletons) → threshold > 0.0 rejects it.
let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() {
*b = i as u8; // 0x00, 0x01, ..., 0x7F — all distinct, all singletons
}
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
assert!(
encode(&src, 0.0).is_rle(),
"128 pairs ≤ 128 limit AND 0% ≥ 0.0 → RLE at zero threshold"
);
assert_eq!(encode(&src, 0.0).rle_len(), 128);
assert!(
!encode(&src, 0.01).is_rle(),
"0% coverage fails any threshold > 0"
);
}
#[test]
fn encode_coverage_threshold_rejects_low_coverage_block() {
// Construct: 63 singletons + 1 run of 65 bytes = 64 pairs.
// coverage = 65/128 ≈ 50.8%.
// threshold 0.50 accepts it; threshold 0.60 rejects it.
let mut qs = [0u8; QK_K / 2];
qs[0] = 0x01;
for i in 1..63usize {
// Distinct odd bytes, none equal to 0x01 or adjacent values.
qs[i] = (i as u8).wrapping_mul(2).wrapping_add(5);
}
qs[63..].fill(0xAB); // 65-byte run; qs[62] = 62*2+5 = 129 → wraps to 0x81 ≠ 0xAB ✓
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
assert!(
encode(&src, 0.50).is_rle(),
"50.8% coverage should meet 50% threshold"
);
assert!(
!encode(&src, 0.60).is_rle(),
"50.8% coverage should fail 60% threshold"
);
}
#[test]
fn encode_coverage_zero_threshold_always_uses_rle_when_pairs_fit() {
// Any block whose runs produce ≤ 128 pairs uses RLE at threshold 0.0,
// regardless of how many singletons it contains.
// Use the 63-pair block from encode_63_pairs_uses_rle.
let mut qs = [0u8; QK_K / 2];
let mut pos = 0usize;
for run in 0..62usize {
let v = (run as u8).wrapping_mul(3).wrapping_add(1);
qs[pos] = v;
qs[pos + 1] = v;
pos += 2;
}
qs[pos..].fill(0xFE);
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
assert!(encode(&src, 0.0).is_rle());
}
#[test]
fn encode_coverage_one_threshold_requires_total_coverage() {
// A block with even one singleton byte fails the 100% threshold.
// Build: 1 singleton + 1 run of 127 bytes = 2 pairs, coverage = 127/128 ≈ 99.2%.
let mut qs = [0u8; QK_K / 2];
qs[0] = 0x01; // singleton (value distinct from rest)
qs[1..].fill(0x02); // 127-byte run
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
assert!(!encode(&src, 1.0).is_rle(), "99.2% coverage should fail 100% threshold");
assert!(encode(&src, 0.99).is_rle(), "99.2% coverage should meet 99% threshold");
} }
#[test] #[test]
fn encode_preserves_d_dmin_scales() { fn encode_preserves_d_dmin_scales() {
let src = make_block(2.0, 0.5, 3, 2, 0x00); let src = make_block(2.0, 0.5, 3, 2, 0x00);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert_eq!(rle.d, src.d); assert_eq!(rle.d, src.d);
assert_eq!(rle.dmin, src.dmin); assert_eq!(rle.dmin, src.dmin);
assert_eq!(rle.scales, src.scales); assert_eq!(rle.scales, src.scales);
@@ -749,25 +849,24 @@ mod tests {
#[test] #[test]
fn decode_qs_raw_mode_returns_qs_unchanged() { fn decode_qs_raw_mode_returns_qs_unchanged() {
// Build a raw BlockQ4KRle (flags = 0) with a non-trivial qs pattern. let mut qs = [0u8; QK_K];
let mut qs = [0u8; QK_K / 2]; for (i, b) in qs[..QK_K / 2].iter_mut().enumerate() { *b = i as u8; }
for (i, b) in qs.iter_mut().enumerate() { *b = i as u8; }
let rle = BlockQ4KRle { let rle = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs, d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs,
}; };
assert_eq!(decode_qs(&rle), qs); let expected: [u8; QK_K / 2] = qs[..QK_K / 2].try_into().unwrap();
assert_eq!(decode_qs(&rle), expected);
} }
#[test] #[test]
fn decode_qs_rle_expands_two_pair_stream() { fn decode_qs_rle_expands_two_pair_stream() {
// Hand-craft an RLE block: [0xAA × 64, 0xBB × 64]. // Hand-craft an RLE block: [0xAA × 64, 0xBB × 64].
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K];
qs[0] = 0xAA; qs[1] = 64; qs[0] = 0xAA; qs[1] = 64;
qs[2] = 0xBB; qs[3] = 64; qs[2] = 0xBB; qs[3] = 64;
let rle = BlockQ4KRle { let rle = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | (2u8 << 1), flags: IS_RLE, n_pairs: 2, qs,
qs,
}; };
let expanded = decode_qs(&rle); let expanded = decode_qs(&rle);
assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA"); assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA");
@@ -776,12 +875,11 @@ mod tests {
#[test] #[test]
fn decode_qs_rle_single_run_covers_all() { fn decode_qs_rle_single_run_covers_all() {
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K];
qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes
let rle = BlockQ4KRle { let rle = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | (1u8 << 1), flags: IS_RLE, n_pairs: 1, qs,
qs,
}; };
let expanded = decode_qs(&rle); let expanded = decode_qs(&rle);
assert!(expanded.iter().all(|&b| b == 0xCD)); assert!(expanded.iter().all(|&b| b == 0xCD));
@@ -794,7 +892,7 @@ mod tests {
#[test] #[test]
fn dequant_rle_zero_d_all_outputs_zero() { fn dequant_rle_zero_d_all_outputs_zero() {
let src = make_block(0.0, 0.0, 1, 0, 0x77); let src = make_block(0.0, 0.0, 1, 0, 0x77);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut out = [0.0f32; QK_K]; let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out); dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 0.0, 0.0); assert_all_close(&out, 0.0, 0.0);
@@ -805,7 +903,7 @@ mod tests {
// qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0. // qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0.
// expected: 1.0 * 1 * 1 - 0.0 = 1.0 // expected: 1.0 * 1 * 1 - 0.0 = 1.0
let src = make_block(1.0, 0.0, 1, 0, 0x11); let src = make_block(1.0, 0.0, 1, 0, 0x11);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut out = [0.0f32; QK_K]; let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out); dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 1.0, 1e-5); assert_all_close(&out, 1.0, 1e-5);
@@ -816,7 +914,7 @@ mod tests {
// nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0 // nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0
// expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0 // expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0
let src = make_block(1.0, 1.0, 1, 2, 0x00); let src = make_block(1.0, 1.0, 1, 2, 0x00);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut out = [0.0f32; QK_K]; let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out); dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, -2.0, 1e-5); assert_all_close(&out, -2.0, 1e-5);
@@ -827,7 +925,7 @@ mod tests {
// qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0. // qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0.
// expected: 1.0 * 1 * 15 - 0.0 = 15.0 // expected: 1.0 * 1 * 15 - 0.0 = 15.0
let src = make_block(1.0, 0.0, 1, 0, 0xFF); let src = make_block(1.0, 0.0, 1, 0, 0xFF);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut out = [0.0f32; QK_K]; let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out); dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 15.0, 1e-5); assert_all_close(&out, 15.0, 1e-5);
@@ -836,7 +934,7 @@ mod tests {
#[test] #[test]
fn dequant_rle_output_count_is_qk_k() { fn dequant_rle_output_count_is_qk_k() {
let src = make_block(1.0, 0.0, 1, 0, 0x00); let src = make_block(1.0, 0.0, 1, 0, 0x00);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut out = [0.0f32; QK_K]; let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out); dequantize_block_q4k_rle(&rle, &mut out);
assert_eq!(out.len(), QK_K); assert_eq!(out.len(), QK_K);
@@ -848,7 +946,7 @@ mod tests {
// expected: 2.0 * 4 * 3 - 0.0 = 24.0 // expected: 2.0 * 4 * 3 - 0.0 = 24.0
// qs_byte = 0x33 → both nibbles = 3 // qs_byte = 0x33 → both nibbles = 3
let src = make_block(2.0, 0.0, 4, 0, 0x33); let src = make_block(2.0, 0.0, 4, 0, 0x33);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut out = [0.0f32; QK_K]; let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out); dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 24.0, 1e-4); assert_all_close(&out, 24.0, 1e-4);
@@ -862,7 +960,7 @@ mod tests {
fn roundtrip_rle_mode_matches_original() { fn roundtrip_rle_mode_matches_original() {
// Uniform qs → RLE mode selected. // Uniform qs → RLE mode selected.
let src = make_block(2.0, 0.5, 3, 1, 0x37); let src = make_block(2.0, 0.5, 3, 1, 0x37);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(rle.is_rle()); assert!(rle.is_rle());
let mut got = [0.0f32; QK_K]; let mut got = [0.0f32; QK_K];
@@ -874,13 +972,14 @@ mod tests {
#[test] #[test]
fn roundtrip_raw_mode_matches_original() { fn roundtrip_raw_mode_matches_original() {
// Alternating bytes → raw mode selected; output must still be correct. // Alternating bytes → 128 singleton pairs, coverage = 0%.
// Use threshold 0.01 to force raw mode (0% < 0.01).
let mut qs = [0u8; QK_K / 2]; let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() { for (i, b) in qs.iter_mut().enumerate() {
*b = if i % 2 == 0 { 0x13 } else { 0x24 }; *b = if i % 2 == 0 { 0x13 } else { 0x24 };
} }
let src = make_block_with_qs(1.5, 0.25, 2, 1, qs); let src = make_block_with_qs(1.5, 0.25, 2, 1, qs);
let rle = encode(&src); let rle = encode(&src, 0.01);
assert!(!rle.is_rle()); assert!(!rle.is_rle());
let mut got = [0.0f32; QK_K]; let mut got = [0.0f32; QK_K];
@@ -896,7 +995,7 @@ mod tests {
qs[..64].fill(0x59); qs[..64].fill(0x59);
qs[64..].fill(0x8C); qs[64..].fill(0x8C);
let src = make_block_with_qs(3.0, 1.0, 5, 2, qs); let src = make_block_with_qs(3.0, 1.0, 5, 2, qs);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(rle.is_rle()); assert!(rle.is_rle());
let mut got = [0.0f32; QK_K]; let mut got = [0.0f32; QK_K];
@@ -915,7 +1014,7 @@ mod tests {
qs[30..31].fill(0x33); qs[30..31].fill(0x33);
qs[31..].fill(0x44); qs[31..].fill(0x44);
let src = make_block_with_qs(1.0, 0.5, 7, 3, qs); let src = make_block_with_qs(1.0, 0.5, 7, 3, qs);
let rle = encode(&src); let rle = encode(&src, 0.0);
assert!(rle.is_rle(), "4-run block should compress"); assert!(rle.is_rle(), "4-run block should compress");
assert_eq!(rle.rle_len(), 4); assert_eq!(rle.rle_len(), 4);
@@ -929,7 +1028,7 @@ mod tests {
#[test] #[test]
fn roundtrip_zero_qs_matches_original() { fn roundtrip_zero_qs_matches_original() {
let src = make_block(1.0, 0.5, 2, 1, 0x00); let src = make_block(1.0, 0.5, 2, 1, 0x00);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut got = [0.0f32; QK_K]; let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K]; let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got); dequantize_block_q4k_rle(&rle, &mut got);
@@ -942,7 +1041,7 @@ mod tests {
// qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3 // qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3
// (sub-block 1 path). Verify both halves are dequantised correctly. // (sub-block 1 path). Verify both halves are dequantised correctly.
let src = make_block(1.0, 0.0, 1, 0, 0x37); let src = make_block(1.0, 0.0, 1, 0, 0x37);
let rle = encode(&src); let rle = encode(&src, 0.0);
let mut got = [0.0f32; QK_K]; let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K]; let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got); dequantize_block_q4k_rle(&rle, &mut got);
@@ -954,6 +1053,26 @@ mod tests {
assert_close(got[32], 3.0, 1e-5); assert_close(got[32], 3.0, 1e-5);
} }
#[test]
fn roundtrip_128_singleton_pairs_matches_original() {
// All-distinct bytes → 128 pairs, 0% coverage.
// encode at threshold 0.0 → RLE; dequantize must match baseline.
let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(3).wrapping_add(7);
}
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src, 0.0);
assert!(rle.is_rle());
assert_eq!(rle.rle_len(), 128);
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
}
// ========================================================================= // =========================================================================
// matmul_q4k_rle_fp16 // matmul_q4k_rle_fp16
// ========================================================================= // =========================================================================
@@ -964,7 +1083,7 @@ mod tests {
// B: 256×1, all fp16 1.0 // B: 256×1, all fp16 1.0
// C = dot([1.0; 256], [1.0; 256]) = 256.0 // C = dot([1.0; 256], [1.0; 256]) = 256.0
let src = make_block(1.0, 0.0, 1, 0, 0x11); let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
let b = fp16_uniform(QK_K, 1, 1.0); let b = fp16_uniform(QK_K, 1, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1); let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1);
assert_eq!(c.len(), 1); assert_eq!(c.len(), 1);
@@ -974,7 +1093,7 @@ mod tests {
#[test] #[test]
fn matmul_rle_2x256_times_256x3_all_ones() { fn matmul_rle_2x256_times_256x3_all_ones() {
let src = make_block(1.0, 0.0, 1, 0, 0x11); let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src), encode(&src)]; let a = vec![encode(&src, 0.0), encode(&src, 0.0)];
let b = fp16_uniform(QK_K, 3, 1.0); let b = fp16_uniform(QK_K, 3, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3); let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3);
assert_eq!(c.len(), 6); assert_eq!(c.len(), 6);
@@ -984,7 +1103,7 @@ mod tests {
#[test] #[test]
fn matmul_rle_zero_a_gives_zero_c() { fn matmul_rle_zero_a_gives_zero_c() {
let src = make_block(0.0, 0.0, 1, 0, 0xFF); let src = make_block(0.0, 0.0, 1, 0, 0xFF);
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
let b = fp16_uniform(QK_K, 4, 1.0); let b = fp16_uniform(QK_K, 4, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4); let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4);
assert_all_close(&c, 0.0, 0.0); assert_all_close(&c, 0.0, 0.0);
@@ -993,7 +1112,7 @@ mod tests {
#[test] #[test]
fn matmul_rle_zero_b_gives_zero_c() { fn matmul_rle_zero_b_gives_zero_c() {
let src = make_block(1.0, 0.0, 1, 0, 0x11); let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
let b = fp16_uniform(QK_K, 2, 0.0); let b = fp16_uniform(QK_K, 2, 0.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2); let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2);
assert_all_close(&c, 0.0, 0.0); assert_all_close(&c, 0.0, 0.0);
@@ -1004,7 +1123,7 @@ mod tests {
// A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0. // A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0.
// Expected: 512.0 // Expected: 512.0
let src = make_block(1.0, 0.0, 1, 0, 0x11); let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src), encode(&src)]; let a = vec![encode(&src, 0.0), encode(&src, 0.0)];
let b = fp16_uniform(2 * QK_K, 1, 1.0); let b = fp16_uniform(2 * QK_K, 1, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1); let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1);
assert_eq!(c.len(), 1); assert_eq!(c.len(), 1);
@@ -1015,7 +1134,7 @@ mod tests {
fn matmul_rle_output_shape_m_times_n() { fn matmul_rle_output_shape_m_times_n() {
// A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements. // A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements.
let src = make_block(1.0, 0.0, 1, 0, 0x00); let src = make_block(1.0, 0.0, 1, 0, 0x00);
let a: Vec<BlockQ4KRle> = (0..6).map(|_| encode(&src)).collect(); let a: Vec<BlockQ4KRle> = (0..6).map(|_| encode(&src, 0.0)).collect();
let b = fp16_uniform(2 * QK_K, 4, 0.0); let b = fp16_uniform(2 * QK_K, 4, 0.0);
let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4); let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4);
assert_eq!(c.len(), 12); assert_eq!(c.len(), 12);
@@ -1025,7 +1144,7 @@ mod tests {
fn matmul_rle_scalar_b_scales_output() { fn matmul_rle_scalar_b_scales_output() {
// Multiplying B by a scalar should scale C by the same factor. // Multiplying B by a scalar should scale C by the same factor.
let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0 let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
let b1 = fp16_uniform(QK_K, 1, 1.0); let b1 = fp16_uniform(QK_K, 1, 1.0);
let b2 = fp16_uniform(QK_K, 1, 3.0); let b2 = fp16_uniform(QK_K, 1, 3.0);
let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1); let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1);
@@ -1046,7 +1165,8 @@ mod tests {
let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw); let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw);
let a_orig: Vec<BlockQ4K> = vec![src_rle, src_raw]; let a_orig: Vec<BlockQ4K> = vec![src_rle, src_raw];
let a_rle: Vec<BlockQ4KRle> = a_orig.iter().map(encode).collect(); // Use threshold 0.01 so the alternating block (0% coverage) stays raw.
let a_rle: Vec<BlockQ4KRle> = a_orig.iter().map(|b| encode(b, 0.01)).collect();
// A: 1×512, B: 512×2 // A: 1×512, B: 512×2
let b = fp16_uniform(2 * QK_K, 2, 1.0); let b = fp16_uniform(2 * QK_K, 2, 1.0);
@@ -1060,7 +1180,7 @@ mod tests {
// A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B. // A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B.
// Each row dot product = 512.0; C should be all 512.0. // Each row dot product = 512.0; C should be all 512.0.
let src = make_block(1.0, 0.0, 1, 0, 0x11); let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a: Vec<BlockQ4KRle> = (0..4).map(|_| encode(&src)).collect(); let a: Vec<BlockQ4KRle> = (0..4).map(|_| encode(&src, 0.0)).collect();
let b = fp16_uniform(2 * QK_K, 3, 1.0); let b = fp16_uniform(2 * QK_K, 3, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3); let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3);
assert_eq!(c.len(), 6); assert_eq!(c.len(), 6);
@@ -1074,7 +1194,7 @@ mod tests {
#[test] #[test]
fn matmul_rle_panics_when_k_not_multiple_of_qkk() { fn matmul_rle_panics_when_k_not_multiple_of_qkk() {
let src = make_block(1.0, 0.0, 1, 0, 0x00); let src = make_block(1.0, 0.0, 1, 0, 0x00);
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
let b = vec![0u16; 512]; let b = vec![0u16; 512];
let result = std::panic::catch_unwind(move || { let result = std::panic::catch_unwind(move || {
matmul_q4k_rle_fp16(&a, &b, 1, 512, 2); matmul_q4k_rle_fp16(&a, &b, 1, 512, 2);
@@ -1086,7 +1206,7 @@ mod tests {
fn matmul_rle_panics_on_wrong_a_length() { fn matmul_rle_panics_on_wrong_a_length() {
let src = make_block(1.0, 0.0, 1, 0, 0x00); let src = make_block(1.0, 0.0, 1, 0, 0x00);
// m=2, k=QK_K requires 2 blocks; only 1 is provided. // m=2, k=QK_K requires 2 blocks; only 1 is provided.
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
let b = fp16_uniform(QK_K, 1, 1.0); let b = fp16_uniform(QK_K, 1, 1.0);
let result = std::panic::catch_unwind(move || { let result = std::panic::catch_unwind(move || {
matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1); matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1);
@@ -1097,7 +1217,7 @@ mod tests {
#[test] #[test]
fn matmul_rle_panics_on_wrong_b_length() { fn matmul_rle_panics_on_wrong_b_length() {
let src = make_block(1.0, 0.0, 1, 0, 0x00); let src = make_block(1.0, 0.0, 1, 0, 0x00);
let a = vec![encode(&src)]; let a = vec![encode(&src, 0.0)];
// B is too short for k=QK_K, n=3. // B is too short for k=QK_K, n=3.
let b = vec![0u16; 10]; let b = vec![0u16; 10];
let result = std::panic::catch_unwind(move || { let result = std::panic::catch_unwind(move || {