Custom Operators¶
In this documentation, we introduce how users can define custom operators (such as other activations) that are not currently supported in auto_LiRPA, with bound propagation methods.
Write a Custom Operator¶
There are three steps to write an operator:
Define a
torch.autograd.Function(orFunctionfor short) class, wrap the computation of the operator into thisFunction, and also define a symbolic method so that the operator can be parsed in auto_LiRPA via ONNX. Please refer to PyTorch documentation on defining aFunctionwith a symbolic method.Create a
torch.nn.Modulewhich uses the defined operator. Call the operator via.apply()ofFunction.Implement a Bound class to support bound propagation methods for this operator.
Create a mapping from the operator name (defined in step 1) to the bound class (defined in step 3). Define a
dictwhich each item is a mapping. Pass thedictto thecustom_opsargument when callingBoundedModule(see the documentation). For example, if the operator name isMyRelu, and the bound class isBoundMyRelu, then add"MyRelu": BoundMyReluto thedict.
Example¶
We provide an code example of using a custom operator called “PlusConstant”.
Contributing to the Library¶
We encourage the community to upload their new operators to the auto_LiRPA library so that the new operators can also be used by other users. To do this, please put the Function and the Bound class of the new operator at the auto_LiRPA/operators, add the mapping at bound_op_map.py, and submit a pull request.