Efficient Matrix Multiplication: A Hybrid Approach Combining Strassen’s and Traditional Methods

Matrix multiplication is a fundamental operation in numerous fields, from computer graphics and machine learning to scientific computing. However, traditional matrix multiplication can become inefficient for large matrices due to its cubic time complexity.

In this post, we’ll discuss a hybrid approach that combines the traditional algorithm and Strassen’s method for matrix multiplication.

Prerequisites:
Matrix Multiplication and Strassen’s Algorithm for Matrix Multiplication have already been discussed in detail in these posts:

What is Matrix Multiplication?

Matrix multiplication involves computing a new matrix C by multiplying two matrices A and B. Given matrices A of size n × m and B of size m × p, the resulting matrix C will be of size n × p. The entry at position C[i][j] is computed as:

C[i][j] = \sum_{k=0}^{m-1} A[i][k] \times B[k][j]

While this method is straightforward, its time complexity is O(n3), which becomes impractical for large matrices.

Strassen’s Algorithm

Strassen’s Algorithm, developed by Volker Strassen in 1969, was revolutionary because it reduced matrix multiplication time complexity from from O(n3) to approximately O(n2.81). This improvement comes from breaking down large matrices into smaller submatrices, multiplying them using a recursive approach, and then recombining the results. Although it offers speed improvements, it also introduces complexity and requires additional memory for managing submatrices.

Hybrid Approach

The hybrid approach combines the best of both approaches: the simplicity of the naive method and the efficiency of Strassen’s algorithm. For small matrices, Strassen’s method may not provide significant performance gains due to the overhead of splitting and recombining submatrices. Therefore, the hybrid approach uses the traditional method for smaller matrices and switches to Strassen’s algorithm when the matrix size exceeds a threshold (typically selected experimentally). This optimizes both speed and memory usage, making it ideal for a wide range of applications.

Pseudocode for Hybrid Approach
Function HybridMultiply(A, B, THRESHOLD)
n = size of matrix A

// Base Case: If matrix size is small, use naive multiplication
If n <= THRESHOLD:
return NaiveMultiply(A, B)

// Split matrices A and B into submatrices
Split A into A11, A12, A21, A22
Split B into B11, B12, B21, B22

// Calculate M1 to M7 using Strassen’s method
M1 = HybridMultiply(A11 + A22, B11 + B22)
M2 = HybridMultiply(A21 + A22, B11)
M3 = HybridMultiply(A11, B12 - B22)
M4 = HybridMultiply(A22, B21 - B11)
M5 = HybridMultiply(A11 + A12, B22)
M6 = HybridMultiply(A21 - A11, B11 + B12)
M7 = HybridMultiply(A12 - A22, B21 + B22)

// Compute C submatrices using M1 to M7
C11 = M1 + M4 - M5 + M7
C12 = M3 + M5
C21 = M2 + M4
C22 = M1 - M2 + M3 + M6

// Combine C11, C12, C21, C22 into result matrix C
JoinMatrices(C11, C12, C21, C22, C)

return C
Code Implementation

C++ implementation of the hybrid approach – combining Strassen’s algorithm with the traditional matrix multiplication method.

#include <iostream>
#include <vector>
#include <chrono>
using namespace std;
using namespace chrono;

// Threshold for switching to the naive method (selected experimentally)
const int THRESHOLD = 64;

// Function to print the matrix
void printMatrix(const vector<vector<int>>& matrix) {
    int n = (int) matrix.size();
    int m = (int) matrix[0].size();

    // Set the limits for rows and columns to be displayed
    int maxRows = min(n, 4); // Display up to 4 rows
    int maxCols = min(m, 4); // Display up to 4 columns

    for (int i = 0; i < maxRows; i++) {
        for (int j = 0; j < maxCols; j++) {
            cout << matrix[i][j] << " ";
        }

        // If there are more columns, indicate there are more
        if (m > maxCols) {
            cout << "...";
        }

        cout << endl;
    }

    // If there are more rows, indicate that there are more
    if (n > maxRows) {
        cout << "..." << endl;
    }
}

