From 3fb10b78e321f8e3d51c5b52aaeab7dd90a6b9c8 Mon Sep 17 00:00:00 2001 From: charles Date: Sun, 12 Apr 2026 20:51:19 -0700 Subject: [PATCH] Allow variable coverage --- benches/matmul.rs | 12 +- src/bin/gguf_matmul.rs | 2 +- src/bin/gguf_scan.rs | 40 +++- src/rle.rs | 438 ++++++++++++++++++++++++++--------------- 4 files changed, 320 insertions(+), 172 deletions(-) diff --git a/benches/matmul.rs b/benches/matmul.rs index ef12a1e..71b6fe3 100644 --- a/benches/matmul.rs +++ b/benches/matmul.rs @@ -153,7 +153,7 @@ fn bench_encode(c: &mut Criterion) { group.bench_function("uniform", |b| { b.iter(|| { for blk in &uniform { - black_box(encode(black_box(blk))); + black_box(encode(black_box(blk), 0.0)); } }); }); @@ -161,7 +161,7 @@ fn bench_encode(c: &mut Criterion) { group.bench_function("rle_optimal", |b| { b.iter(|| { for blk in &rle_opt { - black_box(encode(black_box(blk))); + black_box(encode(black_box(blk), 0.0)); } }); }); @@ -186,8 +186,8 @@ fn bench_dequantize(c: &mut Criterion) { let q4k_uniform = uniform_blocks(1).into_iter().next().unwrap(); let q4k_rle_opt = rle_optimal_blocks(1).into_iter().next().unwrap(); - let rle_raw = encode(&q4k_uniform); // IS_RLE = 0 - let rle_rle = encode(&q4k_rle_opt); // IS_RLE = 1 + let rle_raw = encode(&q4k_uniform, 0.0); // IS_RLE = 0 + let rle_rle = encode(&q4k_rle_opt, 0.0); // IS_RLE = 1 // Confirm the fixtures ended up in the right encoding modes. assert!(!rle_raw.is_rle(), "uniform block should encode to raw mode"); @@ -270,10 +270,10 @@ fn bench_matmul(c: &mut Criterion) { // Build all four A variants and the shared B matrix for this config. let a_q4k_u: Vec = uniform_blocks(m * bpr); - let a_rle_u: Vec = a_q4k_u.iter().map(encode).collect(); + let a_rle_u: Vec = a_q4k_u.iter().map(|b| encode(b, 0.0)).collect(); let a_q4k_r: Vec = rle_optimal_blocks(m * bpr); - let a_rle_r: Vec = a_q4k_r.iter().map(encode).collect(); + let a_rle_r: Vec = a_q4k_r.iter().map(|b| encode(b, 0.0)).collect(); let b = fp16_ones(k, n); diff --git a/src/bin/gguf_matmul.rs b/src/bin/gguf_matmul.rs index 0b4ef81..dbcdec4 100644 --- a/src/bin/gguf_matmul.rs +++ b/src/bin/gguf_matmul.rs @@ -145,7 +145,7 @@ fn main() -> Result<(), Box> { // ── RLE encode (best of `trials`) ──────────────────────────────────────── let (rle_blocks, t_enc) = bench(trials, || -> Vec { - blocks.iter().map(encode).collect() + blocks.iter().map(|b| encode(b, 0.0)).collect() }); let n_rle = rle_blocks.iter().filter(|b| b.is_rle()).count(); diff --git a/src/bin/gguf_scan.rs b/src/bin/gguf_scan.rs index d930496..279a4e5 100644 --- a/src/bin/gguf_scan.rs +++ b/src/bin/gguf_scan.rs @@ -101,11 +101,32 @@ fn fixed(s: &str, width: usize) -> String { fn main() -> Result<(), Box> { let args: Vec = env::args().collect(); if args.len() < 2 { - eprintln!("usage: {} ", args[0]); + eprintln!("usage: {} [--threshold <0.0..1.0>]", args[0]); + eprintln!(); + eprintln!(" --threshold Minimum fraction of qs bytes that must be in runs of"); + eprintln!(" length ≥ 2 for a block to use RLE mode. Default: 0.0"); + eprintln!(" (use RLE whenever the pair count fits in 64 pairs)."); std::process::exit(1); } let path = &args[1]; + // Parse optional --threshold flag from the remaining arguments. + let mut threshold = 0.0f32; + let mut idx = 2usize; + while idx < args.len() { + if args[idx] == "--threshold" { + idx += 1; + threshold = args.get(idx) + .and_then(|s| s.parse::().ok()) + .filter(|&v| (0.0..=1.0).contains(&v)) + .unwrap_or_else(|| { + eprintln!("error: --threshold requires a value in [0.0, 1.0]"); + std::process::exit(1); + }); + } + idx += 1; + } + // ── Parse header ───────────────────────────────────────────────────────── eprintln!("Parsing {path} …"); let (tensors, data_start) = parse_header(path)?; @@ -122,6 +143,8 @@ fn main() -> Result<(), Box> { q4k_tensors.len(), other_count, ); + eprintln!(" RLE threshold: {threshold:.2} (blocks need ≥ {:.0}% of bytes in runs)", + threshold * 100.0); eprintln!(); // ── Header row ─────────────────────────────────────────────────────────── @@ -145,7 +168,7 @@ fn main() -> Result<(), Box> { let mut stats = TensorStats::new(); for_each_block(&mut file, data_start, tensor, |block| { - let rle_block = encode(block); + let rle_block = encode(block, threshold); stats.observe(rle_block.is_rle(), rle_block.rle_len()); })?; @@ -187,10 +210,15 @@ fn main() -> Result<(), Box> { if !any_rle { println!(); - println!("No blocks compressed with RLE — all weights are effectively random at"); - println!("the byte level, which is typical for trained Q4_K quantised weights."); - println!("RLE compression only helps for structured weight matrices (binary,"); - println!("ternary, heavily pruned, or synthetic)."); + println!("No blocks used RLE at threshold {threshold:.2}."); + if threshold < 0.01 { + println!("All weights are effectively random at the byte level — typical for"); + println!("trained Q4_K weights. RLE only helps for structured weight matrices"); + println!("(binary, ternary, heavily pruned, or synthetic)."); + } else { + println!("Try a lower --threshold (e.g. --threshold 0.0) to see whether any"); + println!("blocks have enough run structure to qualify at a looser threshold."); + } } Ok(()) diff --git a/src/rle.rs b/src/rle.rs index 662e69b..c560d96 100644 --- a/src/rle.rs +++ b/src/rle.rs @@ -12,18 +12,16 @@ //! //! ## RLE format (when `IS_RLE` = 1) //! -//! - `flags >> 1` gives the number of `(value, count)` pairs stored in `qs`. +//! - `n_pairs` gives the number of `(value, count)` pairs stored in `qs`. //! - For each pair `i`: //! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing //! as the raw format). //! - `qs[2*i + 1]` — the run length in bytes (1..=255). //! - The run lengths must sum to exactly 128 (the uncompressed `qs` size). //! -//! RLE encoding is chosen only when the compressed representation is -//! **strictly shorter** than the 128-byte raw payload, i.e. when -//! `pairs * 2 < 128`. That caps the useful range at ≤ 63 pairs. The 7-bit -//! `flags >> 1` sub-field can hold up to 127, so this ceiling is never a -//! concern in practice. +//! The 256-byte `qs` field can hold up to 128 `(value, count)` pairs — enough +//! to represent even fully-random blocks where every byte differs from its +//! neighbour. //! //! ## Constructing blocks //! @@ -49,42 +47,45 @@ pub const IS_RLE: u8 = 0x01; /// A Q4_K super-block with optional byte-level RLE compression on the weights. /// -/// Identical to [`crate::BlockQ4K`] except for the additional [`flags`](Self::flags) -/// byte inserted between `scales` and `qs`. +/// Unlike [`crate::BlockQ4K`], this format is **not** binary-compatible with +/// the GGUF on-disk layout. It uses a 256-byte `qs` field (vs the 128-byte +/// field in `BlockQ4K`) so the RLE stream can store up to 128 `(value, count)` +/// pairs — enough to represent even fully-random blocks where every byte +/// differs from its neighbour. /// -/// Memory layout (repr C): +/// Memory layout (`repr C`): /// -/// | Offset | Field | Size | Notes | -/// |--------|------------|-------|--------------------------------| -/// | 0 | `d` | 2 B | fp16 super-block scale | -/// | 2 | `dmin` | fp16 super-block min scale | 2 B | -/// | 4 | `scales` | 12 B | packed 6-bit sub-block params | -/// | 16 | `flags` | 1 B | encoding flags (see below) | -/// | 17 | `qs` | 128 B | raw nibbles or RLE stream | -/// | 145 | (padding) | 1 B | implicit trailing alignment pad| +/// | Offset | Field | Size | Notes | +/// |--------|-----------|--------|------------------------------------| +/// | 0 | `d` | 2 B | fp16 super-block scale | +/// | 2 | `dmin` | 2 B | fp16 super-block min-scale | +/// | 4 | `scales` | 12 B | packed 6-bit sub-block params | +/// | 16 | `flags` | 1 B | bit 0 = `IS_RLE`; bits 1–7 unused | +/// | 17 | `n_pairs` | 1 B | RLE pair count (0 when raw) | +/// | 18 | `qs` | 256 B | raw nibbles (first 128 B) or RLE | /// -/// **sizeof = 146 bytes** (padded to 2-byte alignment imposed by `u16` fields). +/// **sizeof = 274 bytes.** /// -/// ## `flags` bit layout +/// ## `qs` interpretation /// -/// | Bits | Meaning | -/// |------|---------------------------------------------------------------| -/// | 0 | [`IS_RLE`] — 1 = `qs` is RLE-encoded, 0 = raw packed nibbles | -/// | 1–7 | When `IS_RLE`=1: number of `(value, count)` pairs in `qs` | +/// | `IS_RLE` | Meaning | +/// |----------|--------------------------------------------------------------| +/// | 0 | `qs[0..128]` holds raw packed nibbles (same as `BlockQ4K`) | +/// | 1 | `qs[0..n_pairs*2]` holds `(value, count)` byte-pairs | #[repr(C)] #[derive(Clone, Copy, Debug)] pub struct BlockQ4KRle { - /// Super-block scale for quantised sub-block scales (fp16 bits). - pub d: u16, - /// Super-block scale for quantised sub-block mins (fp16 bits). - pub dmin: u16, - /// Packed 6-bit sub-block scales and mins (same layout as [`crate::BlockQ4K`]). - pub scales: [u8; K_SCALE_SIZE], - /// Encoding flags. Bit 0 = [`IS_RLE`]. Bits 1-7 = RLE pair count when - /// `IS_RLE` is set. - pub flags: u8, - /// Raw packed-nibble weights (`IS_RLE` = 0) or RLE byte stream (`IS_RLE` = 1). - pub qs: [u8; QK_K / 2], + pub d: u16, + pub dmin: u16, + pub scales: [u8; K_SCALE_SIZE], + /// Encoding flags. Only bit 0 (`IS_RLE`) is used; bits 1-7 are reserved. + pub flags: u8, + /// When `IS_RLE` is set: number of `(value, count)` byte-pairs in `qs`. + /// Zero when in raw mode. + pub n_pairs: u8, + /// Raw packed-nibble weights (IS_RLE = 0, first 128 bytes) or RLE stream + /// (IS_RLE = 1, first `n_pairs * 2` bytes). + pub qs: [u8; QK_K], // 256 bytes } impl BlockQ4KRle { @@ -94,12 +95,11 @@ impl BlockQ4KRle { self.flags & IS_RLE != 0 } - /// Number of `(value, count)` byte-pairs stored at the start of `qs`. - /// - /// Only meaningful when [`is_rle`](Self::is_rle) returns `true`. + /// Number of `(value, count)` byte-pairs in `qs`. + /// Only meaningful when `is_rle()` is true. #[inline] pub fn rle_len(&self) -> usize { - (self.flags >> 1) as usize + self.n_pairs as usize } } @@ -109,24 +109,38 @@ impl BlockQ4KRle { /// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block. /// -/// The 128-byte `qs` payload is scanned for runs of identical bytes. If the -/// RLE representation fits in the same 128-byte field **and is strictly -/// shorter** than the raw payload, it is stored with `IS_RLE` set. Otherwise -/// the raw bytes are copied unchanged and `IS_RLE` is cleared. +/// The `qs` payload is scanned for runs of equal consecutive bytes. RLE mode +/// is chosen when **both** conditions hold: /// -/// The `d`, `dmin`, and `scales` fields are always copied verbatim. -pub fn encode(block: &BlockQ4K) -> BlockQ4KRle { - let mut raw = Vec::from(&block.qs); +/// 1. **Coverage**: at least `min_coverage` fraction of the 128 `qs` bytes +/// belong to runs of length ≥ 2. These are the bytes whose weights can be +/// batched in `accumulate_rle_block`, replacing `2 * run_len` multiplies +/// with just 2 per group-segment. +/// +/// 2. **Capacity**: the pair count does not exceed 128 (the physical limit of +/// the 256-byte `qs` field at 2 bytes per pair). +/// +/// | `min_coverage` | Effect | +/// |----------------|------------------------------------------------------| +/// | `0.0` | RLE whenever pairs fit (≤ 128), regardless of runs | +/// | `0.5` | RLE only if ≥ 50 % of bytes are in repeated runs | +/// | `1.0` | RLE only when every byte is part of a run | +pub fn encode(block: &BlockQ4K, min_coverage: f32) -> BlockQ4KRle { + debug_assert!( + (0.0..=1.0).contains(&min_coverage), + "min_coverage must be in [0.0, 1.0], got {min_coverage}" + ); - // Sort the raw numbers - raw.sort(); + let raw = &block.qs; // [u8; 128] - // Scan the 128-byte raw payload for runs of equal bytes. - let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(64); + // Scan for runs of equal consecutive bytes. + // Track long_run_bytes: bytes in runs of length ≥ 2 (the bytes that + // benefit from RLE in the matmul). + let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(QK_K / 2); + let mut long_run_bytes = 0usize; let mut i = 0usize; while i < raw.len() { let val = raw[i]; - // Count consecutive equal bytes; saturate at u8::MAX to stay in-range. let mut run = 1u8; while i + (run as usize) < raw.len() && raw[i + (run as usize)] == val @@ -135,38 +149,40 @@ pub fn encode(block: &BlockQ4K) -> BlockQ4KRle { run += 1; } pairs.push((val, run)); + if run >= 2 { + long_run_bytes += run as usize; + } i += run as usize; } - // Only switch to RLE when the encoded form is strictly smaller than the - // raw payload. Because each pair costs 2 bytes and the raw payload is - // 128 bytes, the condition pairs.len() * 2 < 128 also guarantees that - // pairs.len() ≤ 63, which fits in bits 1-7 of the flags byte. - if pairs.len() * 2 < raw.len() { - let n = pairs.len(); - debug_assert!(n <= 63, "RLE pair count {n} unexpectedly exceeds 63"); + // Coverage: fraction of qs bytes that are in non-singleton runs. + let coverage = long_run_bytes as f32 / raw.len() as f32; - let mut qs = [0u8; QK_K / 2]; + if pairs.len() <= QK_K / 2 && coverage >= min_coverage { + let n = pairs.len(); + let mut qs = [0u8; QK_K]; for (k, &(val, count)) in pairs.iter().enumerate() { qs[2 * k] = val; qs[2 * k + 1] = count; } - BlockQ4KRle { - d: block.d, - dmin: block.dmin, - scales: block.scales, - flags: IS_RLE | ((n as u8) << 1), + d: block.d, + dmin: block.dmin, + scales: block.scales, + flags: IS_RLE, + n_pairs: n as u8, qs, } } else { - // No space savings — copy raw bytes and leave IS_RLE clear. + let mut qs = [0u8; QK_K]; + qs[..QK_K / 2].copy_from_slice(&block.qs); BlockQ4KRle { - d: block.d, - dmin: block.dmin, - scales: block.scales, - flags: 0, - qs: block.qs, + d: block.d, + dmin: block.dmin, + scales: block.scales, + flags: 0, + n_pairs: 0, + qs, } } } @@ -183,27 +199,26 @@ pub fn encode(block: &BlockQ4K) -> BlockQ4KRle { /// Panics if the decoded RLE stream does not sum to exactly 128 bytes. fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] { if !block.is_rle() { - return block.qs; + // First QK_K/2 bytes of qs hold the raw packed nibbles. + block.qs[..QK_K / 2].try_into().unwrap() + } else { + let n = block.rle_len(); + let mut raw = [0u8; QK_K / 2]; + let mut pos = 0usize; + for i in 0..n { + let val = block.qs[2 * i]; + let count = block.qs[2 * i + 1] as usize; + raw[pos..pos + count].fill(val); + pos += count; + } + debug_assert_eq!( + pos, + QK_K / 2, + "RLE run lengths sum to {pos}, expected {}", + QK_K / 2 + ); + raw } - - let n = block.rle_len(); - let mut raw = [0u8; QK_K / 2]; - let mut pos = 0usize; - - for i in 0..n { - let val = block.qs[2 * i]; - let count = block.qs[2 * i + 1] as usize; - raw[pos..pos + count].fill(val); - pos += count; - } - - debug_assert_eq!( - pos, - QK_K / 2, - "RLE run lengths sum to {pos}, expected {}", - QK_K / 2 - ); - raw } // --------------------------------------------------------------------------- @@ -570,17 +585,18 @@ mod tests { // ========================================================================= #[test] - fn block_q4k_rle_size_is_146_bytes() { - // d(2) + dmin(2) + scales(12) + flags(1) + qs(128) = 145 raw bytes, - // rounded up to 146 by the 2-byte alignment imposed by the u16 fields. - assert_eq!(core::mem::size_of::(), 146); + fn block_q4k_rle_size_is_274_bytes() { + // d(2) + dmin(2) + scales(12) + flags(1) + n_pairs(1) + qs(256) = 274 bytes. + // No padding needed: struct is already 2-byte aligned and 274 is even. + assert_eq!(core::mem::size_of::(), 274); } #[test] - fn block_q4k_rle_is_two_bytes_larger_than_block_q4k() { + fn block_q4k_rle_is_130_bytes_larger_than_block_q4k() { + // BlockQ4K = 144 bytes, BlockQ4KRle = 274 bytes, delta = 130. assert_eq!( core::mem::size_of::(), - core::mem::size_of::() + 2, + core::mem::size_of::() + 130, ); } @@ -591,7 +607,7 @@ mod tests { #[test] fn is_rle_false_when_flag_clear() { let b = BlockQ4KRle { - d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2], + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs: [0; QK_K], }; assert!(!b.is_rle()); } @@ -599,7 +615,7 @@ mod tests { #[test] fn rle_len_zero_when_flag_clear() { let b = BlockQ4KRle { - d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2], + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs: [0; QK_K], }; assert_eq!(b.rle_len(), 0); } @@ -608,19 +624,21 @@ mod tests { fn is_rle_true_when_flag_set() { let b = BlockQ4KRle { d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], - flags: IS_RLE | (5u8 << 1), - qs: [0; QK_K / 2], + flags: IS_RLE, + n_pairs: 5, + qs: [0; QK_K], }; assert!(b.is_rle()); } #[test] - fn rle_len_reports_pair_count_from_flags() { - for n in [0usize, 1, 7, 31, 63] { + fn rle_len_reports_pair_count_from_n_pairs() { + for n in [0usize, 1, 7, 31, 63, 128] { let b = BlockQ4KRle { d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], - flags: IS_RLE | ((n as u8) << 1), - qs: [0; QK_K / 2], + flags: if n > 0 { IS_RLE } else { 0 }, + n_pairs: n as u8, + qs: [0; QK_K], }; assert_eq!(b.rle_len(), n, "expected rle_len {n}"); } @@ -632,61 +650,64 @@ mod tests { #[test] fn encode_uniform_qs_uses_rle() { - // 128 identical bytes → 1 pair → 2 bytes < 128 raw. + // 128 identical bytes → 1 pair → 2 bytes stored in qs. let src = make_block(1.0, 0.0, 1, 0, 0x77); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert!(rle.is_rle(), "uniform qs should trigger RLE mode"); } #[test] fn encode_uniform_qs_rle_len_is_one() { let src = make_block(1.0, 0.0, 1, 0, 0x55); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert_eq!(rle.rle_len(), 1); } #[test] fn encode_uniform_qs_rle_entry_is_correct() { let src = make_block(1.0, 0.0, 1, 0, 0xAB); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte"); assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes"); } #[test] fn encode_alternating_bytes_stays_raw() { - // 128 single-byte runs → 128 pairs → 256 bytes ≥ 128 raw → raw mode. + // Alternating 0xAA / 0x55 → 128 singleton pairs, coverage = 0%. + // At threshold 0.01 the 0% coverage fails → raw mode. let mut qs = [0u8; QK_K / 2]; for (i, b) in qs.iter_mut().enumerate() { *b = if i % 2 == 0 { 0xAA } else { 0x55 }; } let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); - let rle = encode(&src); - assert!(!rle.is_rle(), "alternating bytes cannot be compressed → raw mode"); + let rle = encode(&src, 0.01); + assert!(!rle.is_rle(), "0% coverage fails any threshold > 0 → raw mode"); } #[test] fn encode_raw_mode_copies_qs_verbatim() { + // Three-byte cycle of distinct values → 128 runs of 1 byte each, + // coverage = 0%. At threshold 0.01 the 0% coverage fails → raw mode. let mut qs = [0u8; QK_K / 2]; for (i, b) in qs.iter_mut().enumerate() { - // Three-byte cycle of distinct values → 128 runs of 1 byte each. *b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 }; } let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); - let rle = encode(&src); + let rle = encode(&src, 0.01); assert!(!rle.is_rle()); - assert_eq!(rle.qs, qs, "raw mode must preserve qs bytes unchanged"); + // Raw mode copies the 128-byte qs into the first half of the 256-byte field. + assert_eq!(&rle.qs[..QK_K / 2], &qs[..], "raw mode must preserve qs bytes in first 128 bytes"); } #[test] fn encode_two_runs_uses_rle_and_stores_correct_pairs() { // Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22. - // → 2 pairs = 4 bytes < 128 bytes raw. + // → 2 pairs = 4 bytes. let mut qs = [0u8; QK_K / 2]; qs[..64].fill(0x11); qs[64..].fill(0x22); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert!(rle.is_rle()); assert_eq!(rle.rle_len(), 2); assert_eq!(rle.qs[0], 0x11, "first pair: value"); @@ -698,7 +719,7 @@ mod tests { #[test] fn encode_63_pairs_uses_rle() { // Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes. - // 63 pairs × 2 = 126 bytes < 128 → RLE should be chosen. + // 63 pairs × 2 = 126 bytes; 63 ≤ 128 → RLE should be chosen. let mut qs = [0u8; QK_K / 2]; let mut pos = 0usize; for run in 0..62usize { @@ -712,15 +733,15 @@ mod tests { qs[pos..].fill(0xFE); let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert!(rle.is_rle(), "63 pairs should use RLE"); assert_eq!(rle.rle_len(), 63); } #[test] - fn encode_64_pairs_stays_raw() { - // 64 runs of 2 bytes each = 128 bytes total. - // 64 pairs × 2 = 128 bytes, which is NOT strictly less than 128 → raw. + fn encode_64_pairs_uses_rle_at_zero_threshold() { + // 64 runs of 2 bytes each = 128 bytes total, coverage = 100%. + // pairs (64) ≤ 128 AND 100% ≥ 0.0 → RLE mode. let mut qs = [0u8; QK_K / 2]; let mut pos = 0usize; for run in 0..64usize { @@ -730,14 +751,93 @@ mod tests { pos += 2; } let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); - let rle = encode(&src); - assert!(!rle.is_rle(), "64 pairs offers no saving → raw mode"); + let rle = encode(&src, 0.0); + assert!(rle.is_rle(), "64 pairs, 100% coverage, threshold 0.0 → RLE"); + assert_eq!(rle.rle_len(), 64); + } + + #[test] + fn encode_128_pairs_uses_rle_at_zero_threshold() { + // 128 distinct consecutive bytes = 128 singleton runs = 128 pairs. + // With old cap (64 pairs), this was always raw. + // With new cap (128 pairs), threshold 0.0 accepts it. + // Coverage = 0 % (all singletons) → threshold > 0.0 rejects it. + let mut qs = [0u8; QK_K / 2]; + for (i, b) in qs.iter_mut().enumerate() { + *b = i as u8; // 0x00, 0x01, ..., 0x7F — all distinct, all singletons + } + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + + assert!( + encode(&src, 0.0).is_rle(), + "128 pairs ≤ 128 limit AND 0% ≥ 0.0 → RLE at zero threshold" + ); + assert_eq!(encode(&src, 0.0).rle_len(), 128); + + assert!( + !encode(&src, 0.01).is_rle(), + "0% coverage fails any threshold > 0" + ); + } + + #[test] + fn encode_coverage_threshold_rejects_low_coverage_block() { + // Construct: 63 singletons + 1 run of 65 bytes = 64 pairs. + // coverage = 65/128 ≈ 50.8%. + // threshold 0.50 accepts it; threshold 0.60 rejects it. + let mut qs = [0u8; QK_K / 2]; + qs[0] = 0x01; + for i in 1..63usize { + // Distinct odd bytes, none equal to 0x01 or adjacent values. + qs[i] = (i as u8).wrapping_mul(2).wrapping_add(5); + } + qs[63..].fill(0xAB); // 65-byte run; qs[62] = 62*2+5 = 129 → wraps to 0x81 ≠ 0xAB ✓ + + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + assert!( + encode(&src, 0.50).is_rle(), + "50.8% coverage should meet 50% threshold" + ); + assert!( + !encode(&src, 0.60).is_rle(), + "50.8% coverage should fail 60% threshold" + ); + } + + #[test] + fn encode_coverage_zero_threshold_always_uses_rle_when_pairs_fit() { + // Any block whose runs produce ≤ 128 pairs uses RLE at threshold 0.0, + // regardless of how many singletons it contains. + // Use the 63-pair block from encode_63_pairs_uses_rle. + let mut qs = [0u8; QK_K / 2]; + let mut pos = 0usize; + for run in 0..62usize { + let v = (run as u8).wrapping_mul(3).wrapping_add(1); + qs[pos] = v; + qs[pos + 1] = v; + pos += 2; + } + qs[pos..].fill(0xFE); + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + assert!(encode(&src, 0.0).is_rle()); + } + + #[test] + fn encode_coverage_one_threshold_requires_total_coverage() { + // A block with even one singleton byte fails the 100% threshold. + // Build: 1 singleton + 1 run of 127 bytes = 2 pairs, coverage = 127/128 ≈ 99.2%. + let mut qs = [0u8; QK_K / 2]; + qs[0] = 0x01; // singleton (value distinct from rest) + qs[1..].fill(0x02); // 127-byte run + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + assert!(!encode(&src, 1.0).is_rle(), "99.2% coverage should fail 100% threshold"); + assert!(encode(&src, 0.99).is_rle(), "99.2% coverage should meet 99% threshold"); } #[test] fn encode_preserves_d_dmin_scales() { let src = make_block(2.0, 0.5, 3, 2, 0x00); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert_eq!(rle.d, src.d); assert_eq!(rle.dmin, src.dmin); assert_eq!(rle.scales, src.scales); @@ -749,25 +849,24 @@ mod tests { #[test] fn decode_qs_raw_mode_returns_qs_unchanged() { - // Build a raw BlockQ4KRle (flags = 0) with a non-trivial qs pattern. - let mut qs = [0u8; QK_K / 2]; - for (i, b) in qs.iter_mut().enumerate() { *b = i as u8; } + let mut qs = [0u8; QK_K]; + for (i, b) in qs[..QK_K / 2].iter_mut().enumerate() { *b = i as u8; } let rle = BlockQ4KRle { - d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs, + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs, }; - assert_eq!(decode_qs(&rle), qs); + let expected: [u8; QK_K / 2] = qs[..QK_K / 2].try_into().unwrap(); + assert_eq!(decode_qs(&rle), expected); } #[test] fn decode_qs_rle_expands_two_pair_stream() { // Hand-craft an RLE block: [0xAA × 64, 0xBB × 64]. - let mut qs = [0u8; QK_K / 2]; + let mut qs = [0u8; QK_K]; qs[0] = 0xAA; qs[1] = 64; qs[2] = 0xBB; qs[3] = 64; let rle = BlockQ4KRle { d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], - flags: IS_RLE | (2u8 << 1), - qs, + flags: IS_RLE, n_pairs: 2, qs, }; let expanded = decode_qs(&rle); assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA"); @@ -776,12 +875,11 @@ mod tests { #[test] fn decode_qs_rle_single_run_covers_all() { - let mut qs = [0u8; QK_K / 2]; + let mut qs = [0u8; QK_K]; qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes let rle = BlockQ4KRle { d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], - flags: IS_RLE | (1u8 << 1), - qs, + flags: IS_RLE, n_pairs: 1, qs, }; let expanded = decode_qs(&rle); assert!(expanded.iter().all(|&b| b == 0xCD)); @@ -794,7 +892,7 @@ mod tests { #[test] fn dequant_rle_zero_d_all_outputs_zero() { let src = make_block(0.0, 0.0, 1, 0, 0x77); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut out = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut out); assert_all_close(&out, 0.0, 0.0); @@ -805,7 +903,7 @@ mod tests { // qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0. // expected: 1.0 * 1 * 1 - 0.0 = 1.0 let src = make_block(1.0, 0.0, 1, 0, 0x11); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut out = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut out); assert_all_close(&out, 1.0, 1e-5); @@ -816,7 +914,7 @@ mod tests { // nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0 // expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0 let src = make_block(1.0, 1.0, 1, 2, 0x00); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut out = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut out); assert_all_close(&out, -2.0, 1e-5); @@ -827,7 +925,7 @@ mod tests { // qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0. // expected: 1.0 * 1 * 15 - 0.0 = 15.0 let src = make_block(1.0, 0.0, 1, 0, 0xFF); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut out = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut out); assert_all_close(&out, 15.0, 1e-5); @@ -836,7 +934,7 @@ mod tests { #[test] fn dequant_rle_output_count_is_qk_k() { let src = make_block(1.0, 0.0, 1, 0, 0x00); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut out = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut out); assert_eq!(out.len(), QK_K); @@ -848,7 +946,7 @@ mod tests { // expected: 2.0 * 4 * 3 - 0.0 = 24.0 // qs_byte = 0x33 → both nibbles = 3 let src = make_block(2.0, 0.0, 4, 0, 0x33); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut out = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut out); assert_all_close(&out, 24.0, 1e-4); @@ -862,7 +960,7 @@ mod tests { fn roundtrip_rle_mode_matches_original() { // Uniform qs → RLE mode selected. let src = make_block(2.0, 0.5, 3, 1, 0x37); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert!(rle.is_rle()); let mut got = [0.0f32; QK_K]; @@ -874,13 +972,14 @@ mod tests { #[test] fn roundtrip_raw_mode_matches_original() { - // Alternating bytes → raw mode selected; output must still be correct. + // Alternating bytes → 128 singleton pairs, coverage = 0%. + // Use threshold 0.01 to force raw mode (0% < 0.01). let mut qs = [0u8; QK_K / 2]; for (i, b) in qs.iter_mut().enumerate() { *b = if i % 2 == 0 { 0x13 } else { 0x24 }; } let src = make_block_with_qs(1.5, 0.25, 2, 1, qs); - let rle = encode(&src); + let rle = encode(&src, 0.01); assert!(!rle.is_rle()); let mut got = [0.0f32; QK_K]; @@ -896,7 +995,7 @@ mod tests { qs[..64].fill(0x59); qs[64..].fill(0x8C); let src = make_block_with_qs(3.0, 1.0, 5, 2, qs); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert!(rle.is_rle()); let mut got = [0.0f32; QK_K]; @@ -915,7 +1014,7 @@ mod tests { qs[30..31].fill(0x33); qs[31..].fill(0x44); let src = make_block_with_qs(1.0, 0.5, 7, 3, qs); - let rle = encode(&src); + let rle = encode(&src, 0.0); assert!(rle.is_rle(), "4-run block should compress"); assert_eq!(rle.rle_len(), 4); @@ -929,7 +1028,7 @@ mod tests { #[test] fn roundtrip_zero_qs_matches_original() { let src = make_block(1.0, 0.5, 2, 1, 0x00); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut got = [0.0f32; QK_K]; let mut expected = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut got); @@ -942,7 +1041,7 @@ mod tests { // qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3 // (sub-block 1 path). Verify both halves are dequantised correctly. let src = make_block(1.0, 0.0, 1, 0, 0x37); - let rle = encode(&src); + let rle = encode(&src, 0.0); let mut got = [0.0f32; QK_K]; let mut expected = [0.0f32; QK_K]; dequantize_block_q4k_rle(&rle, &mut got); @@ -954,6 +1053,26 @@ mod tests { assert_close(got[32], 3.0, 1e-5); } + #[test] + fn roundtrip_128_singleton_pairs_matches_original() { + // All-distinct bytes → 128 pairs, 0% coverage. + // encode at threshold 0.0 → RLE; dequantize must match baseline. + let mut qs = [0u8; QK_K / 2]; + for (i, b) in qs.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(3).wrapping_add(7); + } + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + let rle = encode(&src, 0.0); + assert!(rle.is_rle()); + assert_eq!(rle.rle_len(), 128); + + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + } + // ========================================================================= // matmul_q4k_rle_fp16 // ========================================================================= @@ -964,7 +1083,7 @@ mod tests { // B: 256×1, all fp16 1.0 // C = dot([1.0; 256], [1.0; 256]) = 256.0 let src = make_block(1.0, 0.0, 1, 0, 0x11); - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; let b = fp16_uniform(QK_K, 1, 1.0); let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1); assert_eq!(c.len(), 1); @@ -974,7 +1093,7 @@ mod tests { #[test] fn matmul_rle_2x256_times_256x3_all_ones() { let src = make_block(1.0, 0.0, 1, 0, 0x11); - let a = vec![encode(&src), encode(&src)]; + let a = vec![encode(&src, 0.0), encode(&src, 0.0)]; let b = fp16_uniform(QK_K, 3, 1.0); let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3); assert_eq!(c.len(), 6); @@ -984,7 +1103,7 @@ mod tests { #[test] fn matmul_rle_zero_a_gives_zero_c() { let src = make_block(0.0, 0.0, 1, 0, 0xFF); - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; let b = fp16_uniform(QK_K, 4, 1.0); let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4); assert_all_close(&c, 0.0, 0.0); @@ -993,7 +1112,7 @@ mod tests { #[test] fn matmul_rle_zero_b_gives_zero_c() { let src = make_block(1.0, 0.0, 1, 0, 0x11); - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; let b = fp16_uniform(QK_K, 2, 0.0); let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2); assert_all_close(&c, 0.0, 0.0); @@ -1004,7 +1123,7 @@ mod tests { // A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0. // Expected: 512.0 let src = make_block(1.0, 0.0, 1, 0, 0x11); - let a = vec![encode(&src), encode(&src)]; + let a = vec![encode(&src, 0.0), encode(&src, 0.0)]; let b = fp16_uniform(2 * QK_K, 1, 1.0); let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1); assert_eq!(c.len(), 1); @@ -1015,7 +1134,7 @@ mod tests { fn matmul_rle_output_shape_m_times_n() { // A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements. let src = make_block(1.0, 0.0, 1, 0, 0x00); - let a: Vec = (0..6).map(|_| encode(&src)).collect(); + let a: Vec = (0..6).map(|_| encode(&src, 0.0)).collect(); let b = fp16_uniform(2 * QK_K, 4, 0.0); let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4); assert_eq!(c.len(), 12); @@ -1025,7 +1144,7 @@ mod tests { fn matmul_rle_scalar_b_scales_output() { // Multiplying B by a scalar should scale C by the same factor. let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0 - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; let b1 = fp16_uniform(QK_K, 1, 1.0); let b2 = fp16_uniform(QK_K, 1, 3.0); let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1); @@ -1046,7 +1165,8 @@ mod tests { let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw); let a_orig: Vec = vec![src_rle, src_raw]; - let a_rle: Vec = a_orig.iter().map(encode).collect(); + // Use threshold 0.01 so the alternating block (0% coverage) stays raw. + let a_rle: Vec = a_orig.iter().map(|b| encode(b, 0.01)).collect(); // A: 1×512, B: 512×2 let b = fp16_uniform(2 * QK_K, 2, 1.0); @@ -1060,7 +1180,7 @@ mod tests { // A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B. // Each row dot product = 512.0; C should be all 512.0. let src = make_block(1.0, 0.0, 1, 0, 0x11); - let a: Vec = (0..4).map(|_| encode(&src)).collect(); + let a: Vec = (0..4).map(|_| encode(&src, 0.0)).collect(); let b = fp16_uniform(2 * QK_K, 3, 1.0); let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3); assert_eq!(c.len(), 6); @@ -1074,7 +1194,7 @@ mod tests { #[test] fn matmul_rle_panics_when_k_not_multiple_of_qkk() { let src = make_block(1.0, 0.0, 1, 0, 0x00); - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; let b = vec![0u16; 512]; let result = std::panic::catch_unwind(move || { matmul_q4k_rle_fp16(&a, &b, 1, 512, 2); @@ -1086,7 +1206,7 @@ mod tests { fn matmul_rle_panics_on_wrong_a_length() { let src = make_block(1.0, 0.0, 1, 0, 0x00); // m=2, k=QK_K requires 2 blocks; only 1 is provided. - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; let b = fp16_uniform(QK_K, 1, 1.0); let result = std::panic::catch_unwind(move || { matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1); @@ -1097,7 +1217,7 @@ mod tests { #[test] fn matmul_rle_panics_on_wrong_b_length() { let src = make_block(1.0, 0.0, 1, 0, 0x00); - let a = vec![encode(&src)]; + let a = vec![encode(&src, 0.0)]; // B is too short for k=QK_K, n=3. let b = vec![0u16; 10]; let result = std::panic::catch_unwind(move || {