diff --git a/src/lib.rs b/src/lib.rs index 0a1cb56..8b2a21e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ //! dequantise each row of A into f32, convert each element of B from fp16 to //! f32, accumulate dot-products. No SIMD, no tiling, no tricks. +pub mod rle; + // --------------------------------------------------------------------------- // Constants matching GGML's ggml-common.h // --------------------------------------------------------------------------- diff --git a/src/rle.rs b/src/rle.rs new file mode 100644 index 0000000..4c6788a --- /dev/null +++ b/src/rle.rs @@ -0,0 +1,992 @@ +//! RLE-optional Q4_K super-block encoding. +//! +//! This module provides [`BlockQ4KRle`], a variant of [`crate::BlockQ4K`] that +//! optionally compresses the 128-byte weight payload using **byte-level +//! run-length encoding** (RLE). A flag bit in the [`BlockQ4KRle::flags`] +//! field indicates which mode is active: +//! +//! | `IS_RLE` bit | `qs` interpretation | +//! |--------------|------------------------------------------------------------| +//! | 0 | Raw packed nibbles, identical to [`crate::BlockQ4K::qs`] | +//! | 1 | RLE stream of `(value, count)` byte-pairs | +//! +//! ## RLE format (when `IS_RLE` = 1) +//! +//! - `flags >> 1` gives the number of `(value, count)` pairs stored in `qs`. +//! - For each pair `i`: +//! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing +//! as the raw format). +//! - `qs[2*i + 1]` — the run length in bytes (1..=255). +//! - The run lengths must sum to exactly 128 (the uncompressed `qs` size). +//! +//! RLE encoding is chosen only when the compressed representation is +//! **strictly shorter** than the 128-byte raw payload, i.e. when +//! `pairs * 2 < 128`. That caps the useful range at ≤ 63 pairs. The 7-bit +//! `flags >> 1` sub-field can hold up to 127, so this ceiling is never a +//! concern in practice. +//! +//! ## Constructing blocks +//! +//! Use [`encode`] to convert an existing [`crate::BlockQ4K`] into a +//! [`BlockQ4KRle`]. The function automatically selects the better mode. +//! +//! ## Adding this module to your crate +//! +//! Add `pub mod rle;` to `lib.rs`. + +use crate::{fp16_to_f32, get_scale_min, BlockQ4K, K_SCALE_SIZE, QK_K}; + +// --------------------------------------------------------------------------- +// Flag constants +// --------------------------------------------------------------------------- + +/// Flag bit in [`BlockQ4KRle::flags`]: if set, `qs` contains an RLE stream. +pub const IS_RLE: u8 = 0x01; + +// --------------------------------------------------------------------------- +// Block definition +// --------------------------------------------------------------------------- + +/// A Q4_K super-block with optional byte-level RLE compression on the weights. +/// +/// Identical to [`crate::BlockQ4K`] except for the additional [`flags`](Self::flags) +/// byte inserted between `scales` and `qs`. +/// +/// Memory layout (repr C): +/// +/// | Offset | Field | Size | Notes | +/// |--------|------------|-------|--------------------------------| +/// | 0 | `d` | 2 B | fp16 super-block scale | +/// | 2 | `dmin` | fp16 super-block min scale | 2 B | +/// | 4 | `scales` | 12 B | packed 6-bit sub-block params | +/// | 16 | `flags` | 1 B | encoding flags (see below) | +/// | 17 | `qs` | 128 B | raw nibbles or RLE stream | +/// | 145 | (padding) | 1 B | implicit trailing alignment pad| +/// +/// **sizeof = 146 bytes** (padded to 2-byte alignment imposed by `u16` fields). +/// +/// ## `flags` bit layout +/// +/// | Bits | Meaning | +/// |------|---------------------------------------------------------------| +/// | 0 | [`IS_RLE`] — 1 = `qs` is RLE-encoded, 0 = raw packed nibbles | +/// | 1–7 | When `IS_RLE`=1: number of `(value, count)` pairs in `qs` | +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct BlockQ4KRle { + /// Super-block scale for quantised sub-block scales (fp16 bits). + pub d: u16, + /// Super-block scale for quantised sub-block mins (fp16 bits). + pub dmin: u16, + /// Packed 6-bit sub-block scales and mins (same layout as [`crate::BlockQ4K`]). + pub scales: [u8; K_SCALE_SIZE], + /// Encoding flags. Bit 0 = [`IS_RLE`]. Bits 1-7 = RLE pair count when + /// `IS_RLE` is set. + pub flags: u8, + /// Raw packed-nibble weights (`IS_RLE` = 0) or RLE byte stream (`IS_RLE` = 1). + pub qs: [u8; QK_K / 2], +} + +impl BlockQ4KRle { + /// Returns `true` when `qs` holds an RLE-encoded stream. + #[inline] + pub fn is_rle(&self) -> bool { + self.flags & IS_RLE != 0 + } + + /// Number of `(value, count)` byte-pairs stored at the start of `qs`. + /// + /// Only meaningful when [`is_rle`](Self::is_rle) returns `true`. + #[inline] + pub fn rle_len(&self) -> usize { + (self.flags >> 1) as usize + } +} + +// --------------------------------------------------------------------------- +// Encoding +// --------------------------------------------------------------------------- + +/// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block. +/// +/// The 128-byte `qs` payload is scanned for runs of identical bytes. If the +/// RLE representation fits in the same 128-byte field **and is strictly +/// shorter** than the raw payload, it is stored with `IS_RLE` set. Otherwise +/// the raw bytes are copied unchanged and `IS_RLE` is cleared. +/// +/// The `d`, `dmin`, and `scales` fields are always copied verbatim. +pub fn encode(block: &BlockQ4K) -> BlockQ4KRle { + let raw = &block.qs; + + // Scan the 128-byte raw payload for runs of equal bytes. + let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(64); + let mut i = 0usize; + while i < raw.len() { + let val = raw[i]; + // Count consecutive equal bytes; saturate at u8::MAX to stay in-range. + let mut run = 1u8; + while i + (run as usize) < raw.len() + && raw[i + (run as usize)] == val + && run < u8::MAX + { + run += 1; + } + pairs.push((val, run)); + i += run as usize; + } + + // Only switch to RLE when the encoded form is strictly smaller than the + // raw payload. Because each pair costs 2 bytes and the raw payload is + // 128 bytes, the condition pairs.len() * 2 < 128 also guarantees that + // pairs.len() ≤ 63, which fits in bits 1-7 of the flags byte. + if pairs.len() * 2 < raw.len() { + let n = pairs.len(); + debug_assert!(n <= 63, "RLE pair count {n} unexpectedly exceeds 63"); + + let mut qs = [0u8; QK_K / 2]; + for (k, &(val, count)) in pairs.iter().enumerate() { + qs[2 * k] = val; + qs[2 * k + 1] = count; + } + + BlockQ4KRle { + d: block.d, + dmin: block.dmin, + scales: block.scales, + flags: IS_RLE | ((n as u8) << 1), + qs, + } + } else { + // No space savings — copy raw bytes and leave IS_RLE clear. + BlockQ4KRle { + d: block.d, + dmin: block.dmin, + scales: block.scales, + flags: 0, + qs: block.qs, + } + } +} + +// --------------------------------------------------------------------------- +// Decoding helpers +// --------------------------------------------------------------------------- + +/// Expand the `qs` field of a [`BlockQ4KRle`] block into the 128-byte raw +/// packed-nibble array, handling both raw and RLE modes transparently. +/// +/// # Panics (debug builds only) +/// +/// Panics if the decoded RLE stream does not sum to exactly 128 bytes. +fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] { + if !block.is_rle() { + return block.qs; + } + + let n = block.rle_len(); + let mut raw = [0u8; QK_K / 2]; + let mut pos = 0usize; + + for i in 0..n { + let val = block.qs[2 * i]; + let count = block.qs[2 * i + 1] as usize; + raw[pos..pos + count].fill(val); + pos += count; + } + + debug_assert_eq!( + pos, + QK_K / 2, + "RLE run lengths sum to {pos}, expected {}", + QK_K / 2 + ); + raw +} + +// --------------------------------------------------------------------------- +// Dequantisation +// --------------------------------------------------------------------------- + +/// Dequantise one [`BlockQ4KRle`] super-block into [`QK_K`] (256) `f32` values. +/// +/// When `IS_RLE` is set the RLE stream is first expanded into a 128-byte raw +/// buffer; thereafter the dequantisation is identical to +/// [`crate::dequantize_block_q4k`]: +/// +/// ```text +/// out[i] = d * scale[s] * nibble[i] - dmin * min[s] +/// ``` +/// +/// where `s` is the sub-block index (0..8) that the element belongs to. +pub fn dequantize_block_q4k_rle(block: &BlockQ4KRle, out: &mut [f32; QK_K]) { + let d = fp16_to_f32(block.d); + let dmin = fp16_to_f32(block.dmin); + let qs = decode_qs(block); + + let mut q_off = 0usize; // byte cursor into the raw qs array + let mut out_off = 0usize; // element cursor into `out` + let mut is = 0usize; // sub-block pair index (0, 2, 4, 6) + + while out_off < QK_K { + let (sc1, mn1) = get_scale_min(is, &block.scales); + let (sc2, mn2) = get_scale_min(is + 1, &block.scales); + + let d1 = d * sc1 as f32; + let m1 = dmin * mn1 as f32; + let d2 = d * sc2 as f32; + let m2 = dmin * mn2 as f32; + + for l in 0..32 { + out[out_off + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1; + } + for l in 0..32 { + out[out_off + 32 + l] = d2 * (qs[q_off + l] >> 4) as f32 - m2; + } + + q_off += 32; + out_off += 64; + is += 2; + } +} + +// --------------------------------------------------------------------------- +// Matrix multiplication C = A × B +// --------------------------------------------------------------------------- + +/// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32 +/// matrix **C**. +/// +/// Identical semantics to [`crate::matmul_q4k_fp16`] but accepts +/// [`BlockQ4KRle`] blocks. Each block is dequantised on the fly via +/// [`dequantize_block_q4k_rle`], transparently handling mixed raw/RLE blocks +/// within the same matrix. +/// +/// # Arguments +/// +/// * `a` – Row-major slice of [`BlockQ4KRle`]. Row `i` occupies blocks +/// `a[i * blocks_per_row .. (i+1) * blocks_per_row]`. +/// * `b` – Row-major fp16 matrix stored as raw `u16` bits, shape \[K, N\]. +/// Element `(ki, j)` is at index `ki * n + j`. +/// * `m` – Number of rows in A (and C). +/// * `k` – Number of columns in A = number of rows in B. +/// **Must** be a multiple of [`QK_K`] (256). +/// * `n` – Number of columns in B (and C). +/// +/// # Returns +/// +/// A flat row-major `Vec` of shape \[M, N\]. +/// +/// # Panics +/// +/// Panics if `k` is not a multiple of `QK_K`, or if the lengths of `a` or `b` +/// do not match the declared dimensions. +pub fn matmul_q4k_rle_fp16( + a: &[BlockQ4KRle], + b: &[u16], + m: usize, + k: usize, + n: usize, +) -> Vec { + assert_eq!( + k % QK_K, + 0, + "k ({k}) must be a multiple of QK_K ({QK_K})" + ); + let blocks_per_row = k / QK_K; + assert_eq!( + a.len(), + m * blocks_per_row, + "A block count mismatch: expected {} blocks, got {}", + m * blocks_per_row, + a.len() + ); + assert_eq!( + b.len(), + k * n, + "B element count mismatch: expected {}, got {}", + k * n, + b.len() + ); + + let mut c = vec![0.0f32; m * n]; + let mut a_row = vec![0.0f32; k]; + let mut block_buf = [0.0f32; QK_K]; + + for i in 0..m { + // Dequantise row i of A into a_row (f32). + for b_idx in 0..blocks_per_row { + let block = &a[i * blocks_per_row + b_idx]; + dequantize_block_q4k_rle(block, &mut block_buf); + let start = b_idx * QK_K; + a_row[start..start + QK_K].copy_from_slice(&block_buf); + } + + // Dot-product with each column of B. + for j in 0..n { + let mut sum = 0.0f32; + for ki in 0..k { + sum += a_row[ki] * fp16_to_f32(b[ki * n + j]); + } + c[i * n + j] = sum; + } + } + + c +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::{dequantize_block_q4k, matmul_q4k_fp16, BlockQ4K}; + + // ------------------------------------------------------------------------- + // Test helpers + // ------------------------------------------------------------------------- + + /// Convert a normal finite f32 to its IEEE 754 fp16 bit pattern. + /// + /// Panics if the value falls outside the representable fp16 normal range. + fn f32_to_fp16_bits(f: f32) -> u16 { + if f == 0.0 { return 0x0000; } + if f == f32::INFINITY { return 0x7C00; } + if f == f32::NEG_INFINITY { return 0xFC00; } + if f.is_nan() { return 0x7E00; } + + let bits = f.to_bits(); + let sign = ((bits >> 31) as u16) << 15; + let exp = (bits >> 23) & 0xFF; + let mant = bits & 0x007F_FFFF; + let fp16_exp = exp as i32 - 127 + 15; + assert!( + fp16_exp > 0 && fp16_exp < 31, + "f32 value {f} is outside the representable fp16 normal range" + ); + sign | ((fp16_exp as u16) << 10) | ((mant >> 13) as u16) + } + + /// Build a [`BlockQ4K`] where all 8 sub-blocks share the same `scale` and + /// `min` (both < 16), and every byte in `qs` is `qs_byte`. + fn make_block(d: f32, dmin: f32, scale: u8, min: u8, qs_byte: u8) -> BlockQ4K { + assert!( + scale < 16 && min < 16, + "make_block: scale ({scale}) and min ({min}) must both be < 16" + ); + let mut scales = [0u8; K_SCALE_SIZE]; + for j in 0..4 { + scales[j] = scale; + scales[j + 4] = min; + } + for j in 8..12 { + scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4); + } + BlockQ4K { + d: f32_to_fp16_bits(d), + dmin: f32_to_fp16_bits(dmin), + scales, + qs: [qs_byte; QK_K / 2], + } + } + + /// Build a [`BlockQ4K`] with a custom `qs` array. + fn make_block_with_qs( + d: f32, + dmin: f32, + scale: u8, + min: u8, + qs: [u8; QK_K / 2], + ) -> BlockQ4K { + assert!( + scale < 16 && min < 16, + "make_block_with_qs: scale ({scale}) and min ({min}) must both be < 16" + ); + let mut scales = [0u8; K_SCALE_SIZE]; + for j in 0..4 { + scales[j] = scale; + scales[j + 4] = min; + } + for j in 8..12 { + scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4); + } + BlockQ4K { + d: f32_to_fp16_bits(d), + dmin: f32_to_fp16_bits(dmin), + scales, + qs, + } + } + + fn assert_close(got: f32, expected: f32, tol: f32) { + assert!( + (got - expected).abs() <= tol, + "got {got}, expected {expected} (tol {tol})" + ); + } + + fn assert_all_close(got: &[f32], expected_scalar: f32, tol: f32) { + for (i, &g) in got.iter().enumerate() { + assert!( + (g - expected_scalar).abs() <= tol, + "element {i}: got {g}, expected {expected_scalar} (tol {tol})" + ); + } + } + + fn assert_slices_close(got: &[f32], expected: &[f32], tol: f32) { + assert_eq!(got.len(), expected.len(), "slice length mismatch"); + for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() { + assert!( + (g - e).abs() <= tol, + "element {i}: got {g}, expected {e} (tol {tol})" + ); + } + } + + fn fp16_uniform(k: usize, n: usize, value: f32) -> Vec { + vec![f32_to_fp16_bits(value); k * n] + } + + // ========================================================================= + // Struct layout + // ========================================================================= + + #[test] + fn block_q4k_rle_size_is_146_bytes() { + // d(2) + dmin(2) + scales(12) + flags(1) + qs(128) = 145 raw bytes, + // rounded up to 146 by the 2-byte alignment imposed by the u16 fields. + assert_eq!(core::mem::size_of::(), 146); + } + + #[test] + fn block_q4k_rle_is_two_bytes_larger_than_block_q4k() { + assert_eq!( + core::mem::size_of::(), + core::mem::size_of::() + 2, + ); + } + + // ========================================================================= + // is_rle / rle_len + // ========================================================================= + + #[test] + fn is_rle_false_when_flag_clear() { + let b = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2], + }; + assert!(!b.is_rle()); + } + + #[test] + fn rle_len_zero_when_flag_clear() { + let b = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2], + }; + assert_eq!(b.rle_len(), 0); + } + + #[test] + fn is_rle_true_when_flag_set() { + let b = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], + flags: IS_RLE | (5u8 << 1), + qs: [0; QK_K / 2], + }; + assert!(b.is_rle()); + } + + #[test] + fn rle_len_reports_pair_count_from_flags() { + for n in [0usize, 1, 7, 31, 63] { + let b = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], + flags: IS_RLE | ((n as u8) << 1), + qs: [0; QK_K / 2], + }; + assert_eq!(b.rle_len(), n, "expected rle_len {n}"); + } + } + + // ========================================================================= + // encode: mode selection + // ========================================================================= + + #[test] + fn encode_uniform_qs_uses_rle() { + // 128 identical bytes → 1 pair → 2 bytes < 128 raw. + let src = make_block(1.0, 0.0, 1, 0, 0x77); + let rle = encode(&src); + assert!(rle.is_rle(), "uniform qs should trigger RLE mode"); + } + + #[test] + fn encode_uniform_qs_rle_len_is_one() { + let src = make_block(1.0, 0.0, 1, 0, 0x55); + let rle = encode(&src); + assert_eq!(rle.rle_len(), 1); + } + + #[test] + fn encode_uniform_qs_rle_entry_is_correct() { + let src = make_block(1.0, 0.0, 1, 0, 0xAB); + let rle = encode(&src); + assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte"); + assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes"); + } + + #[test] + fn encode_alternating_bytes_stays_raw() { + // 128 single-byte runs → 128 pairs → 256 bytes ≥ 128 raw → raw mode. + let mut qs = [0u8; QK_K / 2]; + for (i, b) in qs.iter_mut().enumerate() { + *b = if i % 2 == 0 { 0xAA } else { 0x55 }; + } + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + let rle = encode(&src); + assert!(!rle.is_rle(), "alternating bytes cannot be compressed → raw mode"); + } + + #[test] + fn encode_raw_mode_copies_qs_verbatim() { + let mut qs = [0u8; QK_K / 2]; + for (i, b) in qs.iter_mut().enumerate() { + // Three-byte cycle of distinct values → 128 runs of 1 byte each. + *b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 }; + } + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + let rle = encode(&src); + assert!(!rle.is_rle()); + assert_eq!(rle.qs, qs, "raw mode must preserve qs bytes unchanged"); + } + + #[test] + fn encode_two_runs_uses_rle_and_stores_correct_pairs() { + // Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22. + // → 2 pairs = 4 bytes < 128 bytes raw. + let mut qs = [0u8; QK_K / 2]; + qs[..64].fill(0x11); + qs[64..].fill(0x22); + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + let rle = encode(&src); + assert!(rle.is_rle()); + assert_eq!(rle.rle_len(), 2); + assert_eq!(rle.qs[0], 0x11, "first pair: value"); + assert_eq!(rle.qs[1], 64, "first pair: run length"); + assert_eq!(rle.qs[2], 0x22, "second pair: value"); + assert_eq!(rle.qs[3], 64, "second pair: run length"); + } + + #[test] + fn encode_63_pairs_uses_rle() { + // Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes. + // 63 pairs × 2 = 126 bytes < 128 → RLE should be chosen. + let mut qs = [0u8; QK_K / 2]; + let mut pos = 0usize; + for run in 0..62usize { + // Use a stride-3 sequence so consecutive values are always distinct. + let v = (run as u8).wrapping_mul(3).wrapping_add(1); + qs[pos] = v; + qs[pos + 1] = v; + pos += 2; + } + // Final run: 4 bytes, value chosen to differ from the previous one. + qs[pos..].fill(0xFE); + + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + let rle = encode(&src); + assert!(rle.is_rle(), "63 pairs should use RLE"); + assert_eq!(rle.rle_len(), 63); + } + + #[test] + fn encode_64_pairs_stays_raw() { + // 64 runs of 2 bytes each = 128 bytes total. + // 64 pairs × 2 = 128 bytes, which is NOT strictly less than 128 → raw. + let mut qs = [0u8; QK_K / 2]; + let mut pos = 0usize; + for run in 0..64usize { + let v = (run as u8).wrapping_mul(3).wrapping_add(1); + qs[pos] = v; + qs[pos + 1] = v; + pos += 2; + } + let src = make_block_with_qs(1.0, 0.0, 1, 0, qs); + let rle = encode(&src); + assert!(!rle.is_rle(), "64 pairs offers no saving → raw mode"); + } + + #[test] + fn encode_preserves_d_dmin_scales() { + let src = make_block(2.0, 0.5, 3, 2, 0x00); + let rle = encode(&src); + assert_eq!(rle.d, src.d); + assert_eq!(rle.dmin, src.dmin); + assert_eq!(rle.scales, src.scales); + } + + // ========================================================================= + // decode_qs (tested indirectly through dequantise, but also directly) + // ========================================================================= + + #[test] + fn decode_qs_raw_mode_returns_qs_unchanged() { + // Build a raw BlockQ4KRle (flags = 0) with a non-trivial qs pattern. + let mut qs = [0u8; QK_K / 2]; + for (i, b) in qs.iter_mut().enumerate() { *b = i as u8; } + let rle = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs, + }; + assert_eq!(decode_qs(&rle), qs); + } + + #[test] + fn decode_qs_rle_expands_two_pair_stream() { + // Hand-craft an RLE block: [0xAA × 64, 0xBB × 64]. + let mut qs = [0u8; QK_K / 2]; + qs[0] = 0xAA; qs[1] = 64; + qs[2] = 0xBB; qs[3] = 64; + let rle = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], + flags: IS_RLE | (2u8 << 1), + qs, + }; + let expanded = decode_qs(&rle); + assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA"); + assert!(expanded[64..].iter().all(|&b| b == 0xBB), "last 64 bytes must be 0xBB"); + } + + #[test] + fn decode_qs_rle_single_run_covers_all() { + let mut qs = [0u8; QK_K / 2]; + qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes + let rle = BlockQ4KRle { + d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], + flags: IS_RLE | (1u8 << 1), + qs, + }; + let expanded = decode_qs(&rle); + assert!(expanded.iter().all(|&b| b == 0xCD)); + } + + // ========================================================================= + // dequantize_block_q4k_rle + // ========================================================================= + + #[test] + fn dequant_rle_zero_d_all_outputs_zero() { + let src = make_block(0.0, 0.0, 1, 0, 0x77); + let rle = encode(&src); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut out); + assert_all_close(&out, 0.0, 0.0); + } + + #[test] + fn dequant_rle_uniform_nibble_one_scale_one() { + // qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0. + // expected: 1.0 * 1 * 1 - 0.0 = 1.0 + let src = make_block(1.0, 0.0, 1, 0, 0x11); + let rle = encode(&src); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut out); + assert_all_close(&out, 1.0, 1e-5); + } + + #[test] + fn dequant_rle_non_zero_min_subtracts() { + // nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0 + // expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0 + let src = make_block(1.0, 1.0, 1, 2, 0x00); + let rle = encode(&src); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut out); + assert_all_close(&out, -2.0, 1e-5); + } + + #[test] + fn dequant_rle_max_nibble_15() { + // qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0. + // expected: 1.0 * 1 * 15 - 0.0 = 15.0 + let src = make_block(1.0, 0.0, 1, 0, 0xFF); + let rle = encode(&src); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut out); + assert_all_close(&out, 15.0, 1e-5); + } + + #[test] + fn dequant_rle_output_count_is_qk_k() { + let src = make_block(1.0, 0.0, 1, 0, 0x00); + let rle = encode(&src); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut out); + assert_eq!(out.len(), QK_K); + } + + #[test] + fn dequant_rle_larger_scale_multiplies() { + // nibble = 3, scale = 4, d = 2.0, min = 0 + // expected: 2.0 * 4 * 3 - 0.0 = 24.0 + // qs_byte = 0x33 → both nibbles = 3 + let src = make_block(2.0, 0.0, 4, 0, 0x33); + let rle = encode(&src); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut out); + assert_all_close(&out, 24.0, 1e-4); + } + + // ========================================================================= + // Roundtrip: encode → dequantize must match original dequantize + // ========================================================================= + + #[test] + fn roundtrip_rle_mode_matches_original() { + // Uniform qs → RLE mode selected. + let src = make_block(2.0, 0.5, 3, 1, 0x37); + let rle = encode(&src); + assert!(rle.is_rle()); + + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + } + + #[test] + fn roundtrip_raw_mode_matches_original() { + // Alternating bytes → raw mode selected; output must still be correct. + let mut qs = [0u8; QK_K / 2]; + for (i, b) in qs.iter_mut().enumerate() { + *b = if i % 2 == 0 { 0x13 } else { 0x24 }; + } + let src = make_block_with_qs(1.5, 0.25, 2, 1, qs); + let rle = encode(&src); + assert!(!rle.is_rle()); + + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + } + + #[test] + fn roundtrip_two_run_block_matches_original() { + let mut qs = [0u8; QK_K / 2]; + qs[..64].fill(0x59); + qs[64..].fill(0x8C); + let src = make_block_with_qs(3.0, 1.0, 5, 2, qs); + let rle = encode(&src); + assert!(rle.is_rle()); + + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + } + + #[test] + fn roundtrip_many_short_runs_matches_original() { + // Four distinct runs of varying lengths → still compresses. + let mut qs = [0u8; QK_K / 2]; + qs[..10].fill(0x11); + qs[10..30].fill(0x22); + qs[30..31].fill(0x33); + qs[31..].fill(0x44); + let src = make_block_with_qs(1.0, 0.5, 7, 3, qs); + let rle = encode(&src); + assert!(rle.is_rle(), "4-run block should compress"); + assert_eq!(rle.rle_len(), 4); + + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + } + + #[test] + fn roundtrip_zero_qs_matches_original() { + let src = make_block(1.0, 0.5, 2, 1, 0x00); + let rle = encode(&src); + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + } + + #[test] + fn roundtrip_nibble_split_low_high_correct() { + // qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3 + // (sub-block 1 path). Verify both halves are dequantised correctly. + let src = make_block(1.0, 0.0, 1, 0, 0x37); + let rle = encode(&src); + let mut got = [0.0f32; QK_K]; + let mut expected = [0.0f32; QK_K]; + dequantize_block_q4k_rle(&rle, &mut got); + dequantize_block_q4k(&src, &mut expected); + assert_slices_close(&got, &expected, 1e-5); + // First 32 elements of each 64-element group → low nibble = 7. + assert_close(got[0], 7.0, 1e-5); + // Next 32 elements → high nibble = 3. + assert_close(got[32], 3.0, 1e-5); + } + + // ========================================================================= + // matmul_q4k_rle_fp16 + // ========================================================================= + + #[test] + fn matmul_rle_1x256_times_256x1_all_ones() { + // A: 1×256, all weights = nibble 1, scale = 1, d = 1.0 + // B: 256×1, all fp16 1.0 + // C = dot([1.0; 256], [1.0; 256]) = 256.0 + let src = make_block(1.0, 0.0, 1, 0, 0x11); + let a = vec![encode(&src)]; + let b = fp16_uniform(QK_K, 1, 1.0); + let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1); + assert_eq!(c.len(), 1); + assert_close(c[0], 256.0, 1e-3); + } + + #[test] + fn matmul_rle_2x256_times_256x3_all_ones() { + let src = make_block(1.0, 0.0, 1, 0, 0x11); + let a = vec![encode(&src), encode(&src)]; + let b = fp16_uniform(QK_K, 3, 1.0); + let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3); + assert_eq!(c.len(), 6); + assert_all_close(&c, 256.0, 1e-3); + } + + #[test] + fn matmul_rle_zero_a_gives_zero_c() { + let src = make_block(0.0, 0.0, 1, 0, 0xFF); + let a = vec![encode(&src)]; + let b = fp16_uniform(QK_K, 4, 1.0); + let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4); + assert_all_close(&c, 0.0, 0.0); + } + + #[test] + fn matmul_rle_zero_b_gives_zero_c() { + let src = make_block(1.0, 0.0, 1, 0, 0x11); + let a = vec![encode(&src)]; + let b = fp16_uniform(QK_K, 2, 0.0); + let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2); + assert_all_close(&c, 0.0, 0.0); + } + + #[test] + fn matmul_rle_two_blocks_per_row() { + // A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0. + // Expected: 512.0 + let src = make_block(1.0, 0.0, 1, 0, 0x11); + let a = vec![encode(&src), encode(&src)]; + let b = fp16_uniform(2 * QK_K, 1, 1.0); + let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1); + assert_eq!(c.len(), 1); + assert_close(c[0], 512.0, 1e-3); + } + + #[test] + fn matmul_rle_output_shape_m_times_n() { + // A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements. + let src = make_block(1.0, 0.0, 1, 0, 0x00); + let a: Vec = (0..6).map(|_| encode(&src)).collect(); + let b = fp16_uniform(2 * QK_K, 4, 0.0); + let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4); + assert_eq!(c.len(), 12); + } + + #[test] + fn matmul_rle_scalar_b_scales_output() { + // Multiplying B by a scalar should scale C by the same factor. + let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0 + let a = vec![encode(&src)]; + let b1 = fp16_uniform(QK_K, 1, 1.0); + let b2 = fp16_uniform(QK_K, 1, 3.0); + let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1); + let c2 = matmul_q4k_rle_fp16(&a, &b2, 1, QK_K, 1); + assert_close(c2[0], c1[0] * 3.0, 1e-2); + } + + #[test] + fn matmul_rle_matches_original_matmul_mixed_blocks() { + // Mix: first block uniform (RLE), second block alternating (raw). + // Both matmul implementations should produce identical results. + let src_rle = make_block(2.0, 0.5, 3, 1, 0x37); + + let mut qs_raw = [0u8; QK_K / 2]; + for (i, b) in qs_raw.iter_mut().enumerate() { + *b = if i % 2 == 0 { 0x13 } else { 0x24 }; + } + let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw); + + let a_orig: Vec = vec![src_rle, src_raw]; + let a_rle: Vec = a_orig.iter().map(encode).collect(); + + // A: 1×512, B: 512×2 + let b = fp16_uniform(2 * QK_K, 2, 1.0); + let c_orig = matmul_q4k_fp16(&a_orig, &b, 1, 2 * QK_K, 2); + let c_rle = matmul_q4k_rle_fp16(&a_rle, &b, 1, 2 * QK_K, 2); + assert_slices_close(&c_rle, &c_orig, 1e-4); + } + + #[test] + fn matmul_rle_multiple_rows_multiple_blocks_per_row() { + // A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B. + // Each row dot product = 512.0; C should be all 512.0. + let src = make_block(1.0, 0.0, 1, 0, 0x11); + let a: Vec = (0..4).map(|_| encode(&src)).collect(); + let b = fp16_uniform(2 * QK_K, 3, 1.0); + let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3); + assert_eq!(c.len(), 6); + assert_all_close(&c, 512.0, 1e-3); + } + + // ========================================================================= + // Panic / contract checks + // ========================================================================= + + #[test] + fn matmul_rle_panics_when_k_not_multiple_of_qkk() { + let src = make_block(1.0, 0.0, 1, 0, 0x00); + let a = vec![encode(&src)]; + let b = vec![0u16; 512]; + let result = std::panic::catch_unwind(move || { + matmul_q4k_rle_fp16(&a, &b, 1, 512, 2); + }); + assert!(result.is_err(), "should panic when k is not a multiple of QK_K"); + } + + #[test] + fn matmul_rle_panics_on_wrong_a_length() { + let src = make_block(1.0, 0.0, 1, 0, 0x00); + // m=2, k=QK_K requires 2 blocks; only 1 is provided. + let a = vec![encode(&src)]; + let b = fp16_uniform(QK_K, 1, 1.0); + let result = std::panic::catch_unwind(move || { + matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1); + }); + assert!(result.is_err(), "should panic on wrong A block count"); + } + + #[test] + fn matmul_rle_panics_on_wrong_b_length() { + let src = make_block(1.0, 0.0, 1, 0, 0x00); + let a = vec![encode(&src)]; + // B is too short for k=QK_K, n=3. + let b = vec![0u16; 10]; + let result = std::panic::catch_unwind(move || { + matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 3); + }); + assert!(result.is_err(), "should panic on wrong B element count"); + } +}