Skip to content

Commit

Permalink
Merge pull request #263 from anarkiwi/inf2
Browse files Browse the repository at this point in the history
Add experimental inference write annotation.
  • Loading branch information
anarkiwi committed May 6, 2024
2 parents c70666f + 149ced8 commit da8d559
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 43 deletions.
11 changes: 11 additions & 0 deletions lib/base_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,5 +332,16 @@ pmt::pmt_t base_impl::tune_rx_msg(FREQ_T tune_freq, bool tag_now) {
}
return tune_rx;
}

pmt::pmt_t base_impl::string_to_pmt(const std::string &s) {
return pmt::cons(pmt::make_dict(),
pmt::make_blob((const uint8_t *)s.c_str(), s.length()));
}

std::string base_impl::pmt_to_string(const pmt::pmt_t &msg) {
auto blob = pmt::cdr(msg);
return std::string(reinterpret_cast<const char *>(pmt::blob_data(blob)),
pmt::blob_length(blob));
}
} /* namespace iqtlabs */
} /* namespace gr */
2 changes: 2 additions & 0 deletions lib/base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class base_impl {
std::vector<tag_t> &rx_freq_tags, std::vector<TIME_T> &rx_times,
COUNT_T in_count);
pmt::pmt_t tune_rx_msg(COUNT_T tune_freq, bool tag_now);
pmt::pmt_t string_to_pmt(const std::string &s);
std::string pmt_to_string(const pmt::pmt_t &pmt);
};
} /* namespace iqtlabs */
} /* namespace gr */
6 changes: 1 addition & 5 deletions lib/image_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,7 @@ void image_inference_impl::run_inference_() {
const std::string output_json_str = output_json.dump();
json_q_.push(output_json_str + "\n\n");
delete_output_item_(output_item);
auto pdu =
pmt::cons(pmt::make_dict(),
pmt::init_u8vector(output_json_str.length(),
(const uint8_t *)output_json_str.c_str()));
message_port_pub(INFERENCE_KEY, pdu);
message_port_pub(INFERENCE_KEY, string_to_pmt(output_json_str));
}
}

Expand Down
6 changes: 1 addition & 5 deletions lib/iq_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,7 @@ void iq_inference_impl::run_inference_() {
const std::string output_json_str = output_json.dump();
json_q_.push(output_json_str + "\n\n");
delete_output_item_(output_item);
auto pdu =
pmt::cons(pmt::make_dict(),
pmt::init_u8vector(output_json_str.length(),
(const uint8_t *)output_json_str.c_str()));
message_port_pub(INFERENCE_KEY, pdu);
message_port_pub(INFERENCE_KEY, string_to_pmt(output_json_str));
}
}

Expand Down
8 changes: 1 addition & 7 deletions lib/iq_inference_standalone_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@
*/

#include "iq_inference_standalone_impl.h"
#include "base_impl.h"
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include <gnuradio/io_signature.h>
Expand Down Expand Up @@ -260,16 +259,11 @@ int iq_inference_standalone_impl::work(int noutput_items,
torchserve_client_->send_inference_request(results, error);
torchserve_client_->disconnect();
d_logger->info("results {}, error {}", results, error);
auto pdu =
pmt::cons(pmt::make_dict(),
pmt::init_u8vector(results.length(),
(const uint8_t *)results.c_str()));
message_port_pub(INFERENCE_KEY, pdu);
message_port_pub(INFERENCE_KEY, string_to_pmt(results));
}
in += vlen_;
}
return noutput_items;
}

} /* namespace iqtlabs */
} /* namespace gr */
3 changes: 2 additions & 1 deletion lib/iq_inference_standalone_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
#ifndef INCLUDED_IQTLABS_IQ_INFERENCE_STANDALONE_IMPL_H
#define INCLUDED_IQTLABS_IQ_INFERENCE_STANDALONE_IMPL_H

