Allow variable coverage
This commit is contained in:
@@ -153,7 +153,7 @@ fn bench_encode(c: &mut Criterion) {
|
||||
group.bench_function("uniform", |b| {
|
||||
b.iter(|| {
|
||||
for blk in &uniform {
|
||||
black_box(encode(black_box(blk)));
|
||||
black_box(encode(black_box(blk), 0.0));
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -161,7 +161,7 @@ fn bench_encode(c: &mut Criterion) {
|
||||
group.bench_function("rle_optimal", |b| {
|
||||
b.iter(|| {
|
||||
for blk in &rle_opt {
|
||||
black_box(encode(black_box(blk)));
|
||||
black_box(encode(black_box(blk), 0.0));
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -186,8 +186,8 @@ fn bench_dequantize(c: &mut Criterion) {
|
||||
let q4k_uniform = uniform_blocks(1).into_iter().next().unwrap();
|
||||
let q4k_rle_opt = rle_optimal_blocks(1).into_iter().next().unwrap();
|
||||
|
||||
let rle_raw = encode(&q4k_uniform); // IS_RLE = 0
|
||||
let rle_rle = encode(&q4k_rle_opt); // IS_RLE = 1
|
||||
let rle_raw = encode(&q4k_uniform, 0.0); // IS_RLE = 0
|
||||
let rle_rle = encode(&q4k_rle_opt, 0.0); // IS_RLE = 1
|
||||
|
||||
// Confirm the fixtures ended up in the right encoding modes.
|
||||
assert!(!rle_raw.is_rle(), "uniform block should encode to raw mode");
|
||||
@@ -270,10 +270,10 @@ fn bench_matmul(c: &mut Criterion) {
|
||||
|
||||
// Build all four A variants and the shared B matrix for this config.
|
||||
let a_q4k_u: Vec<BlockQ4K> = uniform_blocks(m * bpr);
|
||||
let a_rle_u: Vec<BlockQ4KRle> = a_q4k_u.iter().map(encode).collect();
|
||||
let a_rle_u: Vec<BlockQ4KRle> = a_q4k_u.iter().map(|b| encode(b, 0.0)).collect();
|
||||
|
||||
let a_q4k_r: Vec<BlockQ4K> = rle_optimal_blocks(m * bpr);
|
||||
let a_rle_r: Vec<BlockQ4KRle> = a_q4k_r.iter().map(encode).collect();
|
||||
let a_rle_r: Vec<BlockQ4KRle> = a_q4k_r.iter().map(|b| encode(b, 0.0)).collect();
|
||||
|
||||
let b = fp16_ones(k, n);
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
// ── RLE encode (best of `trials`) ────────────────────────────────────────
|
||||
let (rle_blocks, t_enc) = bench(trials, || -> Vec<BlockQ4KRle> {
|
||||
blocks.iter().map(encode).collect()
|
||||
blocks.iter().map(|b| encode(b, 0.0)).collect()
|
||||
});
|
||||
|
||||
let n_rle = rle_blocks.iter().filter(|b| b.is_rle()).count();
|
||||
|
||||
@@ -101,11 +101,32 @@ fn fixed(s: &str, width: usize) -> String {
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("usage: {} <model.gguf>", args[0]);
|
||||
eprintln!("usage: {} <model.gguf> [--threshold <0.0..1.0>]", args[0]);
|
||||
eprintln!();
|
||||
eprintln!(" --threshold Minimum fraction of qs bytes that must be in runs of");
|
||||
eprintln!(" length ≥ 2 for a block to use RLE mode. Default: 0.0");
|
||||
eprintln!(" (use RLE whenever the pair count fits in 64 pairs).");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let path = &args[1];
|
||||
|
||||
// Parse optional --threshold flag from the remaining arguments.
|
||||
let mut threshold = 0.0f32;
|
||||
let mut idx = 2usize;
|
||||
while idx < args.len() {
|
||||
if args[idx] == "--threshold" {
|
||||
idx += 1;
|
||||
threshold = args.get(idx)
|
||||
.and_then(|s| s.parse::<f32>().ok())
|
||||
.filter(|&v| (0.0..=1.0).contains(&v))
|
||||
.unwrap_or_else(|| {
|
||||
eprintln!("error: --threshold requires a value in [0.0, 1.0]");
|
||||
std::process::exit(1);
|
||||
});
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// ── Parse header ─────────────────────────────────────────────────────────
|
||||
eprintln!("Parsing {path} …");
|
||||
let (tensors, data_start) = parse_header(path)?;
|
||||
@@ -122,6 +143,8 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
q4k_tensors.len(),
|
||||
other_count,
|
||||
);
|
||||
eprintln!(" RLE threshold: {threshold:.2} (blocks need ≥ {:.0}% of bytes in runs)",
|
||||
threshold * 100.0);
|
||||
eprintln!();
|
||||
|
||||
// ── Header row ───────────────────────────────────────────────────────────
|
||||
@@ -145,7 +168,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
let mut stats = TensorStats::new();
|
||||
|
||||
for_each_block(&mut file, data_start, tensor, |block| {
|
||||
let rle_block = encode(block);
|
||||
let rle_block = encode(block, threshold);
|
||||
stats.observe(rle_block.is_rle(), rle_block.rle_len());
|
||||
})?;
|
||||
|
||||
@@ -187,10 +210,15 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
if !any_rle {
|
||||
println!();
|
||||
println!("No blocks compressed with RLE — all weights are effectively random at");
|
||||
println!("the byte level, which is typical for trained Q4_K quantised weights.");
|
||||
println!("RLE compression only helps for structured weight matrices (binary,");
|
||||
println!("ternary, heavily pruned, or synthetic).");
|
||||
println!("No blocks used RLE at threshold {threshold:.2}.");
|
||||
if threshold < 0.01 {
|
||||
println!("All weights are effectively random at the byte level — typical for");
|
||||
println!("trained Q4_K weights. RLE only helps for structured weight matrices");
|
||||
println!("(binary, ternary, heavily pruned, or synthetic).");
|
||||
} else {
|
||||
println!("Try a lower --threshold (e.g. --threshold 0.0) to see whether any");
|
||||
println!("blocks have enough run structure to qualify at a looser threshold.");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
438
src/rle.rs
438
src/rle.rs
@@ -12,18 +12,16 @@
|
||||
//!
|
||||
//! ## RLE format (when `IS_RLE` = 1)
|
||||
//!
|
||||
//! - `flags >> 1` gives the number of `(value, count)` pairs stored in `qs`.
|
||||
//! - `n_pairs` gives the number of `(value, count)` pairs stored in `qs`.
|
||||
//! - For each pair `i`:
|
||||
//! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing
|
||||
//! as the raw format).
|
||||
//! - `qs[2*i + 1]` — the run length in bytes (1..=255).
|
||||
//! - The run lengths must sum to exactly 128 (the uncompressed `qs` size).
|
||||
//!
|
||||
//! RLE encoding is chosen only when the compressed representation is
|
||||
//! **strictly shorter** than the 128-byte raw payload, i.e. when
|
||||
//! `pairs * 2 < 128`. That caps the useful range at ≤ 63 pairs. The 7-bit
|
||||
//! `flags >> 1` sub-field can hold up to 127, so this ceiling is never a
|
||||
//! concern in practice.
|
||||
//! The 256-byte `qs` field can hold up to 128 `(value, count)` pairs — enough
|
||||
//! to represent even fully-random blocks where every byte differs from its
|
||||
//! neighbour.
|
||||
//!
|
||||
//! ## Constructing blocks
|
||||
//!
|
||||
@@ -49,42 +47,45 @@ pub const IS_RLE: u8 = 0x01;
|
||||
|
||||
/// A Q4_K super-block with optional byte-level RLE compression on the weights.
|
||||
///
|
||||
/// Identical to [`crate::BlockQ4K`] except for the additional [`flags`](Self::flags)
|
||||
/// byte inserted between `scales` and `qs`.
|
||||
/// Unlike [`crate::BlockQ4K`], this format is **not** binary-compatible with
|
||||
/// the GGUF on-disk layout. It uses a 256-byte `qs` field (vs the 128-byte
|
||||
/// field in `BlockQ4K`) so the RLE stream can store up to 128 `(value, count)`
|
||||
/// pairs — enough to represent even fully-random blocks where every byte
|
||||
/// differs from its neighbour.
|
||||
///
|
||||
/// Memory layout (repr C):
|
||||
/// Memory layout (`repr C`):
|
||||
///
|
||||
/// | Offset | Field | Size | Notes |
|
||||
/// |--------|------------|-------|--------------------------------|
|
||||
/// | 0 | `d` | 2 B | fp16 super-block scale |
|
||||
/// | 2 | `dmin` | fp16 super-block min scale | 2 B |
|
||||
/// | 4 | `scales` | 12 B | packed 6-bit sub-block params |
|
||||
/// | 16 | `flags` | 1 B | encoding flags (see below) |
|
||||
/// | 17 | `qs` | 128 B | raw nibbles or RLE stream |
|
||||
/// | 145 | (padding) | 1 B | implicit trailing alignment pad|
|
||||
/// | Offset | Field | Size | Notes |
|
||||
/// |--------|-----------|--------|------------------------------------|
|
||||
/// | 0 | `d` | 2 B | fp16 super-block scale |
|
||||
/// | 2 | `dmin` | 2 B | fp16 super-block min-scale |
|
||||
/// | 4 | `scales` | 12 B | packed 6-bit sub-block params |
|
||||
/// | 16 | `flags` | 1 B | bit 0 = `IS_RLE`; bits 1–7 unused |
|
||||
/// | 17 | `n_pairs` | 1 B | RLE pair count (0 when raw) |
|
||||
/// | 18 | `qs` | 256 B | raw nibbles (first 128 B) or RLE |
|
||||
///
|
||||
/// **sizeof = 146 bytes** (padded to 2-byte alignment imposed by `u16` fields).
|
||||
/// **sizeof = 274 bytes.**
|
||||
///
|
||||
/// ## `flags` bit layout
|
||||
/// ## `qs` interpretation
|
||||
///
|
||||
/// | Bits | Meaning |
|
||||
/// |------|---------------------------------------------------------------|
|
||||
/// | 0 | [`IS_RLE`] — 1 = `qs` is RLE-encoded, 0 = raw packed nibbles |
|
||||
/// | 1–7 | When `IS_RLE`=1: number of `(value, count)` pairs in `qs` |
|
||||
/// | `IS_RLE` | Meaning |
|
||||
/// |----------|--------------------------------------------------------------|
|
||||
/// | 0 | `qs[0..128]` holds raw packed nibbles (same as `BlockQ4K`) |
|
||||
/// | 1 | `qs[0..n_pairs*2]` holds `(value, count)` byte-pairs |
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct BlockQ4KRle {
|
||||
/// Super-block scale for quantised sub-block scales (fp16 bits).
|
||||
pub d: u16,
|
||||
/// Super-block scale for quantised sub-block mins (fp16 bits).
|
||||
pub dmin: u16,
|
||||
/// Packed 6-bit sub-block scales and mins (same layout as [`crate::BlockQ4K`]).
|
||||
pub scales: [u8; K_SCALE_SIZE],
|
||||
/// Encoding flags. Bit 0 = [`IS_RLE`]. Bits 1-7 = RLE pair count when
|
||||
/// `IS_RLE` is set.
|
||||
pub flags: u8,
|
||||
/// Raw packed-nibble weights (`IS_RLE` = 0) or RLE byte stream (`IS_RLE` = 1).
|
||||
pub qs: [u8; QK_K / 2],
|
||||
pub d: u16,
|
||||
pub dmin: u16,
|
||||
pub scales: [u8; K_SCALE_SIZE],
|
||||
/// Encoding flags. Only bit 0 (`IS_RLE`) is used; bits 1-7 are reserved.
|
||||
pub flags: u8,
|
||||
/// When `IS_RLE` is set: number of `(value, count)` byte-pairs in `qs`.
|
||||
/// Zero when in raw mode.
|
||||
pub n_pairs: u8,
|
||||
/// Raw packed-nibble weights (IS_RLE = 0, first 128 bytes) or RLE stream
|
||||
/// (IS_RLE = 1, first `n_pairs * 2` bytes).
|
||||
pub qs: [u8; QK_K], // 256 bytes
|
||||
}
|
||||
|
||||
impl BlockQ4KRle {
|
||||
@@ -94,12 +95,11 @@ impl BlockQ4KRle {
|
||||
self.flags & IS_RLE != 0
|
||||
}
|
||||
|
||||
/// Number of `(value, count)` byte-pairs stored at the start of `qs`.
|
||||
///
|
||||
/// Only meaningful when [`is_rle`](Self::is_rle) returns `true`.
|
||||
/// Number of `(value, count)` byte-pairs in `qs`.
|
||||
/// Only meaningful when `is_rle()` is true.
|
||||
#[inline]
|
||||
pub fn rle_len(&self) -> usize {
|
||||
(self.flags >> 1) as usize
|
||||
self.n_pairs as usize
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,24 +109,38 @@ impl BlockQ4KRle {
|
||||
|
||||
/// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block.
|
||||
///
|
||||
/// The 128-byte `qs` payload is scanned for runs of identical bytes. If the
|
||||
/// RLE representation fits in the same 128-byte field **and is strictly
|
||||
/// shorter** than the raw payload, it is stored with `IS_RLE` set. Otherwise
|
||||
/// the raw bytes are copied unchanged and `IS_RLE` is cleared.
|
||||
/// The `qs` payload is scanned for runs of equal consecutive bytes. RLE mode
|
||||
/// is chosen when **both** conditions hold:
|
||||
///
|
||||
/// The `d`, `dmin`, and `scales` fields are always copied verbatim.
|
||||
pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
|
||||
let mut raw = Vec::from(&block.qs);
|
||||
/// 1. **Coverage**: at least `min_coverage` fraction of the 128 `qs` bytes
|
||||
/// belong to runs of length ≥ 2. These are the bytes whose weights can be
|
||||
/// batched in `accumulate_rle_block`, replacing `2 * run_len` multiplies
|
||||
/// with just 2 per group-segment.
|
||||
///
|
||||
/// 2. **Capacity**: the pair count does not exceed 128 (the physical limit of
|
||||
/// the 256-byte `qs` field at 2 bytes per pair).
|
||||
///
|
||||
/// | `min_coverage` | Effect |
|
||||
/// |----------------|------------------------------------------------------|
|
||||
/// | `0.0` | RLE whenever pairs fit (≤ 128), regardless of runs |
|
||||
/// | `0.5` | RLE only if ≥ 50 % of bytes are in repeated runs |
|
||||
/// | `1.0` | RLE only when every byte is part of a run |
|
||||
pub fn encode(block: &BlockQ4K, min_coverage: f32) -> BlockQ4KRle {
|
||||
debug_assert!(
|
||||
(0.0..=1.0).contains(&min_coverage),
|
||||
"min_coverage must be in [0.0, 1.0], got {min_coverage}"
|
||||
);
|
||||
|
||||
// Sort the raw numbers
|
||||
raw.sort();
|
||||
let raw = &block.qs; // [u8; 128]
|
||||
|
||||
// Scan the 128-byte raw payload for runs of equal bytes.
|
||||
let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(64);
|
||||
// Scan for runs of equal consecutive bytes.
|
||||
// Track long_run_bytes: bytes in runs of length ≥ 2 (the bytes that
|
||||
// benefit from RLE in the matmul).
|
||||
let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(QK_K / 2);
|
||||
let mut long_run_bytes = 0usize;
|
||||
let mut i = 0usize;
|
||||
while i < raw.len() {
|
||||
let val = raw[i];
|
||||
// Count consecutive equal bytes; saturate at u8::MAX to stay in-range.
|
||||
let mut run = 1u8;
|
||||
while i + (run as usize) < raw.len()
|
||||
&& raw[i + (run as usize)] == val
|
||||
@@ -135,38 +149,40 @@ pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
|
||||
run += 1;
|
||||
}
|
||||
pairs.push((val, run));
|
||||
if run >= 2 {
|
||||
long_run_bytes += run as usize;
|
||||
}
|
||||
i += run as usize;
|
||||
}
|
||||
|
||||
// Only switch to RLE when the encoded form is strictly smaller than the
|
||||
// raw payload. Because each pair costs 2 bytes and the raw payload is
|
||||
// 128 bytes, the condition pairs.len() * 2 < 128 also guarantees that
|
||||
// pairs.len() ≤ 63, which fits in bits 1-7 of the flags byte.
|
||||
if pairs.len() * 2 < raw.len() {
|
||||
let n = pairs.len();
|
||||
debug_assert!(n <= 63, "RLE pair count {n} unexpectedly exceeds 63");
|
||||
// Coverage: fraction of qs bytes that are in non-singleton runs.
|
||||
let coverage = long_run_bytes as f32 / raw.len() as f32;
|
||||
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
if pairs.len() <= QK_K / 2 && coverage >= min_coverage {
|
||||
let n = pairs.len();
|
||||
let mut qs = [0u8; QK_K];
|
||||
for (k, &(val, count)) in pairs.iter().enumerate() {
|
||||
qs[2 * k] = val;
|
||||
qs[2 * k + 1] = count;
|
||||
}
|
||||
|
||||
BlockQ4KRle {
|
||||
d: block.d,
|
||||
dmin: block.dmin,
|
||||
scales: block.scales,
|
||||
flags: IS_RLE | ((n as u8) << 1),
|
||||
d: block.d,
|
||||
dmin: block.dmin,
|
||||
scales: block.scales,
|
||||
flags: IS_RLE,
|
||||
n_pairs: n as u8,
|
||||
qs,
|
||||
}
|
||||
} else {
|
||||
// No space savings — copy raw bytes and leave IS_RLE clear.
|
||||
let mut qs = [0u8; QK_K];
|
||||
qs[..QK_K / 2].copy_from_slice(&block.qs);
|
||||
BlockQ4KRle {
|
||||
d: block.d,
|
||||
dmin: block.dmin,
|
||||
scales: block.scales,
|
||||
flags: 0,
|
||||
qs: block.qs,
|
||||
d: block.d,
|
||||
dmin: block.dmin,
|
||||
scales: block.scales,
|
||||
flags: 0,
|
||||
n_pairs: 0,
|
||||
qs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -183,27 +199,26 @@ pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
|
||||
/// Panics if the decoded RLE stream does not sum to exactly 128 bytes.
|
||||
fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] {
|
||||
if !block.is_rle() {
|
||||
return block.qs;
|
||||
// First QK_K/2 bytes of qs hold the raw packed nibbles.
|
||||
block.qs[..QK_K / 2].try_into().unwrap()
|
||||
} else {
|
||||
let n = block.rle_len();
|
||||
let mut raw = [0u8; QK_K / 2];
|
||||
let mut pos = 0usize;
|
||||
for i in 0..n {
|
||||
let val = block.qs[2 * i];
|
||||
let count = block.qs[2 * i + 1] as usize;
|
||||
raw[pos..pos + count].fill(val);
|
||||
pos += count;
|
||||
}
|
||||
debug_assert_eq!(
|
||||
pos,
|
||||
QK_K / 2,
|
||||
"RLE run lengths sum to {pos}, expected {}",
|
||||
QK_K / 2
|
||||
);
|
||||
raw
|
||||
}
|
||||
|
||||
let n = block.rle_len();
|
||||
let mut raw = [0u8; QK_K / 2];
|
||||
let mut pos = 0usize;
|
||||
|
||||
for i in 0..n {
|
||||
let val = block.qs[2 * i];
|
||||
let count = block.qs[2 * i + 1] as usize;
|
||||
raw[pos..pos + count].fill(val);
|
||||
pos += count;
|
||||
}
|
||||
|
||||
debug_assert_eq!(
|
||||
pos,
|
||||
QK_K / 2,
|
||||
"RLE run lengths sum to {pos}, expected {}",
|
||||
QK_K / 2
|
||||
);
|
||||
raw
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -570,17 +585,18 @@ mod tests {
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn block_q4k_rle_size_is_146_bytes() {
|
||||
// d(2) + dmin(2) + scales(12) + flags(1) + qs(128) = 145 raw bytes,
|
||||
// rounded up to 146 by the 2-byte alignment imposed by the u16 fields.
|
||||
assert_eq!(core::mem::size_of::<BlockQ4KRle>(), 146);
|
||||
fn block_q4k_rle_size_is_274_bytes() {
|
||||
// d(2) + dmin(2) + scales(12) + flags(1) + n_pairs(1) + qs(256) = 274 bytes.
|
||||
// No padding needed: struct is already 2-byte aligned and 274 is even.
|
||||
assert_eq!(core::mem::size_of::<BlockQ4KRle>(), 274);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_q4k_rle_is_two_bytes_larger_than_block_q4k() {
|
||||
fn block_q4k_rle_is_130_bytes_larger_than_block_q4k() {
|
||||
// BlockQ4K = 144 bytes, BlockQ4KRle = 274 bytes, delta = 130.
|
||||
assert_eq!(
|
||||
core::mem::size_of::<BlockQ4KRle>(),
|
||||
core::mem::size_of::<BlockQ4K>() + 2,
|
||||
core::mem::size_of::<BlockQ4K>() + 130,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -591,7 +607,7 @@ mod tests {
|
||||
#[test]
|
||||
fn is_rle_false_when_flag_clear() {
|
||||
let b = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2],
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs: [0; QK_K],
|
||||
};
|
||||
assert!(!b.is_rle());
|
||||
}
|
||||
@@ -599,7 +615,7 @@ mod tests {
|
||||
#[test]
|
||||
fn rle_len_zero_when_flag_clear() {
|
||||
let b = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2],
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs: [0; QK_K],
|
||||
};
|
||||
assert_eq!(b.rle_len(), 0);
|
||||
}
|
||||
@@ -608,19 +624,21 @@ mod tests {
|
||||
fn is_rle_true_when_flag_set() {
|
||||
let b = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||
flags: IS_RLE | (5u8 << 1),
|
||||
qs: [0; QK_K / 2],
|
||||
flags: IS_RLE,
|
||||
n_pairs: 5,
|
||||
qs: [0; QK_K],
|
||||
};
|
||||
assert!(b.is_rle());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rle_len_reports_pair_count_from_flags() {
|
||||
for n in [0usize, 1, 7, 31, 63] {
|
||||
fn rle_len_reports_pair_count_from_n_pairs() {
|
||||
for n in [0usize, 1, 7, 31, 63, 128] {
|
||||
let b = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||
flags: IS_RLE | ((n as u8) << 1),
|
||||
qs: [0; QK_K / 2],
|
||||
flags: if n > 0 { IS_RLE } else { 0 },
|
||||
n_pairs: n as u8,
|
||||
qs: [0; QK_K],
|
||||
};
|
||||
assert_eq!(b.rle_len(), n, "expected rle_len {n}");
|
||||
}
|
||||
@@ -632,61 +650,64 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn encode_uniform_qs_uses_rle() {
|
||||
// 128 identical bytes → 1 pair → 2 bytes < 128 raw.
|
||||
// 128 identical bytes → 1 pair → 2 bytes stored in qs.
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x77);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle(), "uniform qs should trigger RLE mode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_uniform_qs_rle_len_is_one() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x55);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert_eq!(rle.rle_len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_uniform_qs_rle_entry_is_correct() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0xAB);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte");
|
||||
assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_alternating_bytes_stays_raw() {
|
||||
// 128 single-byte runs → 128 pairs → 256 bytes ≥ 128 raw → raw mode.
|
||||
// Alternating 0xAA / 0x55 → 128 singleton pairs, coverage = 0%.
|
||||
// At threshold 0.01 the 0% coverage fails → raw mode.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
for (i, b) in qs.iter_mut().enumerate() {
|
||||
*b = if i % 2 == 0 { 0xAA } else { 0x55 };
|
||||
}
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
let rle = encode(&src);
|
||||
assert!(!rle.is_rle(), "alternating bytes cannot be compressed → raw mode");
|
||||
let rle = encode(&src, 0.01);
|
||||
assert!(!rle.is_rle(), "0% coverage fails any threshold > 0 → raw mode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_raw_mode_copies_qs_verbatim() {
|
||||
// Three-byte cycle of distinct values → 128 runs of 1 byte each,
|
||||
// coverage = 0%. At threshold 0.01 the 0% coverage fails → raw mode.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
for (i, b) in qs.iter_mut().enumerate() {
|
||||
// Three-byte cycle of distinct values → 128 runs of 1 byte each.
|
||||
*b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 };
|
||||
}
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.01);
|
||||
assert!(!rle.is_rle());
|
||||
assert_eq!(rle.qs, qs, "raw mode must preserve qs bytes unchanged");
|
||||
// Raw mode copies the 128-byte qs into the first half of the 256-byte field.
|
||||
assert_eq!(&rle.qs[..QK_K / 2], &qs[..], "raw mode must preserve qs bytes in first 128 bytes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_two_runs_uses_rle_and_stores_correct_pairs() {
|
||||
// Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22.
|
||||
// → 2 pairs = 4 bytes < 128 bytes raw.
|
||||
// → 2 pairs = 4 bytes.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
qs[..64].fill(0x11);
|
||||
qs[64..].fill(0x22);
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle());
|
||||
assert_eq!(rle.rle_len(), 2);
|
||||
assert_eq!(rle.qs[0], 0x11, "first pair: value");
|
||||
@@ -698,7 +719,7 @@ mod tests {
|
||||
#[test]
|
||||
fn encode_63_pairs_uses_rle() {
|
||||
// Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes.
|
||||
// 63 pairs × 2 = 126 bytes < 128 → RLE should be chosen.
|
||||
// 63 pairs × 2 = 126 bytes; 63 ≤ 128 → RLE should be chosen.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
let mut pos = 0usize;
|
||||
for run in 0..62usize {
|
||||
@@ -712,15 +733,15 @@ mod tests {
|
||||
qs[pos..].fill(0xFE);
|
||||
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle(), "63 pairs should use RLE");
|
||||
assert_eq!(rle.rle_len(), 63);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_64_pairs_stays_raw() {
|
||||
// 64 runs of 2 bytes each = 128 bytes total.
|
||||
// 64 pairs × 2 = 128 bytes, which is NOT strictly less than 128 → raw.
|
||||
fn encode_64_pairs_uses_rle_at_zero_threshold() {
|
||||
// 64 runs of 2 bytes each = 128 bytes total, coverage = 100%.
|
||||
// pairs (64) ≤ 128 AND 100% ≥ 0.0 → RLE mode.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
let mut pos = 0usize;
|
||||
for run in 0..64usize {
|
||||
@@ -730,14 +751,93 @@ mod tests {
|
||||
pos += 2;
|
||||
}
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
let rle = encode(&src);
|
||||
assert!(!rle.is_rle(), "64 pairs offers no saving → raw mode");
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle(), "64 pairs, 100% coverage, threshold 0.0 → RLE");
|
||||
assert_eq!(rle.rle_len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_128_pairs_uses_rle_at_zero_threshold() {
|
||||
// 128 distinct consecutive bytes = 128 singleton runs = 128 pairs.
|
||||
// With old cap (64 pairs), this was always raw.
|
||||
// With new cap (128 pairs), threshold 0.0 accepts it.
|
||||
// Coverage = 0 % (all singletons) → threshold > 0.0 rejects it.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
for (i, b) in qs.iter_mut().enumerate() {
|
||||
*b = i as u8; // 0x00, 0x01, ..., 0x7F — all distinct, all singletons
|
||||
}
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
|
||||
assert!(
|
||||
encode(&src, 0.0).is_rle(),
|
||||
"128 pairs ≤ 128 limit AND 0% ≥ 0.0 → RLE at zero threshold"
|
||||
);
|
||||
assert_eq!(encode(&src, 0.0).rle_len(), 128);
|
||||
|
||||
assert!(
|
||||
!encode(&src, 0.01).is_rle(),
|
||||
"0% coverage fails any threshold > 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_coverage_threshold_rejects_low_coverage_block() {
|
||||
// Construct: 63 singletons + 1 run of 65 bytes = 64 pairs.
|
||||
// coverage = 65/128 ≈ 50.8%.
|
||||
// threshold 0.50 accepts it; threshold 0.60 rejects it.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
qs[0] = 0x01;
|
||||
for i in 1..63usize {
|
||||
// Distinct odd bytes, none equal to 0x01 or adjacent values.
|
||||
qs[i] = (i as u8).wrapping_mul(2).wrapping_add(5);
|
||||
}
|
||||
qs[63..].fill(0xAB); // 65-byte run; qs[62] = 62*2+5 = 129 → wraps to 0x81 ≠ 0xAB ✓
|
||||
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
assert!(
|
||||
encode(&src, 0.50).is_rle(),
|
||||
"50.8% coverage should meet 50% threshold"
|
||||
);
|
||||
assert!(
|
||||
!encode(&src, 0.60).is_rle(),
|
||||
"50.8% coverage should fail 60% threshold"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_coverage_zero_threshold_always_uses_rle_when_pairs_fit() {
|
||||
// Any block whose runs produce ≤ 128 pairs uses RLE at threshold 0.0,
|
||||
// regardless of how many singletons it contains.
|
||||
// Use the 63-pair block from encode_63_pairs_uses_rle.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
let mut pos = 0usize;
|
||||
for run in 0..62usize {
|
||||
let v = (run as u8).wrapping_mul(3).wrapping_add(1);
|
||||
qs[pos] = v;
|
||||
qs[pos + 1] = v;
|
||||
pos += 2;
|
||||
}
|
||||
qs[pos..].fill(0xFE);
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
assert!(encode(&src, 0.0).is_rle());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_coverage_one_threshold_requires_total_coverage() {
|
||||
// A block with even one singleton byte fails the 100% threshold.
|
||||
// Build: 1 singleton + 1 run of 127 bytes = 2 pairs, coverage = 127/128 ≈ 99.2%.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
qs[0] = 0x01; // singleton (value distinct from rest)
|
||||
qs[1..].fill(0x02); // 127-byte run
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
assert!(!encode(&src, 1.0).is_rle(), "99.2% coverage should fail 100% threshold");
|
||||
assert!(encode(&src, 0.99).is_rle(), "99.2% coverage should meet 99% threshold");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_preserves_d_dmin_scales() {
|
||||
let src = make_block(2.0, 0.5, 3, 2, 0x00);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert_eq!(rle.d, src.d);
|
||||
assert_eq!(rle.dmin, src.dmin);
|
||||
assert_eq!(rle.scales, src.scales);
|
||||
@@ -749,25 +849,24 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decode_qs_raw_mode_returns_qs_unchanged() {
|
||||
// Build a raw BlockQ4KRle (flags = 0) with a non-trivial qs pattern.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
for (i, b) in qs.iter_mut().enumerate() { *b = i as u8; }
|
||||
let mut qs = [0u8; QK_K];
|
||||
for (i, b) in qs[..QK_K / 2].iter_mut().enumerate() { *b = i as u8; }
|
||||
let rle = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs,
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, n_pairs: 0, qs,
|
||||
};
|
||||
assert_eq!(decode_qs(&rle), qs);
|
||||
let expected: [u8; QK_K / 2] = qs[..QK_K / 2].try_into().unwrap();
|
||||
assert_eq!(decode_qs(&rle), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_qs_rle_expands_two_pair_stream() {
|
||||
// Hand-craft an RLE block: [0xAA × 64, 0xBB × 64].
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
let mut qs = [0u8; QK_K];
|
||||
qs[0] = 0xAA; qs[1] = 64;
|
||||
qs[2] = 0xBB; qs[3] = 64;
|
||||
let rle = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||
flags: IS_RLE | (2u8 << 1),
|
||||
qs,
|
||||
flags: IS_RLE, n_pairs: 2, qs,
|
||||
};
|
||||
let expanded = decode_qs(&rle);
|
||||
assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA");
|
||||
@@ -776,12 +875,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decode_qs_rle_single_run_covers_all() {
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
let mut qs = [0u8; QK_K];
|
||||
qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes
|
||||
let rle = BlockQ4KRle {
|
||||
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||
flags: IS_RLE | (1u8 << 1),
|
||||
qs,
|
||||
flags: IS_RLE, n_pairs: 1, qs,
|
||||
};
|
||||
let expanded = decode_qs(&rle);
|
||||
assert!(expanded.iter().all(|&b| b == 0xCD));
|
||||
@@ -794,7 +892,7 @@ mod tests {
|
||||
#[test]
|
||||
fn dequant_rle_zero_d_all_outputs_zero() {
|
||||
let src = make_block(0.0, 0.0, 1, 0, 0x77);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut out);
|
||||
assert_all_close(&out, 0.0, 0.0);
|
||||
@@ -805,7 +903,7 @@ mod tests {
|
||||
// qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0.
|
||||
// expected: 1.0 * 1 * 1 - 0.0 = 1.0
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut out);
|
||||
assert_all_close(&out, 1.0, 1e-5);
|
||||
@@ -816,7 +914,7 @@ mod tests {
|
||||
// nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0
|
||||
// expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0
|
||||
let src = make_block(1.0, 1.0, 1, 2, 0x00);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut out);
|
||||
assert_all_close(&out, -2.0, 1e-5);
|
||||
@@ -827,7 +925,7 @@ mod tests {
|
||||
// qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0.
|
||||
// expected: 1.0 * 1 * 15 - 0.0 = 15.0
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0xFF);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut out);
|
||||
assert_all_close(&out, 15.0, 1e-5);
|
||||
@@ -836,7 +934,7 @@ mod tests {
|
||||
#[test]
|
||||
fn dequant_rle_output_count_is_qk_k() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut out);
|
||||
assert_eq!(out.len(), QK_K);
|
||||
@@ -848,7 +946,7 @@ mod tests {
|
||||
// expected: 2.0 * 4 * 3 - 0.0 = 24.0
|
||||
// qs_byte = 0x33 → both nibbles = 3
|
||||
let src = make_block(2.0, 0.0, 4, 0, 0x33);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut out);
|
||||
assert_all_close(&out, 24.0, 1e-4);
|
||||
@@ -862,7 +960,7 @@ mod tests {
|
||||
fn roundtrip_rle_mode_matches_original() {
|
||||
// Uniform qs → RLE mode selected.
|
||||
let src = make_block(2.0, 0.5, 3, 1, 0x37);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle());
|
||||
|
||||
let mut got = [0.0f32; QK_K];
|
||||
@@ -874,13 +972,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn roundtrip_raw_mode_matches_original() {
|
||||
// Alternating bytes → raw mode selected; output must still be correct.
|
||||
// Alternating bytes → 128 singleton pairs, coverage = 0%.
|
||||
// Use threshold 0.01 to force raw mode (0% < 0.01).
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
for (i, b) in qs.iter_mut().enumerate() {
|
||||
*b = if i % 2 == 0 { 0x13 } else { 0x24 };
|
||||
}
|
||||
let src = make_block_with_qs(1.5, 0.25, 2, 1, qs);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.01);
|
||||
assert!(!rle.is_rle());
|
||||
|
||||
let mut got = [0.0f32; QK_K];
|
||||
@@ -896,7 +995,7 @@ mod tests {
|
||||
qs[..64].fill(0x59);
|
||||
qs[64..].fill(0x8C);
|
||||
let src = make_block_with_qs(3.0, 1.0, 5, 2, qs);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle());
|
||||
|
||||
let mut got = [0.0f32; QK_K];
|
||||
@@ -915,7 +1014,7 @@ mod tests {
|
||||
qs[30..31].fill(0x33);
|
||||
qs[31..].fill(0x44);
|
||||
let src = make_block_with_qs(1.0, 0.5, 7, 3, qs);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle(), "4-run block should compress");
|
||||
assert_eq!(rle.rle_len(), 4);
|
||||
|
||||
@@ -929,7 +1028,7 @@ mod tests {
|
||||
#[test]
|
||||
fn roundtrip_zero_qs_matches_original() {
|
||||
let src = make_block(1.0, 0.5, 2, 1, 0x00);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut got = [0.0f32; QK_K];
|
||||
let mut expected = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut got);
|
||||
@@ -942,7 +1041,7 @@ mod tests {
|
||||
// qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3
|
||||
// (sub-block 1 path). Verify both halves are dequantised correctly.
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x37);
|
||||
let rle = encode(&src);
|
||||
let rle = encode(&src, 0.0);
|
||||
let mut got = [0.0f32; QK_K];
|
||||
let mut expected = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut got);
|
||||
@@ -954,6 +1053,26 @@ mod tests {
|
||||
assert_close(got[32], 3.0, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_128_singleton_pairs_matches_original() {
|
||||
// All-distinct bytes → 128 pairs, 0% coverage.
|
||||
// encode at threshold 0.0 → RLE; dequantize must match baseline.
|
||||
let mut qs = [0u8; QK_K / 2];
|
||||
for (i, b) in qs.iter_mut().enumerate() {
|
||||
*b = (i as u8).wrapping_mul(3).wrapping_add(7);
|
||||
}
|
||||
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||
let rle = encode(&src, 0.0);
|
||||
assert!(rle.is_rle());
|
||||
assert_eq!(rle.rle_len(), 128);
|
||||
|
||||
let mut got = [0.0f32; QK_K];
|
||||
let mut expected = [0.0f32; QK_K];
|
||||
dequantize_block_q4k_rle(&rle, &mut got);
|
||||
dequantize_block_q4k(&src, &mut expected);
|
||||
assert_slices_close(&got, &expected, 1e-5);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// matmul_q4k_rle_fp16
|
||||
// =========================================================================
|
||||
@@ -964,7 +1083,7 @@ mod tests {
|
||||
// B: 256×1, all fp16 1.0
|
||||
// C = dot([1.0; 256], [1.0; 256]) = 256.0
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
let b = fp16_uniform(QK_K, 1, 1.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1);
|
||||
assert_eq!(c.len(), 1);
|
||||
@@ -974,7 +1093,7 @@ mod tests {
|
||||
#[test]
|
||||
fn matmul_rle_2x256_times_256x3_all_ones() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a = vec![encode(&src), encode(&src)];
|
||||
let a = vec![encode(&src, 0.0), encode(&src, 0.0)];
|
||||
let b = fp16_uniform(QK_K, 3, 1.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3);
|
||||
assert_eq!(c.len(), 6);
|
||||
@@ -984,7 +1103,7 @@ mod tests {
|
||||
#[test]
|
||||
fn matmul_rle_zero_a_gives_zero_c() {
|
||||
let src = make_block(0.0, 0.0, 1, 0, 0xFF);
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
let b = fp16_uniform(QK_K, 4, 1.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4);
|
||||
assert_all_close(&c, 0.0, 0.0);
|
||||
@@ -993,7 +1112,7 @@ mod tests {
|
||||
#[test]
|
||||
fn matmul_rle_zero_b_gives_zero_c() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
let b = fp16_uniform(QK_K, 2, 0.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2);
|
||||
assert_all_close(&c, 0.0, 0.0);
|
||||
@@ -1004,7 +1123,7 @@ mod tests {
|
||||
// A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0.
|
||||
// Expected: 512.0
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a = vec![encode(&src), encode(&src)];
|
||||
let a = vec![encode(&src, 0.0), encode(&src, 0.0)];
|
||||
let b = fp16_uniform(2 * QK_K, 1, 1.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1);
|
||||
assert_eq!(c.len(), 1);
|
||||
@@ -1015,7 +1134,7 @@ mod tests {
|
||||
fn matmul_rle_output_shape_m_times_n() {
|
||||
// A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements.
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||
let a: Vec<BlockQ4KRle> = (0..6).map(|_| encode(&src)).collect();
|
||||
let a: Vec<BlockQ4KRle> = (0..6).map(|_| encode(&src, 0.0)).collect();
|
||||
let b = fp16_uniform(2 * QK_K, 4, 0.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4);
|
||||
assert_eq!(c.len(), 12);
|
||||
@@ -1025,7 +1144,7 @@ mod tests {
|
||||
fn matmul_rle_scalar_b_scales_output() {
|
||||
// Multiplying B by a scalar should scale C by the same factor.
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
let b1 = fp16_uniform(QK_K, 1, 1.0);
|
||||
let b2 = fp16_uniform(QK_K, 1, 3.0);
|
||||
let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1);
|
||||
@@ -1046,7 +1165,8 @@ mod tests {
|
||||
let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw);
|
||||
|
||||
let a_orig: Vec<BlockQ4K> = vec![src_rle, src_raw];
|
||||
let a_rle: Vec<BlockQ4KRle> = a_orig.iter().map(encode).collect();
|
||||
// Use threshold 0.01 so the alternating block (0% coverage) stays raw.
|
||||
let a_rle: Vec<BlockQ4KRle> = a_orig.iter().map(|b| encode(b, 0.01)).collect();
|
||||
|
||||
// A: 1×512, B: 512×2
|
||||
let b = fp16_uniform(2 * QK_K, 2, 1.0);
|
||||
@@ -1060,7 +1180,7 @@ mod tests {
|
||||
// A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B.
|
||||
// Each row dot product = 512.0; C should be all 512.0.
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a: Vec<BlockQ4KRle> = (0..4).map(|_| encode(&src)).collect();
|
||||
let a: Vec<BlockQ4KRle> = (0..4).map(|_| encode(&src, 0.0)).collect();
|
||||
let b = fp16_uniform(2 * QK_K, 3, 1.0);
|
||||
let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3);
|
||||
assert_eq!(c.len(), 6);
|
||||
@@ -1074,7 +1194,7 @@ mod tests {
|
||||
#[test]
|
||||
fn matmul_rle_panics_when_k_not_multiple_of_qkk() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
let b = vec![0u16; 512];
|
||||
let result = std::panic::catch_unwind(move || {
|
||||
matmul_q4k_rle_fp16(&a, &b, 1, 512, 2);
|
||||
@@ -1086,7 +1206,7 @@ mod tests {
|
||||
fn matmul_rle_panics_on_wrong_a_length() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||
// m=2, k=QK_K requires 2 blocks; only 1 is provided.
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
let b = fp16_uniform(QK_K, 1, 1.0);
|
||||
let result = std::panic::catch_unwind(move || {
|
||||
matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1);
|
||||
@@ -1097,7 +1217,7 @@ mod tests {
|
||||
#[test]
|
||||
fn matmul_rle_panics_on_wrong_b_length() {
|
||||
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||
let a = vec![encode(&src)];
|
||||
let a = vec![encode(&src, 0.0)];
|
||||
// B is too short for k=QK_K, n=3.
|
||||
let b = vec![0u16; 10];
|
||||
let result = std::panic::catch_unwind(move || {
|
||||
|
||||
Reference in New Issue
Block a user