diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index a4ae479de4d27..be193f873713b 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -20,23 +20,21 @@ mod data_utils; use arrow::util::pretty::pretty_format_batches; use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::make_data; -use datafusion::physical_plan::{ExecutionPlan, collect, displayable}; +use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; async fn create_context( - limit: usize, partition_cnt: i32, sample_cnt: i32, asc: bool, use_topk: bool, use_view: bool, -) -> Result<(Arc, Arc)> { +) -> Result { let (schema, parts) = make_data(partition_cnt, sample_cnt, asc, use_view).unwrap(); let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); @@ -46,32 +44,32 @@ async fn create_context( opts.optimizer.enable_topk_aggregation = use_topk; let ctx = SessionContext::new_with_config(cfg); let _ = ctx.register_table("traces", mem_table)?; + + Ok(ctx) +} + +fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) { + black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap(); +} + +async fn aggregate( + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) -> Result<()> { let sql = format!( "select max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};" ); let df = ctx.sql(sql.as_str()).await?; - let physical_plan = df.create_physical_plan().await?; - let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); assert_eq!( actual_phys_plan.contains(&format!("lim=[{limit}]")), use_topk ); - Ok((physical_plan, ctx.task_ctx())) -} - -#[expect(clippy::needless_pass_by_value)] -fn run(rt: &Runtime, plan: Arc, ctx: Arc, asc: bool) { - black_box(rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await })) - .unwrap(); -} - -async fn aggregate( - plan: Arc, - ctx: Arc, - asc: bool, -) -> Result<()> { - let batches = collect(plan, ctx).await?; + let batches = collect(plan, ctx.task_ctx()).await?; assert_eq!(batches.len(), 1); let batch = batches.first().unwrap(); assert_eq!(batch.num_rows(), 10); @@ -107,106 +105,70 @@ fn criterion_benchmark(c: &mut Criterion) { let partitions = 10; let samples = 1_000_000; + let ctx = rt + .block_on(create_context(partitions, samples, false, false, false)) + .unwrap(); c.bench_function( format!("aggregate {} time-series rows", partitions * samples).as_str(), - |b| { - b.iter(|| { - let real = rt.block_on(async { - create_context(limit, partitions, samples, false, false, false) - .await - .unwrap() - }); - run(&rt, real.0.clone(), real.1.clone(), false) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, false, false)), ); + let ctx = rt + .block_on(create_context(partitions, samples, true, false, false)) + .unwrap(); c.bench_function( format!("aggregate {} worst-case rows", partitions * samples).as_str(), - |b| { - b.iter(|| { - let asc = rt.block_on(async { - create_context(limit, partitions, samples, true, false, false) - .await - .unwrap() - }); - run(&rt, asc.0.clone(), asc.1.clone(), true) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, false, true)), ); + let ctx = rt + .block_on(create_context(partitions, samples, false, true, false)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} time-series rows", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true, false) - .await - .unwrap() - }); - run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, false)), ); + let ctx = rt + .block_on(create_context(partitions, samples, true, true, false)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} worst-case rows", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true, false) - .await - .unwrap() - }); - run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, true)), ); // Utf8View schema,time-series rows + let ctx = rt + .block_on(create_context(partitions, samples, false, true, true)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} time-series rows [Utf8View]", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true, true) - .await - .unwrap() - }); - run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, false)), ); // Utf8View schema,worst-case rows + let ctx = rt + .block_on(create_context(partitions, samples, true, true, true)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} worst-case rows [Utf8View]", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true, true) - .await - .unwrap() - }); - run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, true)), ); }