c++实现简单矩阵类Mat
2016-04-24 14:19
441 查看
刚学习C++,之前把 Primer 看了一遍,现在也在刷 leetcode,感觉学习编程语言光看书页刷题也是不够的,最好是能做一些实际的项目,这样要用到哪些东西时不明白再看书,就会印象深刻些,否则光看书只是走马观花,看了也就忘了。
打算自己用C++实现一个简单的矩阵 Mat 类,包括一些简单的操作就可以了,但实现起来发现也并没有那么简单,还是遇到很多问题。这个过程也还是学到了不少东西。
废话不多说,直接上程序:
头文件 Mat.h:
刚开始的时候,我是用 int* 来存储的数据,用int* 不方便的一个地方就是需要自己来管理内存,在实现一些 复制赋值和移动操作时都需要很小心处理,来管理内存,稍有不慎就会出问题。所以我后来改用C++标准容器 vector 了。这样就不用自己来管理内存了。
另外,像 + 、- 这样两边的数可以互换位置的操作符,重载时最好用非成员函数来重载,而且像 +=、 -= 这样的复合操作符,必须要用成员函数来重载。另外,可以用 +=、-= 来实现非成员函数的 + 、- 操作。最后要强调的是,非成员的运算符重载函数(+,-,,/)之类,返回类型应该是 const 类型,避免类似 a b = c 这样的语句通过编译。
像输出(<<)、输入(>>)运算符,需要用非成员函数来重载,一般还要声明为友元函数。不过可以通过成员函数完成输入输出操作,非成员函数来重载、调用成员函数输入输出来避免声明友元函数,后面一个例子就是这样实现的。
最后要强调的一点就是针对const对象,需要有对应的const版本的函数,具体说来,就是对一些不改变类类型的变量成员函数,尽量声明为const类型,这样const对象也能调用这些函数。在返回矩阵数据时,需要针对const 和 非const的对象,实现两个函数。
//access element value
int& operator()(size_t i, size_t j);
const int& operator()(size_t i, size_t j) const;
例如,重载 ( ) 运算符时,需要实现两个版本的函数,const 对象调用第二个版本函数,非 const 对象调用第一个版本的函数。
源文件Mat.cpp
头文件 Mat.h
需要说明的是,当把数据存储从指针改为vector之后,也就不需要单独记录矩阵的行和列了,这些信息都蕴含在vector之中了。要注意的是,在使用取元素下标 [ ] 的时候,需要确保vector不为空。否则就会出现 vector out of range 的错误。
Mat.c
需要说明的时,重载输入元算符<<时,我开始用的友元函数,但是出现连接错误,没有找到bug,然后在 stackoverflow 上看到别人推荐这种非友元的方式来实现,就采用了这种方式,避免了之前的链接错误。。。
如有错误,恳请指正!
打算自己用C++实现一个简单的矩阵 Mat 类,包括一些简单的操作就可以了,但实现起来发现也并没有那么简单,还是遇到很多问题。这个过程也还是学到了不少东西。
固定数据类型的矩阵
先从一个固定数据类型的矩阵的开始吧,以 int 为例子,那么我们需要什么呢? 首先,我们需要记录矩阵的行和列,也需要一块内存存储矩阵的数据;然后需要一些简单的构造函数和其他的成员函数,实现一些基本的操作。废话不多说,直接上程序:
头文件 Mat.h:
#ifndef _MAT_H_ #define _MAT_H_ #include <iostream> #include <ostream> #include <vector> #include <cstring> #include <cassert> //implement Mat class in c++ class Mat{ friend std::ostream& operator<<(std::ostream &os, const Mat &m); friend std::istream& operator>>(std::istream &is, Mat &m); public: typedef int value_type; typedef std::vector<int>::size_type size_type; //construct Mat(); Mat(size_t i, size_t j); //copy constructor Mat(const Mat& m); //copy assignment Mat& operator=(const Mat& m); // += Mat& operator+=(const Mat& m); // -= Mat& operator-=(const Mat& m); //destructor ~Mat(); //access element value int& operator()(size_t i, size_t j); const int& operator()(size_t i, size_t j) const; //get row and col number const size_t rows() const{ return row; } const size_t cols() const{ return col; } //resize void resize(size_t nr, size_t nc); private: size_t row; size_t col; std::vector<std::vector<int>> data; }; #endif
刚开始的时候,我是用 int* 来存储的数据,用int* 不方便的一个地方就是需要自己来管理内存,在实现一些 复制赋值和移动操作时都需要很小心处理,来管理内存,稍有不慎就会出问题。所以我后来改用C++标准容器 vector 了。这样就不用自己来管理内存了。
另外,像 + 、- 这样两边的数可以互换位置的操作符,重载时最好用非成员函数来重载,而且像 +=、 -= 这样的复合操作符,必须要用成员函数来重载。另外,可以用 +=、-= 来实现非成员函数的 + 、- 操作。最后要强调的是,非成员的运算符重载函数(+,-,,/)之类,返回类型应该是 const 类型,避免类似 a b = c 这样的语句通过编译。
像输出(<<)、输入(>>)运算符,需要用非成员函数来重载,一般还要声明为友元函数。不过可以通过成员函数完成输入输出操作,非成员函数来重载、调用成员函数输入输出来避免声明友元函数,后面一个例子就是这样实现的。
最后要强调的一点就是针对const对象,需要有对应的const版本的函数,具体说来,就是对一些不改变类类型的变量成员函数,尽量声明为const类型,这样const对象也能调用这些函数。在返回矩阵数据时,需要针对const 和 非const的对象,实现两个函数。
//access element value
int& operator()(size_t i, size_t j);
const int& operator()(size_t i, size_t j) const;
例如,重载 ( ) 运算符时,需要实现两个版本的函数,const 对象调用第二个版本函数,非 const 对象调用第一个版本的函数。
源文件Mat.cpp
#include <istream> #include <sstream> #include <algorithm> #include "matint.h" using std::cout; using std::endl; using std::istream; using std::ostream; using std::stringstream; ostream& operator<<(ostream &os, const Mat&m){ for (size_t i = 0; i < m.row; i++){ for (size_t j = 0; j < m.col; j++){ os << m.data[i][j] << " "; } os << std::endl; } os << std::endl; return os; } istream& operator>>(istream &is, Mat&m){ for (size_t i = 0; i < m.row; i++){ for (size_t j = 0; j < m.col; j++){ is >> m.data[i][j]; } } return is; } // + const Mat operator+(const Mat& m1, const Mat& m2){ Mat t = m1; t += m2; return t; } // - const Mat operator-(const Mat& m1, const Mat& m2){ Mat t = m1; t -= m2; return t; } //constructor Mat::Mat(){ cout << "default constructor" << endl; row = 0; col = 0; data.clear(); } Mat::Mat(size_t i, size_t j){ row = i; col = j; std::vector<std::vector<int>> vdata(row, std::vector<int>(col, 0)); data = std::move(vdata); } //copy constructor Mat::Mat(const Mat& m){ cout << "copy constructor" << endl; row = m.row; col = m.col; data = m.data; } //copy assignment Mat& Mat::operator=(const Mat& m){ cout << "copy assignment" << endl; row = m.row; col = m.col; data = m.data; return *this; } //destructor Mat::~Mat(){ data.clear(); } //access element value int& Mat::operator()(size_t i, size_t j){ assert(i >= 0 && j >= 0 && i < row && j < col); return data[i][j]; } const int& Mat::operator()(size_t i, size_t j) const{ assert(i >= 0 && j >= 0 && i < row && j < col); return data[i][j]; } //resize void Mat::resize(size_t nr, size_t nc){ data.resize(nr); for (size_t i = 0; i < nr; i++){ data[i].resize(nc); } col = nc; row = nr; } // += Mat& Mat::operator+=(const Mat& m){ if (row == m.row && col == m.col){ for (size_t i = 0; i < row; i++) { for (size_t j = 0; j < col; j++) data[i][j] += m.data[i][j]; } } else{ std::cerr << "mat must be the same size." << std::endl; } return *this; } // -= Mat& Mat::operator-=(const Mat& m){ if (row == m.row && col == m.col){ for (size_t i = 0; i < row; i++) { for (size_t j = 0; j < col; j++) data[i][j] -= m.data[i][j]; } } else{ std::cerr << "mat must be the same size." << std::endl; } return *this; } #if 1 int main(){ Mat mat1(3, 4); Mat mat2(3, 4); for (size_t i = 0; i < mat1.rows(); i++){ for (size_t j = 0; j < mat1.cols(); j++){ mat1(i, j) = 1; mat2(i, j) = 3; } } std::cout << "mat1: " << std::endl << mat1; std::cout << "mat2: " << std::endl << mat2; Mat mat3 = (mat2 + mat1); std::cout << "mat3 = mat2 + mat1: " << std::endl << mat3; Mat mat4 = (mat3 + mat2 - mat1); std::cout << "mat4 = mat3 + mat2 - mat1: " << std::endl << mat4; stringstream ss; ss << mat1; ss >> mat4; std::cout << "mat4:" << std::endl << mat4; const Mat mat6(mat4); std::cout << "const mat6:" << std::endl << mat6; cout << mat6(0, 0) << " " << mat6.rows() << " "<<mat6.cols()<<" "; Mat mat7 = mat2; std::cout << "mat7: " << std::endl << mat7; mat2(0, 0) = 11; std::cout << "mat7: " << std::endl << mat7; mat7.resize(2, 3); std::cout << "mat7.resize(2, 3): " << std::endl << mat7; mat7.resize(5, 6); std::cout << "mat7.resize(5, 6): " << std::endl << mat7; return 1; } #endif
使用Template实现通用数据类型的Mat
以上是针对 int 数据类型实现的矩阵类,那么如果我的矩阵数据类型是double怎么办?总不能再重新实现一遍吧。为了实现多种数据类型,有两种思路。第一种就是像 OpenCV 的矩阵类一样,单独用一个变量来定义使用什么样的数据类型,这样做的好处是不需要使用模板;第二种思路就是使用C++的Templeate来实现。我实现的是第二种思路。头文件 Mat.h
#ifndef _MAT_H_ #define _MAT_H_ #include <iostream> #include <ostream> #include <sstream> #include <vector> #include <cstring> #include <cassert> //implement Mat class in c++ template<typename T> class Mat{ public: typedef T value_type; //construct Mat(); Mat(size_t i, size_t j); ////copy constructor Mat(const Mat& m); ////copy assignment Mat& operator=(const Mat&m); // += Mat& operator+=(const Mat& m); // -= Mat& operator-=(const Mat& m); //move constructor Mat( Mat&& m); //move assignment Mat& operator=( Mat&& m); //destructor ~Mat(); //access element value T& operator()(size_t i, size_t j); const T& operator()(size_t i, size_t j) const; //get row and col number const size_t rows() const{ return vdata.size(); } const size_t cols() const{ if (vdata.empty()) return 0; else return vdata[0].size(); } //resize void resize(size_t nr, size_t nc); //print mat void CoutMat(std::ostream& os) const; void CinMat(std::istream& is); private: std::vector<std::vector<T>> vdata; }; #endif
需要说明的是,当把数据存储从指针改为vector之后,也就不需要单独记录矩阵的行和列了,这些信息都蕴含在vector之中了。要注意的是,在使用取元素下标 [ ] 的时候,需要确保vector不为空。否则就会出现 vector out of range 的错误。
Mat.c
#include "mat.h" using std::cout; using std::endl; using std::istream; using std::ostream; using std::stringstream; template<typename T> ostream& operator<<(ostream &os, const Mat<T> &m){ m.CoutMat(os); return os; } template<typename T> istream& operator>>(istream &is, Mat<T>&m){ m.CinMat(is); return is; } // + template<typename T> const Mat<T> operator+(const Mat<T>& m1, const Mat<T>& m2){ Mat<T> t(m1); t += m2; return t; } // - template<typename T> const Mat<T> operator-(const Mat<T>& m1, const Mat<T>& m2){ Mat<T> t(m1); t -= m2; return t; } // print mat template<typename T> void Mat<T>::CoutMat(std::ostream& os) const { if (vdata.empty()) return; for (size_t i = 0; i < vdata.size(); i++){ for (size_t j = 0; j < vdata[0].size(); j++){ os << vdata[i][j] << " "; } os << std::endl; } os << std::endl; } template<typename T> void Mat<T>::CinMat(std::istream& is) { if (vdata.empty()) return; for (size_t i = 0; i < vdata.size(); i++){ for (size_t j = 0; j < vdata[0].size(); j++){ is >> vdata[i][j]; } } } //construct template<typename T> Mat<T>::Mat(){ cout << "default constructor" << endl; vdata.clear(); } template<typename T> Mat<T>::Mat(size_t i, size_t j){ std::vector<std::vector<T>> tdata(i, std::vector<T>(j, 0)); vdata = std::move(tdata); } //copy constructor template<typename T> Mat<T>::Mat(const Mat<T>& m){ cout << "copy constructor" << endl; vdata.assign(m.vdata.cbegin(), m.vdata.cend()); } //copy assignment template<typename T> Mat<T>& Mat<T>::operator=(const Mat<T>& m){ cout << "copy assignment" << endl; if (this != &m){ vdata.assign(m.vdata.cbegin(), m.vdata.cend()); } return *this; } //move constructor template<typename T> Mat<T>::Mat( Mat<T>&& m ){ cout << "move constructor" << endl; vdata = std::move(m.vdata); } //move assignment template<typename T> Mat<T>& Mat<T>::operator=(Mat<T>&& m){ cout << "move assignment" << endl; if (this != &m){ vdata.clear(); vdata = std::move(m.vdata); } return *this; } //destructor template<typename T> Mat<T>::~Mat(){ vdata.clear(); } //access element value template<typename T> inline T& Mat<T>::operator()(size_t i, size_t j){ assert(!vdata.empty()); assert(i >= 0 && j >= 0 && i < vdata.size() && j < vdata[0].size()); return vdata[i][j]; } template<typename T> inline const T& Mat<T>::operator()(size_t i, size_t j) const{ assert(!vdata.empty()); assert(i >= 0 && j >= 0 && i < vdata.size() && j < vdata[0].size()); return vdata[i][j]; } // += template<typename T> Mat<T>& Mat<T>::operator+=(const Mat<T>& m){ if (vdata.empty() || m.vdata.empty()) return *this; const size_t row = vdata.size(); const size_t col = vdata[0].size(); const size_t mrow = m.vdata.size(); const size_t mcol = m.vdata[0].size(); if (row == mrow && col == mcol){ for (size_t i = 0; i < row; i++) for (size_t j = 0; j < col; j++) vdata[i][j] += m.vdata[i][j]; } else{ std::cerr << "mat must be the same size." << std::endl; } return *this; } // -= template<typename T> Mat<T>& Mat<T>::operator-=(const Mat<T>& m){ if (vdata.empty() || m.vdata.empty()) return *this; const size_t row = vdata.size(); const size_t col = vdata[0].size(); const size_t mrow = m.vdata.size(); const size_t mcol = m.vdata[0].size(); if (row == mrow && col == mcol){ for (size_t i = 0; i < row; i++) for (size_t j = 0; j < col; j++) vdata[i][j] -= m.vdata[i][j]; } else{ std::cerr << "mat must be the same size." << std::endl; } return *this; } //resize template<typename T> void Mat<T>::resize(size_t nr, size_t nc){ vdata.resize(nr); for (size_t i = 0; i < nr; i++){ vdata[i].resize(nc); } } //test Mat class typedef double Type; int main(){ Mat<Type> mat1(3, 4); Mat<Type> mat2(3, 4); for (size_t i = 0; i < mat1.rows(); i++){ for (size_t j = 0; j < mat1.cols(); j++){ mat1(i, j) = i*mat1.cols() + j; mat2(i, j) = 2 * i*mat1.cols() + 2 * j; } } std::cout << "mat1: " << std::endl << mat1; std::cout << "mat2: " << std::endl << mat2; Mat<Type> mat3 = (mat2 + mat1); std::cout << "mat3 = mat2 + mat1: " << std::endl << mat3; Mat<Type> mat4 = (mat3 + mat2 - mat1); std::cout << "mat4 = mat3 + mat2 - mat1: " << std::endl << mat4; stringstream ss; ss << mat1; ss >> mat4; std::cout << "mat4:" << std::endl << mat4; const Mat<Type> mat6(mat4); std::cout << "const mat6:" << std::endl << mat6; cout << mat6(0, 0) << " " << mat6.rows() << " " << mat6.cols() << endl; Mat<Type> mat7; mat7 = std::move(mat1); std::cout << "mat1: " << std::endl << mat1; std::cout << "mat7: " << std::endl << mat7; mat7.resize(2, 3); std::cout << "mat7.resize(2, 3): " << std::endl << mat7; mat7.resize(4, 6); std::cout << "mat7.resize(4, 6): " << std::endl << mat7; Mat<Type> mat8; cout << " " << mat8.rows() << " " << mat8.cols() << endl; //this will cause assertion error since mat8 is empty //cout<<mat8(0,0)<<endl; return 1; }
需要说明的时,重载输入元算符<<时,我开始用的友元函数,但是出现连接错误,没有找到bug,然后在 stackoverflow 上看到别人推荐这种非友元的方式来实现,就采用了这种方式,避免了之前的链接错误。。。
如有错误,恳请指正!
相关文章推荐