Initial commit

This commit is contained in:
2026-04-12 14:45:20 -07:00
commit 16d1f37ae5
5 changed files with 834 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/target

7
Cargo.lock generated Normal file
View File

@@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "matrix-testing"
version = "0.1.0"

6
Cargo.toml Normal file
View File

@@ -0,0 +1,6 @@
[package]
name = "matrix-testing"
version = "0.1.0"
edition = "2024"
[dependencies]

738
src/lib.rs Normal file
View File

@@ -0,0 +1,738 @@
//! Naive Q4_K_M × FP16 matrix multiplication.
//!
//! Q4_K_M (called `block_q4_K` in GGML) is a 4-bit K-quant format with 256
//! elements per super-block. Each super-block carries:
//!
//! - `d` (fp16) super-block scale for the sub-block scales
//! - `dmin` (fp16) super-block scale for the sub-block mins
//! - `scales` [12 u8] 8 pairs of (scale, min), each 6-bit, packed together
//! - `qs` [128 u8] 256 values at 4 bits each (two values per byte)
//!
//! The dequantised value for nibble `q` in sub-block `s` is:
//!
//! ```text
//! d * scales[s] * q - dmin * mins[s]
//! ```
//!
//! This library deliberately uses the simplest possible O(M·N·K) algorithm:
//! dequantise each row of A into f32, convert each element of B from fp16 to
//! f32, accumulate dot-products. No SIMD, no tiling, no tricks.
// ---------------------------------------------------------------------------
// Constants matching GGML's ggml-common.h
// ---------------------------------------------------------------------------
/// Number of elements in one Q4_K super-block.
pub const QK_K: usize = 256;
/// Number of bytes used to store the 8 (scale, min) pairs.
pub const K_SCALE_SIZE: usize = 12;
// ---------------------------------------------------------------------------
// Block definition
// ---------------------------------------------------------------------------
/// One Q4_K super-block, binary-compatible with GGML's `block_q4_K`.
///
/// Memory layout (in order):
///
/// | Offset | Field | Size |
/// |--------|----------|-------|
/// | 0 | `d` | 2 B |
/// | 2 | `dmin` | 2 B |
/// | 4 | `scales` | 12 B |
/// | 16 | `qs` | 128 B |
///
/// Total: 144 bytes.
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct BlockQ4K {
/// Super-block scale for the quantised sub-block scales (fp16 bits).
pub d: u16,
/// Super-block scale for the quantised sub-block mins (fp16 bits).
pub dmin: u16,
/// Packed 6-bit sub-block scales and mins.
/// 8 scales + 8 mins (6 bits each) encoded into 12 bytes.
pub scales: [u8; K_SCALE_SIZE],
/// 4-bit quantised weights: two weights per byte, 128 bytes for 256 values.
pub qs: [u8; QK_K / 2],
}
// ---------------------------------------------------------------------------
// FP16 → f32 conversion (no external dependencies)
// ---------------------------------------------------------------------------
/// Convert an IEEE 754 half-precision float stored as raw `u16` bits to `f32`.
///
/// Handles all IEEE 754 cases: ±zero, normal numbers, infinity, NaN.
/// Subnormal fp16 values (exponent field = 0, non-zero mantissa) are treated
/// as zero — they are vanishingly small and irrelevant for LLM weights.
#[inline]
pub fn fp16_to_f32(bits: u16) -> f32 {
let sign = (bits as u32 & 0x8000) << 16; // sign bit → f32 position
let exp_mant = bits as u32 & 0x7FFF;
let f32_bits = if (bits & 0x7C00) == 0 {
// ±zero or subnormal (exponent field = 0) → treat as signed zero.
sign
} else if (bits & 0x7C00) == 0x7C00 {
// Infinity or NaN: all exponent bits set.
sign | 0x7F80_0000 | ((bits as u32 & 0x03FF) << 13)
} else {
// Normal number: rebias exponent from fp16 (bias 15) to f32 (bias 127).
// Δbias = 127 15 = 112 = 112 × 1024 in the 13-shifted representation.
sign | ((exp_mant + (112 << 10)) << 13)
};
f32::from_bits(f32_bits)
}
// ---------------------------------------------------------------------------
// Scale extraction mirrors GGML's get_scale_min_k4
// ---------------------------------------------------------------------------
/// Extract the 6-bit scale and 6-bit min for sub-block index `j` (0..8).
///
/// GGML packs 8 pairs of 6-bit values into 12 bytes using a two-part scheme:
///
/// **j = 0..3**
/// ```text
/// scale = scales[j] & 0x3F
/// min = scales[j + 4] & 0x3F
/// ```
///
/// **j = 4..7**
/// ```text
/// scale = (scales[j+4] & 0x0F) | ((scales[j-4] >> 6) << 4)
/// min = (scales[j+4] >> 4) | ((scales[j] >> 6) << 4)
/// ```
#[inline]
pub(crate) fn get_scale_min(j: usize, scales: &[u8; K_SCALE_SIZE]) -> (u8, u8) {
debug_assert!(j < 8);
if j < 4 {
let sc = scales[j] & 63;
let mn = scales[j + 4] & 63;
(sc, mn)
} else {
let sc = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4);
let mn = (scales[j + 4] >> 4) | ((scales[j ] >> 6) << 4);
(sc, mn)
}
}
// ---------------------------------------------------------------------------
// Dequantisation
// ---------------------------------------------------------------------------
/// Dequantise one Q4_K super-block into 256 `f32` values.
///
/// The loop mirrors GGML's `dequantize_row_q4_K`:
///
/// ```text
/// for each group of 64 elements (4 groups total):
/// d1, m1 = scale × d, min × dmin for sub-block (is + 0)
/// d2, m2 = scale × d, min × dmin for sub-block (is + 1)
/// out[0..32] = d1 × lower_nibble(qs[0..32]) m1
/// out[32..64] = d2 × upper_nibble(qs[0..32]) m2
/// advance qs by 32 bytes, is by 2
/// ```
pub fn dequantize_block_q4k(block: &BlockQ4K, out: &mut [f32; QK_K]) {
let d = fp16_to_f32(block.d);
let dmin = fp16_to_f32(block.dmin);
let mut q_off = 0usize; // byte cursor into block.qs
let mut out_off = 0usize; // element cursor into out
let mut is = 0usize; // sub-block pair index (0, 2, 4, 6)
while out_off < QK_K {
let (sc1, mn1) = get_scale_min(is, &block.scales);
let (sc2, mn2) = get_scale_min(is + 1, &block.scales);
let d1 = d * sc1 as f32;
let m1 = dmin * mn1 as f32;
let d2 = d * sc2 as f32;
let m2 = dmin * mn2 as f32;
for l in 0..32 {
out[out_off + l] = d1 * (block.qs[q_off + l] & 0x0F) as f32 - m1;
}
for l in 0..32 {
out[out_off + 32 + l] = d2 * (block.qs[q_off + l] >> 4) as f32 - m2;
}
q_off += 32;
out_off += 64;
is += 2;
}
}
// ---------------------------------------------------------------------------
// Matrix multiplication C = A × B
// ---------------------------------------------------------------------------
/// Multiply a Q4_K_M matrix **A** by an FP16 matrix **B**, producing an f32
/// matrix **C**.
///
/// # Arguments
///
/// * `a` Row-major slice of [`BlockQ4K`]. Row `i` occupies blocks
/// `a[i * blocks_per_row .. (i+1) * blocks_per_row]`.
/// * `b` Row-major fp16 matrix stored as raw `u16` bits, shape \[K, N\].
/// Element `(ki, j)` is at index `ki * n + j`.
/// * `m` Number of rows in A (and C).
/// * `k` Number of columns in A = number of rows in B.
/// **Must** be a multiple of [`QK_K`] (256).
/// * `n` Number of columns in B (and C).
///
/// # Returns
///
/// A flat row-major `Vec<f32>` of shape \[M, N\].
///
/// # Panics
///
/// Panics if `k` is not a multiple of `QK_K`, or if the lengths of `a` or `b`
/// do not match the declared dimensions.
pub fn matmul_q4k_fp16(
a: &[BlockQ4K],
b: &[u16],
m: usize,
k: usize,
n: usize,
) -> Vec<f32> {
assert_eq!(k % QK_K, 0,
"k ({k}) must be a multiple of QK_K ({QK_K})");
assert_eq!(a.len(), m * (k / QK_K),
"A block count mismatch: expected {} blocks, got {}", m * (k / QK_K), a.len());
assert_eq!(b.len(), k * n,
"B element count mismatch: expected {}, got {}", k * n, b.len());
let blocks_per_row = k / QK_K;
let mut c = vec![0.0f32; m * n];
// Scratch buffers allocated once and reused across rows.
let mut a_row = vec![0.0f32; k];
let mut block_buf = [0.0f32; QK_K];
for i in 0..m {
// Step 1: dequantise row i of A into a_row (f32).
for b_idx in 0..blocks_per_row {
let block = &a[i * blocks_per_row + b_idx];
dequantize_block_q4k(block, &mut block_buf);
let start = b_idx * QK_K;
a_row[start..start + QK_K].copy_from_slice(&block_buf);
}
// Step 2: for each output column j, compute dot(a_row, B[:, j]).
for j in 0..n {
let mut sum = 0.0f32;
for ki in 0..k {
// B is row-major [K, N]: element (ki, j) → index ki*n+j.
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
}
c[i * n + j] = sum;
}
}
c
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
// -------------------------------------------------------------------------
// Test helpers
// -------------------------------------------------------------------------
/// Convert a normal finite f32 (including ±0.0, ±inf) to its fp16 bit
/// pattern. Panics if the value is a non-representable normal number
/// (magnitude too large or too small for fp16's normal range).
fn f32_to_fp16_bits(f: f32) -> u16 {
if f == 0.0 { return 0x0000; }
if f == f32::INFINITY { return 0x7C00; }
if f == f32::NEG_INFINITY { return 0xFC00; }
if f.is_nan() { return 0x7E00; }
let bits = f.to_bits();
let sign = ((bits >> 31) as u16) << 15;
let exp = (bits >> 23) & 0xFF;
let mant = bits & 0x007F_FFFF;
let fp16_exp = exp as i32 - 127 + 15;
assert!(
fp16_exp > 0 && fp16_exp < 31,
"f32 value {f} is outside the representable fp16 normal range"
);
sign | ((fp16_exp as u16) << 10) | ((mant >> 13) as u16)
}
/// Build a [`BlockQ4K`] where all 8 sub-blocks share the same `scale` and
/// `min` (both **must** be < 16 to keep the encoding simple), and every
/// byte in `qs` is `qs_byte`.
///
/// With scale, min < 16, the scales array encoding simplifies to:
/// ```text
/// scales[0..4] = scale (used for sub-blocks 0..3)
/// scales[4..8] = min (used for sub-blocks 0..3)
/// scales[8..12] = (scale & 0xF) | ((min & 0xF) << 4) (sub-blocks 4..7)
/// ```
fn make_block(d: f32, dmin: f32, scale: u8, min: u8, qs_byte: u8) -> BlockQ4K {
assert!(
scale < 16 && min < 16,
"make_block: scale ({scale}) and min ({min}) must both be < 16"
);
let mut scales = [0u8; K_SCALE_SIZE];
for j in 0..4 {
scales[j] = scale;
scales[j + 4] = min;
}
for j in 8..12 {
scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4);
}
BlockQ4K {
d: f32_to_fp16_bits(d),
dmin: f32_to_fp16_bits(dmin),
scales,
qs: [qs_byte; QK_K / 2],
}
}
/// Assert two f32 values are within `tol` of each other.
fn assert_close(got: f32, expected: f32, tol: f32) {
assert!(
(got - expected).abs() <= tol,
"got {got}, expected {expected} (tol {tol})"
);
}
/// Assert every element of `got` equals `expected_scalar` within `tol`.
fn assert_all_close(got: &[f32], expected_scalar: f32, tol: f32) {
for (i, &g) in got.iter().enumerate() {
assert!(
(g - expected_scalar).abs() <= tol,
"element {i}: got {g}, expected {expected_scalar} (tol {tol})"
);
}
}
/// Assert two f32 slices are element-wise equal within `tol`.
/// Build a flat fp16 matrix (shape K×N) where every element is `value`.
fn fp16_uniform(k: usize, n: usize, value: f32) -> Vec<u16> {
vec![f32_to_fp16_bits(value); k * n]
}
// =========================================================================
// fp16_to_f32
// =========================================================================
#[test]
fn fp16_positive_zero() {
let v = fp16_to_f32(0x0000);
assert_eq!(v, 0.0f32);
assert!(v.is_sign_positive());
}
#[test]
fn fp16_negative_zero() {
let v = fp16_to_f32(0x8000);
// IEEE 754: -0.0 == +0.0
assert_eq!(v, 0.0f32);
assert!(v.is_sign_negative());
}
#[test]
fn fp16_one_and_negative_one() {
// 0x3C00: exp=15, mant=0 → (1+0) × 2^0 = 1.0
assert_eq!(fp16_to_f32(0x3C00), 1.0f32);
// 0xBC00: sign=1, same magnitude
assert_eq!(fp16_to_f32(0xBC00), -1.0f32);
}
#[test]
fn fp16_powers_of_two() {
// 0.5 = 2^-1: exp=14, mant=0 → 0x3800
assert_eq!(fp16_to_f32(0x3800), 0.5f32);
// 2.0 = 2^1: exp=16, mant=0 → 0x4000
assert_eq!(fp16_to_f32(0x4000), 2.0f32);
// 4.0 = 2^2: exp=17, mant=0 → 0x4400
assert_eq!(fp16_to_f32(0x4400), 4.0f32);
// 8.0 = 2^3: exp=18, mant=0 → 0x4800
assert_eq!(fp16_to_f32(0x4800), 8.0f32);
// 64.0 = 2^6: exp=21, mant=0 → 0x5400
assert_eq!(fp16_to_f32(0x5400), 64.0f32);
}
#[test]
fn fp16_non_power_of_two() {
// 3.0 = 1.5 × 2^1: exp=16, mant=512 → 0x4200
// 0x4200 = bit14=1, bit9=1 → exp=16, mant=512 → (1+0.5)×2=3.0
assert_eq!(fp16_to_f32(0x4200), 3.0f32);
// 10.0 = 1.25 × 2^3: exp=18, mant=256 → 0x4900
// 0x4900 = bits14-10=10010=18, bits9-0=01_0000_0000=256
// → (1+256/1024)×8 = 1.25×8 = 10.0
assert_eq!(fp16_to_f32(0x4900), 10.0f32);
}
#[test]
fn fp16_positive_infinity() {
let v = fp16_to_f32(0x7C00);
assert!(v.is_infinite());
assert!(v.is_sign_positive());
}
#[test]
fn fp16_negative_infinity() {
let v = fp16_to_f32(0xFC00);
assert!(v.is_infinite());
assert!(v.is_sign_negative());
}
#[test]
fn fp16_nan() {
// 0x7E00: exp all-ones, non-zero mantissa → NaN
assert!(fp16_to_f32(0x7E00).is_nan());
// 0xFE00: signed NaN
assert!(fp16_to_f32(0xFE00).is_nan());
}
#[test]
fn fp16_subnormals_become_zero() {
// Subnormals (exp=0, mant≠0) are too small for LLM weights; we
// return signed zero rather than performing the full decode.
assert_eq!(fp16_to_f32(0x0001), 0.0f32); // smallest positive subnormal
assert_eq!(fp16_to_f32(0x03FF), 0.0f32); // largest positive subnormal
assert_eq!(fp16_to_f32(0x8200), 0.0f32); // a negative subnormal
}
#[test]
fn fp16_roundtrip_via_helper() {
// Verify that f32_to_fp16_bits + fp16_to_f32 recovers the original
// value exactly for numbers that are precisely representable in fp16.
let values: &[f32] = &[0.5, 1.0, 2.0, 3.0, 4.0, 8.0, 10.0, 64.0, -1.0, -3.0];
for &v in values {
let bits = f32_to_fp16_bits(v);
let recovered = fp16_to_f32(bits);
assert_eq!(recovered, v, "round-trip failed for {v}");
}
}
// =========================================================================
// get_scale_min
// =========================================================================
#[test]
fn scale_min_j_lt_4_basic() {
// For j < 4: scale = scales[j] & 63, min = scales[j+4] & 63.
let mut scales = [0u8; K_SCALE_SIZE];
scales[0] = 42;
scales[4] = 17;
let (sc, mn) = get_scale_min(0, &scales);
assert_eq!(sc, 42);
assert_eq!(mn, 17);
}
#[test]
fn scale_min_j_lt_4_all_four_indices() {
let mut scales = [0u8; K_SCALE_SIZE];
scales[0] = 10; scales[4] = 20;
scales[1] = 11; scales[5] = 21;
scales[2] = 12; scales[6] = 22;
scales[3] = 13; scales[7] = 23;
for (j, (exp_sc, exp_mn)) in (0..4).zip([(10u8, 20u8), (11, 21), (12, 22), (13, 23)]) {
let (sc, mn) = get_scale_min(j, &scales);
assert_eq!(sc, exp_sc, "scale mismatch at j={j}");
assert_eq!(mn, exp_mn, "min mismatch at j={j}");
}
}
#[test]
fn scale_min_j_lt_4_masks_high_bits() {
// Bits 67 of scales[j] / scales[j+4] contribute to the j+4 sub-block
// encoding, not to the j sub-block. The & 63 must strip them.
let mut scales = [0u8; K_SCALE_SIZE];
scales[2] = 0b11_101010; // raw=0xEA=234, lower 6 bits = 0b101010 = 42
scales[6] = 0b10_010101; // raw=0x95=149, lower 6 bits = 0b010101 = 21
let (sc, mn) = get_scale_min(2, &scales);
assert_eq!(sc, 42);
assert_eq!(mn, 21);
}
#[test]
fn scale_min_j_gte_4_small_values() {
// j = 4, scale = 5 (<16), min = 3 (<16) upper bits are zero.
// scales[8] = (5 & 0xF) | ((3 & 0xF) << 4) = 0x05 | 0x30 = 0x35
let mut scales = [0u8; K_SCALE_SIZE];
scales[8] = 0x35;
let (sc, mn) = get_scale_min(4, &scales);
assert_eq!(sc, 5);
assert_eq!(mn, 3);
}
#[test]
fn scale_min_j_gte_4_needs_upper_bits() {
// j = 4, scale = 31 (0b011111), min = 25 (0b011001).
// Both need 5 bits, so the top bits must come from scales[0] and scales[4].
//
// Encoding:
// scales[8] = (31 & 0xF) | ((25 & 0xF) << 4) = 0x0F | 0x90 = 0x9F
// scales[0] |= ((31 >> 4) & 3) << 6 → bits 6-7 = 1 → 0x40
// scales[4] |= ((25 >> 4) & 3) << 6 → bits 6-7 = 1 → 0x40
//
// Decoding at j = 4:
// scale = (0x9F & 0x0F) | ((0x40 >> 6) << 4) = 15 | 16 = 31
// min = (0x9F >> 4) | ((0x40 >> 6) << 4) = 9 | 16 = 25
let mut scales = [0u8; K_SCALE_SIZE];
scales[0] = 0x40;
scales[4] = 0x40;
scales[8] = 0x9F;
let (sc, mn) = get_scale_min(4, &scales);
assert_eq!(sc, 31, "scale mismatch");
assert_eq!(mn, 25, "min mismatch");
}
#[test]
fn scale_min_j_7_last_index() {
// j = 7: scale = (scales[11] & 0x0F) | ((scales[3] >> 6) << 4)
// min = (scales[11] >> 4) | ((scales[7] >> 6) << 4)
// Choose scale = 20 (0b010100), min = 10 (0b001010).
// scales[11] = (20 & 0xF) | ((10 & 0xF) << 4) = 0x04 | 0xA0 = 0xA4
// scales[3] |= ((20 >> 4) & 3) << 6 → bit 6 = 1 → 0x40
// scales[7]: (10 >> 4) & 3 = 0, no change needed
let mut scales = [0u8; K_SCALE_SIZE];
scales[3] = 0x40;
scales[11] = 0xA4;
let (sc, mn) = get_scale_min(7, &scales);
assert_eq!(sc, 20);
assert_eq!(mn, 10);
}
// =========================================================================
// dequantize_block_q4k
// =========================================================================
#[test]
fn dequant_zero_d_all_outputs_zero() {
// d = 0.0 → every product is 0; dmin = 0 so the subtracted min is also 0.
let block = make_block(0.0, 0.0, 5, 3, 0xFF);
let mut out = [f32::NAN; QK_K];
dequantize_block_q4k(&block, &mut out);
assert_all_close(&out, 0.0, 0.0);
}
#[test]
fn dequant_uniform_nibble_one_scale_one() {
// d=1.0, dmin=0.0, scale=1, min=0, all nibbles=1
// formula: 1.0 × 1 × 1 0.0 × 0 = 1.0
let block = make_block(1.0, 0.0, 1, 0, 0x11);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k(&block, &mut out);
assert_all_close(&out, 1.0, 0.0);
}
#[test]
fn dequant_max_nibble_with_larger_scale() {
// d=2.0, dmin=0.0, scale=3, min=0, all nibbles=15
// formula: 2.0 × 3 × 15 0 = 90.0
let block = make_block(2.0, 0.0, 3, 0, 0xFF);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k(&block, &mut out);
assert_all_close(&out, 90.0, 1e-4);
}
#[test]
fn dequant_non_zero_min_subtracts() {
// d=1.0, dmin=1.0, scale=4, min=3, nibble=5
// formula: 1.0 × 4 × 5 1.0 × 3 = 20 3 = 17.0
let block = make_block(1.0, 1.0, 4, 3, 0x55);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k(&block, &mut out);
assert_all_close(&out, 17.0, 1e-4);
}
#[test]
fn dequant_zero_nibble_with_nonzero_min() {
// nibble = 0, but the min offset is still subtracted.
// d=1.0, dmin=1.0, scale=5, min=3, nibble=0
// formula: 1.0 × 5 × 0 1.0 × 3 = 3.0
let block = make_block(1.0, 1.0, 5, 3, 0x00);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k(&block, &mut out);
assert_all_close(&out, -3.0, 1e-4);
}
#[test]
fn dequant_mixed_nibbles_correct_element_layout() {
// qs_byte = 0x21: lower nibble = 1, upper nibble = 2.
// d=1.0, dmin=0.0, scale=1, min=0.
//
// Per 64-element group the layout is:
// elements [0..32] ← lower nibbles (nibble=1) → 1.0
// elements [32..64] ← upper nibbles (nibble=2) → 2.0
//
// This must hold for all four groups (elements 0255).
let block = make_block(1.0, 0.0, 1, 0, 0x21);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k(&block, &mut out);
for group in 0..4_usize {
let base = group * 64;
for l in 0..32 {
assert_eq!(
out[base + l], 1.0,
"group {group} lower element {l} (out[{}])", base + l
);
assert_eq!(
out[base + 32 + l], 2.0,
"group {group} upper element {l} (out[{}])", base + 32 + l
);
}
}
}
#[test]
fn dequant_output_count_is_qk_k() {
// Sanity-check: exactly QK_K = 256 values are written.
let block = make_block(1.0, 0.0, 1, 0, 0x33);
let mut out = [0.0f32; QK_K];
dequantize_block_q4k(&block, &mut out);
// All elements should be non-NaN (were actually written).
assert!(out.iter().all(|v| !v.is_nan()));
}
#[test]
fn block_q4k_size_is_144_bytes() {
// 2 (d) + 2 (dmin) + 12 (scales) + 128 (qs) = 144 bytes
assert_eq!(core::mem::size_of::<BlockQ4K>(), 144);
}
// =========================================================================
// matmul_q4k_fp16
// =========================================================================
#[test]
fn matmul_1x256_times_256x1_all_ones() {
// A: 1×256, all weights = 1.0 (one block)
// B: 256×1, all values = 1.0
// C: 1×1, expected = 256.0
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
let b = fp16_uniform(256, 1, 1.0);
let c = matmul_q4k_fp16(&a, &b, 1, 256, 1);
assert_eq!(c.len(), 1);
assert_close(c[0], 256.0, 0.1);
}
#[test]
fn matmul_2x256_times_256x3_all_ones() {
// A: 2×256, all weights = 1.0
// B: 256×3, all values = 1.0
// C: 2×3, all = 256.0
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 2];
let b = fp16_uniform(256, 3, 1.0);
let c = matmul_q4k_fp16(&a, &b, 2, 256, 3);
assert_eq!(c.len(), 6);
assert_all_close(&c, 256.0, 0.1);
}
#[test]
fn matmul_zero_a_gives_zero_c() {
// d = 0.0 → every weight = 0.0 → every output = 0.0
let a = vec![make_block(0.0, 0.0, 1, 0, 0xFF); 3];
let b = fp16_uniform(256, 4, 7.0);
let c = matmul_q4k_fp16(&a, &b, 3, 256, 4);
assert_all_close(&c, 0.0, 0.0);
}
#[test]
fn matmul_zero_b_gives_zero_c() {
let a = vec![make_block(1.0, 0.0, 1, 0, 0x55)];
let b = fp16_uniform(256, 3, 0.0);
let c = matmul_q4k_fp16(&a, &b, 1, 256, 3);
assert_all_close(&c, 0.0, 0.0);
}
#[test]
fn matmul_two_blocks_per_row() {
// K = 512 = 2 × QK_K; each row has two all-ones blocks.
// B = all ones → C = 512.0
let block = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![block; 2]; // 1 row × 2 blocks
let b = fp16_uniform(512, 1, 1.0);
let c = matmul_q4k_fp16(&a, &b, 1, 512, 1);
assert_close(c[0], 512.0, 0.1);
}
#[test]
fn matmul_multiple_rows_multiple_blocks_per_row() {
// A: 3 rows × 2 blocks each (K = 512), all weights = 1.0
// B: 512×2, all values = 1.0
// C: 3×2, all = 512.0
let block = make_block(1.0, 0.0, 1, 0, 0x11);
let a = vec![block; 3 * 2];
let b = fp16_uniform(512, 2, 1.0);
let c = matmul_q4k_fp16(&a, &b, 3, 512, 2);
assert_eq!(c.len(), 6);
assert_all_close(&c, 512.0, 0.1);
}
#[test]
fn matmul_alternating_weights_known_dot_product() {
// qs_byte = 0x21: lower nibble = 1, upper nibble = 2.
// d=1.0, dmin=0.0, scale=1, min=0.
// The dequantised row has the pattern: 32×1.0, 32×2.0 (×4 groups)
// → 128 ones and 128 twos.
// With B = all ones: C = 128×1.0 + 128×2.0 = 384.0
let a = vec![make_block(1.0, 0.0, 1, 0, 0x21)];
let b = fp16_uniform(256, 1, 1.0);
let c = matmul_q4k_fp16(&a, &b, 1, 256, 1);
assert_close(c[0], 384.0, 0.1);
}
#[test]
fn matmul_b_scalar_scales_output() {
// All A weights = 1.0, all B values = 2.0 → C = 256 × 2.0 = 512.0
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
let b = fp16_uniform(256, 1, 2.0);
let c = matmul_q4k_fp16(&a, &b, 1, 256, 1);
assert_close(c[0], 512.0, 0.2);
}
#[test]
fn matmul_output_has_correct_shape() {
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 5];
let b = fp16_uniform(256, 7, 1.0);
let c = matmul_q4k_fp16(&a, &b, 5, 256, 7);
assert_eq!(c.len(), 5 * 7);
}
#[test]
#[should_panic(expected = "must be a multiple of QK_K")]
fn matmul_panics_when_k_not_multiple_of_qkk() {
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
let b = fp16_uniform(128, 1, 1.0);
let _ = matmul_q4k_fp16(&a, &b, 1, 128, 1);
}
#[test]
#[should_panic(expected = "A block count mismatch")]
fn matmul_panics_on_wrong_a_length() {
// 1×256 needs 1 block, but we supply 2.
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 2];
let b = fp16_uniform(256, 1, 1.0);
let _ = matmul_q4k_fp16(&a, &b, 1, 256, 1);
}
#[test]
#[should_panic(expected = "B element count mismatch")]
fn matmul_panics_on_wrong_b_length() {
// 256×1 needs 256 elements, but we supply 128.
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
let b = vec![0x3C00u16; 128]; // 128 fp16 ones
let _ = matmul_q4k_fp16(&a, &b, 1, 256, 1);
}
}

