Browse Source

Optimize matrix multiplication to be more cache-friendly.

Mnist-1000: 15s -> 7.74s.
main
3gg 7 months ago
parent
commit
041613467a
  1. 20
      src/lib/src/matrix.c

20
src/lib/src/matrix.c

@ -131,21 +131,23 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) {
assert(out->cols == right->cols);
R* out_value = out->values;
for (int i = 0; i < out->rows * out->cols; ++i) {
*out_value++ = 0;
}
for (int i = 0; i < left->rows; ++i) {
const R* left_row = &left->values[i * left->cols];
const R* p_left_value = &left->values[i * left->cols];
for (int j = 0; j < right->cols; ++j) {
const R* right_col = &right->values[j];
*out_value = 0;
for (int j = 0; j < left->cols; ++j) {
const R left_value = *p_left_value;
const R* right_value = &right->values[j * right->cols];
R* out_value = &out->values[i * out->cols];
// Vector dot product.
for (int k = 0; k < left->cols; ++k) {
*out_value += left_row[k] * right_col[0];
right_col += right->cols; // Next row in the column.
for (int k = 0; k < right->cols; ++k) {
*out_value++ += left_value * *right_value++;
}
out_value++;
p_left_value++;
}
}
}

Loading…
Cancel
Save