diff --git a/Cargo.lock b/Cargo.lock index 7884f92e..321e8b94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11326,7 +11326,6 @@ dependencies = [ "tonic 0.12.3", "tonic-build 0.11.0", "tonic-build 0.12.3", - "torii-sqlite-types", ] [[package]] @@ -11444,6 +11443,7 @@ dependencies = [ "ipfs-api-backend-hyper", "katana-runner", "once_cell", + "rayon", "reqwest", "scarb", "serde", @@ -11458,6 +11458,7 @@ dependencies = [ "tokio", "tokio-util", "torii-indexer", + "torii-proto", "torii-sqlite-types", "tracing", ] @@ -11468,10 +11469,12 @@ version = "1.5.7" dependencies = [ "anyhow", "chrono", + "crypto-bigint", "dojo-types 1.5.0-alpha.2", "serde", "sqlx", "starknet 0.12.0", + "torii-proto", ] [[package]] diff --git a/crates/graphql/src/constants.rs b/crates/graphql/src/constants.rs index ac86adb9..c25927c5 100644 --- a/crates/graphql/src/constants.rs +++ b/crates/graphql/src/constants.rs @@ -3,16 +3,6 @@ pub const DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%SZ"; pub const DEFAULT_LIMIT: u64 = 10; pub const BOOLEAN_TRUE: i64 = 1; -pub const ENTITY_TABLE: &str = "entities"; -pub const EVENT_TABLE: &str = "events"; -pub const EVENT_MESSAGE_TABLE: &str = "event_messages"; -pub const MODEL_TABLE: &str = "models"; -pub const TRANSACTION_TABLE: &str = "transactions"; -pub const TRANSACTION_CALLS_TABLE: &str = "transaction_calls"; -pub const TOKEN_TRANSFER_TABLE: &str = "token_transfers"; -pub const METADATA_TABLE: &str = "metadata"; -pub const CONTROLLER_TABLE: &str = "controllers"; - pub const ID_COLUMN: &str = "id"; pub const EVENT_ID_COLUMN: &str = "event_id"; pub const ENTITY_ID_COLUMN: &str = "internal_entity_id"; diff --git a/crates/graphql/src/object/controller.rs b/crates/graphql/src/object/controller.rs index 2a7c9b19..0412294c 100644 --- a/crates/graphql/src/object/controller.rs +++ b/crates/graphql/src/object/controller.rs @@ -1,7 +1,8 @@ use async_graphql::dynamic::Field; +use torii_sqlite::types::Table; use super::{BasicObject, ResolvableObject, TypeMapping}; -use crate::constants::{CONTROLLER_NAMES, CONTROLLER_TABLE, CONTROLLER_TYPE_NAME, ID_COLUMN}; +use crate::constants::{CONTROLLER_NAMES, CONTROLLER_TYPE_NAME, ID_COLUMN}; use crate::mapping::CONTROLLER_MAPPING; use crate::object::{resolve_many, resolve_one}; @@ -25,7 +26,7 @@ impl BasicObject for ControllerObject { impl ResolvableObject for ControllerObject { fn resolvers(&self) -> Vec { let resolve_one = resolve_one( - CONTROLLER_TABLE, + Table::Controllers, ID_COLUMN, self.name().0, self.type_name(), @@ -33,7 +34,7 @@ impl ResolvableObject for ControllerObject { ); let resolve_many = resolve_many( - CONTROLLER_TABLE, + Table::Controllers, ID_COLUMN, self.name().1, self.type_name(), diff --git a/crates/graphql/src/object/entity.rs b/crates/graphql/src/object/entity.rs index 77bdbd2f..d2a97b60 100644 --- a/crates/graphql/src/object/entity.rs +++ b/crates/graphql/src/object/entity.rs @@ -8,12 +8,12 @@ use dojo_types::schema::Ty; use sqlx::{Pool, Sqlite}; use tokio_stream::StreamExt; use torii_sqlite::simple_broker::SimpleBroker; -use torii_sqlite::types::Entity; +use torii_sqlite::types::{Entity, Table}; use super::inputs::keys_input::keys_argument; use super::{BasicObject, ResolvableObject, TypeMapping, ValueMapping}; use crate::constants::{ - DATETIME_FORMAT, ENTITY_NAMES, ENTITY_TABLE, ENTITY_TYPE_NAME, EVENT_ID_COLUMN, ID_COLUMN, + DATETIME_FORMAT, ENTITY_NAMES, ENTITY_TYPE_NAME, EVENT_ID_COLUMN, ID_COLUMN, }; use crate::mapping::ENTITY_TYPE_MAPPING; use crate::object::{resolve_many, resolve_one}; @@ -43,7 +43,7 @@ impl BasicObject for EntityObject { impl ResolvableObject for EntityObject { fn resolvers(&self) -> Vec { let resolve_one = resolve_one( - ENTITY_TABLE, + Table::Entities, ID_COLUMN, self.name().0, self.type_name(), @@ -51,7 +51,7 @@ impl ResolvableObject for EntityObject { ); let mut resolve_many = resolve_many( - ENTITY_TABLE, + Table::Entities, EVENT_ID_COLUMN, self.name().1, self.type_name(), diff --git a/crates/graphql/src/object/event.rs b/crates/graphql/src/object/event.rs index b12c3ffd..6d02b30a 100644 --- a/crates/graphql/src/object/event.rs +++ b/crates/graphql/src/object/event.rs @@ -5,7 +5,7 @@ use async_graphql::{Name, Result, Value}; use tokio_stream::{Stream, StreamExt}; use torii_sqlite::constants::SQL_FELT_DELIMITER; use torii_sqlite::simple_broker::SimpleBroker; -use torii_sqlite::types::Event; +use torii_sqlite::types::{Event, Table}; use super::inputs::keys_input::{keys_argument, parse_keys_argument}; use super::{resolve_many, BasicObject, ResolvableObject, TypeMapping}; @@ -33,7 +33,7 @@ impl BasicObject for EventObject { impl ResolvableObject for EventObject { fn resolvers(&self) -> Vec { let mut resolve_many = resolve_many( - EVENT_TABLE, + Table::Events, ID_COLUMN, self.name().1, self.type_name(), diff --git a/crates/graphql/src/object/event_message.rs b/crates/graphql/src/object/event_message.rs index 134b8a8e..79578c32 100644 --- a/crates/graphql/src/object/event_message.rs +++ b/crates/graphql/src/object/event_message.rs @@ -8,12 +8,12 @@ use dojo_types::schema::Ty; use sqlx::{Pool, Sqlite}; use tokio_stream::StreamExt; use torii_sqlite::simple_broker::SimpleBroker; -use torii_sqlite::types::EventMessage; +use torii_sqlite::types::{EventMessage, Table}; use super::inputs::keys_input::keys_argument; use super::{BasicObject, ResolvableObject, TypeMapping, ValueMapping}; use crate::constants::{ - DATETIME_FORMAT, EVENT_ID_COLUMN, EVENT_MESSAGE_NAMES, EVENT_MESSAGE_TABLE, + DATETIME_FORMAT, EVENT_ID_COLUMN, EVENT_MESSAGE_NAMES, EVENT_MESSAGE_TYPE_NAME, ID_COLUMN, }; use crate::mapping::ENTITY_TYPE_MAPPING; @@ -45,7 +45,7 @@ impl BasicObject for EventMessageObject { impl ResolvableObject for EventMessageObject { fn resolvers(&self) -> Vec { let resolve_one = resolve_one( - EVENT_MESSAGE_TABLE, + Table::EventMessages, ID_COLUMN, self.name().0, self.type_name(), @@ -53,7 +53,7 @@ impl ResolvableObject for EventMessageObject { ); let mut resolve_many = resolve_many( - EVENT_MESSAGE_TABLE, + Table::EventMessages, EVENT_ID_COLUMN, self.name().1, self.type_name(), diff --git a/crates/grpc/server/src/lib.rs b/crates/grpc/server/src/lib.rs index 03b4f63c..d5b5b001 100644 --- a/crates/grpc/server/src/lib.rs +++ b/crates/grpc/server/src/lib.rs @@ -109,7 +109,6 @@ impl DojoWorld

{ sql: Sql, provider: Arc

, world_address: Felt, - model_cache: Arc, cross_messaging_tx: Option>, config: GrpcConfig, ) -> Self { @@ -165,54 +164,22 @@ impl DojoWorld

{ impl DojoWorld

{ pub async fn world(&self) -> Result { - let world_address = sqlx::query_scalar(&format!( - "SELECT contract_address FROM contracts WHERE id = '{:#x}'", - self.world_address - )) - .fetch_one(&self.sql.pool) - .await?; - - #[derive(FromRow)] - struct ModelDb { - id: String, - namespace: String, - name: String, - class_hash: String, - contract_address: String, - packed_size: u32, - unpacked_size: u32, - layout: String, - } - - let models: Vec = sqlx::query_as( - "SELECT id, namespace, name, class_hash, contract_address, packed_size, \ - unpacked_size, layout FROM models", - ) - .fetch_all(&self.sql.pool) - .await?; - - let mut models_metadata = Vec::with_capacity(models.len()); - for model in models { - let schema = self - .model_cache - .model(&Felt::from_str(&model.id).map_err(ParseError::FromStr)?) - .await? - .schema; - models_metadata.push(proto::types::ModelMetadata { - namespace: model.namespace, - name: model.name, - class_hash: model.class_hash, - contract_address: model.contract_address, - packed_size: model.packed_size, - unpacked_size: model.unpacked_size, - layout: model.layout.as_bytes().to_vec(), - schema: serde_json::to_vec(&schema).unwrap(), - }); - } + let models = self.sql.models(&[]).await?.iter().map(|m| { + proto::types::ModelMetadata { + namespace: m.namespace.clone(), + name: m.name.clone(), + class_hash: format!("{:#x}", m.class_hash), + contract_address: format!("{:#x}", m.contract_address), + packed_size: m.packed_size, + unpacked_size: m.unpacked_size, + layout: serde_json::to_vec(&m.layout).unwrap(), + schema: serde_json::to_vec(&m.schema).unwrap(), + } + }).collect::>(); Ok(proto::types::WorldMetadata { - world_address, - models: models_metadata, + world_address: format!("{:#x}", self.world_address), + models, }) } @@ -1041,200 +1008,6 @@ impl DojoWorld

{ next_cursor: page.next_cursor.unwrap_or_default(), }) } - - async fn retrieve_events( - &self, - query: &proto::types::EventQuery, - ) -> Result { - let limit = if query.limit > 0 { - query.limit + 1 - } else { - SQL_DEFAULT_LIMIT as u32 + 1 - }; - - let mut bind_values = Vec::new(); - let mut conditions = Vec::new(); - - let keys_pattern = if let Some(keys_clause) = &query.keys { - build_keys_pattern(keys_clause)? - } else { - String::new() - }; - - if !keys_pattern.is_empty() { - conditions.push("keys REGEXP ?"); - bind_values.push(keys_pattern); - } - - if !query.cursor.is_empty() { - conditions.push("id >= ?"); - bind_values.push(decode_cursor(&query.cursor)?); - } - - let mut events_query = r#" - SELECT id, keys, data, transaction_hash - FROM events - "# - .to_string(); - - if !conditions.is_empty() { - events_query = format!("{} WHERE {}", events_query, conditions.join(" AND ")); - } - - events_query = format!("{} ORDER BY id LIMIT ?", events_query); - bind_values.push(limit.to_string()); - - let mut row_events = sqlx::query_as(&events_query); - for value in &bind_values { - row_events = row_events.bind(value); - } - let mut row_events: Vec<(String, String, String, String)> = - row_events.fetch_all(&self.sql.pool).await?; - - let next_cursor = if row_events.len() > (limit - 1) as usize { - encode_cursor(&row_events.pop().unwrap().0)? - } else { - String::new() - }; - - let events = row_events - .iter() - .map(|(_, keys, data, transaction_hash)| { - map_row_to_event(&(keys, data, transaction_hash)) - }) - .collect::, Error>>()?; - - Ok(RetrieveEventsResponse { - events, - next_cursor, - }) - } - - async fn retrieve_controllers( - &self, - contract_addresses: Vec, - ) -> Result { - let query = if contract_addresses.is_empty() { - "SELECT address, username, deployed_at FROM controllers".to_string() - } else { - format!( - "SELECT address, username, deployed_at FROM controllers WHERE address IN ({})", - contract_addresses - .iter() - .map(|_| "?".to_string()) - .collect::>() - .join(", ") - ) - }; - - let mut db_query = sqlx::query_as::<_, (String, String, DateTime)>(&query); - for address in &contract_addresses { - db_query = db_query.bind(format!("{:#x}", address)); - } - - let rows = db_query.fetch_all(&self.sql.pool).await?; - - let controllers = rows - .into_iter() - .map( - |(address, username, deployed_at)| proto::types::Controller { - address: address.parse::().unwrap().to_bytes_be().to_vec(), - username, - deployed_at_timestamp: deployed_at.timestamp() as u64, - }, - ) - .collect(); - - Ok(RetrieveControllersResponse { controllers }) - } -} - -fn process_event_field(data: &str) -> Result>, Error> { - Ok(data - .trim_end_matches('/') - .split('/') - .filter(|&d| !d.is_empty()) - .map(|d| { - Felt::from_str(d) - .map_err(ParseError::FromStr) - .map(|f| f.to_bytes_be().to_vec()) - }) - .collect::, _>>()?) -} - -fn map_row_to_event(row: &(&str, &str, &str)) -> Result { - let keys = process_event_field(row.0)?; - let data = process_event_field(row.1)?; - let transaction_hash = Felt::from_str(row.2) - .map_err(ParseError::FromStr)? - .to_bytes_be() - .to_vec(); - - Ok(proto::types::Event { - keys, - data, - transaction_hash, - }) -} - -fn map_row_to_entity( - row: &SqliteRow, - schemas: &[Ty], - dont_include_hashed_keys: bool, -) -> Result { - let hashed_keys = Felt::from_str(&row.get::("id")).map_err(ParseError::FromStr)?; - let model_ids = row - .get::("model_ids") - .split(',') - .map(|id| Felt::from_str(id).map_err(ParseError::FromStr)) - .collect::, _>>()?; - - let models = schemas - .iter() - .filter(|schema| model_ids.contains(&compute_selector_from_tag(&schema.name()))) - .map(|schema| { - let mut ty = schema.clone(); - map_row_to_ty("", &schema.name(), &mut ty, row)?; - Ok(ty.as_struct().unwrap().clone().into()) - }) - .collect::, Error>>()?; - - Ok(proto::types::Entity { - hashed_keys: if !dont_include_hashed_keys { - hashed_keys.to_bytes_be().to_vec() - } else { - vec![] - }, - models, - }) -} - -// this builds a sql safe regex pattern to match against for keys -fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result { - const KEY_PATTERN: &str = "0x[0-9a-fA-F]+"; - - let keys = if clause.keys.is_empty() { - vec![KEY_PATTERN.to_string()] - } else { - clause - .keys - .iter() - .map(|bytes| { - if bytes.is_empty() { - return Ok(KEY_PATTERN.to_string()); - } - Ok(format!("{:#x}", Felt::from_bytes_be_slice(bytes))) - }) - .collect::, Error>>()? - }; - let mut keys_pattern = format!("^{}", keys.join("/")); - - if clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 { - keys_pattern += &format!("(/{})*", KEY_PATTERN); - } - keys_pattern += "/$"; - - Ok(keys_pattern) } // builds a composite clause for a query @@ -1874,7 +1647,6 @@ pub async fn new( sql: Sql, provider: Arc

, world_address: Felt, - model_cache: Arc, cross_messaging_tx: UnboundedSender, config: GrpcConfig, ) -> Result< @@ -1896,7 +1668,6 @@ pub async fn new( sql, provider, world_address, - model_cache, Some(cross_messaging_tx), config, ); diff --git a/crates/proto/Cargo.toml b/crates/proto/Cargo.toml index e30067cc..c010815f 100644 --- a/crates/proto/Cargo.toml +++ b/crates/proto/Cargo.toml @@ -27,8 +27,7 @@ strum.workspace = true strum_macros.workspace = true serde_json.workspace = true thiserror.workspace = true -torii-sqlite-types = { workspace = true, optional = true } [features] client = [] -server = ["torii-sqlite-types"] +server = [] diff --git a/crates/proto/src/lib.rs b/crates/proto/src/lib.rs index 295e813c..790a537d 100644 --- a/crates/proto/src/lib.rs +++ b/crates/proto/src/lib.rs @@ -24,8 +24,6 @@ use core::fmt; use std::collections::HashMap; use std::str::FromStr; -#[cfg(feature = "server")] -use crypto_bigint::Encoding; use crypto_bigint::U256; use dojo_types::primitive::Primitive; use dojo_types::schema::Ty; @@ -57,7 +55,7 @@ pub struct Message { #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct Pagination { pub cursor: Option, - pub limit: u32, + pub limit: Option, pub direction: PaginationDirection, pub order_by: Vec, } @@ -66,7 +64,7 @@ impl From for proto::types::Pagination { fn from(value: Pagination) -> Self { Self { cursor: value.cursor.unwrap_or_default(), - limit: value.limit, + limit: value.limit.unwrap_or_default(), direction: value.direction as i32, order_by: value.order_by.into_iter().map(|o| o.into()).collect(), } @@ -76,26 +74,6 @@ impl From for proto::types::Pagination { impl From for Pagination { fn from(value: proto::types::Pagination) -> Self { Self { - cursor: if value.cursor.is_empty() { - None - } else { - Some(value.cursor) - }, - limit: value.limit, - direction: match value.direction { - 0 => PaginationDirection::Forward, - 1 => PaginationDirection::Backward, - _ => unreachable!(), - }, - order_by: value.order_by.into_iter().map(|o| o.into()).collect(), - } - } -} - -#[cfg(feature = "server")] -impl From for torii_sqlite_types::Pagination { - fn from(value: proto::types::Pagination) -> Self { - torii_sqlite_types::Pagination { cursor: if value.cursor.is_empty() { None } else { @@ -107,15 +85,11 @@ impl From for torii_sqlite_types::Pagination { Some(value.limit) }, direction: match value.direction { - 0 => torii_sqlite_types::PaginationDirection::Forward, - 1 => torii_sqlite_types::PaginationDirection::Backward, + 0 => PaginationDirection::Forward, + 1 => PaginationDirection::Backward, _ => unreachable!(), }, - order_by: value - .order_by - .into_iter() - .map(|order_by| order_by.into()) - .collect(), + order_by: value.order_by.into_iter().map(|o| o.into()).collect(), } } } @@ -175,29 +149,6 @@ impl TryFrom for Token { } } -#[cfg(feature = "server")] -impl From for proto::types::Token { - fn from(value: torii_sqlite_types::Token) -> Self { - Self { - token_id: if value.token_id.is_empty() { - U256::ZERO.to_be_bytes().to_vec() - } else { - U256::from_be_hex(value.token_id.trim_start_matches("0x")) - .to_be_bytes() - .to_vec() - }, - contract_address: Felt::from_str(&value.contract_address) - .unwrap() - .to_bytes_be() - .to_vec(), - name: value.name, - symbol: value.symbol, - decimals: value.decimals as u32, - metadata: value.metadata.as_bytes().to_vec(), - } - } -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct TokenCollection { pub contract_address: Felt, @@ -222,23 +173,6 @@ impl TryFrom for TokenCollection { } } -#[cfg(feature = "server")] -impl From for proto::types::TokenCollection { - fn from(value: torii_sqlite_types::TokenCollection) -> Self { - Self { - contract_address: Felt::from_str(&value.contract_address) - .unwrap() - .to_bytes_be() - .to_vec(), - name: value.name, - symbol: value.symbol, - decimals: value.decimals as u32, - count: value.count, - metadata: value.metadata.as_bytes().to_vec(), - } - } -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct TokenBalance { pub balance: U256, @@ -259,33 +193,6 @@ impl TryFrom for TokenBalance { } } -#[cfg(feature = "server")] -impl From for proto::types::TokenBalance { - fn from(value: torii_sqlite_types::TokenBalance) -> Self { - let id = value.token_id.split(':').collect::>(); - - Self { - balance: U256::from_be_hex(value.balance.trim_start_matches("0x")) - .to_be_bytes() - .to_vec(), - account_address: Felt::from_str(&value.account_address) - .unwrap() - .to_bytes_be() - .to_vec(), - contract_address: Felt::from_str(&value.contract_address) - .unwrap() - .to_bytes_be() - .to_vec(), - token_id: if id.len() == 2 { - U256::from_be_hex(id[1].trim_start_matches("0x")) - .to_be_bytes() - .to_vec() - } else { - U256::ZERO.to_be_bytes().to_vec() - }, - } - } -} #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct IndexerUpdate { @@ -337,21 +244,6 @@ impl From for OrderBy { } } -#[cfg(feature = "server")] -impl From for torii_sqlite_types::OrderBy { - fn from(value: proto::types::OrderBy) -> Self { - torii_sqlite_types::OrderBy { - model: value.model, - member: value.member, - direction: match value.direction { - 0 => torii_sqlite_types::OrderDirection::Asc, - 1 => torii_sqlite_types::OrderDirection::Desc, - _ => unreachable!(), - }, - } - } -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub enum OrderDirection { Asc, @@ -793,7 +685,7 @@ impl From for Event { #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct EventQuery { - pub keys: KeysClause, + pub keys: Option, pub limit: u32, pub cursor: Option, } @@ -801,7 +693,7 @@ pub struct EventQuery { impl From for proto::types::EventQuery { fn from(value: EventQuery) -> Self { Self { - keys: Some(value.keys.into()), + keys: value.keys.map(|k| k.into()), limit: value.limit, cursor: value.cursor.unwrap_or_default(), } diff --git a/crates/runner/src/lib.rs b/crates/runner/src/lib.rs index def3102c..92abec2c 100644 --- a/crates/runner/src/lib.rs +++ b/crates/runner/src/lib.rs @@ -343,7 +343,6 @@ impl Runner { db.clone(), provider.clone(), world_address, - model_cache, cross_messaging_tx, GrpcConfig { subscription_buffer_size: self.args.grpc.subscription_buffer_size, diff --git a/crates/sqlite/sqlite/Cargo.toml b/crates/sqlite/sqlite/Cargo.toml index 27ae891a..2b07b2ea 100644 --- a/crates/sqlite/sqlite/Cargo.toml +++ b/crates/sqlite/sqlite/Cargo.toml @@ -34,12 +34,17 @@ sqlx.workspace = true starknet-crypto.workspace = true starknet.workspace = true thiserror.workspace = true -tokio = { version = "1.32.0", features = [ "macros", "sync" ], default-features = true } +tokio = { version = "1.32.0", features = [ + "macros", + "sync", +], default-features = true } # tokio-stream = "0.1.11" ipfs-api-backend-hyper.workspace = true tokio-util.workspace = true tracing.workspace = true flate2.workspace = true +torii-proto.workspace = true +rayon.workspace = true dashmap.workspace = true [dev-dependencies] diff --git a/crates/sqlite/sqlite/src/cache.rs b/crates/sqlite/sqlite/src/cache.rs index f1b02a04..a7e8fb19 100644 --- a/crates/sqlite/sqlite/src/cache.rs +++ b/crates/sqlite/sqlite/src/cache.rs @@ -13,8 +13,8 @@ use starknet::core::utils::get_selector_from_name; use starknet::providers::{Provider, ProviderError}; use starknet_crypto::Felt; use tokio::sync::{Mutex, RwLock}; +use torii_sqlite_types::Table; -use crate::constants::TOKENS_TABLE; use crate::error::{Error, ParseError}; use crate::utils::I256; @@ -118,8 +118,11 @@ impl ModelCache { layout, schema, ): (String, String, String, String, u32, u32, String, String) = sqlx::query_as( - "SELECT namespace, name, class_hash, contract_address, packed_size, unpacked_size, \ - layout, schema FROM models WHERE id = ?", + &format!( + "SELECT namespace, name, class_hash, contract_address, packed_size, unpacked_size, \ + layout, schema FROM {table} WHERE id = ?", + table = Table::Models + ), ) .bind(format!("{:#x}", selector)) .fetch_one(&self.pool) @@ -178,7 +181,7 @@ impl LocalCache { pub async fn new(pool: Pool) -> Self { // read existing token_id's from balances table and cache them let token_id_registry: Vec = - sqlx::query_scalar(&format!("SELECT id FROM {TOKENS_TABLE}")) + sqlx::query_scalar(&format!("SELECT id FROM {table}", table = Table::Tokens)) .fetch_all(&pool) .await .expect("Should be able to read token_id's from blances table"); diff --git a/crates/sqlite/sqlite/src/constants.rs b/crates/sqlite/sqlite/src/constants.rs index c174aeef..686a749f 100644 --- a/crates/sqlite/sqlite/src/constants.rs +++ b/crates/sqlite/sqlite/src/constants.rs @@ -1,7 +1,4 @@ pub const QUERY_QUEUE_BATCH_SIZE: usize = 1000; -pub const TOKEN_BALANCE_TABLE: &str = "token_balances"; -pub const TOKEN_TRANSFER_TABLE: &str = "token_transfers"; -pub const TOKENS_TABLE: &str = "tokens"; pub const WORLD_CONTRACT_TYPE: &str = "WORLD"; pub const SQL_FELT_DELIMITER: &str = "/"; pub const REQ_MAX_RETRIES: u8 = 3; diff --git a/crates/sqlite/sqlite/src/cursor.rs b/crates/sqlite/sqlite/src/cursor.rs new file mode 100644 index 00000000..6b98c984 --- /dev/null +++ b/crates/sqlite/sqlite/src/cursor.rs @@ -0,0 +1,132 @@ +use base64::prelude::BASE64_URL_SAFE_NO_PAD; +use base64::Engine; +use flate2::read::DeflateDecoder; +use flate2::write::DeflateEncoder; +use flate2::Compression; +use sqlx::sqlite::SqliteRow; +use torii_proto::{OrderDirection, Pagination, PaginationDirection}; +use std::io::prelude::*; +use sqlx::Row; + +use crate::error::{Error, QueryError}; + +pub(crate) fn build_cursor_conditions( + pagination: &Pagination, + cursor_values: Option<&[String]>, + table_name: &str, +) -> Result<(Vec, Vec), Error> { + let mut conditions = Vec::new(); + let mut binds = Vec::new(); + + if let Some(values) = cursor_values { + let expected_len = if pagination.order_by.is_empty() { + 1 + } else { + pagination.order_by.len() + 1 + }; + if values.len() != expected_len { + return Err(Error::Query(QueryError::InvalidCursor( + "Invalid cursor values length".to_string(), + ))); + } + + if pagination.order_by.is_empty() { + let operator = if pagination.direction == PaginationDirection::Forward { + "<" + } else { + ">" + }; + conditions.push(format!("{}.event_id {} ?", table_name, operator)); + binds.push(values[0].clone()); + } else { + for (i, (ob, val)) in pagination.order_by.iter().zip(values).enumerate() { + let operator = match (&ob.direction, &pagination.direction) { + (OrderDirection::Asc, PaginationDirection::Forward) => ">", + (OrderDirection::Asc, PaginationDirection::Backward) => "<", + (OrderDirection::Desc, PaginationDirection::Forward) => "<", + (OrderDirection::Desc, PaginationDirection::Backward) => ">", + }; + + let condition = if i == 0 { + format!("[{}.{}] {} ?", ob.model, ob.member, operator) + } else { + let prev = (0..i) + .map(|j| { + let prev_ob = &pagination.order_by[j]; + format!("[{}.{}] = ?", prev_ob.model, prev_ob.member) + }) + .collect::>() + .join(" AND "); + format!("({} AND [{}.{}] {} ?)", prev, ob.model, ob.member, operator) + }; + conditions.push(condition); + binds.push(val.clone()); + } + let operator = if pagination.direction == PaginationDirection::Forward { + "<" + } else { + ">" + }; + conditions.push(format!("{}.event_id {} ?", table_name, operator)); + binds.push(values.last().unwrap().clone()); + } + } + Ok((conditions, binds)) +} + +pub(crate) fn build_cursor_values( + pagination: &Pagination, + row: &SqliteRow, +) -> Result, Error> { + if pagination.order_by.is_empty() { + Ok(vec![row.try_get("event_id")?]) + } else { + let mut values: Vec = pagination + .order_by + .iter() + .map(|ob| row.try_get::(&format!("{}.{}", ob.model, ob.member))) + .collect::, _>>()?; + values.push(row.try_get("event_id")?); + Ok(values) + } +} + +/// Compresses a string using Deflate and then encodes it using Base64 (no padding). +pub(crate) fn encode_cursor(value: &str) -> Result { + let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(value.as_bytes()).map_err(|e| { + Error::Query(QueryError::InvalidCursor(format!( + "Cursor compression error: {}", + e + ))) + })?; + let compressed_bytes = encoder.finish().map_err(|e| { + Error::Query(QueryError::InvalidCursor(format!( + "Cursor compression finish error: {}", + e + ))) + })?; + + Ok(BASE64_URL_SAFE_NO_PAD.encode(&compressed_bytes)) +} + +/// Decodes a Base64 (no padding) string and then decompresses it using Deflate. +pub(crate) fn decode_cursor(encoded_cursor: &str) -> Result { + let compressed_cursor_bytes = BASE64_URL_SAFE_NO_PAD.decode(encoded_cursor).map_err(|e| { + Error::Query(QueryError::InvalidCursor(format!( + "Base64 decode error: {}", + e + ))) + })?; + + let mut decoder = DeflateDecoder::new(&compressed_cursor_bytes[..]); + let mut decompressed_str = String::new(); + decoder.read_to_string(&mut decompressed_str).map_err(|e| { + Error::Query(QueryError::InvalidCursor(format!( + "Decompression error: {}", + e + ))) + })?; + + Ok(decompressed_str) +} diff --git a/crates/sqlite/sqlite/src/entities.rs b/crates/sqlite/sqlite/src/entities.rs new file mode 100644 index 00000000..ec780d27 --- /dev/null +++ b/crates/sqlite/sqlite/src/entities.rs @@ -0,0 +1,608 @@ +use std::collections::HashSet; +use std::str::FromStr; + +use dojo_types::naming::compute_selector_from_tag; +use dojo_types::schema::Ty; +use futures_util::future::try_join_all; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use starknet_crypto::Felt; +use torii_proto::schema::Entity; +use torii_sqlite_types::EntityType; + +use crate::cursor::{build_cursor_conditions, build_cursor_values, decode_cursor, encode_cursor}; +use crate::utils::{build_query, combine_where_clauses, map_row_to_entity}; +use crate::{error::Error, Sql}; +use crate::constants::{SQL_DEFAULT_LIMIT, SQL_MAX_JOINS}; +use crate::error::{ParseError, QueryError}; +use torii_proto::{OrderDirection, Page, Pagination, PaginationDirection}; + +impl Sql { + #[allow(clippy::too_many_arguments)] + pub async fn entities( + &self, + schemas: &[Ty], + table_name: &str, + model_relation_table: &str, + entity_relation_column: &str, + where_clause: Option<&str>, + having_clause: Option<&str>, + pagination: Pagination, + bind_values: Vec, + ) -> Result, Error> { + // Helper function to collect columns + fn collect_columns(table_prefix: &str, path: &str, ty: &Ty, selections: &mut Vec) { + match ty { + Ty::Struct(s) => { + for child in &s.children { + let new_path = if path.is_empty() { + child.name.clone() + } else { + format!("{}.{}", path, child.name) + }; + collect_columns(table_prefix, &new_path, &child.ty, selections); + } + } + Ty::Tuple(t) => { + for (i, child) in t.iter().enumerate() { + let new_path = if path.is_empty() { + format!("{}", i) + } else { + format!("{}.{}", path, i) + }; + collect_columns(table_prefix, &new_path, child, selections); + } + } + Ty::Enum(e) => { + selections.push(format!( + "[{table_prefix}].[{path}] as \"{table_prefix}.{path}\"", + )); + + for option in &e.options { + if let Ty::Tuple(t) = &option.ty { + if t.is_empty() { + continue; + } + } + let variant_path = format!("{}.{}", path, option.name); + collect_columns(table_prefix, &variant_path, &option.ty, selections); + } + } + Ty::Array(_) | Ty::Primitive(_) | Ty::ByteArray(_) => { + selections.push(format!( + "[{table_prefix}].[{path}] as \"{table_prefix}.{path}\"", + )); + } + } + } + + let original_limit = pagination.limit.unwrap_or(SQL_DEFAULT_LIMIT as u32); + let fetch_limit = original_limit + 1; + let mut has_more_pages = false; + + // Build order by clause with proper model joining + let order_by_models: HashSet = pagination + .order_by + .iter() + .map(|ob| ob.model.clone()) + .collect(); + + let order_clause = if pagination.order_by.is_empty() { + format!("{table_name}.event_id DESC") + } else { + pagination + .order_by + .iter() + .map(|ob| { + let direction = match (&ob.direction, &pagination.direction) { + (OrderDirection::Asc, PaginationDirection::Forward) => "ASC", + (OrderDirection::Asc, PaginationDirection::Backward) => "DESC", + (OrderDirection::Desc, PaginationDirection::Forward) => "DESC", + (OrderDirection::Desc, PaginationDirection::Backward) => "ASC", + }; + format!("[{}].[{}] {direction}", ob.model, ob.member) + }) + .chain(std::iter::once(format!("{table_name}.event_id DESC"))) + .collect::>() + .join(", ") + }; + + // Parse cursor + let cursor_values: Option> = pagination + .cursor + .as_ref() + .map(|cursor_str| { + let decompressed_str = decode_cursor(cursor_str)?; + Ok(decompressed_str.split('/').map(|s| s.to_string()).collect()) + }) + .transpose() + .map_err(|e: Error| Error::Query(QueryError::InvalidCursor(e.to_string())))?; + + // Build cursor conditions + let (cursor_conditions, cursor_binds) = + build_cursor_conditions(&pagination, cursor_values.as_deref(), table_name)?; + + // Combine WHERE clauses + let combined_where = combine_where_clauses(where_clause, &cursor_conditions); + + // Process schemas in chunks + let mut all_rows = Vec::new(); + let mut next_cursor = None; + + for chunk in schemas.chunks(SQL_MAX_JOINS) { + let mut selections = vec![ + format!("{}.id", table_name), + format!("{}.keys", table_name), + format!("{}.event_id", table_name), + format!( + "group_concat({}.model_id) as model_ids", + model_relation_table + ), + ]; + let mut joins = Vec::new(); + + // Add schema joins + for model in chunk { + let model_table = model.name(); + let join_type = if order_by_models.contains(&model_table) { + "INNER" + } else { + "LEFT" + }; + joins.push(format!( + "{join_type} JOIN [{model_table}] ON {table_name}.id = \ + [{model_table}].{entity_relation_column}", + )); + collect_columns(&model_table, "", model, &mut selections); + } + + joins.push(format!( + "JOIN {model_relation_table} ON {table_name}.id = {model_relation_table}.entity_id", + )); + + // Build and execute query + let query = build_query( + &selections, + table_name, + &joins, + &combined_where, + having_clause, + &order_clause, + ); + + let mut stmt = sqlx::query(&query); + for value in bind_values.iter().chain(cursor_binds.iter()) { + stmt = stmt.bind(value); + } + + stmt = stmt.bind(fetch_limit); + + let mut rows = stmt.fetch_all(&self.pool).await?; + let has_more = rows.len() >= fetch_limit as usize; + + if pagination.direction == PaginationDirection::Backward { + rows.reverse(); + } + if has_more { + // mark that there are more pages beyond the limit + has_more_pages = true; + rows.truncate(original_limit as usize); + } + + all_rows.extend(rows); + if has_more { + break; + } + } + + // Helper functions + // Replace generation of next cursor to only when there are more pages + if has_more_pages { + if let Some(last_row) = all_rows.last() { + let cursor_values_str = build_cursor_values(&pagination, last_row)?.join("/"); + next_cursor = Some(encode_cursor(&cursor_values_str)?); + } + } + + let entities: Vec = all_rows + .par_iter() + .map(|row| map_row_to_entity(row, schemas)) + .collect::, _>>()?; + Ok(Page { + items: entities, + next_cursor, + }) + } + + async fn fetch_historical_entities( + &self, + table: &str, + model_relation_table: &str, + where_clause: &str, + having_clause: &str, + mut bind_values: Vec, + pagination: Pagination, + ) -> Result, Error> { + if !pagination.order_by.is_empty() { + return Err(QueryError::UnsupportedQuery( + "Order by is not supported for historical entities".to_string(), + ) + .into()); + } + + let mut conditions = Vec::new(); + if !where_clause.is_empty() { + conditions.push(where_clause.to_string()); + } + + let order_direction = match pagination.direction { + PaginationDirection::Forward => "ASC", + PaginationDirection::Backward => "DESC", + }; + + // Add cursor condition if present + if let Some(ref cursor) = pagination.cursor { + let decoded_cursor = decode_cursor(cursor)?; + + let operator = match pagination.direction { + PaginationDirection::Forward => ">=", + PaginationDirection::Backward => "<=", + }; + conditions.push(format!("{table}.event_id {operator} ?")); + bind_values.push(decoded_cursor); + } + + let where_clause = if !conditions.is_empty() { + format!("WHERE {}", conditions.join(" AND ")) + } else { + String::new() + }; + + let limit = pagination.limit.unwrap_or(SQL_DEFAULT_LIMIT as u32); + let query_limit = limit + 1; + + let query_str = format!( + "SELECT {table}.id, {table}.data, {table}.model_id, {table}.event_id, \ + group_concat({model_relation_table}.model_id) as model_ids + FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + {where_clause} + GROUP BY {table}.event_id + {} + ORDER BY {table}.event_id {order_direction} + LIMIT ? + ", + if !having_clause.is_empty() { + format!("HAVING {}", having_clause) + } else { + String::new() + } + ); + + let mut query = sqlx::query_as(&query_str); + for value in bind_values { + query = query.bind(value); + } + query = query.bind(query_limit); + + let db_entities: Vec<(String, String, String, String, String)> = + query.fetch_all(&self.pool).await?; + + let has_more = db_entities.len() == query_limit as usize; + let results_to_take = if has_more { + limit as usize + } else { + db_entities.len() + }; + + let entities = db_entities + .iter() + .take(results_to_take) + .map(|(id, data, model_id, _, _)| async { + let hashed_keys = Felt::from_str(id) + .map_err(ParseError::FromStr)?; + let model = self + .model_cache + .model(&Felt::from_str(model_id).map_err(ParseError::FromStr)?) + .await?; + let mut schema = model.schema; + schema.from_json_value( + serde_json::from_str(data).map_err(ParseError::FromJsonStr)?, + )?; + + Ok::<_, Error>(Entity { + hashed_keys, + models: vec![schema.as_struct().unwrap().clone().into()], + }) + }) + // Collect the futures into a Vec + .collect::>(); + + // Execute all the async mapping operations concurrently + let entities: Vec = try_join_all(entities).await?; + + let next_cursor = if has_more { + db_entities + .last() + .map(|(_, _, _, event_id, _)| encode_cursor(event_id)) + .transpose()? + } else { + None + }; + + Ok(Page { + items: entities, + next_cursor, + }) + } + + /// Unified entrypoint for entity queries, covering all clause types and pagination. + #[allow(clippy::too_many_arguments)] + pub async fn query_entities( + &self, + table: &str, + entity_type: EntityType, + query: &torii_proto::Query, + model_cache: &crate::cache::ModelCache, + ) -> Result, Error> { + let no_hashed_keys = query.no_hashed_keys; + let models = query.models.clone(); + let pagination = query.pagination; + + // Helper for model selectors + let model_selectors = models.iter().map(|m| compute_selector_from_tag(m)).collect::>(); + let schemas = model_cache.models(&model_selectors).await?.iter().map(|m| m.schema.clone()).collect::>(); + let having_clause = model_selectors.iter().map(|model| format!("INSTR(model_ids, '{:#x}') > 0", model)).collect::>().join(" OR "); + + let model_relation_table = entity_type.relation_table(); + let entity_relation_column = entity_type.relation_column(); + + let page = match &query.clause { + None => { + // All entities + if table.ends_with("_historical") { + self.fetch_historical_entities( + table, + &entity_type.relation_table(), + "", + &having_clause, + vec![], + pagination, + ).await + } else { + self.entities( + &schemas, + table, + entity_type.relation_table(), + entity_type.relation_column(), + None, + if !having_clause.is_empty() { Some(&having_clause) } else { None }, + pagination, + vec![], + ).await + } + } + Some(clause) => { + match clause { + torii_proto::Clause::HashedKeys(hashed_keys) => { + let where_clause = if !hashed_keys.is_empty() { + let ids = hashed_keys.iter().map(|_| format!("{table}.id = ?")).collect::>(); + ids.join(" OR ") + } else { + String::new() + }; + let bind_values = hashed_keys.iter().map(|key| format!("{:#x}", key)).collect::>(); + if { + self.fetch_historical_entities( + table, + model_relation_table, + &where_clause, + &having_clause, + bind_values, + pagination, + ).await + } else { + self.entities( + &schemas, + table, + model_relation_table, + entity_relation_column, + if !where_clause.is_empty() { Some(&where_clause) } else { None }, + if !having_clause.is_empty() { Some(&having_clause) } else { None }, + pagination, + bind_values, + ).await + } + } + ClauseType::Keys(keys) => { + let keys_pattern = crate::utils::build_keys_pattern(keys)?; + let model_selectors: Vec = keys.models.iter().map(|model| format!("{:#x}", compute_selector_from_tag(model))).collect(); + let mut bind_values = vec![keys_pattern]; + let where_clause = if model_selectors.is_empty() { + format!("{table}.keys REGEXP ?") + } else { + let model_selectors_len = model_selectors.len(); + bind_values.extend(model_selectors.clone()); + bind_values.extend(model_selectors); + format!( + "({table}.keys REGEXP ? AND {model_relation_table}.model_id IN ({})) OR \\n {model_relation_table}.model_id NOT IN ({})", + vec!["?"; model_selectors_len].join(", "), + vec!["?"; model_selectors_len].join(", "), + ) + }; + if table.ends_with("_historical") { + self.fetch_historical_entities( + table, + model_relation_table, + &where_clause, + &having_clause, + bind_values, + pagination, + ).await + } else { + self.entities( + &schemas, + table, + model_relation_table, + entity_relation_column, + Some(&where_clause), + if !having_clause.is_empty() { Some(&having_clause) } else { None }, + pagination, + bind_values, + ).await + } + } + ClauseType::Member(member) => { + let comparison_operator = ComparisonOperator::from_repr(member.operator as usize).expect("invalid comparison operator"); + fn prepare_comparison(value: &torii_proto::proto::types::MemberValue, bind_values: &mut Vec) -> Result { + match &value.value_type { + Some(ValueType::String(value)) => { + bind_values.push(value.to_string()); + Ok("?".to_string()) + } + Some(ValueType::Primitive(value)) => { + let primitive: Primitive = (value.clone()).try_into()?; + bind_values.push(primitive.to_sql_value()); + Ok("?".to_string()) + } + Some(ValueType::List(values)) => Ok(format!( + "({})", + values.values.iter().map(|v| prepare_comparison(v, bind_values)).collect::, Error>>()?.join(", ") + )), + None => Err(QueryError::MissingParam("value_type".into()).into()), + } + } + let (namespace, model) = member.model.split_once('-').ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?; + let models_query = format!( + r#" + SELECT group_concat({model_relation_table}.model_id) as model_ids + FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + GROUP BY {table}.id + HAVING INSTR(model_ids, '{:#x}') > 0 + LIMIT 1 + "#, compute_selector_from_names(namespace, model)); + let models_str: Option = sqlx::query_scalar(&models_query).fetch_optional(&self.pool).await?; + if models_str.is_none() { + return Ok(Page { items: Vec::new(), next_cursor: None }); + } + let models_str = models_str.unwrap(); + let model_ids = models_str.split(',').filter_map(|id| { + let model_id = Felt::from_str(id).unwrap(); + if model_selectors.is_empty() || model_selectors.contains(&model_id) { + Some(model_id) + } else { + None + } + }).collect::>(); + let schemas = model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect::>(); + let mut bind_values = Vec::new(); + let value = prepare_comparison(member.value.as_ref().ok_or(QueryError::MissingParam("value".into()))?, &mut bind_values)?; + let where_clause = format!("[{}].[{}] {comparison_operator} {value}", member.model, member.member); + self.entities( + &schemas, + table, + model_relation_table, + entity_relation_column, + Some(&where_clause), + None, + pagination, + bind_values, + ).await + } + ClauseType::Composite(composite) => { + // Use the same build_composite_clause as in the server + fn build_composite_clause( + table: &str, + model_relation_table: &str, + composite: &torii_proto::proto::types::CompositeClause, + ) -> Result<(String, Vec), Error> { + let is_or = composite.operator == LogicalOperator::Or as i32; + let mut where_clauses = Vec::new(); + let mut bind_values = Vec::new(); + for clause in &composite.clauses { + match clause.clause_type.as_ref().unwrap() { + ClauseType::HashedKeys(hashed_keys) => { + let ids = hashed_keys.hashed_keys.iter().map(|id| { + bind_values.push(Felt::from_bytes_be_slice(id).to_string()); + "?".to_string() + }).collect::>().join(", "); + where_clauses.push(format!("({table}.id IN ({}))", ids)); + } + ClauseType::Keys(keys) => { + let keys_pattern = crate::utils::build_keys_pattern(keys)?; + bind_values.push(keys_pattern); + let model_selectors: Vec = keys.models.iter().map(|model| format!("{:#x}", compute_selector_from_tag(model))).collect(); + if model_selectors.is_empty() { + where_clauses.push(format!("({table}.keys REGEXP ?)")); + } else { + let placeholders = vec!["?"; model_selectors.len()].join(", "); + where_clauses.push(format!( + "({table}.keys REGEXP ? AND {model_relation_table}.model_id IN ({})) OR \\n {model_relation_table}.model_id NOT IN ({})", + placeholders, placeholders + )); + bind_values.extend(model_selectors.clone()); + bind_values.extend(model_selectors); + } + } + ClauseType::Member(member) => { + let comparison_operator = ComparisonOperator::from_repr(member.operator as usize).expect("invalid comparison operator"); + let value = member.value.clone().ok_or(QueryError::MissingParam("value".into()))?; + fn prepare_comparison(value: &torii_proto::proto::types::MemberValue, bind_values: &mut Vec) -> Result { + match &value.value_type { + Some(ValueType::String(value)) => { + bind_values.push(value.to_string()); + Ok("?".to_string()) + } + Some(ValueType::Primitive(value)) => { + let primitive: Primitive = (value.clone()).try_into()?; + bind_values.push(primitive.to_sql_value()); + Ok("?".to_string()) + } + Some(ValueType::List(values)) => Ok(format!( + "({})", + values.values.iter().map(|v| prepare_comparison(v, bind_values)).collect::, Error>>()?.join(", ") + )), + None => Err(QueryError::MissingParam("value_type".into()).into()), + } + } + let value = prepare_comparison(&value, &mut bind_values)?; + let model = member.model.clone(); + where_clauses.push(format!("([{model}].[{}] {comparison_operator} {value})", member.member)); + } + ClauseType::Composite(nested) => { + let (nested_where, nested_values) = build_composite_clause(table, model_relation_table, nested)?; + if !nested_where.is_empty() { + where_clauses.push(nested_where); + } + bind_values.extend(nested_values); + } + } + } + let where_clause = if !where_clauses.is_empty() { + where_clauses.join(if is_or { " OR " } else { " AND " }) + } else { + String::new() + }; + Ok((where_clause, bind_values)) + } + let (where_clause, bind_values) = build_composite_clause(table, model_relation_table, &composite)?; + self.entities( + &schemas, + table, + model_relation_table, + entity_relation_column, + if where_clause.is_empty() { None } else { Some(&where_clause) }, + if having_clause.is_empty() { None } else { Some(&having_clause) }, + pagination, + bind_values, + ).await + } + } + } + }?; + Ok(Page { + items: page.items, + next_cursor: page.next_cursor, + }) + } +} diff --git a/crates/sqlite/sqlite/src/executor/erc.rs b/crates/sqlite/sqlite/src/executor/erc.rs index d773b8bb..b1fc5a0e 100644 --- a/crates/sqlite/sqlite/src/executor/erc.rs +++ b/crates/sqlite/sqlite/src/executor/erc.rs @@ -6,10 +6,11 @@ use starknet::core::types::{BlockId, BlockTag, FunctionCall, U256}; use starknet::core::utils::get_selector_from_name; use starknet::providers::Provider; use starknet_crypto::Felt; +use torii_sqlite_types::Table; use tracing::{debug, warn}; use super::{ApplyBalanceDiffQuery, BrokerMessage, Executor}; -use crate::constants::{SQL_FELT_DELIMITER, TOKEN_BALANCE_TABLE}; +use crate::constants::SQL_FELT_DELIMITER; use crate::error::Error; use crate::executor::LOG_TARGET; use crate::simple_broker::SimpleBroker; @@ -126,7 +127,8 @@ impl Executor<'_, P> { ) -> Result<(), Error> { let tx = &mut self.transaction; let balance: Option<(String,)> = sqlx::query_as(&format!( - "SELECT balance FROM {TOKEN_BALANCE_TABLE} WHERE id = ?" + "SELECT balance FROM {table} WHERE id = ?", + table = Table::TokenBalances )) .bind(id) .fetch_optional(&mut **tx) @@ -188,8 +190,9 @@ impl Executor<'_, P> { // write the new balance to the database let token_balance: TokenBalance = sqlx::query_as(&format!( - "INSERT INTO {TOKEN_BALANCE_TABLE} (id, contract_address, account_address, \ + "INSERT INTO {table} (id, contract_address, account_address, \ token_id, balance) VALUES (?, ?, ?, ?, ?) ON CONFLICT DO UPDATE SET balance = EXCLUDED.balance RETURNING *", + table = Table::TokenBalances )) .bind(id) .bind(contract_address) diff --git a/crates/sqlite/sqlite/src/executor/mod.rs b/crates/sqlite/sqlite/src/executor/mod.rs index 70546c77..472f4f76 100644 --- a/crates/sqlite/sqlite/src/executor/mod.rs +++ b/crates/sqlite/sqlite/src/executor/mod.rs @@ -16,10 +16,9 @@ use tokio::sync::broadcast::{Receiver, Sender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; use tokio::time::Instant; -use torii_sqlite_types::OptimisticToken; +use torii_sqlite_types::{OptimisticToken, Table}; use tracing::{debug, error, info, warn}; -use crate::constants::TOKENS_TABLE; use crate::error::ParseError; use crate::executor::error::{ExecutorError, ExecutorQueryError}; use crate::simple_broker::SimpleBroker; @@ -635,7 +634,8 @@ impl Executor<'_, P> { QueryType::RegisterNftToken(register_nft_token) => { // Check if we already have the metadata for this contract let res = sqlx::query_as::<_, (String, String)>(&format!( - "SELECT name, symbol FROM {TOKENS_TABLE} WHERE contract_address = ? LIMIT 1" + "SELECT name, symbol FROM {table} WHERE contract_address = ? LIMIT 1", + table = Table::Tokens )) .bind(felt_to_sql_string(®ister_nft_token.contract_address)) .fetch_one(&mut **tx) diff --git a/crates/sqlite/sqlite/src/lib.rs b/crates/sqlite/sqlite/src/lib.rs index e1e616c1..17d0e1a3 100644 --- a/crates/sqlite/sqlite/src/lib.rs +++ b/crates/sqlite/sqlite/src/lib.rs @@ -2,20 +2,27 @@ use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Utc}; +use constants::SQL_DEFAULT_LIMIT; +use cursor::{decode_cursor, encode_cursor}; use dojo_types::naming::get_tag; use dojo_types::primitive::SqlType; use dojo_types::schema::{Struct, Ty}; use dojo_world::config::WorldMetadata; use dojo_world::contracts::abigen::model::Layout; use dojo_world::contracts::naming::compute_selector_from_names; +use error::{Error, ParseError}; use executor::{EntityQuery, StoreTransactionQuery}; +use rayon::iter::IntoParallelRefIterator; use sqlx::{Pool, Sqlite}; use starknet::core::types::{Event, Felt}; use starknet_crypto::poseidon_hash_many; use tokio::sync::mpsc::UnboundedSender; +use torii_proto::{Controller, Page}; use tokio::sync::Semaphore; +use utils::{build_keys_pattern, felts_to_sql_string, map_row_to_event}; use torii_sqlite_types::{ContractCursor, HookEvent, ParsedCall}; -use utils::felts_to_sql_string; use crate::constants::SQL_FELT_DELIMITER; use crate::error::{Error, ParseError}; @@ -34,6 +41,9 @@ pub mod executor; pub mod model; pub mod simple_broker; pub mod utils; +pub mod entities; +pub mod tokens; +mod cursor; use cache::{LocalCache, Model, ModelCache}; pub use torii_sqlite_types as types; @@ -567,6 +577,10 @@ impl Sql { self.model_cache.model(&selector).await } + pub async fn models(&self, selectors: &[Felt]) -> Result, Error> { + self.model_cache.models(selectors).await + } + #[allow(clippy::too_many_arguments)] pub fn store_transaction( &mut self, @@ -1161,4 +1175,113 @@ impl Sql { Ok(()) } + + pub async fn events( + &self, + query: &torii_proto::EventQuery, + ) -> Result, Error> { + let limit = if query.limit > 0 { + query.limit + 1 + } else { + SQL_DEFAULT_LIMIT as u32 + 1 + }; + + let mut bind_values = Vec::new(); + let mut conditions = Vec::new(); + + let keys_pattern = if let Some(keys_clause) = &query.keys { + build_keys_pattern(keys_clause)? + } else { + String::new() + }; + + if !keys_pattern.is_empty() { + conditions.push("keys REGEXP ?"); + bind_values.push(keys_pattern); + } + + if let Some(cursor) = &query.cursor { + conditions.push("id >= ?"); + bind_values.push(decode_cursor(&cursor)?); + } + + let mut events_query = r#" + SELECT id, keys, data, transaction_hash + FROM events + "# + .to_string(); + + if !conditions.is_empty() { + events_query = format!("{} WHERE {}", events_query, conditions.join(" AND ")); + } + + events_query = format!("{} ORDER BY id LIMIT ?", events_query); + bind_values.push(limit.to_string()); + + let mut row_events = sqlx::query_as(&events_query); + for value in &bind_values { + row_events = row_events.bind(value); + } + let mut row_events: Vec<(String, String, String, String)> = + row_events.fetch_all(&self.pool).await?; + + let next_cursor = if row_events.len() > (limit - 1) as usize { + Some(encode_cursor(&row_events.pop().unwrap().0)?) + } else { + None + }; + + let events = row_events + .par_iter() + .map(|(_, keys, data, transaction_hash)| { + map_row_to_event(&(keys, data, transaction_hash)) + }) + .collect::, Error>>()?; + + Ok(Page { + items: events, + next_cursor, + }) + } + + pub async fn controllers( + &self, + contract_addresses: Vec, + ) -> Result, Error> { + let query = if contract_addresses.is_empty() { + "SELECT address, username, deployed_at FROM controllers".to_string() + } else { + format!( + "SELECT address, username, deployed_at FROM controllers WHERE address IN ({})", + contract_addresses + .iter() + .map(|_| "?".to_string()) + .collect::>() + .join(", ") + ) + }; + + let mut db_query = sqlx::query_as::<_, (String, String, DateTime)>(&query); + for address in &contract_addresses { + db_query = db_query.bind(format!("{:#x}", address)); + } + + let rows = db_query.fetch_all(&self.pool).await?; + + let controllers = rows + .into_iter() + .map( + |(address, username, deployed_at)| Controller { + address: Felt::from_str(&address).unwrap(), + username, + deployed_at: deployed_at.timestamp() as u64, + }, + ) + .collect(); + + Ok(Page { + items: controllers, + next_cursor: None, + }) + } } diff --git a/crates/sqlite/sqlite/src/model.rs b/crates/sqlite/sqlite/src/model.rs index 8410cbc4..7f832039 100644 --- a/crates/sqlite/sqlite/src/model.rs +++ b/crates/sqlite/sqlite/src/model.rs @@ -1,10 +1,4 @@ -use base64::prelude::BASE64_URL_SAFE_NO_PAD; -use base64::Engine; -use flate2::read::DeflateDecoder; -use flate2::write::DeflateEncoder; -use flate2::Compression; -use std::collections::HashSet; -use std::io::prelude::*; + use std::str::FromStr; use async_trait::async_trait; @@ -18,10 +12,9 @@ use sqlx::sqlite::SqliteRow; use sqlx::{Pool, Row, Sqlite}; use starknet::core::types::Felt; +use crate::error::ParseError; + use super::error::{self, Error}; -use crate::constants::{SQL_DEFAULT_LIMIT, SQL_MAX_JOINS}; -use crate::error::{ParseError, QueryError}; -use crate::types::{OrderDirection, Page, Pagination, PaginationDirection}; #[derive(Debug)] pub struct ModelSQLReader { @@ -295,350 +288,3 @@ pub fn map_row_to_ty( Ok(()) } - -#[allow(clippy::too_many_arguments)] -pub async fn fetch_entities( - pool: &Pool, - schemas: &[Ty], - table_name: &str, - model_relation_table: &str, - entity_relation_column: &str, - where_clause: Option<&str>, - having_clause: Option<&str>, - pagination: Pagination, - bind_values: Vec, -) -> Result, Error> { - // Helper function to collect columns - fn collect_columns(table_prefix: &str, path: &str, ty: &Ty, selections: &mut Vec) { - match ty { - Ty::Struct(s) => { - for child in &s.children { - let new_path = if path.is_empty() { - child.name.clone() - } else { - format!("{}.{}", path, child.name) - }; - collect_columns(table_prefix, &new_path, &child.ty, selections); - } - } - Ty::Tuple(t) => { - for (i, child) in t.iter().enumerate() { - let new_path = if path.is_empty() { - format!("{}", i) - } else { - format!("{}.{}", path, i) - }; - collect_columns(table_prefix, &new_path, child, selections); - } - } - Ty::Enum(e) => { - selections.push(format!( - "[{table_prefix}].[{path}] as \"{table_prefix}.{path}\"", - )); - - for option in &e.options { - if let Ty::Tuple(t) = &option.ty { - if t.is_empty() { - continue; - } - } - let variant_path = format!("{}.{}", path, option.name); - collect_columns(table_prefix, &variant_path, &option.ty, selections); - } - } - Ty::Array(_) | Ty::Primitive(_) | Ty::ByteArray(_) => { - selections.push(format!( - "[{table_prefix}].[{path}] as \"{table_prefix}.{path}\"", - )); - } - } - } - - let original_limit = pagination.limit.unwrap_or(SQL_DEFAULT_LIMIT as u32); - let fetch_limit = original_limit + 1; - let mut has_more_pages = false; - - // Build order by clause with proper model joining - let order_by_models: HashSet = pagination - .order_by - .iter() - .map(|ob| ob.model.clone()) - .collect(); - - let order_clause = if pagination.order_by.is_empty() { - format!("{table_name}.event_id DESC") - } else { - pagination - .order_by - .iter() - .map(|ob| { - let direction = match (&ob.direction, &pagination.direction) { - (OrderDirection::Asc, PaginationDirection::Forward) => "ASC", - (OrderDirection::Asc, PaginationDirection::Backward) => "DESC", - (OrderDirection::Desc, PaginationDirection::Forward) => "DESC", - (OrderDirection::Desc, PaginationDirection::Backward) => "ASC", - }; - format!("[{}].[{}] {direction}", ob.model, ob.member) - }) - .chain(std::iter::once(format!("{table_name}.event_id DESC"))) - .collect::>() - .join(", ") - }; - - // Parse cursor - let cursor_values: Option> = pagination - .cursor - .as_ref() - .map(|cursor_str| { - let decompressed_str = decode_cursor(cursor_str)?; - Ok(decompressed_str.split('/').map(|s| s.to_string()).collect()) - }) - .transpose() - .map_err(|e: Error| Error::Query(QueryError::InvalidCursor(e.to_string())))?; - - // Build cursor conditions - let (cursor_conditions, cursor_binds) = - build_cursor_conditions(&pagination, cursor_values.as_deref(), table_name)?; - - // Combine WHERE clauses - let combined_where = combine_where_clauses(where_clause, &cursor_conditions); - - // Process schemas in chunks - let mut all_rows = Vec::new(); - let mut next_cursor = None; - - for chunk in schemas.chunks(SQL_MAX_JOINS) { - let mut selections = vec![ - format!("{}.id", table_name), - format!("{}.keys", table_name), - format!("{}.event_id", table_name), - format!( - "group_concat({}.model_id) as model_ids", - model_relation_table - ), - ]; - let mut joins = Vec::new(); - - // Add schema joins - for model in chunk { - let model_table = model.name(); - let join_type = if order_by_models.contains(&model_table) { - "INNER" - } else { - "LEFT" - }; - joins.push(format!( - "{join_type} JOIN [{model_table}] ON {table_name}.id = \ - [{model_table}].{entity_relation_column}", - )); - collect_columns(&model_table, "", model, &mut selections); - } - - joins.push(format!( - "JOIN {model_relation_table} ON {table_name}.id = {model_relation_table}.entity_id", - )); - - // Build and execute query - let query = build_query( - &selections, - table_name, - &joins, - &combined_where, - having_clause, - &order_clause, - ); - - let mut stmt = sqlx::query(&query); - for value in bind_values.iter().chain(cursor_binds.iter()) { - stmt = stmt.bind(value); - } - - stmt = stmt.bind(fetch_limit); - - let mut rows = stmt.fetch_all(pool).await?; - let has_more = rows.len() >= fetch_limit as usize; - - if pagination.direction == PaginationDirection::Backward { - rows.reverse(); - } - if has_more { - // mark that there are more pages beyond the limit - has_more_pages = true; - rows.truncate(original_limit as usize); - } - - all_rows.extend(rows); - if has_more { - break; - } - } - - // Helper functions - // Replace generation of next cursor to only when there are more pages - if has_more_pages { - if let Some(last_row) = all_rows.last() { - let cursor_values_str = build_cursor_values(&pagination, last_row)?.join("/"); - next_cursor = Some(encode_cursor(&cursor_values_str)?); - } - } - - Ok(Page { - items: all_rows, - next_cursor, - }) -} - -// Helper functions -fn build_cursor_conditions( - pagination: &Pagination, - cursor_values: Option<&[String]>, - table_name: &str, -) -> Result<(Vec, Vec), Error> { - let mut conditions = Vec::new(); - let mut binds = Vec::new(); - - if let Some(values) = cursor_values { - let expected_len = if pagination.order_by.is_empty() { - 1 - } else { - pagination.order_by.len() + 1 - }; - if values.len() != expected_len { - return Err(Error::Query(QueryError::InvalidCursor( - "Invalid cursor values length".to_string(), - ))); - } - - if pagination.order_by.is_empty() { - let operator = if pagination.direction == PaginationDirection::Forward { - "<" - } else { - ">" - }; - conditions.push(format!("{}.event_id {} ?", table_name, operator)); - binds.push(values[0].clone()); - } else { - for (i, (ob, val)) in pagination.order_by.iter().zip(values).enumerate() { - let operator = match (&ob.direction, &pagination.direction) { - (OrderDirection::Asc, PaginationDirection::Forward) => ">", - (OrderDirection::Asc, PaginationDirection::Backward) => "<", - (OrderDirection::Desc, PaginationDirection::Forward) => "<", - (OrderDirection::Desc, PaginationDirection::Backward) => ">", - }; - - let condition = if i == 0 { - format!("[{}.{}] {} ?", ob.model, ob.member, operator) - } else { - let prev = (0..i) - .map(|j| { - let prev_ob = &pagination.order_by[j]; - format!("[{}.{}] = ?", prev_ob.model, prev_ob.member) - }) - .collect::>() - .join(" AND "); - format!("({} AND [{}.{}] {} ?)", prev, ob.model, ob.member, operator) - }; - conditions.push(condition); - binds.push(val.clone()); - } - let operator = if pagination.direction == PaginationDirection::Forward { - "<" - } else { - ">" - }; - conditions.push(format!("{}.event_id {} ?", table_name, operator)); - binds.push(values.last().unwrap().clone()); - } - } - Ok((conditions, binds)) -} - -fn combine_where_clauses(base: Option<&str>, cursor_conditions: &[String]) -> String { - let mut parts = Vec::new(); - if let Some(base_where) = base { - parts.push(base_where.to_string()); - } - parts.extend(cursor_conditions.iter().cloned()); - parts.join(" AND ") -} - -fn build_query( - selections: &[String], - table_name: &str, - joins: &[String], - where_clause: &str, - having_clause: Option<&str>, - order_clause: &str, -) -> String { - let mut query = format!( - "SELECT {} FROM [{}] {}", - selections.join(", "), - table_name, - joins.join(" ") - ); - if !where_clause.is_empty() { - query.push_str(&format!(" WHERE {}", where_clause)); - } - - query.push_str(&format!(" GROUP BY {}.id", table_name)); - - if let Some(having) = having_clause { - query.push_str(&format!(" HAVING {}", having)); - } - query.push_str(&format!(" ORDER BY {} LIMIT ?", order_clause)); - query -} - -fn build_cursor_values(pagination: &Pagination, row: &SqliteRow) -> Result, Error> { - if pagination.order_by.is_empty() { - Ok(vec![row.try_get("event_id")?]) - } else { - let mut values: Vec = pagination - .order_by - .iter() - .map(|ob| row.try_get::(&format!("{}.{}", ob.model, ob.member))) - .collect::, _>>()?; - values.push(row.try_get("event_id")?); - Ok(values) - } -} - -/// Compresses a string using Deflate and then encodes it using Base64 (no padding). -pub fn encode_cursor(value: &str) -> Result { - let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default()); - encoder.write_all(value.as_bytes()).map_err(|e| { - Error::Query(QueryError::InvalidCursor(format!( - "Cursor compression error: {}", - e - ))) - })?; - let compressed_bytes = encoder.finish().map_err(|e| { - Error::Query(QueryError::InvalidCursor(format!( - "Cursor compression finish error: {}", - e - ))) - })?; - - Ok(BASE64_URL_SAFE_NO_PAD.encode(&compressed_bytes)) -} - -/// Decodes a Base64 (no padding) string and then decompresses it using Deflate. -pub fn decode_cursor(encoded_cursor: &str) -> Result { - let compressed_cursor_bytes = BASE64_URL_SAFE_NO_PAD.decode(encoded_cursor).map_err(|e| { - Error::Query(QueryError::InvalidCursor(format!( - "Base64 decode error: {}", - e - ))) - })?; - - let mut decoder = DeflateDecoder::new(&compressed_cursor_bytes[..]); - let mut decompressed_str = String::new(); - decoder.read_to_string(&mut decompressed_str).map_err(|e| { - Error::Query(QueryError::InvalidCursor(format!( - "Decompression error: {}", - e - ))) - })?; - - Ok(decompressed_str) -} diff --git a/crates/sqlite/sqlite/src/tokens.rs b/crates/sqlite/sqlite/src/tokens.rs new file mode 100644 index 00000000..ca52ed18 --- /dev/null +++ b/crates/sqlite/sqlite/src/tokens.rs @@ -0,0 +1,128 @@ +use crypto_bigint::U256; +use starknet_crypto::Felt; +use torii_proto::{Page, Token, TokenBalance}; + +use crate::{constants::SQL_DEFAULT_LIMIT, cursor::{decode_cursor, encode_cursor}, error::Error, utils::u256_to_sql_string, Sql}; + +impl Sql { + async fn tokens( + &self, + contract_addresses: Vec, + token_ids: Vec, + limit: Option, + cursor: Option, + ) -> Result, Error> { + let mut query = "SELECT * FROM tokens".to_string(); + let mut bind_values = Vec::new(); + let mut conditions = Vec::new(); + + if !contract_addresses.is_empty() { + let placeholders = vec!["?"; contract_addresses.len()].join(", "); + conditions.push(format!("contract_address IN ({})", placeholders)); + bind_values.extend(contract_addresses.iter().map(|addr| format!("{:#x}", addr))); + } + if !token_ids.is_empty() { + let placeholders = vec!["?"; token_ids.len()].join(", "); + conditions.push(format!("token_id IN ({})", placeholders)); + bind_values.extend(token_ids.iter().map(|id| u256_to_sql_string(&(*id).into()))); + } + + if let Some(cursor) = cursor { + bind_values.push(decode_cursor(&cursor)?); + conditions.push("id >= ?".to_string()); + } + + if !conditions.is_empty() { + query += &format!(" WHERE {}", conditions.join(" AND ")); + } + + query += " ORDER BY id LIMIT ?"; + bind_values.push((limit.unwrap_or(SQL_DEFAULT_LIMIT as u32) + 1).to_string()); + + let mut query = sqlx::query_as(&query); + for value in bind_values { + query = query.bind(value); + } + + let mut tokens: Vec = query.fetch_all(&self.pool).await?; + let next_cursor = if tokens.len() > limit.unwrap_or(SQL_DEFAULT_LIMIT as u32) as usize { + Some(encode_cursor(&tokens.pop().unwrap().id)?) + } else { + None + }; + + let tokens = tokens.iter().map(|token| token.clone().into()).collect(); + Ok(Page { + items: tokens, + next_cursor, + }) + } + + + async fn token_balances( + &self, + account_addresses: Vec, + contract_addresses: Vec, + token_ids: Vec, + limit: Option, + cursor: Option, + ) -> Result, Error> { + let mut query = "SELECT * FROM token_balances".to_string(); + let mut bind_values = Vec::new(); + let mut conditions = Vec::new(); + + if !account_addresses.is_empty() { + let placeholders = vec!["?"; account_addresses.len()].join(", "); + conditions.push(format!("account_address IN ({})", placeholders)); + bind_values.extend(account_addresses.iter().map(|addr| format!("{:#x}", addr))); + } + + if !contract_addresses.is_empty() { + let placeholders = vec!["?"; contract_addresses.len()].join(", "); + conditions.push(format!("contract_address IN ({})", placeholders)); + bind_values.extend(contract_addresses.iter().map(|addr| format!("{:#x}", addr))); + } + + if !token_ids.is_empty() { + let placeholders = vec!["?"; token_ids.len()].join(", "); + conditions.push(format!( + "SUBSTR(token_id, INSTR(token_id, ':') + 1) IN ({})", + placeholders + )); + bind_values.extend(token_ids.iter().map(|id| u256_to_sql_string(&(*id).into()))); + } + + if let Some(cursor) = cursor { + bind_values.push(decode_cursor(&cursor)?); + conditions.push("id >= ?".to_string()); + } + + if !conditions.is_empty() { + query += &format!(" WHERE {}", conditions.join(" AND ")); + } + + query += " ORDER BY id LIMIT ?"; + bind_values.push((limit.unwrap_or(SQL_DEFAULT_LIMIT as u32) + 1).to_string()); + + let mut query = sqlx::query_as(&query); + for value in bind_values { + query = query.bind(value); + } + + let mut balances: Vec = query.fetch_all(&self.pool).await?; + let next_cursor = if balances.len() > limit.unwrap_or(SQL_DEFAULT_LIMIT as u32) as usize { + Some(encode_cursor(&balances.pop().unwrap().id)?) + } else { + None + }; + + let balances = balances + .iter() + .map(|balance| balance.clone().into()) + .collect(); + Ok(Page { + items: balances, + next_cursor, + }) + } +} \ No newline at end of file diff --git a/crates/sqlite/sqlite/src/utils.rs b/crates/sqlite/sqlite/src/utils.rs index d62dc841..8299f24c 100644 --- a/crates/sqlite/sqlite/src/utils.rs +++ b/crates/sqlite/sqlite/src/utils.rs @@ -6,10 +6,13 @@ use std::time::Duration; use base64::engine::general_purpose::STANDARD; use base64::Engine; use chrono::{DateTime, Utc}; +use dojo_types::naming::compute_selector_from_tag; +use dojo_types::schema::Ty; use futures_util::TryStreamExt; use ipfs_api_backend_hyper::{IpfsApi, IpfsClient, TryFromUri}; use once_cell::sync::Lazy; use reqwest::Client; +use sqlx::sqlite::SqliteRow; use sqlx::{Column, Row, TypeInfo}; use starknet::core::types::U256; use starknet_crypto::Felt; @@ -20,8 +23,123 @@ use crate::constants::{ IPFS_CLIENT_PASSWORD, IPFS_CLIENT_URL, IPFS_CLIENT_USERNAME, REQ_MAX_RETRIES, SQL_FELT_DELIMITER, }; +use crate::error::{Error, ParseError}; +use crate::model::map_row_to_ty; use crate::error::HttpError; +fn process_event_field(data: &str) -> Result, Error> { + Ok(data + .trim_end_matches('/') + .split('/') + .filter(|&d| !d.is_empty()) + .map(|d| Felt::from_str(d).map_err(ParseError::FromStr)) + .collect::, _>>()?) +} + +pub(crate) fn map_row_to_event(row: &(&str, &str, &str)) -> Result { + let keys = process_event_field(row.0)?; + let data = process_event_field(row.1)?; + let transaction_hash = Felt::from_str(row.2).map_err(ParseError::FromStr)?; + + Ok(torii_proto::Event { + keys, + data, + transaction_hash, + }) +} + +// this builds a sql safe regex pattern to match against for keys +pub(crate) fn build_keys_pattern(clause: &torii_proto::KeysClause) -> Result { + const KEY_PATTERN: &str = "0x[0-9a-fA-F]+"; + + let keys = if clause.keys.is_empty() { + vec![KEY_PATTERN.to_string()] + } else { + clause + .keys + .iter() + .map(|key| { + if let Some(key) = key { + Ok(format!("{:#x}", key)) + } else { + Ok(KEY_PATTERN.to_string()) + } + }) + .collect::, Error>>()? + }; + let mut keys_pattern = format!("^{}", keys.join("/")); + + if clause.pattern_matching == torii_proto::PatternMatching::VariableLen { + keys_pattern += &format!("(/{})*", KEY_PATTERN); + } + keys_pattern += "/$"; + + Ok(keys_pattern) +} + +pub(crate) fn map_row_to_entity( + row: &SqliteRow, + schemas: &[Ty], +) -> Result { + let hashed_keys = Felt::from_str(&row.get::("id")).map_err(ParseError::FromStr)?; + let model_ids = row + .get::("model_ids") + .split(',') + .map(|id| Felt::from_str(id).map_err(ParseError::FromStr)) + .collect::, _>>()?; + + let models = schemas + .iter() + .filter(|schema| model_ids.contains(&compute_selector_from_tag(&schema.name()))) + .map(|schema| { + let mut ty = schema.clone(); + map_row_to_ty("", &schema.name(), &mut ty, row)?; + Ok(ty.as_struct().unwrap().clone().into()) + }) + .collect::, Error>>()?; + + Ok(torii_proto::schema::Entity { + hashed_keys, + models, + }) +} + +pub(crate) fn combine_where_clauses(base: Option<&str>, cursor_conditions: &[String]) -> String { + let mut parts = Vec::new(); + if let Some(base_where) = base { + parts.push(base_where.to_string()); + } + parts.extend(cursor_conditions.iter().cloned()); + parts.join(" AND ") +} + +pub(crate) fn build_query( + selections: &[String], + table_name: &str, + joins: &[String], + where_clause: &str, + having_clause: Option<&str>, + order_clause: &str, +) -> String { + let mut query = format!( + "SELECT {} FROM [{}] {}", + selections.join(", "), + table_name, + joins.join(" ") + ); + if !where_clause.is_empty() { + query.push_str(&format!(" WHERE {}", where_clause)); + } + + query.push_str(&format!(" GROUP BY {}.id", table_name)); + + if let Some(having) = having_clause { + query.push_str(&format!(" HAVING {}", having)); + } + query.push_str(&format!(" ORDER BY {} LIMIT ?", order_clause)); + query +} + pub fn must_utc_datetime_from_timestamp(timestamp: u64) -> DateTime { let naive_dt = DateTime::from_timestamp(timestamp as i64, 0) .expect("Failed to convert timestamp to NaiveDateTime"); diff --git a/crates/sqlite/types/Cargo.toml b/crates/sqlite/types/Cargo.toml index 2b9cee9b..b748e85c 100644 --- a/crates/sqlite/types/Cargo.toml +++ b/crates/sqlite/types/Cargo.toml @@ -12,3 +12,5 @@ dojo-types.workspace = true sqlx.workspace = true serde.workspace = true anyhow.workspace = true +torii-proto.workspace = true +crypto-bigint.workspace = true \ No newline at end of file diff --git a/crates/sqlite/types/src/lib.rs b/crates/sqlite/types/src/lib.rs index 3c1924fe..2791b0a1 100644 --- a/crates/sqlite/types/src/lib.rs +++ b/crates/sqlite/types/src/lib.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use std::str::FromStr; use chrono::{DateTime, Utc}; +use crypto_bigint::U256; use dojo_types::schema::Ty; use serde::{Deserialize, Serialize}; use sqlx::FromRow; @@ -157,6 +158,20 @@ pub struct TokenCollection { pub metadata: String, } + +impl From for torii_proto::TokenCollection { + fn from(value: TokenCollection) -> Self { + Self { + contract_address: Felt::from_str(&value.contract_address).unwrap(), + name: value.name, + symbol: value.symbol, + decimals: value.decimals as u8, + count: value.count, + metadata: value.metadata, + } + } +} + #[derive(FromRow, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct OptimisticTokenBalance { @@ -299,35 +314,96 @@ pub enum HookEvent { ModelDeleted { model_tag: String }, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Page { - pub items: Vec, - pub next_cursor: Option, +impl From for torii_proto::Token { + fn from(value: Token) -> Self { + Self { + token_id: if value.token_id.is_empty() { + U256::ZERO + } else { + U256::from_be_hex(value.token_id.trim_start_matches("0x")) + }, + contract_address: Felt::from_str(&value.contract_address) + .unwrap(), + name: value.name, + symbol: value.symbol, + decimals: value.decimals as u8, + metadata: value.metadata, + } + } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum PaginationDirection { - Forward, - Backward, +impl From for torii_proto::TokenBalance { + fn from(value: TokenBalance) -> Self { + let id = value.token_id.split(':').collect::>(); + + Self { + balance: U256::from_be_hex(value.balance.trim_start_matches("0x")), + account_address: Felt::from_str(&value.account_address) + .unwrap(), + contract_address: Felt::from_str(&value.contract_address) + .unwrap(), + token_id: if id.len() == 2 { + U256::from_be_hex(id[1].trim_start_matches("0x")) + } else { + U256::ZERO + }, + } + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Pagination { - pub cursor: Option, - pub limit: Option, - pub direction: PaginationDirection, - pub order_by: Vec, +pub enum EntityType { + Entity, + EventMessage } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum OrderDirection { - Asc, - Desc, +impl EntityType { + pub fn relation_table(&self) -> &str { + match self { + EntityType::Entity => "entity_model", + EntityType::EventMessage => "event_model", + } + } + + pub fn relation_column(&self) -> &str { + match self { + EntityType::Entity => "entity_id", + EntityType::EventMessage => "event_message_id", + } + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct OrderBy { - pub model: String, - pub member: String, - pub direction: OrderDirection, -} +pub enum Table { + Entities, + EntitiesHistorical, + EventMessages, + EventMessagesHistorical, + Models, + Events, + Tokens, + TokenBalances, + Contracts, + Controllers, + Transactions, + Metadata +} + +impl std::fmt::Display for Table { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Table::Entities => write!(f, "entities"), + Table::EntitiesHistorical => write!(f, "entities_historical"), + Table::EventMessages => write!(f, "event_messages"), + Table::EventMessagesHistorical => write!(f, "event_messages_historical"), + Table::Models => write!(f, "models"), + Table::Events => write!(f, "events"), + Table::Tokens => write!(f, "tokens"), + Table::TokenBalances => write!(f, "token_balances"), + Table::Contracts => write!(f, "contracts"), + Table::Controllers => write!(f, "controllers"), + Table::Transactions => write!(f, "transactions"), + Table::Metadata => write!(f, "metadata"), + } + } +} \ No newline at end of file