#include "base_impl.h"
#include "iqtlabs_types.h"
#include "torchserve_client.h"
#include <boost/scoped_ptr.hpp>
Expand All @@ -213,7 +214,7 @@
namespace gr {
namespace iqtlabs {

class iq_inference_standalone_impl : public iq_inference_standalone {
class iq_inference_standalone_impl : public iq_inference_standalone, base_impl {
private:
boost::scoped_ptr<torchserve_client> torchserve_client_;
std::vector<std::string> model_names_;
Expand Down
65 changes: 41 additions & 24 deletions lib/write_freq_samples_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,30 +252,46 @@ write_freq_samples_impl::write_freq_samples_impl(
write_freq_samples_impl::~write_freq_samples_impl() {}

void write_freq_samples_impl::recv_inference_(const pmt::pmt_t &msg) {
// inference_item_type inference_item;
// try {
//  nlohmann::json inference_results = nlohmann::json::parse(msg_str);
//  if (inference_results.contains("predictions")) {
// auto metadata = inference_results["metadata"];
// inference_item.sample_clock = std::stod(metadata["sample_clock"]);
// inference_item.sample_count = std::stoi(metadata["sample_count"]);
// double half_samp_rate = std::stod(metadata["sample_rate"]) / 2;
// inference_item.freq_lower_edge = last_rx_freq_ - half_samp_rate;
// inference_item.freq_upper_edge = last_rx_freq_ + half_samp_rate;
//  auto predictions = inference_results["predictions"];
//  for (auto &prediction_class : predictions.items()) {
// std::string prediction_name = prediction_class.key();
//  for (auto &prediction : prediction_class.value()) {
// if (prediction_name != "No signal") {
// continue;
// }
// inference_item.description = prediction_name;
// }
// }
// inference_q_.push(inference_item);
// } catch (std::exception &ex) {
//  d_logger->error("invalid json: " + std::string(ex.what()));
// }
// TODO: non-rotate not supported.
// Among other things, need to delineate inference results for current
// window and adjust sample clock.
if (rotate_) {
return;
}
const std::string msg_str = pmt_to_string(msg);
d_logger->info("inference results: {}", msg_str);
try {
nlohmann::json inference_results = nlohmann::json::parse(msg_str);
const auto metadata = inference_results["metadata"];
const TIME_T sample_clock =
std::stod((std::string)metadata["sample_clock"]);
const int sample_count = std::stoi((std::string)metadata["sample_count"]);
const FREQ_T sample_rate = std::stod((std::string)metadata["sample_rate"]);
if (inference_results.contains("predictions")) {
auto predictions = inference_results["predictions"];
for (auto &prediction_class : predictions.items()) {
// TODO: make configurable.
if (prediction_class.key() == "No signal") {
continue;
}
for (auto &prediction : prediction_class.value()) {
// TODO: add confidence and model to description.
inference_item_type inference_item;
inference_item.sample_start = sample_clock;
inference_item.sample_count = sample_count;
inference_item.freq_lower_edge = last_rx_freq_ - (sample_rate / 2);
inference_item.freq_upper_edge = last_rx_freq_ + (sample_rate / 2);
inference_item.description = prediction_class.key();
inference_item.label = inference_item.description;
if (!inference_q_.push(inference_item)) {
d_logger->error("inference annotation queue full");
}
}
}
}
} catch (std::exception &ex) {
std::string error = "invalid json: " + std::string(ex.what());
}
}

bool write_freq_samples_impl::stop() {
Expand Down Expand Up @@ -321,6 +337,7 @@ void write_freq_samples_impl::close_() {
sigmf_record_t record =
create_sigmf(final_samples_path, open_time_, datatype_, samp_rate_,
last_rx_freq_, gain_);
d_logger->info("writing {} annotations", inference_q_.read_available());
// TODO: handle annotations for the rotate case.
while (!inference_q_.empty()) {
inference_item_type inference_item;
Expand Down
2 changes: 1 addition & 1 deletion lib/write_freq_samples_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@
namespace gr {
namespace iqtlabs {

#define MAX_ANNOTATIONS 128
#define MAX_ANNOTATIONS 1024

typedef struct inference_item {
COUNT_T sample_count;
Expand Down

0 comments on commit da8d559

Please sign in to comment.