Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 49 additions & 11 deletions crates/mqtt-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ use std::{
task::{Context, Poll},
};
use thiserror::Error;
use tracing::{field, instrument, span, Level, Span};

#[derive(Clone)]
pub struct Message {
pub topic: Arc<str>,
pub payload: Arc<[u8]>,
pub retained: bool,
pub span: Span,
}

impl Message {
Expand All @@ -34,6 +36,10 @@ impl Message {
pub fn retained(&self) -> bool {
self.retained
}

pub fn span(&self) -> &Span {
&self.span
}
}

#[derive(Clone)]
Expand All @@ -56,6 +62,7 @@ impl Stream for Subscription {

#[derive(Clone)]
pub struct HassMqttClient {
client_id: Arc<str>,
sender: flume::Sender<command::Command>,
}

Expand All @@ -65,8 +72,9 @@ impl HassMqttClient {
T: command::ClientCommand,
command::Command: command::FromClientCommand<T>,
{
let span = Span::current();
let cmd = Arc::new(cmd);
let (msg, receiver) = command::Command::from_command(cmd.clone());
let (msg, receiver) = command::Command::from_command(cmd.clone(), span);
self
.sender
.send_async(msg)
Expand Down Expand Up @@ -96,11 +104,20 @@ impl ConnectError {
}

impl HassMqttClient {
#[instrument(
level = Level::DEBUG,
name = "HassMqttClient::new",
skip_all,
fields(
provider.name = T::NAME,
)
err,
)]
pub async fn new<T: MqttProvider>(options: HassMqttOptions) -> Result<Self, ConnectError> {
let sender = InnerClient::spawn::<T>(options)
let (sender, client_id) = InnerClient::spawn::<T>(options)
.await
.map_err(ConnectError::new)?;
Ok(Self { sender })
Ok(Self { sender, client_id })
}
}

Expand All @@ -110,7 +127,12 @@ impl HassMqttClient {
domain: impl Into<Arc<str>>,
entity_id: impl Into<Arc<str>>,
) -> EntityTopicBuilder {
EntityTopicBuilder::new(self, domain, entity_id)
self._entity(domain.into(), entity_id.into())
}

fn _entity(&self, domain: Arc<str>, entity_id: Arc<str>) -> EntityTopicBuilder {
let span = span!(Level::DEBUG, "HassMqttClient::entity", client.id = %self.client_id, entity.domain = %domain, entity.id = %entity_id, entity.topic = field::Empty);
EntityTopicBuilder::new(self, domain, entity_id, span)
}
}

Expand All @@ -125,16 +147,24 @@ pub struct PublishMessageError {
}

impl HassMqttClient {
#[instrument(
level = Level::DEBUG,
name = "HassMqttClient::publish_message",
skip_all,
fields(
client.id = %self.client_id,
message.topic = %topic,
message.retained = retained,
message.qos = %qos,
message.payload.len = payload.len(),
))]
pub(crate) async fn publish_message(
&self,
topic: impl Into<Arc<str>>,
payload: impl Into<Arc<[u8]>>,
topic: Arc<str>,
payload: Arc<[u8]>,
retained: bool,
qos: QosLevel,
) -> Result<(), PublishMessageError> {
let topic = topic.into();
let payload = payload.into();

self
.command(command::publish(topic.clone(), payload, retained, qos))
.await
Expand All @@ -159,12 +189,20 @@ pub struct SubscribeError {
}

impl HassMqttClient {
#[instrument(
level = Level::DEBUG,
name = "HassMqttClient::subscribe",
skip_all,
fields(
client.id = %self.client_id,
subscription.topic = %topic,
subscription.qos,
))]
pub(crate) async fn subscribe(
&self,
topic: impl Into<Arc<str>>,
topic: Arc<str>,
qos: QosLevel,
) -> Result<Subscription, SubscribeError> {
let topic = topic.into();
let result = self
.command(command::subscribe(topic.clone(), qos))
.await
Expand Down
18 changes: 9 additions & 9 deletions crates/mqtt-client/src/client/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use async_trait::async_trait;
use hass_mqtt_provider::MqttClient;
use std::sync::Arc;
use tokio::sync::oneshot;
use tracing::{Instrument, Span};

pub(super) use entity::EntityCommand;
pub(super) use publish::PublishCommand;
Expand All @@ -32,7 +33,7 @@ pub(crate) type CommandResultSender<T> = oneshot::Sender<CommandResult<T>>;
pub(crate) type CommandResultReceiver<T> = oneshot::Receiver<CommandResult<T>>;

pub(crate) trait FromClientCommand<T: ClientCommand>: Sized {
fn from_command(command: Arc<T>) -> (Self, CommandResultReceiver<T>);
fn from_command(command: Arc<T>, span: Span) -> (Self, CommandResultReceiver<T>);
}

macro_rules! commands {
Expand All @@ -41,26 +42,26 @@ macro_rules! commands {
}) => {
#[allow(clippy::enum_variant_names)]
$vis enum $name {
$($variant(Arc<$variant>, CommandResultSender<$variant>),)*
$($variant(Arc<$variant>, CommandResultSender<$variant>, Span),)*
}

$(
impl FromClientCommand<$variant> for $name {
fn from_command(command: Arc<$variant>) -> (Self, CommandResultReceiver<$variant>) {
fn from_command(command: Arc<$variant>, span: Span) -> (Self, CommandResultReceiver<$variant>) {
let (tx, rx) = oneshot::channel();

(Self::$variant(command, tx), rx)
(Self::$variant(command, tx, span), rx)
}
}
)*

impl $name {
pub(super) fn from_command<T>(command: Arc<T>) -> (Self, CommandResultReceiver<T>)
pub(super) fn from_command<T>(command: Arc<T>, span: Span) -> (Self, CommandResultReceiver<T>)
where
T: ClientCommand,
Self: FromClientCommand<T>,
{
<Self as FromClientCommand<T>>::from_command(command)
<Self as FromClientCommand<T>>::from_command(command, span)
}

pub(super) async fn run<T: MqttClient>(
Expand All @@ -70,9 +71,8 @@ macro_rules! commands {
) {
match self {
$(
Self::$variant(command, tx) => {
// TODO: tracing
let result = command.run(client, mqtt).await;
Self::$variant(command, tx, span) => {
let result = command.run(client, mqtt).instrument(span).await;

if let Err(_err) = tx.send(result) {
// TODO: log
Expand Down
4 changes: 2 additions & 2 deletions crates/mqtt-client/src/client/command/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{ClientCommand, InnerClient};
use crate::client::QosLevel;
use async_trait::async_trait;
use hass_dyn_error::DynError;
use hass_mqtt_provider::{MqttClient, MqttMessage, MqttMessageBuilder};
use hass_mqtt_provider::{MqttBuildableMessage, MqttClient, MqttMessageBuilder};
use std::sync::Arc;
use thiserror::Error;

Expand Down Expand Up @@ -44,7 +44,7 @@ impl ClientCommand for PublishCommand {
_client: &mut InnerClient,
mqtt: &T,
) -> Result<Self::Result, Self::Error> {
let msg = <T::Message as MqttMessage>::builder()
let msg = <T::Message as MqttBuildableMessage>::builder()
.topic(&*self.topic)
.payload(&*self.payload)
.retain(self.retained)
Expand Down
90 changes: 66 additions & 24 deletions crates/mqtt-client/src/client/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ use crate::{
};
use futures::{pin_mut, StreamExt};
use hass_dyn_error::DynError;
use hass_mqtt_provider::{MqttClient, MqttMessage, MqttProvider};
use std::{thread, time::Duration};
use hass_mqtt_provider::{MqttClient, MqttMessage, MqttProvider, MqttReceivedMessage};
use std::{sync::Arc, thread, time::Duration};
use thiserror::Error;
use tokio::select;
use tracing::{field, instrument, span, Level, Span};

type RouteId = generational_arena::Index;

Expand Down Expand Up @@ -70,6 +71,15 @@ impl InnerClient {
}
}

#[instrument(
level = Level::DEBUG,
name = "InnerClient::run",
skip_all,
fields(
provider.name = %<T::Provider>::NAME,
client.id = %client.client_id(),
)
)]
async fn run<T: MqttClient>(mut self, client: T, receiver: flume::Receiver<Command>) {
// TODO: don't use the events helper, use select instead
let receiver = receiver.into_stream().fuse();
Expand Down Expand Up @@ -101,17 +111,23 @@ impl InnerClient {
cmd.run(self, client).await
}

async fn handle_message<T: MqttClient>(&mut self, msg: T::Message, _client: &T) {
async fn handle_message<T: MqttClient>(&mut self, msg: MqttReceivedMessage<T>, _client: &T) {
let client_span = Span::current();

let topic = msg.topic();
let matches = self.router.matches(topic);
if matches.len() == 0 {
return;
}

let message_span = msg.span().clone();
message_span.follows_from(client_span);

let message = Message {
topic: topic.into(),
payload: msg.payload().into(),
retained: msg.retained(),
span: message_span,
};

let mut to_remove = Vec::new();
Expand All @@ -126,14 +142,35 @@ impl InnerClient {
}
}

#[instrument(
level = Level::DEBUG,
name = "InnerClient::spawn"
skip_all,
fields(
provider.name = %P::NAME,
)
)]
pub(super) async fn spawn<P: MqttProvider>(
options: HassMqttOptions,
) -> Result<flume::Sender<Command>, ConnectError> {
) -> Result<(flume::Sender<Command>, Arc<str>), ConnectError> {
let spawn_span = Span::current().id();
let (result_sender, result_receiver) = tokio::sync::oneshot::channel();

thread::Builder::new()
.name(format!("mqtt-{}-hass", options.application_name.slug()))
.spawn(move || {
let span = {
let span = span!(
parent: None,
Level::DEBUG,
"InnerClient::thread",
provider.name = %P::NAME,
client.id = field::Empty,
);
span.follows_from(spawn_span);
span.entered()
};

let (sender, receiver) = flume::unbounded();
let rt = match tokio::runtime::Builder::new_current_thread()
.build()
Expand All @@ -147,29 +184,34 @@ impl InnerClient {
};

let guard = rt.enter();
rt.block_on(async move {
let HassMqttConnection {
topics,
client: mqtt_client,
} = match <P as MqttProviderExt>::create_client(&options)
.await
.map_err(ConnectError::connect)
{
Ok(c) => c,
Err(e) => {
let _ = result_sender.send(Err(e));
return;
}
};

let client = InnerClient::new(topics);

let _ = result_sender.send(Ok(sender));
client.run(mqtt_client, receiver).await;
rt.block_on({
let span = &span;
async move {
let HassMqttConnection {
topics,
client: mqtt_client,
client_id,
} = match <P as MqttProviderExt>::create_client(&options)
.await
.map_err(ConnectError::connect)
{
Ok(c) => c,
Err(e) => {
let _ = result_sender.send(Err(e));
return;
}
};

span.record("client.id", &client_id);
let client = InnerClient::new(topics);

let _ = result_sender.send(Ok((sender, client_id.into())));
client.run(mqtt_client, receiver).await;
}
});

// ensure it lives til this point
drop(guard);
drop((guard, span));
})
.map_err(ConnectError::spawn_thread)?;

Expand Down
Loading