Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMQP changes for C++ Claim Based Security - also fixed session close issue that affected reliability #1833

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions sdk/core/azure_core_amqp/src/cbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::fmt::Debug;
use super::session::AmqpSession;

#[cfg(all(feature = "fe2o3-amqp", not(target_arch = "wasm32")))]
type CbsImplementation = super::fe2o3::cbs::Fe2o3ClaimsBasedSecurity;
type CbsImplementation<'a> = super::fe2o3::cbs::Fe2o3ClaimsBasedSecurity<'a>;

#[cfg(any(not(any(feature = "fe2o3-amqp")), target_arch = "wasm32"))]
type CbsImplementation = super::noop::NoopAmqpClaimsBasedSecurity;
type CbsImplementation<'a> = super::noop::NoopAmqpClaimsBasedSecurity<'a>;

pub trait AmqpClaimsBasedSecurityApis {
/// Asynchronously attaches the Claims-Based Security (CBS) node to the AMQP session.
Expand Down Expand Up @@ -52,11 +52,18 @@ pub trait AmqpClaimsBasedSecurityApis {
}

#[derive(Debug)]
pub struct AmqpClaimsBasedSecurity {
implementation: CbsImplementation,
pub struct AmqpClaimsBasedSecurity<'a> {
implementation: CbsImplementation<'a>,
}

impl AmqpClaimsBasedSecurityApis for AmqpClaimsBasedSecurity {
impl<'a> AmqpClaimsBasedSecurity<'a> {
pub fn new(session: &'a AmqpSession) -> Result<Self> {
Ok(Self {
implementation: CbsImplementation::new(session)?,
})
}
}
impl<'a> AmqpClaimsBasedSecurityApis for AmqpClaimsBasedSecurity<'a> {
async fn authorize_path(
&self,
path: impl Into<String> + Debug,
Expand All @@ -67,16 +74,7 @@ impl AmqpClaimsBasedSecurityApis for AmqpClaimsBasedSecurity {
.authorize_path(path, secret, expires_on)
.await
}

async fn attach(&self) -> Result<()> {
self.implementation.attach().await
}
}

impl AmqpClaimsBasedSecurity {
pub fn new(session: AmqpSession) -> Result<Self> {
Ok(Self {
implementation: CbsImplementation::new(session)?,
})
}
}
24 changes: 11 additions & 13 deletions sdk/core/azure_core_amqp/src/fe2o3/cbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,36 @@ use azure_core::error::Result;
use fe2o3_amqp_cbs::token::CbsToken;
use fe2o3_amqp_types::primitives::Timestamp;
use std::borrow::BorrowMut;
use std::{
fmt::Debug,
sync::{Arc, OnceLock},
};
use std::{fmt::Debug, sync::OnceLock};
use tracing::{debug, trace};

#[derive(Debug)]
pub(crate) struct Fe2o3ClaimsBasedSecurity {
pub(crate) struct Fe2o3ClaimsBasedSecurity<'a> {
cbs: OnceLock<Mutex<fe2o3_amqp_cbs::client::CbsClient>>,
session: Arc<Mutex<fe2o3_amqp::session::SessionHandle<()>>>,
session: &'a AmqpSession,
}

impl Fe2o3ClaimsBasedSecurity {
pub fn new(session: AmqpSession) -> Result<Self> {
impl<'a> Fe2o3ClaimsBasedSecurity<'a> {
pub fn new(session: &'a AmqpSession) -> Result<Self> {
Ok(Self {
cbs: OnceLock::new(),
session: session.implementation.get()?,
session,
})
}
}

impl Fe2o3ClaimsBasedSecurity {}
impl<'a> Fe2o3ClaimsBasedSecurity<'a> {}

impl Drop for Fe2o3ClaimsBasedSecurity {
impl<'a> Drop for Fe2o3ClaimsBasedSecurity<'a> {
fn drop(&mut self) {
debug!("Dropping Fe2o3ClaimsBasedSecurity.");
}
}

impl AmqpClaimsBasedSecurityApis for Fe2o3ClaimsBasedSecurity {
impl<'a> AmqpClaimsBasedSecurityApis for Fe2o3ClaimsBasedSecurity<'a> {
async fn attach(&self) -> Result<()> {
let mut session = self.session.lock().await;
let session = self.session.implementation.get()?;
let mut session = session.lock().await;
let cbs_client = fe2o3_amqp_cbs::client::CbsClient::builder()
.client_node_addr("rust_amqp_cbs")
.attach(session.borrow_mut())
Expand Down
88 changes: 41 additions & 47 deletions sdk/core/azure_core_amqp/src/fe2o3/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,61 +51,55 @@ impl AmqpConnectionApis for Fe2o3AmqpConnection {
.container_id(id)
.max_frame_size(65536);

let options = options.ok_or_else(|| {
azure_core::Error::new(
azure_core::error::ErrorKind::Other,
"Connection options are not set.",
)
})?;

if let Some(frame_size) = options.max_frame_size {
builder = builder.max_frame_size(frame_size);
}
if let Some(options) = options {
if let Some(frame_size) = options.max_frame_size {
builder = builder.max_frame_size(frame_size);
}

if let Some(channel_max) = options.channel_max {
builder = builder.channel_max(channel_max);
}
if let Some(idle_timeout) = options.idle_timeout {
builder = builder.idle_time_out(idle_timeout.whole_milliseconds() as u32);
}
if let Some(outgoing_locales) = options.outgoing_locales.as_ref() {
for locale in outgoing_locales {
builder = builder.add_outgoing_locales(locale.as_str());
if let Some(channel_max) = options.channel_max {
builder = builder.channel_max(channel_max);
}
}
if let Some(incoming_locales) = options.incoming_locales {
for locale in incoming_locales {
builder = builder.add_incoming_locales(locale.as_str());
if let Some(idle_timeout) = options.idle_timeout {
builder = builder.idle_time_out(idle_timeout.whole_milliseconds() as u32);
}
}
if let Some(offered_capabilities) = options.offered_capabilities.as_ref() {
for capability in offered_capabilities {
let capability: fe2o3_amqp_types::primitives::Symbol =
capability.clone().into();
builder = builder.add_offered_capabilities(capability);
if let Some(outgoing_locales) = options.outgoing_locales.as_ref() {
for locale in outgoing_locales {
builder = builder.add_outgoing_locales(locale.as_str());
}
}
}
if let Some(desired_capabilities) = options.desired_capabilities.as_ref() {
for capability in desired_capabilities {
let capability: fe2o3_amqp_types::primitives::Symbol =
capability.clone().into();
builder = builder.add_desired_capabilities(capability);
if let Some(incoming_locales) = options.incoming_locales {
for locale in incoming_locales {
builder = builder.add_incoming_locales(locale.as_str());
}
}
}
if let Some(properties) = options.properties.as_ref() {
let mut fields = fe2o3_amqp::types::definitions::Fields::new();
for property in properties.iter() {
let k = fe2o3_amqp_types::primitives::Symbol::from(property.0);
let v = fe2o3_amqp_types::primitives::Value::from(property.1);
if let Some(offered_capabilities) = options.offered_capabilities.as_ref() {
for capability in offered_capabilities {
let capability: fe2o3_amqp_types::primitives::Symbol =
capability.clone().into();
builder = builder.add_offered_capabilities(capability);
}
}
if let Some(desired_capabilities) = options.desired_capabilities.as_ref() {
for capability in desired_capabilities {
let capability: fe2o3_amqp_types::primitives::Symbol =
capability.clone().into();
builder = builder.add_desired_capabilities(capability);
}
}
if let Some(properties) = options.properties.as_ref() {
let mut fields = fe2o3_amqp::types::definitions::Fields::new();
for property in properties.iter() {
let k = fe2o3_amqp_types::primitives::Symbol::from(property.0);
let v = fe2o3_amqp_types::primitives::Value::from(property.1);

fields.insert(k, v);
fields.insert(k, v);
}
builder = builder.properties(fields);
}
if let Some(buffer_size) = options.buffer_size {
builder = builder.buffer_size(buffer_size);
}
builder = builder.properties(fields);
}
if let Some(buffer_size) = options.buffer_size {
builder = builder.buffer_size(buffer_size);
}

self.connection
.set(Mutex::new(builder.open(url).await.map_err(AmqpOpen::from)?))
.map_err(|_| {
Expand Down
15 changes: 10 additions & 5 deletions sdk/core/azure_core_amqp/src/noop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
value::{AmqpOrderedMap, AmqpSymbol, AmqpValue},
};
use azure_core::{credentials::AccessToken, error::Result};
use std::marker::PhantomData;

#[derive(Debug, Default)]
pub(crate) struct NoopAmqpConnection {}
Expand All @@ -31,7 +32,9 @@ pub(crate) struct NoopAmqpReceiver {}
pub(crate) struct NoopAmqpSession {}

#[derive(Debug, Default)]
pub(crate) struct NoopAmqpClaimsBasedSecurity {}
pub(crate) struct NoopAmqpClaimsBasedSecurity<'a> {
phantom: PhantomData<&'a AmqpSession>,
}

impl NoopAmqpConnection {
pub fn new() -> Self {
Expand Down Expand Up @@ -81,13 +84,15 @@ impl AmqpSessionApis for NoopAmqpSession {
}
}

impl NoopAmqpClaimsBasedSecurity {
pub fn new(session: AmqpSession) -> Result<Self> {
Ok(Self {})
impl<'a> NoopAmqpClaimsBasedSecurity<'a> {
pub fn new(session: &'a AmqpSession) -> Result<Self> {
Ok(Self {
phantom: PhantomData,
})
}
}

impl AmqpClaimsBasedSecurityApis for NoopAmqpClaimsBasedSecurity {
impl<'a> AmqpClaimsBasedSecurityApis for NoopAmqpClaimsBasedSecurity<'a> {
async fn attach(&self) -> Result<()> {
unimplemented!();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ impl ConsumerClient {
let session = AmqpSession::new();
session.begin(connection, None).await?;

let cbs = AmqpClaimsBasedSecurity::new(session)?;
let cbs = AmqpClaimsBasedSecurity::new(&session)?;
cbs.attach().await?;

debug!("Get Token.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ impl ProducerClient {
let session = AmqpSession::new();
session.begin(connection, None).await?;

let cbs = AmqpClaimsBasedSecurity::new(session)?;
let cbs = AmqpClaimsBasedSecurity::new(&session)?;
cbs.attach().await?;

debug!("Get Token.");
Expand Down