diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 90bd1415003cd..0647b43e7bbc3 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -1094,6 +1094,29 @@ pub struct Signature { pub parameter_names: Option>, } +/// A helper enum used by [`Signature::from_parameter_variants`] +/// to accept either concrete [`DataType`]s (for [`TypeSignature::Exact`] and +/// [`TypeSignature::Uniform`]) or [`Coercion`] rules (for [`TypeSignature::Coercible`]). +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ParameterKind { + /// A concrete [`DataType`] expected by the signature. + DataType(DataType), + /// A [`Coercion`] rule that will be stored on the signature. + Coercion(Coercion), +} + +impl From for ParameterKind { + fn from(value: DataType) -> Self { + Self::DataType(value) + } +} + +impl From for ParameterKind { + fn from(value: Coercion) -> Self { + Self::Coercion(value) + } +} + impl Signature { /// Creates a new Signature from a given type signature and volatility. pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { @@ -1366,6 +1389,162 @@ impl Signature { Ok(()) } + + /// Construct a signature with multiple variants directly from parameter specifications. + /// + /// This is the recommended way to define functions that accept multiple signatures + /// (e.g., optional parameters) as it eliminates duplication and makes the variants + /// explicit. + /// + /// # Example + /// ``` + /// # use datafusion_expr_common::signature::{Signature, Volatility, Coercion, TypeSignatureClass}; + /// # use datafusion_common::types::{logical_string, logical_int64, NativeType}; + /// # use datafusion_common::Result; + /// # fn example() -> Result<()> { + /// let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + /// let int64 = Coercion::new_exact(TypeSignatureClass::Native(logical_int64())); + /// + /// // substr(str, pos) OR substr(str, pos, length) + /// let sig = Signature::from_parameter_variants( + /// vec![ + /// vec![("str", string.clone()), ("start_pos", int64.clone())], + /// vec![("str", string.clone()), ("start_pos", int64.clone()), ("length", int64.clone())], + /// ], + /// Volatility::Immutable + /// )?; + /// # Ok(()) + /// # } + /// ``` + /// + /// # Parameter Name Inference + /// The parameter names for the signature are inferred from the **longest variant**. + /// This ensures all parameters are documented. Shorter variants are treated as + /// having optional trailing parameters. + /// + /// # Supported TypeSignatures + /// This method can generate: + /// - [`TypeSignature::Nullary`] - via empty parameter list `vec![]` + /// - [`TypeSignature::Exact`] - via [`DataType`] parameters + /// - [`TypeSignature::Coercible`] - via [`Coercion`] parameters + /// - [`TypeSignature::OneOf`] - when multiple variants are provided + /// + /// For other signature types (e.g., [`TypeSignature::Variadic`], [`TypeSignature::Uniform`], + /// [`TypeSignature::Numeric`], [`TypeSignature::String`], [`TypeSignature::Comparable`], + /// [`TypeSignature::Any`], [`TypeSignature::ArraySignature`], [`TypeSignature::UserDefined`]), + /// use the corresponding constructor methods like [`Signature::variadic`], [`Signature::uniform`], + /// [`Signature::numeric`], etc. + /// + /// # Errors + /// Returns an error if: + /// - No variants are provided + /// - Parameter names are invalid + /// - Type kinds are inconsistent within a variant + pub fn from_parameter_variants( + variants: Vec>, + volatility: Volatility, + ) -> Result + where + N: AsRef, + P: Clone + Into, + { + if variants.is_empty() { + return plan_err!("At least one variant must be provided"); + } + + let parameter_names = Self::extract_parameter_names(&variants); + let type_signatures = Self::build_type_signatures(&variants)?; + let type_signature = Self::consolidate_signatures(type_signatures)?; + + let mut sig = Self::new(type_signature, volatility); + sig.parameter_names = Some(parameter_names); + Ok(sig) + } + + /// Extract parameter names from the longest variant + fn extract_parameter_names(variants: &[Vec<(N, P)>]) -> Vec + where + N: AsRef, + { + variants + .iter() + .max_by_key(|v| v.len()) + .expect("variants is non-empty") + .iter() + .map(|(name, _)| name.as_ref().to_string()) + .collect() + } + + /// Build TypeSignature for each variant + fn build_type_signatures(variants: &[Vec<(N, P)>]) -> Result> + where + P: Clone + Into, + { + variants + .iter() + .map(|params| Self::build_variant_signature(params)) + .collect() + } + + /// Build a TypeSignature for a single variant + fn build_variant_signature(params: &[(N, P)]) -> Result + where + P: Clone + Into, + { + if params.is_empty() { + return Ok(TypeSignature::Nullary); + } + + match params[0].1.clone().into() { + ParameterKind::DataType(_) => Self::build_exact_signature(params), + ParameterKind::Coercion(_) => Self::build_coercible_signature(params), + } + } + + /// Build an Exact TypeSignature from DataType parameters + fn build_exact_signature(params: &[(N, P)]) -> Result + where + P: Clone + Into, + { + let types: Result> = params + .iter() + .map(|(_, p)| match p.clone().into() { + ParameterKind::DataType(dt) => Ok(dt), + ParameterKind::Coercion(_) => { + plan_err!("Cannot mix DataType and Coercion in same variant") + } + }) + .collect(); + Ok(TypeSignature::Exact(types?)) + } + + /// Build a Coercible TypeSignature from Coercion parameters + fn build_coercible_signature(params: &[(N, P)]) -> Result + where + P: Clone + Into, + { + let coercions: Result> = params + .iter() + .map(|(_, p)| match p.clone().into() { + ParameterKind::Coercion(c) => Ok(c), + ParameterKind::DataType(_) => { + plan_err!("Cannot mix DataType and Coercion in same variant") + } + }) + .collect(); + Ok(TypeSignature::Coercible(coercions?)) + } + + /// Consolidate multiple TypeSignatures into a single one (or OneOf) + fn consolidate_signatures( + mut signatures: Vec, + ) -> Result { + match signatures.len() { + 0 => internal_err!("No type signatures provided"), + 1 => Ok(signatures.pop().unwrap()), + _ => Ok(TypeSignature::OneOf(signatures)), + } + } } #[cfg(test)] @@ -1965,4 +2144,222 @@ mod tests { let sig = TypeSignature::UserDefined; assert_eq!(sig.arity(), Arity::Variable); } + + #[test] + fn test_signature_from_parameter_variants_single_variant() { + // Test with a single variant (Exact signature) + let sig = Signature::from_parameter_variants( + vec![vec![("count", DataType::Int32), ("name", DataType::Utf8)]], + Volatility::Immutable, + ) + .unwrap(); + + assert_eq!( + sig.type_signature, + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]) + ); + assert_eq!( + sig.parameter_names, + Some(vec!["count".to_string(), "name".to_string()]) + ); + assert_eq!(sig.volatility, Volatility::Immutable); + } + + #[test] + fn test_signature_from_parameter_variants_two_variants() { + // Test with two variants creating a OneOf signature + let sig = Signature::from_parameter_variants( + vec![ + vec![("str", DataType::Utf8), ("pos", DataType::Int64)], + vec![ + ("str", DataType::Utf8), + ("pos", DataType::Int64), + ("len", DataType::Int64), + ], + ], + Volatility::Immutable, + ) + .unwrap(); + + // Should create a OneOf signature with parameter names from longest variant + match &sig.type_signature { + TypeSignature::OneOf(sigs) => { + assert_eq!(sigs.len(), 2); + assert_eq!( + sigs[0], + TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]) + ); + assert_eq!( + sigs[1], + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::Int64 + ]) + ); + } + other => panic!("Expected OneOf, got {:?}", other), + } + + // Names should come from the longest variant + assert_eq!( + sig.parameter_names, + Some(vec![ + "str".to_string(), + "pos".to_string(), + "len".to_string() + ]) + ); + } + + #[test] + fn test_signature_from_parameter_variants_with_nullary() { + // Test with a Nullary (no arguments) variant + let sig = Signature::from_parameter_variants( + vec![vec![], vec![("flag", DataType::Boolean)]], + Volatility::Stable, + ) + .unwrap(); + + // Should create a OneOf with Nullary and single-arg signatures + match &sig.type_signature { + TypeSignature::OneOf(sigs) => { + assert_eq!(sigs.len(), 2); + assert_eq!(sigs[0], TypeSignature::Nullary); + assert_eq!(sigs[1], TypeSignature::Exact(vec![DataType::Boolean])); + } + other => panic!("Expected OneOf, got {:?}", other), + } + + // Names should come from the longest variant + assert_eq!(sig.parameter_names, Some(vec!["flag".to_string()])); + } + + #[test] + fn test_signature_from_parameter_variants_empty_error() { + // Test that an empty variant list returns an error + let result = Signature::from_parameter_variants::<&str, DataType>( + vec![], + Volatility::Immutable, + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("At least one variant must be provided") + ); + } + + #[test] + fn test_signature_from_parameter_variants_with_coercions() { + // Test with Coercion-based variants for Coercible signatures + let string_coercion = + Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + let int64_coercion = + Coercion::new_exact(TypeSignatureClass::Native(logical_int64())); + + let sig = Signature::from_parameter_variants( + vec![ + vec![ + ("str", string_coercion.clone()), + ("pos", int64_coercion.clone()), + ], + vec![ + ("str", string_coercion.clone()), + ("pos", int64_coercion.clone()), + ("len", int64_coercion), + ], + ], + Volatility::Immutable, + ) + .unwrap(); + + // Should create OneOf of Coercible signatures + match &sig.type_signature { + TypeSignature::OneOf(sigs) => { + assert_eq!(sigs.len(), 2); + // First variant has 2 coercions + match &sigs[0] { + TypeSignature::Coercible(coercions) => assert_eq!(coercions.len(), 2), + other => panic!("Expected Coercible, got {:?}", other), + } + // Second variant has 3 coercions + match &sigs[1] { + TypeSignature::Coercible(coercions) => assert_eq!(coercions.len(), 3), + other => panic!("Expected Coercible, got {:?}", other), + } + } + other => panic!("Expected OneOf, got {:?}", other), + } + + // Parameter names from longest variant + assert_eq!( + sig.parameter_names, + Some(vec![ + "str".to_string(), + "pos".to_string(), + "len".to_string() + ]) + ); + } + + #[test] + fn test_signature_from_parameter_variants_mixed_volatility() { + // Test that volatility is set correctly + let volatile_sig = Signature::from_parameter_variants( + vec![vec![("x", DataType::Float64)]], + Volatility::Volatile, + ) + .unwrap(); + + assert_eq!(volatile_sig.volatility, Volatility::Volatile); + + let stable_sig = Signature::from_parameter_variants( + vec![vec![("y", DataType::Int32)]], + Volatility::Stable, + ) + .unwrap(); + + assert_eq!(stable_sig.volatility, Volatility::Stable); + } + + #[test] + fn test_signature_from_parameter_variants_mixed_types_error() { + // Test that mixing DataType and Coercion in same variant returns error + let coercion = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + let result = Signature::from_parameter_variants( + vec![vec![ + ("str", ParameterKind::DataType(DataType::Utf8)), + ("pos", ParameterKind::Coercion(coercion)), + ]], + Volatility::Immutable, + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Cannot mix DataType and Coercion") + ); + } + + #[test] + fn test_signature_from_parameter_variants_single_variant_single_param() { + // Test with a single parameter in a single variant + let sig = Signature::from_parameter_variants( + vec![vec![("value", DataType::Float32)]], + Volatility::Immutable, + ) + .unwrap(); + + assert_eq!( + sig.type_signature, + TypeSignature::Exact(vec![DataType::Float32]) + ); + assert_eq!(sig.parameter_names, Some(vec!["value".to_string()])); + } } diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index cc1d53b3aad67..930d5a7503ead 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -32,8 +32,8 @@ use datafusion_common::types::{ }; use datafusion_common::{Result, exec_err}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, + Volatility, }; use datafusion_macros::user_doc; @@ -80,24 +80,20 @@ impl SubstrFunc { vec![TypeSignatureClass::Native(logical_int32())], NativeType::Int64, ); + Self { - signature: Signature::one_of( + signature: Signature::from_parameter_variants( vec![ - TypeSignature::Coercible(vec![string.clone(), int64.clone()]), - TypeSignature::Coercible(vec![ - string.clone(), - int64.clone(), - int64.clone(), - ]), + vec![("str", string.clone()), ("start_pos", int64.clone())], + vec![ + ("str", string.clone()), + ("start_pos", int64.clone()), + ("length", int64.clone()), + ], ], Volatility::Immutable, ) - .with_parameter_names(vec![ - "str".to_string(), - "start_pos".to_string(), - "length".to_string(), - ]) - .expect("valid parameter names"), + .expect("valid parameter variants"), aliases: vec![String::from("substring")], } }