13. FP8 Inference
FP8 formats (E4M3 and E5M2 per IEEE 754) offer near-exponential dynamic range with reduced precision, ideal for transformer inference.
FP8 Format Details
E4M3: 4-bit exponent, 3-bit mantissa, bias=7, max=448, min=2^-9 E5M2: 5-bit exponent, 2-bit mantissa, bias=15, max=57344, min=2^-14
Tensor operands often use E4M3; gradient accumulators and momentum use E5M2 for larger range.
FP8 GEMM Implementation
__device__ __half float_to_fp8_e4m3(float val) {
__half h = __float2half(val);
unsigned int bits = __half_as_uint(h);
// Clamp to E4M3 range [2^-9, 448]
bits = max(bits, 0x3F000000u); // min FP8 value
bits = min(bits, 0x477FE000u); // max FP8 value
return __uint_as_half(bits);
}
__global__ void fp8_gemm_kernel(
const __half* A_fp16, const __half* B_fp16,
__half* C, int M, int N, int K, float scale) {
int row = blockIdx.y * BM + threadIdx.y;
int col = blockIdx.x * BN + threadIdx.x;
float acc = 0.0f;
for (int k = 0; k < K; k++) {
float a = __half2float(A_fp16[row * K + k]);
float b = __half2float(B_fp16[k * N + col]);
acc += a * b;
}
C[row * N + col] = float_to_fp8_e4m3(acc * scale);
}
FP8 Scaling Considerations
Automatic scaling via per-tensor or per-row quantization:
void compute_fp8_scales(const float* fp32_tensor, int size,
float& scale, float max_bound = 448.0f) {
float abs_max = 0.0f;
for (int i = 0; i < size; i++) {
abs_max = fmax(abs_max, fabsf(fp32_tensor[i]));
}
scale = max_bound / abs_max;
}
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Implement FP8 attention with scaled softmax, ensuring numerical stability for large logit values.