diff --git a/EvtGenBase/EvtMatrix.hh b/EvtGenBase/EvtMatrix.hh
index dec2312..93e07c1 100644
--- a/EvtGenBase/EvtMatrix.hh
+++ b/EvtGenBase/EvtMatrix.hh
@@ -1,195 +1,196 @@
/***********************************************************************
* Copyright 1998-2020 CERN for the benefit of the EvtGen authors *
* *
* This file is part of EvtGen. *
* *
* EvtGen is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* EvtGen is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with EvtGen. If not, see . *
***********************************************************************/
#ifndef __EVT_MATRIX_HH__
#define __EVT_MATRIX_HH__
+#include
#include
#include
template
class EvtMatrix {
private:
T** _mat;
int _range;
public:
EvtMatrix() : _range( 0 ){};
~EvtMatrix();
inline void setRange( int range );
T& operator()( int row, int col ) { return _mat[row][col]; }
T* operator[]( int row ) { return _mat[row]; }
T det();
EvtMatrix* min( int row, int col );
EvtMatrix* inverse();
std::string dump();
template
friend EvtMatrix* operator*( const EvtMatrix& left,
const EvtMatrix& right );
};
template
inline void EvtMatrix::setRange( int range )
{
// If the range is changed, delete any previous matrix stored
// and allocate elements with the newly specified range.
if ( _range != range ) {
if ( _range ) {
for ( int row = 0; row < _range; row++ )
delete[] _mat[row];
delete[] _mat;
}
_mat = new T*[range];
for ( int row = 0; row < range; row++ )
_mat[row] = new T[range];
// Set the new range.
_range = range;
}
// Since user is willing to change the range, reset the matrix elements.
for ( int row = 0; row < _range; row++ )
for ( int col = 0; col < _range; col++ )
_mat[row][col] = 0.;
}
template
EvtMatrix::~EvtMatrix()
{
for ( int row = 0; row < _range; row++ )
delete[] _mat[row];
delete[] _mat;
}
template
std::string EvtMatrix::dump()
{
std::ostringstream str;
for ( int row = 0; row < _range; row++ ) {
str << "|";
for ( int col = 0; col < _range; col++ )
str << "\t" << _mat[row][col];
str << "\t|" << std::endl;
}
return str.str();
}
template
T EvtMatrix::det()
{
if ( _range == 1 )
return _mat[0][0];
// There's no need to define the range 2 determinant manually, but it may
// speed up the calculation.
if ( _range == 2 )
return _mat[0][0] * _mat[1][1] - _mat[0][1] * _mat[1][0];
T sum = 0.;
for ( int col = 0; col < _range; col++ ) {
EvtMatrix* minor = min( 0, col );
sum += std::pow( -1., col ) * _mat[0][col] * minor->det();
delete minor;
}
return sum;
}
// Returns the minor at (i, j).
template
EvtMatrix* EvtMatrix::min( int row, int col )
{
EvtMatrix* minor = new EvtMatrix();
minor->setRange( _range - 1 );
int minIndex = 0;
for ( int r = 0; r < _range; r++ )
for ( int c = 0; c < _range; c++ )
if ( ( r != row ) && ( c != col ) ) {
( *minor )( minIndex / ( _range - 1 ),
minIndex % ( _range - 1 ) ) = _mat[r][c];
minIndex++;
}
return minor;
}
template
EvtMatrix* EvtMatrix::inverse()
{
EvtMatrix* inv = new EvtMatrix();
inv->setRange( _range );
if ( det() == 0 ) {
std::cerr << "This matrix has a null determinant and cannot be inverted. Returning zero matrix."
<< std::endl;
for ( int row = 0; row < _range; row++ )
for ( int col = 0; col < _range; col++ )
( *inv )( row, col ) = 0.;
return inv;
}
T determinant = det();
for ( int row = 0; row < _range; row++ )
for ( int col = 0; col < _range; col++ ) {
EvtMatrix* minor = min( row, col );
inv->_mat[col][row] = std::pow( -1., row + col ) * minor->det() /
determinant;
delete minor;
}
return inv;
}
template
EvtMatrix* operator*( const EvtMatrix& left, const EvtMatrix& right )
{
// Chech that the matrices have the correct range.
if ( left._range != right._range ) {
std::cerr << "These matrices cannot be multiplied." << std::endl;
return new EvtMatrix();
}
EvtMatrix* mat = new EvtMatrix();
mat->setRange( left._range );
// Initialize the elements of the matrix.
for ( int row = 0; row < left._range; row++ )
for ( int col = 0; col < right._range; col++ )
( *mat )[row][col] = 0;
for ( int row = 0; row < left._range; row++ )
for ( int col = 0; col < right._range; col++ )
for ( int line = 0; line < right._range; line++ )
( *mat )[row][col] += left._mat[row][line] *
right._mat[line][col];
return mat;
}
#endif