diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index a45d57e8e952d..c2e4e16569cc8 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -421,7 +421,7 @@ impl TableFunctionImpl for ParquetMetadataFunc { compression_arr.push(format!("{:?}", column.compression())); // need to collect into Vec to format let encodings: Vec<_> = column.encodings().collect(); - encodings_arr.push(format!("{:?}", encodings)); + encodings_arr.push(format!("{encodings:?}")); index_page_offset_arr.push(column.index_page_offset()); dictionary_page_offset_arr.push(column.dictionary_page_offset()); data_page_offset_arr.push(column.data_page_offset()); diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0d060db3bf147..a93146b079828 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -445,15 +445,31 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn drop_columns(self, columns: &[&str]) -> Result { + pub fn drop_columns(self, columns: &[T]) -> Result + where + T: Into + Clone, + { let fields_to_drop = columns .iter() - .flat_map(|name| { - self.plan - .schema() - .qualified_fields_with_unqualified_name(name) + .flat_map(|col| { + let column: Column = col.clone().into(); + match column.relation.as_ref() { + Some(_) => { + // qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)> + vec![self.plan.schema().qualified_field_from_column(&column)] + } + None => { + // qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)> + self.plan + .schema() + .qualified_fields_with_unqualified_name(&column.name) + .into_iter() + .map(Ok) + .collect::>() + } + } }) - .collect::>(); + .collect::, _>>()?; let expr: Vec = self .plan .schema() @@ -2463,6 +2479,48 @@ impl DataFrame { .collect() } + /// Find qualified columns for this dataframe from names + /// + /// # Arguments + /// * `names` - Unqualified names to find. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df = ctx.table("first_table").await?; + /// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df2 = ctx.table("second_table").await?; + /// let join_expr = df.find_qualified_columns(&["a"])?.iter() + /// .zip(df2.find_qualified_columns(&["a"])?.iter()) + /// .map(|(col1, col2)| col(*col1).eq(col(*col2))) + /// .collect::>(); + /// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?; + /// # Ok(()) + /// # } + /// ``` + pub fn find_qualified_columns( + &self, + names: &[&str], + ) -> Result, &FieldRef)>> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + schema + .qualified_field_from_column(&Column::from_name(*name)) + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } + /// Helper for creating DataFrame. /// # Example /// ``` diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c09db371912b0..1ae6ef5c4a8b5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -534,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> { async fn drop_columns_with_empty_array() -> Result<()> { // build plan using Table API let t = test_table().await?; - let t2 = t.drop_columns(&[])?; + let drop_columns = vec![] as Vec<&str>; + let t2 = t.drop_columns(&drop_columns)?; let plan = t2.logical_plan().clone(); // build query using SQL @@ -549,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> { Ok(()) } +#[tokio::test] +async fn drop_columns_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2, + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn drop_columns_qualified_find_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2.clone(), + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn test_find_qualified_names() -> Result<()> { + let t = test_table().await?; + let column_names = ["c1", "c2", "c3"]; + let columns = t.find_qualified_columns(&column_names)?; + + // Expected results for each column + let binding = TableReference::bare("aggregate_test_100"); + let expected = [ + (Some(&binding), "c1"), + (Some(&binding), "c2"), + (Some(&binding), "c3"), + ]; + + // Verify we got the expected number of results + assert_eq!( + columns.len(), + expected.len(), + "Expected {} columns, got {}", + expected.len(), + columns.len() + ); + + // Iterate over the results and check each one individually + for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() { + let (actual_table_ref, actual_field_ref) = actual; + let (expected_table_ref, expected_field_name) = expected; + + // Check table reference + assert_eq!( + actual_table_ref, expected_table_ref, + "Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}" + ); + + // Check field name + assert_eq!( + actual_field_ref.name(), + *expected_field_name, + "Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'" + ); + } + + Ok(()) +} + #[tokio::test] async fn drop_with_quotes() -> Result<()> { // define data with a column name that has a "." in it: @@ -594,7 +696,7 @@ async fn drop_with_periods() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?; let df_results = df.collect().await?; diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e4980728b18a0..ad68e8cdd5547 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -410,7 +410,7 @@ pub mod test { #[test] fn test_decimal32_to_i32() { - let cases: [(i32, i8, Either); _] = [ + let cases: [(i32, i8, Either); 10] = [ (123, 0, Either::Left(123)), (1230, 1, Either::Left(123)), (123000, 3, Either::Left(123)), @@ -456,7 +456,7 @@ pub mod test { #[test] fn test_decimal64_to_i64() { - let cases: [(i64, i8, Either); _] = [ + let cases: [(i64, i8, Either); 8] = [ (123, 0, Either::Left(123)), (1234567890, 2, Either::Left(12345678)), (-1234567890, 2, Either::Left(-12345678)),