Skip to content

TensorFlow Python to C++ Python到C++的迁移 #29

Open
@Shellbye

Description

@Shellbye

算法工程师们的给出的demo往往是用Python写的,但是在实际的项目中,考虑的性能的因素,我们常常需要使用C++来对demo进行一次重写,TF官方的文档中,这部分东西比较少,所以我对自己项目中用到的一些东西进行一点简单的记述。下面的代码也许不能直接拿来运行,但是大致意思是对的。

with gfile.FastGFile('model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g = tf.Graph()
    tf.import_graph_def(graph_def, name='')
    with tf.Session(graph=g) as sess:
        _in = sess.graph.get_tensor_by_name("data/in:0")
        _out = sess.graph.get_tensor_by_name("out:0")
        in_batch = np.zeros((1, 16))  # np array (1, 16) as input
        feed_dict = {inp: in_batch}
        results = sess.run(_out, feed_dict=feed_dict)

以上为Python版本的加载模型和运行启动代码,下面是C++版本的

tensorflow::GraphDef graph_def;
Status load_graph_status = ReadBinaryProto(tensorflow::Env::Default(), “./model.pb”, &graph_def);
if (load_graph_status.ok()) {
    //session config
    std::unique_ptr<tensorflow::Session> session;
    tensorflow::SessionOptions session_options;
    session_options.config.mutable_gpu_options()->set_allow_growth(true);
    session->reset(tensorflow::NewSession(session_options));

    Status session_create_status = (*session)->Create(graph_def);
    if (!session_create_status.ok()) {
        // error
    }
    tensorflow::Tensor _in(tensorflow::DT_INT32, tensorflow::TensorShape({1, 16}));
    auto inp_mapped = inp.tensor<int, 2>();  //  这里是C++版本比较奇特的地方,要通过这种方式才能完成输入数据的准备
    for (int i = 0; i < 1; i++) {
        for (int j = 0; j < 16; j++) {
            inp_mapped(i, j) = 0;  // input data
        }
    }

    std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = {
            {"data/in:0", _in}
    };
    std::vector<Tensor> _out;
    Status run_status = session_->Run(feed_dict, {"out:0",}, {}, _out);
}

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions