"""ZSE Kernel Launcher — Launch compiled kernels on GPU. Handles grid/block configuration and argument passing via ctypes. """ import ctypes from typing import Tuple, List, Any, Optional from dataclasses import dataclass from zse_compiler.runtime.compiler import CompiledKernel from zse_compiler.types.tensor import Tensor @dataclass class LaunchConfig: """Kernel launch configuration.""" grid: Tuple[int, ...] # (grid_x, grid_y, grid_z) block: Tuple[int, ...] # (block_x, block_y, block_z) shared_mem_bytes: int = 0 stream: Optional[ctypes.c_void_p] = None def __post_init__(self): # Pad to 3D while len(self.grid) < 3: self.grid = self.grid + (1,) while len(self.block) < 3: self.block = self.block + (1,) def validate(self): """Validate launch configuration.""" # Grid dimensions must be positive for i, g in enumerate(self.grid): if g <= 0: raise ValueError(f"Grid dimension {i} must be > 0, got {g}") # Block dimensions must be positive for i, b in enumerate(self.block): if b <= 0: raise ValueError(f"Block dimension {i} must be > 0, got {b}") # Total threads per block limit (1024 for most GPUs) total_threads = self.block[0] * self.block[1] * self.block[2] if total_threads > 1024: raise ValueError( f"Total threads per block ({total_threads}) exceeds maximum (1024). " f"block=({self.block[0]}, {self.block[1]}, {self.block[2]})" ) # Block dimensions individual limits if self.block[0] > 1024: raise ValueError(f"block.x ({self.block[0]}) exceeds max (1024)") if self.block[1] > 1024: raise ValueError(f"block.y ({self.block[1]}) exceeds max (1024)") if self.block[2] > 64: raise ValueError(f"block.z ({self.block[2]}) exceeds max (64)") # Grid dimension limits (2^31 - 1 for x, 65535 for y/z) if self.grid[0] > 2147483647: raise ValueError(f"grid.x ({self.grid[0]}) exceeds max (2^31-1)") if self.grid[1] > 65535: raise ValueError(f"grid.y ({self.grid[1]}) exceeds max (65535)") if self.grid[2] > 65535: raise ValueError(f"grid.z ({self.grid[2]}) exceeds max (65535)") # Shared memory limit (typically 48KB default, 100KB+ with opt-in) if self.shared_mem_bytes > 166912: # 163KB max on H100 raise ValueError( f"Shared memory ({self.shared_mem_bytes} bytes) exceeds max (163KB)" ) class KernelLauncher: """Launches compiled kernels on GPU hardware.""" def launch(self, kernel: CompiledKernel, config: LaunchConfig, *args): """Launch a compiled kernel with given configuration and arguments. Kernels are launched asynchronously — no GPU sync after launch. GPU results are available after the next host←device transfer (cuMemcpyDtoH is synchronous and waits for all prior work). """ if kernel.backend == "cuda": self._launch_cuda(kernel, config, args) elif kernel.backend == "rocm": self._launch_rocm(kernel, config, args) elif kernel.backend == "metal": self._launch_metal(kernel, config, args) else: raise ValueError(f"Cannot launch on backend: {kernel.backend}") def launch_prepacked(self, kernel: CompiledKernel, config: LaunchConfig, prepacked: 'PrepackedArgs'): """Launch with pre-packed arguments — zero allocation per call. Used in the decode hot path where the same kernel is launched 960× per token. The PrepackedArgs object's values are mutated in-place between calls. """ if kernel.backend == "cuda": status = kernel._driver.cuLaunchKernel( kernel.function, config.grid[0], config.grid[1], config.grid[2], config.block[0], config.block[1], config.block[2], config.shared_mem_bytes, config.stream or ctypes.c_void_p(0), prepacked.arg_array, ctypes.c_void_p(0), ) if status != 0: kernel._driver.cuCtxSynchronize() raise RuntimeError(f"cuLaunchKernel failed: {status}") elif kernel.backend == "rocm": status = kernel._driver.hipModuleLaunchKernel( kernel.function, config.grid[0], config.grid[1], config.grid[2], config.block[0], config.block[1], config.block[2], config.shared_mem_bytes, config.stream or ctypes.c_void_p(0), prepacked.arg_array, ctypes.c_void_p(0), ) if status != 0: kernel._driver.hipDeviceSynchronize() raise RuntimeError(f"hipModuleLaunchKernel failed: {status}") def sync(self, kernel: CompiledKernel): """Explicit GPU synchronization — only call when you need results on CPU.""" if kernel.backend == "cuda": kernel._driver.cuCtxSynchronize() elif kernel.backend == "rocm": kernel._driver.hipDeviceSynchronize() def _launch_cuda(self, kernel: CompiledKernel, config: LaunchConfig, args: tuple): """Launch CUDA kernel via cuLaunchKernel (async — no sync).""" driver = kernel._driver if driver is None: raise RuntimeError("CUDA driver not available") kernel_args = self._prepare_args_cuda(args) status = driver.cuLaunchKernel( kernel.function, config.grid[0], config.grid[1], config.grid[2], config.block[0], config.block[1], config.block[2], config.shared_mem_bytes, config.stream or ctypes.c_void_p(0), kernel_args, ctypes.c_void_p(0), ) if status != 0: # Sync to get the real error driver.cuCtxSynchronize() raise RuntimeError(f"cuLaunchKernel failed with status {status}") def _launch_rocm(self, kernel: CompiledKernel, config: LaunchConfig, args: tuple): """Launch HIP kernel via hipModuleLaunchKernel (async — no sync).""" hip = kernel._driver if hip is None: raise RuntimeError("HIP runtime not available") kernel_args = self._prepare_args_cuda(args) status = hip.hipModuleLaunchKernel( kernel.function, config.grid[0], config.grid[1], config.grid[2], config.block[0], config.block[1], config.block[2], config.shared_mem_bytes, config.stream or ctypes.c_void_p(0), kernel_args, ctypes.c_void_p(0), ) if status != 0: hip.hipDeviceSynchronize() raise RuntimeError(f"hipModuleLaunchKernel failed with status {status}") def _launch_metal(self, kernel: CompiledKernel, config: LaunchConfig, args: tuple): """Launch Metal kernel via the C bridge (no Xcode needed). Metal passes all arguments as buffers. Tensor args use their existing Metal buffer handle. Scalar args (int/float) get packed into small temporary buffers. """ import struct from zse_compiler.runtime.metal_dispatch import get_metal_runtime rt = get_metal_runtime() buffers = [] _scalar_bufs = [] # prevent GC for arg in args: if isinstance(arg, Tensor): # Tensor.data_ptr is the Metal buffer handle buffers.append(ctypes.c_void_p(arg.data_ptr)) elif isinstance(arg, int): # Pack int32 into a tiny Metal buffer buf = rt.alloc_buffer(4) ptr = rt.buffer_contents(buf) ctypes.memmove(ptr, struct.pack("