矩阵类

头文件:

    class X_MATH_EXPORT XMatrix {
    public:
        XMatrix() = delete;
        XMatrix(size_t row, size_t column, XDecimal val);
        XMatrix(size_t row, size_t column, const XDecimal vals[]);
        XMatrix(size_t row, size_t column, std::initializer_list<XDecimal> vals);
        XMatrix(const XMatrix& m);
        virtual ~XMatrix();

        /** 获取行数 */
        size_t                   getRowCount() const;
        /** 获取列数 */
        size_t                   getColumnCount() const;
        /** 获取元素值 */
        XDecimal                   getValue(size_t row, size_t column) const;
        /** 设置元素值 */
        void                   setValue(size_t row, size_t column, XDecimal val);
        /** 矩阵转置 */
        virtual XMatrix&       transpose();
        /** 是否为方阵 */
        virtual bool           isSquareMatrix() const;
        /** 是否为行矩阵 */
        virtual bool           isRowMatrix() const;
        /** 是否为列矩阵 */
        virtual bool           isColumnMatrix() const;
        /** 是否为零矩阵 */
        virtual bool           isZeroMatrix() const;
        /** 重载运算符 XMatrix == otherXMatrix */
        bool                   operator==(const XMatrix& m) const;
        /** 重载运算符 XMatrix != otherXMatrix */
        bool                   operator!=(const XMatrix& m) const;
        /** 重载运算符 XMatrix += otherXMatrix */
        XMatrix&               operator+=(const XMatrix& m);
        /** 重载运算符 XMatrix -= otherXMatrix */
        XMatrix&               operator-=(const XMatrix& m);
        /** 重载运算符 XMatrix *= XDecimal */
        XMatrix&               operator*=(XDecimal val);
        /** 重载运算符 XMatrix /= XDecimal */
        XMatrix&               operator/=(XDecimal val);

        // XDecimal*&               operator[](size_t row);

        /** 重载运算符 XMatrix + otherXMatrix */
        XMatrix                   operator+(const XMatrix& m) const;
        /** 重载运算符 XMatrix - otherXMatrix */
        XMatrix                   operator-(const XMatrix& m) const;
        /** 重载运算符 XMatrix * otherXMatrix */
        XMatrix                   operator*(const XMatrix& m) const;
        /** 重载运算符 XMatrix * XDecimal */
        XMatrix                   operator*(XDecimal val) const;
        /** 重载运算符 XMatrix / XDecimal */
        XMatrix                   operator/(XDecimal val) const;
        /** 重载运算符 -XMatrix */
        friend X_MATH_EXPORT XMatrix operator-(const XMatrix& m);
        /** 重载运算符 XDecimal * XMatrix */
        friend X_MATH_EXPORT XMatrix operator*(XDecimal factor, const XMatrix& m);
        /** 重载运算符 ostream << XMatrix */
        friend X_MATH_EXPORT std::ostream& operator<<(std::ostream& outStream, const XMatrix& m);

    protected:
        XMatrix(size_t row, size_t column);
        XMatrix& swapRow(size_t row1, size_t row2);
        XMatrix& swapColumn(size_t col1, size_t col2);
        XMatrix& addTimesRow(size_t targetRow, size_t sourceRow, XDecimal times = 1.0);
        XMatrix& addTimesColumn(size_t targetCol, size_t sourceCol, XDecimal times = 1.0);

    protected:
        /**
         * @brief 矩阵的数值(行主序)
         */
        XDecimal* _vals{ nullptr };
        size_t     _row{ 0 };
        size_t     _column{ 0 };
    };

 

源文件:

