Vectorized Reduction
In this section, we will explore how to implement a vectorized reduction operation using CubeCL. Vectorization is a powerful technique that allows us to process multiple data elements simultaneously, significantly improving performance for certain types of computations especially I/O operations.
What is vectorization?
Vectorization is the process of converting scalar operations (which operate on single data elements) into vector operations (which operate on multiple data elements simultaneously). This is typically done using SIMD (Single Instruction, Multiple Data) instructions available in modern CPUs and GPUs. By leveraging vectorization, we can achieve significant performance improvements for operations that can be vectorized. For more information on vectorization in CubeCL, you can refer to this section.
Application to the reduction problem
To apply vectorization to the reduction problem, we will modify our reduction kernel to process multiple elements at once. This means that instead of summing one element at a time, we will sum multiple elements with vectorization, which can lead to substantial performance gains. The number of element processed at a time is the line size. So to add vectorization we just needs to pass the LINE_SIZE
to the TensorArgs
and reduce the number of iteration of the reduce_matrix
.
use std::marker::PhantomData;
use cubecl::benchmark::{Benchmark, TimingMethod};
use cubecl::{future, prelude::*};
use cubecl_example::gpu_tensor::GpuTensor; // Change to the path of your own module containing the GpuTensor
pub struct ReductionBench<R: Runtime, F: Float + CubeElement> {
input_shape: Vec<usize>,
client: ComputeClient<R::Server, R::Channel>,
_f: PhantomData<F>,
}
const LINE_SIZE: u32 = 4;
impl<R: Runtime, F: Float + CubeElement> Benchmark for ReductionBench<R, F> {
type Input = GpuTensor<R, F>;
type Output = GpuTensor<R, F>;
fn prepare(&self) -> Self::Input {
GpuTensor::<R, F>::arange(self.input_shape.clone(), &self.client)
}
fn name(&self) -> String {
format!("{}-reduction-{:?}", R::name(&self.client), self.input_shape).to_lowercase()
}
fn sync(&self) {
future::block_on(self.client.sync())
}
fn execute(&self, input: Self::Input) -> Self::Output {
let output_shape: Vec<usize> = vec![self.input_shape[0]];
let output = GpuTensor::<R, F>::empty(output_shape, &self.client);
unsafe {
reduce_matrix::launch_unchecked::<F, R>(
&self.client,
CubeCount::Static(1, 1, 1),
CubeDim::new(self.input_shape[0] as u32, 1, 1),
input.into_tensor_arg(LINE_SIZE as u8),
output.into_tensor_arg(LINE_SIZE as u8),
);
}
output
}
}
// Note the addition of the [Line] struct inside the tensor to guarantee that the data is contiguous and can be parallelized.
#[cube(launch_unchecked)]
fn reduce_matrix<F: Float>(input: &Tensor<Line<F>>, output: &mut Tensor<Line<F>>) {
let mut acc = Line::new(F::new(0.0f32)); // A [Line] is also necessary here
for i in 0..input.shape(1) / LINE_SIZE {
acc = acc + input[UNIT_POS_X * input.stride(0) + i];
}
output[UNIT_POS_X] = acc;
}
pub fn launch<R: Runtime, F: Float + CubeElement>(device: &R::Device) {
let client = R::client(&device);
let bench1 = ReductionBench::<R, F> {
input_shape: vec![512, 8 * 1024],
client: client.clone(),
_f: PhantomData,
};
let bench2 = ReductionBench::<R, F> {
input_shape: vec![128, 32 * 1024],
client: client.clone(),
_f: PhantomData,
};
for bench in [bench1, bench2] {
println!("{}", bench.name());
println!("{}", bench.run(TimingMethod::System));
}
}
fn main() {
launch::<cubecl::wgpu::WgpuRuntime, f32>(&Default::default());
}
The Result
The result of adding vectorization is an average of 3x speedup compared to the previous parallel reduction implementation. This is because we are now processing multiple elements at a time in each invocation, which reduces the time of running a single invocation.
wgpu<wgsl>-reduction-[512, 8192]
―――――――― Result ―――――――――
Timing system
Samples 10
Mean 1.085ms
Variance 14.000ns
Median 1.045ms
Min 998.981µs
Max 1.375ms
―――――――――――――――――――――――――
wgpu<wgsl>-reduction-[128, 32768]
―――――――― Result ―――――――――
Timing system
Samples 10
Mean 3.124ms
Variance 37.000ns
Median 3.061ms
Min 3.009ms
Max 3.670ms
―――――――――――――――――――――――――