From ea97886e51091a1eecbb9387d8fa3a364a1e3de2 Mon Sep 17 00:00:00 2001 From: TheGupta2012 Date: Fri, 19 Sep 2025 16:11:30 +0530 Subject: [PATCH 1/7] add qbraid specific changes --- .github/workflows/upload-s3-production.yml | 2 +- .github/workflows/upload-s3-staging.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/upload-s3-production.yml b/.github/workflows/upload-s3-production.yml index 8a62ed1..33b4cb3 100644 --- a/.github/workflows/upload-s3-production.yml +++ b/.github/workflows/upload-s3-production.yml @@ -88,4 +88,4 @@ jobs: - name: Copy wheel file to S3 lab-extensions bucket run: | aws s3 rm s3://qbraid-lab-extensions/production/ --recursive --exclude "*" --include "lab_notebook_intelligence*.whl" - aws s3 cp ./ s3://qbraid-lab-extensions/production/ --recursive --exclude "*" --include "lab_notebook_intelligence*.whl" \ No newline at end of file + aws s3 cp ./ s3://qbraid-lab-extensions/production/ --recursive --exclude "*" --include "lab_notebook_intelligence*.whl" diff --git a/.github/workflows/upload-s3-staging.yml b/.github/workflows/upload-s3-staging.yml index da74db5..0d3ac45 100644 --- a/.github/workflows/upload-s3-staging.yml +++ b/.github/workflows/upload-s3-staging.yml @@ -89,4 +89,4 @@ jobs: - name: Copy wheel file to S3 lab-extensions bucket run: | aws s3 rm s3://qbraid-lab-extensions/staging/ --recursive --exclude "*" --include "lab_notebook_intelligence*.whl" - aws s3 cp ./ s3://qbraid-lab-extensions/staging/ --recursive --exclude "*" --include "lab_notebook_intelligence*.whl" \ No newline at end of file + aws s3 cp ./ s3://qbraid-lab-extensions/staging/ --recursive --exclude "*" --include "lab_notebook_intelligence*.whl" From 48fa5de17eb0648fe80300a7c7beedc91398caac Mon Sep 17 00:00:00 2001 From: TheGupta2012 Date: Wed, 24 Sep 2025 15:09:59 +0530 Subject: [PATCH 2/7] add formatting and mintlify MCP server for qbraid docs --- .copier-answers.yml | 2 +- .github/workflows/build.yml | 6 - .github/workflows/check-release.yml | 3 +- .github/workflows/format.yml | 41 ++ .prettierignore | 2 +- lab_notebook_intelligence/__init__.py | 26 +- .../ai_service_manager.py | 322 ++++++--- lab_notebook_intelligence/api.py | 497 ++++++++----- .../base_chat_participant.py | 357 ++++++--- .../built_in_toolsets.py | 120 ++-- lab_notebook_intelligence/config.py | 79 +- lab_notebook_intelligence/extension.py | 676 ++++++++++++------ lab_notebook_intelligence/github_copilot.py | 445 +++++++----- .../github_copilot_chat_participant.py | 10 +- .../github_copilot_llm_provider.py | 103 ++- .../litellm_compatible_llm_provider.py | 76 +- .../llm_providers/ollama_llm_provider.py | 132 ++-- .../openai_compatible_llm_provider.py | 76 +- lab_notebook_intelligence/mcp_manager.py | 183 +++-- lab_notebook_intelligence/prompts.py | 17 +- lab_notebook_intelligence/util.py | 70 +- src/index.ts | 92 ++- 22 files changed, 2283 insertions(+), 1052 deletions(-) create mode 100644 .github/workflows/format.yml diff --git a/.copier-answers.yml b/.copier-answers.yml index 7d4168d..ae24d66 100644 --- a/.copier-answers.yml +++ b/.copier-answers.yml @@ -8,7 +8,7 @@ has_settings: true kind: server labextension_name: '@notebook-intelligence/notebook-intelligence' project_short_description: Notebook Intelligence extension for JupyterLab -python_name: notebook_intelligence +python_name: lab_notebook_intelligence repository: https://github.com/notebook-intelligence/notebook-intelligence test: false diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0be6ab5..f1ffd35 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,12 +24,6 @@ jobs: - name: Install dependencies run: python -m pip install -U "jupyterlab>=4.0.0,<5" - - name: Lint the extension - run: | - set -eux - jlpm - jlpm run lint:check - - name: Build the extension run: | set -eux diff --git a/.github/workflows/check-release.yml b/.github/workflows/check-release.yml index 466ad29..71a436c 100644 --- a/.github/workflows/check-release.yml +++ b/.github/workflows/check-release.yml @@ -20,11 +20,10 @@ jobs: - name: Check Release uses: jupyter-server/jupyter_releaser/.github/actions/check-release@v2 with: - token: ${{ secrets.GITHUB_TOKEN }} - name: Upload Distributions uses: actions/upload-artifact@v4 with: - name: notebook_intelligence-releaser-dist-${{ github.run_number }} + name: lab_notebook_intelligence-releaser-dist-${{ github.run_number }} path: .jupyter_releaser_checkout/dist diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..581daf5 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,41 @@ +name: Check Code Formatting + +on: + pull_request: + branches: ['main'] + types: [opened, reopened, ready_for_review, synchronize] + workflow_dispatch: + +jobs: + check-format: + name: Check Code Formatting + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Base Setup + uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + + - name: Lint the extension + run: | + set -eux + jlpm + jlpm run lint:check + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install black isort + + - name: Check Python code formatting + run: | + black --check lab_notebook_intelligence + isort --check lab_notebook_intelligence \ No newline at end of file diff --git a/.prettierignore b/.prettierignore index b23cfc8..76f0fbc 100644 --- a/.prettierignore +++ b/.prettierignore @@ -3,4 +3,4 @@ node_modules **/lib **/package.json !/package.json -notebook_intelligence +lab_notebook_intelligence diff --git a/lab_notebook_intelligence/__init__.py b/lab_notebook_intelligence/__init__.py index 2aed1ce..4776804 100644 --- a/lab_notebook_intelligence/__init__.py +++ b/lab_notebook_intelligence/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Mehmet Bektas - + try: from ._version import __version__ except ImportError: @@ -7,24 +7,26 @@ # in editable mode with pip. It is highly recommended to install # the package from a stable release or in editable mode: https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs import warnings - warnings.warn("Importing 'lab_notebook_intelligence' outside a proper installation.") + + warnings.warn( + "Importing 'lab_notebook_intelligence' outside a proper installation." + ) __version__ = "dev" import logging -logging.basicConfig(format='%(asctime)s - %(name)s - %(filename)s - %(levelname)s - %(message)s', level=logging.INFO) -from .extension import NotebookIntelligence +logging.basicConfig( + format="%(asctime)s - %(name)s - %(filename)s - %(levelname)s - %(message)s", + level=logging.INFO, +) + from .api import * +from .extension import NotebookIntelligence + def _jupyter_labextension_paths(): - return [{ - "src": "labextension", - "dest": "@qbraid/lab-notebook-intelligence" - }] + return [{"src": "labextension", "dest": "@qbraid/lab-notebook-intelligence"}] def _jupyter_server_extension_points(): - return [{ - "module": "lab_notebook_intelligence", - "app": NotebookIntelligence - }] + return [{"module": "lab_notebook_intelligence", "app": NotebookIntelligence}] diff --git a/lab_notebook_intelligence/ai_service_manager.py b/lab_notebook_intelligence/ai_service_manager.py index 2622923..be8722b 100644 --- a/lab_notebook_intelligence/ai_service_manager.py +++ b/lab_notebook_intelligence/ai_service_manager.py @@ -1,31 +1,94 @@ # Copyright (c) Mehmet Bektas import json -from os import path +import logging import os import sys +from os import path from typing import Dict -import logging + from lab_notebook_intelligence import github_copilot -from lab_notebook_intelligence.api import ButtonData, ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, ChatParticipant, ChatRequest, ChatResponse, CompletionContext, ContextRequest, Host, CompletionContextProvider, MCPServer, MarkdownData, NotebookIntelligenceExtension, TelemetryEvent, TelemetryListener, Tool, Toolset +from lab_notebook_intelligence.api import (ButtonData, ChatModel, + ChatParticipant, ChatRequest, + ChatResponse, CompletionContext, + CompletionContextProvider, + ContextRequest, EmbeddingModel, + Host, InlineCompletionModel, + LLMProvider, MarkdownData, + MCPServer, + NotebookIntelligenceExtension, + TelemetryEvent, TelemetryListener, + Tool, Toolset) from lab_notebook_intelligence.base_chat_participant import BaseChatParticipant from lab_notebook_intelligence.config import NBIConfig -from lab_notebook_intelligence.github_copilot_chat_participant import GithubCopilotChatParticipant -from lab_notebook_intelligence.llm_providers.github_copilot_llm_provider import GitHubCopilotLLMProvider -from lab_notebook_intelligence.llm_providers.litellm_compatible_llm_provider import LiteLLMCompatibleLLMProvider -from lab_notebook_intelligence.llm_providers.ollama_llm_provider import OllamaLLMProvider -from lab_notebook_intelligence.llm_providers.openai_compatible_llm_provider import OpenAICompatibleLLMProvider +from lab_notebook_intelligence.github_copilot_chat_participant import \ + GithubCopilotChatParticipant +from lab_notebook_intelligence.llm_providers.github_copilot_llm_provider import \ + GitHubCopilotLLMProvider +from lab_notebook_intelligence.llm_providers.litellm_compatible_llm_provider import \ + LiteLLMCompatibleLLMProvider +from lab_notebook_intelligence.llm_providers.ollama_llm_provider import \ + OllamaLLMProvider +from lab_notebook_intelligence.llm_providers.openai_compatible_llm_provider import \ + OpenAICompatibleLLMProvider from lab_notebook_intelligence.mcp_manager import MCPManager log = logging.getLogger(__name__) -DEFAULT_CHAT_PARTICIPANT_ID = 'default' -RESERVED_LLM_PROVIDER_IDS = set([ - 'openai', 'anthropic', 'chat', 'copilot', 'jupyter', 'jupyterlab', 'jlab', 'notebook', 'intelligence', 'nb', 'nbi', 'ai', 'config', 'settings', 'ui', 'cell', 'code', 'file', 'data', 'new' -]) -RESERVED_PARTICIPANT_IDS = set([ - 'chat', 'copilot', 'jupyter', 'jupyterlab', 'jlab', 'notebook', 'intelligence', 'nb', 'nbi', 'terminal', 'vscode', 'workspace', 'help', 'ai', 'config', 'settings', 'ui', 'cell', 'code', 'file', 'data', 'new', 'run', 'search' -]) +DEFAULT_CHAT_PARTICIPANT_ID = "default" +RESERVED_LLM_PROVIDER_IDS = set( + [ + "openai", + "anthropic", + "chat", + "copilot", + "jupyter", + "jupyterlab", + "jlab", + "notebook", + "intelligence", + "nb", + "nbi", + "ai", + "config", + "settings", + "ui", + "cell", + "code", + "file", + "data", + "new", + ] +) +RESERVED_PARTICIPANT_IDS = set( + [ + "chat", + "copilot", + "jupyter", + "jupyterlab", + "jlab", + "notebook", + "intelligence", + "nb", + "nbi", + "terminal", + "vscode", + "workspace", + "help", + "ai", + "config", + "settings", + "ui", + "cell", + "code", + "file", + "data", + "new", + "run", + "search", + ] +) + class AIServiceManager(Host): def __init__(self, options: dict = {}): @@ -35,7 +98,9 @@ def __init__(self, options: dict = {}): self.telemetry_listeners: Dict[str, TelemetryListener] = {} self._extension_toolsets: Dict[str, list[Toolset]] = {} self._options = options.copy() - self._nbi_config = NBIConfig({"server_root_dir": self._options.get('server_root_dir', '')}) + self._nbi_config = NBIConfig( + {"server_root_dir": self._options.get("server_root_dir", "")} + ) self._openai_compatible_llm_provider = OpenAICompatibleLLMProvider() self._litellm_compatible_llm_provider = LiteLLMCompatibleLLMProvider() self._ollama_llm_provider = OllamaLLMProvider() @@ -45,7 +110,7 @@ def __init__(self, options: dict = {}): @property def nbi_config(self) -> NBIConfig: return self._nbi_config - + @property def ollama_llm_provider(self) -> OllamaLLMProvider: return self._ollama_llm_provider @@ -66,37 +131,65 @@ def initialize(self): def update_models_from_config(self): using_github_copilot_service = self.nbi_config.using_github_copilot_service if using_github_copilot_service: - github_copilot.login_with_existing_credentials(self._nbi_config.store_github_access_token) - github_copilot.enable_github_login_status_change_updater(using_github_copilot_service) + github_copilot.login_with_existing_credentials( + self._nbi_config.store_github_access_token + ) + github_copilot.enable_github_login_status_change_updater( + using_github_copilot_service + ) chat_model_cfg = self.nbi_config.chat_model - chat_model_provider_id = chat_model_cfg.get('provider', 'none') - chat_model_id = chat_model_cfg.get('model', 'none') + chat_model_provider_id = chat_model_cfg.get("provider", "none") + chat_model_id = chat_model_cfg.get("model", "none") chat_model_provider = self.get_llm_provider(chat_model_provider_id) - self._chat_model = chat_model_provider.get_chat_model(chat_model_id) if chat_model_provider is not None else None + self._chat_model = ( + chat_model_provider.get_chat_model(chat_model_id) + if chat_model_provider is not None + else None + ) inline_completion_model_cfg = self.nbi_config.inline_completion_model - inline_completion_model_provider_id = inline_completion_model_cfg.get('provider', 'none') - inline_completion_model_id = inline_completion_model_cfg.get('model', 'none') - inline_completion_model_provider = self.get_llm_provider(inline_completion_model_provider_id) - self._inline_completion_model = inline_completion_model_provider.get_inline_completion_model(inline_completion_model_id) if inline_completion_model_provider is not None else None + inline_completion_model_provider_id = inline_completion_model_cfg.get( + "provider", "none" + ) + inline_completion_model_id = inline_completion_model_cfg.get("model", "none") + inline_completion_model_provider = self.get_llm_provider( + inline_completion_model_provider_id + ) + self._inline_completion_model = ( + inline_completion_model_provider.get_inline_completion_model( + inline_completion_model_id + ) + if inline_completion_model_provider is not None + else None + ) self._embedding_model = None if self._chat_model is not None: - properties = chat_model_cfg.get('properties', []) + properties = chat_model_cfg.get("properties", []) for property in properties: - self._chat_model.set_property_value(property['id'], property['value']) + self._chat_model.set_property_value(property["id"], property["value"]) if self._inline_completion_model is not None: - properties = inline_completion_model_cfg.get('properties', []) + properties = inline_completion_model_cfg.get("properties", []) for property in properties: - self._inline_completion_model.set_property_value(property['id'], property['value']) - - is_github_copilot_chat_model = isinstance(chat_model_provider, GitHubCopilotLLMProvider) - default_chat_participant = GithubCopilotChatParticipant() if is_github_copilot_chat_model else BaseChatParticipant() + self._inline_completion_model.set_property_value( + property["id"], property["value"] + ) + + is_github_copilot_chat_model = isinstance( + chat_model_provider, GitHubCopilotLLMProvider + ) + default_chat_participant = ( + GithubCopilotChatParticipant() + if is_github_copilot_chat_model + else BaseChatParticipant() + ) self._default_chat_participant = default_chat_participant - self.chat_participants[DEFAULT_CHAT_PARTICIPANT_ID] = self._default_chat_participant + self.chat_participants[DEFAULT_CHAT_PARTICIPANT_ID] = ( + self._default_chat_participant + ) def update_mcp_servers(self): self._mcp_manager.update_mcp_servers(self.nbi_config.mcp) @@ -111,9 +204,9 @@ def initialize_extensions(self): log.info(f"Loading NBI extension from '{extension_dir}'...") metadata_path = path.join(extension_dir, "extension.json") if path.exists(metadata_path) and path.isfile(metadata_path): - with open(metadata_path, 'r') as file: + with open(metadata_path, "r") as file: data = json.load(file) - class_name = data['class'] + class_name = data["class"] extension = self.load_extension(class_name) if extension: extension.activate(self) @@ -121,15 +214,18 @@ def initialize_extensions(self): self._extensions.append(extension) except Exception as e: log.error(f"Failed to load NBI extension from '{extension_dir}'!\n{e}") - + def load_extension(self, extension_class: str) -> NotebookIntelligenceExtension: import importlib + try: parts = extension_class.split(".") module_name = ".".join(parts[0:-1]) class_name = parts[-1] ExtensionClass = getattr(importlib.import_module(module_name), class_name) - if ExtensionClass is not None and issubclass(ExtensionClass, NotebookIntelligenceExtension): + if ExtensionClass is not None and issubclass( + ExtensionClass, NotebookIntelligenceExtension + ): instance = ExtensionClass() return instance except Exception as e: @@ -155,22 +251,32 @@ def register_llm_provider(self, provider: LLMProvider) -> None: return self.llm_providers[provider.id] = provider - def register_completion_context_provider(self, provider: CompletionContextProvider) -> None: + def register_completion_context_provider( + self, provider: CompletionContextProvider + ) -> None: if provider.id in self.completion_context_providers: - log.error(f"Completion Context Provider ID '{provider.id}' is already in use!") + log.error( + f"Completion Context Provider ID '{provider.id}' is already in use!" + ) return self.completion_context_providers[provider.id] = provider def register_telemetry_listener(self, listener: TelemetryListener) -> None: if listener.name in self.telemetry_listeners: - log.error(f"Notebook Intelligence telemetry listener '{listener.name}' already exists!") + log.error( + f"Notebook Intelligence telemetry listener '{listener.name}' already exists!" + ) return - log.warning(f"Notebook Intelligence telemetry listener '{listener.name}' registered. Make sure it is from a trusted source.") + log.warning( + f"Notebook Intelligence telemetry listener '{listener.name}' registered. Make sure it is from a trusted source." + ) self.telemetry_listeners[listener.name] = listener def register_toolset(self, toolset: Toolset) -> None: if toolset.provider is None: - log.error(f"Toolset '{toolset.id}' has no provider! It cannot be registered.") + log.error( + f"Toolset '{toolset.id}' has no provider! It cannot be registered." + ) return provider_id = toolset.provider.id if provider_id not in self._extension_toolsets: @@ -185,11 +291,11 @@ def default_chat_participant(self) -> ChatParticipant: @property def chat_model(self) -> ChatModel: return self._chat_model - + @property def inline_completion_model(self) -> InlineCompletionModel: return self._inline_completion_model - + @property def embedding_model(self) -> EmbeddingModel: return self._embedding_model @@ -197,20 +303,20 @@ def embedding_model(self) -> EmbeddingModel: @staticmethod def parse_prompt(prompt: str) -> tuple[str, str, str]: participant = DEFAULT_CHAT_PARTICIPANT_ID - command = '' - input = '' + command = "" + input = "" prompt = prompt.lstrip() - parts = prompt.split(' ') - parts = [part for part in parts if part.strip() != ''] + parts = prompt.split(" ") + parts = [part for part in parts if part.strip() != ""] if len(parts) > 0: - if parts[0].startswith('@'): + if parts[0].startswith("@"): participant = parts[0][1:] parts = parts[1:] if len(parts) > 0: - if parts[0].startswith('/'): + if parts[0].startswith("/"): command = parts[0][1:] parts = parts[1:] @@ -218,12 +324,12 @@ def parse_prompt(prompt: str) -> tuple[str, str, str]: input = " ".join(parts) return [participant, command, input] - + def get_llm_provider(self, provider_id: str) -> LLMProvider: return self.llm_providers.get(provider_id) - + def get_llm_provider_for_model_ref(self, model_ref: str) -> LLMProvider: - parts = model_ref.split('::') + parts = model_ref.split("::") if len(parts) < 2: return None @@ -232,16 +338,16 @@ def get_llm_provider_for_model_ref(self, model_ref: str) -> LLMProvider: return self.get_llm_provider(provider_id) def get_chat_model(self, model_ref: str) -> ChatModel: - return self._get_provider_model(model_ref, 'chat') - + return self._get_provider_model(model_ref, "chat") + def get_inline_completion_model(self, model_ref: str) -> ChatModel: - return self._get_provider_model(model_ref, 'inline-completion') - + return self._get_provider_model(model_ref, "inline-completion") + def get_embedding_model(self, model_ref: str) -> ChatModel: - return self._get_provider_model(model_ref, 'embedding') - + return self._get_provider_model(model_ref, "embedding") + def _get_provider_model(self, model_ref: str, model_type: str) -> ChatModel: - parts = model_ref.split('::') + parts = model_ref.split("::") if len(parts) < 2: return None @@ -252,54 +358,99 @@ def _get_provider_model(self, model_ref: str, model_type: str) -> ChatModel: if llm_provider is None: return None - model_list = llm_provider.chat_models if model_type == 'chat' else llm_provider.inline_completion_models if model_type == 'inline-completion' else llm_provider.embedding_models + model_list = ( + llm_provider.chat_models + if model_type == "chat" + else ( + llm_provider.inline_completion_models + if model_type == "inline-completion" + else llm_provider.embedding_models + ) + ) for model in model_list: if model.id == model_id: return model return None - + @property def chat_model_ids(self) -> list[ChatModel]: model_ids = [] for provider in self.llm_providers.values(): - model_ids += [{"provider": provider.id, "id": model.id, "name": model.name, "context_window": model.context_window, "properties": [property.to_dict() for property in model.properties]} for model in provider.chat_models] + model_ids += [ + { + "provider": provider.id, + "id": model.id, + "name": model.name, + "context_window": model.context_window, + "properties": [property.to_dict() for property in model.properties], + } + for model in provider.chat_models + ] return model_ids @property def inline_completion_model_ids(self) -> list[InlineCompletionModel]: model_ids = [] for provider in self.llm_providers.values(): - model_ids += [{"provider": provider.id, "id": model.id, "name": model.name, "context_window": model.context_window, "properties": [property.to_dict() for property in model.properties]} for model in provider.inline_completion_models] + model_ids += [ + { + "provider": provider.id, + "id": model.id, + "name": model.name, + "context_window": model.context_window, + "properties": [property.to_dict() for property in model.properties], + } + for model in provider.inline_completion_models + ] return model_ids - + @property def embedding_model_ids(self) -> list[EmbeddingModel]: model_ids = [] for provider in self.llm_providers.values(): - model_ids += [{"id": f"{provider.id}::{model.id}", "name": f"{provider.name} / {model.name}", "context_window": model.context_window} for model in provider.embedding_models] + model_ids += [ + { + "id": f"{provider.id}::{model.id}", + "name": f"{provider.name} / {model.name}", + "context_window": model.context_window, + } + for model in provider.embedding_models + ] return model_ids def get_chat_participant(self, prompt: str) -> ChatParticipant: (participant_id, command, input) = AIServiceManager.parse_prompt(prompt) return self.chat_participants.get(participant_id, DEFAULT_CHAT_PARTICIPANT_ID) - async def handle_chat_request(self, request: ChatRequest, response: ChatResponse, options: dict = {}) -> None: + async def handle_chat_request( + self, request: ChatRequest, response: ChatResponse, options: dict = {} + ) -> None: if self.chat_model is None: response.stream(MarkdownData("Chat model is not set!")) - response.stream(ButtonData("Configure", "lab-notebook-intelligence:open-configuration-dialog")) + response.stream( + ButtonData( + "Configure", "lab-notebook-intelligence:open-configuration-dialog" + ) + ) response.finish() return request.host = self - (participant_id, command, prompt) = AIServiceManager.parse_prompt(request.prompt) - participant = self.chat_participants.get(participant_id, DEFAULT_CHAT_PARTICIPANT_ID) + (participant_id, command, prompt) = AIServiceManager.parse_prompt( + request.prompt + ) + participant = self.chat_participants.get( + participant_id, DEFAULT_CHAT_PARTICIPANT_ID + ) request.command = command request.prompt = prompt - response.participant_id = participant_id + response.participant_id = participant_id return await participant.handle_chat_request(request, response, options) - async def get_completion_context(self, request: ContextRequest) -> CompletionContext: + async def get_completion_context( + self, request: ContextRequest + ) -> CompletionContext: cancel_token = request.cancel_token context = CompletionContext([]) @@ -312,27 +463,32 @@ async def get_completion_context(self, request: ContextRequest) -> CompletionCon if cancel_token.is_cancel_requested: return context provider = self.completion_context_providers.get(provider) - if provider.id not in allowed_context_providers and '*' not in allowed_context_providers: + if ( + provider.id not in allowed_context_providers + and "*" not in allowed_context_providers + ): continue try: provider_context = provider.handle_completion_context_request(request) if provider_context.items: context.items += provider_context.items except Exception as e: - log.error(f"Error while getting completion context from provider '{provider.id}'!\n{e}") + log.error( + f"Error while getting completion context from provider '{provider.id}'!\n{e}" + ) return context - + async def emit_telemetry_event(self, event: TelemetryEvent): for listener in self.telemetry_listeners.values(): listener.on_telemetry_event(event) def get_mcp_servers(self): return self._mcp_manager.get_mcp_servers() - + def get_mcp_server(self, server_name: str) -> MCPServer: return self._mcp_manager.get_mcp_server(server_name) - + def get_mcp_server_tool(self, server_name: str, tool_name: str) -> Tool: mcp_server = self._mcp_manager.get_mcp_server(server_name) if mcp_server is not None: @@ -342,7 +498,7 @@ def get_mcp_server_tool(self, server_name: str, tool_name: str) -> Tool: def get_extension_toolsets(self) -> Dict[str, list[Toolset]]: return self._extension_toolsets - + def get_extension_toolset(self, extension_id: str, toolset_id: str) -> Toolset: if extension_id not in self._extension_toolsets: return None @@ -351,10 +507,12 @@ def get_extension_toolset(self, extension_id: str, toolset_id: str) -> Toolset: for toolset in extension_toolsets: if toolset_id == toolset.id: return toolset - + return None - def get_extension_tool(self, extension_id: str, toolset_id: str, tool_name: str) -> Tool: + def get_extension_tool( + self, extension_id: str, toolset_id: str, tool_name: str + ) -> Tool: if extension_id not in self._extension_toolsets: return None extension_toolsets = self._extension_toolsets[extension_id] @@ -364,7 +522,7 @@ def get_extension_tool(self, extension_id: str, toolset_id: str, tool_name: str) if tool.name == tool_name: return tool return None - + def get_extension(self, extension_id: str) -> NotebookIntelligenceExtension: for extension in self._extensions: if extension.id == extension_id: diff --git a/lab_notebook_intelligence/api.py b/lab_notebook_intelligence/api.py index ad18a3f..58d699b 100644 --- a/lab_notebook_intelligence/api.py +++ b/lab_notebook_intelligence/api.py @@ -1,52 +1,58 @@ # Copyright (c) Mehmet Bektas import asyncio -from typing import Any, Callable, Dict, Union +import logging +import uuid from dataclasses import asdict, dataclass from enum import Enum -import uuid +from typing import Any, Callable, Dict, Union + from fuzzy_json import loads as fuzzy_json_loads -import logging from mcp.server.fastmcp.tools import Tool as MCPToolClass from lab_notebook_intelligence.config import NBIConfig log = logging.getLogger(__name__) + class RequestDataType(str, Enum): - ChatRequest = 'chat-request' - ChatUserInput = 'chat-user-input' - ClearChatHistory = 'clear-chat-history' - RunUICommandResponse = 'run-ui-command-response' - GenerateCode = 'generate-code' - CancelChatRequest = 'cancel-chat-request' - InlineCompletionRequest = 'inline-completion-request' - CancelInlineCompletionRequest = 'cancel-inline-completion-request' + ChatRequest = "chat-request" + ChatUserInput = "chat-user-input" + ClearChatHistory = "clear-chat-history" + RunUICommandResponse = "run-ui-command-response" + GenerateCode = "generate-code" + CancelChatRequest = "cancel-chat-request" + InlineCompletionRequest = "inline-completion-request" + CancelInlineCompletionRequest = "cancel-inline-completion-request" + class BackendMessageType(str, Enum): - StreamMessage = 'stream-message' - StreamEnd = 'stream-end' - RunUICommand = 'run-ui-command' - GitHubCopilotLoginStatusChange = 'github-copilot-login-status-change' + StreamMessage = "stream-message" + StreamEnd = "stream-end" + RunUICommand = "run-ui-command" + GitHubCopilotLoginStatusChange = "github-copilot-login-status-change" + class ResponseStreamDataType(str, Enum): - LLMRaw = 'llm-raw' - Markdown = 'markdown' - MarkdownPart = 'markdown-part' - Image = 'image' - HTMLFrame = 'html-frame' - Button = 'button' - Anchor = 'anchor' - Progress = 'progress' - Confirmation = 'confirmation' + LLMRaw = "llm-raw" + Markdown = "markdown" + MarkdownPart = "markdown-part" + Image = "image" + HTMLFrame = "html-frame" + Button = "button" + Anchor = "anchor" + Progress = "progress" + Confirmation = "confirmation" def __str__(self) -> str: return self.value + class BuiltinToolset(str, Enum): - NotebookEdit = 'nbi-notebook-edit' - NotebookExecute = 'nbi-notebook-execute' - PythonFileEdit = 'nbi-python-file-edit' + NotebookEdit = "nbi-notebook-edit" + NotebookExecute = "nbi-notebook-execute" + PythonFileEdit = "nbi-python-file-edit" + class Signal: def __init__(self): @@ -58,6 +64,7 @@ def connect(self, listener: Callable) -> None: def disconnect(self, listener: Callable) -> None: self._listeners.remove(listener) + class SignalImpl(Signal): def __init__(self): super().__init__() @@ -66,6 +73,7 @@ def emit(self, *args, **kwargs) -> None: for listener in self._listeners: listener(*args, **kwargs) + class CancelToken: def __init__(self): self._cancellation_signal = Signal() @@ -79,93 +87,104 @@ def is_cancel_requested(self) -> bool: def cancellation_signal(self) -> Signal: return self._cancellation_signal + @dataclass class RequestToolSelection: built_in_toolsets: list[str] = None mcp_server_tools: dict[str, list[str]] = None extension_tools: dict[str, dict[str, list[str]]] = None + @dataclass class ChatRequest: - host: 'Host' = None - chat_mode: 'ChatMode' = None + host: "Host" = None + chat_mode: "ChatMode" = None tool_selection: RequestToolSelection = None - command: str = '' - prompt: str = '' + command: str = "" + prompt: str = "" chat_history: list[dict] = None cancel_token: CancelToken = None + @dataclass class ResponseStreamData: @property def data_type(self) -> ResponseStreamDataType: raise NotImplemented + @dataclass class MarkdownData(ResponseStreamData): - content: str = '' + content: str = "" detail: dict = None @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.Markdown + @dataclass class MarkdownPartData(ResponseStreamData): - content: str = '' + content: str = "" @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.MarkdownPart + @dataclass class ImageData(ResponseStreamData): - content: str = '' + content: str = "" @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.Image + @dataclass class HTMLFrameData(ResponseStreamData): - source: str = '' + source: str = "" height: int = 30 @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.HTMLFrame + @dataclass class AnchorData(ResponseStreamData): - uri: str = '' - title: str = '' + uri: str = "" + title: str = "" @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.Anchor + @dataclass class ButtonData(ResponseStreamData): - title: str = '' - commandId: str = '' + title: str = "" + commandId: str = "" args: Dict[str, str] = None @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.Button + @dataclass class ProgressData(ResponseStreamData): - title: str = '' + title: str = "" @property def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.Progress + @dataclass class ConfirmationData(ResponseStreamData): - title: str = '' - message: str = '' + title: str = "" + message: str = "" confirmArgs: dict = None cancelArgs: dict = None confirmLabel: str = None @@ -175,26 +194,30 @@ class ConfirmationData(ResponseStreamData): def data_type(self) -> ResponseStreamDataType: return ResponseStreamDataType.Confirmation + class ContextRequestType(Enum): - InlineCompletion = 'inline-completion' - NewPythonFile = 'new-python-file' - NewNotebook = 'new-notebook' + InlineCompletion = "inline-completion" + NewPythonFile = "new-python-file" + NewNotebook = "new-notebook" + class ContextType(Enum): - Custom = 'custom' - Provider = 'provider' - CurrentFile = 'current-file' + Custom = "custom" + Provider = "provider" + CurrentFile = "current-file" + @dataclass class ContextRequest: type: ContextRequestType - prefix: str = '' - suffix: str = '' - language: str = '' - filename: str = '' - participant: 'ChatParticipant' = None + prefix: str = "" + suffix: str = "" + language: str = "" + filename: str = "" + participant: "ChatParticipant" = None cancel_token: CancelToken = None + @dataclass class ContextItem: type: ContextType @@ -205,15 +228,17 @@ class ContextItem: startLine: int = None endLine: int = None + @dataclass class CompletionContext: items: list[ContextItem] + class ChatResponse: def __init__(self): self._user_input_signal: SignalImpl = SignalImpl() self._run_ui_command_response_signal: SignalImpl = SignalImpl() - self.participant_id = '' + self.participant_id = "" @property def message_id(self) -> str: @@ -221,10 +246,10 @@ def message_id(self) -> str: def stream(self, data: ResponseStreamData, finish: bool = False) -> None: raise NotImplemented - + def finish(self) -> None: raise NotImplemented - + @property def user_input_signal(self) -> Signal: return self._user_input_signal @@ -233,11 +258,12 @@ def on_user_input(self, data: dict) -> None: self._user_input_signal.emit(data) @staticmethod - async def wait_for_chat_user_input(response: 'ChatResponse', callback_id: str): + async def wait_for_chat_user_input(response: "ChatResponse", callback_id: str): resp = {"data": None} + def _on_user_input(data: dict): - if data['callback_id'] == callback_id: - resp["data"] = data['data'] + if data["callback_id"] == callback_id: + resp["data"] = data["data"] response.user_input_signal.connect(_on_user_input) @@ -249,29 +275,35 @@ def _on_user_input(data: dict): async def run_ui_command(self, command: str, args: dict = {}) -> None: raise NotImplemented - + @property def run_ui_command_response_signal(self) -> Signal: return self._run_ui_command_response_signal - + def on_run_ui_command_response(self, data: dict) -> None: self._run_ui_command_response_signal.emit(data) @staticmethod - async def wait_for_run_ui_command_response(response: 'ChatResponse', callback_id: str): + async def wait_for_run_ui_command_response( + response: "ChatResponse", callback_id: str + ): resp = {"result": None} + def _on_ui_command_response(data: dict): - if data['callback_id'] == callback_id: - resp["result"] = data['result'] + if data["callback_id"] == callback_id: + resp["result"] = data["result"] response.run_ui_command_response_signal.connect(_on_ui_command_response) while True: if resp["result"] is not None: - response.run_ui_command_response_signal.disconnect(_on_ui_command_response) + response.run_ui_command_response_signal.disconnect( + _on_ui_command_response + ) return resp["result"] await asyncio.sleep(0.1) + @dataclass class ToolPreInvokeResponse: message: str = None @@ -279,10 +311,12 @@ class ToolPreInvokeResponse: confirmationTitle: str = None confirmationMessage: str = None + @dataclass class ChatCommand: - name: str = '' - description: str = '' + name: str = "" + description: str = "" + class Tool: @property @@ -292,27 +326,44 @@ def name(self) -> str: @property def title(self) -> str: raise NotImplemented - + @property def tags(self) -> list[str]: raise NotImplemented - + @property def description(self) -> str: raise NotImplemented - + @property def schema(self) -> dict: raise NotImplemented - def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + def pre_invoke( + self, request: ChatRequest, tool_args: dict + ) -> Union[ToolPreInvokeResponse, None]: return None - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: raise NotImplemented + class Toolset: - def __init__(self, id: str, name: str, description: str, provider: Union['NotebookIntelligenceExtension', None], tools: list[Tool] = [], instructions: str = None): + def __init__( + self, + id: str, + name: str, + description: str, + provider: Union["NotebookIntelligenceExtension", None], + tools: list[Tool] = [], + instructions: str = None, + ): self.id = id self.name = name self.description = description @@ -326,8 +377,18 @@ def add_tool(self, tool: Tool) -> None: def remove_tool(self, tool: Tool) -> None: self.tools.remove(tool) + class SimpleTool(Tool): - def __init__(self, tool_function: Callable, name: str, description: str, schema: dict, title: str = None, auto_approve: bool = False, has_var_args: bool = False): + def __init__( + self, + tool_function: Callable, + name: str, + description: str, + schema: dict, + title: str = None, + auto_approve: bool = False, + has_var_args: bool = False, + ): super().__init__() self._tool_function = tool_function self._name = name @@ -344,38 +405,49 @@ def name(self) -> str: @property def title(self) -> str: return self._title if self._title is not None else self._name - + @property def tags(self) -> list[str]: return [] - + @property def description(self) -> str: return self._description - + @property def schema(self) -> dict: return self._schema - def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + def pre_invoke( + self, request: ChatRequest, tool_args: dict + ) -> Union[ToolPreInvokeResponse, None]: confirmationTitle = None confirmationMessage = None if not self._auto_approve: confirmationTitle = "Approve" confirmationMessage = "Are you sure you want to call this tool?" - return ToolPreInvokeResponse(f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage) - - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: + return ToolPreInvokeResponse( + f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage + ) + + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: fn_args = tool_args.copy() if self._has_var_args: fn_args.update({"request": request, "response": response}) return await self._tool_function(**fn_args) + class MCPServer: @property def name(self) -> str: return NotImplemented - + async def connect(self): return NotImplemented @@ -384,7 +456,7 @@ async def disconnect(self): async def update_tool_list(self): return NotImplemented - + def get_tools(self) -> list[Tool]: return NotImplemented @@ -394,6 +466,7 @@ def get_tool(self, tool_name: str) -> Tool: async def call_tool(self, tool_name: str, tool_args: dict): return NotImplemented + def auto_approve(tool: SimpleTool): """ Decorator to set auto_approve to True for a tool. @@ -401,6 +474,7 @@ def auto_approve(tool: SimpleTool): tool._auto_approve = True return tool + def tool(tool_function: Callable) -> SimpleTool: mcp_tool = MCPToolClass.from_function(tool_function) has_var_args = False @@ -416,11 +490,20 @@ def tool(tool_function: Callable) -> SimpleTool: "name": mcp_tool.name, "description": mcp_tool.description, "strict": False, - "parameters": mcp_tool.parameters + "parameters": mcp_tool.parameters, }, } - return SimpleTool(tool_function, mcp_tool.name, mcp_tool.description, schema, mcp_tool.name, auto_approve, has_var_args) + return SimpleTool( + tool_function, + mcp_tool.name, + mcp_tool.description, + schema, + mcp_tool.name, + auto_approve, + has_var_args, + ) + class ChatMode: def __init__(self, id: str, name: str, instructions: str = None): @@ -428,6 +511,7 @@ def __init__(self, id: str, name: str, instructions: str = None): self.name = name self.instructions = instructions + class ChatParticipant: @property def id(self) -> str: @@ -440,15 +524,15 @@ def name(self) -> str: @property def description(self) -> str: raise NotImplemented - + @property def icon_path(self) -> str: return None - + @property def commands(self) -> list[ChatCommand]: return [] - + @property def tools(self) -> list[Tool]: return [] @@ -458,48 +542,68 @@ def allowed_context_providers(self) -> set[str]: # any context provider can be used return set(["*"]) - async def handle_chat_request(self, request: ChatRequest, response: ChatResponse, options: dict = {}) -> None: + async def handle_chat_request( + self, request: ChatRequest, response: ChatResponse, options: dict = {} + ) -> None: raise NotImplemented - - async def handle_chat_request_with_tools(self, request: ChatRequest, response: ChatResponse, options: dict = {}, tool_context: dict = {}, tool_choice = 'auto') -> None: + + async def handle_chat_request_with_tools( + self, + request: ChatRequest, + response: ChatResponse, + options: dict = {}, + tool_context: dict = {}, + tool_choice="auto", + ) -> None: tools = self.tools messages = request.chat_history.copy() system_prompt = options.get("system_prompt") if system_prompt is not None: - messages = [ - {"role": "system", "content": system_prompt} - ] + messages + messages = [{"role": "system", "content": system_prompt}] + messages if len(tools) == 0: - request.host.chat_model.completions(messages, tools=None, cancel_token=request.cancel_token, response=response) + request.host.chat_model.completions( + messages, + tools=None, + cancel_token=request.cancel_token, + response=response, + ) return openai_tools = [tool.schema for tool in tools] - tool_call_rounds = [] # TODO overrides options arg - options = {'tool_choice': tool_choice} + options = {"tool_choice": tool_choice} async def _tool_call_loop(tool_call_rounds: list): try: if request.cancel_token.is_cancel_requested: return - tool_response = request.host.chat_model.completions(messages, openai_tools, cancel_token=request.cancel_token, options=options) + tool_response = request.host.chat_model.completions( + messages, + openai_tools, + cancel_token=request.cancel_token, + options=options, + ) # after first call, set tool_choice to auto - options['tool_choice'] = 'auto' + options["tool_choice"] = "auto" - for choice in tool_response['choices']: - if choice['message'].get('tool_calls', None) is not None: - for tool_call in choice['message']['tool_calls']: + for choice in tool_response["choices"]: + if choice["message"].get("tool_calls", None) is not None: + for tool_call in choice["message"]["tool_calls"]: tool_call_rounds.append(tool_call) - elif choice['message'].get('content', None) is not None: - response.stream(MarkdownData(tool_response['choices'][0]['message']['content'])) + elif choice["message"].get("content", None) is not None: + response.stream( + MarkdownData( + tool_response["choices"][0]["message"]["content"] + ) + ) - messages.append(choice['message']) + messages.append(choice["message"]) had_tool_call = len(tool_call_rounds) > 0 @@ -510,30 +614,41 @@ async def _tool_call_loop(tool_call_rounds: list): tool_call = tool_call_rounds[0] if "id" not in tool_call: - tool_call['id'] = uuid.uuid4().hex + tool_call["id"] = uuid.uuid4().hex tool_call_rounds = tool_call_rounds[1:] - tool_name = tool_call['function']['name'] + tool_name = tool_call["function"]["name"] print("Tool name is : ", tool_name) tool_to_call = self._get_tool_by_name(tool_name) if tool_to_call is None: - log.error(f"Tool not found: {tool_name}, args: {tool_call['function']['arguments']}") - response.stream(MarkdownData("Oops! Failed to find requested tool. Please try again with a different prompt.")) + log.error( + f"Tool not found: {tool_name}, args: {tool_call['function']['arguments']}" + ) + response.stream( + MarkdownData( + "Oops! Failed to find requested tool. Please try again with a different prompt." + ) + ) response.finish() return - + print("Tool to call is : ", tool_to_call) - if type(tool_call['function']['arguments']) is dict: - args = tool_call['function']['arguments'] - elif not tool_call['function']['arguments'].startswith('{'): - args = tool_call['function']['arguments'] + if type(tool_call["function"]["arguments"]) is dict: + args = tool_call["function"]["arguments"] + elif not tool_call["function"]["arguments"].startswith("{"): + args = tool_call["function"]["arguments"] else: - args = fuzzy_json_loads(tool_call['function']['arguments']) + args = fuzzy_json_loads(tool_call["function"]["arguments"]) - tool_properties = tool_to_call.schema["function"]["parameters"]["properties"] + tool_properties = tool_to_call.schema["function"]["parameters"][ + "properties" + ] if type(args) is str: - if len(tool_properties) == 1 and tool_call['function']['arguments'] is not None: + if ( + len(tool_properties) == 1 + and tool_call["function"]["arguments"] is not None + ): tool_property = list(tool_properties.keys())[0] args = {tool_property: args} else: @@ -542,25 +657,48 @@ async def _tool_call_loop(tool_call_rounds: list): tool_pre_invoke_response = tool_to_call.pre_invoke(request, args) if tool_pre_invoke_response is not None: if tool_pre_invoke_response.message is not None: - response.stream(MarkdownData(f"✓ {tool_pre_invoke_response.message}...", tool_pre_invoke_response.detail)) + response.stream( + MarkdownData( + f"✓ {tool_pre_invoke_response.message}...", + tool_pre_invoke_response.detail, + ) + ) if tool_pre_invoke_response.confirmationMessage is not None: - response.stream(ConfirmationData( - title=tool_pre_invoke_response.confirmationTitle, - message=tool_pre_invoke_response.confirmationMessage, - confirmArgs={"id": response.message_id, "data": { "callback_id": tool_call['id'], "data": {"confirmed": True}}}, - cancelArgs={"id": response.message_id, "data": { "callback_id": tool_call['id'], "data": {"confirmed": False}}}, - )) - user_input = await ChatResponse.wait_for_chat_user_input(response, tool_call['id']) - if user_input['confirmed'] == False: + response.stream( + ConfirmationData( + title=tool_pre_invoke_response.confirmationTitle, + message=tool_pre_invoke_response.confirmationMessage, + confirmArgs={ + "id": response.message_id, + "data": { + "callback_id": tool_call["id"], + "data": {"confirmed": True}, + }, + }, + cancelArgs={ + "id": response.message_id, + "data": { + "callback_id": tool_call["id"], + "data": {"confirmed": False}, + }, + }, + ) + ) + user_input = await ChatResponse.wait_for_chat_user_input( + response, tool_call["id"] + ) + if user_input["confirmed"] == False: response.finish() return - tool_call_response = await tool_to_call.handle_tool_call(request, response, tool_context, args) + tool_call_response = await tool_to_call.handle_tool_call( + request, response, tool_context, args + ) function_call_result_message = { "role": "tool", "content": str(tool_call_response), - "tool_call_id": tool_call['id'] + "tool_call_id": tool_call["id"], } messages.append(function_call_result_message) @@ -577,26 +715,34 @@ async def _tool_call_loop(tool_call_rounds: list): return except Exception as e: log.error(f"Error in tool call loop: {str(e)}") - response.stream(MarkdownData(f"Oops! I am sorry, there was a problem generating response with tools. Please try again. You can check server logs for more details.")) + response.stream( + MarkdownData( + f"Oops! I am sorry, there was a problem generating response with tools. Please try again. You can check server logs for more details." + ) + ) response.finish() return await _tool_call_loop(tool_call_rounds) - + def _get_tool_by_name(self, name: str) -> Tool: for tool in self.tools: if tool.name == name: return tool return None + class CompletionContextProvider: @property def id(self) -> str: raise NotImplemented - def handle_completion_context_request(self, request: ContextRequest) -> CompletionContext: + def handle_completion_context_request( + self, request: ContextRequest + ) -> CompletionContext: raise NotImplemented + @dataclass class LLMProviderProperty: id: str @@ -608,6 +754,7 @@ class LLMProviderProperty: def to_dict(self): return asdict(self) + class LLMPropertyProvider: def __init__(self): self._properties = [] @@ -627,23 +774,24 @@ def set_property_value(self, property_id: str, value: str): if prop.id == property_id: prop.value = value + class AIModel(LLMPropertyProvider): - def __init__(self, provider: 'LLMProvider'): + def __init__(self, provider: "LLMProvider"): super().__init__() self._provider = provider @property def id(self) -> str: raise NotImplemented - + @property def name(self) -> str: raise NotImplemented @property - def provider(self) -> 'LLMProvider': + def provider(self) -> "LLMProvider": return self._provider - + @property def context_window(self) -> int: raise NotImplemented @@ -652,18 +800,36 @@ def context_window(self) -> int: def supports_tools(self) -> bool: return False + class ChatModel(AIModel): - def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: + def completions( + self, + messages: list[dict], + tools: list[dict] = None, + response: ChatResponse = None, + cancel_token: CancelToken = None, + options: dict = {}, + ) -> Any: raise NotImplemented + class InlineCompletionModel(AIModel): - def inline_completions(prefix, suffix, language, filename, context: CompletionContext, cancel_token: CancelToken) -> str: + def inline_completions( + prefix, + suffix, + language, + filename, + context: CompletionContext, + cancel_token: CancelToken, + ) -> str: raise NotImplemented + class EmbeddingModel(AIModel): def embeddings(self, inputs: list[str]) -> Any: raise NotImplemented + class LLMProvider(LLMPropertyProvider): def __init__(self): super().__init__() @@ -671,7 +837,7 @@ def __init__(self): @property def id(self) -> str: raise NotImplemented - + @property def name(self) -> str: raise NotImplemented @@ -679,11 +845,11 @@ def name(self) -> str: @property def chat_models(self) -> list[ChatModel]: raise NotImplemented - + @property def inline_completion_models(self) -> list[InlineCompletionModel]: raise NotImplemented - + @property def embedding_models(self) -> list[EmbeddingModel]: raise NotImplemented @@ -693,41 +859,44 @@ def get_chat_model(self, model_id: str) -> ChatModel: if model.id == model_id: return model return None - + def get_inline_completion_model(self, model_id: str) -> InlineCompletionModel: for model in self.inline_completion_models: if model.id == model_id: return model return None - + def get_embedding_model(self, model_id: str) -> EmbeddingModel: for model in self.embedding_models: if model.id == model_id: return model return None + class TelemetryEventType(str, Enum): - InlineCompletionRequest = 'inline-completion-request' - ExplainThisRequest = 'explain-this-request' - FixThisCodeRequest = 'fix-this-code-request' - ExplainThisOutputRequest = 'explain-this-output-request' - TroubleshootThisOutputRequest = 'troubleshoot-this-output-request' - GenerateCodeRequest = 'generate-code-request' - ChatRequest = 'chat-request' - InlineChatRequest = 'inline-chat-request' - ChatResponse = 'chat-response' - InlineChatResponse = 'inline-chat-response' - InlineCompletionResponse = 'inline-completion-response' + InlineCompletionRequest = "inline-completion-request" + ExplainThisRequest = "explain-this-request" + FixThisCodeRequest = "fix-this-code-request" + ExplainThisOutputRequest = "explain-this-output-request" + TroubleshootThisOutputRequest = "troubleshoot-this-output-request" + GenerateCodeRequest = "generate-code-request" + ChatRequest = "chat-request" + InlineChatRequest = "inline-chat-request" + ChatResponse = "chat-response" + InlineChatResponse = "inline-chat-response" + InlineCompletionResponse = "inline-completion-response" + class TelemetryEvent: @property def type(self) -> TelemetryEventType: raise NotImplemented - + @property def data(self) -> dict: return None + class TelemetryListener: @property def name(self) -> str: @@ -736,6 +905,7 @@ def name(self) -> str: def on_telemetry_event(self, event: TelemetryEvent): raise NotImplemented + class Host: def register_llm_provider(self, provider: LLMProvider) -> None: raise NotImplemented @@ -743,9 +913,11 @@ def register_llm_provider(self, provider: LLMProvider) -> None: def register_chat_participant(self, participant: ChatParticipant) -> None: raise NotImplemented - def register_completion_context_provider(self, provider: CompletionContextProvider) -> None: + def register_completion_context_provider( + self, provider: CompletionContextProvider + ) -> None: raise NotImplemented - + def register_telemetry_listener(self, listener: TelemetryListener) -> None: raise NotImplemented @@ -759,15 +931,15 @@ def nbi_config(self) -> NBIConfig: @property def default_chat_participant(self) -> ChatParticipant: raise NotImplemented - + @property def chat_model(self) -> ChatModel: raise NotImplemented - + @property def inline_completion_model(self) -> InlineCompletionModel: raise NotImplemented - + @property def embedding_model(self) -> EmbeddingModel: raise NotImplemented @@ -781,9 +953,12 @@ def get_mcp_server_tool(self, server_name: str, tool_name: str) -> Tool: def get_extension_toolset(self, extension_id: str, toolset_id: str) -> Toolset: return NotImplemented - def get_extension_tool(self, extension_id: str, toolset_id: str, tool_name: str) -> Tool: + def get_extension_tool( + self, extension_id: str, toolset_id: str, tool_name: str + ) -> Tool: return NotImplemented + class NotebookIntelligenceExtension: @property def id(self) -> str: @@ -792,7 +967,7 @@ def id(self) -> str: @property def name(self) -> str: raise NotImplemented - + @property def provider(self) -> str: raise NotImplemented diff --git a/lab_notebook_intelligence/base_chat_participant.py b/lab_notebook_intelligence/base_chat_participant.py index ffeabaa..2334d15 100644 --- a/lab_notebook_intelligence/base_chat_participant.py +++ b/lab_notebook_intelligence/base_chat_participant.py @@ -1,14 +1,17 @@ # Copyright (c) Mehmet Bektas -import os -from typing import Union -import json -from lab_notebook_intelligence.api import ChatCommand, ChatParticipant, ChatRequest, ChatResponse, MarkdownData, ProgressData, Tool, ToolPreInvokeResponse -from lab_notebook_intelligence.prompts import Prompts import base64 +import json import logging -from lab_notebook_intelligence.built_in_toolsets import built_in_toolsets +import os +from typing import Union +from lab_notebook_intelligence.api import (ChatCommand, ChatParticipant, + ChatRequest, ChatResponse, + MarkdownData, ProgressData, Tool, + ToolPreInvokeResponse) +from lab_notebook_intelligence.built_in_toolsets import built_in_toolsets +from lab_notebook_intelligence.prompts import Prompts from lab_notebook_intelligence.util import extract_llm_generated_code log = logging.getLogger(__name__) @@ -16,6 +19,7 @@ ICON_SVG = '' ICON_URL = f"data:image/svg+xml;base64,{base64.b64encode(ICON_SVG.encode('utf-8')).decode('utf-8')}" + class SecuredExtensionTool(Tool): def __init__(self, extension_tool: Tool): super().__init__() @@ -28,11 +32,11 @@ def name(self) -> str: @property def title(self) -> str: return self._ext_tool.title - + @property def tags(self) -> list[str]: return self._ext_tool.tags - + @property def description(self) -> str: return self._ext_tool.description @@ -40,19 +44,30 @@ def description(self) -> str: @property def schema(self) -> dict: return self._ext_tool.schema - - def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + + def pre_invoke( + self, request: ChatRequest, tool_args: dict + ) -> Union[ToolPreInvokeResponse, None]: confirmationTitle = "Approve" confirmationMessage = "Are you sure you want to call this extension tool?" return ToolPreInvokeResponse( - message = f"Calling extension tool '{self.name}'", - detail = {"title": "Parameters", "content": json.dumps(tool_args)}, - confirmationTitle = confirmationTitle, - confirmationMessage = confirmationMessage + message=f"Calling extension tool '{self.name}'", + detail={"title": "Parameters", "content": json.dumps(tool_args)}, + confirmationTitle=confirmationTitle, + confirmationMessage=confirmationMessage, + ) + + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: + return await self._ext_tool.handle_tool_call( + request, response, tool_context, tool_args ) - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: - return await self._ext_tool.handle_tool_call(request, response, tool_context, tool_args) class CreateNewNotebookTool(Tool): def __init__(self, auto_approve: bool = False): @@ -66,15 +81,17 @@ def name(self) -> str: @property def title(self) -> str: return "Create new notebook with the provided code and markdown cells" - + @property def tags(self) -> list[str]: return ["default-participant-tool"] - + @property def description(self) -> str: - return "This tool creates a new notebook with the provided code and markdown cells" - + return ( + "This tool creates a new notebook with the provided code and markdown cells" + ) + @property def schema(self) -> dict: return { @@ -93,14 +110,14 @@ def schema(self) -> dict: "properties": { "cell_type": { "type": "string", - "enum": ["code", "markdown"] + "enum": ["code", "markdown"], }, "source": { "type": "string", - "description": "The content of the cell" - } - } - } + "description": "The content of the cell", + }, + }, + }, } }, "required": [], @@ -108,32 +125,51 @@ def schema(self) -> dict: }, }, } - - def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + + def pre_invoke( + self, request: ChatRequest, tool_args: dict + ) -> Union[ToolPreInvokeResponse, None]: confirmationTitle = None confirmationMessage = None if not self._auto_approve: confirmationTitle = "Approve" confirmationMessage = "Are you sure you want to call this tool?" - return ToolPreInvokeResponse(f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage) + return ToolPreInvokeResponse( + f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage + ) - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: - cell_sources = tool_args.get('cell_sources', []) - - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:create-new-notebook-from-py', {'code': ''}) - file_path = ui_cmd_response['path'] + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: + cell_sources = tool_args.get("cell_sources", []) + + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:create-new-notebook-from-py", {"code": ""} + ) + file_path = ui_cmd_response["path"] for cell_source in cell_sources: - cell_type = cell_source.get('cell_type') - if cell_type == 'markdown': - source = cell_source.get('source', '') - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-markdown-cell-to-notebook', {'markdown': source, 'path': file_path}) - elif cell_type == 'code': - source = cell_source.get('source', '') - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-code-cell-to-notebook', {'code': source, 'path': file_path}) + cell_type = cell_source.get("cell_type") + if cell_type == "markdown": + source = cell_source.get("source", "") + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-markdown-cell-to-notebook", + {"markdown": source, "path": file_path}, + ) + elif cell_type == "code": + source = cell_source.get("source", "") + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-code-cell-to-notebook", + {"code": source, "path": file_path}, + ) return "Notebook created successfully at {file_path}" + class AddMarkdownCellToNotebookTool(Tool): def __init__(self, auto_approve: bool = False): self._auto_approve = auto_approve @@ -146,15 +182,15 @@ def name(self) -> str: @property def title(self) -> str: return "Add markdown cell to notebook" - + @property def tags(self) -> list[str]: return ["default-participant-tool"] - + @property def description(self) -> str: return "This is a tool that adds markdown cell to a notebook" - + @property def schema(self) -> dict: return { @@ -173,7 +209,7 @@ def schema(self) -> dict: "markdown_cell_source": { "type": "string", "description": "Markdown to add to the notebook", - } + }, }, "required": ["notebook_file_path", "markdown_cell_source"], "additionalProperties": False, @@ -181,23 +217,37 @@ def schema(self) -> dict: }, } - def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + def pre_invoke( + self, request: ChatRequest, tool_args: dict + ) -> Union[ToolPreInvokeResponse, None]: confirmationTitle = None confirmationMessage = None if not self._auto_approve: confirmationTitle = "Approve" confirmationMessage = "Are you sure you want to call this tool?" - return ToolPreInvokeResponse(f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage) + return ToolPreInvokeResponse( + f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage + ) - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: - notebook_file_path = tool_args.get('notebook_file_path', '') + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: + notebook_file_path = tool_args.get("notebook_file_path", "") server_root_dir = request.host.nbi_config.server_root_dir if notebook_file_path.startswith(server_root_dir): notebook_file_path = os.path.relpath(notebook_file_path, server_root_dir) - source = tool_args.get('markdown_cell_source') - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-markdown-cell-to-notebook', {'markdown': source, 'path': notebook_file_path}) + source = tool_args.get("markdown_cell_source") + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-markdown-cell-to-notebook", + {"markdown": source, "path": notebook_file_path}, + ) return f"Added markdown cell to notebook" + class AddCodeCellTool(Tool): def __init__(self, auto_approve: bool = False): self._auto_approve = auto_approve @@ -210,15 +260,15 @@ def name(self) -> str: @property def title(self) -> str: return "Add code cell to notebook" - + @property def tags(self) -> list[str]: return ["default-participant-tool"] - + @property def description(self) -> str: return "This is a tool that adds code cell to a notebook" - + @property def schema(self) -> dict: return { @@ -237,7 +287,7 @@ def schema(self) -> dict: "code_cell_source": { "type": "string", "description": "Code to add to the notebook", - } + }, }, "required": ["notebook_file_path", "code_cell_source"], "additionalProperties": False, @@ -245,23 +295,37 @@ def schema(self) -> dict: }, } - def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + def pre_invoke( + self, request: ChatRequest, tool_args: dict + ) -> Union[ToolPreInvokeResponse, None]: confirmationTitle = None confirmationMessage = None if not self._auto_approve: confirmationTitle = "Approve" confirmationMessage = "Are you sure you want to call this tool?" - return ToolPreInvokeResponse(f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage) + return ToolPreInvokeResponse( + f"Calling tool '{self.name}'", confirmationTitle, confirmationMessage + ) - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: - notebook_file_path = tool_args.get('notebook_file_path', '') + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: + notebook_file_path = tool_args.get("notebook_file_path", "") server_root_dir = request.host.nbi_config.server_root_dir if notebook_file_path.startswith(server_root_dir): notebook_file_path = os.path.relpath(notebook_file_path, server_root_dir) - source = tool_args.get('code_cell_source') - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-code-cell-to-notebook', {'code': source, 'path': notebook_file_path}) + source = tool_args.get("code_cell_source") + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-code-cell-to-notebook", + {"code": source, "path": notebook_file_path}, + ) return "Added code cell added to notebook" + # Fallback tool to handle tool errors class PythonTool(AddCodeCellTool): @property @@ -271,15 +335,15 @@ def name(self) -> str: @property def title(self) -> str: return "Add code cell to notebook" - + @property def tags(self) -> list[str]: return ["default-participant-tool"] - + @property def description(self) -> str: return "This is a tool that adds code cell to a notebook" - + @property def schema(self) -> dict: return { @@ -302,11 +366,21 @@ def schema(self) -> dict: }, } - async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: - code = tool_args.get('code_cell_source') - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-code-cell-to-notebook', {'code': code, 'path': tool_context.get('file_path')}) + async def handle_tool_call( + self, + request: ChatRequest, + response: ChatResponse, + tool_context: dict, + tool_args: dict, + ) -> str: + code = tool_args.get("code_cell_source") + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-code-cell-to-notebook", + {"code": code, "path": tool_context.get("file_path")}, + ) return {"result": "Code cell added to notebook"} + class BaseChatParticipant(ChatParticipant): def __init__(self): super().__init__() @@ -315,7 +389,7 @@ def __init__(self): @property def id(self) -> str: return "default" - + @property def name(self) -> str: return "AI Assistant" @@ -323,17 +397,17 @@ def name(self) -> str: @property def description(self) -> str: return "AI Assistant" - + @property def icon_path(self) -> str: return ICON_URL - + @property def commands(self) -> list[ChatCommand]: return [ - ChatCommand(name='newNotebook', description='Create a new notebook'), - ChatCommand(name='newPythonFile', description='Create a new Python file'), - ChatCommand(name='clear', description='Clears chat history'), + ChatCommand(name="newNotebook", description="Create a new notebook"), + ChatCommand(name="newPythonFile", description="Create a new Python file"), + ChatCommand(name="clear", description="Clears chat history"), ] @property @@ -341,14 +415,21 @@ def tools(self) -> list[Tool]: tool_list = [] chat_mode = self._current_chat_request.chat_mode if chat_mode.id == "ask": - tool_list = [AddMarkdownCellToNotebookTool(), AddCodeCellTool(), PythonTool()] + tool_list = [ + AddMarkdownCellToNotebookTool(), + AddCodeCellTool(), + PythonTool(), + ] elif chat_mode.id == "agent": tool_selection = self._current_chat_request.tool_selection host = self._current_chat_request.host for toolset in tool_selection.built_in_toolsets: built_in_toolset = built_in_toolsets[toolset] tool_list += built_in_toolset.tools - for server_name, mcp_server_tool_list in tool_selection.mcp_server_tools.items(): + for ( + server_name, + mcp_server_tool_list, + ) in tool_selection.mcp_server_tools.items(): for tool_name in mcp_server_tool_list: mcp_server_tool = host.get_mcp_server_tool(server_name, tool_name) if mcp_server_tool is not None: @@ -356,7 +437,9 @@ def tools(self) -> list[Tool]: for ext_id, ext_toolsets in tool_selection.extension_tools.items(): for toolset_id, toolset_tools in ext_toolsets.items(): for tool_name in toolset_tools: - ext_tool = host.get_extension_tool(ext_id, toolset_id, tool_name) + ext_tool = host.get_extension_tool( + ext_id, toolset_id, tool_name + ) if ext_tool is not None: tool_list.append(SecuredExtensionTool(ext_tool)) return tool_list @@ -365,7 +448,7 @@ def tools(self) -> list[Tool]: def allowed_context_providers(self) -> set[str]: # any context provider can be used return set(["*"]) - + def chat_prompt(self, model_provider: str, model_name: str) -> str: return Prompts.generic_chat_prompt(model_provider, model_name) @@ -373,25 +456,46 @@ async def generate_code_cell(self, request: ChatRequest) -> str: chat_model = request.host.chat_model messages = request.chat_history.copy() messages.pop() - messages.insert(0, {"role": "system", "content": f"You are an assistant that creates Python code which will be used in a Jupyter notebook. Generate only Python code and some comments for the code. You should return the code directly, without wrapping it inside ```."}) - messages.append({"role": "user", "content": f"Generate code for: {request.prompt}"}) + messages.insert( + 0, + { + "role": "system", + "content": f"You are an assistant that creates Python code which will be used in a Jupyter notebook. Generate only Python code and some comments for the code. You should return the code directly, without wrapping it inside ```.", + }, + ) + messages.append( + {"role": "user", "content": f"Generate code for: {request.prompt}"} + ) generated = chat_model.completions(messages) - code = generated['choices'][0]['message']['content'] - + code = generated["choices"][0]["message"]["content"] + return extract_llm_generated_code(code) - + async def generate_markdown_for_code(self, request: ChatRequest, code: str) -> str: chat_model = request.host.chat_model messages = request.chat_history.copy() messages.pop() - messages.insert(0, {"role": "system", "content": f"You are an assistant that explains the provided code using markdown. Don't include any code, just narrative markdown text. Keep it concise, only generate few lines. First create a title that suits the code and then explain the code briefly. You should return the markdown directly, without wrapping it inside ```."}) - messages.append({"role": "user", "content": f"Generate markdown that explains this code: {code}"}) + messages.insert( + 0, + { + "role": "system", + "content": f"You are an assistant that explains the provided code using markdown. Don't include any code, just narrative markdown text. Keep it concise, only generate few lines. First create a title that suits the code and then explain the code briefly. You should return the markdown directly, without wrapping it inside ```.", + }, + ) + messages.append( + { + "role": "user", + "content": f"Generate markdown that explains this code: {code}", + } + ) generated = chat_model.completions(messages) - markdown = generated['choices'][0]['message']['content'] + markdown = generated["choices"][0]["message"]["content"] return extract_llm_generated_code(markdown) - async def handle_chat_request(self, request: ChatRequest, response: ChatResponse, options: dict = {}) -> None: + async def handle_chat_request( + self, request: ChatRequest, response: ChatResponse, options: dict = {} + ) -> None: self._current_chat_request = request if request.chat_mode.id == "ask": return await self.handle_ask_mode_chat_request(request, response, options) @@ -405,9 +509,14 @@ async def handle_chat_request(self, request: ChatRequest, response: ChatResponse if built_in_toolset.instructions is not None: system_prompt += built_in_toolset.instructions + "\n" - for extension_id, toolsets in request.tool_selection.extension_tools.items(): + for ( + extension_id, + toolsets, + ) in request.tool_selection.extension_tools.items(): for toolset_id in toolsets.keys(): - ext_toolset = request.host.get_extension_toolset(extension_id, toolset_id) + ext_toolset = request.host.get_extension_toolset( + extension_id, toolset_id + ) if ext_toolset is not None and ext_toolset.instructions is not None: system_prompt += ext_toolset.instructions + "\n" @@ -422,53 +531,89 @@ async def handle_chat_request(self, request: ChatRequest, response: ChatResponse await self.handle_chat_request_with_tools(request, response, options) - async def handle_ask_mode_chat_request(self, request: ChatRequest, response: ChatResponse, options: dict = {}) -> None: + async def handle_ask_mode_chat_request( + self, request: ChatRequest, response: ChatResponse, options: dict = {} + ) -> None: chat_model = request.host.chat_model - if request.command == 'newNotebook': + if request.command == "newNotebook": # create a new notebook - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:create-new-notebook-from-py', {'code': ''}) - file_path = ui_cmd_response['path'] + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:create-new-notebook-from-py", {"code": ""} + ) + file_path = ui_cmd_response["path"] code = await self.generate_code_cell(request) markdown = await self.generate_markdown_for_code(request, code) - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-markdown-cell-to-notebook', {'markdown': markdown, 'path': file_path}) - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-code-cell-to-notebook', {'code': code, 'path': file_path}) - - response.stream(MarkdownData(f"Notebook '{file_path}' created and opened successfully")) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-markdown-cell-to-notebook", + {"markdown": markdown, "path": file_path}, + ) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-code-cell-to-notebook", + {"code": code, "path": file_path}, + ) + + response.stream( + MarkdownData(f"Notebook '{file_path}' created and opened successfully") + ) response.finish() return - elif request.command == 'newPythonFile': + elif request.command == "newPythonFile": # create a new python file messages = request.chat_history.copy() messages.pop() - messages.insert(0, {"role": "system", "content": f"You are an assistant that creates Python code. You should return the code directly, without wrapping it inside ```."}) - messages.append({"role": "user", "content": f"Generate code for: {request.prompt}"}) + messages.insert( + 0, + { + "role": "system", + "content": f"You are an assistant that creates Python code. You should return the code directly, without wrapping it inside ```.", + }, + ) + messages.append( + {"role": "user", "content": f"Generate code for: {request.prompt}"} + ) generated = chat_model.completions(messages) - code = generated['choices'][0]['message']['content'] + code = generated["choices"][0]["message"]["content"] code = extract_llm_generated_code(code) - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:create-new-file', {'code': code }) - file_path = ui_cmd_response['path'] + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:create-new-file", {"code": code} + ) + file_path = ui_cmd_response["path"] response.stream(MarkdownData(f"File '{file_path}' created successfully")) response.finish() return - elif request.command == 'settings': - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:open-configuration-dialog') + elif request.command == "settings": + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:open-configuration-dialog" + ) response.stream(MarkdownData(f"Opened the settings dialog")) response.finish() return messages = [ - {"role": "system", "content": options.get("system_prompt", self.chat_prompt(chat_model.provider.name, chat_model.name))}, + { + "role": "system", + "content": options.get( + "system_prompt", + self.chat_prompt(chat_model.provider.name, chat_model.name), + ), + }, ] + request.chat_history try: if chat_model.provider.id != "github-copilot": response.stream(ProgressData("Thinking...")) - chat_model.completions(messages, response=response, cancel_token=request.cancel_token) + chat_model.completions( + messages, response=response, cancel_token=request.cancel_token + ) except Exception as e: log.error(f"Error while handling chat request!\n{e}") - response.stream(MarkdownData(f"Oops! There was a problem handling chat request. Please try again with a different prompt.")) + response.stream( + MarkdownData( + f"Oops! There was a problem handling chat request. Please try again with a different prompt." + ) + ) response.finish() @staticmethod diff --git a/lab_notebook_intelligence/built_in_toolsets.py b/lab_notebook_intelligence/built_in_toolsets.py index de12970..afdb990 100644 --- a/lab_notebook_intelligence/built_in_toolsets.py +++ b/lab_notebook_intelligence/built_in_toolsets.py @@ -1,34 +1,40 @@ # Copyright (c) Mehmet Bektas -from lab_notebook_intelligence.api import ChatResponse, Toolset import logging + import lab_notebook_intelligence.api as nbapi -from lab_notebook_intelligence.api import BuiltinToolset +from lab_notebook_intelligence.api import BuiltinToolset, ChatResponse, Toolset log = logging.getLogger(__name__) + @nbapi.auto_approve @nbapi.tool async def create_new_notebook(**args) -> str: - """Creates a new empty notebook. - """ + """Creates a new empty notebook.""" response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:create-new-notebook-from-py', {'code': ''}) - file_path = ui_cmd_response['path'] + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:create-new-notebook-from-py", {"code": ""} + ) + file_path = ui_cmd_response["path"] return f"Created new notebook at {file_path}" + @nbapi.auto_approve @nbapi.tool -async def rename_notebook(new_name: str, **args) -> str: +async def rename_notebook(new_name: str, **args) -> str: """Renames the notebook. Args: new_name: New name for the notebook """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:rename-notebook', {'newName': new_name}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:rename-notebook", {"newName": new_name} + ) return str(ui_cmd_response) + @nbapi.auto_approve @nbapi.tool async def add_markdown_cell(source: str, **args) -> str: @@ -37,10 +43,14 @@ async def add_markdown_cell(source: str, **args) -> str: source: Markdown source """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-markdown-cell-to-active-notebook', {'source': source}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-markdown-cell-to-active-notebook", + {"source": source}, + ) return "Added markdown cell to notebook" + @nbapi.auto_approve @nbapi.tool async def add_code_cell(source: str, **args) -> str: @@ -49,20 +59,25 @@ async def add_code_cell(source: str, **args) -> str: source: Python code source """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:add-code-cell-to-active-notebook', {'source': source}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:add-code-cell-to-active-notebook", {"source": source} + ) return "Added code cell to notebook" + @nbapi.auto_approve @nbapi.tool async def get_number_of_cells(**args) -> str: - """Get number of cells for the active notebook. - """ + """Get number of cells for the active notebook.""" response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:get-number-of-cells', {}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:get-number-of-cells", {} + ) return str(ui_cmd_response) + @nbapi.auto_approve @nbapi.tool async def get_cell_type_and_source(cell_index: int, **args) -> str: @@ -72,7 +87,9 @@ async def get_cell_type_and_source(cell_index: int, **args) -> str: cell_index: Zero based cell index """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:get-cell-type-and-source', {"cellIndex": cell_index }) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:get-cell-type-and-source", {"cellIndex": cell_index} + ) return str(ui_cmd_response) @@ -86,13 +103,18 @@ async def get_cell_output(cell_index: int, **args) -> str: cell_index: Zero based cell index """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:get-cell-output', {"cellIndex": cell_index}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:get-cell-output", {"cellIndex": cell_index} + ) return str(ui_cmd_response) + @nbapi.auto_approve @nbapi.tool -async def set_cell_type_and_source(cell_index: int, cell_type: str, source: str, **args) -> str: +async def set_cell_type_and_source( + cell_index: int, cell_type: str, source: str, **args +) -> str: """Set cell type and source for the cell at index for the active notebook. Args: @@ -101,10 +123,14 @@ async def set_cell_type_and_source(cell_index: int, cell_type: str, source: str, source: Markdown or Python code source """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:set-cell-type-and-source', {"cellIndex": cell_index, "cellType": cell_type, "source": source}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:set-cell-type-and-source", + {"cellIndex": cell_index, "cellType": cell_type, "source": source}, + ) return str(ui_cmd_response) + @nbapi.auto_approve @nbapi.tool async def delete_cell(cell_index: int, **args) -> str: @@ -115,10 +141,13 @@ async def delete_cell(cell_index: int, **args) -> str: """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:delete-cell-at-index', {"cellIndex": cell_index}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:delete-cell-at-index", {"cellIndex": cell_index} + ) return f"Deleted the cell at index: {cell_index}" + @nbapi.auto_approve @nbapi.tool async def insert_cell(cell_index: int, cell_type: str, source: str, **args) -> str: @@ -130,10 +159,14 @@ async def insert_cell(cell_index: int, cell_type: str, source: str, **args) -> s source: Markdown or Python code source """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:insert-cell-at-index', {"cellIndex": cell_index, "cellType": cell_type, "source": source}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:insert-cell-at-index", + {"cellIndex": cell_index, "cellType": cell_type, "source": source}, + ) return str(ui_cmd_response) + @nbapi.auto_approve @nbapi.tool async def run_cell(cell_index: int, **args) -> str: @@ -144,20 +177,23 @@ async def run_cell(cell_index: int, **args) -> str: """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:run-cell-at-index', {"cellIndex": cell_index}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:run-cell-at-index", {"cellIndex": cell_index} + ) return f"Ran the cell at index: {cell_index}" + @nbapi.auto_approve @nbapi.tool async def save_notebook(**args) -> str: - """Save the changes in active notebook to disk. - """ + """Save the changes in active notebook to disk.""" response: ChatResponse = args["response"] - ui_cmd_response = await response.run_ui_command('docmanager:save') + ui_cmd_response = await response.run_ui_command("docmanager:save") return f"Save the notebook" + @nbapi.auto_approve @nbapi.tool async def create_new_python_file(code: str, **args) -> str: @@ -166,21 +202,26 @@ async def create_new_python_file(code: str, **args) -> str: code: Python code source """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:create-new-file', {'code': code}) - file_path = ui_cmd_response['path'] + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:create-new-file", {"code": code} + ) + file_path = ui_cmd_response["path"] return f"Created new Python file at {file_path}" + @nbapi.auto_approve @nbapi.tool async def get_file_content(**args) -> str: - """Returns the content of the current file. - """ + """Returns the content of the current file.""" response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:get-current-file-content', {}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:get-current-file-content", {} + ) return f"Received the file content" + @nbapi.auto_approve @nbapi.tool async def set_file_content(content: str, **args) -> str: @@ -189,10 +230,13 @@ async def set_file_content(content: str, **args) -> str: content: File content """ response = args["response"] - ui_cmd_response = await response.run_ui_command('lab-notebook-intelligence:set-current-file-content', {"content": content}) + ui_cmd_response = await response.run_ui_command( + "lab-notebook-intelligence:set-current-file-content", {"content": content} + ) return f"Set the file content" + NOTEBOOK_EDIT_INSTRUCTIONS = """ You are an assistant that creates and edits Jupyter notebooks. Notebooks are made up of source code cells and markdown cells. Markdown cells have source in markdown format and code cells have source in a specified programming language. If no programming language is specified, then use Python for the language of the code. @@ -246,30 +290,24 @@ async def set_file_content(content: str, **args) -> str: set_cell_type_and_source, delete_cell, insert_cell, - save_notebook + save_notebook, ], - instructions=NOTEBOOK_EDIT_INSTRUCTIONS + instructions=NOTEBOOK_EDIT_INSTRUCTIONS, ), BuiltinToolset.NotebookExecute: Toolset( id=BuiltinToolset.NotebookExecute, name="Notebook execute", description="Notebook execute", provider=None, - tools=[ - run_cell - ], - instructions=NOTEBOOK_EXECUTE_INSTRUCTIONS + tools=[run_cell], + instructions=NOTEBOOK_EXECUTE_INSTRUCTIONS, ), BuiltinToolset.PythonFileEdit: Toolset( id=BuiltinToolset.PythonFileEdit, name="Python file edit", description="Python file edit", provider=None, - tools=[ - create_new_python_file, - get_file_content, - set_file_content - ], - instructions=PYTHON_FILE_EDIT_INSTRUCTIONS + tools=[create_new_python_file, get_file_content, set_file_content], + instructions=PYTHON_FILE_EDIT_INSTRUCTIONS, ), } diff --git a/lab_notebook_intelligence/config.py b/lab_notebook_intelligence/config.py index 7d32f47..a031a3e 100644 --- a/lab_notebook_intelligence/config.py +++ b/lab_notebook_intelligence/config.py @@ -7,15 +7,20 @@ log = logging.getLogger(__name__) + class NBIConfig: def __init__(self, options: dict = {}): self.options = options - self.deprecated_env_config_file = os.path.join(sys.prefix, "share", "jupyter", "nbi-config.json") - self.deprecated_user_config_file = os.path.join(os.path.expanduser('~'), ".jupyter", "nbi-config.json") + self.deprecated_env_config_file = os.path.join( + sys.prefix, "share", "jupyter", "nbi-config.json" + ) + self.deprecated_user_config_file = os.path.join( + os.path.expanduser("~"), ".jupyter", "nbi-config.json" + ) self.nbi_env_dir = os.path.join(sys.prefix, "share", "jupyter", "nbi") - self.nbi_user_dir = os.path.join(os.path.expanduser('~'), ".jupyter", "nbi") + self.nbi_user_dir = os.path.join(os.path.expanduser("~"), ".jupyter", "nbi") self.env_config_file = os.path.join(self.nbi_env_dir, "config.json") self.user_config_file = os.path.join(self.nbi_user_dir, "config.json") self.env_mcp_file = os.path.join(self.nbi_env_dir, "mcp.json") @@ -28,59 +33,68 @@ def __init__(self, options: dict = {}): # TODO: Remove after 12/2025 if os.path.exists(self.deprecated_env_config_file): - log.warning(f"Deprecated config file found: {self.deprecated_env_config_file}. Use {self.env_config_file} and {self.env_mcp_file} instead.") + log.warning( + f"Deprecated config file found: {self.deprecated_env_config_file}. Use {self.env_config_file} and {self.env_mcp_file} instead." + ) if os.path.exists(self.deprecated_user_config_file): - log.warning(f"Deprecated config file found: {self.deprecated_user_config_file}. Use {self.user_config_file} and {self.user_mcp_file} instead.") - if self.env_mcp.get("participants") is not None or self.user_mcp.get("participants") is not None: - log.warning("MCP participants configuration is deprecated. Users should use Agent mode to select MCP tools.") + log.warning( + f"Deprecated config file found: {self.deprecated_user_config_file}. Use {self.user_config_file} and {self.user_mcp_file} instead." + ) + if ( + self.env_mcp.get("participants") is not None + or self.user_mcp.get("participants") is not None + ): + log.warning( + "MCP participants configuration is deprecated. Users should use Agent mode to select MCP tools." + ) @property def server_root_dir(self): - return self.options.get('server_root_dir', '') + return self.options.get("server_root_dir", "") def load(self): if os.path.exists(self.env_config_file): - with open(self.env_config_file, 'r') as file: + with open(self.env_config_file, "r") as file: self.env_config = json.load(file) elif os.path.exists(self.deprecated_env_config_file): - with open(self.deprecated_env_config_file, 'r') as file: + with open(self.deprecated_env_config_file, "r") as file: self.env_config = json.load(file) self.env_mcp = {} - if 'mcp' in self.env_config: - self.env_mcp = self.env_config.get('mcp', {}) - del self.env_config['mcp'] + if "mcp" in self.env_config: + self.env_mcp = self.env_config.get("mcp", {}) + del self.env_config["mcp"] else: self.env_config = {} if os.path.exists(self.user_config_file): - with open(self.user_config_file, 'r') as file: + with open(self.user_config_file, "r") as file: self.user_config = json.load(file) elif os.path.exists(self.deprecated_user_config_file): - with open(self.deprecated_user_config_file, 'r') as file: + with open(self.deprecated_user_config_file, "r") as file: self.user_config = json.load(file) self.user_mcp = {} - if 'mcp' in self.user_config: - self.user_mcp = self.user_config.get('mcp', {}) - del self.user_config['mcp'] + if "mcp" in self.user_config: + self.user_mcp = self.user_config.get("mcp", {}) + del self.user_config["mcp"] else: self.user_config = {} if os.path.exists(self.env_mcp_file): - with open(self.env_mcp_file, 'r') as file: + with open(self.env_mcp_file, "r") as file: self.env_mcp = json.load(file) if os.path.exists(self.user_mcp_file): - with open(self.user_mcp_file, 'r') as file: + with open(self.user_mcp_file, "r") as file: self.user_mcp = json.load(file) def save(self): # TODO: save only diff os.makedirs(self.nbi_user_dir, exist_ok=True) - with open(self.user_config_file, 'w') as file: + with open(self.user_config_file, "w") as file: json.dump(self.user_config, file, indent=2) - with open(self.user_mcp_file, 'w') as file: + with open(self.user_mcp_file, "w") as file: json.dump(self.user_mcp, file, indent=2) def get(self, key, default=None): @@ -92,19 +106,24 @@ def set(self, key, value): @property def default_chat_mode(self): - return self.get('default_chat_mode', 'ask') + return self.get("default_chat_mode", "ask") @property def chat_model(self): - return self.get('chat_model', {'provider': 'github-copilot', 'model': 'gpt-4.1'}) + return self.get( + "chat_model", {"provider": "github-copilot", "model": "gpt-4.1"} + ) @property def inline_completion_model(self): - return self.get('inline_completion_model', {'provider': 'github-copilot', 'model': 'gpt-4o-copilot'}) + return self.get( + "inline_completion_model", + {"provider": "github-copilot", "model": "gpt-4o-copilot"}, + ) @property def embedding_model(self): - return self.get('embedding_model', {}) + return self.get("embedding_model", {}) @property def mcp(self): @@ -114,9 +133,11 @@ def mcp(self): @property def store_github_access_token(self): - return self.get('store_github_access_token', False) + return self.get("store_github_access_token", False) @property def using_github_copilot_service(self) -> bool: - return self.chat_model.get("provider") == 'github-copilot' or \ - self.inline_completion_model.get("provider") == 'github-copilot' + return ( + self.chat_model.get("provider") == "github-copilot" + or self.inline_completion_model.get("provider") == "github-copilot" + ) diff --git a/lab_notebook_intelligence/extension.py b/lab_notebook_intelligence/extension.py index 5c2f812..59750cf 100644 --- a/lab_notebook_intelligence/extension.py +++ b/lab_notebook_intelligence/extension.py @@ -1,51 +1,79 @@ # Copyright (c) Mehmet Bektas import asyncio -from dataclasses import dataclass -import json -from os import path import datetime as dt +import json +import logging import os -from typing import Union -import uuid import threading -import logging -import tiktoken +import uuid +from dataclasses import dataclass +from os import path +from typing import Union -from jupyter_server.extension.application import ExtensionApp +import tiktoken +import tornado from jupyter_server.base.handlers import APIHandler +from jupyter_server.extension.application import ExtensionApp from jupyter_server.utils import url_path_join -import tornado from tornado import websocket from traitlets import Unicode -from lab_notebook_intelligence.api import BuiltinToolset, CancelToken, ChatMode, ChatResponse, ChatRequest, ContextRequest, ContextRequestType, RequestDataType, RequestToolSelection, ResponseStreamData, ResponseStreamDataType, BackendMessageType, SignalImpl -from lab_notebook_intelligence.ai_service_manager import AIServiceManager + import lab_notebook_intelligence.github_copilot as github_copilot +from lab_notebook_intelligence.ai_service_manager import AIServiceManager +from lab_notebook_intelligence.api import (BackendMessageType, BuiltinToolset, + CancelToken, ChatMode, ChatRequest, + ChatResponse, ContextRequest, + ContextRequestType, RequestDataType, + RequestToolSelection, + ResponseStreamData, + ResponseStreamDataType, SignalImpl) from lab_notebook_intelligence.built_in_toolsets import built_in_toolsets from lab_notebook_intelligence.util import ThreadSafeWebSocketConnector ai_service_manager: AIServiceManager = None log = logging.getLogger(__name__) -tiktoken_encoding = tiktoken.encoding_for_model('gpt-4o') +tiktoken_encoding = tiktoken.encoding_for_model("gpt-4o") + class GetCapabilitiesHandler(APIHandler): - notebook_execute_tool = 'enabled' + notebook_execute_tool = "enabled" @tornado.web.authenticated def get(self): ai_service_manager.update_models_from_config() nbi_config = ai_service_manager.nbi_config llm_providers = ai_service_manager.llm_providers.values() - notebook_execute_tool_enabled = self.notebook_execute_tool == 'enabled' or (self.notebook_execute_tool == 'env_enabled' and os.getenv('NBI_NOTEBOOK_EXECUTE_TOOL', 'disabled') == 'enabled') - allowed_builtin_toolsets = [{"id": toolset.id, "name": toolset.name} for toolset in built_in_toolsets.values() if toolset.id != BuiltinToolset.NotebookExecute or notebook_execute_tool_enabled] + notebook_execute_tool_enabled = self.notebook_execute_tool == "enabled" or ( + self.notebook_execute_tool == "env_enabled" + and os.getenv("NBI_NOTEBOOK_EXECUTE_TOOL", "disabled") == "enabled" + ) + allowed_builtin_toolsets = [ + {"id": toolset.id, "name": toolset.name} + for toolset in built_in_toolsets.values() + if toolset.id != BuiltinToolset.NotebookExecute + or notebook_execute_tool_enabled + ] mcp_servers = ai_service_manager.get_mcp_servers() - mcp_server_tools = [{"id": mcp_server.name, "tools": [{"name": tool.name, "description": tool.description} for tool in mcp_server.get_tools()]} for mcp_server in mcp_servers] + mcp_server_tools = [ + { + "id": mcp_server.name, + "tools": [ + {"name": tool.name, "description": tool.description} + for tool in mcp_server.get_tools() + ], + } + for mcp_server in mcp_servers + ] mcp_server_tools = [tool for tool in mcp_server_tools if len(tool["tools"]) > 0] # sort by server id mcp_server_tools.sort(key=lambda server: server["id"]) extensions = [] - for extension_id, toolsets in ai_service_manager.get_extension_toolsets().items(): + for ( + extension_id, + toolsets, + ) in ai_service_manager.get_extension_toolsets().items(): ts = [] for toolset in toolsets: tools = [] @@ -53,28 +81,30 @@ def get(self): tools.append({"name": tool.name, "description": tool.description}) # sort by tool name tools.sort(key=lambda tool: tool["name"]) - ts.append({ - "id": toolset.id, - "name": toolset.name, - "description": toolset.description, - "tools": tools - }) + ts.append( + { + "id": toolset.id, + "name": toolset.name, + "description": toolset.description, + "tools": tools, + } + ) # sort by toolset name ts.sort(key=lambda toolset: toolset["name"]) extension = ai_service_manager.get_extension(extension_id) - extensions.append({ - "id": extension_id, - "name": extension.name, - "toolsets": ts - }) + extensions.append( + {"id": extension_id, "name": extension.name, "toolsets": ts} + ) # sort by extension id extensions.sort(key=lambda extension: extension["id"]) response = { - "user_home_dir": os.path.expanduser('~'), + "user_home_dir": os.path.expanduser("~"), "nbi_user_config_dir": nbi_config.nbi_user_dir, "using_github_copilot_service": nbi_config.using_github_copilot_service, - "llm_providers": [{"id": provider.id, "name": provider.name} for provider in llm_providers], + "llm_providers": [ + {"id": provider.id, "name": provider.name} for provider in llm_providers + ], "chat_models": ai_service_manager.chat_model_ids, "inline_completion_models": ai_service_manager.inline_completion_model_ids, "embedding_models": ai_service_manager.embedding_model_ids, @@ -86,26 +116,36 @@ def get(self): "tool_config": { "builtinToolsets": allowed_builtin_toolsets, "mcpServers": mcp_server_tools, - "extensions": extensions + "extensions": extensions, }, - "default_chat_mode": nbi_config.default_chat_mode + "default_chat_mode": nbi_config.default_chat_mode, } for participant_id in ai_service_manager.chat_participants: participant = ai_service_manager.chat_participants[participant_id] - response["chat_participants"].append({ - "id": participant.id, - "name": participant.name, - "description": participant.description, - "iconPath": participant.icon_path, - "commands": [command.name for command in participant.commands] - }) + response["chat_participants"].append( + { + "id": participant.id, + "name": participant.name, + "description": participant.description, + "iconPath": participant.icon_path, + "commands": [command.name for command in participant.commands], + } + ) self.finish(json.dumps(response)) + class ConfigHandler(APIHandler): @tornado.web.authenticated def post(self): data = json.loads(self.request.body) - valid_keys = set(["default_chat_mode", "chat_model", "inline_completion_model", "store_github_access_token"]) + valid_keys = set( + [ + "default_chat_mode", + "chat_model", + "inline_completion_model", + "store_github_access_token", + ] + ) for key in data: if key in valid_keys: ai_service_manager.nbi_config.set(key, data[key]) @@ -118,6 +158,7 @@ def post(self): ai_service_manager.update_models_from_config() self.finish(json.dumps({})) + class UpdateProviderModelsHandler(APIHandler): @tornado.web.authenticated def post(self): @@ -126,14 +167,23 @@ def post(self): ai_service_manager.ollama_llm_provider.update_chat_model_list() self.finish(json.dumps({})) + class ReloadMCPServersHandler(APIHandler): @tornado.web.authenticated def post(self): ai_service_manager.nbi_config.load() ai_service_manager.update_mcp_servers() - self.finish(json.dumps({ - "mcpServers": [{"id": server.name} for server in ai_service_manager.get_mcp_servers()] - })) + self.finish( + json.dumps( + { + "mcpServers": [ + {"id": server.name} + for server in ai_service_manager.get_mcp_servers() + ] + } + ) + ) + class MCPConfigFileHandler(APIHandler): @tornado.web.authenticated @@ -156,15 +206,16 @@ def post(self): self.finish(json.dumps({"status": "error", "message": str(e)})) return + class CreateDynamicMCPConfigHandler(APIHandler): @tornado.web.authenticated def post(self): try: # Get the directory where JupyterLab was started (user's working directory) user_root_dir = NotebookIntelligence.root_dir - + print(f"Creating dynamic MCP config for directory: {user_root_dir}") - + # Create dynamic MCP config with filesystem servers dynamic_mcp_config = { "mcpServers": { @@ -173,13 +224,22 @@ def post(self): "args": [ "-y", "@modelcontextprotocol/server-filesystem", - user_root_dir - ] - } + user_root_dir, + ], + }, + "qbraid-web-search": { + "command": "uv", + "args": ["tool", "run", "web-browser-mcp-server"], + "env": { + "REQUEST_TIMEOUT": "60", + }, + }, + # add the MCP server for accessing docs.qbraid.com + "qbraid-docs-search": {"url": "https://docs.qbraid.com/mcp"}, } } - - # Add qBraid environments MCP server + + # Add qBraid environments MCP server qbraid_envs_dir = os.path.expanduser("~/.qbraid/environments/") print(f"qBraid environments directory: {qbraid_envs_dir}") if os.path.exists(qbraid_envs_dir): @@ -189,11 +249,11 @@ def post(self): "args": [ "-y", "@modelcontextprotocol/server-filesystem", - qbraid_envs_dir - ] + qbraid_envs_dir, + ], } print(f"Added qBraid environments MCP server") - + # TODO: Uncomment and fix the code below to add individual Python execution servers for each qBraid environment # # Discover individual environments and add Python execution servers @@ -202,12 +262,12 @@ def post(self): # for env_name in os.listdir(qbraid_envs_dir): # env_path = os.path.join(qbraid_envs_dir, env_name) # python_executable = os.path.join(env_path, "bin", "python") - + # # Check if this is a valid environment with Python - # if (os.path.isdir(env_path) and - # os.path.exists(python_executable) and + # if (os.path.isdir(env_path) and + # os.path.exists(python_executable) and # not env_name.startswith('.')): - + # # Add Python execution server for this environment # server_name = f"python-{env_name}" # dynamic_mcp_config["mcpServers"][server_name] = { @@ -222,36 +282,44 @@ def post(self): # } # env_count += 1 # print(f"Added Python execution server for environment: {env_name}") - + # print(f"Discovered and added {env_count} qBraid environment Python servers") # except Exception as e: # print(f"Error discovering qBraid environments: {e}") - + else: print(f"qBraid environments directory not found: {qbraid_envs_dir}") - + # Save to user's MCP config (this will merge with existing config) ai_service_manager.nbi_config.user_mcp = dynamic_mcp_config ai_service_manager.nbi_config.save() ai_service_manager.nbi_config.load() ai_service_manager.update_mcp_servers() - - self.finish(json.dumps({ - "status": "ok", - "message": f"Dynamic MCP config created for directory: {user_root_dir}" - })) + + self.finish( + json.dumps( + { + "status": "ok", + "message": f"Dynamic MCP config created for directory: {user_root_dir}", + } + ) + ) except Exception as e: self.finish(json.dumps({"status": "error", "message": str(e)})) return + class EmitTelemetryEventHandler(APIHandler): @tornado.web.authenticated def post(self): event = json.loads(self.request.body) - thread = threading.Thread(target=asyncio.run, args=(ai_service_manager.emit_telemetry_event(event),)) + thread = threading.Thread( + target=asyncio.run, args=(ai_service_manager.emit_telemetry_event(event),) + ) thread.start() self.finish(json.dumps({})) + class GetGitHubLoginStatusHandler(APIHandler): # The following decorator should be present on all verb methods (head, get, post, # patch, put, delete, options) to ensure only authorized user can request the @@ -260,34 +328,42 @@ class GetGitHubLoginStatusHandler(APIHandler): def get(self): self.finish(json.dumps(github_copilot.get_login_status())) + class PostGitHubLoginHandler(APIHandler): @tornado.web.authenticated def post(self): device_verification_info = github_copilot.login() if device_verification_info is None: self.set_status(500) - self.finish(json.dumps({ - "error": "Failed to get device verification info from GitHub Copilot" - })) + self.finish( + json.dumps( + { + "error": "Failed to get device verification info from GitHub Copilot" + } + ) + ) return self.finish(json.dumps(device_verification_info)) + class GetGitHubLogoutHandler(APIHandler): @tornado.web.authenticated def get(self): self.finish(json.dumps(github_copilot.logout())) + class ChatHistory: """ History of chat messages, key is chat id, value is list of messages keep the last 10 messages in the same chat participant """ + MAX_MESSAGES = 10 def __init__(self): self.messages = {} - def clear(self, chatId = None): + def clear(self, chatId=None): if chatId is None: self.messages = {} return True @@ -304,21 +380,28 @@ def add_message(self, chatId, message): # clear the chat history if participant changed if message["role"] == "user": existing_messages = self.messages[chatId] - prev_user_message = next((m for m in reversed(existing_messages) if m["role"] == "user"), None) + prev_user_message = next( + (m for m in reversed(existing_messages) if m["role"] == "user"), None + ) if prev_user_message is not None: - (current_participant, command, prompt) = AIServiceManager.parse_prompt(message["content"]) - (prev_participant, command, prompt) = AIServiceManager.parse_prompt(prev_user_message["content"]) + (current_participant, command, prompt) = AIServiceManager.parse_prompt( + message["content"] + ) + (prev_participant, command, prompt) = AIServiceManager.parse_prompt( + prev_user_message["content"] + ) if current_participant != prev_participant: self.messages[chatId] = [] self.messages[chatId].append(message) # limit number of messages kept in history if len(self.messages[chatId]) > ChatHistory.MAX_MESSAGES: - self.messages[chatId] = self.messages[chatId][-ChatHistory.MAX_MESSAGES:] + self.messages[chatId] = self.messages[chatId][-ChatHistory.MAX_MESSAGES :] def get_history(self, chatId): return self.messages.get(chatId, []) + class WebsocketCopilotResponseEmitter(ChatResponse): def __init__(self, chatId, messageId, websocket_handler, chat_history): super().__init__() @@ -337,10 +420,14 @@ def message_id(self) -> str: return self.messageId def stream(self, data: Union[ResponseStreamData, dict]): - data_type = ResponseStreamDataType.LLMRaw if type(data) is dict else data.data_type + data_type = ( + ResponseStreamDataType.LLMRaw if type(data) is dict else data.data_type + ) if data_type == ResponseStreamDataType.Markdown: - self.chat_history.add_message(self.chatId, {"role": "assistant", "content": data.content}) + self.chat_history.add_message( + self.chatId, {"role": "assistant", "content": data.content} + ) data = { "choices": [ { @@ -348,10 +435,10 @@ def stream(self, data: Union[ResponseStreamData, dict]): "nbiContent": { "type": data_type, "content": data.content, - "detail": data.detail + "detail": data.detail, }, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -361,12 +448,9 @@ def stream(self, data: Union[ResponseStreamData, dict]): "choices": [ { "delta": { - "nbiContent": { - "type": data_type, - "content": data.content - }, + "nbiContent": {"type": data_type, "content": data.content}, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -378,13 +462,13 @@ def stream(self, data: Union[ResponseStreamData, dict]): "delta": { "nbiContent": { "type": data_type, - "content" : { + "content": { "source": data.source, - "height": data.height - } + "height": data.height, + }, }, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -396,13 +480,10 @@ def stream(self, data: Union[ResponseStreamData, dict]): "delta": { "nbiContent": { "type": data_type, - "content": { - "uri": data.uri, - "title": data.title - } + "content": {"uri": data.uri, "title": data.title}, }, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -417,11 +498,11 @@ def stream(self, data: Union[ResponseStreamData, dict]): "content": { "title": data.title, "commandId": data.commandId, - "args": data.args if data.args is not None else {} - } + "args": data.args if data.args is not None else {}, + }, }, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -431,12 +512,9 @@ def stream(self, data: Union[ResponseStreamData, dict]): "choices": [ { "delta": { - "nbiContent": { - "type": data_type, - "content": data.title - }, + "nbiContent": {"type": data_type, "content": data.title}, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -451,14 +529,30 @@ def stream(self, data: Union[ResponseStreamData, dict]): "content": { "title": data.title, "message": data.message, - "confirmArgs": data.confirmArgs if data.confirmArgs is not None else {}, - "cancelArgs": data.cancelArgs if data.cancelArgs is not None else {}, - "confirmLabel": data.confirmLabel if data.confirmLabel is not None else "Approve", - "cancelLabel": data.cancelLabel if data.cancelLabel is not None else "Cancel" - } + "confirmArgs": ( + data.confirmArgs + if data.confirmArgs is not None + else {} + ), + "cancelArgs": ( + data.cancelArgs + if data.cancelArgs is not None + else {} + ), + "confirmLabel": ( + data.confirmLabel + if data.confirmLabel is not None + else "Approve" + ), + "cancelLabel": ( + data.cancelLabel + if data.cancelLabel is not None + else "Cancel" + ), + }, }, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -469,12 +563,9 @@ def stream(self, data: Union[ResponseStreamData, dict]): "choices": [ { "delta": { - "nbiContent": { - "type": data_type, - "content": data.content - }, + "nbiContent": {"type": data_type, "content": data.content}, "content": "", - "role": "assistant" + "role": "assistant", } } ] @@ -482,45 +573,57 @@ def stream(self, data: Union[ResponseStreamData, dict]): part = content if part is not None: self.streamed_contents.append(part) - else: # ResponseStreamDataType.LLMRaw + else: # ResponseStreamDataType.LLMRaw if len(data.get("choices", [])) > 0: part = data["choices"][0].get("delta", {}).get("content", "") if part is not None: self.streamed_contents.append(part) - self.websocket_handler.write_message({ - "id": self.messageId, - "participant": self.participant_id, - "type": BackendMessageType.StreamMessage, - "data": data, - "created": dt.datetime.now().isoformat() - }) + self.websocket_handler.write_message( + { + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.StreamMessage, + "data": data, + "created": dt.datetime.now().isoformat(), + } + ) def finish(self) -> None: - self.chat_history.add_message(self.chatId, {"role": "assistant", "content": "".join(self.streamed_contents)}) + self.chat_history.add_message( + self.chatId, + {"role": "assistant", "content": "".join(self.streamed_contents)}, + ) self.streamed_contents = [] - self.websocket_handler.write_message({ - "id": self.messageId, - "participant": self.participant_id, - "type": BackendMessageType.StreamEnd, - "data": {} - }) + self.websocket_handler.write_message( + { + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.StreamEnd, + "data": {}, + } + ) async def run_ui_command(self, command: str, args: dict = {}) -> None: callback_id = str(uuid.uuid4()) - self.websocket_handler.write_message({ - "id": self.messageId, - "participant": self.participant_id, - "type": BackendMessageType.RunUICommand, - "data": { - "callback_id": callback_id, - "commandId": command, - "args": args + self.websocket_handler.write_message( + { + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.RunUICommand, + "data": { + "callback_id": callback_id, + "commandId": command, + "args": args, + }, } - }) - response = await ChatResponse.wait_for_run_ui_command_response(self, callback_id) + ) + response = await ChatResponse.wait_for_run_ui_command_response( + self, callback_id + ) return response + class CancelTokenImpl(CancelToken): def __init__(self): super().__init__() @@ -530,11 +633,13 @@ def cancel_request(self) -> None: self._cancellation_requested = True self._cancellation_signal.emit() + @dataclass class MessageCallbackHandlers: response_emitter: WebsocketCopilotResponseEmitter cancel_token: CancelTokenImpl + class WebsocketCopilotHandler(websocket.WebSocketHandler): def __init__(self, application, request, **kwargs): super().__init__(application, request, **kwargs) @@ -549,27 +654,35 @@ def open(self): def on_message(self, message): msg = json.loads(message) - messageId = msg['id'] - messageType = msg['type'] + messageId = msg["id"] + messageType = msg["type"] if messageType == RequestDataType.ChatRequest: - data = msg['data'] - chatId = data['chatId'] - prompt = data['prompt'] - language = data['language'] - filename = data['filename'] - additionalContext = data.get('additionalContext', []) - chat_mode = ChatMode('agent', 'Agent') if data.get('chatMode', 'ask') == 'agent' else ChatMode('ask', 'Ask') - toolSelections = data.get('toolSelections', {}) + data = msg["data"] + chatId = data["chatId"] + prompt = data["prompt"] + language = data["language"] + filename = data["filename"] + additionalContext = data.get("additionalContext", []) + chat_mode = ( + ChatMode("agent", "Agent") + if data.get("chatMode", "ask") == "agent" + else ChatMode("ask", "Ask") + ) + toolSelections = data.get("toolSelections", {}) tool_selection = RequestToolSelection( - built_in_toolsets=toolSelections.get('builtinToolsets', []), - mcp_server_tools=toolSelections.get('mcpServers', {}), - extension_tools=toolSelections.get('extensions', {}) + built_in_toolsets=toolSelections.get("builtinToolsets", []), + mcp_server_tools=toolSelections.get("mcpServers", {}), + extension_tools=toolSelections.get("extensions", {}), ) request_chat_history = self.chat_history.get_history(chatId).copy() - token_limit = 100 if ai_service_manager.chat_model is None else ai_service_manager.chat_model.context_window - token_budget = 0.8 * token_limit + token_limit = ( + 100 + if ai_service_manager.chat_model is None + else ai_service_manager.chat_model.context_window + ) + token_budget = 0.8 * token_limit for context in additionalContext: file_path = context["filePath"] @@ -578,95 +691,213 @@ def on_message(self, message): start_line = context["startLine"] end_line = context["endLine"] current_cell_contents = context["currentCellContents"] - current_cell_input = current_cell_contents["input"] if current_cell_contents is not None else "" - current_cell_output = current_cell_contents["output"] if current_cell_contents is not None else "" - current_cell_context = f"This is a Jupyter notebook and currently selected cell input is: ```{current_cell_input}``` and currently selected cell output is: ```{current_cell_output}```. If user asks a question about 'this' cell then assume that user is referring to currently selected cell." if current_cell_contents is not None else "" + current_cell_input = ( + current_cell_contents["input"] + if current_cell_contents is not None + else "" + ) + current_cell_output = ( + current_cell_contents["output"] + if current_cell_contents is not None + else "" + ) + current_cell_context = ( + f"This is a Jupyter notebook and currently selected cell input is: ```{current_cell_input}``` and currently selected cell output is: ```{current_cell_output}```. If user asks a question about 'this' cell then assume that user is referring to currently selected cell." + if current_cell_contents is not None + else "" + ) context_content = context["content"] token_count = len(tiktoken_encoding.encode(context_content)) if token_count > token_budget: - context_content = context_content[:int(token_budget)] + "..." + context_content = context_content[: int(token_budget)] + "..." - request_chat_history.append({"role": "user", "content": f"Use this as additional context: ```{context_content}```. It is from current file: '{filename}' at path '{file_path}', lines: {start_line} - {end_line}. {current_cell_context}"}) - self.chat_history.add_message(chatId, {"role": "user", "content": f"This file was provided as additional context: '{filename}' at path '{file_path}', lines: {start_line} - {end_line}. {current_cell_context}"}) + request_chat_history.append( + { + "role": "user", + "content": f"Use this as additional context: ```{context_content}```. It is from current file: '{filename}' at path '{file_path}', lines: {start_line} - {end_line}. {current_cell_context}", + } + ) + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"This file was provided as additional context: '{filename}' at path '{file_path}', lines: {start_line} - {end_line}. {current_cell_context}", + }, + ) self.chat_history.add_message(chatId, {"role": "user", "content": prompt}) request_chat_history.append({"role": "user", "content": prompt}) - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history) + response_emitter = WebsocketCopilotResponseEmitter( + chatId, messageId, self, self.chat_history + ) cancel_token = CancelTokenImpl() - self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) - thread = threading.Thread(target=asyncio.run, args=(ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token), response_emitter),)) + self._messageCallbackHandlers[messageId] = MessageCallbackHandlers( + response_emitter, cancel_token + ) + thread = threading.Thread( + target=asyncio.run, + args=( + ai_service_manager.handle_chat_request( + ChatRequest( + chat_mode=chat_mode, + tool_selection=tool_selection, + prompt=prompt, + chat_history=request_chat_history, + cancel_token=cancel_token, + ), + response_emitter, + ), + ), + ) thread.start() elif messageType == RequestDataType.GenerateCode: - data = msg['data'] - chatId = data['chatId'] - prompt = data['prompt'] - prefix = data['prefix'] - suffix = data['suffix'] - existing_code = data['existingCode'] - language = data['language'] - filename = data['filename'] - chat_mode = ChatMode('ask', 'Ask') - if prefix != '': - self.chat_history.add_message(chatId, {"role": "user", "content": f"This code section comes before the code section you will generate, use as context. Leading content: ```{prefix}```"}) - if suffix != '': - self.chat_history.add_message(chatId, {"role": "user", "content": f"This code section comes after the code section you will generate, use as context. Trailing content: ```{suffix}```"}) - if existing_code != '': - self.chat_history.add_message(chatId, {"role": "user", "content": f"You are asked to modify the existing code. Generate a replacement for this existing code : ```{existing_code}```"}) - self.chat_history.add_message(chatId, {"role": "user", "content": f"Generate code for: {prompt}"}) - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history) + data = msg["data"] + chatId = data["chatId"] + prompt = data["prompt"] + prefix = data["prefix"] + suffix = data["suffix"] + existing_code = data["existingCode"] + language = data["language"] + filename = data["filename"] + chat_mode = ChatMode("ask", "Ask") + if prefix != "": + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"This code section comes before the code section you will generate, use as context. Leading content: ```{prefix}```", + }, + ) + if suffix != "": + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"This code section comes after the code section you will generate, use as context. Trailing content: ```{suffix}```", + }, + ) + if existing_code != "": + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"You are asked to modify the existing code. Generate a replacement for this existing code : ```{existing_code}```", + }, + ) + self.chat_history.add_message( + chatId, {"role": "user", "content": f"Generate code for: {prompt}"} + ) + response_emitter = WebsocketCopilotResponseEmitter( + chatId, messageId, self, self.chat_history + ) cancel_token = CancelTokenImpl() - self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) - existing_code_message = " Update the existing code section and return a modified version. Don't just return the update, recreate the existing code section with the update." if existing_code != '' else '' - thread = threading.Thread(target=asyncio.run, args=(ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}),)) + self._messageCallbackHandlers[messageId] = MessageCallbackHandlers( + response_emitter, cancel_token + ) + existing_code_message = ( + " Update the existing code section and return a modified version. Don't just return the update, recreate the existing code section with the update." + if existing_code != "" + else "" + ) + thread = threading.Thread( + target=asyncio.run, + args=( + ai_service_manager.handle_chat_request( + ChatRequest( + chat_mode=chat_mode, + prompt=prompt, + chat_history=self.chat_history.get_history(chatId), + cancel_token=cancel_token, + ), + response_emitter, + options={ + "system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content." + }, + ), + ), + ) thread.start() elif messageType == RequestDataType.InlineCompletionRequest: - data = msg['data'] - chatId = data['chatId'] - prefix = data['prefix'] - suffix = data['suffix'] - language = data['language'] - filename = data['filename'] + data = msg["data"] + chatId = data["chatId"] + prefix = data["prefix"] + suffix = data["suffix"] + language = data["language"] + filename = data["filename"] chat_history = ChatHistory() - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, chat_history) + response_emitter = WebsocketCopilotResponseEmitter( + chatId, messageId, self, chat_history + ) cancel_token = CancelTokenImpl() - self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) + self._messageCallbackHandlers[messageId] = MessageCallbackHandlers( + response_emitter, cancel_token + ) - thread = threading.Thread(target=asyncio.run, args=(WebsocketCopilotHandler.handle_inline_completions(prefix, suffix, language, filename, response_emitter, cancel_token),)) + thread = threading.Thread( + target=asyncio.run, + args=( + WebsocketCopilotHandler.handle_inline_completions( + prefix, + suffix, + language, + filename, + response_emitter, + cancel_token, + ), + ), + ) thread.start() elif messageType == RequestDataType.ChatUserInput: handlers = self._messageCallbackHandlers.get(messageId) if handlers is None: return - handlers.response_emitter.on_user_input(msg['data']) + handlers.response_emitter.on_user_input(msg["data"]) elif messageType == RequestDataType.ClearChatHistory: self.chat_history.clear() elif messageType == RequestDataType.RunUICommandResponse: handlers = self._messageCallbackHandlers.get(messageId) if handlers is None: return - handlers.response_emitter.on_run_ui_command_response(msg['data']) - elif messageType == RequestDataType.CancelChatRequest or messageType == RequestDataType.CancelInlineCompletionRequest: + handlers.response_emitter.on_run_ui_command_response(msg["data"]) + elif ( + messageType == RequestDataType.CancelChatRequest + or messageType == RequestDataType.CancelInlineCompletionRequest + ): handlers = self._messageCallbackHandlers.get(messageId) if handlers is None: return handlers.cancel_token.cancel_request() - + def on_close(self): pass - async def handle_inline_completions(prefix, suffix, language, filename, response_emitter, cancel_token): + async def handle_inline_completions( + prefix, suffix, language, filename, response_emitter, cancel_token + ): if ai_service_manager.inline_completion_model is None: response_emitter.finish() return - context = await ai_service_manager.get_completion_context(ContextRequest(ContextRequestType.InlineCompletion, prefix, suffix, language, filename, participant=ai_service_manager.get_chat_participant(prefix), cancel_token=cancel_token)) + context = await ai_service_manager.get_completion_context( + ContextRequest( + ContextRequestType.InlineCompletion, + prefix, + suffix, + language, + filename, + participant=ai_service_manager.get_chat_participant(prefix), + cancel_token=cancel_token, + ) + ) if cancel_token.is_cancel_requested: response_emitter.finish() return - completions = ai_service_manager.inline_completion_model.inline_completions(prefix, suffix, language, filename, context, cancel_token) + completions = ai_service_manager.inline_completion_model.inline_completions( + prefix, suffix, language, filename, context, cancel_token + ) if cancel_token.is_cancel_requested: response_emitter.finish() return @@ -674,6 +905,7 @@ async def handle_inline_completions(prefix, suffix, language, filename, response response_emitter.stream({"completions": completions}) response_emitter.finish() + class NotebookIntelligence(ExtensionApp): name = "lab_notebook_intelligence" default_url = "/lab-notebook-intelligence" @@ -684,7 +916,7 @@ class NotebookIntelligence(ExtensionApp): template_paths = [] settings = {} handlers = [] - root_dir = '' + root_dir = "" notebook_execute_tool = Unicode( default_value="enabled", @@ -704,11 +936,13 @@ def initialize_settings(self): def initialize_handlers(self): NotebookIntelligence.root_dir = self.serverapp.root_dir - server_root_dir = os.path.expanduser(self.serverapp.web_app.settings["server_root_dir"]) + server_root_dir = os.path.expanduser( + self.serverapp.web_app.settings["server_root_dir"] + ) self.initialize_ai_service(server_root_dir) self._setup_handlers(self.serverapp.web_app) self.serverapp.log.info(f"Registered {self.name} server extension") - + def initialize_ai_service(self, server_root_dir: str): global ai_service_manager ai_service_manager = AIServiceManager({"server_root_dir": server_root_dir}) @@ -724,17 +958,39 @@ def _setup_handlers(self, web_app): host_pattern = ".*$" base_url = web_app.settings["base_url"] - route_pattern_capabilities = url_path_join(base_url, "lab-notebook-intelligence", "capabilities") - route_pattern_config = url_path_join(base_url, "lab-notebook-intelligence", "config") - route_pattern_update_provider_models = url_path_join(base_url, "lab-notebook-intelligence", "update-provider-models") - route_pattern_reload_mcp_servers = url_path_join(base_url, "lab-notebook-intelligence", "reload-mcp-servers") - route_pattern_mcp_config_file = url_path_join(base_url, "lab-notebook-intelligence", "mcp-config-file") - route_pattern_create_dynamic_mcp_config = url_path_join(base_url, "lab-notebook-intelligence", "create-dynamic-mcp-config") - route_pattern_emit_telemetry_event = url_path_join(base_url, "lab-notebook-intelligence", "emit-telemetry-event") - route_pattern_github_login_status = url_path_join(base_url, "lab-notebook-intelligence", "gh-login-status") - route_pattern_github_login = url_path_join(base_url, "lab-notebook-intelligence", "gh-login") - route_pattern_github_logout = url_path_join(base_url, "lab-notebook-intelligence", "gh-logout") - route_pattern_copilot = url_path_join(base_url, "lab-notebook-intelligence", "copilot") + route_pattern_capabilities = url_path_join( + base_url, "lab-notebook-intelligence", "capabilities" + ) + route_pattern_config = url_path_join( + base_url, "lab-notebook-intelligence", "config" + ) + route_pattern_update_provider_models = url_path_join( + base_url, "lab-notebook-intelligence", "update-provider-models" + ) + route_pattern_reload_mcp_servers = url_path_join( + base_url, "lab-notebook-intelligence", "reload-mcp-servers" + ) + route_pattern_mcp_config_file = url_path_join( + base_url, "lab-notebook-intelligence", "mcp-config-file" + ) + route_pattern_create_dynamic_mcp_config = url_path_join( + base_url, "lab-notebook-intelligence", "create-dynamic-mcp-config" + ) + route_pattern_emit_telemetry_event = url_path_join( + base_url, "lab-notebook-intelligence", "emit-telemetry-event" + ) + route_pattern_github_login_status = url_path_join( + base_url, "lab-notebook-intelligence", "gh-login-status" + ) + route_pattern_github_login = url_path_join( + base_url, "lab-notebook-intelligence", "gh-login" + ) + route_pattern_github_logout = url_path_join( + base_url, "lab-notebook-intelligence", "gh-logout" + ) + route_pattern_copilot = url_path_join( + base_url, "lab-notebook-intelligence", "copilot" + ) GetCapabilitiesHandler.notebook_execute_tool = self.notebook_execute_tool NotebookIntelligence.handlers = [ (route_pattern_capabilities, GetCapabilitiesHandler), diff --git a/lab_notebook_intelligence/github_copilot.py b/lab_notebook_intelligence/github_copilot.py index 5b3368c..1490b5b 100644 --- a/lab_notebook_intelligence/github_copilot.py +++ b/lab_notebook_intelligence/github_copilot.py @@ -3,28 +3,44 @@ # GitHub auth and inline completion sections are derivative of https://github.com/B00TK1D/copilot-api import base64 +import datetime as dt +import json +import logging +import os +import secrets +import threading +import time +import uuid from enum import Enum -import os, json, time, requests, threading from typing import Any -import uuid -import secrets + +import requests import sseclient -import datetime as dt -import logging -from lab_notebook_intelligence.api import BackendMessageType, CancelToken, ChatResponse, CompletionContext, MarkdownData -from lab_notebook_intelligence.util import decrypt_with_password, encrypt_with_password, ThreadSafeWebSocketConnector + +from lab_notebook_intelligence.api import (BackendMessageType, CancelToken, + ChatResponse, CompletionContext, + MarkdownData) +from lab_notebook_intelligence.util import (ThreadSafeWebSocketConnector, + decrypt_with_password, + encrypt_with_password) from ._version import __version__ as NBI_VERSION log = logging.getLogger(__name__) GHE_SUBDOMAIN = os.getenv("NBI_GHE_SUBDOMAIN", "") -GH_WEB_BASE_URL = "https://github.com" if GHE_SUBDOMAIN == "" else f"https://{GHE_SUBDOMAIN}.ghe.com" -GH_REST_API_BASE_URL = "https://api.github.com" if GHE_SUBDOMAIN == "" else f"https://api.{GHE_SUBDOMAIN}.ghe.com" - -EDITOR_VERSION = f"NotebookIntelligence/{NBI_VERSION}" -EDITOR_PLUGIN_VERSION = f"NotebookIntelligence/{NBI_VERSION}" -USER_AGENT = f"NotebookIntelligence/{NBI_VERSION}" +GH_WEB_BASE_URL = ( + "https://github.com" if GHE_SUBDOMAIN == "" else f"https://{GHE_SUBDOMAIN}.ghe.com" +) +GH_REST_API_BASE_URL = ( + "https://api.github.com" + if GHE_SUBDOMAIN == "" + else f"https://api.{GHE_SUBDOMAIN}.ghe.com" +) + +EDITOR_VERSION = f"LabNotebookIntelligence/{NBI_VERSION}" +EDITOR_PLUGIN_VERSION = f"LabNotebookIntelligence/{NBI_VERSION}" +USER_AGENT = f"LabNotebookIntelligence/{NBI_VERSION}" CLIENT_ID = "Iv1.b507a08c87ecfe98" MACHINE_ID = secrets.token_hex(33)[0:65] @@ -34,18 +50,20 @@ ACCESS_TOKEN_THREAD_SLEEP_INTERVAL = 5 TOKEN_THREAD_SLEEP_INTERVAL = 3 TOKEN_FETCH_INTERVAL = 15 -NL = '\n' +NL = "\n" -LoginStatus = Enum('LoginStatus', ['NOT_LOGGED_IN', 'ACTIVATING_DEVICE', 'LOGGING_IN', 'LOGGED_IN']) +LoginStatus = Enum( + "LoginStatus", ["NOT_LOGGED_IN", "ACTIVATING_DEVICE", "LOGGING_IN", "LOGGED_IN"] +) github_auth = { "verification_uri": None, "user_code": None, "device_code": None, "access_token": None, - "status" : LoginStatus.NOT_LOGGED_IN, + "status": LoginStatus.NOT_LOGGED_IN, "token": None, - "token_expires_at": dt.datetime.now() + "token_expires_at": dt.datetime.now(), } stop_requested = False @@ -58,74 +76,88 @@ websocket_connector: ThreadSafeWebSocketConnector = None github_login_status_change_updater_enabled = False + def enable_github_login_status_change_updater(enabled: bool): global github_login_status_change_updater_enabled github_login_status_change_updater_enabled = enabled + def emit_github_login_status_change(): if github_login_status_change_updater_enabled and websocket_connector is not None: - websocket_connector.write_message({ - "type": BackendMessageType.GitHubCopilotLoginStatusChange, - "data": { - "status": github_auth["status"].name + websocket_connector.write_message( + { + "type": BackendMessageType.GitHubCopilotLoginStatusChange, + "data": {"status": github_auth["status"].name}, } - }) + ) + def get_login_status(): global github_auth - response = { - "status": github_auth["status"].name - } + response = {"status": github_auth["status"].name} if github_auth["status"] is LoginStatus.ACTIVATING_DEVICE: - response.update({ - "verification_uri": github_auth["verification_uri"], - "user_code": github_auth["user_code"] - }) + response.update( + { + "verification_uri": github_auth["verification_uri"], + "user_code": github_auth["user_code"], + } + ) return response -deprecated_user_data_file = os.path.join(os.path.expanduser('~'), ".jupyter", "nbi-data.json") -user_data_file = os.path.join(os.path.expanduser('~'), ".jupyter", "nbi", "user-data.json") -access_token_password = os.getenv("NBI_GH_ACCESS_TOKEN_PASSWORD", "nbi-access-token-password") + +deprecated_user_data_file = os.path.join( + os.path.expanduser("~"), ".jupyter", "nbi-data.json" +) +user_data_file = os.path.join( + os.path.expanduser("~"), ".jupyter", "nbi", "user-data.json" +) +access_token_password = os.getenv( + "NBI_GH_ACCESS_TOKEN_PASSWORD", "nbi-access-token-password" +) + def read_stored_github_access_token() -> str: try: if os.path.exists(user_data_file): - with open(user_data_file, 'r') as file: + with open(user_data_file, "r") as file: user_data = json.load(file) elif os.path.exists(deprecated_user_data_file): - with open(deprecated_user_data_file, 'r') as file: + with open(deprecated_user_data_file, "r") as file: user_data = json.load(file) else: user_data = {} - base64_access_token = user_data.get('github_access_token') + base64_access_token = user_data.get("github_access_token") if base64_access_token is not None: - base64_bytes = base64.b64decode(base64_access_token.encode('utf-8')) - return decrypt_with_password(access_token_password, base64_bytes).decode('utf-8') + base64_bytes = base64.b64decode(base64_access_token.encode("utf-8")) + return decrypt_with_password(access_token_password, base64_bytes).decode( + "utf-8" + ) except Exception as e: log.error(f"Failed to read GitHub access token: {e}") return None + def write_github_access_token(access_token: str) -> bool: try: - encrypted_access_token = encrypt_with_password(access_token_password, access_token.encode()) + encrypted_access_token = encrypt_with_password( + access_token_password, access_token.encode() + ) base64_bytes = base64.b64encode(encrypted_access_token) - base64_access_token = base64_bytes.decode('utf-8') + base64_access_token = base64_bytes.decode("utf-8") if os.path.exists(user_data_file): - with open(user_data_file, 'r') as file: + with open(user_data_file, "r") as file: user_data = json.load(file) else: user_data = {} - user_data.update({ - 'github_access_token': base64_access_token - }) - with open(user_data_file, 'w') as file: + user_data.update({"github_access_token": base64_access_token}) + with open(user_data_file, "w") as file: json.dump(user_data, file, indent=4) return True except Exception as e: @@ -133,20 +165,21 @@ def write_github_access_token(access_token: str) -> bool: return False + def delete_stored_github_access_token() -> bool: try: if os.path.exists(user_data_file): - with open(user_data_file, 'r') as file: + with open(user_data_file, "r") as file: user_data = json.load(file) else: user_data = {} try: - del user_data['github_access_token'] + del user_data["github_access_token"] except: pass - with open(user_data_file, 'w') as file: + with open(user_data_file, "w") as file: json.dump(user_data, file, indent=4) return True except Exception as e: @@ -154,6 +187,7 @@ def delete_stored_github_access_token() -> bool: return False + def login_with_existing_credentials(store_access_token: bool): global github_access_token_provided, remember_github_access_token @@ -170,66 +204,71 @@ def login_with_existing_credentials(store_access_token: bool): login() if os.path.exists(deprecated_user_data_file): # TODO: remove after 12/2025 - log.warning(f"Deprecated user data file found: {deprecated_user_data_file}. Removing it now. Use {user_data_file} instead.") + log.warning( + f"Deprecated user data file found: {deprecated_user_data_file}. Removing it now. Use {user_data_file} instead." + ) store_github_access_token() os.remove(deprecated_user_data_file) + def store_github_access_token(): access_token = github_auth["access_token"] if access_token is not None: if not write_github_access_token(access_token): log.error("Failed to store GitHub access token") + def login(): login_info = get_device_verification_info() if login_info is not None: wait_for_tokens() return login_info + def logout(): global github_auth, github_access_token_provided github_access_token_provided = None - github_auth.update({ - "verification_uri": None, - "user_code": None, - "device_code": None, - "access_token": None, - "status" : LoginStatus.NOT_LOGGED_IN, - "token": None - }) + github_auth.update( + { + "verification_uri": None, + "user_code": None, + "device_code": None, + "access_token": None, + "status": LoginStatus.NOT_LOGGED_IN, + "token": None, + } + ) emit_github_login_status_change() - return { - "status": github_auth["status"].name - } + return {"status": github_auth["status"].name} + def handle_stop_request(): global stop_requested stop_requested = True + def get_device_verification_info(): global github_auth - data = { - "client_id": CLIENT_ID, - "scope": "read:user" - } + data = {"client_id": CLIENT_ID, "scope": "read:user"} try: - resp = requests.post(f'{GH_WEB_BASE_URL}/login/device/code', + resp = requests.post( + f"{GH_WEB_BASE_URL}/login/device/code", headers={ - 'accept': 'application/json', - 'editor-version': EDITOR_VERSION, - 'editor-plugin-version': EDITOR_PLUGIN_VERSION, - 'content-type': 'application/json', - 'user-agent': USER_AGENT, - 'accept-encoding': 'gzip,deflate,br' + "accept": "application/json", + "editor-version": EDITOR_VERSION, + "editor-plugin-version": EDITOR_PLUGIN_VERSION, + "content-type": "application/json", + "user-agent": USER_AGENT, + "accept-encoding": "gzip,deflate,br", }, - data=json.dumps(data) + data=json.dumps(data), ) resp_json = resp.json() - github_auth["verification_uri"] = resp_json.get('verification_uri') - github_auth["user_code"] = resp_json.get('user_code') - github_auth["device_code"] = resp_json.get('device_code') + github_auth["verification_uri"] = resp_json.get("verification_uri") + github_auth["user_code"] = resp_json.get("user_code") + github_auth["device_code"] = resp_json.get("device_code") github_auth["status"] = LoginStatus.ACTIVATING_DEVICE emit_github_login_status_change() @@ -240,9 +279,10 @@ def get_device_verification_info(): # user needs to visit the verification_uri and enter the user_code return { "verification_uri": github_auth["verification_uri"], - "user_code": github_auth["user_code"] + "user_code": github_auth["user_code"], } + def wait_for_user_access_token_thread_func(): global github_auth, get_access_code_thread @@ -254,29 +294,35 @@ def wait_for_user_access_token_thread_func(): while True: # terminate thread if logged out or stop requested - if stop_requested or github_auth["access_token"] is not None or github_auth["device_code"] is None or github_auth["status"] == LoginStatus.NOT_LOGGED_IN: + if ( + stop_requested + or github_auth["access_token"] is not None + or github_auth["device_code"] is None + or github_auth["status"] == LoginStatus.NOT_LOGGED_IN + ): get_access_code_thread = None break data = { "client_id": CLIENT_ID, "device_code": github_auth["device_code"], - "grant_type": "urn:ietf:params:oauth:grant-type:device_code" + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", } try: - resp = requests.post(f'{GH_WEB_BASE_URL}/login/oauth/access_token', + resp = requests.post( + f"{GH_WEB_BASE_URL}/login/oauth/access_token", headers={ - 'accept': 'application/json', - 'editor-version': EDITOR_VERSION, - 'editor-plugin-version': EDITOR_PLUGIN_VERSION, - 'content-type': 'application/json', - 'user-agent': USER_AGENT, - 'accept-encoding': 'gzip,deflate,br' + "accept": "application/json", + "editor-version": EDITOR_VERSION, + "editor-plugin-version": EDITOR_PLUGIN_VERSION, + "content-type": "application/json", + "user-agent": USER_AGENT, + "accept-encoding": "gzip,deflate,br", }, - data=json.dumps(data) + data=json.dumps(data), ) resp_json = resp.json() - access_token = resp_json.get('access_token') + access_token = resp_json.get("access_token") if access_token: github_auth["access_token"] = access_token @@ -290,6 +336,7 @@ def wait_for_user_access_token_thread_func(): time.sleep(ACCESS_TOKEN_THREAD_SLEEP_INTERVAL) + def get_token(): global github_auth, github_access_token_provided, API_ENDPOINT, PROXY_ENDPOINT, TOKEN_REFRESH_INTERVAL access_token = github_auth["access_token"] @@ -301,12 +348,15 @@ def get_token(): emit_github_login_status_change() try: - resp = requests.get(f'{GH_REST_API_BASE_URL}/copilot_internal/v2/token', headers={ - 'authorization': f'token {access_token}', - 'editor-version': EDITOR_VERSION, - 'editor-plugin-version': EDITOR_PLUGIN_VERSION, - 'user-agent': USER_AGENT - }) + resp = requests.get( + f"{GH_REST_API_BASE_URL}/copilot_internal/v2/token", + headers={ + "authorization": f"token {access_token}", + "editor-version": EDITOR_VERSION, + "editor-plugin-version": EDITOR_PLUGIN_VERSION, + "user-agent": USER_AGENT, + }, + ) resp_json = resp.json() @@ -315,30 +365,33 @@ def get_token(): logout() wait_for_tokens() return - + if resp.status_code != 200: log.error(f"Failed to get token from GitHub Copilot: {resp_json}") return - token = resp_json.get('token') + token = resp_json.get("token") github_auth["token"] = token - expires_at = resp_json.get('expires_at') + expires_at = resp_json.get("expires_at") if expires_at is not None: github_auth["token_expires_at"] = dt.datetime.fromtimestamp(expires_at) else: - github_auth["token_expires_at"] = dt.datetime.now() + dt.timedelta(seconds=TOKEN_REFRESH_INTERVAL) + github_auth["token_expires_at"] = dt.datetime.now() + dt.timedelta( + seconds=TOKEN_REFRESH_INTERVAL + ) github_auth["verification_uri"] = None github_auth["user_code"] = None github_auth["status"] = LoginStatus.LOGGED_IN emit_github_login_status_change() - endpoints = resp_json.get('endpoints', {}) - API_ENDPOINT = endpoints.get('api', API_ENDPOINT) - PROXY_ENDPOINT = endpoints.get('proxy', PROXY_ENDPOINT) - TOKEN_REFRESH_INTERVAL = resp_json.get('refresh_in', TOKEN_REFRESH_INTERVAL) + endpoints = resp_json.get("endpoints", {}) + API_ENDPOINT = endpoints.get("api", API_ENDPOINT) + PROXY_ENDPOINT = endpoints.get("proxy", PROXY_ENDPOINT) + TOKEN_REFRESH_INTERVAL = resp_json.get("refresh_in", TOKEN_REFRESH_INTERVAL) except Exception as e: log.error(f"Failed to get token from GitHub Copilot: {e}") + def get_token_thread_func(): global github_auth, get_token_thread, last_token_fetch_time while True: @@ -348,174 +401,210 @@ def get_token_thread_func(): return token = github_auth["token"] # update token if 10 seconds or less left to expiration - if github_auth["access_token"] is not None and (token is None or (dt.datetime.now() - github_auth["token_expires_at"]).total_seconds() > -10): - if (dt.datetime.now() - last_token_fetch_time).total_seconds() > TOKEN_FETCH_INTERVAL: + if github_auth["access_token"] is not None and ( + token is None + or (dt.datetime.now() - github_auth["token_expires_at"]).total_seconds() + > -10 + ): + if ( + dt.datetime.now() - last_token_fetch_time + ).total_seconds() > TOKEN_FETCH_INTERVAL: log.info("Refreshing GitHub token") get_token() last_token_fetch_time = dt.datetime.now() time.sleep(TOKEN_THREAD_SLEEP_INTERVAL) + def wait_for_tokens(): global get_access_code_thread, get_token_thread if get_access_code_thread is None: - get_access_code_thread = threading.Thread(target=wait_for_user_access_token_thread_func) + get_access_code_thread = threading.Thread( + target=wait_for_user_access_token_thread_func + ) get_access_code_thread.start() if get_token_thread is None: get_token_thread = threading.Thread(target=get_token_thread_func) get_token_thread.start() + def generate_copilot_headers(): global github_auth - token = github_auth['token'] + token = github_auth["token"] return { - 'authorization': f'Bearer {token}', - 'editor-version': EDITOR_VERSION, - 'editor-plugin-version': EDITOR_PLUGIN_VERSION, - 'user-agent': USER_AGENT, - 'content-type': 'application/json', - 'openai-intent': 'conversation-panel', - 'openai-organization': 'github-copilot', - 'copilot-integration-id': 'vscode-chat', - 'x-request-id': str(uuid.uuid4()), - 'vscode-sessionid': str(uuid.uuid4()), - 'vscode-machineid': MACHINE_ID, + "authorization": f"Bearer {token}", + "editor-version": EDITOR_VERSION, + "editor-plugin-version": EDITOR_PLUGIN_VERSION, + "user-agent": USER_AGENT, + "content-type": "application/json", + "openai-intent": "conversation-panel", + "openai-organization": "github-copilot", + "copilot-integration-id": "vscode-chat", + "x-request-id": str(uuid.uuid4()), + "vscode-sessionid": str(uuid.uuid4()), + "vscode-machineid": MACHINE_ID, } -def inline_completions(model_id, prefix, suffix, language, filename, context: CompletionContext, cancel_token: CancelToken) -> str: + +def inline_completions( + model_id, + prefix, + suffix, + language, + filename, + context: CompletionContext, + cancel_token: CancelToken, +) -> str: global github_auth - token = github_auth['token'] + token = github_auth["token"] prompt = f"# Path: {filename}" if cancel_token.is_cancel_requested: - return '' + return "" if context is not None: for item in context.items: context_file = f"Compare this snippet from {item.filePath if item.filePath is not None else 'undefined'}:{NL}{item.content}{NL}" - prompt += "\n# " + "\n# ".join(context_file.split('\n')) + prompt += "\n# " + "\n# ".join(context_file.split("\n")) prompt += f"{NL}{prefix}" try: if cancel_token.is_cancel_requested: - return '' - resp = requests.post(f"{PROXY_ENDPOINT}/v1/engines/{model_id}/completions", - headers={'authorization': f'Bearer {token}'}, - json={ - 'prompt': prompt, - 'suffix': suffix, - 'min_tokens': 500, - 'max_tokens': 2000, - 'temperature': 0, - 'top_p': 1, - 'n': 1, - 'stop': ['', '```'], - 'nwo': 'NotebookIntelligence', - 'stream': True, - 'extra': { - 'language': language, - 'next_indent': 0, - 'trim_by_indentation': True - } - } + return "" + resp = requests.post( + f"{PROXY_ENDPOINT}/v1/engines/{model_id}/completions", + headers={"authorization": f"Bearer {token}"}, + json={ + "prompt": prompt, + "suffix": suffix, + "min_tokens": 500, + "max_tokens": 2000, + "temperature": 0, + "top_p": 1, + "n": 1, + "stop": ["", "```"], + "nwo": "LabNotebookIntelligence", + "stream": True, + "extra": { + "language": language, + "next_indent": 0, + "trim_by_indentation": True, + }, + }, ) except Exception as e: log.error(f"Failed to get inline completions: {e}") - return '' + return "" if cancel_token.is_cancel_requested: - return '' + return "" - result = '' + result = "" decoded_response = resp.content.decode() - resp_text = decoded_response.split('\n') + resp_text = decoded_response.split("\n") for line in resp_text: - if line.startswith('data: {'): + if line.startswith("data: {"): json_completion = json.loads(line[6:]) - completion = json_completion.get('choices')[0].get('text') + completion = json_completion.get("choices")[0].get("text") if completion: result += completion # else: # result += '\n' - + return result + def _aggregate_streaming_response(client: sseclient.SSEClient) -> dict: final_tool_calls = [] - final_content = '' + final_content = "" def _format_llm_response(): for tool_call in final_tool_calls: - if 'arguments' in tool_call['function'] and tool_call['function']['arguments'] == '': - tool_call['function']['arguments'] = '{}' + if ( + "arguments" in tool_call["function"] + and tool_call["function"]["arguments"] == "" + ): + tool_call["function"]["arguments"] = "{}" return { "choices": [ { "message": { - "tool_calls": final_tool_calls if len(final_tool_calls) > 0 else None, + "tool_calls": ( + final_tool_calls if len(final_tool_calls) > 0 else None + ), "content": final_content, - "role": "assistant" + "role": "assistant", } } ] } for event in client.events(): - if event.data == '[DONE]': + if event.data == "[DONE]": return _format_llm_response() chunk = json.loads(event.data) - if len(chunk['choices']) == 0: + if len(chunk["choices"]) == 0: continue - content_chunk = chunk['choices'][0]['delta'].get('content') + content_chunk = chunk["choices"][0]["delta"].get("content") if content_chunk: final_content += content_chunk - for tool_call in chunk['choices'][0]['delta'].get('tool_calls', []): - if 'index' not in tool_call: + for tool_call in chunk["choices"][0]["delta"].get("tool_calls", []): + if "index" not in tool_call: continue - index = tool_call['index'] + index = tool_call["index"] if index >= len(final_tool_calls): tc = tool_call.copy() - if 'arguments' not in tc: - tc['function']['arguments'] = '' + if "arguments" not in tc: + tc["function"]["arguments"] = "" final_tool_calls.append(tc) else: - if 'arguments' in tool_call['function']: - final_tool_calls[index]['function']['arguments'] += tool_call['function']['arguments'] + if "arguments" in tool_call["function"]: + final_tool_calls[index]["function"]["arguments"] += tool_call[ + "function" + ]["arguments"] return _format_llm_response() -def completions(model_id, messages, tools = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: + +def completions( + model_id, + messages, + tools=None, + response: ChatResponse = None, + cancel_token: CancelToken = None, + options: dict = {}, +) -> Any: aggregate = response is None try: data = { - 'model': model_id, - 'messages': messages, - 'tools': tools, - 'temperature': 0, - 'top_p': 1, - 'n': 1, - 'nwo': 'NotebookIntelligence', - 'stream': True + "model": model_id, + "messages": messages, + "tools": tools, + "temperature": 0, + "top_p": 1, + "n": 1, + "nwo": "LabNotebookIntelligence", + "stream": True, } - if not (model_id == 'gpt-5' or model_id == 'gpt-5-mini'): - data['stop'] = [''] + if not (model_id == "gpt-5" or model_id == "gpt-5-mini"): + data["stop"] = [""] - if 'tool_choice' in options: - data['tool_choice'] = options['tool_choice'] + if "tool_choice" in options: + data["tool_choice"] = options["tool_choice"] if cancel_token is not None and cancel_token.is_cancel_requested: if response is not None: @@ -524,9 +613,9 @@ def completions(model_id, messages, tools = None, response: ChatResponse = None, request = requests.post( f"{API_ENDPOINT}/chat/completions", - headers = generate_copilot_headers(), - json = data, - stream = True + headers=generate_copilot_headers(), + json=data, + stream=True, ) if request.status_code != 200: @@ -544,7 +633,7 @@ def completions(model_id, messages, tools = None, response: ChatResponse = None, for event in client.events(): if cancel_token is not None and cancel_token.is_cancel_requested: response.finish() - if event.data == '[DONE]': + if event.data == "[DONE]": response.finish() else: response.stream(json.loads(event.data)) diff --git a/lab_notebook_intelligence/github_copilot_chat_participant.py b/lab_notebook_intelligence/github_copilot_chat_participant.py index 8bf3800..4628945 100644 --- a/lab_notebook_intelligence/github_copilot_chat_participant.py +++ b/lab_notebook_intelligence/github_copilot_chat_participant.py @@ -1,9 +1,9 @@ # Copyright (c) Mehmet Bektas -from lab_notebook_intelligence.base_chat_participant import BaseChatParticipant -from lab_notebook_intelligence.prompts import Prompts import base64 +from lab_notebook_intelligence.base_chat_participant import BaseChatParticipant +from lab_notebook_intelligence.prompts import Prompts COPILOT_ICON_SVG = '' COPILOT_ICON_URL = f"data:image/svg+xml;base64,{base64.b64encode(COPILOT_ICON_SVG.encode('utf-8')).decode('utf-8')}" @@ -16,7 +16,7 @@ def __init__(self): @property def id(self) -> str: return "default" - + @property def name(self) -> str: return "GitHub Copilot" @@ -24,10 +24,10 @@ def name(self) -> str: @property def description(self) -> str: return "GitHub Copilot" - + @property def icon_path(self) -> str: return COPILOT_ICON_URL - + def chat_prompt(self, model_provider: str, model_name: str) -> str: return Prompts.github_copilot_chat_prompt(model_provider, model_name) diff --git a/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py b/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py index 1abf5bc..8386aa0 100644 --- a/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py +++ b/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py @@ -1,17 +1,30 @@ # Copyright (c) Mehmet Bektas +import logging from typing import Any -from lab_notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext -from lab_notebook_intelligence.github_copilot import generate_copilot_headers, completions, inline_completions -import logging +from lab_notebook_intelligence.api import (CancelToken, ChatModel, + ChatResponse, CompletionContext, + EmbeddingModel, + InlineCompletionModel, LLMProvider) +from lab_notebook_intelligence.github_copilot import (completions, + generate_copilot_headers, + inline_completions) log = logging.getLogger(__name__) GH_COPILOT_EXCLUDED_MODELS = set(["o1"]) + class GitHubCopilotChatModel(ChatModel): - def __init__(self, provider: LLMProvider, model_id: str, model_name: str, context_window: int, supports_tools: bool): + def __init__( + self, + provider: LLMProvider, + model_id: str, + model_name: str, + context_window: int, + supports_tools: bool, + ): super().__init__(provider) self._model_id = model_id self._model_name = model_name @@ -21,11 +34,11 @@ def __init__(self, provider: LLMProvider, model_id: str, model_name: str, contex @property def id(self) -> str: return self._model_id - + @property def name(self) -> str: return self._model_name - + @property def context_window(self) -> int: return self._context_window @@ -34,8 +47,18 @@ def context_window(self) -> int: def supports_tools(self) -> bool: return self._supports_tools - def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: - return completions(self._model_id, messages, tools, response, cancel_token, options) + def completions( + self, + messages: list[dict], + tools: list[dict] = None, + response: ChatResponse = None, + cancel_token: CancelToken = None, + options: dict = {}, + ) -> Any: + return completions( + self._model_id, messages, tools, response, cancel_token, options + ) + class GitHubCopilotInlineCompletionModel(InlineCompletionModel): def __init__(self, provider: LLMProvider, model_id: str, model_name: str): @@ -46,17 +69,28 @@ def __init__(self, provider: LLMProvider, model_id: str, model_name: str): @property def id(self) -> str: return self._model_id - + @property def name(self) -> str: return self._model_name - + @property def context_window(self) -> int: return 4096 - def inline_completions(self, prefix, suffix, language, filename, context: CompletionContext, cancel_token: CancelToken) -> str: - return inline_completions(self._model_id, prefix, suffix, language, filename, context, cancel_token) + def inline_completions( + self, + prefix, + suffix, + language, + filename, + context: CompletionContext, + cancel_token: CancelToken, + ) -> str: + return inline_completions( + self._model_id, prefix, suffix, language, filename, context, cancel_token + ) + class GitHubCopilotLLMProvider(LLMProvider): def __init__(self): @@ -66,20 +100,36 @@ def __init__(self): GitHubCopilotChatModel(self, "gpt-4o", "GPT-4o", 128000, True), GitHubCopilotChatModel(self, "o3-mini", "o3-mini", 200000, True), GitHubCopilotChatModel(self, "gpt-5", "GPT-5", 128000, True), - GitHubCopilotChatModel(self, "claude-sonnet-4", "Claude Sonnet 4", 80000, True), - GitHubCopilotChatModel(self, "claude-3.7-sonnet", "Claude 3.7 Sonnet", 200000, True), - GitHubCopilotChatModel(self, "claude-3.5-sonnet", "Claude 3.5 Sonnet", 90000, True), - GitHubCopilotChatModel(self, "gemini-2.5-pro", "Gemini 2.5 Pro", 128000, True), - GitHubCopilotChatModel(self, "gemini-2.0-flash-001", "Gemini 2.0 Flash", 1000000, False), + GitHubCopilotChatModel( + self, "claude-sonnet-4", "Claude Sonnet 4", 80000, True + ), + GitHubCopilotChatModel( + self, "claude-3.7-sonnet", "Claude 3.7 Sonnet", 200000, True + ), + GitHubCopilotChatModel( + self, "claude-3.5-sonnet", "Claude 3.5 Sonnet", 90000, True + ), + GitHubCopilotChatModel( + self, "gemini-2.5-pro", "Gemini 2.5 Pro", 128000, True + ), + GitHubCopilotChatModel( + self, "gemini-2.0-flash-001", "Gemini 2.0 Flash", 1000000, False + ), ] - self._inline_completion_model_gpt41 = GitHubCopilotInlineCompletionModel(self, "gpt-41-copilot", "GPT-4.1 Copilot") - self._inline_completion_model_gpt4o = GitHubCopilotInlineCompletionModel(self, "gpt-4o-copilot", "GPT-4o Copilot") - self._inline_completion_model_codex = GitHubCopilotInlineCompletionModel(self, "copilot-codex", "Copilot Codex") + self._inline_completion_model_gpt41 = GitHubCopilotInlineCompletionModel( + self, "gpt-41-copilot", "GPT-4.1 Copilot" + ) + self._inline_completion_model_gpt4o = GitHubCopilotInlineCompletionModel( + self, "gpt-4o-copilot", "GPT-4o Copilot" + ) + self._inline_completion_model_codex = GitHubCopilotInlineCompletionModel( + self, "copilot-codex", "Copilot Codex" + ) @property def id(self) -> str: return "github-copilot" - + @property def name(self) -> str: return "GitHub Copilot" @@ -87,12 +137,15 @@ def name(self) -> str: @property def chat_models(self) -> list[ChatModel]: return self._chat_models - + @property def inline_completion_models(self) -> list[InlineCompletionModel]: - return [self._inline_completion_model_gpt41, self._inline_completion_model_gpt4o, self._inline_completion_model_codex] - + return [ + self._inline_completion_model_gpt41, + self._inline_completion_model_gpt4o, + self._inline_completion_model_codex, + ] + @property def embedding_models(self) -> list[EmbeddingModel]: return [] - diff --git a/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py b/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py index 8f6e819..b8e4572 100644 --- a/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py +++ b/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py @@ -2,30 +2,41 @@ import json from typing import Any -from lab_notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext, LLMProviderProperty + import litellm +from lab_notebook_intelligence.api import (CancelToken, ChatModel, + ChatResponse, CompletionContext, + EmbeddingModel, + InlineCompletionModel, LLMProvider, + LLMProviderProperty) + DEFAULT_CONTEXT_WINDOW = 4096 + class LiteLLMCompatibleChatModel(ChatModel): def __init__(self, provider: "LiteLLMCompatibleLLMProvider"): super().__init__(provider) self._provider = provider self._properties = [ - LLMProviderProperty("model_id", "Model", "Model (must support streaming)", "", False), + LLMProviderProperty( + "model_id", "Model", "Model (must support streaming)", "", False + ), LLMProviderProperty("base_url", "Base URL", "Base URL", "", False), LLMProviderProperty("api_key", "API key", "API key", "", True), - LLMProviderProperty("context_window", "Context window", "Context window length", "", True), + LLMProviderProperty( + "context_window", "Context window", "Context window length", "", True + ), ] @property def id(self) -> str: return "litellm-compatible-chat-model" - + @property def name(self) -> str: return self.get_property("model_id").value - + @property def context_window(self) -> int: try: @@ -36,7 +47,14 @@ def context_window(self) -> int: except: return DEFAULT_CONTEXT_WINDOW - def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: + def completions( + self, + messages: list[dict], + tools: list[dict] = None, + response: ChatResponse = None, + cancel_token: CancelToken = None, + options: dict = {}, + ) -> Any: stream = response is not None model_id = self.get_property("model_id").value base_url = self.get_property("base_url").value @@ -54,20 +72,25 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: if stream: for chunk in litellm_resp: - response.stream({ - "choices": [{ - "delta": { - "role": chunk.choices[0].delta.role, - "content": chunk.choices[0].delta.content + response.stream( + { + "choices": [ + { + "delta": { + "role": chunk.choices[0].delta.role, + "content": chunk.choices[0].delta.content, + } } - }] - }) + ] + } + ) response.finish() return else: json_resp = json.loads(litellm_resp.model_dump_json()) return json_resp - + + class LiteLLMCompatibleInlineCompletionModel(InlineCompletionModel): def __init__(self, provider: "LiteLLMCompatibleLLMProvider"): super().__init__(provider) @@ -76,17 +99,19 @@ def __init__(self, provider: "LiteLLMCompatibleLLMProvider"): LLMProviderProperty("model_id", "Model", "Model", "", False), LLMProviderProperty("base_url", "Base URL", "Base URL", "", False), LLMProviderProperty("api_key", "API key", "API key", "", True), - LLMProviderProperty("context_window", "Context window", "Context window length", "", True), + LLMProviderProperty( + "context_window", "Context window", "Context window length", "", True + ), ] @property def id(self) -> str: return "litellm-compatible-inline-completion-model" - + @property def name(self) -> str: return "Inline Completion Model" - + @property def context_window(self) -> int: try: @@ -97,7 +122,15 @@ def context_window(self) -> int: except: return DEFAULT_CONTEXT_WINDOW - def inline_completions(self, prefix, suffix, language, filename, context: CompletionContext, cancel_token: CancelToken) -> str: + def inline_completions( + self, + prefix, + suffix, + language, + filename, + context: CompletionContext, + cancel_token: CancelToken, + ) -> str: model_id = self.get_property("model_id").value base_url = self.get_property("base_url").value api_key_prop = self.get_property("api_key") @@ -113,6 +146,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple return litellm_resp.choices[0].message.content + class LiteLLMCompatibleLLMProvider(LLMProvider): def __init__(self): super().__init__() @@ -122,7 +156,7 @@ def __init__(self): @property def id(self) -> str: return "litellm-compatible" - + @property def name(self) -> str: return "LiteLLM Compatible" @@ -130,11 +164,11 @@ def name(self) -> str: @property def chat_models(self) -> list[ChatModel]: return [self._chat_model] - + @property def inline_completion_models(self) -> list[InlineCompletionModel]: return [self._inline_completion_model] - + @property def embedding_models(self) -> list[EmbeddingModel]: return [] diff --git a/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py b/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py index 53d2418..139f455 100644 --- a/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py +++ b/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py @@ -1,24 +1,37 @@ # Copyright (c) Mehmet Bektas import json +import logging from typing import Any -from lab_notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext + import ollama -import logging +from lab_notebook_intelligence.api import (CancelToken, ChatModel, + ChatResponse, CompletionContext, + EmbeddingModel, + InlineCompletionModel, LLMProvider) from lab_notebook_intelligence.util import extract_llm_generated_code log = logging.getLogger(__name__) OLLAMA_EMBEDDING_FAMILIES = set(["nomic-bert", "bert"]) -QWEN_INLINE_COMPL_PROMPT = """<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>""" -DEEPSEEK_INLINE_COMPL_PROMPT = """<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>""" +QWEN_INLINE_COMPL_PROMPT = ( + """<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>""" +) +DEEPSEEK_INLINE_COMPL_PROMPT = ( + """<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>""" +) CODELLAMA_INLINE_COMPL_PROMPT = """
 {prefix} {suffix} """
-STARCODER_INLINE_COMPL_PROMPT = """{prefix}{suffix}"""
+STARCODER_INLINE_COMPL_PROMPT = (
+    """{prefix}{suffix}"""
+)
 CODESTRAL_INLINE_COMPL_PROMPT = """[SUFFIX]{suffix}[PREFIX]{prefix}"""
 
+
 class OllamaChatModel(ChatModel):
-    def __init__(self, provider: LLMProvider, model_id: str, model_name: str, context_window: int):
+    def __init__(
+        self, provider: LLMProvider, model_id: str, model_name: str, context_window: int
+    ):
         super().__init__(provider)
         self._model_id = model_id
         self._model_name = model_name
@@ -27,19 +40,26 @@ def __init__(self, provider: LLMProvider, model_id: str, model_name: str, contex
     @property
     def id(self) -> str:
         return self._model_id
-    
+
     @property
     def name(self) -> str:
         return self._model_name
-    
+
     @property
     def context_window(self) -> int:
         return self._context_window
 
-    def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any:
+    def completions(
+        self,
+        messages: list[dict],
+        tools: list[dict] = None,
+        response: ChatResponse = None,
+        cancel_token: CancelToken = None,
+        options: dict = {},
+    ) -> Any:
         stream = response is not None
         completion_args = {
-            "model": self._model_id, 
+            "model": self._model_id,
             "messages": messages.copy(),
             "stream": stream,
         }
@@ -50,30 +70,35 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response:
 
         if stream:
             for chunk in ollama_response:
-                response.stream({
-                        "choices": [{
-                            "delta": {
-                                "role": chunk['message']['role'],
-                                "content": chunk['message']['content']
+                response.stream(
+                    {
+                        "choices": [
+                            {
+                                "delta": {
+                                    "role": chunk["message"]["role"],
+                                    "content": chunk["message"]["content"],
+                                }
                             }
-                        }]
-                    })
+                        ]
+                    }
+                )
             response.finish()
             return
         else:
             json_resp = json.loads(ollama_response.model_dump_json())
 
-            return {
-                'choices': [
-                    {
-                        'message': json_resp['message']
-                    }
-                ]
-            }
+            return {"choices": [{"message": json_resp["message"]}]}
 
 
 class OllamaInlineCompletionModel(InlineCompletionModel):
-    def __init__(self, provider: LLMProvider, model_id: str, model_name: str, context_window: int, prompt_template: str):
+    def __init__(
+        self,
+        provider: LLMProvider,
+        model_id: str,
+        model_name: str,
+        context_window: int,
+        prompt_template: str,
+    ):
         super().__init__(provider)
         self._model_id = model_id
         self._model_name = model_name
@@ -83,16 +108,24 @@ def __init__(self, provider: LLMProvider, model_id: str, model_name: str, contex
     @property
     def id(self) -> str:
         return self._model_id
-    
+
     @property
     def name(self) -> str:
         return self._model_name
-    
+
     @property
     def context_window(self) -> int:
         return self._context_window
 
-    def inline_completions(self, prefix, suffix, language, filename, context: CompletionContext, cancel_token: CancelToken) -> str:
+    def inline_completions(
+        self,
+        prefix,
+        suffix,
+        language,
+        filename,
+        context: CompletionContext,
+        cancel_token: CancelToken,
+    ) -> str:
         has_suffix = suffix.strip() != ""
         if has_suffix:
             prompt = self._prompt_template.format(prefix=prefix, suffix=suffix.strip())
@@ -101,13 +134,13 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple
 
         try:
             generate_args = {
-                "model": self._model_id, 
+                "model": self._model_id,
                 "prompt": prompt,
                 "raw": True,
                 "options": {
-                    'num_predict': 128,
+                    "num_predict": 128,
                     "temperature": 0,
-                    "stop" : [
+                    "stop": [
                         "<|end▁of▁sentence|>",
                         "<|end▁of▁sentence|>",
                         "<|EOT|>",
@@ -128,6 +161,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple
             log.error(f"Error occurred while generating using completions ollama: {e}")
             return ""
 
+
 class OllamaLLMProvider(LLMProvider):
     def __init__(self):
         super().__init__()
@@ -137,7 +171,7 @@ def __init__(self):
     @property
     def id(self) -> str:
         return "ollama"
-    
+
     @property
     def name(self) -> str:
         return "Ollama"
@@ -149,17 +183,35 @@ def chat_models(self) -> list[ChatModel]:
     @property
     def inline_completion_models(self) -> list[InlineCompletionModel]:
         return [
-            OllamaInlineCompletionModel(self, "deepseek-coder-v2", "deepseek-coder-v2", 163840, DEEPSEEK_INLINE_COMPL_PROMPT),
-            OllamaInlineCompletionModel(self, "qwen2.5-coder", "qwen2.5-coder", 32768, QWEN_INLINE_COMPL_PROMPT),
-            OllamaInlineCompletionModel(self, "codestral", "codestral", 32768, CODESTRAL_INLINE_COMPL_PROMPT),
-            OllamaInlineCompletionModel(self, "starcoder2", "starcoder2", 16384, STARCODER_INLINE_COMPL_PROMPT),
-            OllamaInlineCompletionModel(self, "codellama:7b-code", "codellama:7b-code", 16384, CODELLAMA_INLINE_COMPL_PROMPT),
+            OllamaInlineCompletionModel(
+                self,
+                "deepseek-coder-v2",
+                "deepseek-coder-v2",
+                163840,
+                DEEPSEEK_INLINE_COMPL_PROMPT,
+            ),
+            OllamaInlineCompletionModel(
+                self, "qwen2.5-coder", "qwen2.5-coder", 32768, QWEN_INLINE_COMPL_PROMPT
+            ),
+            OllamaInlineCompletionModel(
+                self, "codestral", "codestral", 32768, CODESTRAL_INLINE_COMPL_PROMPT
+            ),
+            OllamaInlineCompletionModel(
+                self, "starcoder2", "starcoder2", 16384, STARCODER_INLINE_COMPL_PROMPT
+            ),
+            OllamaInlineCompletionModel(
+                self,
+                "codellama:7b-code",
+                "codellama:7b-code",
+                16384,
+                CODELLAMA_INLINE_COMPL_PROMPT,
+            ),
         ]
-    
+
     @property
     def embedding_models(self) -> list[EmbeddingModel]:
         return []
-    
+
     def update_chat_model_list(self):
         try:
             response = ollama.list()
@@ -178,5 +230,5 @@ def update_chat_model_list(self):
                     )
                 except Exception as e:
                     log.error(f"Error getting Ollama model info {model}: {e}")
-        except Exception as e:          
+        except Exception as e:
             log.error(f"Error updating supported Ollama models: {e}")
diff --git a/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
index caf6713..9917a55 100644
--- a/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
+++ b/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
@@ -2,30 +2,41 @@
 
 import json
 from typing import Any
-from lab_notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext, LLMProviderProperty
+
 from openai import OpenAI
 
+from lab_notebook_intelligence.api import (CancelToken, ChatModel,
+                                           ChatResponse, CompletionContext,
+                                           EmbeddingModel,
+                                           InlineCompletionModel, LLMProvider,
+                                           LLMProviderProperty)
+
 DEFAULT_CONTEXT_WINDOW = 4096
 
+
 class OpenAICompatibleChatModel(ChatModel):
     def __init__(self, provider: "OpenAICompatibleLLMProvider"):
         super().__init__(provider)
         self._provider = provider
         self._properties = [
             LLMProviderProperty("api_key", "API key", "API key", "", False),
-            LLMProviderProperty("model_id", "Model", "Model (must support streaming)", "", False),
+            LLMProviderProperty(
+                "model_id", "Model", "Model (must support streaming)", "", False
+            ),
             LLMProviderProperty("base_url", "Base URL", "Base URL", "", True),
-            LLMProviderProperty("context_window", "Context window", "Context window length", "", True),
+            LLMProviderProperty(
+                "context_window", "Context window", "Context window length", "", True
+            ),
         ]
 
     @property
     def id(self) -> str:
         return "openai-compatible-chat-model"
-    
+
     @property
     def name(self) -> str:
         return self.get_property("model_id").value
-    
+
     @property
     def context_window(self) -> int:
         try:
@@ -36,7 +47,14 @@ def context_window(self) -> int:
         except:
             return DEFAULT_CONTEXT_WINDOW
 
-    def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any:
+    def completions(
+        self,
+        messages: list[dict],
+        tools: list[dict] = None,
+        response: ChatResponse = None,
+        cancel_token: CancelToken = None,
+        options: dict = {},
+    ) -> Any:
         stream = response is not None
         model_id = self.get_property("model_id").value
         base_url_prop = self.get_property("base_url")
@@ -55,20 +73,25 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response:
 
         if stream:
             for chunk in resp:
-                response.stream({
-                        "choices": [{
-                            "delta": {
-                                "role": chunk.choices[0].delta.role,
-                                "content": chunk.choices[0].delta.content
+                response.stream(
+                    {
+                        "choices": [
+                            {
+                                "delta": {
+                                    "role": chunk.choices[0].delta.role,
+                                    "content": chunk.choices[0].delta.content,
+                                }
                             }
-                        }]
-                    })
+                        ]
+                    }
+                )
             response.finish()
             return
         else:
             json_resp = json.loads(resp.model_dump_json())
             return json_resp
-    
+
+
 class OpenAICompatibleInlineCompletionModel(InlineCompletionModel):
     def __init__(self, provider: "OpenAICompatibleLLMProvider"):
         super().__init__(provider)
@@ -77,17 +100,19 @@ def __init__(self, provider: "OpenAICompatibleLLMProvider"):
             LLMProviderProperty("api_key", "API key", "API key", "", False),
             LLMProviderProperty("model_id", "Model", "Model", "", False),
             LLMProviderProperty("base_url", "Base URL", "Base URL", "", True),
-            LLMProviderProperty("context_window", "Context window", "Context window length", "", True),
+            LLMProviderProperty(
+                "context_window", "Context window", "Context window length", "", True
+            ),
         ]
 
     @property
     def id(self) -> str:
         return "openai-compatible-inline-completion-model"
-    
+
     @property
     def name(self) -> str:
         return "Inline Completion Model"
-    
+
     @property
     def context_window(self) -> int:
         try:
@@ -98,7 +123,15 @@ def context_window(self) -> int:
         except:
             return DEFAULT_CONTEXT_WINDOW
 
-    def inline_completions(self, prefix, suffix, language, filename, context: CompletionContext, cancel_token: CancelToken) -> str:
+    def inline_completions(
+        self,
+        prefix,
+        suffix,
+        language,
+        filename,
+        context: CompletionContext,
+        cancel_token: CancelToken,
+    ) -> str:
         model_id = self.get_property("model_id").value
         base_url_prop = self.get_property("base_url")
         base_url = base_url_prop.value if base_url_prop is not None else None
@@ -115,6 +148,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple
 
         return resp.choices[0].text
 
+
 class OpenAICompatibleLLMProvider(LLMProvider):
     def __init__(self):
         super().__init__()
@@ -124,7 +158,7 @@ def __init__(self):
     @property
     def id(self) -> str:
         return "openai-compatible"
-    
+
     @property
     def name(self) -> str:
         return "OpenAI Compatible"
@@ -132,11 +166,11 @@ def name(self) -> str:
     @property
     def chat_models(self) -> list[ChatModel]:
         return [self._chat_model]
-    
+
     @property
     def inline_completion_models(self) -> list[InlineCompletionModel]:
         return [self._inline_completion_model]
-    
+
     @property
     def embedding_models(self) -> list[EmbeddingModel]:
         return []
diff --git a/lab_notebook_intelligence/mcp_manager.py b/lab_notebook_intelligence/mcp_manager.py
index 5b1b57c..e8c84d5 100644
--- a/lab_notebook_intelligence/mcp_manager.py
+++ b/lab_notebook_intelligence/mcp_manager.py
@@ -1,34 +1,44 @@
 # Copyright (c) Mehmet Bektas 
 
 import asyncio
-from dataclasses import dataclass
 import json
+import logging
 import threading
+from dataclasses import dataclass
 from typing import Any, Union
+
+from fastmcp import Client
 from fastmcp.client import StdioTransport, StreamableHttpTransport
 from mcp import StdioServerParameters
-from mcp.client.stdio import get_default_environment as mcp_get_default_environment
-from mcp.types import TextContent, ImageContent
-from lab_notebook_intelligence.api import ChatCommand, ChatRequest, ChatResponse, HTMLFrameData, ImageData, MCPServer, MarkdownData, ProgressData, Tool, ToolPreInvokeResponse
+from mcp.client.stdio import \
+    get_default_environment as mcp_get_default_environment
+from mcp.types import ImageContent, TextContent
+
+from lab_notebook_intelligence.api import (ChatCommand, ChatRequest,
+                                           ChatResponse, HTMLFrameData,
+                                           ImageData, MarkdownData, MCPServer,
+                                           ProgressData, Tool,
+                                           ToolPreInvokeResponse)
 from lab_notebook_intelligence.base_chat_participant import BaseChatParticipant
-import logging
-from fastmcp import Client
 
 log = logging.getLogger(__name__)
 
-MCP_ICON_SRC = 'iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAIAAAAiOjnJAAAPBUlEQVR4nOydf2wT5f/AW7pZGLOjE7K5DAfWIYMWM7rpWJTZkCxEh4hdcGKahYhkYRojGPQfUnCJMRpDlpD5hyEknZlWY2AL2UaiDubYLFkDzsE2CBlBrc5ldGm6VLjdnm++6Sf77DO758fdPb279v36kzz3ft73vhfPdffjfRkIIQMAKM0ytRMAUhMQC+ACiAVwAcQCuABiAVwAsQAugFgAF0AsgAsgFsAFEAvgAogFcAHEArgAYgFcALEALoBYABdALIALIBbABRAL4AKIBXABxAK4kKF2Anrlt99+6+vru3r16s2bN+/cuTM1NRWJRGZnZ1esWGGxWPLy8mw22+bNm8vLy7dt27Zy5Uq18002RnhLh4mhoaGvvvqqo6Pjxo0blJuYzebKykq32/3qq6+uXr2ac4KaAQEUiKL4zTffVFRUyCm12Wyur6+/du2a2nuTDEAsMu3t7Xa7Xan/ySaTae/evbdu3VJ7t/gCYuG4e/duTU2NUkotJCsry+v1CoKg9i7yAsRaEr/fb7VaeVg1zzPPPDM+Pq72jnIBxEqAKIoffPABV6XmWbVqVU9Pj9p7rDwg1mIEQairq0uOVXFMJlNbW5va+60wINb/cP/+/d27dyfTqlR1C8T6L2pZNe/W2bNn1a6BYsAF0v/w4MGDvXv3tre3q5hDdnb25cuXt2zZomIOSgFiGeRYZbVan3/++a1bt27YsCE3NzcjI0MQhD/++OP69es///xzIBAQBIEpYHFxcTAYfPjhh1kz0RxqL5nqI+EMaDQa3W53V1cX/kJUOBw+derUxo0bmYIfOnQoiXvPi3QXS4JVLpdraGiIfgpBEHw+X35+Pr21vb29PHc6GaS1WKxWmc3mU6dOSZtrenra7XZTTuRwOERRVHp3k0r6isVqldVq7evrkznpsWPHKKfz+/0K7ag6pKlYEqwKBoOKTH38+HGaGe12uyLTqUU6iqWiVXEaGhpo5tX1rZ60E0t1qxBCsVhs06ZNxKk9Ho+y8yaT9BJLC1bFGRgYMBqN+NktFkssFuMxexJII7G0Y1Wc2tpaYg4XLlzglwBX0uUtHdZr61ar9fvvv9+6dSu/lN5//33imJ6eHn4JcCUt3tLhbdXc3NyPP/7Y29sbDofXr1//8ssvP/7448StysrKHA7Hr7/+ihkzMDBAmYPmUHvJ5A7vM+CtW7fKy8sXRjCZTI2Njffv3ydu6/V6icnI23vVSHGxeFs1NjZWUFCQMFRdXR1x8wsXLhBTmpyclFcDdUhlsVS0Ks63336LjzA5OUnM6pdffpFdCRVIWbFUt8pgMOzatYsYZ/ny5fggOr1MmppiacEqg8FQWFhIDEWMc/78eXnFUIcUvNzA+2/AmzdvulyuUChEHDk3N6fIGD2SamJpxyqDwUDzkHEkEsEPWLFiBWVu2kLtJVNJNHIGnOe7777DB5yYmCAGuXr1quzCqEDqXCDV1FplMBh27979yiuv4McMDw8T42RmZg4ODobD4QcPHsQfNszNzS0oKKB/JFUd1DZbGbS2VlVVVUWjUWJY4gVSDBaLZfv27UePHu3s7NTgvepUEEunViGESktLJUm1GIvF8vrrr2vqwoTuxdKvVcFgUJJFODZt2nTmzBktNLHRt1j6tQoh5PF4JMlDprS09NKlS1KLqgw6FkvXVgUCAZPJJEkbWt54443p6Wmp1ZWLXsXStVWCIDgcDkm2sFFcXKzWrUZdiqVrqxBCjY2NkjyRgsViUaXXiP7E0rtVH3/8sSRDpKNKjySdiaV3qz755BNJbsgl+W7pSSywSg4mk+ncuXOSCi8F3YiV5lZlZ2c/9thjhYWFFotFThCmdiZy0IdY6WmV0+lsamq6dOnS1NTUwmhTU1M9PT1er/epp55ijblhw4ZIJMJSe4noQKx0s8poNHo8HsqlJRAI1NbWMl0SS07/La2LlW5WlZeXS/gmSiAQoL8wlpz+W5oWK92sOnz4sOTbfLFY7ODBg5QTlZaW8u6/pV2x0soqo9HY0tIiqU7/w4cffkg5I+/+WxoVK92sOn36tKQ6JYDSLYfDodSMCdGiWGCVTCjPifIbFGLQnFhglXwo+2/V19crPvU82hILrFKKvr4+Yv+tnJwcmgYT0tCQWGCVstD03/rhhx84za4VscAqxbl8+TIxk2PHjnGaXRNigVUY4p+j3rFjR35+fkFBQXV1Nf3zVU888QQ+merqavpMmFBfLLAKQzQafeGFF/4dp7a2lubnEfFznvn5+fTJMKGyWGAVhmg0WlVVtVS0t956ixihs7OTmBWn5+LVFAuswoC3Kh6Q+Cl8mve2b9y4QZ8VPaqJBVZhIFoV5+TJk/g4oihmZmbig3C6TKqOWGAVBkqrDAbDkSNHiNGIH+Ln9P60CmKBVRjorTIYDCdOnCAGfOSRR/BBUkQssAoDk1U0ToiiSHwGMBVOhWAVBlarKioqiDHv3r1LjKP7H+9gFQZWq4qKisbHx4lhz58/TwzF6RH4JIkFVmHgZBVC6OjRo/hQ+r5AClZh4GcVQqi4uBgfbefOnfSpMsFdLLAKA1er+vv7iQG9Xi99tkzwFQuswsDVKoRQwpuMi+DXRoujWGAVBt5W9fb2EmPm5OTw6/3HSyywCgNvq6LR6JNPPkkMe+DAAfqYrHARC6zCwNsqhND+/ftpIvf39zOFZUJ5sQRB2LNnD33hWK0aHx8HqzBQtvguKytjCsuK8mIx9S5ntWpqaqqoqIg+Pli1FLy//aSwWNPT01lZWZT7xmqVIAgul4v+wIBVS8F7uVJerO7ubsp9k/CN+BMnTtAfGLBqKUwmE9dfV3EUFuvMmTM0+ybBqmvXrtE36wGrMLz99ttMwaWhwoolwSpBEMrKyigLB1Zh2LhxI1NxJKOwWJFIJCcnB7NjEqxCCPl8PsrCgVUYLBYLp4dk/o3yfxVi2k1Ls0oQBJvNRlM4sAqDyWTq7Oxkii8HLtex9u3b9+8dy8vLk2AVQujs2bM0hausrASrliJ12nH7fD6n0xnvS5GXl3fo0KFQKCQt1I4dO4iFKywsnJiYoI8JVvGG79MNsVgsHA7LiRAKhWj+GOzq6qKPCVYlAfVfscfT0tJCrJ3b7aYPCFYlB62LVVNTQ6zdyMgIZTSwKmloWixRFFetWoUvH32/FLAqmWharJGREWIFfT4fTSiwKsloWqxz584Ri0jzxyBYlXw0LdZnn32GL6LNZiMG+fTTT+mPClilFJoW68iRI/g6vvjii/gIXV1dxB6v84BVCqJpsYj9yg8ePIiP4HQ6KY8KWKUsmharvr4eX83GxkbM5jSdC+KAVYqjabGILwU0NDRgNqd5BQqs4oSmxXrnnXfwNa2trcVsPjw8TDwqYBUnNC1WU1MTvqxbtmzBbC6KYn5+PmZzsIofmhbryy+/xFc2MzNzZmYGE+HkyZNglSpoWqwrV64Q69vd3Y2JIIpiXV1dwqMCVnFF02LNzMwQm/7u378fH0QQhObm5oXnRKfTyfQNGbBKApoWCyFUUVGBr3J2dva9e/eIcURRHB0dvXLlCuvzhmCVNLQuFk3R+X1pCKySjNbFCgaDxHJnZWXdvn1b8anBKjloXSyEEM3HQquqqpRt9QRWyUQHYmEuGSzk3XffVWpGsEo+OhDr3r17+Jdg51Gko2YkEgGr5KMDsZiOxOHDh+WcE0Oh0NNPPw1WyUcfYk1PT+Nvzixk+/btxO+tJaS9vZ1+FrAKjz7EYmrfEP878b333qN/hXVoaIipCyFYRUQ3YiGEdu3axXTss7KyPB5PR0fHUq/eh0KhL774wuVy0T9lClZRYvx/uXTC33//7XQ6f//9d9YNMzMzS0pKbDbbmjVrMjIyYrHYn3/+OTo6eufOHQlpFBUVXbx4cd26dfSbHD9+nL5rnMlkam1tfe211yTkpiHUNpuN/v7+5cuXq1guWKso0ZlYCCG/30/f2k9ZwCp69CcWQqitrS35boFVTOhSrPi6lcxzYnFxMVjFhF7FQggNDAwUFhby1Ok/VFdXszZjSnOr9C0WQmhiYoLYjkYOZrP5o48+EkWRKSuwSvdixWlra1uzZo3iVj377LPDw8OsyYBVcVJBrPjzCF6v12q1KqKU3W73+/2sCxVYtZAUEStONBptbm622+3SfDIajTU1NR0dHRKUAqsWkVJizRMMBpuamiorK81mM/EYW63WPXv2tLS0/PXXX5JnBKsWoadbOhL4559/RkdHr1+/HgqFJiYmZmZmYrHYypUrLRbLo48+un79+s2bN69bt27ZsmVyZknHOzZE1DZb98BalRAQSxZg1VKAWNIBqzCAWBIBq/CAWFIAq4iAWMyAVTSAWGyAVZSAWAyAVfSAWLSAVUyAWFSAVayAWGTAKgmAWAQoW5KAVYtI8ZvQMvnpp59cLpcoijSD0+XuMh0g1pLMzs7a7faxsTGawWDVImQ9LpLanD59GqySDKxYS1JSUjI6OkocBlYlBFasxIyMjIBVcgCxEnPx4kXiGLAKA4iVGOLnqMEqPCBWYiYnJ/ED3G43WIUBxErM3NwcfsDq1auTlYsuAbESk5ubix/g8/kGBgaSlY7+ALESY7PZ8AOi0ejOnTvBraUAsRLz3HPPEcdEIhFwayngAmli5ubm1q5dGwqFiCMtFkt3d/e2bduSkpdugBUrMcuWLXvzzTdpRsK6lRBYsZYkHA7bbLZwOEwzGNatRcCKtSRWq7W5uZlyMKxbiwCxcHg8ngMHDlAOBrcWAqdCArOzsx6P5+uvv6YcD+fEOLBiEcjIyGhtbU34KfyEwLoVB8QiA25JAMSiAtxiBX5jMSDh91ZfX5/D4eCclxYBsdhgdcvpdA4ODnJOSovAqZAN1nNiMBgEsQAqWN0KBAKcM9IiIJYUmNwSBIF/RpoDxJIIvVslJSVJyUhbwI93WRB/yxcUFNy+fVvdr8KqAqxYssCvW0aj8fPPP09Dq0AsBYi71dDQsOjfs7OzW1tbX3rpJZXyUhk4FSrG4OCg3+8fGxt76KGHysvL9+3bt3btWrWTUg0QC+ACnAoBLoBYABdALIALIBbABRAL4AKIBXABxAK4AGIBXACxAC6AWAAXQCyACyAWwAUQC+ACiAVwAcQCuABiAVz4vwAAAP//b8cbMGXTzMEAAAAASUVORK5CYII='
+MCP_ICON_SRC = "iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAIAAAAiOjnJAAAPBUlEQVR4nOydf2wT5f/AW7pZGLOjE7K5DAfWIYMWM7rpWJTZkCxEh4hdcGKahYhkYRojGPQfUnCJMRpDlpD5hyEknZlWY2AL2UaiDubYLFkDzsE2CBlBrc5ldGm6VLjdnm++6Sf77DO758fdPb279v36kzz3ft73vhfPdffjfRkIIQMAKM0ytRMAUhMQC+ACiAVwAcQCuABiAVwAsQAugFgAF0AsgAsgFsAFEAvgAogFcAHEArgAYgFcALEALoBYABdALIALIBbABRAL4AKIBXABxAK4kKF2Anrlt99+6+vru3r16s2bN+/cuTM1NRWJRGZnZ1esWGGxWPLy8mw22+bNm8vLy7dt27Zy5Uq18002RnhLh4mhoaGvvvqqo6Pjxo0blJuYzebKykq32/3qq6+uXr2ac4KaAQEUiKL4zTffVFRUyCm12Wyur6+/du2a2nuTDEAsMu3t7Xa7Xan/ySaTae/evbdu3VJ7t/gCYuG4e/duTU2NUkotJCsry+v1CoKg9i7yAsRaEr/fb7VaeVg1zzPPPDM+Pq72jnIBxEqAKIoffPABV6XmWbVqVU9Pj9p7rDwg1mIEQairq0uOVXFMJlNbW5va+60wINb/cP/+/d27dyfTqlR1C8T6L2pZNe/W2bNn1a6BYsAF0v/w4MGDvXv3tre3q5hDdnb25cuXt2zZomIOSgFiGeRYZbVan3/++a1bt27YsCE3NzcjI0MQhD/++OP69es///xzIBAQBIEpYHFxcTAYfPjhh1kz0RxqL5nqI+EMaDQa3W53V1cX/kJUOBw+derUxo0bmYIfOnQoiXvPi3QXS4JVLpdraGiIfgpBEHw+X35+Pr21vb29PHc6GaS1WKxWmc3mU6dOSZtrenra7XZTTuRwOERRVHp3k0r6isVqldVq7evrkznpsWPHKKfz+/0K7ag6pKlYEqwKBoOKTH38+HGaGe12uyLTqUU6iqWiVXEaGhpo5tX1rZ60E0t1qxBCsVhs06ZNxKk9Ho+y8yaT9BJLC1bFGRgYMBqN+NktFkssFuMxexJII7G0Y1Wc2tpaYg4XLlzglwBX0uUtHdZr61ar9fvvv9+6dSu/lN5//33imJ6eHn4JcCUt3tLhbdXc3NyPP/7Y29sbDofXr1//8ssvP/7448StysrKHA7Hr7/+ihkzMDBAmYPmUHvJ5A7vM+CtW7fKy8sXRjCZTI2Njffv3ydu6/V6icnI23vVSHGxeFs1NjZWUFCQMFRdXR1x8wsXLhBTmpyclFcDdUhlsVS0Ks63336LjzA5OUnM6pdffpFdCRVIWbFUt8pgMOzatYsYZ/ny5fggOr1MmppiacEqg8FQWFhIDEWMc/78eXnFUIcUvNzA+2/AmzdvulyuUChEHDk3N6fIGD2SamJpxyqDwUDzkHEkEsEPWLFiBWVu2kLtJVNJNHIGnOe7777DB5yYmCAGuXr1quzCqEDqXCDV1FplMBh27979yiuv4McMDw8T42RmZg4ODobD4QcPHsQfNszNzS0oKKB/JFUd1DZbGbS2VlVVVUWjUWJY4gVSDBaLZfv27UePHu3s7NTgvepUEEunViGESktLJUm1GIvF8vrrr2vqwoTuxdKvVcFgUJJFODZt2nTmzBktNLHRt1j6tQoh5PF4JMlDprS09NKlS1KLqgw6FkvXVgUCAZPJJEkbWt54443p6Wmp1ZWLXsXStVWCIDgcDkm2sFFcXKzWrUZdiqVrqxBCjY2NkjyRgsViUaXXiP7E0rtVH3/8sSRDpKNKjySdiaV3qz755BNJbsgl+W7pSSywSg4mk+ncuXOSCi8F3YiV5lZlZ2c/9thjhYWFFotFThCmdiZy0IdY6WmV0+lsamq6dOnS1NTUwmhTU1M9PT1er/epp55ijblhw4ZIJMJSe4noQKx0s8poNHo8HsqlJRAI1NbWMl0SS07/La2LlW5WlZeXS/gmSiAQoL8wlpz+W5oWK92sOnz4sOTbfLFY7ODBg5QTlZaW8u6/pV2x0soqo9HY0tIiqU7/w4cffkg5I+/+WxoVK92sOn36tKQ6JYDSLYfDodSMCdGiWGCVTCjPifIbFGLQnFhglXwo+2/V19crPvU82hILrFKKvr4+Yv+tnJwcmgYT0tCQWGCVstD03/rhhx84za4VscAqxbl8+TIxk2PHjnGaXRNigVUY4p+j3rFjR35+fkFBQXV1Nf3zVU888QQ+merqavpMmFBfLLAKQzQafeGFF/4dp7a2lubnEfFznvn5+fTJMKGyWGAVhmg0WlVVtVS0t956ixihs7OTmBWn5+LVFAuswoC3Kh6Q+Cl8mve2b9y4QZ8VPaqJBVZhIFoV5+TJk/g4oihmZmbig3C6TKqOWGAVBkqrDAbDkSNHiNGIH+Ln9P60CmKBVRjorTIYDCdOnCAGfOSRR/BBUkQssAoDk1U0ToiiSHwGMBVOhWAVBlarKioqiDHv3r1LjKP7H+9gFQZWq4qKisbHx4lhz58/TwzF6RH4JIkFVmHgZBVC6OjRo/hQ+r5AClZh4GcVQqi4uBgfbefOnfSpMsFdLLAKA1er+vv7iQG9Xi99tkzwFQuswsDVKoRQwpuMi+DXRoujWGAVBt5W9fb2EmPm5OTw6/3HSyywCgNvq6LR6JNPPkkMe+DAAfqYrHARC6zCwNsqhND+/ftpIvf39zOFZUJ5sQRB2LNnD33hWK0aHx8HqzBQtvguKytjCsuK8mIx9S5ntWpqaqqoqIg+Pli1FLy//aSwWNPT01lZWZT7xmqVIAgul4v+wIBVS8F7uVJerO7ubsp9k/CN+BMnTtAfGLBqKUwmE9dfV3EUFuvMmTM0+ybBqmvXrtE36wGrMLz99ttMwaWhwoolwSpBEMrKyigLB1Zh2LhxI1NxJKOwWJFIJCcnB7NjEqxCCPl8PsrCgVUYLBYLp4dk/o3yfxVi2k1Ls0oQBJvNRlM4sAqDyWTq7Oxkii8HLtex9u3b9+8dy8vLk2AVQujs2bM0hausrASrliJ12nH7fD6n0xnvS5GXl3fo0KFQKCQt1I4dO4iFKywsnJiYoI8JVvGG79MNsVgsHA7LiRAKhWj+GOzq6qKPCVYlAfVfscfT0tJCrJ3b7aYPCFYlB62LVVNTQ6zdyMgIZTSwKmloWixRFFetWoUvH32/FLAqmWharJGREWIFfT4fTSiwKsloWqxz584Ri0jzxyBYlXw0LdZnn32GL6LNZiMG+fTTT+mPClilFJoW68iRI/g6vvjii/gIXV1dxB6v84BVCqJpsYj9yg8ePIiP4HQ6KY8KWKUsmharvr4eX83GxkbM5jSdC+KAVYqjabGILwU0NDRgNqd5BQqs4oSmxXrnnXfwNa2trcVsPjw8TDwqYBUnNC1WU1MTvqxbtmzBbC6KYn5+PmZzsIofmhbryy+/xFc2MzNzZmYGE+HkyZNglSpoWqwrV64Q69vd3Y2JIIpiXV1dwqMCVnFF02LNzMwQm/7u378fH0QQhObm5oXnRKfTyfQNGbBKApoWCyFUUVGBr3J2dva9e/eIcURRHB0dvXLlCuvzhmCVNLQuFk3R+X1pCKySjNbFCgaDxHJnZWXdvn1b8anBKjloXSyEEM3HQquqqpRt9QRWyUQHYmEuGSzk3XffVWpGsEo+OhDr3r17+Jdg51Gko2YkEgGr5KMDsZiOxOHDh+WcE0Oh0NNPPw1WyUcfYk1PT+Nvzixk+/btxO+tJaS9vZ1+FrAKjz7EYmrfEP878b333qN/hXVoaIipCyFYRUQ3YiGEdu3axXTss7KyPB5PR0fHUq/eh0KhL774wuVy0T9lClZRYvx/uXTC33//7XQ6f//9d9YNMzMzS0pKbDbbmjVrMjIyYrHYn3/+OTo6eufOHQlpFBUVXbx4cd26dfSbHD9+nL5rnMlkam1tfe211yTkpiHUNpuN/v7+5cuXq1guWKso0ZlYCCG/30/f2k9ZwCp69CcWQqitrS35boFVTOhSrPi6lcxzYnFxMVjFhF7FQggNDAwUFhby1Ok/VFdXszZjSnOr9C0WQmhiYoLYjkYOZrP5o48+EkWRKSuwSvdixWlra1uzZo3iVj377LPDw8OsyYBVcVJBrPjzCF6v12q1KqKU3W73+/2sCxVYtZAUEStONBptbm622+3SfDIajTU1NR0dHRKUAqsWkVJizRMMBpuamiorK81mM/EYW63WPXv2tLS0/PXXX5JnBKsWoadbOhL4559/RkdHr1+/HgqFJiYmZmZmYrHYypUrLRbLo48+un79+s2bN69bt27ZsmVyZknHOzZE1DZb98BalRAQSxZg1VKAWNIBqzCAWBIBq/CAWFIAq4iAWMyAVTSAWGyAVZSAWAyAVfSAWLSAVUyAWFSAVayAWGTAKgmAWAQoW5KAVYtI8ZvQMvnpp59cLpcoijSD0+XuMh0g1pLMzs7a7faxsTGawWDVImQ9LpLanD59GqySDKxYS1JSUjI6OkocBlYlBFasxIyMjIBVcgCxEnPx4kXiGLAKA4iVGOLnqMEqPCBWYiYnJ/ED3G43WIUBxErM3NwcfsDq1auTlYsuAbESk5ubix/g8/kGBgaSlY7+ALESY7PZ8AOi0ejOnTvBraUAsRLz3HPPEcdEIhFwayngAmli5ubm1q5dGwqFiCMtFkt3d/e2bduSkpdugBUrMcuWLXvzzTdpRsK6lRBYsZYkHA7bbLZwOEwzGNatRcCKtSRWq7W5uZlyMKxbiwCxcHg8ngMHDlAOBrcWAqdCArOzsx6P5+uvv6YcD+fEOLBiEcjIyGhtbU34KfyEwLoVB8QiA25JAMSiAtxiBX5jMSDh91ZfX5/D4eCclxYBsdhgdcvpdA4ODnJOSovAqZAN1nNiMBgEsQAqWN0KBAKcM9IiIJYUmNwSBIF/RpoDxJIIvVslJSVJyUhbwI93WRB/yxcUFNy+fVvdr8KqAqxYssCvW0aj8fPPP09Dq0AsBYi71dDQsOjfs7OzW1tbX3rpJZXyUhk4FSrG4OCg3+8fGxt76KGHysvL9+3bt3btWrWTUg0QC+ACnAoBLoBYABdALIALIBbABRAL4AKIBXABxAK4AGIBXACxAC6AWAAXQCyACyAWwAUQC+ACiAVwAcQCuABiAVz4vwAAAP//b8cbMGXTzMEAAAAASUVORK5CYII="
 MCP_ICON_URL = f"data:image/png;base64,{MCP_ICON_SRC}"
 MCP_TOOL_TIMEOUT = 60
+WHITELISTED_MCP_TOOLS = {"SearchQBraid"}
 
 
 class MCPTool(Tool):
-    def __init__(self, server: 'MCPServer', name, description, schema, auto_approve=False):
+    def __init__(
+        self, server: "MCPServer", name, description, schema, auto_approve=False
+    ):
         super().__init__()
         self._server = server
         self._name = name
         self._description = description
         self._schema = schema
-        self._auto_approve = auto_approve
+        self._auto_approve = auto_approve or self._name in WHITELISTED_MCP_TOOLS
 
     @property
     def name(self) -> str:
@@ -37,11 +47,11 @@ def name(self) -> str:
     @property
     def title(self) -> str:
         return self._name
-    
+
     @property
     def tags(self) -> list[str]:
         return ["mcp-tool"]
-    
+
     @property
     def description(self) -> str:
         return self._description
@@ -54,22 +64,35 @@ def schema(self) -> dict:
                 "name": self._name,
                 "description": self._description,
                 "strict": False,
-                "parameters": self._schema
+                "parameters": self._schema,
             },
         }
-    
-    def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]:
+
+    def pre_invoke(
+        self, request: ChatRequest, tool_args: dict
+    ) -> Union[ToolPreInvokeResponse, None]:
         confirmationTitle = None
         confirmationMessage = None
         if not self._auto_approve:
             confirmationTitle = "Approve"
             confirmationMessage = "Are you sure you want to call this MCP tool?"
-        return ToolPreInvokeResponse(f"Calling MCP tool '{self.name}'", detail={"title": "Parameters", "content": json.dumps(tool_args)}, confirmationTitle=confirmationTitle, confirmationMessage=confirmationMessage)
-
-    async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str:
+        return ToolPreInvokeResponse(
+            f"Calling MCP tool '{self.name}'",
+            detail={"title": "Parameters", "content": json.dumps(tool_args)},
+            confirmationTitle=confirmationTitle,
+            confirmationMessage=confirmationMessage,
+        )
+
+    async def handle_tool_call(
+        self,
+        request: ChatRequest,
+        response: ChatResponse,
+        tool_context: dict,
+        tool_args: dict,
+    ) -> str:
         call_args = {}
 
-        for key in self._schema['properties']:
+        for key in self._schema["properties"]:
             if key in tool_args:
                 call_args[key] = tool_args.get(key)
 
@@ -80,7 +103,11 @@ async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, t
                     text_contents = []
                     for content in result.content:
                         if type(content) is ImageContent:
-                            response.stream(ImageData(f"data:{content.mimeType};base64,{content.data}"))
+                            response.stream(
+                                ImageData(
+                                    f"data:{content.mimeType};base64,{content.data}"
+                                )
+                            )
                         elif type(content) is TextContent:
                             text_contents.append(content.text)
 
@@ -95,16 +122,26 @@ async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, t
         except Exception as e:
             return f"Error occurred while calling MCP tool: {str(e)}"
 
+
 @dataclass
 class StreamableHttpServerParameters:
     url: str
     headers: dict[str, Any] | None = None
 
+
 class MCPServerImpl(MCPServer):
-    def __init__(self, name: str, stdio_params: StdioServerParameters = None, streamable_http_params: StreamableHttpServerParameters = None, auto_approve_tools: list[str] = []):
+    def __init__(
+        self,
+        name: str,
+        stdio_params: StdioServerParameters = None,
+        streamable_http_params: StreamableHttpServerParameters = None,
+        auto_approve_tools: list[str] = [],
+    ):
         self._name: str = name
         self._stdio_params: StdioServerParameters = stdio_params
-        self._streamable_http_params: StreamableHttpServerParameters = streamable_http_params
+        self._streamable_http_params: StreamableHttpServerParameters = (
+            streamable_http_params
+        )
         self._auto_approve_tools: set[str] = set(auto_approve_tools)
         self._tried_to_get_tool_list = False
         self._mcp_tools = []
@@ -117,20 +154,26 @@ def name(self) -> str:
 
     def _create_client(self) -> Client:
         if self._stdio_params is not None:
-            return Client(transport=StdioTransport(
-                command=self._stdio_params.command,
-                args=self._stdio_params.args,
-                env=self._stdio_params.env
-            ))
+            return Client(
+                transport=StdioTransport(
+                    command=self._stdio_params.command,
+                    args=self._stdio_params.args,
+                    env=self._stdio_params.env,
+                )
+            )
         elif self._streamable_http_params is not None:
-            return Client(transport=StreamableHttpTransport(
-                url=self._streamable_http_params.url,
-                headers=self._streamable_http_params.headers
-            ))
+            return Client(
+                transport=StreamableHttpTransport(
+                    url=self._streamable_http_params.url,
+                    headers=self._streamable_http_params.headers,
+                )
+            )
 
     async def get_client(self) -> Client:
         if self._stdio_params is None and self._streamable_http_params is None:
-            raise ValueError("Failed to create MCP client. Either stdio_params or sse_params must be provided")
+            raise ValueError(
+                "Failed to create MCP client. Either stdio_params or sse_params must be provided"
+            )
         if self._client is None:
             self._client = self._create_client()
         else:
@@ -156,7 +199,16 @@ async def call_tool(self, tool_name: str, tool_args: dict):
 
     # TODO: optimize this
     def get_tools(self) -> list[Tool]:
-        return [MCPTool(self, tool.name, tool.description, tool.inputSchema, auto_approve=(tool.name in self._auto_approve_tools)) for tool in self._mcp_tools]
+        return [
+            MCPTool(
+                self,
+                tool.name,
+                tool.description,
+                tool.inputSchema,
+                auto_approve=(tool.name in self._auto_approve_tools),
+            )
+            for tool in self._mcp_tools
+        ]
 
     def get_tool(self, tool_name: str) -> Tool:
         for tool in self.get_tools():
@@ -164,8 +216,11 @@ def get_tool(self, tool_name: str) -> Tool:
                 return tool
         return None
 
+
 class MCPChatParticipant(BaseChatParticipant):
-    def __init__(self, id: str, name: str, servers: list[MCPServer], nbi_tools: list[str] = []):
+    def __init__(
+        self, id: str, name: str, servers: list[MCPServer], nbi_tools: list[str] = []
+    ):
         super().__init__()
         self._id = id
         self._name = name
@@ -184,7 +239,7 @@ def name(self) -> str:
     @property
     def description(self) -> str:
         return self._name
-    
+
     @property
     def icon_path(self) -> str:
         return MCP_ICON_URL
@@ -203,12 +258,14 @@ def tools(self) -> list[Tool]:
             if tool is not None:
                 mcp_tools.append(tool)
         return mcp_tools
-    
+
     @property
     def servers(self) -> list[MCPServer]:
         return self._servers
-    
-    async def handle_chat_request(self, request: ChatRequest, response: ChatResponse, options: dict = {}) -> None:
+
+    async def handle_chat_request(
+        self, request: ChatRequest, response: ChatResponse, options: dict = {}
+    ) -> None:
         response.stream(ProgressData("Thinking..."))
 
         if request.command == "info":
@@ -223,6 +280,7 @@ async def handle_chat_request(self, request: ChatRequest, response: ChatResponse
         else:
             await self.handle_chat_request_with_tools(request, response, options)
 
+
 class MCPManager:
     def __init__(self, mcp_config: dict):
         self.update_mcp_servers(mcp_config)
@@ -245,10 +303,21 @@ def update_mcp_servers(self, mcp_config):
             participant_servers = self.create_servers(server_names, servers_config)
 
             if len(participant_servers) > 0:
-                self._mcp_participants.append(MCPChatParticipant(f"mcp-{participant_id}", participant_name, participant_servers, nbi_tools))
+                self._mcp_participants.append(
+                    MCPChatParticipant(
+                        f"mcp-{participant_id}",
+                        participant_name,
+                        participant_servers,
+                        nbi_tools,
+                    )
+                )
                 self._mcp_servers += participant_servers
 
-        enabled_server_names = [server_name for server_name in servers_config.keys() if servers_config.get(server_name, {}).get("disabled", False) == False]
+        enabled_server_names = [
+            server_name
+            for server_name in servers_config.keys()
+            if servers_config.get(server_name, {}).get("disabled", False) == False
+        ]
         unused_server_names = set(enabled_server_names)
 
         for participant in self._mcp_participants:
@@ -260,7 +329,9 @@ def update_mcp_servers(self, mcp_config):
             unused_servers = self.create_servers(unused_server_names, servers_config)
             mcp_participant_config = participants_config.get("mcp", {})
             nbi_tools = mcp_participant_config.get("nbiTools", [])
-            self._mcp_participants.append(MCPChatParticipant("mcp", "MCP", unused_servers, nbi_tools))
+            self._mcp_participants.append(
+                MCPChatParticipant("mcp", "MCP", unused_servers, nbi_tools)
+            )
             self._mcp_servers += unused_servers
 
         thread = threading.Thread(target=self.init_tool_lists, args=())
@@ -271,11 +342,15 @@ def create_servers(self, server_names: list[str], servers_config: dict):
         for server_name in server_names:
             server_config = servers_config.get(server_name, None)
             if server_config is None:
-                log.error(f"Server '{server_name}' not found in MCP servers configuration")
+                log.error(
+                    f"Server '{server_name}' not found in MCP servers configuration"
+                )
                 continue
 
             if server_config.get("disabled", False) == True:
-                log.info(f"MCP Server '{server_name}' is disabled in MCP servers configuration. Skipping it.")
+                log.info(
+                    f"MCP Server '{server_name}' is disabled in MCP servers configuration. Skipping it."
+                )
                 continue
 
             mcp_server = self.create_mcp_server(server_name, server_config)
@@ -286,7 +361,7 @@ def create_servers(self, server_names: list[str], servers_config: dict):
             servers.append(mcp_server)
 
         return servers
-    
+
     def create_mcp_server(self, server_name: str, server_config: dict):
         auto_approve_tools = server_config.get("autoApprove", [])
 
@@ -299,19 +374,23 @@ def create_mcp_server(self, server_name: str, server_config: dict):
                 server_env = mcp_get_default_environment()
                 server_env.update(env)
 
-            return MCPServerImpl(server_name, stdio_params=StdioServerParameters(
-                command = command,
-                args = args,
-                env = server_env
-                ), auto_approve_tools = auto_approve_tools)
+            return MCPServerImpl(
+                server_name,
+                stdio_params=StdioServerParameters(
+                    command=command, args=args, env=server_env
+                ),
+                auto_approve_tools=auto_approve_tools,
+            )
         elif "url" in server_config:
             server_url = server_config["url"]
             headers = server_config.get("headers", None)
 
             return MCPServerImpl(
                 server_name,
-                streamable_http_params=StreamableHttpServerParameters(url=server_url, headers=headers),
-                auto_approve_tools=auto_approve_tools
+                streamable_http_params=StreamableHttpServerParameters(
+                    url=server_url, headers=headers
+                ),
+                auto_approve_tools=auto_approve_tools,
             )
 
         log.error(f"Invalid MCP server configuration for: {server_name}")
@@ -326,13 +405,13 @@ async def init_tool_lists_async(self):
                 await server.update_tool_list()
             except Exception as e:
                 log.error(f"Error initializing tool list for server {server.name}: {e}")
-    
+
     def init_tool_lists(self):
         asyncio.run(self.init_tool_lists_async())
 
     def get_mcp_servers(self):
         return self._mcp_servers
-    
+
     def get_mcp_server(self, server_name: str):
         for server in self._mcp_servers:
             if server.name == server_name:
diff --git a/lab_notebook_intelligence/prompts.py b/lab_notebook_intelligence/prompts.py
index a1acc0a..4741a25 100644
--- a/lab_notebook_intelligence/prompts.py
+++ b/lab_notebook_intelligence/prompts.py
@@ -39,11 +39,24 @@
 You can only give one reply for each conversation turn.
 """
 
