Open
Description
算法工程师们的给出的demo往往是用Python写的,但是在实际的项目中,考虑的性能的因素,我们常常需要使用C++来对demo进行一次重写,TF官方的文档中,这部分东西比较少,所以我对自己项目中用到的一些东西进行一点简单的记述。下面的代码也许不能直接拿来运行,但是大致意思是对的。
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g = tf.Graph()
tf.import_graph_def(graph_def, name='')
with tf.Session(graph=g) as sess:
_in = sess.graph.get_tensor_by_name("data/in:0")
_out = sess.graph.get_tensor_by_name("out:0")
in_batch = np.zeros((1, 16)) # np array (1, 16) as input
feed_dict = {inp: in_batch}
results = sess.run(_out, feed_dict=feed_dict)
以上为Python版本的加载模型和运行启动代码,下面是C++版本的
tensorflow::GraphDef graph_def;
Status load_graph_status = ReadBinaryProto(tensorflow::Env::Default(), “./model.pb”, &graph_def);
if (load_graph_status.ok()) {
//session config
std::unique_ptr<tensorflow::Session> session;
tensorflow::SessionOptions session_options;
session_options.config.mutable_gpu_options()->set_allow_growth(true);
session->reset(tensorflow::NewSession(session_options));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
// error
}
tensorflow::Tensor _in(tensorflow::DT_INT32, tensorflow::TensorShape({1, 16}));
auto inp_mapped = inp.tensor<int, 2>(); // 这里是C++版本比较奇特的地方,要通过这种方式才能完成输入数据的准备
for (int i = 0; i < 1; i++) {
for (int j = 0; j < 16; j++) {
inp_mapped(i, j) = 0; // input data
}
}
std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = {
{"data/in:0", _in}
};
std::vector<Tensor> _out;
Status run_status = session_->Run(feed_dict, {"out:0",}, {}, _out);
}