Tutorial 26 - Generic Math Node

This tutorial demonstrates how to compose nodes that perform mathematical operations in python using numpy. Using numpy has the advantage that it is api-compatible to cuNumeric. As demonstrated in the Extended Attributes tutorial, generic math nodes use extended attributes to allow inputs and outputs of arbitrary numeric types, specified using the “numerics” keyword.

"inputs": {
    "myNumbericAttribute": {
        "description": "Accepts an incoming connection from any type of numeric value",
        "type": ["numerics"]
    }
}

OgnTutorialGenericMathNode.ogn

The ogn file shows the implementation of a node named “omni.graph.tutorials.GenericMathNode”, which takes inputs of any numeric types and performs a multiplication.

 1{
 2    "GenericMathNode": {
 3        "description": [
 4            "This is a tutorial node. It is functionally equivalent to the built-in Multiply node,",
 5            "but written in python as a practical demonstration of using extended attributes to ",
 6            "write math nodes that work with any numeric types, including arrays and tuples."
 7        ],
 8        "version": 1,
 9        "language": "python",
10        "uiName": "Tutorial Python Node: Generic Math Node",
11        "categories": "tutorials",
12        "inputs": {
13            "a": {
14                "type": ["numerics"],
15                "description": "First number to multiply",
16                "uiName": "A"
17            },
18            "b": {
19                "type": ["numerics"],
20                "description": "Second number to multiply",
21                "uiName": "B"
22            }
23        },
24        "outputs": {
25            "product": {
26                "type": ["numerics"],
27                "description": "Product of the two numbers",
28                "uiName": "Product"
29            }
30        },
31        "tests" : [
32            {  "inputs:a": {"type": "int", "value": 2}, "inputs:b": {"type": "int", "value": 3}, "outputs:product": {"type": "int", "value": 6} },
33            {  "inputs:a": {"type": "int", "value": 2}, "inputs:b": {"type": "int64", "value": 3}, "outputs:product": {"type": "int64", "value": 6} },
34            {  "inputs:a": {"type": "int", "value": 2}, "inputs:b": {"type": "half", "value": 3}, "outputs:product": {"type": "float", "value": 6} },
35            {  "inputs:a": {"type": "int", "value": 2}, "inputs:b": {"type": "float", "value": 3}, "outputs:product": {"type": "float", "value": 6} },
36            {  "inputs:a": {"type": "int", "value": 2}, "inputs:b": {"type": "double", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
37            {  "inputs:a": {"type": "int64", "value": 2}, "inputs:b": {"type": "int64", "value": 3}, "outputs:product": {"type": "int64", "value": 6} },
38            {  "inputs:a": {"type": "int64", "value": 2}, "inputs:b": {"type": "half", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
39            {  "inputs:a": {"type": "int64", "value": 2}, "inputs:b": {"type": "float", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
40            {  "inputs:a": {"type": "int64", "value": 2}, "inputs:b": {"type": "double", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
41            {  "inputs:a": {"type": "half", "value": 2}, "inputs:b": {"type": "half", "value": 3}, "outputs:product": {"type": "half", "value": 6} },
42            {  "inputs:a": {"type": "half", "value": 2}, "inputs:b": {"type": "float", "value": 3}, "outputs:product": {"type": "float", "value": 6} },
43            {  "inputs:a": {"type": "half", "value": 2}, "inputs:b": {"type": "double", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
44            {  "inputs:a": {"type": "float", "value": 2}, "inputs:b": {"type": "float", "value": 3}, "outputs:product": {"type": "float", "value": 6} },
45            {  "inputs:a": {"type": "float", "value": 2}, "inputs:b": {"type": "double", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
46            {  "inputs:a": {"type": "double", "value": 2}, "inputs:b": {"type": "double", "value": 3}, "outputs:product": {"type": "double", "value": 6} },
47            {
48                "inputs:a": {"type": "double[2]", "value": [1.0, 42.0]}, "inputs:b": {"type": "double[2]", "value": [2.0, 1.0]},
49                "outputs:product": {"type": "double[2]", "value": [2.0, 42.0]}
50            },
51            {
52                "inputs:a": {"type": "double[]", "value": [1.0, 42.0]}, "inputs:b": {"type": "double", "value": 2.0},
53                "outputs:product": {"type": "double[]", "value": [2.0, 84.0]}
54            },
55            {
56                "inputs:a": {"type": "double[2][]", "value": [[10, 5], [1, 1]]}, 
57                "inputs:b": {"type": "double[2]", "value": [5, 5]},
58                "outputs:product": {"type": "double[2][]", "value": [[50, 25], [5, 5]]}
59            }
60        ]
61    }
62}

OgnTutorialGenericMathNode.py

The py file contains the implementation of the node. It takes two numeric inputs and performs a multiplication, demonstrating how to handle cases where the inputs are both numeric types but vary in precision, format or dimension.

 1import numpy as np
 2import omni.graph.core as og
 3
 4# Mappings of possible numpy dtypes from the result data type and back
 5dtype_from_basetype = {
 6    og.BaseDataType.INT: np.int32,
 7    og.BaseDataType.INT64: np.int64,
 8    og.BaseDataType.HALF: np.float16,
 9    og.BaseDataType.FLOAT: np.float32,
10    og.BaseDataType.DOUBLE: np.float64,
11}
12
13supported_basetypes = [
14    og.BaseDataType.INT,
15    og.BaseDataType.INT64,
16    og.BaseDataType.HALF,
17    og.BaseDataType.FLOAT,
18    og.BaseDataType.DOUBLE,
19]
20
21basetype_resolution_table = [
22    [0, 1, 3, 3, 4],  # Int
23    [1, 1, 4, 4, 4],  # Int64
24    [3, 4, 2, 3, 4],  # Half
25    [3, 4, 3, 3, 4],  # Float
26    [4, 4, 4, 4, 4],  # Double
27]
28
29
30class OgnTutorialGenericMathNode:
31    """Node to multiple two values of any type"""
32
33    @staticmethod
34    def compute(db) -> bool:
35        """Compute the product of two values, if the types are all resolved.
36
37        When the types are not compatible for multiplication, or the result type is not compatible with the
38        resolved output type, the method will log an error and fail
39        """
40        try:
41            # To support multiplying array of vectors by array of scalars we need to broadcast the scalars to match the
42            # shape of the vector array, and we will convert the result to whatever the result is resolved to
43            atype = db.inputs.a.type
44            btype = db.inputs.b.type
45            rtype = db.outputs.product.type
46
47            result_dtype = dtype_from_basetype.get(rtype.base_type, None)
48
49            # Use numpy to perform the multiplication in order to automatically handle both scalar and array types
50            # and automatically convert to the resolved output type
51            if atype.array_depth > 0 and btype.array_depth > 0 and btype.tuple_count < atype.tuple_count:
52                r = np.multiply(db.inputs.a.value, db.inputs.b.value[:, np.newaxis], dtype=result_dtype)
53            else:
54                r = np.multiply(db.inputs.a.value, db.inputs.b.value, dtype=result_dtype)
55
56            db.outputs.product.value = r
57        except TypeError as error:
58            db.log_error(f"Multiplication could not be performed: {error}")
59            return False
60
61        return True
62
63    @staticmethod
64    def on_connection_type_resolve(node) -> None:
65        # Resolves the type of the output based on the types of inputs
66        atype = node.get_attribute("inputs:a").get_resolved_type()
67        btype = node.get_attribute("inputs:b").get_resolved_type()
68        productattr = node.get_attribute("outputs:product")
69        producttype = productattr.get_resolved_type()
70
71        # The output types can be only inferred when both inputs types are resolved.
72        if (
73            atype.base_type != og.BaseDataType.UNKNOWN
74            and btype.base_type != og.BaseDataType.UNKNOWN
75            and producttype.base_type == og.BaseDataType.UNKNOWN
76        ):
77
78            # Resolve the base type using the lookup table
79            base_type = og.BaseDataType.DOUBLE
80
81            a_index = supported_basetypes.index(atype.base_type)
82            b_index = supported_basetypes.index(btype.base_type)
83
84            if a_index >= 0 and b_index >= 0:
85                base_type = supported_basetypes[basetype_resolution_table[a_index][b_index]]
86
87            productattr.set_resolved_type(
88                og.Type(base_type, max(atype.tuple_count, btype.tuple_count), max(atype.array_depth, btype.array_depth))
89            )