r"""
The main object-oriented implementation of the MDF schema, with each core component of the `MDF specification <../Specification.html>`_
implemented as a :code:`class`. Instances of these objects can be composed to create a representation of
an MDF model as Python objects. These models can then be serialized and deserialized to and from JSON or YAML,
executed via the :mod:`~modeci_mdf.execution_engine` module, or imported and exported to supported external
environments using the :mod:`~modeci_mdf.interfaces` module.
"""
import copy
import numpy as np
from typing import List, Tuple, Dict, Set, Any, Union, Optional
import modelspec
from modelspec import has, field, fields, optional, instance_of, in_
from modelspec.base_types import (
Base,
converter,
ValueExprType,
value_expr_types,
value_expr_converter,
)
from ast import literal_eval as make_tuple
from modeci_mdf import MODECI_MDF_VERSION
from modeci_mdf import __version__
__all__ = [
"Model",
"Graph",
"Node",
"Function",
"InputPort",
"OutputPort",
"ParameterCondition",
"Parameter",
"Edge",
"ConditionSet",
"Condition",
]
@modelspec.define(eq=False)
class MdfBase(Base):
"""
Base class for all MDF core classes that implements common functionality.
Attributes:
metadata: Optional metadata field, an arbitrary dictionary of string keys and JSON serializable values.
"""
metadata: Optional[Dict[str, Any]] = field(
kw_only=True, default=None, validator=optional(instance_of(dict))
)
[docs]@modelspec.define(eq=False)
class Function(MdfBase):
r"""
A single value which is evaluated as a function of values on :class:`InputPort`(s) and other Functions
Attributes:
id: The unique (for this Node) id of the function, which will be used in other :class:`~Function`s and
the :class:`~OutputPort`s for its value
function: Which of the in-build MDF functions (:code:`linear`, etc.). See supported functions:
https://mdf.readthedocs.io/en/latest/api/MDF_function_specifications.html
args: Dictionary of values for each of the arguments for the Function, e.g. if the in-built function
is linear(slope),the args here could be {"slope":3} or {"slope":"input_port_0 + 2"}
value: If the function is a value expression, this attribute will contain the expression and the function
and args attributes will be None.
"""
id: str = field(validator=instance_of(str))
function: Optional[str] = field(
default=None, validator=optional(instance_of((str, dict)))
)
args: Optional[Dict[str, Any]] = field(default=None)
value: Optional[ValueExprType] = field(
default=None, validator=optional(instance_of(value_expr_types))
)
@classmethod
def _parse_cattr_structure(cls, o):
if "function" in o:
if isinstance(o["function"], dict):
# TODO: handle args already present and conflicting?
func_name = list(o["function"].keys())[0]
o["args"] = o["function"][func_name]
o["function"] = func_name
return o
[docs]@modelspec.define(eq=False)
class OutputPort(MdfBase):
r"""
The :class:`OutputPort` is an attribute of a :class:`Node` which exports information to another :class:`Node`
connected by an :class:`Edge`
Attributes:
id: Unique identifier for the output port.
value: The value of the :class:`OutputPort` in terms of the :class:`InputPort`, :class:`Function` values, and
:class:`Parameter` values.
shape: The shape of the output port. This uses the same syntax as numpy ndarray shapes
(e.g., :code:`numpy.zeros(shape)` would produce an array with the correct shape
type: The data type of the output sent by a port.
"""
id: str = field(validator=instance_of(str))
value: Optional[str] = field(validator=optional(instance_of(str)), default=None)
shape: Optional[Tuple[int, ...]] = field(
validator=optional(instance_of(tuple)),
default=None,
converter=lambda x: make_tuple(x) if type(x) == str else x,
)
type: Optional[str] = field(validator=optional(instance_of(str)), default=None)
[docs]@modelspec.define(eq=False)
class ParameterCondition(Base):
r"""
A condition to test on a Node's parameters, which if true, sets the value of this Parameter
Attributes:
id: A unique identifier for the ParameterCondition
test: The boolean expression to evaluate
value: The new value of the Parameter if the test is true
"""
id: str = field(validator=instance_of(str))
test: Optional[ValueExprType] = field(default=None)
value: Optional[ValueExprType] = field(default=None)
[docs]@modelspec.define(eq=False)
class Parameter(MdfBase):
r"""
A parameter of the :class:`Node`, which can be: 1) a specific fixed :code:`value` (a constant (int/float) or an array) 2) a string expression for the :code:`value`
referencing other named :class:`Parameter`(s). which may be stateful (i.e. can change value over multiple executions of the :class:`Node`); 3) be evaluated by an
inbuilt :code:`function` with :code:`args`; 4) or change from a :code:`default_initial_value` with a :code:`time_derivative`.
Attributes:
value: The next value of the parameter, in terms of the inputs, functions and PREVIOUS parameter values
default_initial_value: The initial value of the parameter, only used when parameter is stateful.
time_derivative: How the parameter changes with time, i.e. ds/dt. Units of time are seconds.
function: Which of the in-build MDF functions (linear etc.) this uses, See
https://mdf.readthedocs.io/en/latest/api/MDF_function_specifications.html
args: Dictionary of values for each of the arguments for the function of the parameter,
e.g. if the in-build function is :code:`linear(slope)`, the args here could be :code:`{"slope": 3}` or
:code:`{"slope": "input_port_0 + 2"}`
conditions: Parameter specific conditions
"""
id: str = field(validator=instance_of(str))
value: Optional[ValueExprType] = field(default=None)
default_initial_value: Optional[ValueExprType] = field(default=None)
time_derivative: Optional[str] = field(
validator=optional(instance_of(str)), default=None
)
function: Optional[str] = field(validator=optional(instance_of(str)), default=None)
args: Optional[Dict[str, Any]] = field(
validator=optional(instance_of(dict)), default=None
)
conditions: List[ParameterCondition] = field(
factory=list, validator=instance_of(list)
)
[docs] def summary(self):
"""
Short summary of Parameter...
"""
info = "Parameter: {} = {} (stateful: {})".format(
self.id,
self.value,
self.is_stateful(),
)
if self.default_initial_value is not None:
info += f", def init: {self.default_initial_value}"
if self.time_derivative is not None:
info += f", time deriv: {self.time_derivative}"
if self.conditions is not None:
for c in self.conditions:
info += f", {c}"
return info
[docs] def is_stateful(self) -> bool:
"""
Is the parameter stateful?
A parameter is considered stateful if it has a :code:`time_derivative`, :code:`default_initial_value`, or its
id is referenced in its value expression.
Returns:
:code:`True` if stateful, `False` if not.
"""
from modeci_mdf.execution_engine import get_required_variables_from_expression
if self.time_derivative is not None:
return True
if self.default_initial_value is not None:
return True
if self.value is not None and type(self.value) == str:
sf = self.id in get_required_variables_from_expression(self.value)
"""
print(
"Checking whether %s is stateful, %s: %s"
% (self, param_expr.free_symbols, sf)
)"""
return sf
return False
@staticmethod
def _unstructure(o: "Parameter"):
"""A custom unstructuring hook for parameters that removes default values of None from the output."""
# Unstructure the parameter and any value expressions it contains.
d = value_expr_converter.unstructure(o)
# Remove any fields that are None by default
d = {
name: val
for name, val in d.items()
if not ((val is None) or (type(val) == list and len(val) == 0))
}
return d
[docs]@modelspec.define(eq=False)
class Node(MdfBase):
r"""
A self contained unit of evaluation receiving input from other nodes on :class:`InputPort`(s).
The values from these are processed via a number of :class:`Function`(s) and one or more final values
are calculated on the :class:`OutputPort`(s)
Attributes:
id: A unique identifier for the node.
input_ports: Dictionary of the :class:`InputPort` objects in the Node
parameters: Dictionary of :class:`Parameter`(s) for the node
functions: The :class:`Function`(s) for computation the node
output_ports: The :class:`OutputPort`(s) containing evaluated quantities from the node
"""
id: str = field(validator=instance_of(str))
input_ports: List[InputPort] = field(factory=list, validator=instance_of(list))
functions: List[Function] = field(factory=list, validator=instance_of(list))
parameters: List[Parameter] = field(factory=list, validator=instance_of(list))
output_ports: List[OutputPort] = field(factory=list, validator=instance_of(list))
[docs] def get_parameter(self, id: str) -> Union[Parameter, None]:
r"""Get a parameter by its string :code:`id`
Args:
id: The unique string id of the :class:`Parameter`
Returns:
The :class:`Parameter` object stored on this node. :code:`None` if not found.
"""
for p in self.parameters:
if p.id == id:
return p
return None
[docs] def get_output_port(self, id: str) -> "OutputPort":
"""Retrieve :class:`OutputPort` object corresponding to the given id
Args:
id: Unique identifier of class:`OutputPort` object
Returns:
class:`OutputPort` object if the entered id matches with the id of class:`OutputPort` present in the class:`Node`
"""
for op in self.output_ports:
if id == op.id:
return op
[docs]@modelspec.define(eq=False)
class Edge(MdfBase):
r"""
An :class:`Edge` is an attribute of a :class:`Graph` that transmits computational results from a sender's
:class:`OutputPort` to a receiver's :class:`InputPort`.
Attributes:
id: A unique string identifier for this edge.
sender: The :code:`id` of the :class:`~Node` which is the source of the edge.
receiver: The :code:`id` of the :class:`~Node` which is the target of the edge.
sender_port: The id of the :class:`~OutputPort` on the sender :class:`~Node`, whose value should be sent to the
:code:`receiver_port`
receiver_port: The id of the InputPort on the receiver :class:`~Node`
parameters: Dictionary of parameters for the edge.
"""
id: str = field(validator=instance_of(str))
sender: str = field(validator=instance_of(str))
receiver: str = field(validator=instance_of(str))
sender_port: str = field(validator=instance_of(str))
receiver_port: str = field(validator=instance_of(str))
parameters: Optional[Dict[str, Any]] = field(
validator=optional(instance_of(dict)), default=None
)
[docs]@modelspec.define(eq=False)
class Condition(MdfBase):
r"""A set of descriptors which specifies conditional execution of Nodes to meet complex execution requirements.
Attributes:
type: The type of :class:`Condition` from the library
kwargs: The dictionary of keyword arguments needed to evaluate the :class:`Condition`
"""
type: str = field(validator=instance_of(str))
kwargs: Optional[Dict[str, Any]] = field(
validator=optional(instance_of(dict)), default=None
)
def __init__(
self,
type: Optional[str] = None,
**kwargs: Optional[Dict[str, Any]],
):
# We need to right our own __init__ in this case because the current API for Condition requires saving
# kwargs. This code is very attrs specific and hacky. Wish there was a better way to do this.
# Remove any field from kwargs that is pre-defined field on the MDF Condition class.
for a in self.__attrs_attrs__:
if a.name != "kwargs":
kwargs.pop(a.name, None)
# If there is a kwargs argument inside the kwargs argument (this happens when loading MDF from JSON because)
# kwargs is stored in JSON as a simple dict in a field called kwargs.
kwargs.update(kwargs.pop("kwargs", {}))
self.__attrs_init__(type, kwargs)
@classmethod
def _parse_cattr_structure(cls, o):
def _parse_as_condition(arg):
try:
return Condition(**cls._parse_cattr_structure(arg))
except (TypeError, ValueError):
return arg
if "kwargs" in o:
for k in o["kwargs"]:
if isinstance(o["kwargs"][k], list):
o["kwargs"][k] = [_parse_as_condition(a) for a in o["kwargs"][k]]
else:
try:
# read Model object as just object id
o["kwargs"][k] = o["kwargs"][k]["id"]
except (KeyError, TypeError):
o["kwargs"][k] = _parse_as_condition(o["kwargs"][k])
return o
@classmethod
def _parse_cattr_unstructure(cls, o):
def _clean_nested_Condition_objects(cond):
res = copy.copy(cond)
for k, v in res.kwargs.items():
if isinstance(v, Condition):
res.kwargs[k] = _clean_nested_Condition_objects(v)
elif isinstance(v, Base):
# replace mdf objects with their id to avoid making duplicates
res.kwargs[k] = v.id
return res
return _clean_nested_Condition_objects(o)
[docs]@modelspec.define(eq=False)
class ConditionSet(MdfBase):
r"""
Specifies the non-default pattern of execution of Nodes
Attributes:
node_specific: A dictionary mapping nodes to any non-default run conditions
termination: A dictionary mapping time scales of model execution to conditions indicating when they end
"""
node_specific: Optional[Dict[str, Condition]] = field(
validator=optional(instance_of(dict)), default=None
)
termination: Optional[Dict[str, Condition]] = field(
validator=optional(instance_of(dict)), default=None
)
@classmethod
def _parse_cattr_unstructure(cls, o):
res = copy.copy(o)
if res.termination is not None:
res.termination = {str(k): v for k, v in res.termination.items()}
return res
[docs]@modelspec.define(eq=False)
class Graph(MdfBase):
r"""
A directed graph consisting of :class:`~Node`s (with :class:`~Parameter`s and :class:`~Function`s evaluated internally) connected via :class:`~Edge`s.
Attributes:
id: A unique identifier for this Graph
nodes: One or more :class:`Node`(s) present in the graph
edges: Zero or more :class:`Edge`(s) present in the graph
parameters: Dictionary of global parameters for the Graph
conditions: The ConditionSet stored as dictionary for scheduling of the Graph
"""
id: str = field(validator=instance_of(str))
nodes: List[Node] = field(factory=list)
edges: List[Edge] = field(factory=list)
parameters: Optional[Dict[str, Any]] = None
conditions: Optional[ConditionSet] = None
[docs] def get_node(self, id: str) -> Union[Node, None]:
"""Retrieve Node object corresponding to the given id
Args:
id: Unique identifier of Node object
Returns:
:class:`Node` object if the entered :code:`id` matches with the :code:`id` of node present in the
:class:`~Graph`. :code:`None` if a node is not found with that id .
"""
for node in self.nodes:
if id == node.id:
return node
return None
@property
def dependency_dict(self) -> Dict[Node, Set[Node]]:
"""Returns the dependency among nodes as dictionary
Key: receiver, Value: Set of senders imparting information to the receiver
Returns:
Returns the dependency dictionary
"""
# assumes no cycles, need to develop a way to prune if cyclic
# graphs are to be supported
dependencies = {n: set() for n in self.nodes}
for edge in self.edges:
sender = self.get_node(edge.sender)
receiver = self.get_node(edge.receiver)
dependencies[receiver].add(sender)
return dependencies
@property
def inputs(self: "Graph") -> List[Tuple[Node, InputPort]]:
"""
Enumerate all Node-InputPort pairs that specify no incoming edge.
These are input ports for the graph itself and must be provided values to evaluate
Returns:
A list of Node, InputPort tuples
"""
# Get all input ports
all_ips = [(node.id, ip.id) for node in self.nodes for ip in node.input_ports]
# Get all receiver ports
all_receiver_ports = {(e.receiver, e.receiver_port) for e in self.edges}
# Find any input ports that aren't receiving values from an edge
return list(filter(lambda x: x not in all_receiver_ports, all_ips))
[docs]@modelspec.define(eq=False)
class Model(MdfBase):
r"""
The top level construct in MDF is Model, which may contain multiple :class:`Graph` objects and model attribute(s)
Attributes:
id: A unique identifier for this Model
graphs: The collection of graphs that make up the MDF model.
format: Information on the version of MDF used in this file
generating_application: Information on what application generated/saved this file
onnx_opset_version: The ONNX opset used for any ONNX functions in this model.
"""
id: str = field(validator=instance_of(str))
graphs: List[Graph] = field(factory=list)
format: str = field(
default=f"ModECI MDF v{MODECI_MDF_VERSION}", metadata={"omit_if_default": False}
)
generating_application: str = field(
default=f"Python modeci-mdf v{__version__}", metadata={"omit_if_default": False}
)
onnx_opset_version: Optional[str] = field(default=None)
[docs] def to_graph_image(
self,
engine: str = "dot",
output_format: str = "png",
view_on_render: bool = False,
level: int = 2,
filename_root: Optional[str] = None,
only_warn_on_fail: bool = False,
is_horizontal: bool = False,
solid_color=False,
):
"""Convert MDF graph to an image (png or svg) using the Graphviz export
Args:
engine: dot or other Graphviz formats
output_format: e.g. png (default) or svg
view_on_render: if True, will open generated image in system viewer
level: 1,2,3, depending on how much detail to include
filename_root: will change name of file generated to filename_root.png, etc.
only_warn_on_fail: just give a warning if this fails, e.g. no dot executable. Useful for preventing errors in automated tests
solid_color: if True, will return a solid color for the graphviz node. default is False
"""
from modeci_mdf.interfaces.graphviz.exporter import mdf_to_graphviz
try:
mdf_to_graphviz(
self.graphs[0],
engine=engine,
output_format=output_format,
view_on_render=view_on_render,
level=level,
filename_root=filename_root,
is_horizontal=is_horizontal,
solid_color=solid_color,
)
except Exception as e:
if only_warn_on_fail:
import traceback
print(traceback.format_exc())
print(
"Failure to generate image! Ensure Graphviz executables (dot etc.) are installed on native system. Error: \n%s"
% e
)
else:
raise (e)
# Add a special hook for handling unstructuring of Parameters and Functions
converter.register_unstructure_hook(Parameter, Parameter._unstructure)
[docs]def parsed_structure_factory(cl):
base_structure = converter._structure_func.dispatch(cl)
def new_structure(obj, t):
obj = cl._parse_cattr_structure(obj)
return base_structure(obj, t)
return new_structure
[docs]def parsed_unstructure_factory(cl):
base_unstructure = converter._unstructure_func.dispatch(cl)
def new_unstructure(obj):
obj = cl._parse_cattr_unstructure(obj)
return base_unstructure(obj)
return new_unstructure
for k, v in list(locals().items()):
# add structure hook for MdfBase classes where present
try:
v._parse_cattr_structure
except AttributeError:
pass
else:
converter.register_structure_hook(v, parsed_structure_factory(v))
# add unstructure hook for MdfBase classes where present
try:
v._parse_cattr_unstructure
except AttributeError:
pass
else:
converter.register_unstructure_hook(v, parsed_unstructure_factory(v))