82
src/main.rs Normal file
View File

@@ -0,0 +1,82 @@
//! Demo binary for the `matrix-testing` library.
//!
//! Constructs a small Q4_K_M × FP16 matrix multiply and prints the result.
use matrix_testing::{matmul_q4k_fp16, BlockQ4K, K_SCALE_SIZE, QK_K};
fn main() {
// -----------------------------------------------------------------------
// Build a tiny test case: A is (2 × 256) Q4_K_M, B is (256 × 3) fp16.
//
// We construct A so that every dequantised weight is exactly 1.0 and B so
// that every fp16 value is also 1.0. Then every output element should
// equal K = 256.
// -----------------------------------------------------------------------
const M: usize = 2; // rows of A / rows of C
const K: usize = 256; // cols of A / rows of B (one Q4_K block wide)
const N: usize = 3; // cols of B / cols of C
let fp16_one = 0x3C00u16; // 1.0 in fp16
let fp16_zero = 0x0000u16; // 0.0 in fp16
// ---- Build A -----------------------------------------------------------
// Goal: dequant(q) == 1.0 for every element.
//
// formula: d * scale * q - dmin * min = 1.0
//
// Choosing d=1.0, dmin=0.0, scale=1, min=0, nibble=1:
// 1.0 * 1 * 1 - 0.0 * 0 = 1.0 ✓
//
// Scale encoding for values < 16 (upper bits are zero):
// scales[0..4] = scale = 1
// scales[4..8] = min = 0
// scales[8..12] = (scale & 0xF) | ((min & 0xF) << 4) = 0x01
let mut scales = [0u8; K_SCALE_SIZE];
for j in 0..4 {
scales[j] = 1; // scale = 1 for sub-blocks 0..3
scales[j + 4] = 0; // min = 0 for sub-blocks 0..3
}
for s in scales.iter_mut().skip(8) {
*s = 0x01; // scale=1, min=0 for sub-blocks 4..7
}
let block_template = BlockQ4K {
d: fp16_one,
dmin: fp16_zero,
scales,
qs: [0x11u8; QK_K / 2], // nibble=1 in both halves of every byte
};
let a_blocks: Vec<BlockQ4K> = vec![block_template; M * (K / QK_K)];
// ---- Build B -----------------------------------------------------------
let b_fp16: Vec<u16> = vec![fp16_one; K * N];
// ---- Run the multiply --------------------------------------------------
let c = matmul_q4k_fp16(&a_blocks, &b_fp16, M, K, N);
// ---- Print results -----------------------------------------------------
println!("Output matrix C ({M} x {N}):");
for i in 0..M {
print!(" row {i}: ");
for j in 0..N {
print!("{:.1} ", c[i * N + j]);
}
println!();
}
let expected = K as f32;
let all_ok = c.iter().all(|&v| (v - expected).abs() < 0.1);
if all_ok {
println!("All outputs == {expected:.1}");
} else {
eprintln!("FAIL: unexpected output values");
std::process::exit(1);
}
println!();
println!("Note: this is the naive O(M·N·K) implementation.");
println!("It is intentionally simple no SIMD, no tiling, no tricks.");
}