This commit is contained in:
2026-04-12 15:30:04 -07:00
parent 16d1f37ae5
commit 5d310b8df5
2 changed files with 994 additions and 0 deletions

View File

@@ -18,6 +18,8 @@
//! dequantise each row of A into f32, convert each element of B from fp16 to
//! f32, accumulate dot-products. No SIMD, no tiling, no tricks.
pub mod rle;
// ---------------------------------------------------------------------------
// Constants matching GGML's ggml-common.h
// ---------------------------------------------------------------------------

992
src/rle.rs Normal file
View File

@@ -0,0 +1,992 @@
//! RLE-optional Q4_K super-block encoding.
//!
//! This module provides [`BlockQ4KRle`], a variant of [`crate::BlockQ4K`] that
//! optionally compresses the 128-byte weight payload using **byte-level
//! run-length encoding** (RLE). A flag bit in the [`BlockQ4KRle::flags`]
//! field indicates which mode is active:
//!
//! | `IS_RLE` bit | `qs` interpretation |
//! |--------------|------------------------------------------------------------|
//! | 0 | Raw packed nibbles, identical to [`crate::BlockQ4K::qs`] |
//! | 1 | RLE stream of `(value, count)` byte-pairs |
//!
//! ## RLE format (when `IS_RLE` = 1)
//!
//! - `flags >> 1` gives the number of `(value, count)` pairs stored in `qs`.
//! - For each pair `i`:
//! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing
//! as the raw format).
//! - `qs[2*i + 1]` — the run length in bytes (1..=255).
//! - The run lengths must sum to exactly 128 (the uncompressed `qs` size).
//!
//! RLE encoding is chosen only when the compressed representation is
//! **strictly shorter** than the 128-byte raw payload, i.e. when
//! `pairs * 2 < 128`. That caps the useful range at ≤ 63 pairs. The 7-bit
//! `flags >> 1` sub-field can hold up to 127, so this ceiling is never a
//! concern in practice.
//!
//! ## Constructing blocks
//!
//! Use [`encode`] to convert an existing [`crate::BlockQ4K`] into a
//! [`BlockQ4KRle`]. The function automatically selects the better mode.
//!
//! ## Adding this module to your crate
//!
//! Add `pub mod rle;` to `lib.rs`.
use crate::{fp16_to_f32, get_scale_min, BlockQ4K, K_SCALE_SIZE, QK_K};
// ---------------------------------------------------------------------------
// Flag constants
// ---------------------------------------------------------------------------
/// Flag bit in [`BlockQ4KRle::flags`]: if set, `qs` contains an RLE stream.
pub const IS_RLE: u8 = 0x01;
// ---------------------------------------------------------------------------
// Block definition
// ---------------------------------------------------------------------------
/// A Q4_K super-block with optional byte-level RLE compression on the weights.
///
/// Identical to [`crate::BlockQ4K`] except for the additional [`flags`](Self::flags)
/// byte inserted between `scales` and `qs`.
///
/// Memory layout (repr C):
///
/// | Offset | Field | Size | Notes |
/// |--------|------------|-------|--------------------------------|
/// | 0 | `d` | 2 B | fp16 super-block scale |
/// | 2 | `dmin` | fp16 super-block min scale | 2 B |
/// | 4 | `scales` | 12 B | packed 6-bit sub-block params |
/// | 16 | `flags` | 1 B | encoding flags (see below) |
/// | 17 | `qs` | 128 B | raw nibbles or RLE stream |
/// | 145 | (padding) | 1 B | implicit trailing alignment pad|
///
/// **sizeof = 146 bytes** (padded to 2-byte alignment imposed by `u16` fields).
///
/// ## `flags` bit layout
///
/// | Bits | Meaning |
/// |------|---------------------------------------------------------------|
/// | 0 | [`IS_RLE`] — 1 = `qs` is RLE-encoded, 0 = raw packed nibbles |
/// | 17 | When `IS_RLE`=1: number of `(value, count)` pairs in `qs` |
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct BlockQ4KRle {
/// Super-block scale for quantised sub-block scales (fp16 bits).
pub d: u16,
/// Super-block scale for quantised sub-block mins (fp16 bits).
pub dmin: u16,
/// Packed 6-bit sub-block scales and mins (same layout as [`crate::BlockQ4K`]).
pub scales: [u8; K_SCALE_SIZE],
/// Encoding flags. Bit 0 = [`IS_RLE`]. Bits 1-7 = RLE pair count when
/// `IS_RLE` is set.
pub flags: u8,
/// Raw packed-nibble weights (`IS_RLE` = 0) or RLE byte stream (`IS_RLE` = 1).
pub qs: [u8; QK_K / 2],
}
impl BlockQ4KRle {
/// Returns `true` when `qs` holds an RLE-encoded stream.
#[inline]
pub fn is_rle(&self) -> bool {
self.flags & IS_RLE != 0
}
/// Number of `(value, count)` byte-pairs stored at the start of `qs`.
///
/// Only meaningful when [`is_rle`](Self::is_rle) returns `true`.
#[inline]
pub fn rle_len(&self) -> usize {
(self.flags >> 1) as usize
}
}
// ---------------------------------------------------------------------------
// Encoding
// ---------------------------------------------------------------------------
/// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block.
///
/// The 128-byte `qs` payload is scanned for runs of identical bytes. If the
/// RLE representation fits in the same 128-byte field **and is strictly
/// shorter** than the raw payload, it is stored with `IS_RLE` set. Otherwise
/// the raw bytes are copied unchanged and `IS_RLE` is cleared.
///
/// The `d`, `dmin`, and `scales` fields are always copied verbatim.
pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
let raw = &block.qs;
// Scan the 128-byte raw payload for runs of equal bytes.
let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(64);
let mut i = 0usize;
while i < raw.len() {
let val = raw[i];
// Count consecutive equal bytes; saturate at u8::MAX to stay in-range.
let mut run = 1u8;
while i + (run as usize) < raw.len()
&& raw[i + (run as usize)] == val
&& run < u8::MAX
{
run += 1;
}
pairs.push((val, run));
i += run as usize;
}
// Only switch to RLE when the encoded form is strictly smaller than the
// raw payload. Because each pair costs 2 bytes and the raw payload is
// 128 bytes, the condition pairs.len() * 2 < 128 also guarantees that
// pairs.len() ≤ 63, which fits in bits 1-7 of the flags byte.
if pairs.len() * 2 < raw.len() {
let n = pairs.len();
debug_assert!(n <= 63, "RLE pair count {n} unexpectedly exceeds 63");
let mut qs = [0u8; QK_K / 2];
for (k, &(val, count)) in pairs.iter().enumerate() {
qs[2 * k] = val;
qs[2 * k + 1] = count;
}
BlockQ4KRle {
d: block.d,
dmin: block.dmin,
scales: block.scales,
flags: IS_RLE | ((n as u8) << 1),
qs,
}
} else {
// No space savings — copy raw bytes and leave IS_RLE clear.
BlockQ4KRle {
d: block.d,
dmin: block.dmin,
scales: block.scales,
flags: 0,
qs: block.qs,
}
}
}
// ---------------------------------------------------------------------------
// Decoding helpers
// ---------------------------------------------------------------------------
/// Expand the `qs` field of a [`BlockQ4KRle`] block into the 128-byte raw
/// packed-nibble array, handling both raw and RLE modes transparently.
///
/// # Panics (debug builds only)
///
/// Panics if the decoded RLE stream does not sum to exactly 128 bytes.
fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] {
if !block.is_rle() {
return block.qs;
}
let n = block.rle_len();
let mut raw = [0u8; QK_K / 2];
let mut pos = 0usize;
for i in 0..n {
let val = block.qs[2 * i];
let count = block.qs[2 * i + 1] as usize;
raw[pos..pos + count].fill(val);
pos += count;
}
debug_assert_eq!(
pos,
QK_K / 2,
"RLE run lengths sum to {pos}, expected {}",
QK_K / 2
);
raw
}
// ---------------------------------------------------------------------------
// Dequantisation
// ---------------------------------------------------------------------------
/// Dequantise one [`BlockQ4KRle`] super-block into [`QK_K`] (256) `f32` values.
///
/// When `IS_RLE` is set the RLE stream is first expanded into a 128-byte raw
/// buffer; thereafter the dequantisation is identical to
/// [`crate::dequantize_block_q4k`]:
///
/// ```text
/// out[i] = d * scale[s] * nibble[i] - dmin * min[s]
/// ```
///
/// where `s` is the sub-block index (0..8) that the element belongs to.
pub fn dequantize_block_q4k_rle(block: &BlockQ4KRle, out: &mut [f32; QK_K]) {
let d = fp16_to_f32(block.d);
let dmin = fp16_to_f32(block.dmin);
let qs = decode_qs(block);
let mut q_off = 0usize; // byte cursor into the raw qs array
let mut out_off = 0usize; // element cursor into `out`
let mut is = 0usize; // sub-block pair index (0, 2, 4, 6)
while out_off < QK_K {
let (sc1, mn1) = get_scale_min(is, &block.scales);
let (sc2, mn2) = get_scale_min(is + 1, &block.scales);
let d1 = d * sc1 as f32;
let m1 = dmin * mn1 as f32;
let d2 = d * sc2 as f32;
let m2 = dmin * mn2 as f32;
for l in 0..32 {
out[out_off + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1;
}
for l in 0..32 {
out[out_off + 32 + l] = d2 * (qs[q_off + l] >> 4) as f32 - m2;
}
q_off += 32;
out_off += 64;
is += 2;
}
}
// ---------------------------------------------------------------------------
// Matrix multiplication C = A × B
// ---------------------------------------------------------------------------
/// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32
/// matrix **C**.
///
/// Identical semantics to [`crate::matmul_q4k_fp16`] but accepts
/// [`BlockQ4KRle`] blocks. Each block is dequantised on the fly via
/// [`dequantize_block_q4k_rle`], transparently handling mixed raw/RLE blocks
/// within the same matrix.
///
/// # Arguments
///
/// * `a` Row-major slice of [`BlockQ4KRle`]. Row `i` occupies blocks
/// `a[i * blocks_per_row .. (i+1) * blocks_per_row]`.
/// * `b` Row-major fp16 matrix stored as raw `u16` bits, shape \[K, N\].
/// Element `(ki, j)` is at index `ki * n + j`.
/// * `m` Number of rows in A (and C).
/// * `k` Number of columns in A = number of rows in B.
/// **Must** be a multiple of [`QK_K`] (256).
/// * `n` Number of columns in B (and C).
///
/// # Returns
///
/// A flat row-major `Vec<f32>` of shape \[M, N\].
///
/// # Panics
///
/// Panics if `k` is not a multiple of `QK_K`, or if the lengths of `a` or `b`
/// do not match the declared dimensions.
pub fn matmul_q4k_rle_fp16(
a: &[BlockQ4KRle],
b: &[u16],
m: usize,
k: usize,
n: usize,
) -> Vec<f32> {
assert_eq!(
k % QK_K,
0,
"k ({k}) must be a multiple of QK_K ({QK_K})"
);
let blocks_per_row = k / QK_K;
assert_eq!(
a.len(),
m * blocks_per_row,
"A block count mismatch: expected {} blocks, got {}",
m * blocks_per_row,
a.len()
);
assert_eq!(
b.len(),
k * n,
"B element count mismatch: expected {}, got {}",
k * n,
b.len()
);
let mut c = vec![0.0f32; m * n];
let mut a_row = vec![0.0f32; k];
let mut block_buf = [0.0f32; QK_K];
for i in 0..m {
// Dequantise row i of A into a_row (f32).
for b_idx in 0..blocks_per_row {
let block = &a[i * blocks_per_row + b_idx];
dequantize_block_q4k_rle(block, &mut block_buf);
let start = b_idx * QK_K;
a_row[start..start + QK_K].copy_from_slice(&block_buf);
}
// Dot-product with each column of B.
for j in 0..n {
let mut sum = 0.0f32;
for ki in 0..k {
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
}
c[i * n + j] = sum;
}
}
c
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::{dequantize_block_q4k, matmul_q4k_fp16, BlockQ4K};
// -------------------------------------------------------------------------
// Test helpers
// -------------------------------------------------------------------------
/// Convert a normal finite f32 to its IEEE 754 fp16 bit pattern.
///
/// Panics if the value falls outside the representable fp16 normal range.
fn f32_to_fp16_bits(f: f32) -> u16 {
if f == 0.0 { return 0x0000; }
if f == f32::INFINITY { return 0x7C00; }
if f == f32::NEG_INFINITY { return 0xFC00; }
if f.is_nan() { return 0x7E00; }
let bits = f.to_bits();
let sign = ((bits >> 31) as u16) << 15;
let exp = (bits >> 23) & 0xFF;
let mant = bits & 0x007F_FFFF;
let fp16_exp = exp as i32 - 127 + 15;
assert!(
fp16_exp > 0 && fp16_exp < 31,
"f32 value {f} is outside the representable fp16 normal range"
);
sign | ((fp16_exp as u16) << 10) | ((mant >> 13) as u16)
}
/// Build a [`BlockQ4K`] where all 8 sub-blocks share the same `scale` and
/// `min` (both < 16), and every byte in `qs` is `qs_byte`.
fn make_block(d: f32, dmin: f32, scale: u8, min: u8, qs_byte: u8) -> BlockQ4K {
assert!(
scale < 16 && min < 16,
"make_block: scale ({scale}) and min ({min}) must both be < 16"
);
let mut scales = [0u8; K_SCALE_SIZE];
for j in 0..4 {
scales[j] = scale;
scales[j + 4] = min;
}
for j in 8..12 {
scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4);
}
BlockQ4K {
d: f32_to_fp16_bits(d),
dmin: f32_to_fp16_bits(dmin),
scales,
qs: [qs_byte; QK_K / 2],
}
}
/// Build a [`BlockQ4K`] with a custom `qs` array.
fn make_block_with_qs(
d: f32,
dmin: f32,
scale: u8,
min: u8,
qs: [u8; QK_K / 2],
) -> BlockQ4K {
assert!(
scale < 16 && min < 16,
"make_block_with_qs: scale ({scale}) and min ({min}) must both be < 16"
);
let mut scales = [0u8; K_SCALE_SIZE];
for j in 0..4 {
scales[j] = scale;
scales[j + 4] = min;
}
for j in 8..12 {
scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4);
}
BlockQ4K {
d: f32_to_fp16_bits(d),
dmin: f32_to_fp16_bits(dmin),
scales,
qs,
}
}
fn assert_close(got: f32, expected: f32, tol: f32) {
assert!(
(got - expected).abs() <= tol,
"got {got}, expected {expected} (tol {tol})"
);
}
fn assert_all_close(got: &[f32], expected_scalar: f32, tol: f32) {
for (i, &g) in got.iter().enumerate() {
assert!(
(g - expected_scalar).abs() <= tol,
"element {i}: got {g}, expected {expected_scalar} (tol {tol})"
);
}
}
fn assert_slices_close(got: &[f32], expected: &[f32], tol: f32) {
assert_eq!(got.len(), expected.len(), "slice length mismatch");
for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() {
assert!(
(g - e).abs() <= tol,
"element {i}: got {g}, expected {e} (tol {tol})"
);
}
}
fn fp16_uniform(k: usize, n: usize, value: f32) -> Vec<u16> {
vec![f32_to_fp16_bits(value); k * n]
}
// =========================================================================
// Struct layout
// =========================================================================
#[test]
fn block_q4k_rle_size_is_146_bytes() {
// d(2) + dmin(2) + scales(12) + flags(1) + qs(128) = 145 raw bytes,
// rounded up to 146 by the 2-byte alignment imposed by the u16 fields.
assert_eq!(core::mem::size_of::<BlockQ4KRle>(), 146);
}
#[test]
fn block_q4k_rle_is_two_bytes_larger_than_block_q4k() {
assert_eq!(
core::mem::size_of::<BlockQ4KRle>(),
core::mem::size_of::<BlockQ4K>() + 2,
);
}
// =========================================================================
// is_rle / rle_len
// =========================================================================
#[test]
fn is_rle_false_when_flag_clear() {
let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2],
};
assert!(!b.is_rle());
}
#[test]
fn rle_len_zero_when_flag_clear() {
let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2],
};
assert_eq!(b.rle_len(), 0);
}
#[test]
fn is_rle_true_when_flag_set() {
let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | (5u8 << 1),
qs: [0; QK_K / 2],
};
assert!(b.is_rle());
}
#[test]
fn rle_len_reports_pair_count_from_flags() {
for n in [0usize, 1, 7, 31, 63] {
let b = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | ((n as u8) << 1),
qs: [0; QK_K / 2],
};
assert_eq!(b.rle_len(), n, "expected rle_len {n}");
}
}
// =========================================================================
// encode: mode selection
// =========================================================================
#[test]
fn encode_uniform_qs_uses_rle() {
// 128 identical bytes → 1 pair → 2 bytes < 128 raw.
let src = make_block(1.0, 0.0, 1, 0, 0x77);
let rle = encode(&src);
assert!(rle.is_rle(), "uniform qs should trigger RLE mode");
}
#[test]
fn encode_uniform_qs_rle_len_is_one() {
let src = make_block(1.0, 0.0, 1, 0, 0x55);
let rle = encode(&src);
assert_eq!(rle.rle_len(), 1);
}
#[test]
fn encode_uniform_qs_rle_entry_is_correct() {
let src = make_block(1.0, 0.0, 1, 0, 0xAB);
let rle = encode(&src);
assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte");
assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes");
}
#[test]
fn encode_alternating_bytes_stays_raw() {
// 128 single-byte runs → 128 pairs → 256 bytes ≥ 128 raw → raw mode.
let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() {
*b = if i % 2 == 0 { 0xAA } else { 0x55 };
}
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src);
assert!(!rle.is_rle(), "alternating bytes cannot be compressed → raw mode");
}
#[test]
fn encode_raw_mode_copies_qs_verbatim() {
let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() {
// Three-byte cycle of distinct values → 128 runs of 1 byte each.
*b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 };
}
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src);
assert!(!rle.is_rle());
assert_eq!(rle.qs, qs, "raw mode must preserve qs bytes unchanged");
}
#[test]
fn encode_two_runs_uses_rle_and_stores_correct_pairs() {
// Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22.
// → 2 pairs = 4 bytes < 128 bytes raw.
let mut qs = [0u8; QK_K / 2];
qs[..64].fill(0x11);
qs[64..].fill(0x22);
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src);
assert!(rle.is_rle());
assert_eq!(rle.rle_len(), 2);
assert_eq!(rle.qs[0], 0x11, "first pair: value");
assert_eq!(rle.qs[1], 64, "first pair: run length");
assert_eq!(rle.qs[2], 0x22, "second pair: value");
assert_eq!(rle.qs[3], 64, "second pair: run length");
}
#[test]
fn encode_63_pairs_uses_rle() {
// Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes.
// 63 pairs × 2 = 126 bytes < 128 → RLE should be chosen.
let mut qs = [0u8; QK_K / 2];
let mut pos = 0usize;
for run in 0..62usize {
// Use a stride-3 sequence so consecutive values are always distinct.
let v = (run as u8).wrapping_mul(3).wrapping_add(1);
qs[pos] = v;
qs[pos + 1] = v;
pos += 2;
}
// Final run: 4 bytes, value chosen to differ from the previous one.
qs[pos..].fill(0xFE);
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src);
assert!(rle.is_rle(), "63 pairs should use RLE");
assert_eq!(rle.rle_len(), 63);
}
#[test]
fn encode_64_pairs_stays_raw() {
// 64 runs of 2 bytes each = 128 bytes total.
// 64 pairs × 2 = 128 bytes, which is NOT strictly less than 128 → raw.
let mut qs = [0u8; QK_K / 2];
let mut pos = 0usize;
for run in 0..64usize {
let v = (run as u8).wrapping_mul(3).wrapping_add(1);
qs[pos] = v;
qs[pos + 1] = v;
pos += 2;
}
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
let rle = encode(&src);
assert!(!rle.is_rle(), "64 pairs offers no saving → raw mode");
}
#[test]
fn encode_preserves_d_dmin_scales() {
let src = make_block(2.0, 0.5, 3, 2, 0x00);
let rle = encode(&src);
assert_eq!(rle.d, src.d);
assert_eq!(rle.dmin, src.dmin);
assert_eq!(rle.scales, src.scales);
}
// =========================================================================
// decode_qs (tested indirectly through dequantise, but also directly)
// =========================================================================
#[test]
fn decode_qs_raw_mode_returns_qs_unchanged() {
// Build a raw BlockQ4KRle (flags = 0) with a non-trivial qs pattern.
let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() { *b = i as u8; }
let rle = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs,
};
assert_eq!(decode_qs(&rle), qs);
}
#[test]
fn decode_qs_rle_expands_two_pair_stream() {
// Hand-craft an RLE block: [0xAA × 64, 0xBB × 64].
let mut qs = [0u8; QK_K / 2];
qs[0] = 0xAA; qs[1] = 64;
qs[2] = 0xBB; qs[3] = 64;
let rle = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | (2u8 << 1),
qs,
};
let expanded = decode_qs(&rle);
assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA");
assert!(expanded[64..].iter().all(|&b| b == 0xBB), "last 64 bytes must be 0xBB");
}
#[test]
fn decode_qs_rle_single_run_covers_all() {
let mut qs = [0u8; QK_K / 2];
qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes
let rle = BlockQ4KRle {
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
flags: IS_RLE | (1u8 << 1),
qs,
};
let expanded = decode_qs(&rle);
assert!(expanded.iter().all(|&b| b == 0xCD));
}
// =========================================================================
// dequantize_block_q4k_rle
// =========================================================================
#[test]
fn dequant_rle_zero_d_all_outputs_zero() {
let src = make_block(0.0, 0.0, 1, 0, 0x77);
let rle = encode(&src);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 0.0, 0.0);
}
#[test]
fn dequant_rle_uniform_nibble_one_scale_one() {
// qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0.
// expected: 1.0 * 1 * 1 - 0.0 = 1.0
let src = make_block(1.0, 0.0, 1, 0, 0x11);
let rle = encode(&src);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 1.0, 1e-5);
}
#[test]
fn dequant_rle_non_zero_min_subtracts() {
// nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0
// expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0
let src = make_block(1.0, 1.0, 1, 2, 0x00);
let rle = encode(&src);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, -2.0, 1e-5);
}
#[test]
fn dequant_rle_max_nibble_15() {
// qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0.
// expected: 1.0 * 1 * 15 - 0.0 = 15.0
let src = make_block(1.0, 0.0, 1, 0, 0xFF);
let rle = encode(&src);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 15.0, 1e-5);
}
#[test]
fn dequant_rle_output_count_is_qk_k() {
let src = make_block(1.0, 0.0, 1, 0, 0x00);
let rle = encode(&src);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out);
assert_eq!(out.len(), QK_K);
}
#[test]
fn dequant_rle_larger_scale_multiplies() {
// nibble = 3, scale = 4, d = 2.0, min = 0
// expected: 2.0 * 4 * 3 - 0.0 = 24.0
// qs_byte = 0x33 → both nibbles = 3
let src = make_block(2.0, 0.0, 4, 0, 0x33);
let rle = encode(&src);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut out);
assert_all_close(&out, 24.0, 1e-4);
}
// =========================================================================
// Roundtrip: encode → dequantize must match original dequantize
// =========================================================================
#[test]
fn roundtrip_rle_mode_matches_original() {
// Uniform qs → RLE mode selected.
let src = make_block(2.0, 0.5, 3, 1, 0x37);
let rle = encode(&src);
assert!(rle.is_rle());
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
}
#[test]
fn roundtrip_raw_mode_matches_original() {
// Alternating bytes → raw mode selected; output must still be correct.
let mut qs = [0u8; QK_K / 2];
for (i, b) in qs.iter_mut().enumerate() {
*b = if i % 2 == 0 { 0x13 } else { 0x24 };
}
let src = make_block_with_qs(1.5, 0.25, 2, 1, qs);
let rle = encode(&src);
assert!(!rle.is_rle());
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
}
#[test]
fn roundtrip_two_run_block_matches_original() {
let mut qs = [0u8; QK_K / 2];
qs[..64].fill(0x59);
qs[64..].fill(0x8C);
let src = make_block_with_qs(3.0, 1.0, 5, 2, qs);
let rle = encode(&src);
assert!(rle.is_rle());
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
}
#[test]
fn roundtrip_many_short_runs_matches_original() {
// Four distinct runs of varying lengths → still compresses.
let mut qs = [0u8; QK_K / 2];
qs[..10].fill(0x11);
qs[10..30].fill(0x22);
qs[30..31].fill(0x33);
qs[31..].fill(0x44);
let src = make_block_with_qs(1.0, 0.5, 7, 3, qs);
let rle = encode(&src);
assert!(rle.is_rle(), "4-run block should compress");
assert_eq!(rle.rle_len(), 4);
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
}
#[test]
fn roundtrip_zero_qs_matches_original() {
let src = make_block(1.0, 0.5, 2, 1, 0x00);
let rle = encode(&src);
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
}
#[test]
fn roundtrip_nibble_split_low_high_correct() {
// qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3
// (sub-block 1 path). Verify both halves are dequantised correctly.
let src = make_block(1.0, 0.0, 1, 0, 0x37);
let rle = encode(&src);
let mut got = [0.0f32; QK_K];
let mut expected = [0.0f32; QK_K];
dequantize_block_q4k_rle(&rle, &mut got);
dequantize_block_q4k(&src, &mut expected);
assert_slices_close(&got, &expected, 1e-5);
// First 32 elements of each 64-element group → low nibble = 7.
assert_close(got[0], 7.0, 1e-5);
// Next 32 elements → high nibble = 3.
assert_close(got[32], 3.0, 1e-5);
}
// =========================================================================
// matmul_q4k_rle_fp16
// =========================================================================
#[test]
fn matmul_rle_1x256_times_256x1_all_ones() {
// A: 1×256, all weights = nibble 1, scale = 1, d = 1.0
// B: 256×1, all fp16 1.0
// C = dot([1.0; 256], [1.0; 256]) = 256.0
let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src)];
let b = fp16_uniform(QK_K, 1, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1);
assert_eq!(c.len(), 1);
assert_close(c[0], 256.0, 1e-3);
}
#[test]
fn matmul_rle_2x256_times_256x3_all_ones() {
let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src), encode(&src)];
let b = fp16_uniform(QK_K, 3, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3);
assert_eq!(c.len(), 6);
assert_all_close(&c, 256.0, 1e-3);
}
#[test]
fn matmul_rle_zero_a_gives_zero_c() {
let src = make_block(0.0, 0.0, 1, 0, 0xFF);
let a = vec![encode(&src)];
let b = fp16_uniform(QK_K, 4, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4);
assert_all_close(&c, 0.0, 0.0);
}
#[test]
fn matmul_rle_zero_b_gives_zero_c() {
let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src)];
let b = fp16_uniform(QK_K, 2, 0.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2);
assert_all_close(&c, 0.0, 0.0);
}
#[test]
fn matmul_rle_two_blocks_per_row() {
// A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0.
// Expected: 512.0
let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![encode(&src), encode(&src)];
let b = fp16_uniform(2 * QK_K, 1, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1);
assert_eq!(c.len(), 1);
assert_close(c[0], 512.0, 1e-3);
}
#[test]
fn matmul_rle_output_shape_m_times_n() {
// A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements.
let src = make_block(1.0, 0.0, 1, 0, 0x00);
let a: Vec<BlockQ4KRle> = (0..6).map(|_| encode(&src)).collect();
let b = fp16_uniform(2 * QK_K, 4, 0.0);
let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4);
assert_eq!(c.len(), 12);
}
#[test]
fn matmul_rle_scalar_b_scales_output() {
// Multiplying B by a scalar should scale C by the same factor.
let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0
let a = vec![encode(&src)];
let b1 = fp16_uniform(QK_K, 1, 1.0);
let b2 = fp16_uniform(QK_K, 1, 3.0);
let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1);
let c2 = matmul_q4k_rle_fp16(&a, &b2, 1, QK_K, 1);
assert_close(c2[0], c1[0] * 3.0, 1e-2);
}
#[test]
fn matmul_rle_matches_original_matmul_mixed_blocks() {
// Mix: first block uniform (RLE), second block alternating (raw).
// Both matmul implementations should produce identical results.
let src_rle = make_block(2.0, 0.5, 3, 1, 0x37);
let mut qs_raw = [0u8; QK_K / 2];
for (i, b) in qs_raw.iter_mut().enumerate() {
*b = if i % 2 == 0 { 0x13 } else { 0x24 };
}
let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw);
let a_orig: Vec<BlockQ4K> = vec![src_rle, src_raw];
let a_rle: Vec<BlockQ4KRle> = a_orig.iter().map(encode).collect();
// A: 1×512, B: 512×2
let b = fp16_uniform(2 * QK_K, 2, 1.0);
let c_orig = matmul_q4k_fp16(&a_orig, &b, 1, 2 * QK_K, 2);
let c_rle = matmul_q4k_rle_fp16(&a_rle, &b, 1, 2 * QK_K, 2);
assert_slices_close(&c_rle, &c_orig, 1e-4);
}
#[test]
fn matmul_rle_multiple_rows_multiple_blocks_per_row() {
// A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B.
// Each row dot product = 512.0; C should be all 512.0.
let src = make_block(1.0, 0.0, 1, 0, 0x11);
let a: Vec<BlockQ4KRle> = (0..4).map(|_| encode(&src)).collect();
let b = fp16_uniform(2 * QK_K, 3, 1.0);
let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3);
assert_eq!(c.len(), 6);
assert_all_close(&c, 512.0, 1e-3);
}
// =========================================================================
// Panic / contract checks
// =========================================================================
#[test]
fn matmul_rle_panics_when_k_not_multiple_of_qkk() {
let src = make_block(1.0, 0.0, 1, 0, 0x00);
let a = vec![encode(&src)];
let b = vec![0u16; 512];
let result = std::panic::catch_unwind(move || {
matmul_q4k_rle_fp16(&a, &b, 1, 512, 2);
});
assert!(result.is_err(), "should panic when k is not a multiple of QK_K");
}
#[test]
fn matmul_rle_panics_on_wrong_a_length() {
let src = make_block(1.0, 0.0, 1, 0, 0x00);
// m=2, k=QK_K requires 2 blocks; only 1 is provided.
let a = vec![encode(&src)];
let b = fp16_uniform(QK_K, 1, 1.0);
let result = std::panic::catch_unwind(move || {
matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1);
});
assert!(result.is_err(), "should panic on wrong A block count");
}
#[test]
fn matmul_rle_panics_on_wrong_b_length() {
let src = make_block(1.0, 0.0, 1, 0, 0x00);
let a = vec![encode(&src)];
// B is too short for k=QK_K, n=3.
let b = vec![0u16; 10];
let result = std::panic::catch_unwind(move || {
matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 3);
});
assert!(result.is_err(), "should panic on wrong B element count");
}
}