4.1.3. Graph

class Graph : public primitiv::mixins::DefaultSettable<Graph>, primitiv::mixins::Nonmovable<Graph>

Computation graph.

Public Functions

void clear()

Clear all operators in the graph.

Remark
After calling this method, all Node objects supplied by the graph itself is invalidated.

std::vector<Node> add_operator(std::unique_ptr<Operator> &&op, const std::vector<Node> &args)

Adds an operator into the graph.

Return
New Node objects of resulting values.
Parameters
  • op: Interface of the new operator.
  • args: List of arguments. Each node should point a node in the same computation graph.

const Tensor &forward(const Node &node)

Calculates the value of given node.

Return
Calculated value.
Remark
This function calculates only the subgraph which is required to calculate the target node. Each intermediate result is stored to the corresponding node in the subgraph and they are re-used for future calculation. I.e., each node is calculated only once while the lifetime of the Graph object.
Parameters
  • node: Node object specifying the target node.

void backward(const Node &node)

Calculates the backpropagation.

Remark
If node is not yet forwarded, this function implicitly calls forward(node).
Parameters
  • node: Node object specifying the output node.

Shape get_shape(const Node &node) const

Retrieves the shape of the node.

Return
The shape of the node.
Parameters
  • node: Node object specifying the target node.

Device &get_device(const Node &node) const

Retrieves the device of the node.

Return
Device of the node.
Parameters
  • node: Node object specifying the target node.

std::string dump(const std::string &format) const

Dump internal graph structure.

Return
A string that represents the internal graph using given format.
Parameters
  • format: Name of the format. Available options: “dot” … Graphviz’s dot format.

std::uint32_t num_operators() const

Returns the number of operators in the computation graph.

Return
Number of nodes.