// Naive matrix multiplication for small matrices
vector<vector<int>> naiveMultiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    vector<vector<int>> C(n, vector<int>(n, 0));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < n; k++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    
    return C;
}

// Function to add two matrices
vector<vector<int>> add(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    vector<vector<int>> result(n, vector<int>(n));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            result[i][j] = A[i][j] + B[i][j];
        }
    }
    
    return result;
}

// Function to subtract two matrices
vector<vector<int>> subtract(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    vector<vector<int>> result(n, vector<int>(n));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            result[i][j] = A[i][j] - B[i][j];
        }
    }
    
    return result;
}

// Function to split a matrix into 4 submatrices
void splitMatrix(const vector<vector<int>>& original, vector<vector<int>>& A11, vector<vector<int>>& A12,
                 vector<vector<int>>& A21, vector<vector<int>>& A22) {
    int newSize = (int) (original.size() / 2);
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            A11[i][j] = original[i][j];
            A12[i][j] = original[i][j + newSize];
            A21[i][j] = original[i + newSize][j];
            A22[i][j] = original[i + newSize][j + newSize];
        }
    }
}

// Function to join 4 submatrices into a matrix
void joinMatrices(const vector<vector<int>>& A11, const vector<vector<int>>& A12,
                  const vector<vector<int>>& A21, const vector<vector<int>>& A22,
                  vector<vector<int>>& result) {
    int newSize = (int) A11.size();
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            result[i][j] = A11[i][j];
            result[i][j + newSize] = A12[i][j];
            result[i + newSize][j] = A21[i][j];
            result[i + newSize][j + newSize] = A22[i][j];
        }
    }
}

// Hybrid Strassen's matrix multiplication function
vector<vector<int>> strassenMultiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    
    // Base case: Use the naive method for small matrices
    if (n <= THRESHOLD) {
        return naiveMultiply(A, B);
    }

    // Splitting matrices into submatrices
    int newSize = n / 2;
    vector<vector<int>> A11(newSize, vector<int>(newSize));
    vector<vector<int>> A12(newSize, vector<int>(newSize));
    vector<vector<int>> A21(newSize, vector<int>(newSize));
    vector<vector<int>> A22(newSize, vector<int>(newSize));
    vector<vector<int>> B11(newSize, vector<int>(newSize));
    vector<vector<int>> B12(newSize, vector<int>(newSize));
    vector<vector<int>> B21(newSize, vector<int>(newSize));
    vector<vector<int>> B22(newSize, vector<int>(newSize));

    splitMatrix(A, A11, A12, A21, A22);
    splitMatrix(B, B11, B12, B21, B22);

    // Computing the 7 products using Strassen's method
    vector<vector<int>> M1 = strassenMultiply(add(A11, A22), add(B11, B22));
    vector<vector<int>> M2 = strassenMultiply(add(A21, A22), B11);
    vector<vector<int>> M3 = strassenMultiply(A11, subtract(B12, B22));
    vector<vector<int>> M4 = strassenMultiply(A22, subtract(B21, B11));
    vector<vector<int>> M5 = strassenMultiply(add(A11, A12), B22);
    vector<vector<int>> M6 = strassenMultiply(subtract(A21, A11), add(B11, B12));
    vector<vector<int>> M7 = strassenMultiply(subtract(A12, A22), add(B21, B22));

    // Calculating C submatrices
    vector<vector<int>> C11 = add(subtract(add(M1, M4), M5), M7);
    vector<vector<int>> C12 = add(M3, M5);
    vector<vector<int>> C21 = add(M2, M4);
    vector<vector<int>> C22 = add(subtract(add(M1, M3), M2), M6);

    // Joining the 4 submatrices into the result matrix
    vector<vector<int>> C(n, vector<int>(n));
    joinMatrices(C11, C12, C21, C22, C);

    return C;
}