+
 class Prompts:
     @staticmethod
     def generic_chat_prompt(model_provider: str, model_name: str) -> str:
-        return CHAT_SYSTEM_PROMPT.format(AI_ASSISTANT_NAME="Notebook Intelligence", IDE_NAME=IDE_NAME, OS_TYPE=OS_TYPE, MODEL_NAME=model_name, MODEL_PROVIDER=model_provider)
+        return CHAT_SYSTEM_PROMPT.format(
+            AI_ASSISTANT_NAME="Lab Notebook Intelligence",
+            IDE_NAME=IDE_NAME,
+            OS_TYPE=OS_TYPE,
+            MODEL_NAME=model_name,
+            MODEL_PROVIDER=model_provider,
+        )
 
     @staticmethod
     def github_copilot_chat_prompt(model_provider: str, model_name: str) -> str:
-        return CHAT_SYSTEM_PROMPT.format(AI_ASSISTANT_NAME="GitHub Copilot", IDE_NAME=IDE_NAME, OS_TYPE=OS_TYPE, MODEL_NAME=model_name, MODEL_PROVIDER=model_provider)
+        return CHAT_SYSTEM_PROMPT.format(
+            AI_ASSISTANT_NAME="GitHub Copilot",
+            IDE_NAME=IDE_NAME,
+            OS_TYPE=OS_TYPE,
+            MODEL_NAME=model_name,
+            MODEL_PROVIDER=model_provider,
+        )
diff --git a/lab_notebook_intelligence/util.py b/lab_notebook_intelligence/util.py
index b7664b9..58a09f9 100644
--- a/lab_notebook_intelligence/util.py
+++ b/lab_notebook_intelligence/util.py
@@ -1,39 +1,42 @@
 # Copyright (c) Mehmet Bektas 
 
