c++ 简单实现矩阵类

665 阅读3分钟

由于数值分析课程需要使用C++实现矩阵相关的算法,此处先实现简单的矩阵类便于之后使用。

该矩阵类主要使用其中的get set方法来获取、设置矩阵中的某个位置的值,并以行数和列数作为参数,避免与数组从0开始的概念冲突。

通过重载运算符实现了矩阵的加、减、乘法,并重载了输出,方便观察结果,同时实现了一些简单的错误处理,方便在调试矩阵算法时发现错误。

建议使用嵌套的vector作为矩阵类的二维数组,而不是iMatrix[][]二维数组,否则在对象删除时调用析构函数删除二维数组总会引发内存泄漏问题。

Matrix.h

#include <vector>
#include <iostream>
#include <string>
using namespace std;

// 定义自定义的矩阵类
class Matrix {
private:
    // 矩阵的行和列数
    int row;
    int column;
    // 矩阵本身,二维数组
    vector<vector<double>> iMatrix;
public:
    // 矩阵的构造函数,指定行列数
    Matrix();
    Matrix(int n);
    Matrix(int row, int column);
    ~Matrix();
    // 按行读入形成矩阵
    void init(vector<double> list);
    // 初始化为全0
    void allZero();
    // 获得、设定矩阵某个元素
    double get(int row, int column);
    void set(int row, int column, double num);
    int getRow();
    int getColumn();
    // 复制该矩阵
    Matrix& copy();
    // 返回某行和某列
    vector<double> getRow(int rowNum);
    vector<double> getColumn(int columnNum);
    void setRow(int rowNum, vector<double> rowVector);
    void setColumn(int columnNum, vector<double> columnVector);
    // 求倒置矩阵
    Matrix& getT();
    // 矩阵加法
    Matrix& operator+(Matrix matrix);
    // 矩阵减法
    Matrix& operator-(Matrix matrix);
    // 乘一个常数
    Matrix& operator*(double num);
    // 乘矩阵
    Matrix& operator*(Matrix matrix);
    // 重载输出,易于观察输出
    friend ostream& operator << (ostream& os, Matrix& matrix);
};

Martix.cpp

#include "Matrix.h"

void throwErr(const string &msg) {
    cout << msg << endl;
    throw new exception;
}


Matrix::Matrix(int n) : Matrix(n, n) {

}

Matrix::Matrix(int row, int column) {
    this->row = row;
    this->column = column;

    vector<vector<double>> v(row);
    for (int i = 0; i < row; ++i) {
        v.at(i) = vector<double>(column);
    }
    this->iMatrix = v;

}

void Matrix::init(std::vector<double> list) {
    if (list.size() != this->row * this->column) {
        throwErr("矩阵初始化错误:给出的list长度与矩阵元素个数不同");
    }
    int n = 0;
    for (int i = 1; i <= this->row; ++i) {
        for (int j = 1; j <= this->column; ++j) {
            this->set(i, j, list.at(n));
            n++;
        }
    }
}

double Matrix::get(int row, int column) {
    return this->iMatrix.at(row - 1).at(column - 1);
}

void Matrix::set(int row, int column, double num) {
    this->iMatrix.at(row - 1).at(column - 1) = num;
}

ostream &operator<<(ostream &os, Matrix &matrix) {
    os << "***** 矩阵 " << matrix.row << "x" << matrix.column << " " << "*****" << endl;

    for (int i = 1; i <= matrix.row; ++i) {
        for (int j = 1; j <= matrix.column; ++j) {
            os << matrix.get(i, j) << " ";
        }
        os << endl;
    }
    os << "********************" << endl;
    return os;
}

int Matrix::getRow() {
    return this->row;
}

int Matrix::getColumn() {
    return this->column;
}

Matrix &Matrix::getT() {
    Matrix *matrix = new Matrix(this->column, this->row);
    vector<double> v{};
    for (int i = 1; i <= this->column; ++i) {
        for (int j = 1; j <= this->row; ++j) {
            v.push_back(this->get(j, i));
        }
    }
    matrix->init(v);
    return *matrix;
}

