矩阵类
头文件:
class X_MATH_EXPORT XMatrix { public: XMatrix() = delete; XMatrix(size_t row, size_t column, XDecimal val); XMatrix(size_t row, size_t column, const XDecimal vals[]); XMatrix(size_t row, size_t column, std::initializer_list<XDecimal> vals); XMatrix(const XMatrix& m); virtual ~XMatrix(); /** 获取行数 */ size_t getRowCount() const; /** 获取列数 */ size_t getColumnCount() const; /** 获取元素值 */ XDecimal getValue(size_t row, size_t column) const; /** 设置元素值 */ void setValue(size_t row, size_t column, XDecimal val); /** 矩阵转置 */ virtual XMatrix& transpose(); /** 是否为方阵 */ virtual bool isSquareMatrix() const; /** 是否为行矩阵 */ virtual bool isRowMatrix() const; /** 是否为列矩阵 */ virtual bool isColumnMatrix() const; /** 是否为零矩阵 */ virtual bool isZeroMatrix() const; /** 重载运算符 XMatrix == otherXMatrix */ bool operator==(const XMatrix& m) const; /** 重载运算符 XMatrix != otherXMatrix */ bool operator!=(const XMatrix& m) const; /** 重载运算符 XMatrix += otherXMatrix */ XMatrix& operator+=(const XMatrix& m); /** 重载运算符 XMatrix -= otherXMatrix */ XMatrix& operator-=(const XMatrix& m); /** 重载运算符 XMatrix *= XDecimal */ XMatrix& operator*=(XDecimal val); /** 重载运算符 XMatrix /= XDecimal */ XMatrix& operator/=(XDecimal val); // XDecimal*& operator[](size_t row); /** 重载运算符 XMatrix + otherXMatrix */ XMatrix operator+(const XMatrix& m) const; /** 重载运算符 XMatrix - otherXMatrix */ XMatrix operator-(const XMatrix& m) const; /** 重载运算符 XMatrix * otherXMatrix */ XMatrix operator*(const XMatrix& m) const; /** 重载运算符 XMatrix * XDecimal */ XMatrix operator*(XDecimal val) const; /** 重载运算符 XMatrix / XDecimal */ XMatrix operator/(XDecimal val) const; /** 重载运算符 -XMatrix */ friend X_MATH_EXPORT XMatrix operator-(const XMatrix& m); /** 重载运算符 XDecimal * XMatrix */ friend X_MATH_EXPORT XMatrix operator*(XDecimal factor, const XMatrix& m); /** 重载运算符 ostream << XMatrix */ friend X_MATH_EXPORT std::ostream& operator<<(std::ostream& outStream, const XMatrix& m); protected: XMatrix(size_t row, size_t column); XMatrix& swapRow(size_t row1, size_t row2); XMatrix& swapColumn(size_t col1, size_t col2); XMatrix& addTimesRow(size_t targetRow, size_t sourceRow, XDecimal times = 1.0); XMatrix& addTimesColumn(size_t targetCol, size_t sourceCol, XDecimal times = 1.0); protected: /** * @brief 矩阵的数值(行主序) */ XDecimal* _vals{ nullptr }; size_t _row{ 0 }; size_t _column{ 0 }; };
源文件:
XMatrix::XMatrix(size_t row, size_t column, XDecimal val) : _row(row) , _column(column) { XAssert(row != 0 && column != 0); size_t count = row * column; _vals = new XDecimal[count]; // 手动填充数组,避免编译器默认处理不同的情况 for(size_t i = 0; i < count; ++i) { _vals[i] = val; } } XMatrix::XMatrix(size_t row, size_t column, const XDecimal* vals) : _row(row) , _column(column) { XAssert(row != 0 && column != 0); size_t count = row * column; _vals = new XDecimal[count]; for(size_t i = 0; i < count; ++i) { _vals[i] = vals[i]; } } XMatrix::XMatrix(size_t row, size_t column, std::initializer_list<XDecimal> vals) : _row(row) , _column(column) { size_t count = row * column; XAssert(vals.size() == count); _vals = new XDecimal[count]; size_t index = 0; for(auto it = vals.begin(); index < count, it != vals.end(); ++it) { _vals[index] = *it; ++index; } } XMatrix::XMatrix(const XMatrix& m) : _row(m._row) , _column(m._column) { if(m._vals == nullptr) return; size_t count = _row * _column; _vals = new XDecimal[count]; for(size_t i = 0; i < count; ++i) { _vals[i] = m._vals[i]; } } XMatrix::~XMatrix() { if(_vals != nullptr) { delete[] _vals; _vals = nullptr; } } size_t XMatrix::getRowCount() const { return _row; } size_t XMatrix::getColumnCount() const { return _column; } XDecimal XMatrix::getValue(size_t row, size_t column) const { XAssert(row < _row && column < _column); return _vals[_column * row + column]; } void XMatrix::setValue(size_t row, size_t column, XDecimal val) { XAssert(row < _row && column < _column); _vals[_column * row + column] = val; } bool XMatrix::operator==(const XMatrix& m) const { if(_row != m._row || _column != m._column) return false; size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { if(!X::isEqual(_vals[i], m._vals[i])) return false; } return true; } bool XMatrix::operator!=(const XMatrix& m) const { return !(*this == m); } XMatrix& XMatrix::operator+=(const XMatrix& m) { XAssert(_row == m._row && _column == m._column); size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { _vals[i] += m._vals[i]; } return *this; } XMatrix& XMatrix::operator-=(const XMatrix& m) { XAssert(_row == m._row && _column == m._column); size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { _vals[i] -= m._vals[i]; } return *this; } XMatrix& XMatrix::operator*=(XDecimal val) { size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { _vals[i] *= val; } return *this; } XMatrix& XMatrix::operator/=(XDecimal val) { size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { _vals[i] /= val; } return *this; } XMatrix XMatrix::operator+(const XMatrix& m) const { XAssert(_row == m._row && _column == m._column); XMatrix mat(_row, _column); size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { mat._vals[i] = _vals[i] + m._vals[i]; } return mat; } XMatrix XMatrix::operator-(const XMatrix& m) const { XAssert(_row == m._row && _column == m._column); XMatrix mat(_row, _column); size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { mat._vals[i] = _vals[i] - m._vals[i]; } return mat; } XMatrix XMatrix::operator*(const XMatrix& m) const { XAssert(_column == m._row); XMatrix mat(_row, m._column); XDecimal tempVal = 0.0; for(size_t row = 0; row < mat._row; ++row) { for(size_t col = 0; col < mat._column; ++col) { tempVal = 0.0; for(int i = 0; i < _column; ++i) { tempVal += _vals[_column * row + i] * m._vals[m._column * i + col]; } mat._vals[mat._column * row + col] = tempVal; } } return mat; } XMatrix XMatrix::operator*(XDecimal val) const { XMatrix mat(_row, _column); size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { mat._vals[i] = _vals[i] * val; } return mat; } XMatrix XMatrix::operator/(XDecimal val) const { XMatrix mat(_row, _column); size_t count = _row * _column; for(size_t i = 0; i < count; ++i) { mat._vals[i] = _vals[i] / val; } return mat; } XMatrix operator-(const XMatrix& m) { XMatrix mat(m._row, m._column); size_t count = m._row * m._column; for(size_t i = 0; i < count; ++i) { mat._vals[i] = -m._vals[i]; } return mat; } XMatrix operator*(XDecimal val, const XMatrix& m) { return m * val; } std::ostream& operator<<(std::ostream& outStream, const XMatrix& m) { char paddedNumber[25]; outStream << "{"; for(int row = 0; row < m._row; ++row) { outStream << "["; for(int col = 0; col < m._column; ++col) { sprintf(paddedNumber, "% -4.8f", m._vals[row * m._column + col]); outStream << paddedNumber; if(col < m._column) { outStream << " "; } } outStream << "] "; } outStream << "}"; return outStream; } XMatrix::XMatrix(size_t row, size_t column) : _row(row) , _column(column) { XAssert(row != 0 && column != 0); _vals = new XDecimal[row * column]; } XMatrix& XMatrix::swapRow(size_t row1, size_t row2) { XAssert(row1 < _row && row2 < _row); // 如果自己和自己交换则直接返回 if(row1 == row2) return *this; XDecimal tempVal = 0.0; for(size_t col = 0; col < _column; ++col) { tempVal = _vals[_column * row1 + col]; _vals[_column * row1 + col] = _vals[_column * row2 + col]; this->_vals[_column * row2 + col] = tempVal; } return *this; } XMatrix& XMatrix::swapColumn(size_t col1, size_t col2) { XAssert(col1 < _row && col2 < _row); // 如果自己和自己交换则直接返回 if(col1 == col2) return *this; XDecimal tempVal = 0.0; for(size_t row = 0; row < _column; ++row) { tempVal = _vals[_column * row + col1]; _vals[_column * row + col1] = _vals[_column * row + col2]; this->_vals[_column * row + col2] = tempVal; } return *this; } XMatrix& XMatrix::addTimesRow(size_t targetRow, size_t sourceRow, XDecimal times) { XAssert(targetRow < _row && sourceRow < _row); if(!X::isZero(times)) { for(size_t col = 0; col < _column; ++col) { _vals[_column * targetRow + col] += _vals[_column * sourceRow + col] * times; } } return *this; } XMatrix& XMatrix::addTimesColumn(size_t targetCol, size_t sourceCol, XDecimal times) { XAssert(targetCol < _row && sourceCol < _row); if(!X::isZero(times)) { for(size_t row = 0; row < _column; ++row) { _vals[_column * row + targetCol] += _vals[_column * row + sourceCol] * times; } } return *this; } XMatrix& XMatrix::transpose() { auto* vals = new XDecimal[_column * _row]; for(int row = 0; row < _row; ++row) { for(int col = 0; col < _column; ++col) { vals[col * _row + row] = _vals[row * _column + col]; } } delete[] _vals; _vals = vals; auto row = _row; _row = _column; _column = row; return *this; } bool XMatrix::isSquareMatrix() const { return _row == _column; } bool XMatrix::isRowMatrix() const { return _row == 1; } bool XMatrix::isColumnMatrix() const { return _column == 1; } bool XMatrix::isZeroMatrix() const { for(int row = 0; row < _row; ++row) { for(int col = 0; col < _column; ++col) { if(!X::isZero(_vals[row * _column + col])) { return false; } } } return true; }