Skip to content

Commit

Permalink
Merge pull request #202 from anarkiwi/moreiq
Browse files Browse the repository at this point in the history
test/workarounds for I/Q inference
  • Loading branch information
anarkiwi committed Feb 1, 2024
2 parents 8a713cb + 2870f38 commit a11b1f2
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 45 deletions.
2 changes: 2 additions & 0 deletions lib/image_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ void image_inference_impl::run_inference_() {

// attempt to re-use existing connection. may fail if an http 1.1 server
// has dropped the connection to use in the meantime.
// TODO: handle case where model server is up but blocks us forever.
if (inference_connected_) {
try {
boost::beast::flat_buffer buffer;
Expand Down Expand Up @@ -677,6 +678,7 @@ int image_inference_impl::general_work(int noutput_items,
const auto rel = tag.offset - in_first;
in_first += rel;

// TODO: process leftover untagged items.
if (rel > 0) {
process_items_(rel, in);
}
Expand Down
52 changes: 26 additions & 26 deletions lib/iq_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ void iq_inference_impl::run_inference_() {
size_t rendered_predictions = 0;

for (auto model_name : model_names_) {
const std::string_view body(reinterpret_cast<char const *>(
output_item.samples,
output_item.sample_count * sizeof(gr_complex)));
const std::string_view body(
reinterpret_cast<char const *>(output_item.samples),
output_item.sample_count * sizeof(gr_complex));
boost::beast::http::request<boost::beast::http::string_body> req{
boost::beast::http::verb::post, "/predictions/" + model_name, 11};
req.keep_alive(true);
Expand All @@ -326,9 +326,12 @@ void iq_inference_impl::run_inference_() {
req.body() = body;
req.prepare_payload();
std::string results;
// TODO: troubleshoot test flask server hang after one request.
inference_connected_ = false;

// attempt to re-use existing connection. may fail if an http 1.1 server
// has dropped the connection to use in the meantime.
// TODO: handle case where model server is up but blocks us forever.
if (inference_connected_) {
try {
boost::beast::flat_buffer buffer;
Expand Down Expand Up @@ -377,7 +380,8 @@ void iq_inference_impl::run_inference_() {
for (auto &prediction_ref : prediction_class.value().items()) {
auto prediction = prediction_ref.value();
prediction["model"] = model_name;
float conf = prediction["conf"];
// TODO: gate on minimum confidence.
// float conf = prediction["conf"];
results_json[prediction_class.key()].emplace_back(prediction);
}
}
Expand Down Expand Up @@ -437,10 +441,6 @@ void iq_inference_impl::process_items_(size_t power_in_count,
delete_output_item_(output_item);
d_logger->error("inference queue full");
}
// volk_32f_accumulator_s32f(
// total_.get(), (const float *)&samples_lookback_[j * vlen_], vlen_ *
// 2);
// d_logger->info("max: {}, total: {}, {}", power_max, *total_, j);
}
}

Expand All @@ -456,6 +456,23 @@ int iq_inference_impl::general_work(int noutput_items,
const float *power_in = static_cast<const float *>(input_items[1]);
std::vector<tag_t> all_tags, rx_freq_tags;
std::vector<double> rx_times;
size_t leftover = 0;

while (!json_q_.empty()) {
std::string json;
json_q_.pop(json);
out_buf_.insert(out_buf_.end(), json.begin(), json.end());
}

if (!out_buf_.empty()) {
auto out = static_cast<char *>(output_items[0]);
leftover = std::min(out_buf_.size(), (size_t)noutput_items);
auto from = out_buf_.begin();
auto to = from + leftover;
std::copy(from, to, out);
out_buf_.erase(from, to);
}

get_tags_in_window(all_tags, 1, 0, power_in_count);
get_tags(tag_, all_tags, rx_freq_tags, rx_times, power_in_count);

Expand All @@ -475,6 +492,7 @@ int iq_inference_impl::general_work(int noutput_items,
const auto rel = tag.offset - in_first;
in_first += rel;

// TODO: process leftover untagged items.
if (rel > 0) {
process_items_(rel, power_read, power_in);
}
Expand All @@ -488,24 +506,6 @@ int iq_inference_impl::general_work(int noutput_items,

consume(0, samples_in_count);
consume(1, power_in_count);

size_t leftover = 0;

while (!json_q_.empty()) {
std::string json;
json_q_.pop(json);
out_buf_.insert(out_buf_.end(), json.begin(), json.end());
}

if (!out_buf_.empty()) {
auto out = static_cast<char *>(output_items[0]);
leftover = std::min(out_buf_.size(), (size_t)noutput_items);
auto from = out_buf_.begin();
auto to = from + leftover;
std::copy(from, to, out);
out_buf_.erase(from, to);
}

return leftover;
}

Expand Down
27 changes: 16 additions & 11 deletions lib/retune_fft_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,21 @@ void retune_fft_impl::reset_items_() {
peak_[i] = fft_min_;
}
sample_count_ = 0;
fft_count_ = 0;
}

void retune_fft_impl::calc_peaks_() {
volk_32f_s32f_multiply_32f(mean_.get(), (const float *)sample_.get(),
1 / float(sample_count_), nfft_);
for (size_t k = 0; k < nfft_; ++k) {
float mean = std::min(std::max(mean_[k], fft_min_), fft_max_);
peak_[k] = std::max(mean, peak_[k]);
sample_[k] = 0;
if (sample_count_) {
volk_32f_s32f_multiply_32f(mean_.get(), (const float *)sample_.get(),
1 / float(sample_count_), nfft_);
for (size_t k = 0; k < nfft_; ++k) {
float mean = std::min(std::max(mean_[k], fft_min_), fft_max_);
peak_[k] = std::max(mean, peak_[k]);
sample_[k] = 0;
}

sample_count_ = 0;
}
sample_count_ = 0;
}

retune_fft_impl::~retune_fft_impl() { close_(); }
Expand Down Expand Up @@ -342,7 +346,7 @@ void retune_fft_impl::write_items_(const input_type *in) {
void retune_fft_impl::sum_items_(const input_type *in) {
volk_32f_x2_add_32f(sample_.get(), (const float *)sample_.get(), in, nfft_);
++sample_count_;
if (peak_fft_range_ && sample_count_ && sample_count_ == peak_fft_range_) {
if (peak_fft_range_ && sample_count_ == peak_fft_range_) {
calc_peaks_();
}
}
Expand Down Expand Up @@ -386,6 +390,7 @@ void retune_fft_impl::process_items_(size_t c, const input_type *&in,
fft_output += nfft_;
write_items_(in);
sum_items_(in);
++fft_count_;
++produced;
if (need_retune_(1)) {
if (!pre_fft_) {
Expand Down Expand Up @@ -480,15 +485,14 @@ void retune_fft_impl::write_buckets_(double host_now, uint64_t rx_freq) {
}

void retune_fft_impl::process_buckets_(uint64_t rx_freq, double rx_time) {
if (last_rx_freq_ && sample_count_) {
if (last_rx_freq_ && fft_count_) {
reopen_(rx_time, rx_freq);
if (sample_count_) {
if (!peak_fft_range_) {
calc_peaks_();
}
write_buckets_(rx_time, rx_freq);
}
reset_items_();
fft_count_ = 0;
skip_fft_count_ = skip_tune_step_fft_;
write_step_fft_count_ = write_step_fft_;
last_rx_freq_ = rx_freq;
Expand All @@ -513,6 +517,7 @@ void retune_fft_impl::process_tags_(const input_type *in, size_t in_count,
const auto rel = tag.offset - in_first;
in_first += rel;

// TODO: process leftover untagged items.
if (rel > 0) {
process_items_(rel, in, fft_output, produced);
}
Expand Down
104 changes: 96 additions & 8 deletions python/iqtlabs/qa_iq_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,33 +203,121 @@
# limitations under the License.
#

import concurrent.futures
import json
import os
import pmt
import time
import tempfile
from flask import Flask, request
from gnuradio import gr, gr_unittest
from gnuradio import analog, blocks

# from gnuradio import blocks
try:
from gnuradio.iqtlabs import iq_inference
from gnuradio.iqtlabs import iq_inference, retune_pre_fft, tuneable_test_source
except ImportError:
import os
import sys

dirname, filename = os.path.split(os.path.abspath(__file__))
sys.path.append(os.path.join(dirname, "bindings"))
from gnuradio.iqtlabs import iq_inference
from gnuradio.iqtlabs import qa_inference, retune_pre_fft, tuneable_test_source


class qa_iq_inference(gr_unittest.TestCase):

def setUp(self):
self.tb = gr.top_block()
self.pid = os.fork()

def tearDown(self):
self.tb = None
if self.pid:
os.kill(self.pid, 15)

def test_instance(self):
instance = iq_inference(
"rx_freq", 1024, 512, 0.1, "server", "model", 0.1, 0, int(20e6)
def simulate_torchserve(self, port, model_name, result):
app = Flask(__name__)

# nosemgrep:github.workflows.config.useless-inner-function
@app.route(f"/predictions/{model_name}", methods=["POST"])
def predictions_test():
print("got %s, count %u" % (type(request.data), len(request.data)))
return json.dumps(result, indent=2), 200

try:
app.run(host="127.0.0.1", port=port)
except RuntimeError:
return

def run_flowgraph(self, tmpdir, fft_size, samp_rate, port, model_name):
test_file = os.path.join(tmpdir, "samples")
freq_divisor = 1e9
new_freq = 1e9 / 2
delay = 500

source = tuneable_test_source(freq_divisor)
strobe = blocks.message_strobe(pmt.to_pmt({"freq": new_freq}), delay)
throttle = blocks.throttle(gr.sizeof_gr_complex, samp_rate, True)
fs = blocks.file_sink(gr.sizeof_char, os.path.join(tmpdir, test_file), False)
c2r = blocks.complex_to_real(1)
stream2vector_power = blocks.stream_to_vector(gr.sizeof_float, fft_size)
stream2vector_samples = blocks.stream_to_vector(gr.sizeof_gr_complex, fft_size)

iq_inf = iq_inference(
"rx_freq", fft_size, 512, -1e9, f"localhost:{port}", model_name, 0.8, 10001, int(samp_rate)
)
instance.stop()

self.tb.msg_connect((strobe, "strobe"), (source, "cmd"))
self.tb.connect((source, 0), (throttle, 0))
self.tb.connect((throttle, 0), (c2r, 0))
self.tb.connect((throttle, 0), (stream2vector_samples, 0))
self.tb.connect((c2r, 0), (stream2vector_power, 0))
self.tb.connect((stream2vector_samples, 0), (iq_inf, 0))
self.tb.connect((stream2vector_power, 0), (iq_inf, 1))
self.tb.connect((iq_inf, 0), (fs, 0))

self.tb.start()
test_time = 10
time.sleep(test_time)
self.tb.stop()
self.tb.wait()
return test_file

def test_bad_instance(self):
port = 11002
model_name = "testmodel"
predictions_result = ["cant", "parse", {"this": 0}]
if self.pid == 0:
self.simulate_torchserve(port, model_name, predictions_result)
return
fft_size = 1024
samp_rate = 4e6
with tempfile.TemporaryDirectory() as tmpdir:
self.run_flowgraph(tmpdir, fft_size, samp_rate, port, model_name)

def test_instance(self):
port = 11001
model_name = "testmodel"
px = 100
predictions_result = {"modulation": [{"conf": 0.9}]}
if self.pid == 0:
self.simulate_torchserve(port, model_name, predictions_result)
return
fft_size = 1024
samp_rate = 4e6
with tempfile.TemporaryDirectory() as tmpdir:
test_file = self.run_flowgraph(
tmpdir, fft_size, samp_rate, port, model_name
)
self.assertTrue(os.stat(test_file).st_size)
with open(test_file) as f:
content = f.read()
json_raw_all = content.split("\n\n")
self.assertTrue(json_raw_all)
for json_raw in json_raw_all:
if not json_raw:
continue
result = json.loads(json_raw)
print(result)


if __name__ == "__main__":
Expand Down

0 comments on commit a11b1f2

Please sign in to comment.