optimize rle a bit

This commit is contained in:
2026-04-12 15:56:39 -07:00
parent 4ca68c7f94
commit e80cd09415

View File

@@ -253,13 +253,113 @@ pub fn dequantize_block_q4k_rle(block: &BlockQ4KRle, out: &mut [f32; QK_K]) {
// Matrix multiplication C = A × B // Matrix multiplication C = A × B
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/// Accumulate the contribution of one RLE-encoded block into `c_row`.
///
/// For each `(value, count)` pair the dequantised weight is constant within
/// every 32-byte sub-block group, so the per-output-column dot-product
/// contribution reduces from `2 * run_len` multiplies to just `2`:
///
/// ```text
/// original: Σ_{l} ( dq_lo * B[ki_lo+l, j] + dq_hi * B[ki_hi+l, j] )
///
/// optimised: dq_lo * Σ_{l} B[ki_lo+l, j] + dq_hi * Σ_{l} B[ki_hi+l, j]
/// ```
///
/// A run that crosses a 32-byte group boundary (and thus a scale/min change)
/// is split at the boundary; each resulting segment is handled independently.
///
/// `sum_lo` and `sum_hi` are caller-provided scratch slices (length `≥ n`)
/// reused across calls to avoid repeated allocation.
fn accumulate_rle_block(
block: &BlockQ4KRle,
b: &[u16],
ki_base: usize, // first B-row index for this block (= b_idx * QK_K)
n: usize,
c_row: &mut [f32],
sum_lo: &mut [f32],
sum_hi: &mut [f32],
) {
let d = fp16_to_f32(block.d);
let dmin = fp16_to_f32(block.dmin);
let mut byte_pos = 0usize; // running cursor into the 128-byte qs payload
for p in 0..block.rle_len() {
let val = block.qs[2 * p];
let run = block.qs[2 * p + 1] as usize;
let lo = (val & 0x0F) as f32;
let hi = (val >> 4) as f32;
let mut remaining = run;
let mut pos = byte_pos;
while remaining > 0 {
// Clip the current run to the boundary of the 32-byte group so
// that the sub-block scale/min stays constant over the segment.
let group = pos / 32; // 0..4
let in_group = pos % 32; // byte offset within this group
let seg_len = remaining.min((group + 1) * 32 - pos);
// Constant dequantised values for both nibble levels in this group.
let (sc_lo, mn_lo) = get_scale_min(group * 2, &block.scales);
let (sc_hi, mn_hi) = get_scale_min(group * 2 + 1, &block.scales);
let dq_lo = d * sc_lo as f32 * lo - dmin * mn_lo as f32;
let dq_hi = d * sc_hi as f32 * hi - dmin * mn_hi as f32;
// Map byte positions to dequantised-output indices (0..QK_K):
// lo nibbles → group*64 + in_group .. + seg_len
// hi nibbles → group*64 + 32 + in_group .. + seg_len
let out_lo = group * 64 + in_group;
let out_hi = group * 64 + 32 + in_group;
// Sum B rows for every j across the segment (B accessed stride-1
// within each row — cache-friendly).
sum_lo[..n].fill(0.0);
sum_hi[..n].fill(0.0);
for l in 0..seg_len {
let base_lo = (ki_base + out_lo + l) * n;
let base_hi = (ki_base + out_hi + l) * n;
for j in 0..n {
sum_lo[j] += fp16_to_f32(b[base_lo + j]);
sum_hi[j] += fp16_to_f32(b[base_hi + j]);
}
}
// One multiply per output column instead of one per weight element.
for j in 0..n {
c_row[j] += dq_lo * sum_lo[j] + dq_hi * sum_hi[j];
}
pos += seg_len;
remaining -= seg_len;
}
byte_pos += run;
}
}
/// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32 /// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32
/// matrix **C**. /// matrix **C**.
/// ///
/// Identical semantics to [`crate::matmul_q4k_fp16`] but accepts /// For blocks in **RLE mode** (`IS_RLE = 1`) the intermediate decompressed row
/// [`BlockQ4KRle`] blocks. Each block is dequantised on the fly via /// is eliminated entirely. [`accumulate_rle_block`] works directly over the
/// [`dequantize_block_q4k_rle`], transparently handling mixed raw/RLE blocks /// `(value, count)` pairs: within each run the dequantised weight is constant
/// within the same matrix. /// across all elements in the run, so each output column `j` requires only
/// **2 multiplies per group-segment** rather than 2 per weight element:
///
/// ```text
/// c[i, j] += dq_lo * Σ B[ki_lo, j] + dq_hi * Σ B[ki_hi, j]
/// ───────────────────────────────────────────
/// summed over seg_len consecutive positions
/// ```
///
/// For a single-run block (all bytes identical) this reduces the multiply
/// count from `2 * QK_K = 512` to `2 * 4 = 8` per output column (4 groups,
/// 2 nibble levels each), while B is still read exactly once.
///
/// For blocks in **raw mode** (`IS_RLE = 0`) the block is dequantised into a
/// scratch buffer and its contribution is accumulated via a saxpy loop
/// (weight-outer, column-inner), which accesses B in row-major order.
/// ///
/// # Arguments /// # Arguments
/// ///
@@ -308,26 +408,39 @@ pub fn matmul_q4k_rle_fp16(
b.len() b.len()
); );
let mut c = vec![0.0f32; m * n]; let mut c = vec![0.0f32; m * n];
let mut a_row = vec![0.0f32; k];
// Scratch for raw-mode block dequantisation.
let mut block_buf = [0.0f32; QK_K]; let mut block_buf = [0.0f32; QK_K];
// Scratch for RLE-mode B-column sums; allocated once and reused per segment.
let mut sum_lo = vec![0.0f32; n];
let mut sum_hi = vec![0.0f32; n];
for i in 0..m { for i in 0..m {
// Dequantise row i of A into a_row (f32). let c_row = &mut c[i * n..(i + 1) * n];
for b_idx in 0..blocks_per_row {
let block = &a[i * blocks_per_row + b_idx];
dequantize_block_q4k_rle(block, &mut block_buf);
let start = b_idx * QK_K;
a_row[start..start + QK_K].copy_from_slice(&block_buf);
}
// Dot-product with each column of B. for b_idx in 0..blocks_per_row {
for j in 0..n { let block = &a[i * blocks_per_row + b_idx];
let mut sum = 0.0f32; let ki_base = b_idx * QK_K;
for ki in 0..k {
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]); if block.is_rle() {
// RLE path: accumulate directly from runs, no decompression.
accumulate_rle_block(
block, b, ki_base, n, c_row,
&mut sum_lo, &mut sum_hi,
);
} else {
// Raw path: dequantise once, then saxpy into c_row.
// Outer loop over weights (l) keeps B access stride-1 per row.
dequantize_block_q4k_rle(block, &mut block_buf);
for l in 0..QK_K {
let w = block_buf[l];
let b_off = (ki_base + l) * n;
for j in 0..n {
c_row[j] += w * fp16_to_f32(b[b_off + j]);
}
}
} }
c[i * n + j] = sum;
} }
} }