PyTorch and Brevitas

The PyTorch frontend in hls4ml is implemented by parsing the symbolic trace of the torch.fx framework. This ensures the proper execution graph is captured. Therefore, only models that can be traced with the FX framework can be parsed by hls4ml.

Provided the underlying operation is supported in hls4ml, we generally aim to support the use of both torch.nn classes and torch.nn.functional functions in the construction of PyTorch models. Generally, the use of classes is more thoroughly tested. Please reach out if you experience any issues with either case.

The PyTorch/Brevitas parser is under heavy development and doesn’t yet have the same feature set of the Keras parsers. Feel free to reach out to developers if you find a missing feature that is present in Keras parser and would like it implemented.

Note

The direct ingestion of models quantized with brevitas is not supported currently. Instead, brevitas models shoud be exported in the ONNX format (see here) and read with the hls4ml QONNX frontend. Issues may arise, for example when non power-of-2 or non-scalar quantization scales are used. Please reach out if you encounter any problems with this workflow.

For multi-dimensional tensors, hls4ml follows the channels-last convention adopted by Keras, whereas PyTorch uses channels-first. By default, hls4ml will automaticlly transpose any tensors associated with weights and biases of the internal layers of the model. If the io_parallel I/O type (see Concepts) is used, a transpose node will be added to the model that also adjusts the input tensors. This is not available in the io_stream case and inputs must be transposed by the user. Outputs are not transposed back by default, but in io_parallel case, a transpose node can be added. If not needed, these adjustments can also be switched off. See config_from_pytorch_model for details.

The equivalent of Keras extension API is not yet available for PyTorch parser, and will be provided in the future.