-import os
+import asyncio
 import base64
+import os
+
+from cryptography.fernet import Fernet
 from cryptography.hazmat.primitives import hashes
 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
-from cryptography.fernet import Fernet
-import asyncio
 from tornado import ioloop
 
+
 def extract_llm_generated_code(code: str) -> str:
-        if code.endswith("```"):
-            code = code[:-3]
+    if code.endswith("```"):
+        code = code[:-3]
+
+    lines = code.split("\n")
+    if len(lines) < 2:
+        return code
 
-        lines = code.split("\n")
-        if len(lines) < 2:
-            return code
+    num_lines = len(lines)
+    start_line = -1
+    end_line = num_lines
 
-        num_lines = len(lines)
-        start_line = -1
-        end_line = num_lines
+    for i in range(num_lines):
+        if start_line == -1:
+            if lines[i].lstrip().startswith("```"):
+                start_line = i
+                continue
+        else:
+            if lines[i].lstrip().startswith("```"):
+                end_line = i
+                break
 
-        for i in range(num_lines):
-            if start_line == -1:
-                if lines[i].lstrip().startswith("```"):
-                    start_line = i
-                    continue
-            else:
-                if lines[i].lstrip().startswith("```"):
-                    end_line = i
-                    break
+    if start_line != -1:
+        lines = lines[start_line + 1 : end_line]
 