Matrix &Matrix::operator*(Matrix matrix) {
    // 首先检查该矩阵的列数和乘的矩阵行数是否一致
    if (this->column != matrix.getRow()) {
        throwErr("矩阵的列数和乘的矩阵行数不一致");
    }
    Matrix *newMatrix = new Matrix(this->row, matrix.getColumn());
    // 矩阵乘法
    vector<double> v1{};
    vector<double> v2{};
    double res;
    for (int i = 1; i <= this->row; ++i) {
        for (int j = 1; j <= matrix.getColumn(); ++j) {
            v1 = this->getRow(i);
            v2 = matrix.getColumn(j);
            res = 0;
            for (int k = 0; k < v1.size(); ++k) {
                res += v1.at(k) * v2.at(k);
            }
            newMatrix->set(i, j, res);
        }
    }
    return *newMatrix;
}

vector<double> Matrix::getRow(int rowNum) {
    vector<double> v1{};
    for (int i = 1; i <= this->column; ++i) {
        v1.push_back(this->get(rowNum, i));
    }
    return v1;
}

vector<double> Matrix::getColumn(int columnNum) {
    vector<double> v1{};
    for (int i = 1; i <= this->row; ++i) {
        v1.push_back(this->get(i, columnNum));
    }
    return v1;
}

Matrix &Matrix::operator+(Matrix matrix) {
    // 检查两个矩阵的列数和行数
    if (this->row != matrix.getRow() || this->column != matrix.getColumn()) {
        throwErr("两个相加的矩阵行列数不相同");
    }
    double val;
    Matrix *newMatrix = new Matrix(this->getRow(), this->getColumn());
    for (int i = 1; i <= this->getRow(); ++i) {
        for (int j = 1; j <= this->getColumn(); ++j) {
            val = this->get(i, j) + matrix.get(i, j);
            newMatrix->set(i, j, val);
        }
    }
    return *newMatrix;
}

Matrix &Matrix::operator-(Matrix matrix) {
    // 检查两个矩阵的列数和行数
    if (this->row != matrix.getRow() || this->column != matrix.getColumn()) {
        throwErr("两个相减的矩阵行列数不相同");
    }
    double val;
    Matrix *newMatrix = new Matrix(this->getRow(), this->getColumn());
    for (int i = 1; i <= this->getRow(); ++i) {
        for (int j = 1; j <= this->getColumn(); ++j) {
            val = this->get(i, j) - matrix.get(i, j);
            newMatrix->set(i, j, val);
        }
    }
    return *newMatrix;
}

void Matrix::setRow(int rowNum, vector<double> rowVector) {
    // 检查行向量的尺寸和该矩阵的列数
    if (this->column != rowVector.size()) {
        throwErr("传入的行向量尺寸不对应");
    }
    for (int i = 1; i <= rowVector.size(); ++i) {
        this->set(rowNum, i, rowVector.at(i - 1));
    }
}

void Matrix::setColumn(int columnNum, vector<double> columnVector) {
    // 检查列向量的尺寸和该矩阵的行数
    if (this->row != columnVector.size()) {
        throwErr("传入的列向量尺寸不对应");
    }
    for (int i = 1; i <= columnVector.size(); ++i) {
        this->set(i, columnNum, columnVector.at(i - 1));
    }
}

Matrix::Matrix() : Matrix(1) {

}

Matrix &Matrix::operator*(double num) {
    Matrix *newMatrix = new Matrix(this->row, this->column);
    for (int i = 1; i <= this->getRow(); ++i) {
        for (int j = 1; j <= this->getColumn(); ++j) {
            double val = this->get(i, j);
            newMatrix->set(i, j, val * num);
        }
    }
    return *newMatrix;
}

void Matrix::allZero() {
    for (int i = 1; i <= this->getRow(); ++i) {
        for (int j = 1; j <= this->getColumn(); ++j) {
            this->set(i, j, 0);
        }
    }
}

Matrix::~Matrix() {
}

Matrix &Matrix::copy() {
    Matrix *newMatrix = new Matrix(this->row, this->column);
    for (int i = 1; i <= this->row; ++i) {
        newMatrix->setRow(i, this->getRow(i));
    }
    return *newMatrix;
}