00001
00002
00003
00004
00005
00006
00007
00008
00009 #ifndef BOOST_BINDINGS_BLAS_BLAS3_HPP
00010 #define BOOST_BINDINGS_BLAS_BLAS3_HPP
00011
00012 #include <boost/numeric/bindings/blas/blas3_overloads.hpp>
00013 #include <boost/numeric/bindings/traits/traits.hpp>
00014 #include <boost/numeric/bindings/traits/transpose.hpp>
00015
00016 namespace boost { namespace numeric { namespace bindings { namespace blas {
00017
00018
00019
00020 template < typename value_type, typename matrix_type_a, typename matrix_type_b, typename matrix_type_c >
00021
00022 void gemm(const char TRANSA, const char TRANSB,
00023 const value_type& alpha,
00024 const matrix_type_a &a,
00025 const matrix_type_b &b,
00026 const value_type &beta,
00027 matrix_type_c &c
00028 )
00029 {
00030 const int m = TRANSA == traits::NO_TRANSPOSE ? traits::matrix_size1( a ) : traits::matrix_size2( a ) ;
00031 const int n = TRANSB == traits::NO_TRANSPOSE ? traits::matrix_size2( b ) : traits::matrix_size1( b );
00032 const int k = TRANSA == traits::NO_TRANSPOSE ? traits::matrix_size2( a ) : traits::matrix_size1( a ) ;
00033 assert( k == ( TRANSB == traits::NO_TRANSPOSE ? traits::matrix_size1( b ) : traits::matrix_size2( b ) ) ) ;
00034 assert( m == traits::matrix_size1( c ) );
00035 assert( n == traits::matrix_size2( c ) );
00036 const int lda = traits::leading_dimension( a );
00037 const int ldb = traits::leading_dimension( b );
00038 const int ldc = traits::leading_dimension( c );
00039
00040 const value_type *a_ptr = traits::matrix_storage( a ) ;
00041 const value_type *b_ptr = traits::matrix_storage( b ) ;
00042 value_type *c_ptr = traits::matrix_storage( c ) ;
00043
00044 detail::gemm( TRANSA, TRANSB, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc ) ;
00045 }
00046
00047
00048
00049 template < typename value_type, typename matrix_type_a, typename matrix_type_b, typename matrix_type_c >
00050 void gemm(const value_type& alpha,
00051 const matrix_type_a &a,
00052 const matrix_type_b &b,
00053 const value_type &beta,
00054 matrix_type_c &c
00055 )
00056 {
00057 gemm( traits::NO_TRANSPOSE, traits::NO_TRANSPOSE, alpha, a, b, beta, c ) ;
00058 }
00059
00060
00061
00062
00063 template <
00064 typename matrix_type_a, typename matrix_type_b, typename matrix_type_c
00065 >
00066 void gemm(const matrix_type_a &a, const matrix_type_b &b, matrix_type_c &c)
00067 {
00068 typedef typename traits::matrix_traits<matrix_type_c>::value_type val_t;
00069 gemm( traits::NO_TRANSPOSE, traits::NO_TRANSPOSE, (val_t) 1, a, b, (val_t) 0, c ) ;
00070 }
00071
00072
00073
00074
00075 template < typename value_type, typename matrix_type_a, typename matrix_type_c >
00076 void syrk( char uplo, char trans, const value_type& alpha, const matrix_type_a& a,
00077 const value_type& beta, matrix_type_c& c) {
00078 const int n = traits::matrix_size1( c );
00079 assert( n == traits::matrix_size2( c ) );
00080 const int k = trans == traits::NO_TRANSPOSE ? traits::matrix_size2( a ) : traits::matrix_size1( a ) ;
00081 assert( n == traits::NO_TRANSPOSE ? traits::matrix_size1( a ) : traits::matrix_size2( a ) );
00082 const int lda = traits::leading_dimension( a );
00083 const int ldc = traits::leading_dimension( c );
00084
00085 const value_type *a_ptr = traits::matrix_storage( a ) ;
00086 value_type *c_ptr = traits::matrix_storage( c ) ;
00087
00088 detail::syrk( uplo, trans, n, k, alpha, a_ptr, lda, beta, c_ptr, ldc );
00089 }
00090
00091
00092
00093
00094 template < typename real_type, typename matrix_type_a, typename matrix_type_c >
00095 void herk( char uplo, char trans, const real_type& alpha, const matrix_type_a& a,
00096 const real_type& beta, matrix_type_c& c) {
00097 typedef typename matrix_type_c::value_type value_type ;
00098
00099 const int n = traits::matrix_size1( c );
00100 assert( n == traits::matrix_size2( c ) );
00101 const int k = trans == traits::NO_TRANSPOSE ? traits::matrix_size2( a ) : traits::matrix_size1( a ) ;
00102 assert( n == traits::NO_TRANSPOSE ? traits::matrix_size1( a ) : traits::matrix_size2( a ) );
00103 const int lda = traits::leading_dimension( a );
00104 const int ldc = traits::leading_dimension( c );
00105
00106 const value_type *a_ptr = traits::matrix_storage( a ) ;
00107 value_type *c_ptr = traits::matrix_storage( c ) ;
00108
00109 detail::herk( uplo, trans, n, k, alpha, a_ptr, lda, beta, c_ptr, ldc );
00110 }
00111
00112
00113
00114
00115 template < class T, class A, class B >
00116 void trsm( char side, char uplo, char transa, char diag, T const& alpha, A const& a, B& b ) {
00117 const int m = traits::matrix_size1( b ) ;
00118 const int n = traits::matrix_size2( b ) ;
00119 assert( ( side=='L' && m==traits::matrix_size2( a ) && m==traits::matrix_size1( a ) ) ||
00120 ( side=='R' && n==traits::matrix_size2( a ) && n==traits::matrix_size1( a ) ) ) ;
00121 assert( side=='R' || side=='L' ) ;
00122 assert( uplo=='U' || uplo=='L' ) ;
00123 assert( ( side=='L' && m==traits::matrix_size1( a ) ) || ( side=='R' && n==traits::matrix_size1( a ) ) ) ;
00124 detail::trsm( side, uplo, transa, diag, m, n, alpha,
00125 traits::matrix_storage( a ), traits::leading_dimension( a ),
00126 traits::matrix_storage( b ), traits::leading_dimension( b )
00127 ) ;
00128 }
00129
00130 }}}}
00131
00132 #endif // BOOST_BINDINGS_BLAS_BLAS3_HPP