-        if start_line != -1:
-            lines = lines[start_line+1:end_line]
+    return "\n".join(lines)
 
-        return "\n".join(lines)
 
 def encrypt_with_password(password: str, data: bytes) -> bytes:
     salt = os.urandom(16)
@@ -49,6 +52,7 @@ def encrypt_with_password(password: str, data: bytes) -> bytes:
 
     return salt + encrypted_data
 
+
 def decrypt_with_password(password: str, encrypted_data_with_salt: bytes) -> bytes:
     salt = encrypted_data_with_salt[:16]
     encrypted_data = encrypted_data_with_salt[16:]
@@ -64,13 +68,15 @@ def decrypt_with_password(password: str, encrypted_data_with_salt: bytes) -> byt
 
     return decrypted_data
 
-class ThreadSafeWebSocketConnector():
-  def __init__(self, websocket_handler):
-    self.io_loop = ioloop.IOLoop.current()
-    self.websocket_handler = websocket_handler
 
-  def write_message(self, message: dict):
-    def _write_message():
-        self.websocket_handler.write_message(message)
-    asyncio.set_event_loop(self.io_loop.asyncio_loop)
-    self.io_loop.asyncio_loop.call_soon_threadsafe(_write_message)
+class ThreadSafeWebSocketConnector:
+    def __init__(self, websocket_handler):
+        self.io_loop = ioloop.IOLoop.current()
+        self.websocket_handler = websocket_handler
+
+    def write_message(self, message: dict):
+        def _write_message():
+            self.websocket_handler.write_message(message)
+
+        asyncio.set_event_loop(self.io_loop.asyncio_loop)
+        self.io_loop.asyncio_loop.call_soon_threadsafe(_write_message)
diff --git a/src/index.ts b/src/index.ts
index 693e0b7..1573f45 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -327,8 +327,19 @@ class ActiveDocumentWatcher {
 class NBIInlineCompletionProvider
   implements IInlineCompletionProvider
 {
+  private _NBCellContextState: {
+    activeCellIndex: number;
+    preContent: string;
+    postContent: string;
+  };
+
   constructor(telemetryEmitter: TelemetryEmitter) {
     this._telemetryEmitter = telemetryEmitter;
+    this._NBCellContextState = {
+      activeCellIndex: -1,
+      preContent: '',
+      postContent: ''
+    };
   }
 
   get schema(): ISettingRegistry.IProperty {
@@ -340,6 +351,60 @@ class NBIInlineCompletionProvider
     };
   }
 
+  private calculateCellContext(context: IInlineCompletionContext): {
+    preContent: string;
+    postContent: string;
+  } {
+    if (!(context.widget instanceof NotebookPanel)) {
+      return { preContent: '', postContent: '' };
+    }
+
+    const notebook = context.widget.content;
+    const activeCellIndex = notebook.activeCellIndex;
+    // Check if the active cell has changed
+    if (this._NBCellContextState.activeCellIndex === activeCellIndex) {
+      // Return cached content if the active cell hasn't changed
+      return {
+        preContent: this._NBCellContextState.preContent,
+        postContent: this._NBCellContextState.postContent
+      };
+    }
+
+    // Recalculate pre and post content
+    let preContent = '';
+    let postContent = '';
+    // const activeCell = notebook.activeCell;
+    let activeCellReached = false;
+
+    for (let i = 0; i < notebook.widgets.length; i++) {
+      const cell = notebook.widgets[i];
+      const cellModel = cell.model.sharedModel;
+
+      if (i === activeCellIndex) {
+        activeCellReached = true;
+      } else if (!activeCellReached) {
+        if (cellModel.cell_type === 'code') {
+          preContent += cellModel.source + '\n';
+        } else if (cellModel.cell_type === 'markdown') {
+          preContent += markdownToComment(cellModel.source) + '\n';
+        }
+      } else {
+        if (cellModel.cell_type === 'code') {
+          postContent += cellModel.source + '\n';
+        } else if (cellModel.cell_type === 'markdown') {
+          postContent += markdownToComment(cellModel.source) + '\n';
+        }
+      }
+    }
+
+    // Update the state
+    this._NBCellContextState.activeCellIndex = activeCellIndex;
+    this._NBCellContextState.preContent = preContent;
+    this._NBCellContextState.postContent = postContent;
+
+    return { preContent, postContent };
+  }
+
   fetch(
     request: CompletionHandler.IRequest,
     context: IInlineCompletionContext
@@ -348,36 +413,13 @@ class NBIInlineCompletionProvider
     let postContent = '';
     const preCursor = request.text.substring(0, request.offset);
     const postCursor = request.text.substring(request.offset);
-    let language = ActiveDocumentWatcher.activeDocumentInfo.language;
+    const language = ActiveDocumentWatcher.activeDocumentInfo.language;
 
     let editorType = 'file-editor';
 
     if (context.widget instanceof NotebookPanel) {
       editorType = 'notebook';
-      const activeCell = context.widget.content.activeCell;
-      if (activeCell.model.sharedModel.cell_type === 'markdown') {
-        language = 'markdown';
-      }
-      let activeCellReached = false;
-
-      for (const cell of context.widget.content.widgets) {
-        const cellModel = cell.model.sharedModel;
-        if (cell === activeCell) {
-          activeCellReached = true;
-        } else if (!activeCellReached) {
-          if (cellModel.cell_type === 'code') {
-            preContent += cellModel.source + '\n';
-          } else if (cellModel.cell_type === 'markdown') {
-            preContent += markdownToComment(cellModel.source) + '\n';
-          }
-        } else {
-          if (cellModel.cell_type === 'code') {
-            postContent += cellModel.source + '\n';
-          } else if (cellModel.cell_type === 'markdown') {
-            postContent += markdownToComment(cellModel.source) + '\n';
-          }
-        }
-      }
+      ({ preContent, postContent } = this.calculateCellContext(context));
     }
 
     const nbiConfig = NBIAPI.config;

From 43f4f60dbf08862bbf33c80c23643c71e86ca7ee Mon Sep 17 00:00:00 2001
From: TheGupta2012 
Date: Wed, 24 Sep 2025 15:20:33 +0530
Subject: [PATCH 3/7] add jlpm deps

---
 .github/workflows/format.yml | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 581daf5..8f77a52 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -18,6 +18,9 @@ jobs:
 
       - name: Base Setup
         uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1
+        
+      - name: Install dependencies
+        run: python -m pip install -U "jupyterlab>=4.0.0,<5"
 
       - name: Lint the extension
         run: |

From 17c44872dedb877854c03640627279cccb935d01 Mon Sep 17 00:00:00 2001
From: TheGupta2012 
Date: Wed, 24 Sep 2025 15:23:03 +0530
Subject: [PATCH 4/7] formatting python

---
 .../ai_service_manager.py                     | 57 ++++++++++++-------
 .../base_chat_participant.py                  | 14 +++--
 lab_notebook_intelligence/extension.py        | 22 ++++---
 lab_notebook_intelligence/github_copilot.py   | 18 ++++--
 .../github_copilot_llm_provider.py            | 21 ++++---
 .../litellm_compatible_llm_provider.py        | 15 +++--
 .../llm_providers/ollama_llm_provider.py      | 13 +++--
 .../openai_compatible_llm_provider.py         | 15 +++--
 lab_notebook_intelligence/mcp_manager.py      | 20 ++++---
 9 files changed, 129 insertions(+), 66 deletions(-)

diff --git a/lab_notebook_intelligence/ai_service_manager.py b/lab_notebook_intelligence/ai_service_manager.py
index be8722b..4431e21 100644
--- a/lab_notebook_intelligence/ai_service_manager.py
+++ b/lab_notebook_intelligence/ai_service_manager.py
@@ -8,29 +8,44 @@
 from typing import Dict
 
 from lab_notebook_intelligence import github_copilot
-from lab_notebook_intelligence.api import (ButtonData, ChatModel,
-                                           ChatParticipant, ChatRequest,
-                                           ChatResponse, CompletionContext,
-                                           CompletionContextProvider,
-                                           ContextRequest, EmbeddingModel,
-                                           Host, InlineCompletionModel,
-                                           LLMProvider, MarkdownData,
-                                           MCPServer,
-                                           NotebookIntelligenceExtension,
-                                           TelemetryEvent, TelemetryListener,
-                                           Tool, Toolset)
+from lab_notebook_intelligence.api import (
+    ButtonData,
+    ChatModel,
+    ChatParticipant,
+    ChatRequest,
+    ChatResponse,
+    CompletionContext,
+    CompletionContextProvider,
+    ContextRequest,
+    EmbeddingModel,
+    Host,
+    InlineCompletionModel,
+    LLMProvider,
+    MarkdownData,
+    MCPServer,
+    NotebookIntelligenceExtension,
+    TelemetryEvent,
+    TelemetryListener,
+    Tool,
+    Toolset,
+)
 from lab_notebook_intelligence.base_chat_participant import BaseChatParticipant
 from lab_notebook_intelligence.config import NBIConfig
-from lab_notebook_intelligence.github_copilot_chat_participant import \
-    GithubCopilotChatParticipant
-from lab_notebook_intelligence.llm_providers.github_copilot_llm_provider import \
-    GitHubCopilotLLMProvider
-from lab_notebook_intelligence.llm_providers.litellm_compatible_llm_provider import \
-    LiteLLMCompatibleLLMProvider
-from lab_notebook_intelligence.llm_providers.ollama_llm_provider import \
-    OllamaLLMProvider
-from lab_notebook_intelligence.llm_providers.openai_compatible_llm_provider import \
-    OpenAICompatibleLLMProvider
+from lab_notebook_intelligence.github_copilot_chat_participant import (
+    GithubCopilotChatParticipant,
+)
+from lab_notebook_intelligence.llm_providers.github_copilot_llm_provider import (
+    GitHubCopilotLLMProvider,
+)
+from lab_notebook_intelligence.llm_providers.litellm_compatible_llm_provider import (
+    LiteLLMCompatibleLLMProvider,
+)
+from lab_notebook_intelligence.llm_providers.ollama_llm_provider import (
+    OllamaLLMProvider,
+)
+from lab_notebook_intelligence.llm_providers.openai_compatible_llm_provider import (
+    OpenAICompatibleLLMProvider,
+)
 from lab_notebook_intelligence.mcp_manager import MCPManager
 
 log = logging.getLogger(__name__)
diff --git a/lab_notebook_intelligence/base_chat_participant.py b/lab_notebook_intelligence/base_chat_participant.py
index 2334d15..96e5f3f 100644
--- a/lab_notebook_intelligence/base_chat_participant.py
+++ b/lab_notebook_intelligence/base_chat_participant.py
@@ -6,10 +6,16 @@
 import os
 from typing import Union
 
-from lab_notebook_intelligence.api import (ChatCommand, ChatParticipant,
-                                           ChatRequest, ChatResponse,
-                                           MarkdownData, ProgressData, Tool,
-                                           ToolPreInvokeResponse)
+from lab_notebook_intelligence.api import (
+    ChatCommand,
+    ChatParticipant,
+    ChatRequest,
+    ChatResponse,
+    MarkdownData,
+    ProgressData,
+    Tool,
+    ToolPreInvokeResponse,
+)
 from lab_notebook_intelligence.built_in_toolsets import built_in_toolsets
 from lab_notebook_intelligence.prompts import Prompts
 from lab_notebook_intelligence.util import extract_llm_generated_code
diff --git a/lab_notebook_intelligence/extension.py b/lab_notebook_intelligence/extension.py
index 59750cf..d51b5cd 100644
--- a/lab_notebook_intelligence/extension.py
+++ b/lab_notebook_intelligence/extension.py
@@ -21,13 +21,21 @@
 
 import lab_notebook_intelligence.github_copilot as github_copilot
 from lab_notebook_intelligence.ai_service_manager import AIServiceManager
-from lab_notebook_intelligence.api import (BackendMessageType, BuiltinToolset,
-                                           CancelToken, ChatMode, ChatRequest,
-                                           ChatResponse, ContextRequest,
-                                           ContextRequestType, RequestDataType,
-                                           RequestToolSelection,
-                                           ResponseStreamData,
-                                           ResponseStreamDataType, SignalImpl)
+from lab_notebook_intelligence.api import (
+    BackendMessageType,
+    BuiltinToolset,
+    CancelToken,
+    ChatMode,
+    ChatRequest,
+    ChatResponse,
+    ContextRequest,
+    ContextRequestType,
+    RequestDataType,
+    RequestToolSelection,
+    ResponseStreamData,
+    ResponseStreamDataType,
+    SignalImpl,
+)
 from lab_notebook_intelligence.built_in_toolsets import built_in_toolsets
 from lab_notebook_intelligence.util import ThreadSafeWebSocketConnector
 
diff --git a/lab_notebook_intelligence/github_copilot.py b/lab_notebook_intelligence/github_copilot.py
index 1490b5b..41013ea 100644
--- a/lab_notebook_intelligence/github_copilot.py
+++ b/lab_notebook_intelligence/github_copilot.py
@@ -17,12 +17,18 @@
 import requests
 import sseclient
 
-from lab_notebook_intelligence.api import (BackendMessageType, CancelToken,
-                                           ChatResponse, CompletionContext,
-                                           MarkdownData)
-from lab_notebook_intelligence.util import (ThreadSafeWebSocketConnector,
-                                            decrypt_with_password,
-                                            encrypt_with_password)
+from lab_notebook_intelligence.api import (
+    BackendMessageType,
+    CancelToken,
+    ChatResponse,
+    CompletionContext,
+    MarkdownData,
+)
+from lab_notebook_intelligence.util import (
+    ThreadSafeWebSocketConnector,
+    decrypt_with_password,
+    encrypt_with_password,
+)
 
 from ._version import __version__ as NBI_VERSION
 
diff --git a/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py b/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py
index 8386aa0..c7f69f3 100644
--- a/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py
+++ b/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py
@@ -3,13 +3,20 @@
 import logging
 from typing import Any
 
-from lab_notebook_intelligence.api import (CancelToken, ChatModel,
-                                           ChatResponse, CompletionContext,
-                                           EmbeddingModel,
-                                           InlineCompletionModel, LLMProvider)
-from lab_notebook_intelligence.github_copilot import (completions,
-                                                      generate_copilot_headers,
-                                                      inline_completions)
+from lab_notebook_intelligence.api import (
+    CancelToken,
+    ChatModel,
+    ChatResponse,
+    CompletionContext,
+    EmbeddingModel,
+    InlineCompletionModel,
+    LLMProvider,
+)
+from lab_notebook_intelligence.github_copilot import (
+    completions,
+    generate_copilot_headers,
+    inline_completions,
+)
 
 log = logging.getLogger(__name__)
 
diff --git a/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py b/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py
index b8e4572..f01bdd3 100644
--- a/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py
+++ b/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py
@@ -5,11 +5,16 @@
 
 import litellm
 
-from lab_notebook_intelligence.api import (CancelToken, ChatModel,
-                                           ChatResponse, CompletionContext,
-                                           EmbeddingModel,
-                                           InlineCompletionModel, LLMProvider,
-                                           LLMProviderProperty)
+from lab_notebook_intelligence.api import (
+    CancelToken,
+    ChatModel,
+    ChatResponse,
+    CompletionContext,
+    EmbeddingModel,
+    InlineCompletionModel,
+    LLMProvider,
+    LLMProviderProperty,
+)
 
 DEFAULT_CONTEXT_WINDOW = 4096
 
diff --git a/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py b/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py
index 139f455..0dbb101 100644
--- a/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py
+++ b/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py
@@ -6,10 +6,15 @@
 
 import ollama
 
-from lab_notebook_intelligence.api import (CancelToken, ChatModel,
-                                           ChatResponse, CompletionContext,
-                                           EmbeddingModel,
-                                           InlineCompletionModel, LLMProvider)
+from lab_notebook_intelligence.api import (
+    CancelToken,
+    ChatModel,
+    ChatResponse,
+    CompletionContext,
+    EmbeddingModel,
+    InlineCompletionModel,
+    LLMProvider,
+)
 from lab_notebook_intelligence.util import extract_llm_generated_code
 
 log = logging.getLogger(__name__)
diff --git a/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
index 9917a55..6a58654 100644
--- a/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
+++ b/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
@@ -5,11 +5,16 @@
 
 from openai import OpenAI
 
-from lab_notebook_intelligence.api import (CancelToken, ChatModel,
-                                           ChatResponse, CompletionContext,
-                                           EmbeddingModel,
-                                           InlineCompletionModel, LLMProvider,
-                                           LLMProviderProperty)
+from lab_notebook_intelligence.api import (
+    CancelToken,
+    ChatModel,
+    ChatResponse,
+    CompletionContext,
+    EmbeddingModel,
+    InlineCompletionModel,
+    LLMProvider,
+    LLMProviderProperty,
+)
 
 DEFAULT_CONTEXT_WINDOW = 4096
 
diff --git a/lab_notebook_intelligence/mcp_manager.py b/lab_notebook_intelligence/mcp_manager.py
index e8c84d5..9bbc7e1 100644
--- a/lab_notebook_intelligence/mcp_manager.py
+++ b/lab_notebook_intelligence/mcp_manager.py
@@ -10,15 +10,21 @@
 from fastmcp import Client
 from fastmcp.client import StdioTransport, StreamableHttpTransport
 from mcp import StdioServerParameters
-from mcp.client.stdio import \
-    get_default_environment as mcp_get_default_environment
+from mcp.client.stdio import get_default_environment as mcp_get_default_environment
 from mcp.types import ImageContent, TextContent
 
-from lab_notebook_intelligence.api import (ChatCommand, ChatRequest,
-                                           ChatResponse, HTMLFrameData,
-                                           ImageData, MarkdownData, MCPServer,
-                                           ProgressData, Tool,
-                                           ToolPreInvokeResponse)
+from lab_notebook_intelligence.api import (
+    ChatCommand,
+    ChatRequest,
+    ChatResponse,
+    HTMLFrameData,
+    ImageData,
+    MarkdownData,
+    MCPServer,
+    ProgressData,
+    Tool,
+    ToolPreInvokeResponse,
+)
 from lab_notebook_intelligence.base_chat_participant import BaseChatParticipant
 
 log = logging.getLogger(__name__)

From 833d3a76d60aff7b54f799eaeeb6cbdc3798b81e Mon Sep 17 00:00:00 2001
From: TheGupta2012 
Date: Fri, 26 Sep 2025 13:16:47 +0530
Subject: [PATCH 5/7] add context files and improve UI

---
 .github/workflows/format.yml           |   2 +-
 lab_notebook_intelligence/extension.py |  31 ++--
 src/api.ts                             |   2 +-
 src/chat-sidebar.tsx                   | 124 ++++++++++++---
 src/context.tsx                        | 102 +++++++++++++
 src/index.ts                           |  10 ++
 src/tokens.ts                          |   3 +-
 src/utils.ts                           |   2 -
 style/base.css                         | 204 ++++++++++++++++++++++++-
 9 files changed, 436 insertions(+), 44 deletions(-)
 create mode 100644 src/context.tsx

diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 8f77a52..4faf93b 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -36,7 +36,7 @@ jobs:
       - name: Install Python dependencies
         run: |
           python -m pip install --upgrade pip
-          pip install black isort 
+          pip install black isort>=6.0.0
 
       - name: Check Python code formatting
         run: |
diff --git a/lab_notebook_intelligence/extension.py b/lab_notebook_intelligence/extension.py
index d51b5cd..9f3d489 100644
--- a/lab_notebook_intelligence/extension.py
+++ b/lab_notebook_intelligence/extension.py
@@ -694,42 +694,41 @@ def on_message(self, message):
 
             for context in additionalContext:
                 file_path = context["filePath"]
+
                 file_path = path.join(NotebookIntelligence.root_dir, file_path)
                 filename = path.basename(file_path)
-                start_line = context["startLine"]
-                end_line = context["endLine"]
-                current_cell_contents = context["currentCellContents"]
-                current_cell_input = (
-                    current_cell_contents["input"]
-                    if current_cell_contents is not None
-                    else ""
-                )
-                current_cell_output = (
-                    current_cell_contents["output"]
-                    if current_cell_contents is not None
-                    else ""
-                )
+                start_line = context.get("startLine", 0)
+                end_line = context.get("endLine", 0)
+                current_cell_contents = context.get("currentCellContents", {})
+                current_cell_input = current_cell_contents.get("input", "")
+                current_cell_output = current_cell_contents.get("output", "")
+
                 current_cell_context = (
                     f"This is a Jupyter notebook and currently selected cell input is: ```{current_cell_input}``` and currently selected cell output is: ```{current_cell_output}```. If user asks a question about 'this' cell then assume that user is referring to currently selected cell."
                     if current_cell_contents is not None
                     else ""
                 )
-                context_content = context["content"]
+                context_content = context.get("content", "")
                 token_count = len(tiktoken_encoding.encode(context_content))
                 if token_count > token_budget:
                     context_content = context_content[: int(token_budget)] + "..."
+                msg_content = f"Use this as additional context: ```{context_content}```. It is from current file: '{filename}' at path '{file_path}'"
+                if start_line >= 0 and end_line > 0:
+                    msg_content += f", lines: {start_line} - {end_line}."
+                if current_cell_context != "":
+                    msg_content += f" {current_cell_context}"
 
                 request_chat_history.append(
                     {
                         "role": "user",
-                        "content": f"Use this as additional context: ```{context_content}```. It is from current file: '{filename}' at path '{file_path}', lines: {start_line} - {end_line}. {current_cell_context}",
+                        "content": msg_content,
                     }
                 )
                 self.chat_history.add_message(
                     chatId,
                     {
                         "role": "user",
-                        "content": f"This file was provided as additional context: '{filename}' at path '{file_path}', lines: {start_line} - {end_line}. {current_cell_context}",
+                        "content": f"This file was provided as additional context: '{filename}' at path '{file_path}'. {current_cell_context}",
                     },
                 )
 
diff --git a/src/api.ts b/src/api.ts
index b743987..84a91fb 100644
--- a/src/api.ts
+++ b/src/api.ts
@@ -129,7 +129,7 @@ export class NBIAPI {
     };
 
     this._webSocket.onclose = msg => {
-      console.log(`Websocket is closed: ${msg.reason}. Reconnecting...`);
+      console.log(`Websocket is closed: ${msg}. Reconnecting...`);
       setTimeout(() => {
         NBIAPI.initializeWebsocket();
       }, 1000);
diff --git a/src/chat-sidebar.tsx b/src/chat-sidebar.tsx
index 64e7932..38b371b 100644
--- a/src/chat-sidebar.tsx
+++ b/src/chat-sidebar.tsx
@@ -33,6 +33,7 @@ import {
   TelemetryEventType
 } from './tokens';
 import { JupyterFrontEnd } from '@jupyterlab/application';
+import { FileBrowserDropdown } from './context';
 import { MarkdownRenderer as OriginalMarkdownRenderer } from './markdown-renderer';
 const MarkdownRenderer = memo(OriginalMarkdownRenderer);
 
@@ -50,7 +51,8 @@ import {
   VscSettingsGear,
   VscPassFilled,
   VscTools,
-  VscTrash
+  VscTrash,
+  VscAttach
 } from 'react-icons/vsc';
 
 import { MdOutlineCheckBoxOutlineBlank, MdCheckBox } from 'react-icons/md';
@@ -96,6 +98,7 @@ export interface IChatSidebarOptions {
   openFile: (path: string) => void;
   getApp: () => JupyterFrontEnd;
   getTelemetryEmitter: () => ITelemetryEmitter;
+  getFileContent: (filePath: string) => Promise;
 }
 
 export class ChatSidebar extends ReactWidget {
@@ -115,6 +118,7 @@ export class ChatSidebar extends ReactWidget {
         openFile={this._options.openFile}
         getApp={this._options.getApp}
         getTelemetryEmitter={this._options.getTelemetryEmitter}
+        getFileContent={this._options.getFileContent}
       />
     );
   }
@@ -744,6 +748,26 @@ function SidebarComponent(props: any) {
   const [currentChatModel, setCurrentChatModel] = useState(
     NBIAPI.config.chatModel.model
   );
+  const [showFileBrowser, setShowFileBrowser] = useState(false);
+  const [selectedFiles, setSelectedFiles] = useState([]);
+
+  const selectContextFiles = () => {
+    setShowFileBrowser(true);
+  };
+
+  const handleFilesSelected = (filePaths: string[]) => {
+    // append the new files to the existing selected files
+    const finalFiles = [...selectedFiles, ...filePaths];
+    const uniqueFiles = Array.from(new Set(finalFiles));
+    setSelectedFiles(uniqueFiles);
+    setShowFileBrowser(false);
+    console.log(`Selected files: ${filePaths}`);
+  };
+
+  const handleContextCancel = () => {
+    setSelectedFiles([]);
+    setShowFileBrowser(false);
+  };
 
   // Load available models when config changes
   useEffect(() => {
@@ -1306,6 +1330,7 @@ function SidebarComponent(props: any) {
     const app = props.getApp();
     const additionalContext: IContextItem[] = [];
     if (contextOn && activeDocumentInfo?.filename) {
+      // TODO: update for selected files
       const selection = activeDocumentInfo.selection;
       const textSelected =
         selection &&
@@ -1326,6 +1351,20 @@ function SidebarComponent(props: any) {
       });
     }
 
+    if (selectedFiles.length > 0) {
+      // parse through each selected file and add to context
+      for (const filePath of selectedFiles) {
+        additionalContext.push({
+          type: ContextType.File,
+          content: await props.getFileContent(filePath),
+          filePath: filePath,
+          currentCellContents: undefined
+        });
+      }
+    }
+
+    console.log('Complete context : ', additionalContext);
+
     submitCompletionRequest(
       {
         messageId: lastMessageId.current,
@@ -1870,30 +1909,71 @@ function SidebarComponent(props: any) {
             spellCheck={false}
             value={prompt}
           />
-          {activeDocumentInfo?.filename && (
-            
-
+ {activeDocumentInfo?.filename && ( +
+
+
{currentFileContextTitle}
+ {contextOn ? ( +
setContextOn(!contextOn)} + > + +
+ ) : ( +
setContextOn(!contextOn)} + > + +
+ )} +
+
+ )} +
+ +
+ {showFileBrowser && ( +
+
+ )} +
+ {selectedFiles.length > 0 && ( +
+
    + {selectedFiles.map(file => ( +
  • + {file} + +
  • + ))} +
+
+ )}
- )} +
{chatMode === 'ask' && (
diff --git a/src/context.tsx b/src/context.tsx new file mode 100644 index 0000000..ca8ff26 --- /dev/null +++ b/src/context.tsx @@ -0,0 +1,102 @@ +import React, { useEffect, useState } from 'react'; +import { + VscArrowLeft, + VscFile, + VscFolder, + VscArrowRight +} from 'react-icons/vsc'; + +import { ContentsManager } from '@jupyterlab/services'; + +const contentsManager = new ContentsManager(); + +export function FileBrowserDropdown(props: { + onFilesSelected: (filePaths: string[]) => void; + onCancelSelection: () => void; +}) { + const [currentPath, setCurrentPath] = useState(null); + const [parentPaths, setParentPaths] = useState([]); // Stack of parent paths + const [files, setFiles] = useState([]); + const [selectedFiles, setSelectedFiles] = useState>(new Set()); + + useEffect(() => { + // Fetch files when the component mounts or the path changes + const fetchFiles = async (path: string = '') => { + const response = await contentsManager.get(path); + setFiles(response.content); // Returns an array of files and directories + }; + fetchFiles(currentPath?.path); + }, [currentPath]); + + const handleFileClick = (file: any) => { + if (file.type === 'directory') { + // Navigate into the directory + setParentPaths([...parentPaths, currentPath]); // Append currentPath to parentPaths + setCurrentPath(file); // Set the clicked directory as the current path + } else { + // Toggle file selection + const updatedSelectedFiles = new Set(selectedFiles); + if (updatedSelectedFiles.has(file.path)) { + updatedSelectedFiles.delete(file.path); + } else { + updatedSelectedFiles.add(file.path); + } + setSelectedFiles(updatedSelectedFiles); + } + }; + + const handleBackClick = () => { + if (parentPaths.length > 0) { + const newParentPaths = [...parentPaths]; + const lastParentPath = newParentPaths.pop(); // Remove the last element + setParentPaths(newParentPaths); // Update the parentPaths stack + setCurrentPath(lastParentPath); // Set the last parent path as the current path + } + }; + + const handleConfirmSelection = () => { + props.onFilesSelected(Array.from(selectedFiles)); + }; + + const handleCancelSelection = () => { + // Clear selection and close dropdown + setSelectedFiles(new Set()); + props.onCancelSelection(); + // You might want to add a prop to control the visibility of this dropdown + // and call a function here to close it. + }; + + return ( +
+
+ + {currentPath?.path || '/'} +
+
    + {files.map(file => ( +
  • handleFileClick(file)} + > + {file.type === 'directory' ? ( + + ) : ( + + )} + {file.name} +
  • + ))} +
+ {parentPaths.length > 0 && ( + + )} + + +
+ ); +} diff --git a/src/index.ts b/src/index.ts index 1573f45..5b4cf94 100644 --- a/src/index.ts +++ b/src/index.ts @@ -789,6 +789,16 @@ const plugin: JupyterFrontEndPlugin = { }, getTelemetryEmitter(): ITelemetryEmitter { return telemetryEmitter; + }, + getFileContent: async (filepath: string): Promise => { + try { + const contentManager = new ContentsManager(); + const file = await contentManager.get(filepath); // Fetch file content + return JSON.stringify(file.content); // Return the file content + } catch (error) { + console.error('Failed to get file content:', error); + return null; // Return null if an error occurs + } } }); panel.addWidget(sidebar); diff --git a/src/tokens.ts b/src/tokens.ts index acfb6f6..2c6400a 100644 --- a/src/tokens.ts +++ b/src/tokens.ts @@ -49,7 +49,8 @@ export enum ResponseStreamDataType { export enum ContextType { Custom = 'custom', - CurrentFile = 'current-file' + CurrentFile = 'current-file', + File = 'file' } export interface IContextItem { diff --git a/src/utils.ts b/src/utils.ts index c5fd57b..131b031 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,5 +1,3 @@ -// Copyright (c) Mehmet Bektas - import { CodeCell } from '@jupyterlab/cells'; import { PartialJSONObject } from '@lumino/coreutils'; import { CodeEditor } from '@jupyterlab/codeeditor'; diff --git a/style/base.css b/style/base.css index b2b7a2a..63e42c7 100644 --- a/style/base.css +++ b/style/base.css @@ -8,7 +8,7 @@ .sidebar { display: flex; flex-direction: column; - min-width: 25%; + min-width: 40%; height: 100%; } @@ -77,6 +77,12 @@ max-height: 200px; } +.user-input-context-and-add-context { + display: flex; /* Align children horizontally */ + align-items: center; /* Vertically align items */ + gap: 10px; /* Add spacing between the two elements */ +} + .sidebar-user-input .user-input-context-row { display: flex; flex-direction: row; @@ -87,6 +93,7 @@ display: flex; align-items: center; border: 1px solid var(--jp-border-color1); + margin: 5px; border-radius: 4px; padding: 2px; gap: 5px; @@ -119,6 +126,35 @@ gap: 10px; } +.sidebar-add-context { + margin: 10px 0; + display: flex; + align-items: center; + justify-content: flex-start; + border-radius: 8px; + flex-shrink: 0; +} + +.add-context-button { + display: flex; + align-items: center; + background-color: var(--jp-layout-color1); + border: 1px solid var(--jp-border-color2); + padding: 5px 10px; + cursor: pointer; + font-size: 12px; + color: var(--jp-ui-font-color2); + transition: background-color 0.2s; +} + +.add-context-button:hover { + background-color: var(--jp-layout-color2); +} + +.add-context-button svg { + margin-right: 5px; +} + .chat-mode-widgets-container { display: flex; flex-direction: row; @@ -871,3 +907,169 @@ svg.access-token-warning { text-overflow: ellipsis; white-space: nowrap; } + +.file-browser-path { + display: flex; /* Use flexbox for alignment */ + align-items: center; /* Vertically center the arrow and text */ + gap: 8px; /* Add spacing between the arrow and the text */ + padding: 4px 8px; + background-color: var(--jp-layout-color2); /* Subtle background color */ + border: 1px solid var(--jp-border-color1); /* Border for better visibility */ + border-radius: 4px; /* Rounded corners */ + font-size: 12px; /* Adjust font size */ + color: var(--jp-ui-font-color1); /* Dynamic text color */ + font-weight: bold; /* Make the text stand out */ +} + +.file-browser-path .current-path { + overflow: hidden; /* Prevent overflow */ + text-overflow: ellipsis; /* Add ellipsis for long paths */ + white-space: nowrap; /* Prevent wrapping */ +} + +.file-browser-dropdown button { + background-color: var(--jp-brand-color1); + color: white; + border: none; + border-radius: 4px; + padding: 8px 12px; + margin: 5px; + cursor: pointer; + font-size: 12px; + transition: background-color 0.2s ease; +} + +.file-browser-dropdown button:hover { + background-color: var(--jp-brand-color2); +} + +.file-browser-dropdown button:active { + background-color: var(--jp-brand-color3); +} + +.file-browser-dropdown button:disabled { + background-color: #ccc; + cursor: not-allowed; +} + +.file-browser-dropdown-container { + position: fixed; /* Fixed positioning to align relative to the viewport */ + top: 50px; /* Adjust this value to place it below the JupyterLab header */ + left: 30%; /* Center horizontally */ + transform: translateX(-50%); /* Offset by 50% of its width to center */ + padding: 5%; + width: 60%; + margin-left: 20%; + margin-bottom: 20%; + max-height: 200px; + overflow-y: auto; + background-color: var(--jp-layout-color1); + border: 1px solid var(--jp-border-color1); + border-radius: 4px; + z-index: 1000; + box-shadow: 0 4px 6px rgba(0, 0, 0, 10%); +} + +.file-browser-dropdown-container ul { + list-style: none; + padding: 0; + margin: 0; +} + +.file-browser-dropdown-container li { + display: flex; + align-items: center; + gap: 8px; + padding: 8px 12px; + cursor: pointer; + border-bottom: 1px solid var(--jp-border-color2); + transition: background-color 0.2s ease; +} + +.file-browser-dropdown-container li:hover { + background-color: var(--jp-layout-color2); +} + +.file-browser-dropdown-container li:last-child { + border-bottom: none; +} + +.file-browser-dropdown-container li.directory::before { + font-size: 16px; + color: var(--jp-ui-font-color2); +} + +.file-browser-dropdown-container .file-browser-item { + padding: 8px 12px; + cursor: pointer; + border-bottom: 1px solid var(--jp-border-color2); + transition: background-color 0.2s ease; +} + +.file-browser-dropdown-container .file-browser-item:hover { + background-color: var(--jp-layout-color2); +} + +.file-browser-dropdown-container .file-browser-item.selected { + background-color: var(--jp-brand-color1); + color: white; +} + +.file-browser-dropdown-container .file-browser-item:last-child { + border-bottom: none; +} + +.selected-files-container { + margin: 5px; + padding: 5px; + border: 1px solid var(--jp-border-color1); /* Dynamic border color */ + background-color: var(--jp-layout-color1); /* Dynamic background color */ + border-radius: 5px; + display: flex; + flex-wrap: wrap; /* Allow wrapping if files exceed container width */ + gap: 10px; /* Add spacing between items */ +} + +.selected-files-list { + display: flex; /* Use flexbox for horizontal alignment */ + flex-wrap: wrap; /* Allow wrapping to the next line */ + list-style: none; /* Remove bullets */ + padding: 0; + margin: 0; + gap: 10px; /* Add spacing between items */ +} + +.selected-file-item { + display: flex; + align-items: center; + justify-content: space-between; + padding: 3px; + border: 1px solid var(--jp-border-color2); /* Dynamic border color */ + border-radius: 3px; + background-color: var(--jp-layout-color2); /* Dynamic background color */ + white-space: nowrap; /* Prevent text from wrapping */ +} + +.selected-file-name { + font-size: 12px; + color: var(--jp-ui-font-color1); /* Dynamic text color */ + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + margin-right: 5px; +} + +.remove-file-button { + background: none; + border: none; + color: var(--jp-warn-color0); /* Dynamic warning color */ + font-size: 12px; + cursor: pointer; + padding: 0 5px; + transition: color 0.2s; +} + +.remove-file-button:hover { + color: var(--jp-warn-color1); /* Dynamic hover warning color */ + background-color: var(--jp-layout-color1); +} From 6d11b760712fa8279d4130666430d5071886c72d Mon Sep 17 00:00:00 2001 From: TheGupta2012 Date: Fri, 26 Sep 2025 14:58:23 +0530 Subject: [PATCH 6/7] fix css format issue --- style/base.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style/base.css b/style/base.css index 63e42c7..a557e45 100644 --- a/style/base.css +++ b/style/base.css @@ -967,7 +967,7 @@ svg.access-token-warning { border: 1px solid var(--jp-border-color1); border-radius: 4px; z-index: 1000; - box-shadow: 0 4px 6px rgba(0, 0, 0, 10%); + box-shadow: 0 4px 6px rgb(0 0 0 / 10%); } .file-browser-dropdown-container ul { From f6d117e9242ef52ecacfb84b91cace4295eb8229 Mon Sep 17 00:00:00 2001 From: TheGupta2012 Date: Fri, 26 Sep 2025 15:02:25 +0530 Subject: [PATCH 7/7] fix black and isort deps --- lab_notebook_intelligence/__init__.py | 4 +- .../ai_service_manager.py | 64 +++++-------------- lab_notebook_intelligence/api.py | 28 ++------ .../base_chat_participant.py | 32 +++------- .../built_in_toolsets.py | 4 +- lab_notebook_intelligence/config.py | 4 +- lab_notebook_intelligence/extension.py | 44 ++++--------- lab_notebook_intelligence/github_copilot.py | 54 +++++----------- .../github_copilot_llm_provider.py | 20 ++---- .../litellm_compatible_llm_provider.py | 4 +- .../llm_providers/ollama_llm_provider.py | 16 ++--- .../openai_compatible_llm_provider.py | 4 +- lab_notebook_intelligence/mcp_manager.py | 24 ++----- pyproject.toml | 22 +++++++ 14 files changed, 99 insertions(+), 225 deletions(-) diff --git a/lab_notebook_intelligence/__init__.py b/lab_notebook_intelligence/__init__.py index 4776804..d295b08 100644 --- a/lab_notebook_intelligence/__init__.py +++ b/lab_notebook_intelligence/__init__.py @@ -8,9 +8,7 @@ # the package from a stable release or in editable mode: https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs import warnings - warnings.warn( - "Importing 'lab_notebook_intelligence' outside a proper installation." - ) + warnings.warn("Importing 'lab_notebook_intelligence' outside a proper installation.") __version__ = "dev" import logging diff --git a/lab_notebook_intelligence/ai_service_manager.py b/lab_notebook_intelligence/ai_service_manager.py index 4431e21..c4c285c 100644 --- a/lab_notebook_intelligence/ai_service_manager.py +++ b/lab_notebook_intelligence/ai_service_manager.py @@ -113,9 +113,7 @@ def __init__(self, options: dict = {}): self.telemetry_listeners: Dict[str, TelemetryListener] = {} self._extension_toolsets: Dict[str, list[Toolset]] = {} self._options = options.copy() - self._nbi_config = NBIConfig( - {"server_root_dir": self._options.get("server_root_dir", "")} - ) + self._nbi_config = NBIConfig({"server_root_dir": self._options.get("server_root_dir", "")}) self._openai_compatible_llm_provider = OpenAICompatibleLLMProvider() self._litellm_compatible_llm_provider = LiteLLMCompatibleLLMProvider() self._ollama_llm_provider = OllamaLLMProvider() @@ -149,9 +147,7 @@ def update_models_from_config(self): github_copilot.login_with_existing_credentials( self._nbi_config.store_github_access_token ) - github_copilot.enable_github_login_status_change_updater( - using_github_copilot_service - ) + github_copilot.enable_github_login_status_change_updater(using_github_copilot_service) chat_model_cfg = self.nbi_config.chat_model chat_model_provider_id = chat_model_cfg.get("provider", "none") @@ -164,17 +160,13 @@ def update_models_from_config(self): ) inline_completion_model_cfg = self.nbi_config.inline_completion_model - inline_completion_model_provider_id = inline_completion_model_cfg.get( - "provider", "none" - ) + inline_completion_model_provider_id = inline_completion_model_cfg.get("provider", "none") inline_completion_model_id = inline_completion_model_cfg.get("model", "none") inline_completion_model_provider = self.get_llm_provider( inline_completion_model_provider_id ) self._inline_completion_model = ( - inline_completion_model_provider.get_inline_completion_model( - inline_completion_model_id - ) + inline_completion_model_provider.get_inline_completion_model(inline_completion_model_id) if inline_completion_model_provider is not None else None ) @@ -188,13 +180,9 @@ def update_models_from_config(self): if self._inline_completion_model is not None: properties = inline_completion_model_cfg.get("properties", []) for property in properties: - self._inline_completion_model.set_property_value( - property["id"], property["value"] - ) + self._inline_completion_model.set_property_value(property["id"], property["value"]) - is_github_copilot_chat_model = isinstance( - chat_model_provider, GitHubCopilotLLMProvider - ) + is_github_copilot_chat_model = isinstance(chat_model_provider, GitHubCopilotLLMProvider) default_chat_participant = ( GithubCopilotChatParticipant() if is_github_copilot_chat_model @@ -202,9 +190,7 @@ def update_models_from_config(self): ) self._default_chat_participant = default_chat_participant - self.chat_participants[DEFAULT_CHAT_PARTICIPANT_ID] = ( - self._default_chat_participant - ) + self.chat_participants[DEFAULT_CHAT_PARTICIPANT_ID] = self._default_chat_participant def update_mcp_servers(self): self._mcp_manager.update_mcp_servers(self.nbi_config.mcp) @@ -266,21 +252,15 @@ def register_llm_provider(self, provider: LLMProvider) -> None: return self.llm_providers[provider.id] = provider - def register_completion_context_provider( - self, provider: CompletionContextProvider - ) -> None: + def register_completion_context_provider(self, provider: CompletionContextProvider) -> None: if provider.id in self.completion_context_providers: - log.error( - f"Completion Context Provider ID '{provider.id}' is already in use!" - ) + log.error(f"Completion Context Provider ID '{provider.id}' is already in use!") return self.completion_context_providers[provider.id] = provider def register_telemetry_listener(self, listener: TelemetryListener) -> None: if listener.name in self.telemetry_listeners: - log.error( - f"Notebook Intelligence telemetry listener '{listener.name}' already exists!" - ) + log.error(f"Notebook Intelligence telemetry listener '{listener.name}' already exists!") return log.warning( f"Notebook Intelligence telemetry listener '{listener.name}' registered. Make sure it is from a trusted source." @@ -289,9 +269,7 @@ def register_telemetry_listener(self, listener: TelemetryListener) -> None: def register_toolset(self, toolset: Toolset) -> None: if toolset.provider is None: - log.error( - f"Toolset '{toolset.id}' has no provider! It cannot be registered." - ) + log.error(f"Toolset '{toolset.id}' has no provider! It cannot be registered.") return provider_id = toolset.provider.id if provider_id not in self._extension_toolsets: @@ -445,27 +423,19 @@ async def handle_chat_request( if self.chat_model is None: response.stream(MarkdownData("Chat model is not set!")) response.stream( - ButtonData( - "Configure", "lab-notebook-intelligence:open-configuration-dialog" - ) + ButtonData("Configure", "lab-notebook-intelligence:open-configuration-dialog") ) response.finish() return request.host = self - (participant_id, command, prompt) = AIServiceManager.parse_prompt( - request.prompt - ) - participant = self.chat_participants.get( - participant_id, DEFAULT_CHAT_PARTICIPANT_ID - ) + (participant_id, command, prompt) = AIServiceManager.parse_prompt(request.prompt) + participant = self.chat_participants.get(participant_id, DEFAULT_CHAT_PARTICIPANT_ID) request.command = command request.prompt = prompt response.participant_id = participant_id return await participant.handle_chat_request(request, response, options) - async def get_completion_context( - self, request: ContextRequest - ) -> CompletionContext: + async def get_completion_context(self, request: ContextRequest) -> CompletionContext: cancel_token = request.cancel_token context = CompletionContext([]) @@ -525,9 +495,7 @@ def get_extension_toolset(self, extension_id: str, toolset_id: str) -> Toolset: return None - def get_extension_tool( - self, extension_id: str, toolset_id: str, tool_name: str - ) -> Tool: + def get_extension_tool(self, extension_id: str, toolset_id: str, tool_name: str) -> Tool: if extension_id not in self._extension_toolsets: return None extension_toolsets = self._extension_toolsets[extension_id] diff --git a/lab_notebook_intelligence/api.py b/lab_notebook_intelligence/api.py index 58d699b..6a839ad 100644 --- a/lab_notebook_intelligence/api.py +++ b/lab_notebook_intelligence/api.py @@ -284,9 +284,7 @@ def on_run_ui_command_response(self, data: dict) -> None: self._run_ui_command_response_signal.emit(data) @staticmethod - async def wait_for_run_ui_command_response( - response: "ChatResponse", callback_id: str - ): + async def wait_for_run_ui_command_response(response: "ChatResponse", callback_id: str): resp = {"result": None} def _on_ui_command_response(data: dict): @@ -297,9 +295,7 @@ def _on_ui_command_response(data: dict): while True: if resp["result"] is not None: - response.run_ui_command_response_signal.disconnect( - _on_ui_command_response - ) + response.run_ui_command_response_signal.disconnect(_on_ui_command_response) return resp["result"] await asyncio.sleep(0.1) @@ -598,9 +594,7 @@ async def _tool_call_loop(tool_call_rounds: list): tool_call_rounds.append(tool_call) elif choice["message"].get("content", None) is not None: response.stream( - MarkdownData( - tool_response["choices"][0]["message"]["content"] - ) + MarkdownData(tool_response["choices"][0]["message"]["content"]) ) messages.append(choice["message"]) @@ -641,9 +635,7 @@ async def _tool_call_loop(tool_call_rounds: list): else: args = fuzzy_json_loads(tool_call["function"]["arguments"]) - tool_properties = tool_to_call.schema["function"]["parameters"][ - "properties" - ] + tool_properties = tool_to_call.schema["function"]["parameters"]["properties"] if type(args) is str: if ( len(tool_properties) == 1 @@ -737,9 +729,7 @@ class CompletionContextProvider: def id(self) -> str: raise NotImplemented - def handle_completion_context_request( - self, request: ContextRequest - ) -> CompletionContext: + def handle_completion_context_request(self, request: ContextRequest) -> CompletionContext: raise NotImplemented @@ -913,9 +903,7 @@ def register_llm_provider(self, provider: LLMProvider) -> None: def register_chat_participant(self, participant: ChatParticipant) -> None: raise NotImplemented - def register_completion_context_provider( - self, provider: CompletionContextProvider - ) -> None: + def register_completion_context_provider(self, provider: CompletionContextProvider) -> None: raise NotImplemented def register_telemetry_listener(self, listener: TelemetryListener) -> None: @@ -953,9 +941,7 @@ def get_mcp_server_tool(self, server_name: str, tool_name: str) -> Tool: def get_extension_toolset(self, extension_id: str, toolset_id: str) -> Toolset: return NotImplemented - def get_extension_tool( - self, extension_id: str, toolset_id: str, tool_name: str - ) -> Tool: + def get_extension_tool(self, extension_id: str, toolset_id: str, tool_name: str) -> Tool: return NotImplemented diff --git a/lab_notebook_intelligence/base_chat_participant.py b/lab_notebook_intelligence/base_chat_participant.py index 96e5f3f..fdeeaea 100644 --- a/lab_notebook_intelligence/base_chat_participant.py +++ b/lab_notebook_intelligence/base_chat_participant.py @@ -70,9 +70,7 @@ async def handle_tool_call( tool_context: dict, tool_args: dict, ) -> str: - return await self._ext_tool.handle_tool_call( - request, response, tool_context, tool_args - ) + return await self._ext_tool.handle_tool_call(request, response, tool_context, tool_args) class CreateNewNotebookTool(Tool): @@ -94,9 +92,7 @@ def tags(self) -> list[str]: @property def description(self) -> str: - return ( - "This tool creates a new notebook with the provided code and markdown cells" - ) + return "This tool creates a new notebook with the provided code and markdown cells" @property def schema(self) -> dict: @@ -443,9 +439,7 @@ def tools(self) -> list[Tool]: for ext_id, ext_toolsets in tool_selection.extension_tools.items(): for toolset_id, toolset_tools in ext_toolsets.items(): for tool_name in toolset_tools: - ext_tool = host.get_extension_tool( - ext_id, toolset_id, tool_name - ) + ext_tool = host.get_extension_tool(ext_id, toolset_id, tool_name) if ext_tool is not None: tool_list.append(SecuredExtensionTool(ext_tool)) return tool_list @@ -469,9 +463,7 @@ async def generate_code_cell(self, request: ChatRequest) -> str: "content": f"You are an assistant that creates Python code which will be used in a Jupyter notebook. Generate only Python code and some comments for the code. You should return the code directly, without wrapping it inside ```.", }, ) - messages.append( - {"role": "user", "content": f"Generate code for: {request.prompt}"} - ) + messages.append({"role": "user", "content": f"Generate code for: {request.prompt}"}) generated = chat_model.completions(messages) code = generated["choices"][0]["message"]["content"] @@ -520,9 +512,7 @@ async def handle_chat_request( toolsets, ) in request.tool_selection.extension_tools.items(): for toolset_id in toolsets.keys(): - ext_toolset = request.host.get_extension_toolset( - extension_id, toolset_id - ) + ext_toolset = request.host.get_extension_toolset(extension_id, toolset_id) if ext_toolset is not None and ext_toolset.instructions is not None: system_prompt += ext_toolset.instructions + "\n" @@ -560,9 +550,7 @@ async def handle_ask_mode_chat_request( {"code": code, "path": file_path}, ) - response.stream( - MarkdownData(f"Notebook '{file_path}' created and opened successfully") - ) + response.stream(MarkdownData(f"Notebook '{file_path}' created and opened successfully")) response.finish() return elif request.command == "newPythonFile": @@ -576,9 +564,7 @@ async def handle_ask_mode_chat_request( "content": f"You are an assistant that creates Python code. You should return the code directly, without wrapping it inside ```.", }, ) - messages.append( - {"role": "user", "content": f"Generate code for: {request.prompt}"} - ) + messages.append({"role": "user", "content": f"Generate code for: {request.prompt}"}) generated = chat_model.completions(messages) code = generated["choices"][0]["message"]["content"] code = extract_llm_generated_code(code) @@ -610,9 +596,7 @@ async def handle_ask_mode_chat_request( try: if chat_model.provider.id != "github-copilot": response.stream(ProgressData("Thinking...")) - chat_model.completions( - messages, response=response, cancel_token=request.cancel_token - ) + chat_model.completions(messages, response=response, cancel_token=request.cancel_token) except Exception as e: log.error(f"Error while handling chat request!\n{e}") response.stream( diff --git a/lab_notebook_intelligence/built_in_toolsets.py b/lab_notebook_intelligence/built_in_toolsets.py index afdb990..f17a03a 100644 --- a/lab_notebook_intelligence/built_in_toolsets.py +++ b/lab_notebook_intelligence/built_in_toolsets.py @@ -112,9 +112,7 @@ async def get_cell_output(cell_index: int, **args) -> str: @nbapi.auto_approve @nbapi.tool -async def set_cell_type_and_source( - cell_index: int, cell_type: str, source: str, **args -) -> str: +async def set_cell_type_and_source(cell_index: int, cell_type: str, source: str, **args) -> str: """Set cell type and source for the cell at index for the active notebook. Args: diff --git a/lab_notebook_intelligence/config.py b/lab_notebook_intelligence/config.py index a031a3e..4e931e6 100644 --- a/lab_notebook_intelligence/config.py +++ b/lab_notebook_intelligence/config.py @@ -110,9 +110,7 @@ def default_chat_mode(self): @property def chat_model(self): - return self.get( - "chat_model", {"provider": "github-copilot", "model": "gpt-4.1"} - ) + return self.get("chat_model", {"provider": "github-copilot", "model": "gpt-4.1"}) @property def inline_completion_model(self): diff --git a/lab_notebook_intelligence/extension.py b/lab_notebook_intelligence/extension.py index 9f3d489..77911bd 100644 --- a/lab_notebook_intelligence/extension.py +++ b/lab_notebook_intelligence/extension.py @@ -59,8 +59,7 @@ def get(self): allowed_builtin_toolsets = [ {"id": toolset.id, "name": toolset.name} for toolset in built_in_toolsets.values() - if toolset.id != BuiltinToolset.NotebookExecute - or notebook_execute_tool_enabled + if toolset.id != BuiltinToolset.NotebookExecute or notebook_execute_tool_enabled ] mcp_servers = ai_service_manager.get_mcp_servers() mcp_server_tools = [ @@ -100,9 +99,7 @@ def get(self): # sort by toolset name ts.sort(key=lambda toolset: toolset["name"]) extension = ai_service_manager.get_extension(extension_id) - extensions.append( - {"id": extension_id, "name": extension.name, "toolsets": ts} - ) + extensions.append({"id": extension_id, "name": extension.name, "toolsets": ts}) # sort by extension id extensions.sort(key=lambda extension: extension["id"]) @@ -185,8 +182,7 @@ def post(self): json.dumps( { "mcpServers": [ - {"id": server.name} - for server in ai_service_manager.get_mcp_servers() + {"id": server.name} for server in ai_service_manager.get_mcp_servers() ] } ) @@ -344,11 +340,7 @@ def post(self): if device_verification_info is None: self.set_status(500) self.finish( - json.dumps( - { - "error": "Failed to get device verification info from GitHub Copilot" - } - ) + json.dumps({"error": "Failed to get device verification info from GitHub Copilot"}) ) return self.finish(json.dumps(device_verification_info)) @@ -428,9 +420,7 @@ def message_id(self) -> str: return self.messageId def stream(self, data: Union[ResponseStreamData, dict]): - data_type = ( - ResponseStreamDataType.LLMRaw if type(data) is dict else data.data_type - ) + data_type = ResponseStreamDataType.LLMRaw if type(data) is dict else data.data_type if data_type == ResponseStreamDataType.Markdown: self.chat_history.add_message( @@ -538,14 +528,10 @@ def stream(self, data: Union[ResponseStreamData, dict]): "title": data.title, "message": data.message, "confirmArgs": ( - data.confirmArgs - if data.confirmArgs is not None - else {} + data.confirmArgs if data.confirmArgs is not None else {} ), "cancelArgs": ( - data.cancelArgs - if data.cancelArgs is not None - else {} + data.cancelArgs if data.cancelArgs is not None else {} ), "confirmLabel": ( data.confirmLabel @@ -626,9 +612,7 @@ async def run_ui_command(self, command: str, args: dict = {}) -> None: }, } ) - response = await ChatResponse.wait_for_run_ui_command_response( - self, callback_id - ) + response = await ChatResponse.wait_for_run_ui_command_response(self, callback_id) return response @@ -943,9 +927,7 @@ def initialize_settings(self): def initialize_handlers(self): NotebookIntelligence.root_dir = self.serverapp.root_dir - server_root_dir = os.path.expanduser( - self.serverapp.web_app.settings["server_root_dir"] - ) + server_root_dir = os.path.expanduser(self.serverapp.web_app.settings["server_root_dir"]) self.initialize_ai_service(server_root_dir) self._setup_handlers(self.serverapp.web_app) self.serverapp.log.info(f"Registered {self.name} server extension") @@ -968,9 +950,7 @@ def _setup_handlers(self, web_app): route_pattern_capabilities = url_path_join( base_url, "lab-notebook-intelligence", "capabilities" ) - route_pattern_config = url_path_join( - base_url, "lab-notebook-intelligence", "config" - ) + route_pattern_config = url_path_join(base_url, "lab-notebook-intelligence", "config") route_pattern_update_provider_models = url_path_join( base_url, "lab-notebook-intelligence", "update-provider-models" ) @@ -995,9 +975,7 @@ def _setup_handlers(self, web_app): route_pattern_github_logout = url_path_join( base_url, "lab-notebook-intelligence", "gh-logout" ) - route_pattern_copilot = url_path_join( - base_url, "lab-notebook-intelligence", "copilot" - ) + route_pattern_copilot = url_path_join(base_url, "lab-notebook-intelligence", "copilot") GetCapabilitiesHandler.notebook_execute_tool = self.notebook_execute_tool NotebookIntelligence.handlers = [ (route_pattern_capabilities, GetCapabilitiesHandler), diff --git a/lab_notebook_intelligence/github_copilot.py b/lab_notebook_intelligence/github_copilot.py index 41013ea..bbedf43 100644 --- a/lab_notebook_intelligence/github_copilot.py +++ b/lab_notebook_intelligence/github_copilot.py @@ -39,9 +39,7 @@ "https://github.com" if GHE_SUBDOMAIN == "" else f"https://{GHE_SUBDOMAIN}.ghe.com" ) GH_REST_API_BASE_URL = ( - "https://api.github.com" - if GHE_SUBDOMAIN == "" - else f"https://api.{GHE_SUBDOMAIN}.ghe.com" + "https://api.github.com" if GHE_SUBDOMAIN == "" else f"https://api.{GHE_SUBDOMAIN}.ghe.com" ) EDITOR_VERSION = f"LabNotebookIntelligence/{NBI_VERSION}" @@ -58,9 +56,7 @@ TOKEN_FETCH_INTERVAL = 15 NL = "\n" -LoginStatus = Enum( - "LoginStatus", ["NOT_LOGGED_IN", "ACTIVATING_DEVICE", "LOGGING_IN", "LOGGED_IN"] -) +LoginStatus = Enum("LoginStatus", ["NOT_LOGGED_IN", "ACTIVATING_DEVICE", "LOGGING_IN", "LOGGED_IN"]) github_auth = { "verification_uri": None, @@ -113,15 +109,9 @@ def get_login_status(): return response -deprecated_user_data_file = os.path.join( - os.path.expanduser("~"), ".jupyter", "nbi-data.json" -) -user_data_file = os.path.join( - os.path.expanduser("~"), ".jupyter", "nbi", "user-data.json" -) -access_token_password = os.getenv( - "NBI_GH_ACCESS_TOKEN_PASSWORD", "nbi-access-token-password" -) +deprecated_user_data_file = os.path.join(os.path.expanduser("~"), ".jupyter", "nbi-data.json") +user_data_file = os.path.join(os.path.expanduser("~"), ".jupyter", "nbi", "user-data.json") +access_token_password = os.getenv("NBI_GH_ACCESS_TOKEN_PASSWORD", "nbi-access-token-password") def read_stored_github_access_token() -> str: @@ -139,9 +129,7 @@ def read_stored_github_access_token() -> str: if base64_access_token is not None: base64_bytes = base64.b64decode(base64_access_token.encode("utf-8")) - return decrypt_with_password(access_token_password, base64_bytes).decode( - "utf-8" - ) + return decrypt_with_password(access_token_password, base64_bytes).decode("utf-8") except Exception as e: log.error(f"Failed to read GitHub access token: {e}") @@ -150,9 +138,7 @@ def read_stored_github_access_token() -> str: def write_github_access_token(access_token: str) -> bool: try: - encrypted_access_token = encrypt_with_password( - access_token_password, access_token.encode() - ) + encrypted_access_token = encrypt_with_password(access_token_password, access_token.encode()) base64_bytes = base64.b64encode(encrypted_access_token) base64_access_token = base64_bytes.decode("utf-8") @@ -409,12 +395,9 @@ def get_token_thread_func(): # update token if 10 seconds or less left to expiration if github_auth["access_token"] is not None and ( token is None - or (dt.datetime.now() - github_auth["token_expires_at"]).total_seconds() - > -10 + or (dt.datetime.now() - github_auth["token_expires_at"]).total_seconds() > -10 ): - if ( - dt.datetime.now() - last_token_fetch_time - ).total_seconds() > TOKEN_FETCH_INTERVAL: + if (dt.datetime.now() - last_token_fetch_time).total_seconds() > TOKEN_FETCH_INTERVAL: log.info("Refreshing GitHub token") get_token() last_token_fetch_time = dt.datetime.now() @@ -425,9 +408,7 @@ def get_token_thread_func(): def wait_for_tokens(): global get_access_code_thread, get_token_thread if get_access_code_thread is None: - get_access_code_thread = threading.Thread( - target=wait_for_user_access_token_thread_func - ) + get_access_code_thread = threading.Thread(target=wait_for_user_access_token_thread_func) get_access_code_thread.start() if get_token_thread is None: @@ -532,19 +513,14 @@ def _aggregate_streaming_response(client: sseclient.SSEClient) -> dict: def _format_llm_response(): for tool_call in final_tool_calls: - if ( - "arguments" in tool_call["function"] - and tool_call["function"]["arguments"] == "" - ): + if "arguments" in tool_call["function"] and tool_call["function"]["arguments"] == "": tool_call["function"]["arguments"] = "{}" return { "choices": [ { "message": { - "tool_calls": ( - final_tool_calls if len(final_tool_calls) > 0 else None - ), + "tool_calls": (final_tool_calls if len(final_tool_calls) > 0 else None), "content": final_content, "role": "assistant", } @@ -577,9 +553,9 @@ def _format_llm_response(): final_tool_calls.append(tc) else: if "arguments" in tool_call["function"]: - final_tool_calls[index]["function"]["arguments"] += tool_call[ - "function" - ]["arguments"] + final_tool_calls[index]["function"]["arguments"] += tool_call["function"][ + "arguments" + ] return _format_llm_response() diff --git a/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py b/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py index c7f69f3..a7ef2cf 100644 --- a/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py +++ b/lab_notebook_intelligence/llm_providers/github_copilot_llm_provider.py @@ -62,9 +62,7 @@ def completions( cancel_token: CancelToken = None, options: dict = {}, ) -> Any: - return completions( - self._model_id, messages, tools, response, cancel_token, options - ) + return completions(self._model_id, messages, tools, response, cancel_token, options) class GitHubCopilotInlineCompletionModel(InlineCompletionModel): @@ -107,18 +105,10 @@ def __init__(self): GitHubCopilotChatModel(self, "gpt-4o", "GPT-4o", 128000, True), GitHubCopilotChatModel(self, "o3-mini", "o3-mini", 200000, True), GitHubCopilotChatModel(self, "gpt-5", "GPT-5", 128000, True), - GitHubCopilotChatModel( - self, "claude-sonnet-4", "Claude Sonnet 4", 80000, True - ), - GitHubCopilotChatModel( - self, "claude-3.7-sonnet", "Claude 3.7 Sonnet", 200000, True - ), - GitHubCopilotChatModel( - self, "claude-3.5-sonnet", "Claude 3.5 Sonnet", 90000, True - ), - GitHubCopilotChatModel( - self, "gemini-2.5-pro", "Gemini 2.5 Pro", 128000, True - ), + GitHubCopilotChatModel(self, "claude-sonnet-4", "Claude Sonnet 4", 80000, True), + GitHubCopilotChatModel(self, "claude-3.7-sonnet", "Claude 3.7 Sonnet", 200000, True), + GitHubCopilotChatModel(self, "claude-3.5-sonnet", "Claude 3.5 Sonnet", 90000, True), + GitHubCopilotChatModel(self, "gemini-2.5-pro", "Gemini 2.5 Pro", 128000, True), GitHubCopilotChatModel( self, "gemini-2.0-flash-001", "Gemini 2.0 Flash", 1000000, False ), diff --git a/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py b/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py index f01bdd3..aabd9e6 100644 --- a/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py +++ b/lab_notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py @@ -24,9 +24,7 @@ def __init__(self, provider: "LiteLLMCompatibleLLMProvider"): super().__init__(provider) self._provider = provider self._properties = [ - LLMProviderProperty( - "model_id", "Model", "Model (must support streaming)", "", False - ), + LLMProviderProperty("model_id", "Model", "Model (must support streaming)", "", False), LLMProviderProperty("base_url", "Base URL", "Base URL", "", False), LLMProviderProperty("api_key", "API key", "API key", "", True), LLMProviderProperty( diff --git a/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py b/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py index 0dbb101..4332b8f 100644 --- a/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py +++ b/lab_notebook_intelligence/llm_providers/ollama_llm_provider.py @@ -20,23 +20,15 @@ log = logging.getLogger(__name__) OLLAMA_EMBEDDING_FAMILIES = set(["nomic-bert", "bert"]) -QWEN_INLINE_COMPL_PROMPT = ( - """<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>""" -) -DEEPSEEK_INLINE_COMPL_PROMPT = ( - """<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>""" -) +QWEN_INLINE_COMPL_PROMPT = """<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>""" +DEEPSEEK_INLINE_COMPL_PROMPT = """<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>""" CODELLAMA_INLINE_COMPL_PROMPT = """
 {prefix} {suffix} """
-STARCODER_INLINE_COMPL_PROMPT = (
-    """{prefix}{suffix}"""
-)
+STARCODER_INLINE_COMPL_PROMPT = """{prefix}{suffix}"""
 CODESTRAL_INLINE_COMPL_PROMPT = """[SUFFIX]{suffix}[PREFIX]{prefix}"""
 
 
 class OllamaChatModel(ChatModel):
-    def __init__(
-        self, provider: LLMProvider, model_id: str, model_name: str, context_window: int
-    ):
+    def __init__(self, provider: LLMProvider, model_id: str, model_name: str, context_window: int):
         super().__init__(provider)
         self._model_id = model_id
         self._model_name = model_name
diff --git a/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
index 6a58654..9ce491e 100644
--- a/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
+++ b/lab_notebook_intelligence/llm_providers/openai_compatible_llm_provider.py
@@ -25,9 +25,7 @@ def __init__(self, provider: "OpenAICompatibleLLMProvider"):
         self._provider = provider
         self._properties = [
             LLMProviderProperty("api_key", "API key", "API key", "", False),
-            LLMProviderProperty(
-                "model_id", "Model", "Model (must support streaming)", "", False
-            ),
+            LLMProviderProperty("model_id", "Model", "Model (must support streaming)", "", False),
             LLMProviderProperty("base_url", "Base URL", "Base URL", "", True),
             LLMProviderProperty(
                 "context_window", "Context window", "Context window length", "", True
diff --git a/lab_notebook_intelligence/mcp_manager.py b/lab_notebook_intelligence/mcp_manager.py
index 9bbc7e1..be713e1 100644
--- a/lab_notebook_intelligence/mcp_manager.py
+++ b/lab_notebook_intelligence/mcp_manager.py
@@ -36,9 +36,7 @@
 
 
 class MCPTool(Tool):
-    def __init__(
-        self, server: "MCPServer", name, description, schema, auto_approve=False
-    ):
+    def __init__(self, server: "MCPServer", name, description, schema, auto_approve=False):
         super().__init__()
         self._server = server
         self._name = name
@@ -110,9 +108,7 @@ async def handle_tool_call(
                     for content in result.content:
                         if type(content) is ImageContent:
                             response.stream(
-                                ImageData(
-                                    f"data:{content.mimeType};base64,{content.data}"
-                                )
+                                ImageData(f"data:{content.mimeType};base64,{content.data}")
                             )
                         elif type(content) is TextContent:
                             text_contents.append(content.text)
@@ -145,9 +141,7 @@ def __init__(
     ):
         self._name: str = name
         self._stdio_params: StdioServerParameters = stdio_params
-        self._streamable_http_params: StreamableHttpServerParameters = (
-            streamable_http_params
-        )
+        self._streamable_http_params: StreamableHttpServerParameters = streamable_http_params
         self._auto_approve_tools: set[str] = set(auto_approve_tools)
         self._tried_to_get_tool_list = False
         self._mcp_tools = []
@@ -224,9 +218,7 @@ def get_tool(self, tool_name: str) -> Tool:
 
 
 class MCPChatParticipant(BaseChatParticipant):
-    def __init__(
-        self, id: str, name: str, servers: list[MCPServer], nbi_tools: list[str] = []
-    ):
+    def __init__(self, id: str, name: str, servers: list[MCPServer], nbi_tools: list[str] = []):
         super().__init__()
         self._id = id
         self._name = name
@@ -348,9 +340,7 @@ def create_servers(self, server_names: list[str], servers_config: dict):
         for server_name in server_names:
             server_config = servers_config.get(server_name, None)
             if server_config is None:
-                log.error(
-                    f"Server '{server_name}' not found in MCP servers configuration"
-                )
+                log.error(f"Server '{server_name}' not found in MCP servers configuration")
                 continue
 
             if server_config.get("disabled", False) == True:
@@ -382,9 +372,7 @@ def create_mcp_server(self, server_name: str, server_config: dict):
 
             return MCPServerImpl(
                 server_name,
-                stdio_params=StdioServerParameters(
-                    command=command, args=args, env=server_env
-                ),
+                stdio_params=StdioServerParameters(command=command, args=args, env=server_env),
                 auto_approve_tools=auto_approve_tools,
             )
         elif "url" in server_config:
diff --git a/pyproject.toml b/pyproject.toml
index 0400740..333308d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -83,3 +83,25 @@ before-build-python = ["jlpm clean:all"]
 
 [tool.check-wheel-contents]
 ignore = ["W002"]
+
+[tool.black]
+line-length = 100
+target-version = ['py310', 'py311', 'py312', 'py313']
+include = '\.pyi?$'
+exclude = '''
+/(
+    \.git
+  | \.__pycache__
+  | \.tox
+  | \.venv
+  | dist
+)/
+'''
+
+[tool.isort]
+profile = "black"
+multi_line_output = 3
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+line_length = 100