Skip to content

Commit 02d3d77

Browse files
authored
Simplify process_task_inner in worker (#934)
## Motivation This is a functional NOOP refactor the worker, that largely aims to - Remove uses of `unreachable!` macro - Remove unnecessary messagedestination fetch attempt in dispatch function - Just generally simplify things, remove code, etc. ## Solution - This mostly refactors `process_queue_task_inner` so that the task type is `match`ed on only once. Overall this cuts down on code duplication too.
2 parents ef78d66 + 5c03733 commit 02d3d77

File tree

1 file changed

+78
-105
lines changed

1 file changed

+78
-105
lines changed

server/svix-server/src/worker.rs

Lines changed: 78 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ use crate::core::webhook_http_client::{
1919
};
2020
use crate::db::models::{endpoint, message, messageattempt, messagedestination};
2121
use crate::error::{Error, ErrorType, HttpError, Result};
22-
use crate::queue::{
23-
MessageTask, MessageTaskBatch, QueueTask, TaskQueueConsumer, TaskQueueProducer,
24-
};
22+
use crate::queue::{MessageTask, QueueTask, TaskQueueConsumer, TaskQueueProducer};
2523
use crate::v1::utils::get_unix_timestamp;
2624
use crate::{ctx, err_cache, err_generic, err_validation};
2725

@@ -35,7 +33,8 @@ use rand::Rng;
3533

3634
use sea_orm::prelude::DateTimeUtc;
3735
use sea_orm::{
38-
ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set, TryIntoModel,
36+
ActiveModelBehavior, ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait,
37+
QueryFilter, Set, TryIntoModel,
3938
};
4039
use serde::{Deserialize, Serialize};
4140
use tokio::time::sleep;
@@ -644,24 +643,12 @@ async fn dispatch_message_task(
644643
msg_task: MessageTask,
645644
payload: &str,
646645
endp: CreateMessageEndpoint,
647-
msg_dest: Option<messagedestination::Model>,
646+
msg_dest: messagedestination::Model,
648647
) -> Result<()> {
649-
let WorkerContext { cfg, db, .. } = worker_context;
648+
let WorkerContext { cfg, .. } = worker_context;
650649

651650
tracing::trace!("Dispatch start");
652651

653-
let msg_dest = if let Some(msg_dest) = msg_dest {
654-
msg_dest
655-
} else {
656-
ctx!(
657-
messagedestination::Entity::secure_find_by_msg(msg_task.msg_id.clone())
658-
.filter(messagedestination::Column::EndpId.eq(endp.id.clone()))
659-
.one(*db)
660-
.await
661-
)?
662-
.ok_or_else(|| err_generic!("Msg dest not found {} {}", msg_task.msg_id, endp.id))?
663-
};
664-
665652
if (msg_dest.status != MessageStatus::Pending && msg_dest.status != MessageStatus::Sending)
666653
&& (msg_task.trigger_type != MessageAttemptTriggerType::Manual)
667654
{
@@ -735,32 +722,50 @@ async fn process_queue_task_inner(
735722
queue_task: QueueTask,
736723
) -> Result<()> {
737724
let WorkerContext { db, cache, .. }: WorkerContext<'_> = worker_context;
725+
let span = tracing::Span::current();
738726

739-
if queue_task == QueueTask::HealthCheck {
740-
return Ok(());
741-
}
727+
let (msg, force_endpoint, destination, trigger_type, attempt_count) = match queue_task {
728+
QueueTask::HealthCheck => return Ok(()),
729+
QueueTask::MessageV1(task) => {
730+
let msg = ctx!(
731+
message::Entity::find_by_id(task.msg_id.clone())
732+
.one(db)
733+
.await
734+
)?
735+
.ok_or_else(|| err_generic!("Unexpected: message doesn't exist"))?;
742736

743-
let span = tracing::Span::current();
737+
let destination = ctx!(
738+
messagedestination::Entity::secure_find_by_msg(task.msg_id.clone())
739+
.filter(messagedestination::Column::EndpId.eq(task.endpoint_id.clone()))
740+
.one(db)
741+
.await
742+
)?
743+
.ok_or_else(|| {
744+
err_generic!(format!(
745+
"MessageDestination not found for message {}",
746+
&task.msg_id
747+
))
748+
})?;
744749

745-
let (msg_id, trigger_type) = match &queue_task {
746-
QueueTask::MessageBatch(MessageTaskBatch {
747-
msg_id,
748-
trigger_type,
749-
..
750-
}) => (msg_id, trigger_type),
751-
QueueTask::MessageV1(MessageTask {
752-
msg_id,
753-
trigger_type,
754-
..
755-
}) => (msg_id, trigger_type),
756-
757-
QueueTask::HealthCheck => unreachable!(),
750+
(
751+
msg,
752+
Some(task.endpoint_id),
753+
Some(destination),
754+
task.trigger_type,
755+
task.attempt_count,
756+
)
757+
}
758+
QueueTask::MessageBatch(task) => {
759+
let msg = ctx!(message::Entity::find_by_id(task.msg_id).one(db).await)?
760+
.ok_or_else(|| err_generic!("Unexpected: message doesn't exist"))?;
761+
(msg, None, None, task.trigger_type, 0)
762+
}
758763
};
759764

760-
span.record("msg_id", &msg_id.0);
765+
span.record("msg_id", &msg.id.0);
766+
span.record("app_id", &msg.app_id.0);
767+
span.record("org_id", &msg.org_id.0);
761768

762-
let msg = ctx!(message::Entity::find_by_id(msg_id.clone()).one(db).await)?
763-
.ok_or_else(|| err_generic!("Unexpected: message doesn't exist"))?;
764769
let payload = match msg
765770
.payload
766771
.as_ref()
@@ -773,9 +778,6 @@ async fn process_queue_task_inner(
773778
}
774779
};
775780

776-
span.record("app_id", &msg.app_id.0);
777-
span.record("org_id", &msg.org_id.0);
778-
779781
let create_message_app = match CreateMessageApp::layered_fetch(
780782
cache,
781783
db,
@@ -794,60 +796,26 @@ async fn process_queue_task_inner(
794796
};
795797

796798
let endpoints: Vec<CreateMessageEndpoint> = create_message_app
797-
.filtered_endpoints(*trigger_type, &msg.event_type, msg.channels.as_ref())
799+
.filtered_endpoints(trigger_type, &msg.event_type, msg.channels.as_ref())
798800
.iter()
799-
.filter(|endpoint| match &queue_task {
800-
QueueTask::HealthCheck => unreachable!(),
801-
QueueTask::MessageV1(task) => task.endpoint_id == endpoint.id,
802-
QueueTask::MessageBatch(_) => true,
801+
.filter(|endpoint| match force_endpoint.as_ref() {
802+
Some(endp_id) => endp_id == &endpoint.id,
803+
None => true,
803804
})
804805
.cloned()
805806
.collect();
806807

807-
let futures: Vec<_> = match &queue_task {
808-
QueueTask::HealthCheck => unreachable!(),
809-
810-
QueueTask::MessageV1(task) => {
811-
let endpoint = match endpoints.into_iter().next() {
812-
Some(ep) => ep,
813-
None => {
814-
return Ok(());
815-
}
816-
};
817-
818-
let destination = ctx!(
819-
messagedestination::Entity::secure_find_by_msg(task.msg_id.clone())
820-
.filter(messagedestination::Column::EndpId.eq(endpoint.id.clone()))
821-
.one(db)
822-
.await
823-
)?
824-
.ok_or_else(|| {
825-
err_generic!(format!(
826-
"MessageDestination not found for message {}",
827-
&task.msg_id
828-
))
829-
})?;
830-
831-
vec![dispatch_message_task(
832-
&worker_context,
833-
&msg,
834-
&create_message_app,
835-
task.clone(),
836-
&payload,
837-
endpoint,
838-
Some(destination),
839-
)]
840-
}
841-
842-
QueueTask::MessageBatch(task) => {
808+
let destinations = match destination {
809+
Some(d) => vec![d],
810+
None => {
843811
let destinations: Vec<_> = endpoints
844812
.iter()
845813
.map(|endpoint| messagedestination::ActiveModel {
846814
msg_id: Set(msg.id.clone()),
847815
endp_id: Set(endpoint.id.clone()),
848816
next_attempt: Set(Some(Utc::now().into())),
849817
status: Set(MessageStatus::Sending),
850-
..Default::default()
818+
..messagedestination::ActiveModel::new()
851819
})
852820
.collect();
853821

@@ -857,32 +825,37 @@ async fn process_queue_task_inner(
857825
.await
858826
)?;
859827

860-
endpoints
828+
let dests: std::result::Result<_, _> = destinations
861829
.into_iter()
862-
.zip(destinations)
863-
.map(|(endpoint, destination)| {
864-
let task = MessageTask {
865-
msg_id: msg_id.clone(),
866-
app_id: task.app_id.clone(),
867-
endpoint_id: endpoint.id.clone(),
868-
attempt_count: 0,
869-
trigger_type: *trigger_type,
870-
};
871-
872-
dispatch_message_task(
873-
&worker_context,
874-
&msg,
875-
&create_message_app,
876-
task,
877-
&payload,
878-
endpoint,
879-
destination.try_into_model().ok(),
880-
)
881-
})
882-
.collect()
830+
.map(|d| d.try_into_model())
831+
.collect();
832+
ctx!(dests)?
883833
}
884834
};
885835

836+
let futures = endpoints
837+
.into_iter()
838+
.zip(destinations)
839+
.map(|(endpoint, destination)| {
840+
let task = MessageTask {
841+
msg_id: msg.id.clone(),
842+
app_id: create_message_app.id.clone(),
843+
endpoint_id: endpoint.id.clone(),
844+
attempt_count,
845+
trigger_type,
846+
};
847+
848+
dispatch_message_task(
849+
&worker_context,
850+
&msg,
851+
&create_message_app,
852+
task,
853+
&payload,
854+
endpoint,
855+
destination,
856+
)
857+
});
858+
886859
let join = future::join_all(futures).await;
887860

888861
let errs: Vec<_> = join.iter().filter(|x| x.is_err()).collect();

0 commit comments

Comments
 (0)