Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-cli/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ impl TableFunctionImpl for ParquetMetadataFunc {
compression_arr.push(format!("{:?}", column.compression()));
// need to collect into Vec to format
let encodings: Vec<_> = column.encodings().collect();
encodings_arr.push(format!("{:?}", encodings));
encodings_arr.push(format!("{encodings:?}"));
index_page_offset_arr.push(column.index_page_offset());
dictionary_page_offset_arr.push(column.dictionary_page_offset());
data_page_offset_arr.push(column.data_page_offset());
Expand Down
70 changes: 64 additions & 6 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,31 @@ impl DataFrame {
/// # Ok(())
/// # }
/// ```
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
pub fn drop_columns<T>(self, columns: &[T]) -> Result<DataFrame>
where
T: Into<Column> + Clone,
{
let fields_to_drop = columns
.iter()
.flat_map(|name| {
self.plan
.schema()
.qualified_fields_with_unqualified_name(name)
.flat_map(|col| {
let column: Column = col.clone().into();
match column.relation.as_ref() {
Some(_) => {
// qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)>
vec![self.plan.schema().qualified_field_from_column(&column)]
}
None => {
// qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)>
self.plan
.schema()
.qualified_fields_with_unqualified_name(&column.name)
.into_iter()
.map(Ok)
.collect::<Vec<_>>()
}
}
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()?;
let expr: Vec<Expr> = self
.plan
.schema()
Expand Down Expand Up @@ -2463,6 +2479,48 @@ impl DataFrame {
.collect()
}

/// Find qualified columns for this dataframe from names
///
/// # Arguments
/// * `names` - Unqualified names to find.
///
/// # Example
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion_common::ScalarValue;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// let df = ctx.table("first_table").await?;
/// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// let df2 = ctx.table("second_table").await?;
/// let join_expr = df.find_qualified_columns(&["a"])?.iter()
/// .zip(df2.find_qualified_columns(&["a"])?.iter())
/// .map(|(col1, col2)| col(*col1).eq(col(*col2)))
/// .collect::<Vec<Expr>>();
/// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?;
/// # Ok(())
/// # }
/// ```
pub fn find_qualified_columns(
&self,
names: &[&str],
) -> Result<Vec<(Option<&TableReference>, &FieldRef)>> {
let schema = self.logical_plan().schema();
names
.iter()
.map(|name| {
schema
.qualified_field_from_column(&Column::from_name(*name))
.map_err(|_| plan_datafusion_err!("Column '{}' not found", name))
})
.collect()
}

/// Helper for creating DataFrame.
/// # Example
/// ```
Expand Down
106 changes: 104 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> {
async fn drop_columns_with_empty_array() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&[])?;
let drop_columns = vec![] as Vec<&str>;
let t2 = t.drop_columns(&drop_columns)?;
let plan = t2.logical_plan().clone();

// build query using SQL
Expand All @@ -549,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn drop_columns_qualified() -> Result<()> {
// build plan using Table API
let mut t = test_table().await?;
t = t.select_columns(&["c1", "c2", "c11"])?;
let mut t2 = test_table_with_name("another_table").await?;
t2 = t2.select_columns(&["c1", "c2", "c11"])?;
let mut t3 = t.join_on(
t2,
JoinType::Inner,
[col("aggregate_test_100.c1").eq(col("another_table.c1"))],
)?;
t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?;

let plan = t3.logical_plan().clone();

let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1";
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
register_aggregate_csv(&ctx, "another_table").await?;
let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan();

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_qualified_find_qualified() -> Result<()> {
// build plan using Table API
let mut t = test_table().await?;
t = t.select_columns(&["c1", "c2", "c11"])?;
let mut t2 = test_table_with_name("another_table").await?;
t2 = t2.select_columns(&["c1", "c2", "c11"])?;
let mut t3 = t.join_on(
t2.clone(),
JoinType::Inner,
[col("aggregate_test_100.c1").eq(col("another_table.c1"))],
)?;
t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?;

let plan = t3.logical_plan().clone();

let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1";
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
register_aggregate_csv(&ctx, "another_table").await?;
let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan();

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn test_find_qualified_names() -> Result<()> {
let t = test_table().await?;
let column_names = ["c1", "c2", "c3"];
let columns = t.find_qualified_columns(&column_names)?;

// Expected results for each column
let binding = TableReference::bare("aggregate_test_100");
let expected = [
(Some(&binding), "c1"),
(Some(&binding), "c2"),
(Some(&binding), "c3"),
];

// Verify we got the expected number of results
assert_eq!(
columns.len(),
expected.len(),
"Expected {} columns, got {}",
expected.len(),
columns.len()
);

// Iterate over the results and check each one individually
for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() {
let (actual_table_ref, actual_field_ref) = actual;
let (expected_table_ref, expected_field_name) = expected;

// Check table reference
assert_eq!(
actual_table_ref, expected_table_ref,
"Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}"
);

// Check field name
assert_eq!(
actual_field_ref.name(),
*expected_field_name,
"Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'"
);
}

Ok(())
}

#[tokio::test]
async fn drop_with_quotes() -> Result<()> {
// define data with a column name that has a "." in it:
Expand Down Expand Up @@ -594,7 +696,7 @@ async fn drop_with_periods() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f.c1"])?;
let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?;

let df_results = df.collect().await?;

Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ pub mod test {

#[test]
fn test_decimal32_to_i32() {
let cases: [(i32, i8, Either<i32, String>); _] = [
let cases: [(i32, i8, Either<i32, String>); 10] = [
(123, 0, Either::Left(123)),
(1230, 1, Either::Left(123)),
(123000, 3, Either::Left(123)),
Expand Down Expand Up @@ -456,7 +456,7 @@ pub mod test {

#[test]
fn test_decimal64_to_i64() {
let cases: [(i64, i8, Either<i64, String>); _] = [
let cases: [(i64, i8, Either<i64, String>); 8] = [
(123, 0, Either::Left(123)),
(1234567890, 2, Either::Left(12345678)),
(-1234567890, 2, Either::Left(-12345678)),
Expand Down