gemm
gemm(
a: &Tensor<T>,
b: &Tensor<T> | Tensor<T>,
alpha: T,
beta: T,
conj_dst: bool,
conj_lhs: bool,
conj_rhs: bool,
) -> Result<Tensor<T>, TensorError>
Perform general matrix multiplication of two tensors. The behavior depends on the dimensions of the input tensors:
- If both tensors are 2D, they are multiplied as matrices
- If either tensor is ND (N > 2), it is treated as a stack of matrices
- Broadcasting is applied to match dimensions
Unlike matmul
method, gemm only supports f16
, f32
, f64
, Complex32
, Complex64
.
Parameters:
a
: First input tensor.
b
: Second input tensor.
alpha
: Scaling factor for the matrix product (A @ B)
beta
: Scaling factor for the existing values in output matrix C
conj_dst
: Whether to conjugate C before scaling with beta
conj_lhs
: Whether to conjugate A before multiplication
conj_rhs
: Whether to conjugate B before multiplication
Returns:
A new Tensor containing the result of general matrix multiplication.
Examples:
use hpt::{
error::TensorError,
ops::{Gemm, TensorCreator},
Tensor,
};
fn main() -> Result<(), TensorError> {
// 2D matrix multiplication
let a = Tensor::<f64>::new(&[[1., 2.], [3., 4.]]);
let b = Tensor::<f64>::new(&[[5., 6.], [7., 8.]]);
let c = a.gemm(&b, 0.0, 1.0, false, false, false)?;
println!("2D result:\n{}", c);
// 3D batch matrix multiplication
let d = Tensor::<f64>::ones(&[2, 2, 3])?; // 2 matrices of shape 2x3
let e = Tensor::<f64>::ones(&[2, 3, 2])?; // 2 matrices of shape 3x2
let f = d.gemm(&e, 0.0, 1.0, false, false, false)?; // 2 matrices of shape 2x2
println!("3D result:\n{}", f);
Ok(())
}
Backend Support
Backend | Supported |
---|---|
CPU | ✅ |
Cuda | ✅ |