Sunday, July 16, 2023

torch.fx Notes

make_fx

decomposition_table 

The decomposition_table argument to the make_fx function in PyTorch is a dictionary that maps from ATen operators to their decomposed counterparts. This can be used to improve the performance of traced graphs by eliminating unnecessary copies and mutations. For example, the following code shows how to use the decomposition_table argument to improve the performance of the torch.nn.functional.dropout function: 

Python 
import torch
from torch.fx.experimental.proxy_tensor import make_fx

def make_decomposition_table():
  table = {}
  table[torch.nn.functional.dropout] = torch.jit.trace(
      lambda x, p: torch.nn.functional.dropout(x, p, training=False),
      example_inputs=(torch.randn(10), torch.tensor(0.5))
  )
  return table

decomposition_table = make_decomposition_table()
graph = make_fx(
    lambda x, p: torch.nn.functional.dropout(x, p),
    decomposition_table=decomposition_table
)
  • The make_decomposition_table function creates a dictionary that maps from the torch.nn.functional.dropout function to its traced counterpart. This traced counterpart is a more efficient implementation of the dropout function that eliminates unnecessary copies and mutations. 
  • The make_fx function then uses the decomposition_table argument to trace the dropout function using the traced counterpart. This results in a traced graph that is more efficient than the original graph. 
  • The decomposition_table argument can be used to improve the performance of any ATen operator. However, it is important to note that the traced counterparts of some operators may not be as efficient as the original operators. Therefore, it is important to benchmark the performance of the traced graphs before using them in production.

get_decompositions 

The get_decompositions function in PyTorch is a function that returns a dictionary of all the decompositions that are currently registered in the torch.fx.experimental.proxy_tensor module. This can be used to see which operators have been decomposed and how they have been decomposed. For example, the following code shows how to use the get_decompositions function to see which operators have been decomposed: Python
import torch
from torch.fx.experimental.proxy_tensor import get_decompositions

decompositions = get_decompositions()
print(decompositions)
  • This code will print a dictionary that maps from ATen operators to their decomposed counterparts. For example, the torch.nn.functional.dropout function will be mapped to its traced counterpart. 
  •  The get_decompositions function can be used to see which operators have been decomposed and how they have been decomposed. This can be useful for understanding the performance of traced graphs and for debugging traced graphs. Here is an example of the output of the get_decompositions function:
    {
      "aten::dropout": ,
      "aten::mul": ,
      "aten::add": ,
      ...
    }
    
    As you can see, the get_decompositions function returns a dictionary that maps from ATen operators to their decomposed counterparts. This can be useful for understanding the performance of traced graphs and for debugging traced graphs.