// Function to measure execution time
void measureExecutionTime(const string& name, vector<vector<int>> (*multiplyFunc)(const vector<vector<int>>&, const vector<vector<int>>&), const vector<vector<int>>& A, const vector<vector<int>>& B) {
    auto start = high_resolution_clock::now();
    vector<vector<int>> C = multiplyFunc(A, B);
    auto end = high_resolution_clock::now();
    auto duration = duration_cast<milliseconds>(end - start).count();
    cout << endl << name << " took " << (duration/1000.0) << " seconds.\n";
 
    cout << "Result Matrix C:" << endl;
    printMatrix(C);
 
}

int main() {
    int n = 1024; // Example size, should be a power of 2 for simplicity when using Strassen's method
    vector<vector<int>> A(n, vector<int>(n, 1));
    vector<vector<int>> B(n, vector<int>(n, 1));
     
    cout << "Matrix A:" << endl;
    printMatrix(A);
     
    cout << endl << "Matrix B:" << endl;
    printMatrix(B);
 
    measureExecutionTime("Naive Multiplication", naiveMultiply, A, B);
    measureExecutionTime("Hybrid Multiplication", strassenMultiply, A, B);

    return 0;
}

Output

Matrix A:
1 1 1 1 ...
1 1 1 1 ...
1 1 1 1 ...
1 1 1 1 ...
...

Matrix B:
1 1 1 1 ...
1 1 1 1 ...
1 1 1 1 ...
1 1 1 1 ...
...

Naive Multiplication took 7.442 seconds.
Result Matrix C:
1024 1024 1024 1024 ...
1024 1024 1024 1024 ...
1024 1024 1024 1024 ...
1024 1024 1024 1024 ...
...

Hybrid Multiplication took 0.808 seconds.
Result Matrix C:
1024 1024 1024 1024 ...
1024 1024 1024 1024 ...
1024 1024 1024 1024 ...
1024 1024 1024 1024 ...
...
Explanation of the Code

The hybrid matrix multiplication code combines Strassen’s Algorithm and the naive (traditional) approach to optimize both performance and memory usage. The code defines a threshold (THRESHOLD), which determines when to switch between the two methods. If the matrix size is less than or equal to this threshold, the code uses the naive method for simplicity and efficiency on smaller matrices. For larger matrices, it uses Strassen’s method.

The Strassen method involves recursively splitting matrices into four submatrices (A11, A12, A21, A22 for matrix A and similarly for matrix B), and calculating seven matrix products (M1 to M7) using these submatrices. These products are then combined to form the final result. This method reduces the number of required multiplications from 8 (in naive multiplication) to 7, yielding improved performance for large matrices.

Time and Space Complexity of Hybrid Approach

Time Complexity
The hybrid approach has a time complexity of approximately O(n2.81) when using Strassen’s algorithm for large matrices and O(n3) for smaller matrices handled by the naive method. The threshold value determines when to switch between the two, ensuring that the approach adapts to the matrix size for efficient performance. The threshold optimizes performance by ensuring that Strassen’s overhead is avoided for matrices where it doesn’t offer significant gains.

Thus, for matrices larger than the threshold, the complexity tends toward O(n2.81), leveraging the recursive efficiency of Strassen’s algorithm.

Space Complexity
The hybrid approach requires additional space complexity due to the recursive splitting in Strassen’s algorithm, resulting in an approximate space complexity of O(n2). While the naive method has O(n2) space complexity, primarily for storing the result matrix, Strassen’s algorithm requires even more storage to store the submatrices and intermediate products generated during computation.

Consequently, while efficient, the hybrid approach uses more memory than naive method due to its recursive structure and additional storage requirements.

In this post, we discussed how to efficiently perform matrix multiplication using a hybrid approach that combines the simplicity of the traditional method with the speed of Strassen’s algorithm. By dynamically switching between methods based on matrix size, we achieve an optimized balance of speed and memory usage.

Leave a Reply

Your email address will not be published. Required fields are marked *