|
| 1 | +use crate::TableEntryTableProvider; |
| 2 | +use ahash::{HashMap, HashSet}; |
| 3 | +use async_trait::async_trait; |
| 4 | +use datafusion::catalog::{CatalogProvider, SchemaProvider, TableProvider}; |
| 5 | +use datafusion::common::{DataFusionError, Result as DataFusionResult, TableReference, exec_err}; |
| 6 | +use parking_lot::Mutex; |
| 7 | +use re_redap_client::ConnectionClient; |
| 8 | +use std::any::Any; |
| 9 | +use std::iter; |
| 10 | +use std::sync::Arc; |
| 11 | +use tokio::runtime::Handle as RuntimeHandle; |
| 12 | + |
| 13 | +// These are to match the defaults in datafusion. |
| 14 | +pub const DEFAULT_CATALOG_NAME: &str = "datafusion"; |
| 15 | +const DEFAULT_SCHEMA_NAME: &str = "public"; |
| 16 | + |
| 17 | +/// `DataFusion` catalog provider for interacting with Rerun gRPC services. |
| 18 | +/// |
| 19 | +/// Tables are stored on the server in a flat namespace with a string |
| 20 | +/// representation of the catalog, schema, and table delimited by a |
| 21 | +/// period. This matches typical SQL style naming conventions. It the |
| 22 | +/// catalog or schema is not specified, it will be assumed to use |
| 23 | +/// the defaults. For example a table stored with table named |
| 24 | +/// `my_table` will be stored within the `datafusion` catalog and |
| 25 | +/// `public` schema. If a table is specified with more than three |
| 26 | +/// levels, it will also be stored in the default catalog and schema. |
| 27 | +/// This matches how `DataFusion` will store such table names. |
| 28 | +#[derive(Debug)] |
| 29 | +pub struct RedapCatalogProvider { |
| 30 | + catalog_name: Option<String>, |
| 31 | + client: ConnectionClient, |
| 32 | + schemas: Mutex<HashMap<Option<String>, Arc<RedapSchemaProvider>>>, |
| 33 | + runtime: RuntimeHandle, |
| 34 | +} |
| 35 | + |
| 36 | +fn get_table_refs( |
| 37 | + client: &ConnectionClient, |
| 38 | + runtime: &RuntimeHandle, |
| 39 | +) -> DataFusionResult<Vec<TableReference>> { |
| 40 | + runtime.block_on(async { |
| 41 | + Ok::<Vec<_>, DataFusionError>( |
| 42 | + client |
| 43 | + .clone() |
| 44 | + .get_table_names() |
| 45 | + .await |
| 46 | + .map_err(|err| DataFusionError::External(Box::new(err)))? |
| 47 | + .into_iter() |
| 48 | + .map(TableReference::from) |
| 49 | + .collect(), |
| 50 | + ) |
| 51 | + }) |
| 52 | +} |
| 53 | + |
| 54 | +pub fn get_all_catalog_names( |
| 55 | + client: &ConnectionClient, |
| 56 | + runtime: &RuntimeHandle, |
| 57 | +) -> DataFusionResult<Vec<String>> { |
| 58 | + let catalog_names = get_table_refs(client, runtime)? |
| 59 | + .into_iter() |
| 60 | + .filter_map(|reference| reference.catalog().map(|c| c.to_owned())) |
| 61 | + .collect::<HashSet<String>>(); |
| 62 | + |
| 63 | + Ok(catalog_names.into_iter().collect()) |
| 64 | +} |
| 65 | + |
| 66 | +impl RedapCatalogProvider { |
| 67 | + pub fn new(name: Option<&str>, client: ConnectionClient, runtime: RuntimeHandle) -> Self { |
| 68 | + let name = if let Some(inner_name) = name |
| 69 | + && inner_name == DEFAULT_CATALOG_NAME |
| 70 | + { |
| 71 | + None |
| 72 | + } else { |
| 73 | + name |
| 74 | + }; |
| 75 | + let default_schema = Arc::new(RedapSchemaProvider { |
| 76 | + catalog_name: name.map(ToOwned::to_owned), |
| 77 | + schema_name: None, |
| 78 | + client: client.clone(), |
| 79 | + runtime: runtime.clone(), |
| 80 | + in_memory_tables: Default::default(), |
| 81 | + }); |
| 82 | + let schemas: HashMap<_, _> = iter::once((None, default_schema)).collect(); |
| 83 | + |
| 84 | + Self { |
| 85 | + catalog_name: name.map(ToOwned::to_owned), |
| 86 | + client, |
| 87 | + schemas: Mutex::new(schemas), |
| 88 | + runtime, |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + fn update_from_server(&self) -> DataFusionResult<()> { |
| 93 | + let table_names = get_table_refs(&self.client, &self.runtime)?; |
| 94 | + |
| 95 | + let schema_names: HashSet<_> = table_names |
| 96 | + .into_iter() |
| 97 | + .filter(|table_ref| table_ref.catalog() == self.catalog_name.as_deref()) |
| 98 | + .map(|table_ref| table_ref.schema().map(|s| s.to_owned())) |
| 99 | + .collect(); |
| 100 | + |
| 101 | + let mut schemas = self.schemas.lock(); |
| 102 | + |
| 103 | + schemas.retain(|k, _| schema_names.contains(k) || k.is_none()); |
| 104 | + for schema_name in schema_names { |
| 105 | + let _ = schemas.entry(schema_name.clone()).or_insert( |
| 106 | + RedapSchemaProvider { |
| 107 | + catalog_name: self.catalog_name.clone(), |
| 108 | + schema_name, |
| 109 | + client: self.client.clone(), |
| 110 | + runtime: self.runtime.clone(), |
| 111 | + in_memory_tables: Default::default(), |
| 112 | + } |
| 113 | + .into(), |
| 114 | + ); |
| 115 | + } |
| 116 | + |
| 117 | + Ok(()) |
| 118 | + } |
| 119 | + |
| 120 | + fn get_schema_names(&self) -> DataFusionResult<Vec<String>> { |
| 121 | + self.update_from_server()?; |
| 122 | + |
| 123 | + let schemas = self.schemas.lock(); |
| 124 | + Ok(schemas |
| 125 | + .keys() |
| 126 | + .map(|k| k.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME).to_owned()) |
| 127 | + .collect()) |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +impl CatalogProvider for RedapCatalogProvider { |
| 132 | + fn as_any(&self) -> &dyn Any { |
| 133 | + self |
| 134 | + } |
| 135 | + |
| 136 | + fn schema_names(&self) -> Vec<String> { |
| 137 | + self.get_schema_names().unwrap_or_else(|err| { |
| 138 | + log::error!("Error attempting to get table references from server: {err}"); |
| 139 | + vec![] |
| 140 | + }) |
| 141 | + } |
| 142 | + |
| 143 | + fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> { |
| 144 | + if let Err(err) = self.update_from_server() { |
| 145 | + log::error!("Error updating table references from server: {err}"); |
| 146 | + return None; |
| 147 | + } |
| 148 | + |
| 149 | + let schemas = self.schemas.lock(); |
| 150 | + |
| 151 | + let schema_name = if name == DEFAULT_SCHEMA_NAME { |
| 152 | + None |
| 153 | + } else { |
| 154 | + Some(name.to_owned()) |
| 155 | + }; |
| 156 | + |
| 157 | + schemas |
| 158 | + .get(&schema_name) |
| 159 | + .map(|s| Arc::clone(s) as Arc<dyn SchemaProvider>) |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +/// `DataFusion` schema provider for interacting with Rerun gRPC services. |
| 164 | +/// |
| 165 | +/// For a detailed description of how tables are named on the server |
| 166 | +/// vs represented in the catalog and schema providers, see |
| 167 | +/// [`RedapCatalogProvider`]. |
| 168 | +/// |
| 169 | +/// When the user calls `register_table` on this provider, it will |
| 170 | +/// register the table *only for the current session*. To persist |
| 171 | +/// tables, instead they must be registered via the [`ConnectionClient`] |
| 172 | +/// `register_table`. It is expected for this behavior to change in |
| 173 | +/// the future. |
| 174 | +#[derive(Debug)] |
| 175 | +struct RedapSchemaProvider { |
| 176 | + catalog_name: Option<String>, |
| 177 | + schema_name: Option<String>, |
| 178 | + client: ConnectionClient, |
| 179 | + runtime: RuntimeHandle, |
| 180 | + in_memory_tables: Mutex<HashMap<String, Arc<dyn TableProvider>>>, |
| 181 | +} |
| 182 | + |
| 183 | +#[async_trait] |
| 184 | +impl SchemaProvider for RedapSchemaProvider { |
| 185 | + fn owner_name(&self) -> Option<&str> { |
| 186 | + self.catalog_name.as_deref() |
| 187 | + } |
| 188 | + |
| 189 | + fn as_any(&self) -> &dyn Any { |
| 190 | + self |
| 191 | + } |
| 192 | + |
| 193 | + fn table_names(&self) -> Vec<String> { |
| 194 | + let table_refs = get_table_refs(&self.client, &self.runtime).unwrap_or_else(|err| { |
| 195 | + log::error!("Error getting table references: {err}"); |
| 196 | + vec![] |
| 197 | + }); |
| 198 | + |
| 199 | + let mut table_names = table_refs |
| 200 | + .into_iter() |
| 201 | + .filter(|table_ref| { |
| 202 | + table_ref.catalog() == self.catalog_name.as_deref() |
| 203 | + && table_ref.schema() == self.schema_name.as_deref() |
| 204 | + }) |
| 205 | + .map(|table_ref| table_ref.table().to_owned()) |
| 206 | + .collect::<Vec<_>>(); |
| 207 | + |
| 208 | + table_names.extend(self.in_memory_tables.lock().keys().cloned()); |
| 209 | + |
| 210 | + table_names |
| 211 | + } |
| 212 | + |
| 213 | + async fn table( |
| 214 | + &self, |
| 215 | + table_name: &str, |
| 216 | + ) -> DataFusionResult<Option<Arc<dyn TableProvider>>, DataFusionError> { |
| 217 | + if let Some(table) = self.in_memory_tables.lock().get(table_name) { |
| 218 | + return Ok(Some(Arc::clone(table))); |
| 219 | + } |
| 220 | + |
| 221 | + let table_name = match (&self.catalog_name, &self.schema_name) { |
| 222 | + (Some(catalog_name), Some(schema_name)) => { |
| 223 | + format!("{catalog_name}.{schema_name}.{table_name}") |
| 224 | + } |
| 225 | + (None, Some(schema_name)) => format!("{schema_name}.{table_name}"), |
| 226 | + _ => table_name.to_owned(), |
| 227 | + }; |
| 228 | + TableEntryTableProvider::new(self.client.clone(), table_name) |
| 229 | + .into_provider() |
| 230 | + .await |
| 231 | + .map(Some) |
| 232 | + } |
| 233 | + |
| 234 | + fn register_table( |
| 235 | + &self, |
| 236 | + name: String, |
| 237 | + table: Arc<dyn TableProvider>, |
| 238 | + ) -> DataFusionResult<Option<Arc<dyn TableProvider>>> { |
| 239 | + let server_tables = get_table_refs(&self.client, &self.runtime)?; |
| 240 | + if server_tables.into_iter().any(|table_ref| { |
| 241 | + table_ref.catalog() == self.catalog_name.as_deref() |
| 242 | + && table_ref.schema() == self.schema_name.as_deref() |
| 243 | + && table_ref.table() == name |
| 244 | + }) { |
| 245 | + return exec_err!("{name} already exists on the server catalog"); |
| 246 | + } |
| 247 | + |
| 248 | + self.in_memory_tables.lock().insert(name, table); |
| 249 | + Ok(None) |
| 250 | + } |
| 251 | + |
| 252 | + fn deregister_table(&self, name: &str) -> DataFusionResult<Option<Arc<dyn TableProvider>>> { |
| 253 | + Ok(self.in_memory_tables.lock().remove(name)) |
| 254 | + } |
| 255 | + |
| 256 | + fn table_exist(&self, name: &str) -> bool { |
| 257 | + self.table_names().into_iter().any(|t| t == name) |
| 258 | + } |
| 259 | +} |
0 commit comments