Function linear

pub fn linear<B, const D: usize>(
    input: Tensor<B, D>,
    weight: Tensor<B, 2>,
    bias: Option<Tensor<B, 1>>,
) -> Tensor<B, D>
where B: Backend,
Expand description

Applies a linear transformation to the input tensor using the given weight and bias.

y = x @ weight + [bias]

§Arguments:

  • input is the input tensor, [..., d_input].
  • weight is the weight tensor, [d_input, d_output].
  • bias is the bias tensor (optional), [d_output].

§Returns:

The transformed tensor, [..., d_output].

§Compatibility

This function differs from PyTorch’s torch.nn.functional.linear in that it does not transpose the weight matrix. In PyTorch, the weight matrix is transposed before multiplication:

y = x @ weight^T + [bias]