11. TensorRT Plugin Development
Chapter 11 of 18 · 15 min
TensorRT plugins extend the inference graph with custom operations, enabling deployment of proprietary quantization schemes and optimized kernels.
Plugin Interface
Inherit from nvinfer1::IPluginV2 or IPluginV2IOExt for custom layer implementation:
class QuantizedMatmulPlugin : public nvinfer1::IPluginV2IOExt {
private:
struct PluginTensor {
int32_t rows, cols;
nvinfer1::DataType type;
std::vector<int8_t> quantized_data;
std::vector<float> scales;
};
PluginTensor weight_;
float alpha_, beta_;
nvinfer1::DataType output_type_;
public:
QuantizedMatmulPlugin(const void* data, size_t length) {
// Deserialize from buffer
const char* d = reinterpret_cast<const char*>(data);
weight_.rows = *reinterpret_cast<const int32_t*>(d);
weight_.cols = *reinterpret_cast<const int32_t*>(d + 4);
// ... deserialize scales and quantized weights
}
int enqueue(int32_t batch_size, const void* const* inputs,
void* const* outputs, void*, cudaStream_t stream) override {
// Launch optimized kernel
quantized_matmul_kernel(
inputs[0], weight_.quantized_data.data(),
weight_.scales.data(), weight_.rows, weight_.cols,
outputs[0], alpha_, beta_, stream
);
return 0;
}
size_t getSerializationSize() const override {
return sizeof(int32_t) * 2 + weight_.scales.size() * sizeof(float)
+ weight_.quantized_data.size();
}
};
Plugin Registration
Register plugins with the plugin registry for dynamic loading:
REGISTER_TENSORRT_PLUGIN(QuantizedMatmulPluginCreator);
class QuantizedMatmulPluginCreator : public nvinfer1::IPluginCreator {
public:
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
float alpha = 1.0f, beta = 0.0f;
std::vector<int8_t> weight_data;
std::vector<float> scales;
for (int i = 0; i < fc->nbFields; i++) {
std::string field_name(fc->fields[i].name);
if (field_name == "alpha") alpha = *(float*)fc->fields[i].data;
if (field_name == "beta") beta = *(float*)fc->fields[i].data;
if (field_name == "weight") weight_data = *(std::vector<int8_t>*)fc->fields[i].data;
if (field_name == "scales") scales = *(std::vector<float>*)fc->fields[i].data;
}
return new QuantizedMatmulPlugin(weight_data, scales, alpha, beta);
}
};
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
EXERCISE
Create a TensorRT plugin for a symmetrically quantized convolution with per-channel scales. Test it with ONNX model containing the operation.