XMatrix::XMatrix(size_t row, size_t column, XDecimal val)
        : _row(row)
        , _column(column)
    {
        XAssert(row != 0 && column != 0);
        size_t count = row * column;
        _vals         = new XDecimal[count];
        // 手动填充数组,避免编译器默认处理不同的情况
        for(size_t i = 0; i < count; ++i) {
            _vals[i] = val;
        }
    }

    XMatrix::XMatrix(size_t row, size_t column, const XDecimal* vals)
        : _row(row)
        , _column(column)
    {
        XAssert(row != 0 && column != 0);
        size_t count = row * column;
        _vals         = new XDecimal[count];
        for(size_t i = 0; i < count; ++i) {
            _vals[i] = vals[i];
        }
    }

    XMatrix::XMatrix(size_t row, size_t column, std::initializer_list<XDecimal> vals)
        : _row(row)
        , _column(column)
    {
        size_t count = row * column;
        XAssert(vals.size() == count);
        _vals         = new XDecimal[count];
        size_t index = 0;
        for(auto it = vals.begin(); index < count, it != vals.end(); ++it) {
            _vals[index] = *it;
            ++index;
        }
    }

    XMatrix::XMatrix(const XMatrix& m)
        : _row(m._row)
        , _column(m._column)
    {
        if(m._vals == nullptr)
            return;
        size_t count = _row * _column;
        _vals         = new XDecimal[count];
        for(size_t i = 0; i < count; ++i) {
            _vals[i] = m._vals[i];
        }
    }
    XMatrix::~XMatrix()
    {
        if(_vals != nullptr) {
            delete[] _vals;
            _vals = nullptr;
        }
    }

    size_t XMatrix::getRowCount() const
    {
        return _row;
    }
    size_t XMatrix::getColumnCount() const
    {
        return _column;
    }
    XDecimal XMatrix::getValue(size_t row, size_t column) const
    {
        XAssert(row < _row && column < _column);
        return _vals[_column * row + column];
    }
    void XMatrix::setValue(size_t row, size_t column, XDecimal val)
    {
        XAssert(row < _row && column < _column);
        _vals[_column * row + column] = val;
    }

    bool XMatrix::operator==(const XMatrix& m) const
    {
        if(_row != m._row || _column != m._column)
            return false;
        size_t count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            if(!X::isEqual(_vals[i], m._vals[i]))
                return false;
        }
        return true;
    }
    bool XMatrix::operator!=(const XMatrix& m) const
    {
        return !(*this == m);
    }
    XMatrix& XMatrix::operator+=(const XMatrix& m)
    {
        XAssert(_row == m._row && _column == m._column);
        size_t count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            _vals[i] += m._vals[i];
        }
        return *this;
    }
    XMatrix& XMatrix::operator-=(const XMatrix& m)
    {
        XAssert(_row == m._row && _column == m._column);
        size_t count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            _vals[i] -= m._vals[i];
        }
        return *this;
    }
    XMatrix& XMatrix::operator*=(XDecimal val)
    {
        size_t count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            _vals[i] *= val;
        }
        return *this;
    }
    XMatrix& XMatrix::operator/=(XDecimal val)
    {
        size_t count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            _vals[i] /= val;
        }
        return *this;
    }
    XMatrix XMatrix::operator+(const XMatrix& m) const
    {
        XAssert(_row == m._row && _column == m._column);
        XMatrix mat(_row, _column);
        size_t    count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            mat._vals[i] = _vals[i] + m._vals[i];
        }
        return mat;
    }
    XMatrix XMatrix::operator-(const XMatrix& m) const
    {
        XAssert(_row == m._row && _column == m._column);
        XMatrix mat(_row, _column);
        size_t    count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            mat._vals[i] = _vals[i] - m._vals[i];
        }
        return mat;
    }

    XMatrix XMatrix::operator*(const XMatrix& m) const
    {
        XAssert(_column == m._row);
        XMatrix mat(_row, m._column);
        XDecimal tempVal = 0.0;
        for(size_t row = 0; row < mat._row; ++row) {
            for(size_t col = 0; col < mat._column; ++col) {
                tempVal = 0.0;
                for(int i = 0; i < _column; ++i) {
                    tempVal += _vals[_column * row + i] * m._vals[m._column * i + col];
                }
                mat._vals[mat._column * row + col] = tempVal;
            }
        }
        return mat;
    }

    XMatrix XMatrix::operator*(XDecimal val) const
    {
        XMatrix mat(_row, _column);
        size_t    count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            mat._vals[i] = _vals[i] * val;
        }
        return mat;
    }
    XMatrix XMatrix::operator/(XDecimal val) const
    {
        XMatrix mat(_row, _column);
        size_t    count = _row * _column;
        for(size_t i = 0; i < count; ++i) {
            mat._vals[i] = _vals[i] / val;
        }
        return mat;
    }
    XMatrix operator-(const XMatrix& m)
    {
        XMatrix mat(m._row, m._column);
        size_t    count = m._row * m._column;
        for(size_t i = 0; i < count; ++i) {
            mat._vals[i] = -m._vals[i];
        }
        return mat;
    }
    XMatrix operator*(XDecimal val, const XMatrix& m)
    {
        return m * val;
    }
    std::ostream& operator<<(std::ostream& outStream, const XMatrix& m)
    {
        char paddedNumber[25];
        outStream << "{";
        for(int row = 0; row < m._row; ++row) {
            outStream << "[";
            for(int col = 0; col < m._column; ++col) {
                sprintf(paddedNumber, "% -4.8f", m._vals[row * m._column + col]);
                outStream << paddedNumber;
                if(col < m._column) {
                    outStream << " ";
                }
            }
            outStream << "] ";
        }
        outStream << "}";

        return outStream;
    }
    XMatrix::XMatrix(size_t row, size_t column)
        : _row(row)
        , _column(column)
    {
        XAssert(row != 0 && column != 0);
        _vals = new XDecimal[row * column];
    }

    XMatrix& XMatrix::swapRow(size_t row1, size_t row2)
    {
        XAssert(row1 < _row && row2 < _row);
        // 如果自己和自己交换则直接返回
        if(row1 == row2)
            return *this;
        XDecimal tempVal = 0.0;
        for(size_t col = 0; col < _column; ++col) {
            tempVal                              = _vals[_column * row1 + col];
            _vals[_column * row1 + col]          = _vals[_column * row2 + col];
            this->_vals[_column * row2 + col] = tempVal;
        }
        return *this;
    }
    XMatrix& XMatrix::swapColumn(size_t col1, size_t col2)
    {
        XAssert(col1 < _row && col2 < _row);
        // 如果自己和自己交换则直接返回
        if(col1 == col2)
            return *this;
        XDecimal tempVal = 0.0;
        for(size_t row = 0; row < _column; ++row) {
            tempVal                              = _vals[_column * row + col1];
            _vals[_column * row + col1]          = _vals[_column * row + col2];
            this->_vals[_column * row + col2] = tempVal;
        }
        return *this;
    }

    XMatrix& XMatrix::addTimesRow(size_t targetRow, size_t sourceRow, XDecimal times)
    {
        XAssert(targetRow < _row && sourceRow < _row);
        if(!X::isZero(times)) {
            for(size_t col = 0; col < _column; ++col) {
                _vals[_column * targetRow + col] += _vals[_column * sourceRow + col] * times;
            }
        }
        return *this;
    }
    XMatrix& XMatrix::addTimesColumn(size_t targetCol, size_t sourceCol, XDecimal times)
    {
        XAssert(targetCol < _row && sourceCol < _row);
        if(!X::isZero(times)) {
            for(size_t row = 0; row < _column; ++row) {
                _vals[_column * row + targetCol] += _vals[_column * row + sourceCol] * times;
            }
        }
        return *this;
    }

    XMatrix& XMatrix::transpose()
    {
        auto* vals = new XDecimal[_column * _row];
        for(int row = 0; row < _row; ++row) {
            for(int col = 0; col < _column; ++col) {
                vals[col * _row + row] = _vals[row * _column + col];
            }
        }
        delete[] _vals;
        _vals     = vals;
        auto row = _row;
        _row     = _column;
        _column     = row;
        return *this;
    }
    bool XMatrix::isSquareMatrix() const
    {
        return _row == _column;
    }
    bool XMatrix::isRowMatrix() const
    {
        return _row == 1;
    }
    bool XMatrix::isColumnMatrix() const
    {
        return _column == 1;
    }
    bool XMatrix::isZeroMatrix() const
    {
        for(int row = 0; row < _row; ++row) {
            for(int col = 0; col < _column; ++col) {
                if(!X::isZero(_vals[row * _column + col])) {
                    return false;
                }
            }
        }
        return true;
    }

 

posted @ 2024-07-17 11:43  禅元天道  阅读(8)  评论(0编辑  收藏  举报