make loops match
This commit is contained in:
27
src/lib.rs
27
src/lib.rs
@@ -211,27 +211,28 @@ pub fn matmul_q4k_fp16(
|
||||
let blocks_per_row = k / QK_K;
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
|
||||
// Scratch buffers allocated once and reused across rows.
|
||||
let mut a_row = vec![0.0f32; k];
|
||||
// Scratch buffer for one dequantised block, reused across iterations.
|
||||
let mut block_buf = [0.0f32; QK_K];
|
||||
|
||||
for i in 0..m {
|
||||
// Step 1: dequantise row i of A into a_row (f32).
|
||||
let c_row = &mut c[i * n..(i + 1) * n];
|
||||
|
||||
// Dequantise one block at a time and saxpy its weights directly into
|
||||
// c_row. The inner loop order (weight-outer, column-inner) keeps each
|
||||
// B row in a contiguous stride-1 access, which is more cache-friendly
|
||||
// than the alternative (column-outer, weight-inner) that jumps by N
|
||||
// between consecutive B reads.
|
||||
for b_idx in 0..blocks_per_row {
|
||||
let block = &a[i * blocks_per_row + b_idx];
|
||||
let ki_base = b_idx * QK_K;
|
||||
dequantize_block_q4k(block, &mut block_buf);
|
||||
let start = b_idx * QK_K;
|
||||
a_row[start..start + QK_K].copy_from_slice(&block_buf);
|
||||
}
|
||||
|
||||
// Step 2: for each output column j, compute dot(a_row, B[:, j]).
|
||||
for l in 0..QK_K {
|
||||
let w = block_buf[l];
|
||||
let b_off = (ki_base + l) * n;
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for ki in 0..k {
|
||||
// B is row-major [K, N]: element (ki, j) → index ki*n+j.
|
||||
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
|
||||
c_row[j] += w * fp16_to_f32(b[b_off + j]);
|
||||
}
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user