"""
Functions for converting from PyTorch TorchScript to MDF models.
This code was originally inspired by the following blog post:
Mike He, "From Models to Computation Graphs (Part I)", https://ad1024.space/articles/22
"""
import inspect
import logging
from typing import Union, Dict, Any, Tuple, List, Callable
import onnx.defs
import torch
from modeci_mdf.mdf import Model, Graph, Node, Edge, InputPort, OutputPort, Parameter
from modeci_mdf.functions.onnx import onnx_opset_version as modeci_onnx_opset_version
logger = logging.getLogger(__name__)
[docs]def convert_to_serializable(value):
"""Helper function that converts some common unserializable types to JSON seriralizable types"""
if type(value) is torch.device:
value = str(value)
elif type(value) is torch.Tensor:
value = value.numpy().tolist()
return value
[docs]def make_node_id(node: torch.Node) -> str:
"""Helper function to get a unique name (used in MDF as id) from a TorchScript Node object"""
return "_".join(
[node.kind().split("::")[-1]] + [str(o.unique()) for o in node.outputs()]
)
[docs]def make_func_id(node: torch.Node) -> str:
"""Helper function to get a unique name (used in MDF as id) for a TorchScript node's op/function."""
return node.kind().replace("::", "_") + "_1"
[docs]def make_model_graph_name(
model: Union[torch.ScriptModule, torch.ScriptFunction]
) -> Tuple[str, str]:
"""Helper function that generates a clean graph and model name from a TorchScript model"""
# Get a name for this module
try:
model_name = model.original_name.split(".")[-1]
graph_name = f"{model_name}Graph"
except AttributeError:
try:
model_name = model.qualified_name.split(".")[-1]
graph_name = f"{model_name}_graph"
except AttributeError:
# It hasn't been compiled yet, use the class name I guess
model_name = type(model).__name__.split(".")[-1]
graph_name = f"{model_name}Graph"
return model_name, graph_name
[docs]def process_torch_schema(
node: torch.Node, consts: Dict, port_mapper: "PortMapper"
) -> Tuple[List[str], Dict[str, Any]]:
"""
Parse a TorchScript node schema into argument names and constant attributes (parameters in MDF)
Args:
node: The TorchScript node to retrieve the schema from.
consts: The constant nodes Dict for the graph we are working with.
Returns:
A tuple containing a list of argument names and Dict of parameter names and values.
"""
# Get the input node names
inputs = [i.unique() for i in node.inputs()]
# If this is a TorchScript funciton (aten::*), it should have a schema string to parse.
if "no schema" not in node.schema():
schema = torch._C.parse_schema(node.schema())
# Get the arguments and covert to a simple List[str]
schema_args = schema.arguments
schema_args = [schema_args[i].name for i, inp in enumerate(inputs)]
else:
logger.warning(
f"Schema not found for TorchScript node ({node}), using placeholders for argument names."
)
schema_args = [f"arg{i}" for i in range(len(inputs))]
# Get any input to this node that is TorchScript node.kind() prim::Constant, make those a parameter
parameters = {
schema_args[i]: consts[inp] for i, inp in enumerate(inputs) if inp in consts
}
return schema_args, parameters
[docs]def process_onnx_schema(
node: torch.Node, consts: Dict, port_mapper: "PortMapper"
) -> Tuple[Dict[str, str], Dict[str, Any]]:
"""
Retrieve the argument names and attributes (parameters in MDF) for this Operation.
Args:
op: The TorchScript node containing the ONNX operation.
port_mapper: The utitlity class for assigning TorchScript input output ids to Input Output Port ids.
Returns:
A two element tuple:
- A dict representing argument names mapping to input port ids
- A dict mapping parameters (ONNX attributes) names mapping to values
"""
# Get the input node names
inputs = [i.unique() for i in node.inputs()]
# If this is an ONNX op, we need to get the schema from ONNX
if "onnx::" in node.kind():
try:
schema = onnx.defs.get_schema(
node.kind().split("::")[-1], modeci_onnx_opset_version
)
schema_args = {}
if len(schema.inputs) > 0:
# If the first argument is variadic. Represent this as a list of input port names
if schema.inputs[0].option.name == "Variadic":
schema_args = {
schema.inputs[0].name: str(
[
port_mapper.id_to_port(inp)
for i, inp in enumerate(inputs)
]
)
}
else:
schema_args = {
schema.inputs[i].name: port_mapper.id_to_port(inp)
for i, inp in enumerate(inputs)
}
except onnx.onnx_cpp2py_export.defs.SchemaError:
logger.warning(
f"Could not find ONNX OpSchema for op {node.kind()}, using placeholder names for arguments."
)
schema_args = {
f"arg{i}": port_mapper.id_to_port(inp) for i, inp in enumerate(inputs)
}
else:
raise ValueError(f"Cannot process ONNX schema for non ONNX node: {node}")
# Any inputs that are from constant nodes should be parameters in MDF
parameters = {
port_mapper.id_to_port(inp): consts[inp]
for i, inp in enumerate(inputs)
if inp in consts
}
# ONNX attributes are equivalent to MDF parameters
parameters.update(
{aname: convert_to_serializable(node[aname]) for aname in node.attributeNames()}
)
return schema_args, parameters
[docs]def get_graph_constants(graph: torch.Graph) -> Dict[str, Any]:
"""
Find all constant nodes in the graph and extract their values as a proper JSON serializable value.
Args:
graph: The graph to extract constants from.
Returns:
A Dict that maps the constant nodes unique TorchScript node ID string to its value.
"""
consts = {}
for n in graph.findAllNodes("prim::Constant"):
for o in n.outputs():
consts[o.unique()] = convert_to_serializable(o.toIValue())
# Get the ONNX constant nodes too
for n in graph.findAllNodes("onnx::Constant"):
consts[list(n.outputs())[0].unique()] = n["value"].numpy()
return consts
[docs]class PortMapper:
r"""
A simple class that handles mapping TorchScript input\ouput ids to MDF InputPort\OutputPort ids. It keeps track of
annoying details like graph level inputs and stuff.
"""
def __init__(self, graph: torch.Graph, args: Tuple):
# Keep generate special names for all the graph inputs and parameters
self.graph_inputs = PortMapper._get_graph_inputs_dict(graph, args)
[docs] def id_to_port(self, id: str):
"""Turn unique TorchScript output and input value names into valid MDF input and outport names"""
# If this id is a graph input, use its debug name
if id in self.graph_inputs:
id = self.graph_inputs[id]
new_name = str(id).replace(".", "_")
# Remove :: from ids, these cause issues with parsing in the execution engine
new_name = new_name.replace("::", "_")
# Renive aby "-" from names, these cause issues with parsing in the execution engine
new_name = new_name.replace("-", "_")
# If the first character is a digit, precede with an underscore so this can never be interpreted
# as number down the line.
if new_name[0].isdigit():
new_name = "_" + new_name
return new_name
[docs] def port_to_id(self, name: str):
"""Transform a port name back to is TorchScript ID"""
# If first character is underscore, remove it
id = name
if name[0] == "_":
id = name[1:]
# Replace any remaining underscores with '.'
id = id.replace("_", ".")
# If this is a numeric id, make it an int again
if id[0].isdigit():
id = int(id)
# If this id is actually a debugName from a graph input, use that
for input_id, debug_name in self.graph_inputs.items():
if debug_name == id:
return input_id
return id
@staticmethod
def _get_graph_inputs_dict(
graph: torch.Graph, args: Tuple[torch.Tensor]
) -> Dict[str, str]:
"""
Create a dict mapping graph input torch.Node ids to default names. The default names are just:
- input1
- input2
- etc.
Any parameters for the model will also be graph inputs but their node.debugName() will be used
instead.
"""
graph_inputs = {
inp.unique(): inp.debugName() for i, inp in enumerate(graph.inputs())
}
# The first len(args) inputs should be the input arguments to the function or forward method. Lets
# canonicalize them.
input_ids = list(graph_inputs.keys())
for i in range(len(args)):
graph_inputs[input_ids[i]] = f"input{i + 1}"
return graph_inputs
[docs]def torchnode_to_mdfnode(
node: torch.Node,
graph: torch.Graph,
consts: Dict[str, Any],
port_mapper: "PortMapper",
) -> Union[Node, None]:
"""
Convert a TorchScript node to an MDF node.
Args:
node: The node to convert.
graph: The graph that this node is a member.
consts: A dict containing any constants in the graph.
Returns:
The MDF node for this TorchScript node. prim::Constant nodes are excluded from the MDF graph and are
instead placed as parameters. In this case, return None.
"""
op = node.kind()
# Lookup the schema. For some reason we cannot just call node.schema(), it returns "(no schema)", huh?
# We need to do this the hard way.
schema = onnx.defs.get_schema(op.replace("onnx::", ""), modeci_onnx_opset_version)
# Exclude constants (as nodes) from the MDF graph. We will instead insert them as parameters to the nodes that
# they project to.
if op in ("prim::Constant", "onnx::Constant"):
return None
# If we are dealing with a loop node, we need to recursively create a sub-graph for the loop body
if op == "onnx::Loop":
sub_mdf_graph = Graph(id=f"LoopSubgraph{make_node_id(node)}")
block_graph = list(node.blocks())[0]
translate_graph(
graph=block_graph,
mdf_graph=sub_mdf_graph,
consts=consts,
port_mapper=port_mapper,
)
return sub_mdf_graph
outputs = [o.unique() for o in node.outputs()]
inputs = [i.unique() for i in node.inputs()]
# Get the argument names and parameter names and values for this Node's operation
if "onnx::" in op:
arguments, parameters = process_onnx_schema(node, consts, port_mapper)
else:
arguments, parameters = process_torch_schema(node, consts, port_mapper)
mdf_node = Node(id=make_node_id(node))
from modeci_mdf.interfaces.onnx.importer import (
get_category_of_onnx_node,
get_color_for_onnx_category,
)
category = get_category_of_onnx_node(mdf_node.id)
color = get_color_for_onnx_category(category)
mdf_node.metadata = color
for p in parameters:
mdf_node.parameters.append(Parameter(id=p, value=parameters[p]))
# Add any output ports
subscript = lambda x: "" if len(schema.outputs) <= 1 else f"[{x}]"
for out_num, o in enumerate(outputs):
# Try to get the shape and type of the out port
out_type = node.outputsAt(out_num).type()
try:
out_dtype = str(out_type.dtype()).replace("torch.", "")
except RuntimeError:
out_dtype = str(out_type.getElementType())
try:
shape = tuple(out_type.sizes()) if out_type.sizes() else None
except RuntimeError:
shape = None
mdf_node.output_ports.append(
OutputPort(
id=port_mapper.id_to_port(o),
value=make_func_id(node) + subscript(out_num),
shape=shape,
type=out_dtype,
)
)
# Add any input ports to the node, exclude inputs from constant nodes, these are parameters now
for inp_i, inp in enumerate(inputs):
if inp not in consts:
ip_name = port_mapper.id_to_port(inp)
# Try to get the shape and type of the input port
inp_type = node.inputsAt(inp_i).type()
try:
inp_dtype = str(inp_type.dtype()).replace("torch.", "")
except RuntimeError:
inp_dtype = str(inp_type.getElementType())
try:
shape = tuple(inp_type.sizes()) if inp_type.sizes() else None
except RuntimeError:
shape = None
mdf_node.input_ports.append(
InputPort(id=ip_name, shape=shape, type=inp_dtype)
)
# Add Parameter
if type(arguments) == list:
arguments = {"arguments": arguments}
f = Parameter(id=make_func_id(node), function=op, args=arguments)
mdf_node.parameters.append(f)
return mdf_node
[docs]def translate_graph(
graph: Union[torch.Graph, torch.Block],
mdf_graph: Graph,
consts: Dict[str, Any],
port_mapper: "PortMapper",
):
"""
Go through a :class:`~torch.Graph` or :class:`~torch.Block` and translate the nodes and edges to MDF nodes and
edges.
Args:
graph: The graph to translate.
mdf_graph: The MDF graph to store the translation into.
consts: Constant to use for parameters of nodes.
port_mapper: A port mapper instance to handle translating names.
Returns:
"""
# For every node, cache its input edges. This will let us look this up quickly for
# any node in the loop below.
node_to_in_edge = {
node: [i.unique() for i in node.inputs()] for node in graph.nodes()
}
for node in graph.nodes():
mdf_node = torchnode_to_mdfnode(
node=node, graph=graph, consts=consts, port_mapper=port_mapper
)
# If we are excluding this node from the MDF graph, skip it.
if mdf_node is None:
continue
mdf_graph.nodes.append(mdf_node)
if type(mdf_node) == Graph:
continue
# Now we need to examine all outgoing edges from this node and add them to the MDF graph. We do this by looping
# over all nodes in the graph and seeing if they have an input from the node we just constructed. This is
# O(n^2) in terms of the number of the nodes!
outputs = [o.unique() for o in node.outputs()]
for to in graph.nodes():
# Lookup this nodes input edges
to_inputs = node_to_in_edge[to]
edges = set(outputs) & set(to_inputs)
for edge in edges:
from_id = make_node_id(node)
from_port = mdf_node.output_ports[outputs.index(edge)].id
to_id = make_node_id(to)
to_port = mdf_node.output_ports[outputs.index(edge)].id
mdf_edge = Edge(
id=f"{from_id}_{to_id}",
sender=from_id,
sender_port=f"{from_port}",
receiver=to_id,
receiver_port=f"{to_port}",
)
mdf_graph.edges.append(mdf_edge)
[docs]def pytorch_to_mdf(
model: Union[Callable, torch.nn.Module, torch.ScriptFunction, torch.ScriptModule],
args: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
trace: bool = False,
use_onnx_ops: bool = True,
) -> Union[Model, Graph]:
r"""
Convert a PyTorch model to an MDF model. By default, this function will invoke `torch.jit.script` on the
model to compile it down to TorchScript IR and simplify the graph before exporting the MDF. The default is
to use ONNX operations when possible and fallback to ATEN\Torch ops when ONNX support is not available
(`torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK` mode). To use allATEN\Torch ops, set use_onnx_ops to False.
Args:
model: The model to translate into MDF.
args: The input arguments for this model. If a nn.Module is passed then the model will be traced with these
inputs. If a ScriptModule is passed, they are still needed to determine input shapes.
trace: Force the use of tracing to compile the model. The default is to use torch.jit.script
use_onnx_ops: Use ONNX ops when possible, fallback to ATEN ops when not available. Default is True. If False,
use only ATEN ops.
Returns:
The translated MDF model
"""
# Special case for common case of passing a single Tensor
if isinstance(args, (torch.Tensor, int, float, bool)):
args = (args,)
# Get the graph and nodes from the TorchScript model
try:
# If the graph attribute is available, we are dealing with a already jitted model (ScriptModule, ScriptFunciton,
# etc.)
graph = model.graph
jit_model = model
except AttributeError:
# Lets jit things, if the user doesn't want to trace or we are dealing with a standard Python function, we need
# to JIT script it.
if not trace or inspect.isfunction(model):
jit_model = torch.jit.script(model)
graph = jit_model.graph
else:
# If the user wants to trace, _model_to_graph below will take care of that for us.
graph = None
if use_onnx_ops:
operator_export_type = torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = torch._C._onnx.OperatorExportTypes.RAW
# Call out to a part of the ONNX exporter that simiplifies the graph before ONNX export.
from torch.onnx.utils import _model_to_graph
from torch.onnx import TrainingMode
from torch.onnx.symbolic_helper import _set_opset_version
try:
from torch.onnx.symbolic_helper import _export_onnx_opset_version
except ImportError:
# This is need for PyTorch 1.12
from torch.onnx._globals import GLOBALS
_export_onnx_opset_version = GLOBALS.export_onnx_opset_version
previous_opset_version = _export_onnx_opset_version
_set_opset_version(modeci_onnx_opset_version)
graph, params_dict, torch_out = _model_to_graph(
model=jit_model if graph else model,
args=args,
do_constant_folding=False,
training=TrainingMode.EVAL,
operator_export_type=operator_export_type,
dynamic_axes={},
)
_set_opset_version(previous_opset_version)
model_name, graph_name = make_model_graph_name(model)
# Setup the MDF model and graph
mdf_model = Model(id=model_name)
mdf_graph = Graph(id=graph_name)
mdf_model.graphs.append(mdf_graph)
# Get all constant nodes in the graph
consts = get_graph_constants(graph)
# Get any inputs to the graph, and their debug names. Pass args so we know how
# many original input arguments the graph has. ONNX lowering from _model_to_graph
# makes all parameters to the model inputs.
port_mapper = PortMapper(graph=graph, args=args)
# Translate the TorchScript graph to and MDF graph object. This could be a recursive call
translate_graph(
graph=graph, mdf_graph=mdf_graph, consts=consts, port_mapper=port_mapper
)
# Replace in "." for "_" in parameter names. We have done this elsewhere when creating the input ports for these
# parameters.
params_dict = {port_mapper.id_to_port(k): v for k, v in params_dict.items()}
# Set the ONNX opset version
mdf_model.onnx_opset_version = _export_onnx_opset_version
return mdf_model, params_dict
if __name__ == "__main__":
def simple(x, y):
return x + y
mdf_model, param_dict = pytorch_to_mdf(
simple,
args=(torch.tensor(1.0), torch.tensor(2.0)),
)