0%

ov::Node

摘要

  • ov::Node 类 学习笔记

ov::Node 详解

ov::Node 是 OpenVINO 中的基础类,表示计算图中的一个节点(操作或算子)。在 OpenVINO 的推理和模型操作中,ov::Node 是用于操作计算图结构和属性的核心组件。每个 ov::Node 都对应一个计算操作(如加法、卷积、激活函数)或数据结构(如常量、参数)。


核心功能

  1. 表示模型图中的操作节点
    每个 ov::Node 对应一个算子或数据节点。

  2. 支持动态图操作
    可以通过 ov::Node 动态修改模型结构。

  3. 支持输入输出连接
    管理算子的输入和输出。

  4. 提供元信息
    包括节点名称、类型、属性等信息。


创建和使用 ov::Node

ov::Node 通常通过高层接口如 ov::Core 读取模型后获得,而不是直接创建。下面是一些常见操作。

加载模型并获取节点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <openvino/openvino.hpp>
#include <iostream>

int main() {
ov::Core core;

// 读取模型
auto model = core.read_model("model.xml");

// 遍历所有节点
for (const auto& node : model->get_ops()) {
std::cout << "Node: " << node->get_friendly_name() << ", Type: "
<< node->get_type_name() << std::endl;
}

return 0;
}

主要方法和属性

1. 获取节点名称

1
2
std::string name = node->get_friendly_name();
std::cout << "Node name: " << name << std::endl;
  • get_friendly_name() 返回节点的友好名称。

2. 获取节点类型

1
2
std::string type = node->get_type_name();
std::cout << "Node type: " << type << std::endl;
  • get_type_name() 返回节点的操作类型,例如 "Convolution""Relu"

3. 获取节点输入

1
2
3
4
for (const auto& input : node->inputs()) {
std::cout << "Input index: " << input.get_index()
<< ", Shape: " << input.get_shape() << std::endl;
}
  • inputs() 返回所有输入端口。
  • 可以通过 get_shape()get_element_type() 获取输入的形状和数据类型。

4. 获取节点输出

1
2
3
4
for (const auto& output : node->outputs()) {
std::cout << "Output index: " << output.get_index()
<< ", Shape: " << output.get_shape() << std::endl;
}
  • outputs() 返回所有输出端口。

5. 设置节点属性

某些节点(如激活函数)支持修改属性:

1
2
3
4
5
6
7
if (node->get_type_name() == "Relu") {
// 动态转换为具体的算子类型
auto relu_node = std::dynamic_pointer_cast<ov::op::v0::Relu>(node);
if (relu_node) {
// 修改属性,例如某些激活函数支持修改 slope 等参数
}
}

动态修改模型图

通过 ov::Node,可以动态修改模型的计算图,包括替换节点、添加新节点等操作。

1. 替换节点

可以用新的节点替换旧节点:

1
2
3
4
5
6
7
8
auto relu_node = std::dynamic_pointer_cast<ov::op::v0::Relu>(node);
if (relu_node) {
// 创建新的节点,例如替换为 Sigmoid
auto sigmoid_node = std::make_shared<ov::op::v0::Sigmoid>(relu_node->input(0).get_source_output());

// 替换节点
ov::replace_node(relu_node, sigmoid_node);
}

2. 添加新节点

可以在图中插入新节点:

1
2
auto add_node = std::make_shared<ov::op::v1::Add>(
node->output(0), some_other_node->output(0));

节点间连接

1. 获取输入连接

1
2
3
4
for (const auto& input : node->inputs()) {
auto source_output = input.get_source_output();
std::cout << "Connected from node: " << source_output.get_node()->get_friendly_name() << std::endl;
}

2. 获取输出连接

1
2
3
4
5
for (const auto& output : node->outputs()) {
for (const auto& target_input : output.get_target_inputs()) {
std::cout << "Connected to node: " << target_input.get_node()->get_friendly_name() << std::endl;
}
}

实际应用示例

以下示例展示了如何修改模型中的某个节点:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <openvino/openvino.hpp>
#include <iostream>

int main() {
ov::Core core;

// 加载模型
auto model = core.read_model("model.xml");

// 查找所有 ReLU 节点并替换为 Sigmoid
for (const auto& node : model->get_ops()) {
if (node->get_type_name() == "Relu") {
auto relu_node = std::dynamic_pointer_cast<ov::op::v0::Relu>(node);
if (relu_node) {
// 创建新的 Sigmoid 节点
auto sigmoid_node = std::make_shared<ov::op::v0::Sigmoid>(relu_node->input(0).get_source_output());

// 替换 ReLU 节点
ov::replace_node(relu_node, sigmoid_node);
std::cout << "Replaced ReLU with Sigmoid" << std::endl;
}
}
}

// 保存修改后的模型
core.compile_model(model, "CPU")->export_model("modified_model.xml");

return 0;
}

常见节点类型

以下是 OpenVINO 中一些常见的节点类型:

  • ov::op::v0::Parameter:模型的输入节点。
  • ov::op::v0::Constant:常量节点。
  • ov::op::v0::Relu:ReLU 激活函数。
  • ov::op::v0::Convolution:卷积操作。
  • ov::op::v0::Add:加法操作。
  • ov::op::v0::Multiply:乘法操作。
  • ov::op::v0::Softmax:Softmax 激活函数。

注意事项

  1. 动态类型检查
    使用 std::dynamic_pointer_cast 确保节点类型匹配。

  2. 操作合法性
    修改模型图时,确保输入输出连接正确,否则可能导致模型不合法。

  3. 性能优化
    修改节点可能影响性能,建议使用 OpenVINO 的性能工具进行优化分析。


总结

ov::Node 是 OpenVINO 模型图的核心类,提供了访问和操作计算图的强大功能。通过对节点的操作,开发者可以实现模型的动态修改、优化和自定义算子操作。这对于模型量化、剪枝和图优化等高级任务尤为重要。

感谢老板支持!敬礼(^^ゞ