blas3.hpp

Go to the documentation of this file.
00001 //
00002 //  Copyright Toon Knapen and Kresimir Fresl
00003 //
00004 // Distributed under the Boost Software License, Version 1.0.
00005 // (See accompanying file LICENSE_1_0.txt or copy at
00006 // http://www.boost.org/LICENSE_1_0.txt)
00007 //
00008 
00009 #ifndef BOOST_BINDINGS_BLAS_BLAS3_HPP
00010 #define BOOST_BINDINGS_BLAS_BLAS3_HPP
00011 
00012 #include <boost/numeric/bindings/blas/blas3_overloads.hpp>
00013 #include <boost/numeric/bindings/traits/traits.hpp>
00014 #include <boost/numeric/bindings/traits/transpose.hpp>
00015 
00016 namespace boost { namespace numeric { namespace bindings { namespace blas {
00017 
00018   // C <- alpha * op (A) * op (B) + beta * C 
00019   // op (X) == X || X^T || X^H
00020   template < typename value_type, typename matrix_type_a, typename matrix_type_b, typename matrix_type_c >
00021   // ! CAUTION this function assumes that all matrices involved are column-major matrices
00022   void gemm(const char TRANSA, const char TRANSB, 
00023             const value_type& alpha,
00024             const matrix_type_a &a,
00025             const matrix_type_b &b,
00026             const value_type &beta,
00027             matrix_type_c &c
00028             )
00029   {
00030     const int m = TRANSA == traits::NO_TRANSPOSE ? traits::matrix_size1( a ) : traits::matrix_size2( a ) ;
00031     const int n = TRANSB == traits::NO_TRANSPOSE ? traits::matrix_size2( b ) : traits::matrix_size1( b );
00032     const int k = TRANSA == traits::NO_TRANSPOSE ? traits::matrix_size2( a ) : traits::matrix_size1( a ) ;
00033     assert( k ==  ( TRANSB == traits::NO_TRANSPOSE ? traits::matrix_size1( b ) : traits::matrix_size2( b ) ) ) ;
00034     assert( m == traits::matrix_size1( c ) ); 
00035     assert( n == traits::matrix_size2( c ) ); 
00036     const int lda = traits::leading_dimension( a );
00037     const int ldb = traits::leading_dimension( b );
00038     const int ldc = traits::leading_dimension( c );
00039 
00040     const value_type *a_ptr = traits::matrix_storage( a ) ;
00041     const value_type *b_ptr = traits::matrix_storage( b ) ;
00042     value_type *c_ptr = traits::matrix_storage( c ) ;
00043 
00044     detail::gemm( TRANSA, TRANSB, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc ) ;
00045   }
00046 
00047 
00048   // C <- alpha * A * B + beta * C 
00049   template < typename value_type, typename matrix_type_a, typename matrix_type_b, typename matrix_type_c >
00050   void gemm(const value_type& alpha,
00051             const matrix_type_a &a,
00052             const matrix_type_b &b,
00053             const value_type &beta,
00054             matrix_type_c &c
00055             )
00056   {
00057     gemm( traits::NO_TRANSPOSE, traits::NO_TRANSPOSE, alpha, a, b, beta, c ) ;
00058   }
00059 
00060 
00061   // C <- A * B 
00062   // ! CAUTION this function assumes that all matrices involved are column-major matrices
00063   template < 
00064     typename matrix_type_a, typename matrix_type_b, typename matrix_type_c 
00065     >
00066   void gemm(const matrix_type_a &a, const matrix_type_b &b, matrix_type_c &c)
00067   {
00068     typedef typename traits::matrix_traits<matrix_type_c>::value_type val_t; 
00069     gemm( traits::NO_TRANSPOSE, traits::NO_TRANSPOSE, (val_t) 1, a, b, (val_t) 0, c ) ;
00070   }
00071 
00072 
00073   // C <- alpha * A * A^T + beta * C
00074   // C <- alpha * A^T * A + beta * C
00075   template < typename value_type, typename matrix_type_a, typename matrix_type_c >
00076   void syrk( char uplo, char trans, const value_type& alpha, const matrix_type_a& a,
00077              const value_type& beta, matrix_type_c& c) {
00078      const int n = traits::matrix_size1( c );
00079      assert( n == traits::matrix_size2( c ) );
00080      const int k = trans == traits::NO_TRANSPOSE ? traits::matrix_size2( a ) : traits::matrix_size1( a ) ;
00081      assert( n == traits::NO_TRANSPOSE ? traits::matrix_size1( a ) : traits::matrix_size2( a ) );
00082      const int lda = traits::leading_dimension( a );
00083      const int ldc = traits::leading_dimension( c );
00084 
00085      const value_type *a_ptr = traits::matrix_storage( a ) ;
00086      value_type *c_ptr = traits::matrix_storage( c ) ;
00087 
00088      detail::syrk( uplo, trans, n, k, alpha, a_ptr, lda, beta, c_ptr, ldc );
00089   } // syrk()
00090 
00091 
00092   // C <- alpha * A * A^H + beta * C
00093   // C <- alpha * A^H * A + beta * C
00094   template < typename real_type, typename matrix_type_a, typename matrix_type_c >
00095   void herk( char uplo, char trans, const real_type& alpha, const matrix_type_a& a,
00096              const real_type& beta, matrix_type_c& c) {
00097      typedef typename matrix_type_c::value_type value_type ;
00098 
00099      const int n = traits::matrix_size1( c );
00100      assert( n == traits::matrix_size2( c ) );
00101      const int k = trans == traits::NO_TRANSPOSE ? traits::matrix_size2( a ) : traits::matrix_size1( a ) ;
00102      assert( n == traits::NO_TRANSPOSE ? traits::matrix_size1( a ) : traits::matrix_size2( a ) );
00103      const int lda = traits::leading_dimension( a );
00104      const int ldc = traits::leading_dimension( c );
00105 
00106      const value_type *a_ptr = traits::matrix_storage( a ) ;
00107      value_type *c_ptr = traits::matrix_storage( c ) ;
00108 
00109      detail::herk( uplo, trans, n, k, alpha, a_ptr, lda, beta, c_ptr, ldc );
00110   } // herk()
00111 
00112   // B <- alpha * op( A^-1 )
00113   // B <- alpha * B op( A^-1 )
00114   // op( A ) = A, A^T, A^H
00115   template < class T, class A, class B >
00116   void trsm( char side, char uplo, char transa, char diag, T const& alpha, A const& a, B& b ) {
00117      const int m = traits::matrix_size1( b ) ;
00118      const int n = traits::matrix_size2( b ) ;
00119      assert( ( side=='L' && m==traits::matrix_size2( a ) && m==traits::matrix_size1( a ) ) ||
00120              ( side=='R' && n==traits::matrix_size2( a ) && n==traits::matrix_size1( a ) ) ) ;
00121      assert( side=='R' || side=='L' ) ;
00122      assert( uplo=='U' || uplo=='L' ) ;
00123      assert( ( side=='L' && m==traits::matrix_size1( a ) ) || ( side=='R' && n==traits::matrix_size1( a ) ) ) ;
00124      detail::trsm( side, uplo, transa, diag, m, n, alpha,
00125                    traits::matrix_storage( a ), traits::leading_dimension( a ),
00126                    traits::matrix_storage( b ), traits::leading_dimension( b )
00127                  ) ;
00128   }
00129 
00130 }}}}
00131 
00132 #endif // BOOST_BINDINGS_BLAS_BLAS3_HPP

Generated on Wed Nov 23 18:59:58 2011 for FreeCAD by  doxygen 1.6.1