diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 5d71591466..38ebded4b5 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -29,6 +29,7 @@ from typing import Literal from typing import Optional +from fastapi import Body from fastapi import FastAPI from fastapi import HTTPException from fastapi import Query @@ -868,6 +869,21 @@ async def create_session( return session + @app.patch( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/title", + response_model_exclude_none=True, + ) + async def update_session_title( + app_name: str, + user_id: str, + session_id: str, + title: Optional[str] = Body(None, embed=True), + ) -> dict[str, str]: + await self.session_service.update_session_title( + app_name=app_name, user_id=user_id, session_id=session_id, title=title + ) + return {"status": "success"} + @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") async def delete_session( app_name: str, user_id: str, session_id: str diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py index 12207e8070..19381a033f 100644 --- a/src/google/adk/cli/utils/local_storage.py +++ b/src/google/adk/cli/utils/local_storage.py @@ -157,6 +157,7 @@ async def create_session( user_id: str, state: Optional[dict[str, object]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: service = await self._get_service(app_name) return await service.create_session( @@ -164,6 +165,7 @@ async def create_session( user_id=user_id, state=state, session_id=session_id, + title=title, ) @override @@ -206,6 +208,23 @@ async def delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) + @override + async def update_session_title( + self, + *, + app_name: str, + user_id: str, + session_id: str, + title: Optional[str], + ) -> None: + service = await self._get_service(app_name) + await service.update_session_title( + app_name=app_name, + user_id=user_id, + session_id=session_id, + title=title, + ) + @override async def append_event(self, session: Session, event: Event) -> Event: service = await self._get_service(session.app_name) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index f2f6f9f22d..9002baaa46 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -56,6 +56,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: """Creates a new session. @@ -102,6 +103,24 @@ async def delete_session( ) -> None: """Deletes a session.""" + @abc.abstractmethod + async def update_session_title( + self, + *, + app_name: str, + user_id: str, + session_id: str, + title: Optional[str], + ) -> None: + """Updates the title of a session. + + Args: + app_name: The name of the app. + user_id: The id of the user. + session_id: The id of the session. + title: The new title for the session. If None, clears the title. + """ + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 3cc9bb6a68..b66ab71866 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -222,6 +222,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: # 1. Populate states. # 2. Build storage session object @@ -268,12 +269,15 @@ async def create_session( storage_user_state.state = storage_user_state.state | user_state_delta # Store the session - storage_session = schema.StorageSession( - app_name=app_name, - user_id=user_id, - id=session_id, - state=session_state, - ) + storage_session_kwargs = { + "app_name": app_name, + "user_id": user_id, + "id": session_id, + "state": session_state, + } + if hasattr(schema.StorageSession, "title"): + storage_session_kwargs["title"] = title + storage_session = schema.StorageSession(**storage_session_kwargs) sql_session.add(storage_session) await sql_session.commit() @@ -408,6 +412,30 @@ async def delete_session( await sql_session.execute(stmt) await sql_session.commit() + @override + async def update_session_title( + self, + *, + app_name: str, + user_id: str, + session_id: str, + title: Optional[str], + ) -> None: + await self._prepare_tables() + schema = self._get_schema_classes() + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( + schema.StorageSession, (app_name, user_id, session_id) + ) + if storage_session is None: + raise ValueError( + f"Session not found: app_name={app_name}, user_id={user_id}," + f" session_id={session_id}" + ) + if hasattr(storage_session, "title"): + storage_session.title = title + await sql_session.commit() + @override async def append_event(self, session: Session, event: Event) -> Event: await self._prepare_tables() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 6ba7f0bb01..523406ec5a 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -58,12 +58,14 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: return self._create_session_impl( app_name=app_name, user_id=user_id, state=state, session_id=session_id, + title=title, ) def create_session_sync( @@ -73,6 +75,7 @@ def create_session_sync( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: logger.warning('Deprecated. Please migrate to the async method.') return self._create_session_impl( @@ -80,6 +83,7 @@ def create_session_sync( user_id=user_id, state=state, session_id=session_id, + title=title, ) def _create_session_impl( @@ -89,6 +93,7 @@ def _create_session_impl( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: if session_id and self._get_session_impl( app_name=app_name, user_id=user_id, session_id=session_id @@ -116,6 +121,7 @@ def _create_session_impl( id=session_id, state=session_state or {}, last_update_time=time.time(), + title=title, ) if app_name not in self.sessions: @@ -286,6 +292,23 @@ def _delete_session_impl( self.sessions[app_name][user_id].pop(session_id) + @override + async def update_session_title( + self, + *, + app_name: str, + user_id: str, + session_id: str, + title: Optional[str], + ) -> None: + session = self.sessions.get(app_name, {}).get(user_id, {}).get(session_id) + if session is None: + raise ValueError( + f'Session not found: app_name={app_name}, user_id={user_id},' + f' session_id={session_id}' + ) + session.title = title + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py index a0dd3a84a1..0231888c02 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py @@ -91,13 +91,14 @@ def migrate(source_db_url: str, dest_db_path: str): sessions = source_session.query(v0_schema.StorageSession).all() for item in sessions: dest_cursor.execute( - "INSERT INTO sessions (app_name, user_id, id, state, create_time," - " update_time) VALUES (?, ?, ?, ?, ?, ?)", + "INSERT INTO sessions (app_name, user_id, id, state, title," + " create_time, update_time) VALUES (?, ?, ?, ?, ?, ?, ?)", ( item.app_name, item.user_id, item.id, json.dumps(item.state), + item.title, item.create_time.replace(tzinfo=timezone.utc).timestamp(), item.update_time.replace(tzinfo=timezone.utc).timestamp(), ), diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py index df309287fa..3710fa6f17 100644 --- a/src/google/adk/sessions/schemas/v1.py +++ b/src/google/adk/sessions/schemas/v1.py @@ -88,6 +88,9 @@ class StorageSession(Base): state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} ) + title: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) create_time: Mapped[datetime] = mapped_column( PreciseTimestamp, default=func.now() @@ -139,6 +142,7 @@ def to_session( state=state, events=events, last_update_time=self.update_timestamp_tz, + title=self.title, ) diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index e674dd3778..08982a3085 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -15,6 +15,7 @@ from __future__ import annotations from typing import Any +from typing import Optional from pydantic import alias_generators from pydantic import BaseModel @@ -48,3 +49,5 @@ class Session(BaseModel): call/response, etc.""" last_update_time: float = 0.0 """The last update time of the session.""" + title: Optional[str] = None + """The title of the session.""" diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index e0d44b3872..28b50b0e15 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -64,6 +64,7 @@ user_id TEXT NOT NULL, id TEXT NOT NULL, state TEXT NOT NULL, + title TEXT, create_time REAL NOT NULL, update_time REAL NOT NULL, PRIMARY KEY (app_name, user_id, id) @@ -121,6 +122,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, ) -> Session: if session_id: session_id = session_id.strip() @@ -160,14 +162,15 @@ async def create_session( # Store the session await db.execute( """ - INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO sessions (app_name, user_id, id, state, title, create_time, update_time) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( app_name, user_id, session_id, json.dumps(session_state), + title, now, now, ), @@ -185,6 +188,7 @@ async def create_session( state=merged_state, events=[], last_update_time=now, + title=title, ) @override @@ -198,7 +202,7 @@ async def get_session( ) -> Optional[Session]: async with self._get_db_connection() as db: async with db.execute( - "SELECT state, update_time FROM sessions WHERE app_name=? AND" + "SELECT state, title, update_time FROM sessions WHERE app_name=? AND" " user_id=? AND id=?", (app_name, user_id, session_id), ) as cursor: @@ -206,6 +210,7 @@ async def get_session( if session_row is None: return None session_state = json.loads(session_row["state"]) + title = session_row["title"] last_update_time = session_row["update_time"] # Build events query @@ -248,6 +253,7 @@ async def get_session( state=merged_state, events=events, last_update_time=last_update_time, + title=title, ) @override @@ -259,13 +265,13 @@ async def list_sessions( # Fetch sessions if user_id: session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" + "SELECT id, user_id, state, title, update_time FROM sessions WHERE" " app_name=? AND user_id=?", (app_name, user_id), ) else: session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" + "SELECT id, user_id, state, title, update_time FROM sessions WHERE" " app_name=?", (app_name,), ) @@ -291,6 +297,7 @@ async def list_sessions( for row in session_rows: session_user_id = row["user_id"] session_state = json.loads(row["state"]) + title = row["title"] user_state = user_states_map.get(session_user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) sessions_list.append( @@ -301,10 +308,32 @@ async def list_sessions( state=merged_state, events=[], last_update_time=row["update_time"], + title=title, ) ) return ListSessionsResponse(sessions=sessions_list) + @override + async def update_session_title( + self, + *, + app_name: str, + user_id: str, + session_id: str, + title: Optional[str], + ) -> None: + async with self._get_db_connection() as db: + cursor = await db.execute( + "UPDATE sessions SET title=? WHERE app_name=? AND user_id=? AND id=?", + (title, app_name, user_id, session_id), + ) + if cursor.rowcount == 0: + raise ValueError( + f"Session not found: app_name={app_name}, user_id={user_id}," + f" session_id={session_id}" + ) + await db.commit() + @override async def delete_session( self, *, app_name: str, user_id: str, session_id: str diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index cce7e99b32..13a6caffd0 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -83,6 +83,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + title: Optional[str] = None, **kwargs: Any, ) -> Session: """Creates a new session. @@ -92,6 +93,7 @@ async def create_session( user_id: The ID of the user. state: The initial state of the session. session_id: The ID of the session. + title: The title of the session. **kwargs: Additional arguments to pass to the session creation. E.g. set expire_time='2025-10-01T00:00:00Z' to set the session expiration time. See https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.reasoningEngines.sessions @@ -109,6 +111,8 @@ async def create_session( reasoning_engine_id = self._get_reasoning_engine_id(app_name) config = {'session_state': state} if state else {} + if title: + config['display_name'] = title config.update(kwargs) async with self._get_api_client() as api_client: api_response = await api_client.agent_engines.sessions.create( @@ -120,12 +124,14 @@ async def create_session( get_session_response = api_response.response session_id = get_session_response.name.split('/')[-1] + final_title = getattr(get_session_response, 'display_name', None) session = Session( app_name=app_name, user_id=user_id, id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=get_session_response.update_time.timestamp(), + title=final_title, ) return session @@ -169,12 +175,14 @@ async def get_session( ) update_timestamp = get_session_response.update_time.timestamp() + title = getattr(get_session_response, 'display_name', None) session = Session( app_name=app_name, user_id=user_id, id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=update_timestamp, + title=title, ) # Preserve the entire event stream that Vertex returns rather than trying # to discard events written milliseconds after the session resource was @@ -207,6 +215,7 @@ async def list_sessions( ) for api_session in sessions_iterator: + title = getattr(api_session, 'display_name', None) sessions.append( Session( app_name=app_name, @@ -214,11 +223,35 @@ async def list_sessions( id=api_session.name.split('/')[-1], state=getattr(api_session, 'session_state', None) or {}, last_update_time=api_session.update_time.timestamp(), + title=title, ) ) return ListSessionsResponse(sessions=sessions) + @override + async def update_session_title( + self, + *, + app_name: str, + user_id: str, + session_id: str, + title: Optional[str], + ) -> None: + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + + async with self._get_api_client() as api_client: + try: + await api_client.agent_engines.sessions.update( + name=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' + ), + display_name=title, + ) + except Exception as e: + logger.error('Error updating session title %s: %s', session_id, e) + raise + async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 45aa3feede..3915094838 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -143,6 +143,52 @@ async def test_create_and_list_sessions(service_type, tmp_path): assert session.state == {'key': 'value' + session.id} +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_session_title(service_type, tmp_path): + session_service = get_session_service(service_type, tmp_path) + app_name = 'my_app' + user_id = 'test_user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id, title='Test Title' + ) + assert session.title == 'Test Title' + + got_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got_session.title == 'Test Title' + + await session_service.update_session_title( + app_name=app_name, + user_id=user_id, + session_id=session.id, + title='New Title', + ) + + updated_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert updated_session.title == 'New Title' + + await session_service.update_session_title( + app_name=app_name, user_id=user_id, session_id=session.id, title=None + ) + + cleared_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert cleared_session.title is None + + @pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 14d2b15b6e..81c937b065 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -275,6 +275,7 @@ def __init__(self) -> None: self.agent_engines.sessions.list.side_effect = self._list_sessions self.agent_engines.sessions.delete.side_effect = self._delete_session self.agent_engines.sessions.create.side_effect = self._create_session + self.agent_engines.sessions.update.side_effect = self._update_session self.agent_engines.sessions.events.list.side_effect = self._list_events self.agent_engines.sessions.events.append.side_effect = self._append_event self.last_create_session_config: dict[str, Any] = {} @@ -318,6 +319,15 @@ async def _delete_session(self, name: str): session_id = name.split('/')[-1] self.session_dict.pop(session_id) + async def _update_session(self, name: str, **kwargs: Any): + session_id = name.split('/')[-1] + if session_id not in self.session_dict: + raise api_core_exceptions.NotFound(f'Session not found: {session_id}') + if 'display_name' in kwargs: + self.session_dict[session_id]['display_name'] = kwargs['display_name'] + if 'session_state' in kwargs: + self.session_dict[session_id]['session_state'] = kwargs['session_state'] + async def _create_session( self, name: str, user_id: str, config: dict[str, Any] ): @@ -331,6 +341,7 @@ async def _create_session( ), 'user_id': user_id, 'session_state': config.get('session_state', {}), + 'display_name': config.get('display_name'), 'update_time': '2024-12-12T12:12:12.123456Z', } return _convert_to_object({ @@ -703,3 +714,102 @@ async def test_append_event(): assert len(retrieved_session.events) == 2 event_to_append.id = retrieved_session.events[1].id assert retrieved_session.events[1] == event_to_append + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_create_session_with_title(mock_api_client_instance): + session_service = mock_vertex_ai_session_service() + + session = await session_service.create_session( + app_name='123', user_id='user', title='Test Title' + ) + assert session.title == 'Test Title' + assert ( + mock_api_client_instance.last_create_session_config['display_name'] + == 'Test Title' + ) + + got_session = await session_service.get_session( + app_name='123', user_id='user', session_id=session.id + ) + assert got_session.title == 'Test Title' + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_get_session_with_title(mock_api_client_instance): + session_service = mock_vertex_ai_session_service() + + mock_api_client_instance.session_dict['1']['session_state'] = { + 'key': {'value': 'test_value'}, + } + mock_api_client_instance.session_dict['1']['display_name'] = 'Existing Title' + + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert session.title == 'Existing Title' + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_list_sessions_with_title(mock_api_client_instance): + session_service = mock_vertex_ai_session_service() + + mock_api_client_instance.session_dict['1']['session_state'] = { + 'key': {'value': 'test_value'}, + } + mock_api_client_instance.session_dict['1']['display_name'] = 'Session 1 Title' + mock_api_client_instance.session_dict['2']['display_name'] = 'Session 2 Title' + + sessions = await session_service.list_sessions(app_name='123', user_id='user') + assert len(sessions.sessions) == 2 + assert sessions.sessions[0].title == 'Session 1 Title' + assert sessions.sessions[1].title == 'Session 2 Title' + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_update_session_title(mock_api_client_instance): + session_service = mock_vertex_ai_session_service() + + await session_service.update_session_title( + app_name='123', user_id='user', session_id='1', title='New Title' + ) + + updated_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert updated_session.title == 'New Title' + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_update_session_title_to_none(mock_api_client_instance): + session_service = mock_vertex_ai_session_service() + + mock_api_client_instance.session_dict['1']['session_state'] = { + 'key': {'value': 'test_value'}, + } + mock_api_client_instance.session_dict['1']['display_name'] = 'Old Title' + + await session_service.update_session_title( + app_name='123', user_id='user', session_id='1', title=None + ) + + updated_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert updated_session.title is None + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_update_session_title_not_found(mock_api_client_instance): + session_service = mock_vertex_ai_session_service() + + with pytest.raises(api_core_exceptions.NotFound): + await session_service.update_session_title( + app_name='123', user_id='user', session_id='nonexistent', title='Title' + )