Custom Operators¶
In this documentation, we introduce how users can define custom operators (such as other activations) that are not currently supported in auto_LiRPA, with bound propagation methods.
Write a Custom Operator¶
There are three steps to write an operator:
Define a
torch.autograd.Function
(orFunction
for short) class, wrap the computation of the operator into thisFunction
, and also define a symbolic method so that the operator can be parsed in auto_LiRPA via ONNX. Please refer to PyTorch documentation on defining aFunction
with a symbolic method.Create a
torch.nn.Module
which uses the defined operator. Call the operator via.apply()
ofFunction
.Implement a Bound class to support bound propagation methods for this operator.
Create a mapping from the operator name (defined in step 1) to the bound class (defined in step 3). Define a
dict
which each item is a mapping. Pass thedict
to thecustom_ops
argument when callingBoundedModule
(see the documentation). For example, if the operator name isMyRelu
, and the bound class isBoundMyRelu
, then add"MyRelu": BoundMyRelu
to thedict
.
Example¶
We provide an code example of using a custom operator called “PlusConstant”.
Contributing to the Library¶
We encourage the community to upload their new operators to the auto_LiRPA library so that the new operators can also be used by other users. To do this, please put the Function
and the Bound class of the new operator at the auto_LiRPA/operators
, add the mapping at bound_op_map.py
, and submit a pull request.