00001
00002
00003
00004 #ifndef MATRIX_H
00005 #define MATRIX_H
00006
00007 #include "Object.h"
00008 #include "ObjectParser.h"
00009 #include "typetraits.h"
00010 #include "binio.h"
00011 #include "net_types.h"
00012
00013 namespace FD {
00014
00015 class BaseMatrix : public Object {
00016
00017 public:
00019 BaseMatrix(){}
00020
00022 virtual size_t msize() const = 0;
00023
00025 virtual bool mempty() const = 0;
00026
00033 virtual ObjectRef getIndex(int _row, int _col) {
00034 throw new GeneralException(std::string("Matrix index not implemented for object : ") + className(),__FILE__,__LINE__);
00035 }
00036
00043 virtual void setIndex(int _row, int _col, ObjectRef val) {
00044 throw new GeneralException(std::string("Matrix index not implemented for object : ") + className(),__FILE__,__LINE__);
00045 }
00046
00050 virtual ObjectRef clone() = 0;
00051
00052 };
00053
00064 template<class T>
00065 class Matrix : public BaseMatrix
00066 {
00067 protected:
00068
00070 int rows;
00071
00073 int cols;
00074
00075
00076 T *data;
00077
00078 public:
00079
00081 typedef T basicType;
00082
00084 Matrix()
00085 : rows(0)
00086 , cols(0)
00087 , data(NULL)
00088 {}
00089
00095 Matrix(const Matrix &mat, bool transpose=false)
00096 : rows(mat.rows)
00097 , cols(mat.cols)
00098 , data(new T [rows*cols])
00099 {
00100 if (transpose)
00101 {
00102 rows=mat.cols;
00103 cols=mat.rows;
00104 for (int i=0;i<rows;i++)
00105 for (int j=0;j<cols;j++)
00106 data[i*cols+j] = mat.data[j*mat.cols+i];
00107 } else {
00108 for (int i=0;i<rows*cols;i++)
00109 data[i] = mat.data[i];
00110 }
00111 }
00112
00118 Matrix(int _rows, int _cols)
00119 : rows(_rows)
00120 , cols(_cols)
00121 , data(new T [_rows*_cols])
00122 {}
00123
00125 virtual ~Matrix() {delete [] data;}
00126
00127
00134 void resize(int _rows, int _cols)
00135 {
00136 T *new_data = new T [_rows*_cols];
00137 int min_rows = _rows < rows ? _rows : rows;
00138 int min_cols = _cols < cols ? _cols : cols;
00139 for (int i=0;i<min_rows;i++)
00140 for (int j=0;j<min_cols;j++)
00141 new_data[i*_cols+j] = data[i*cols+j];
00142 if (data)
00143 delete [] data;
00144 data = new_data;
00145 cols = _cols;
00146 rows = _rows;
00147 }
00148
00154 T *operator [] (int i)
00155 {
00156 return data+i*cols;
00157 }
00158
00164 const T *operator [] (int i) const
00165 {
00166 return data+i*cols;
00167 }
00168
00169
00176 T &operator () (int i, int j)
00177 {
00178 return data[i*cols+j];
00179 }
00180
00187 const T &operator () (int i, int j) const
00188 {
00189 return data[i*cols+j];
00190 }
00191
00193 int nrows() const {return rows;}
00194
00196 int ncols() const {return cols;}
00197
00198
00199
00201 void transpose()
00202 {
00203 if (rows==cols)
00204 {
00205 for (int i=0;i<rows;i++)
00206 for (int j=0;j<i+1;j++)
00207 {
00208 float tmp=data[i*cols+j];
00209 data[i*cols+j] = data[j*cols+i];
00210 data[j*cols+i] = tmp;
00211 }
00212 } else {
00213 Matrix mat(*this);
00214 for (int i=0;i<rows;i++)
00215 for (int j=0;j<cols;j++)
00216 data[i*cols+j] = mat.data[j*mat.cols+i];
00217 }
00218 }
00219
00225 void printOn(std::ostream &out) const
00226 {
00227 out << "<"<<className() << std::endl;
00228 out << "<rows " << rows << ">" << std::endl;
00229 out << "<cols " << cols << ">" << std::endl;
00230 out << "<data " << std::endl;
00231 for (int i=0;i<rows;i++)
00232 {
00233 for (int j=0;j<cols;j++)
00234 out << data[i*cols + j] << " ";
00235 out << std::endl;
00236 }
00237 out << ">" << std::endl;
00238 out << ">\n";
00239 }
00240
00241
00242
00243
00244
00245
00246
00247
00248
00254 void readFrom(std::istream &in=std::cin);
00255
00256
00261 int size() const {
00262 return (cols * rows);
00263 }
00264
00266 virtual size_t msize() const {return cols * rows;}
00267
00269 virtual bool mempty() const {return (cols == 0 && rows == 0);}
00270
00271
00277 virtual void serialize(std::ostream &out) const;
00278
00279
00285 virtual void unserialize(std::istream &in);
00286
00288 static std::string GetClassName()
00289 {
00290 std::string name = ObjectGetClassName<Matrix<T> >();
00291 if (name == "unknown")
00292 return std::string("Matrix");
00293 else
00294 return name;
00295 }
00296
00303 virtual ObjectRef getIndex(int _row, int _col);
00304
00305
00312 virtual void setIndex(int _row, int _col, ObjectRef val);
00313
00314
00318 virtual ObjectRef clone();
00319
00320 };
00321
00323 template <class T>
00324 inline ObjectRef Matrix<T>::clone() {
00325
00326 Matrix<T> *cpy = new Matrix<T>(this->nrows(), this->ncols());
00327
00328 for (int i = 0; i < this->nrows(); i++) {
00329 for (int j = 0; j < this->ncols(); j++) {
00330 (*cpy)(i,j) = (*this)(i,j);
00331 }
00332 }
00333 return ObjectRef(cpy);
00334 }
00335
00336
00338 template <>
00339 inline ObjectRef Matrix<ObjectRef>::clone() {
00340
00341 Matrix<ObjectRef> *cpy = new Matrix<ObjectRef>(this->nrows(), this->ncols());
00342
00343 for (int i = 0; i < this->nrows(); i++) {
00344 for (int j = 0; j < this->ncols(); j++) {
00345
00346 (*cpy)(i,j) = (*this)(i,j)->clone();
00347 }
00348 }
00349
00350 return ObjectRef(cpy);
00351 }
00352
00353 template <class T>
00354 inline void Matrix<T>::readFrom(std::istream &in)
00355 {
00356 std::string tag;
00357 int new_cols, new_rows;
00358 while (1)
00359 {
00360 char ch;
00361 in >> ch;
00362 if (ch == '>') break;
00363 else if (ch != '<')
00364 throw new ParsingException ("Matrix<T>::readFrom : Parse error: '<' expected");
00365 in >> tag;
00366 if (tag == "rows")
00367 in >> new_rows;
00368 else if (tag == "cols")
00369 in >> new_cols;
00370 else if (tag == "data")
00371 {
00372 resize(new_rows,new_cols);
00373 for (int i=0;i<rows*cols;i++)
00374 in >> data[i];
00375 } else
00376 throw new ParsingException ("Matrix<T>::readFrom : unknown argument: " + tag);
00377
00378 if (!in) throw new ParsingException ("Matrix<T>::readFrom : Parse error trying to build " + tag);
00379
00380 in >> tag;
00381 if (tag != ">")
00382 throw new ParsingException ("Matrix<T>::readFrom : Parse error: '>' expected ");
00383 }
00384
00385 }
00386
00387
00388 template<class T, int I>
00389 struct MatrixMethod {
00390 static inline void serialize(const Matrix<T> &m, std::ostream &out)
00391 {
00392 throw new GeneralException("MatrixMethod default serialize should never be called", __FILE__, __LINE__);
00393 }
00394 static inline void unserialize(Matrix<T> &m, std::istream &in)
00395 {
00396 throw new GeneralException("MatrixMethod default unserialize should never be called", __FILE__, __LINE__);
00397 }
00398 static inline ObjectRef getIndex(Matrix<T> &m, int _row, int _col)
00399 {
00400 throw new GeneralException("MatrixMethod getIndex should never be called", __FILE__, __LINE__);
00401 }
00402 static inline void setIndex(Matrix<T> &m, int _row, int _col, ObjectRef val)
00403 {
00404 throw new GeneralException("MatrixMethod setIndex should never be called", __FILE__, __LINE__);
00405 }
00406 };
00407
00408 template<class T>
00409 struct MatrixMethod<T,TTraits::Object> {
00410
00411 static inline void serialize(const Matrix<T> &m, std::ostream &out) {
00412 out << "{" << m.className() << std::endl;
00413 out << "|";
00414
00415
00416 int tmp = m.nrows();
00417 BinIO::write(out, &tmp, 1);
00418
00419
00420 tmp = m.ncols();
00421 BinIO::write(out, &tmp, 1);
00422
00423
00424 for (size_t i=0;i<m.nrows();i++) {
00425 for (size_t j=0;j < m.ncols(); j++) {
00426 m(i,j).serialize(out);
00427 }
00428 }
00429 out << "}";
00430 }
00431
00432 static inline void unserialize(Matrix<T> &m, std::istream &in)
00433 {
00434 int ncols, nrows;
00435 std::string expected = Matrix<T>::GetClassName();
00436
00437
00438 BinIO::read(in, &nrows, 1);
00439 BinIO::read(in, &ncols, 1);
00440
00441
00442 m.resize(nrows,ncols);
00443
00444
00445 for (size_t i=0;i<m.nrows();i++) {
00446 for (size_t j=0;j<m.ncols();j++) {
00447 if (!isValidType(in, expected))
00448 throw new ParsingException("Expected type " + expected);
00449 m(i,j).unserialize(in);
00450 }
00451 }
00452
00453
00454 char ch;
00455 in >> ch;
00456 }
00457
00458 static inline ObjectRef getIndex(Matrix<T> &m, int _row, int _col)
00459 {
00460 if (_row < 0 || _row >= m.nrows() ||
00461 _col < 0 || _col >= m.ncols() ) {
00462 throw new GeneralException("Matrix getIndex : index out of bound",__FILE__,__LINE__);
00463 }
00464
00465 return ObjectRef(m(_row,_col).clone());
00466 }
00467 static inline void setIndex(Matrix<T> &m, int _row, int _col, ObjectRef val)
00468 {
00469
00470 if (_row < 0 || _row >= m.nrows() ||
00471 _col < 0 || _col >= m.ncols() ) {
00472 throw new GeneralException("Matrix setIndex : index out of bound",__FILE__,__LINE__);
00473 }
00474 RCPtr<T> obj = val;
00475 m(_row,_col) = *obj;
00476 }
00477
00478 };
00479
00480
00481 template<class T>
00482 struct MatrixMethod<T,TTraits::ObjectPointer> {
00483 static inline void serialize(const Matrix<T> &m, std::ostream &out)
00484 {
00485 out << "{" << m.className() << std::endl;
00486 out << "|";
00487
00488
00489 int tmp = m.nrows();
00490 BinIO::write(out, &tmp, 1);
00491
00492
00493 tmp = m.ncols();
00494 BinIO::write(out, &tmp, 1);
00495
00496
00497 for (size_t i=0;i<m.nrows();i++) {
00498 for (size_t j=0;j < m.ncols(); j++) {
00499 m(i,j)->serialize(out);
00500 }
00501 }
00502
00503 out << "}";
00504 }
00505
00506 static inline void unserialize(Matrix<T> &m, std::istream &in)
00507 {
00508 int nrows,ncols;
00509
00510
00511 BinIO::read(in, &nrows, 1);
00512 BinIO::read(in, &ncols, 1);
00513
00514
00515 m.resize(nrows,ncols);
00516
00517 for (size_t i=0;i<m.nrows();i++) {
00518 for (size_t j=0;j<m.ncols();j++) {
00519 in >> m(i,j);
00520 }
00521 }
00522
00523 char ch;
00524 in >> ch;
00525 }
00526
00527 static inline ObjectRef getIndex(Matrix<T> &m, int _row, int _col)
00528 {
00529 if (_row < 0 || _row >= m.nrows() ||
00530 _col < 0 || _col >= m.ncols() ) {
00531 throw new GeneralException("Matrix getIndex : index out of bound",__FILE__,__LINE__);
00532 }
00533 return m(_row,_col);
00534
00535 }
00536 static inline void setIndex(Matrix<T> &m, int _row, int _col, ObjectRef val)
00537 {
00538
00539 if (_row < 0 || _row >= m.nrows() ||
00540 _col < 0 || _col >= m.ncols() ) {
00541 throw new GeneralException("Matrix setIndex : index out of bound",__FILE__,__LINE__);
00542 }
00543 m(_row,_col) = val;
00544 }
00545
00546 };
00547
00548
00549 template<class T>
00550 struct MatrixMethod<T,TTraits::Basic> {
00551 static inline void serialize(const Matrix<T> &m, std::ostream &out)
00552 {
00553 out << "{" << m.className() << std::endl;
00554 out << "|";
00555
00556
00557 int tmp = m.nrows();
00558 BinIO::write(out, &tmp, 1);
00559
00560
00561 tmp = m.ncols();
00562 BinIO::write(out, &tmp, 1);
00563
00564
00565 BinIO::write(out,m[0], m.size());
00566
00567 out << "}";
00568 }
00569 static inline void unserialize(Matrix<T> &m, std::istream &in)
00570 {
00571 int nrows,ncols;
00572
00573
00574 BinIO::read(in, &nrows, 1);
00575 BinIO::read(in, &ncols, 1);
00576
00577
00578 m.resize(nrows,ncols);
00579
00580
00581 BinIO::read(in,m[0], m.size());
00582 char ch;
00583 in >> ch;
00584 }
00585 static inline ObjectRef getIndex(Matrix<T> &m, int _row, int _col)
00586 {
00587 if (_row < 0 || _row >= m.nrows() ||
00588 _col < 0 || _col >= m.ncols() ) {
00589 throw new GeneralException("Matrix getIndex : index out of bound",__FILE__,__LINE__);
00590 }
00591
00592 return ObjectRef(NetCType<T>::alloc(m(_row,_col)));
00593 }
00594 static inline void setIndex(Matrix<T> &m, int _row, int _col, ObjectRef val)
00595 {
00596
00597 if (_row < 0 || _row >= m.nrows() ||
00598 _col < 0 || _col >= m.ncols() ) {
00599 throw new GeneralException("Matrix setIndex : index out of bound",__FILE__,__LINE__);
00600 }
00601
00602 RCPtr<NetCType<T> > obj = val;
00603 m(_row,_col) = *obj;
00604 }
00605
00606 };
00607
00608
00609 template<class T>
00610 struct MatrixMethod<T,TTraits::Unknown> {
00611 static inline void serialize(const Matrix<T> &m, std::ostream &out)
00612 {
00613 throw new GeneralException(std::string("Sorry, can't serialize this kind of object (") + typeid(T).name()
00614 + ")", __FILE__, __LINE__);
00615 }
00616 static inline void unserialize(Matrix<T> &m, std::istream &in)
00617 {
00618 throw new GeneralException(std::string("Sorry, can't unserialize this kind of object (") + typeid(T).name()
00619 + ")", __FILE__, __LINE__);
00620 }
00621 static inline ObjectRef getIndex(Matrix<T> &m, int _row, int _col)
00622 {
00623 throw new GeneralException(std::string("Sorry, can't getIndex this kind of object (") + typeid(T).name()
00624 + ")", __FILE__, __LINE__);
00625 }
00626 static inline void setIndex(Matrix<T> &m, int _row, int _col, ObjectRef val)
00627 {
00628 throw new GeneralException(std::string("Sorry, can't setIndex this kind of object (") + typeid(T).name()
00629 + ")", __FILE__, __LINE__);
00630 }
00631 };
00632
00633
00634 template <class T>
00635 inline void Matrix<T>::serialize(std::ostream &out) const
00636 {
00637 MatrixMethod<T, TypeTraits<T>::kind>::serialize(*this, out);
00638 }
00639
00640 template <class T>
00641 inline void Matrix<T>::unserialize(std::istream &in)
00642 {
00643 MatrixMethod<T, TypeTraits<T>::kind>::unserialize(*this, in);
00644 }
00645
00646 template <class T>
00647 inline ObjectRef Matrix<T>::getIndex(int _row, int _col)
00648 {
00649 return MatrixMethod<T, TypeTraits<T>::kind>::getIndex(*this,_row,_col);
00650 }
00651
00652 template <class T>
00653 inline void Matrix<T>::setIndex(int _row, int _col, ObjectRef val)
00654 {
00655 MatrixMethod<T,TypeTraits<T>::kind>::setIndex(*this,_row,_col,val);
00656 }
00657
00658 }
00659
00660 #endif