diff --git a/lib/op-attrs/ffi/include/flexflow/op-attrs.h b/lib/op-attrs/ffi/include/flexflow/op-attrs.h index 08b8e26f83..991fc9fd88 100644 --- a/lib/op-attrs/ffi/include/flexflow/op-attrs.h +++ b/lib/op-attrs/ffi/include/flexflow/op-attrs.h @@ -13,7 +13,8 @@ typedef enum { FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_ACTIVATION_VALUE, FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_POOL_OP_VALUE, FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_AGGREGATE_OP_VALUE, - FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_OP_TYPE_VALUE + FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_OP_TYPE_VALUE, + FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_LOSS_FUNCTION_VALUE, } flexflow_opattrs_error_code_t; FF_NEW_OPAQUE_TYPE(flexflow_opattrs_error_t); @@ -23,11 +24,64 @@ flexflow_error_t flexflow_opattrs_error_unwrap(flexflow_error_t, flexflow_error_t flexflow_opattrs_error_is_ok(flexflow_opattrs_error_t, bool *); flexflow_error_t flexflow_opattrs_error_get_string(flexflow_opattrs_error_t, char **); +flexflow_error_t + flexflow_opattrs_error_get_error_code(flexflow_opattrs_error_t, + flexflow_opattrs_error_code_t *); flexflow_error_t flexflow_opattrs_error_destroy(flexflow_opattrs_error_t); // FF_NEW_OPAQUE_TYPE(flexflow_regularizer_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_dim_ordered_t); +FF_NEW_OPAQUE_TYPE(flexflow_ff_dim_t); +FF_NEW_OPAQUE_TYPE(flexflow_parallel_dim_t); +FF_NEW_OPAQUE_TYPE(flexflow_parallel_tensor_dims_t); +FF_NEW_OPAQUE_TYPE(flexflow_parallel_tensor_shape_t); +FF_NEW_OPAQUE_TYPE(flexflow_tensor_shape_t); +FF_NEW_OPAQUE_TYPE( + flexflow_parallel_tesor_shape_list_t); // std::vector +FF_NEW_OPAQUE_TYPE(flexflow_tensor_shape_list_t); // std::vector + +// ops +FF_NEW_OPAQUE_TYPE(flexflow_aggregate_specattrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_aggregate_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_multihead_attention_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_multihead_attention_inputs_parallel_tensor_shape_t); +FF_NEW_OPAQUE_TYPE(flexflow_multihead_attention_inputs_tensor_shape_t); +FF_NEW_OPAQUE_TYPE(flexflow_batchmatmul_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_batchnorm_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_broadcast_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_cast_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_combine_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_concat_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_conv2d_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_dropout_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_element_sclar_unary_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_element_unary_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_embedding_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_flat_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_gather_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_group_by_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_input_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_layernorm_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_l1_regularizer_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_l2_regularizer_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_linear_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_sparse_categorical_crossentropy_loss_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_other_loss_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_loss_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_noop_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_pool2d_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_reduce_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_reduction_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_repartition_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_replicate_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_reshape_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_reverse_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_softmax_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_split_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_topk_attrs_t); +FF_NEW_OPAQUE_TYPE(flexflow_transpose_attrs_t); typedef enum { FLEXFLOW_DATATYPE_BOOL, @@ -62,6 +116,14 @@ typedef enum { FLEXFLOW_AGGREGATE_OP_AVG, } flexflow_aggregate_op_t; +typedef enum { + FLEXFLOW_LOSS_FUNCTION_CATEGORICAL_CROSSENTROPY, + FLEXFLOW_LOSS_FUNCTION_SPARSE_CATEGORICAL_CROSSENTROPY, + FLEXFLOW_LOSS_FUNCTION_MEAN_SQUARED_ERROR_AVG_REDUCE, + FLEXFLOW_LOSS_FUNCTION_MEAN_SQUARED_ERROR_SUM_REDUCE, + FLEXFLOW_LOSS_FUNCTION_IDENTITY, +} flexflow_loss_function_t; + typedef enum { // does _not_ have to stay synchronized with op-attrs/op.h FLEXFLOW_OP_TYPE_NOOP, FLEXFLOW_OP_TYPE_INPUT, @@ -155,7 +217,9 @@ typedef enum { // does _not_ have to stay synchronized with op-attrs/op.h typedef struct { flexflow_op_type_t op_type; void *data; -} flexflow_operator_attrs_t; +} flexflow_operator_attrs; + +FF_NEW_OPAQUE_TYPE(flexflow_operator_attrs_t); flexflow_opattrs_error_t flexflow_get_datatype_size(flexflow_datatype_t, int *out); diff --git a/lib/op-attrs/ffi/internal/internal/op-attrs.h b/lib/op-attrs/ffi/internal/internal/op-attrs.h index df0d6ce61c..b2ac652d04 100644 --- a/lib/op-attrs/ffi/internal/internal/op-attrs.h +++ b/lib/op-attrs/ffi/internal/internal/op-attrs.h @@ -2,18 +2,253 @@ #define _FLEXFLOW_OPATTRS_FFI_INTERNAL_INTERNAL_OPATTRS_H #include "flexflow/op-attrs.h" +#include "flexflow/utils.h" #include "internal/opaque.h" #include "op-attrs/activation.h" #include "op-attrs/datatype.h" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/get_op_type.h" +#include "op-attrs/get_output_shapes.h" #include "op-attrs/op.h" +#include "op-attrs/ops/aggreagate.h" +#include "op-attrs/ops/aggregate_spec.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/conv2d.h" +#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/group_by.h" +#include "op-attrs/ops/input.h" +#include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/loss_function.h" +#include "op-attrs/ops/noop.h" #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/ops/reduce.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/reshape.h" +#include "op-attrs/ops/reverse.h" +#include "op-attrs/ops/softmax.h" +#include "op-attrs/ops/split.h" +#include "op-attrs/ops/topk.h" +#include "op-attrs/ops/transpose.h" +#include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/param_sync.h" using namespace FlexFlow; REGISTER_OPAQUE(flexflow_regularizer_attrs_t, optional); +REGISTER_OPAQUE(flexflow_ff_dim_t, ff_dim_t); +// REGISTER_OPAQUE(flexflow_dim_ordered_t, DimOrdered); Note:how to define +// DimOrdered +REGISTER_OPAQUE(flexflow_parallel_dim_t, ParallelDim); +REGISTER_OPAQUE(flexflow_parallel_tensor_dims_t, ParallelTensorDims); +REGISTER_OPAQUE(flexflow_parallel_tensor_shape_t, ParallelTensorShape); +REGISTER_OPAQUE(flexflow_tensor_shape_t, TensorShape); +REGISTER_OPAQUE(flexflow_parallel_tesor_shape_list_t, + std::vector) +REGISTER_OPAQUE(flexflow_tensor_shape_list_t, std::vector) + +// ops +REGISTER_OPAQUE(flexflow_aggregate_specattrs_t, AggregateSpecAttrs); +REGISTER_OPAQUE(flexflow_aggregate_attrs_t, AggregateAttrs); +REGISTER_OPAQUE(flexflow_multihead_attention_attrs_t, MultiHeadAttentionAttrs); +REGISTER_OPAQUE(flexflow_multihead_attention_inputs_parallel_tensor_shape_t, + MultiHeadAttentionInputs); +REGISTER_OPAQUE(flexflow_multihead_attention_inputs_tensor_shape_t, + MultiHeadAttentionInputs); +REGISTER_OPAQUE(flexflow_batchmatmul_attrs_t, BatchMatmulAttrs); +REGISTER_OPAQUE(flexflow_batchnorm_attrs_t, BatchNormAttrs); +REGISTER_OPAQUE(flexflow_broadcast_attrs_t, BroadcastAttrs); +REGISTER_OPAQUE(flexflow_cast_attrs_t, CastAttrs); +REGISTER_OPAQUE(flexflow_combine_attrs_t, CombineAttrs); +REGISTER_OPAQUE(flexflow_concat_attrs_t, ConcatAttrs); +REGISTER_OPAQUE(flexflow_conv2d_attrs_t, Conv2DAttrs); +REGISTER_OPAQUE(flexflow_dropout_attrs_t, DropoutAttrs); +REGISTER_OPAQUE(flexflow_element_sclar_unary_attrs_t, ElementScalarUnaryAttrs); +REGISTER_OPAQUE(flexflow_element_unary_attrs_t, ElementUnaryAttrs); +REGISTER_OPAQUE(flexflow_embedding_attrs_t, EmbeddingAttrs); +REGISTER_OPAQUE(flexflow_flat_attrs_t, FlatAttrs); +REGISTER_OPAQUE(flexflow_gather_attrs_t, GatherAttrs); +REGISTER_OPAQUE(flexflow_group_by_attrs_t, GroupByAttrs); +REGISTER_OPAQUE(flexflow_input_attrs_t, InputAttrs); +REGISTER_OPAQUE(flexflow_layernorm_attrs_t, LayerNormAttrs); +REGISTER_OPAQUE(flexflow_l1_regularizer_attrs_t, L1RegularizerAttrs); +REGISTER_OPAQUE(flexflow_l2_regularizer_attrs_t, L2RegularizerAttrs); +REGISTER_OPAQUE(flexflow_linear_attrs_t, LinearAttrs); +REGISTER_OPAQUE(flexflow_sparse_categorical_crossentropy_loss_attrs_t, + SparseCategoricalCrossEntropyLossAttrs); +REGISTER_OPAQUE(flexflow_other_loss_attrs_t, OtherLossAttrs); +REGISTER_OPAQUE(flexflow_loss_attrs_t, LossAttrs); +REGISTER_OPAQUE(flexflow_noop_attrs_t, NoopAttrs); +REGISTER_OPAQUE(flexflow_pool2d_attrs_t, Pool2DAttrs); +REGISTER_OPAQUE(flexflow_reduce_attrs_t, ReduceAttrs); +REGISTER_OPAQUE(flexflow_reduction_attrs_t, ReductionAttrs); +REGISTER_OPAQUE(flexflow_repartition_attrs_t, RepartitionAttrs); +REGISTER_OPAQUE(flexflow_replicate_attrs_t, ReplicateAttrs); +REGISTER_OPAQUE(flexflow_reshape_attrs_t, ReshapeAttrs); +REGISTER_OPAQUE(flexflow_reverse_attrs_t, ReverseAttrs); +REGISTER_OPAQUE(flexflow_softmax_attrs_t, SoftmaxAttrs); +REGISTER_OPAQUE(flexflow_split_attrs_t, SplitAttrs); +REGISTER_OPAQUE(flexflow_topk_attrs_t, TopKAttrs); + +REGISTER_OPAQUE(flexflow_operator_attrs_t, flexflow_operator_attrs); + +flexflow_error_t + flexflow_get_output_shape(flexflow_aggregate_specattrs_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t *out, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t *, + int num_exp_preds); +flexflow_error_t flexflow_is_valid(flexflow_aggregate_attrs_t, + flexflow_parallel_tensor_shape_t, + bool *out, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t *, + int num_exp_preds); + +flexflow_error_t + flexflow_get_output_shape(flexflow_aggregate_attrs_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t *out, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t *, + int num_exp_preds); + +flexflow_error_t flexflow_get_kProjSize(flexflow_multihead_attention_attrs_t, + int *out); +flexflow_error_t flexflow_get_vProjSize(flexflow_multihead_attention_attrs_t, + int *out); +flexflow_error_t flexflow_get_kProjSize(flexflow_multihead_attention_attrs_t, + int *out); +flexflow_error_t flexflow_get_oProjSize(flexflow_multihead_attention_attrs_t, + int *out); + +flexflow_error_t flexflow_get_qSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); +flexflow_error_t flexflow_get_kSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); +flexflow_error_t flexflow_get_vSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); +flexflow_error_t flexflow_get_oSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); + +flexflow_error_t flexflow_get_qoSeqLength( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); +flexflow_error_t flexflow_get_kvSeqLength( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); + +flexflow_error_t flexflow_get_num_samples( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, int *out); + +flexflow_error_t flexflow_get_weights_shape( + flexflow_multihead_attention_attrs_t, + flexflow_multihead_attention_inputs_tensor_shape_t, + flexflow_tensor_shape_t *out); + +flexflow_error_t flexflow_get_weights_shape( + flexflow_multihead_attention_attrs_t, + flexflow_multihead_attention_inputs_parallel_tensor_shape_t, + flexflow_parallel_tensor_shape_t *out); + +flexflow_error_t flexflow_get_output_shape( + flexflow_multihead_attention_attrs_t, + flexflow_multihead_attention_inputs_tensor_shape_t, + flexflow_parallel_tensor_shape_t *out, ); + +flexflow_error_t flexflow_get_output_shape( + flexflow_multihead_attention_attrs_t, + flexflow_multihead_attention_inputs_tensor_shape_t, + flexflow_tensor_shape_t *out); + +flexflow_error_t + flexflow_get_output_shape(flexflow_batchnorm_attrs_t, + flexflow_parallel_tensor_shape_t *out); + +flexflow_error_t flexflow_get_kernel_shape(flexflow_conv2d_attrs_t, + flexflow_tensor_shape_t *out, + flexflow_tensor_shape_t); + +flexflow_error_t flexflow_get_bias_shape(flexflow_conv2d_attrs_t, + flexflow_tensor_shape_t *out, + flexflow_tensor_shape_t); + +flexflow_error_t flexflow_get_weights_shape(flexflow_embedding_attrs_t, + flexflow_tensor_shape_t *out, + flexflow_tensor_shape_t); + +flexflow_error_t + flexflow_parse_loss_function_name(char **, flexflow_loss_function_t *out); + +flexflow_error_t flexflow_get_loss_function(flexflow_other_loss_attrs_t, + flexflow_loss_function_t *out); + +flexflow_error_t flexflow_get_loss_function( + flexflow_sparse_categorical_crossentropy_loss_attrs_t, + flexflow_loss_function_t *out); + +flexflow_error_t flexflow_get_loss_function(flexflow_loss_attrs_t, + flexflow_loss_function_t *out); + +// TODO(Note lambda):how to define nner_to_outer_idxs, outer_to_inner_idxs,how +// to define DimOrdered outer_to_inner(op-attrs/include/op-attrs/dim_ordered.h) + +// Note(lambda): have to define all +// get_output_shape(op-attrs/include/op-attrs/get_output_shapes.h)? + +flexflow_error_t flexflow_is_valid(flexflow_parallel_dim_t, bool *out); + +flexflow_error_t flexflow_is_replica_dim(flexflow_parallel_dim_t, bool *out); + +flexflow_error_t flexflow_is_valid(flexflow_parallel_tensor_dims_t, bool *out); + +flexflow_error_t flexflow_get_piece_dims(flexflow_parallel_tensor_dims_t, + flexflow_tensor_dims_t *out); + +flexflow_error_t + flexflow_get_tensor_dims_unsafe(flexflow_parallel_tensor_dims_t, + flexflow_tensor_dims_t *out); + +flexflow_error_t flexflow_get_piece_shape(flexflow_parallel_tensor_shape_t, + flexflow_tensor_shape_t *out); + +flexflow_error_t + flexflow_get_num_replica_dims(flexflolw_parallel_tensor_shape_t, int *out); + +flexflow_error_t flexflow_get_num_replicas(flexflow_parallel_tensor_shape_t, + int *out); + +flexflow_error_t flexflow_is_valid(flexflow_parallel_tensor_shape_t, bool *out); + +flexflow_error_t + flexflow_get_tensor_shape_unsafe(flexflow_parallel_tensor_shape_t, + flexflow_tensor_shape_t *out); + +flexflow_error_t + flexflow_get_tensor_shape_unsafe(flexflow_parallel_tesor_shape_t *input, + int num_input, + flexflow_tensor_shape_list_t *out); optional to_internal(flexflow_param_sync_t); flexflow_param_sync_t to_external(optional); diff --git a/lib/op-attrs/ffi/src/op-attrs.cc b/lib/op-attrs/ffi/src/op-attrs.cc index 828574dc25..c28c4b624e 100644 --- a/lib/op-attrs/ffi/src/op-attrs.cc +++ b/lib/op-attrs/ffi/src/op-attrs.cc @@ -1,14 +1,77 @@ #include "flexflow/op-attrs.h" +#include "flexflow/utils.h" #include "internal/enums.h" #include "internal/error.h" #include "internal/op-attrs.h" #include "op-attrs/op.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/loss_functions.h" #include "utils/bidict.h" +#include "utils/exception.h" flexflow_utils_exception_t make_opattrs_exception(flexflow_opattrs_error_code_t); +flexflow_error_t flexflow_opattrs_error_wrap(flexflow_opattrs_error_t e) { + return flexflow_error_wrap(FLEXFLOW_ERROR_SOURCE_OPATTRS, *unwrap_opaque(e)); +} + +flexflow_error_t flexflow_opattrs_error_unwrap(flexflow_error_t err, + flexflow_opattrs_error_t *out) { + return flexflow_error_unwrap(err, FLEXFLOW_ERROR_SOURCE_OPATTRS, out); +} + +flexflow_error_t flexflow_opattrs_error_is_ok(flexflow_opattrs_error_t err, + bool *out) { + *out = false; + return status_ok(); +} + +flexflow_error_t flexflow_opattrs_error_get_string(flexflow_opattrs_error_t err, + char **m_out) { + flexflow_opattrs_error_code_t err_code; + flexflow_opattrs_error_get_error_code(err, &err_code); + auto out = const_cast(m_out); + switch (err_code) { + case FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_PARAM_SYNC_VALUE: + *out = strdup("Invalid param sync value"); + break; + case FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_DATATYPE_VALUE: + *out = strdup("Invalid datatype value"); + break; + case FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_ACTIVATION_VALUE: + *out = strdup("Invalid activation value"); + break; + case FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_POOL_OP_VALUE: + *out = strdup("Invalid pool op value"); + break; + case FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_AGGREGATE_OP_VALUE: + *out = strdup("Invalid aggregate op value"); + break; + case FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_OP_TYPE_VALUE: + *out = strdup("Invalid op type value"); + break; + default: + *out = strdup("Unknown error"); + } + return status_ok(); +} + +flexflow_error_t + flexflow_opattrs_error_get_error_code(flexflow_opattrs_error_t err, + flexflow_opattrs_error_code_t *out) { + flexflow_opattrs_error_t opaque; + RAISE_FLEXFLOW(flexflow_opattrs_error_unwrap(err, &opaque)); + interal_flexflow_opattrs_error_t const *unwrapped = unwrap_opaque(opaque); + *out = unwrapped->code; + return status_ok(); +} + +flexflow_error_t flexflow_opattrs_error_destroy(flexflow_opattrs_error_t err) { + return status_ok(); // Note(lambda): this is follow the + // https://github.com/lockshaw/FlexFlow/blob/expanded-ffi/lib/pcg/ffi/src/pcg.cc#L71-#L72 +} + REGISTER_FFI_ENUM(flexflow_param_sync_t, ParamSync, FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_PARAM_SYNC_VALUE, @@ -46,6 +109,18 @@ REGISTER_FFI_ENUM(flexflow_aggregate_op_t, {{FLEXFLOW_AGGREGATE_OP_SUM, AggregateOp::SUM}, {FLEXFLOW_AGGREGATE_OP_AVG, AggregateOp::AVG}}); +REGISTER_FFI_NUM(flexflow_loss_function_t, + LossFunction, + FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_LOSS_FUNCTION_VALUE, + {{FLEXFLOW_LOSS_FUNCTION_CATEGORICAL_CROSSENTROPY, + LossFunction::CATEGORICAL_CROSSENTROPY}, + {FLEXFLOW_LOSS_FUNCTION_SPARSE_CATEGORICAL_CROSSENTROPY, + LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY}, + {FLEXFLOW_LOSS_FUNCTION_MEAN_SQUARED_ERROR, + LossFunction::MEAN_SQUARED_ERROR}, + {FLEXFLOW_LOSS_FUNCTION_MEAN_ABSOLUTE_ERROR, + LossFunction::MEAN_ABSOLUTE_ERROR}}); + REGISTER_FFI_ENUM(flexflow_op_type_t, OperatorType, FLEXFLOW_OPATTRS_ERROR_CODE_INVALID_OP_TYPE_VALUE, @@ -141,9 +216,344 @@ REGISTER_FFI_ENUM(flexflow_op_type_t, flexflow_error_t make_opattrs_error(flexflow_opattrs_error_code_t); +flexflow_error_t flexflow_get_output_shape( + flexflow_aggregate_specattrs_t aggregate_spec_attrs, + flexflow_parallel_tensor_shape_t gate_preds, + flexflow_parallel_tensor_shape_t *out, + flexflow_parallel_tensor_shape_t gate_assign, + flexflow_parallel_tensor_shape_t true_gate_assign, + flexflow_parallel_tensor_shape_t gate_gridents_full, + flexflow_parallel_tensor_shape_t *exp_preds, + int num_exp_preds) { + return handle_errors(out, [&]) { + return get_out_shape(deref_opaque(aggregate_spec_attrs), + deref_opaque(gate_preds), + deref_opaque(gate_assign), + deref_opaque(true_gate_assign), + deref_opaque(gate_gridents_full), + c_deref_opaque_list(exp_preds, num_exp_preds)); + } +} + +flexflow_error_t flexflow_is_valid( + flexflow_aggregate_attrs_t aggregate_attrs, + flexflow_parallel_tensor_shape_t gate_preds, + bool *out, + flexflow_parallel_tensor_shape_t gate_assign, + flexflow_parallel_tensor_shape_t true_gate_assign, + flexflow_parallel_tensor_shape_t full_gate_gradients, + flexflow_parallel_tensor_shape_t *exp_preds int num_exp_preds) { + return handle_errors(out, [&]) { + return is_valid(deref_opaque(aggregate_attrs), + deref_opaque(gate_preds), + deref_opaque(gate_assign), + deref_opaque(true_gate_assign), + deref_opaque(full_gate_gradients), + c_deref_opaque_list(exp_preds, num_exp_preds)); + } +} + +flexflow_error_t flexflow_get_output_shape( + flexflow_aggregate_attrs_t aggregate_attrs, + flexflow_parallel_tensor_shape_t gate_preds, + flexflow_parallel_tensor_shape_t *out, + flexflow_parallel_tensor_shape_t gate_assign, + flexflow_parallel_tensor_shape_t true_gate_assign, + flexflow_parallel_tensor_shape_t full_gate_gradients, + flexflow_parallel_tensor_shape_t *exp_preds, + int num_exp_preds) { + return handle_errors(out, [&]) { + return get_out_shape(deref_opaque(aggregate_attrs), + deref_opaque(gate_preds), + deref_opaque(gate_assign), + deref_opaque(true_gate_assign), + deref_opaque(full_gate_gradients), + c_deref_opaque_list(exp_preds, num_exp_preds)); + } +} + +flexflow_error_t flexflow_get_kProjSize( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, int *out) { + return handle_errors(out, [&]) { + return get_kProjSize(deref_opaque(multi_head_attention_attrs)); + } +} + +flexflow_error_t flexflow_get_vProjSize( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, int *out) { + return handle_errors(out, [&]) { + return get_vProjSize(deref_opaque(multi_head_attention_attrs)); + } +} + +flexflow_error_t flexflow_get_kProjSize( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, int *out) { + return handle_errors(out, [&]) { + return get_kProjSize(deref_opaque(multi_head_attention_attrs)); + } +} + +flexflow_error_t flexflow_get_oProjSize( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, int *out) { + return handle_errors(out, [&]) { + return get_oProjSize(deref_opaque(multi_head_attention_attrs)); + } +} + +flexflow_error_t flexflow_get_qSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_qSize(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_kSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_kSize(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_vSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_vSize(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_oSize( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_oSize(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_qoSeqLength( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_qoSeqLength(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_kvSeqLength( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_kvSeqLength(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_num_samples( + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + int *out) { + return handle_errors(out, [&]) { + return get_num_samples(deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_weights_shape( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, + flexflow_multihead_attention_inputs_tensor_shape_t + multi_head_attention_inputs, + flexflow_tensor_shape_t *out) { + return handle_errors(out, [&]) { + return get_weights_shape(deref_opaque(multi_head_attention_attrs), + deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_weights_shape( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, + flexflow_multihead_attention_inputs_parallel_tensor_shape_t + multi_head_attention_inputs, + flexflow_parallel_tensor_shape_t *out) { + return handle_errors(out, [&]) { + return get_weights_shape(deref_opaque(multi_head_attention_attrs), + deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_output_shape( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, + flexflow_multihead_attention_inputs_tensor_shape_t + multi_head_attention_inputs, + flexflow_parallel_tensor_shape_t *out, ) { + return handle_errors(out, [&]) { + return get_output_shape(deref_opaque(multi_head_attention_attrs), + deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t flexflow_get_output_shape( + flexflow_multihead_attention_attrs_t multi_head_attention_attrs, + flexflow_multihead_attention_inputs_tensor_shape_t + multi_head_attention_inputs, + flexflow_tensor_shape_t *out) { + return handle_errors(out, [&]) { + return get_output_shape(deref_opaque(multi_head_attention_attrs), + deref_opaque(multi_head_attention_inputs)); + } +} + +flexflow_error_t + flexflow_get_output_shape(flexflow_batchnorm_attrs_t batchnorm_attrs, + flexflow_parallel_tensor_shape_t *out) { + return handle_errors(out, [&]) { + return get_output_shape(deref_opaque(batchnorm_attrs)); + } +} + +flexflow_error_t + flexflow_get_kernel_shape(flexflow_conv2d_attrs_t conv2d_attrs, + flexflow_tensor_shape_t *out, + flexflow_tensor_shape_t input_shape) { + return handle_errors(out, [&]) { + return get_kernel_shape(deref_opaque(conv2d_attrs), + deref_opaque(input_shape)); + } +} + +flexflow_error_t flexflow_get_bias_shape(flexflow_conv2d_attrs_t conv2d_attrs, + flexflow_tensor_shape_t *out, + flexflow_tensor_shape_t input_shape) { + return handle_errors(out, [&]) { + return get_bias_shape(deref_opaque(conv2d_attrs), + deref_opaque(input_shape)); + } +} + +flexflow_error_t + flexflow_get_weights_shape(flexflow_embedding_attrs_t embedding_attrs, + flexflow_tensor_shape_t *out, + flexflow_tensor_shape_t input_shape) { + return handle_errors(out, [&]) { + return get_weights_shape(deref_opaque(embedding_attrs), + deref_opaque(input_shape)); + } +} + +flexflow_error_t + flexflow_parse_loss_function_name(char **raw_name, + flexflow_loss_function_t *out) { + NOT_IMPLEMENTED(); // Note(lambda):how to implement the function +} + +flexflow_error_t flexflow_is_valid(flexflow_parallel_dim_t parallel_dim_t, + bool *out) { + return handle_errors(out, [&]) { + return is_valid(deref_opaque(parallel_dim_t)); + } +} + +flexflow_error_t flexflow_is_replica_dim(flexflow_parallel_dim_t parallel_dim_t, + bool *out) { + return handle_errors(out, [&]) { + return is_replica_dim(deref_opaque(parallel_dim_t)); + } +} + +flexflow_error_t + flexflow_is_valid(flexflow_parallel_tensor_dims_t parallel_tensor_dims_t, + bool *out) { + return handle_errors(out, [&]) { + return is_valid(deref_opaque(parallel_tensor_dims_t)); + } +} + +flexflow_error_t flexflow_get_piece_dims( + flexflow_parallel_tensor_dims_t parallel_tensor_dims_t, + flexflow_tensor_dims_t *out) { + return handle_errors(out, [&]) { + return get_piece_dims(deref_opaque(parallel_tensor_dims_t)); + } +} + +flexflow_error_t flexflow_get_tensor_dims_unsafe( + flexflow_parallel_tensor_dims_t tensor_dims_t, + flexflow_tensor_dims_t *out) { + return handle_errors(out, [&]) { + return get_tensor_dims_unsafe(deref_opaque(tensor_dims_t)); + } +} + +flexflow_error_t flexflow_get_piece_shape( + flexflow_parallel_tensor_shape_t parallel_tensor_shape, + flexflow_tensor_shape_t *out) { + return handle_errors(out, [&]) { + return get_piece_shape(deref_opaque(parallel_tensor_shape)); + } +} + +flexflow_error_t flexflow_get_num_replica_dims( + flexflolw_parallel_tensor_shape_t parallel_tensor_shape, int *out) { + return handle_errors(out, [&]) { + return get_num_replica_dims(deref_opaque(parallel_tensor_shape)); + } +} + +flexflow_error_t flexflow_get_num_replicas( + flexflow_parallel_tensor_shape_t parallel_tensor_shape, int *out) { + return handle_errors(out, [&]) { + return get_num_replicas(deref_opaque(parallel_tensor_shape)); + } +} + +flexflow_error_t + flexflow_is_valid(flexflow_parallel_tensor_shape_t parallel_tensor_shape, + bool *out) { + return handle_errors(out, [&]) { + return is_valid(deref_opaque(parallel_tensor_shape)); + } +} + +flexflow_error_t flexflow_get_tensor_shape_unsafe( + flexflow_parallel_tensor_shape_t parallel_tensor_shape, + flexflow_tensor_shape_t *out) { + return handle_errors(out, [&]) { + return get_tensor_shape_unsafe(deref_opaque(parallel_tensor_shape)); + } +} + +flexflow_error_t + flexflow_get_tensor_shape_unsafe(flexflow_parallel_tesor_shape_t *input, + int num_input, + flexflow_tensor_shape_list_t *out) { + return handle_errors(out, [&]) { + return get_tensor_shape_unsafe(c_deref_opaque_list(input, num_input)); + } +} + +flexflow_opattrs_error_t + flexflow_get_datatype_size(flexflow_datatype_t datatype, int *out) { + return handle_errors(out, [&]) { + return size_of(to_internal(datatype)); + } +} + +flexflow_opattrs_error_t + flexflow_operator_attrs_get_op_type(flexflow_operator_attrs_t op_attrs, + flexflow_op_type_t *out) { + return handle_errors(out, [&]) { + return deref_opaque(op_attrs).op_type; + } +} + ParamSync to_internal(flexflow_param_sync_t e) { return to_internal_impl(e); } + flexflow_param_sync_t to_external(ParamSync i) { return to_external_impl(i); } @@ -158,6 +568,7 @@ flexflow_datatype_t to_external(DataType i) { optional to_internal(flexflow_activation_t e) { return to_internal_impl(e); } + flexflow_activation_t to_external(optional i) { return to_external_impl(i); }