Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from PIL import Image
from transformers import AutoTokenizer

def get_api_key(request: Request):
return request.headers.get("API_KEY", get_remote_address(request))
Expand Down Expand Up @@ -83,6 +84,14 @@ class ValidatorInfo(BaseModel):
all_uid_info: dict = {}
sha: str = ""

class ChatCompletion(BaseModel):
model: str
messages: list[dict]
temperature: float = 1
top_p: float = 1
max_tokens: int = 128


class ImageGenerationService:
def __init__(self):
self.subtensor = bt.subtensor("finney")
Expand Down Expand Up @@ -126,6 +135,12 @@ def __init__(self):
Thread(target=self.sync_metagraph_periodically, daemon=True).start()
Thread(target=self.recheck_validators, daemon=True).start()
Thread(target=self.update_model_config, daemon=True).start()
self.tokenizer_config = self.model_config.find_one({"name": "tokenizer"})
print(self.tokenizer_config, flush=True)
self.tokenizers = {
k: AutoTokenizer.from_pretrained(v) for k, v in self.tokenizer_config["data"].items()
}
print(self.tokenizers, flush=True)

def update_model_config(self):
while True:
Expand Down Expand Up @@ -568,6 +583,27 @@ async def controlnet_api(self, request: Request, data: ImageToImage):
generate_data["pipeline_params"][key] = value

return await self.generate(Prompt(**generate_data))

async def chat_completions(self, request: Request, data: ChatCompletion):
# Get API_KEY from header
api_key = request.headers.get("API_KEY")
self.check_auth(api_key)
if data.model not in self.model_list:
raise HTTPException(status_code=404, detail="Model not found")
messages_str = self.tokenizers[data.model].apply_chat_template(data.messages, tokenize=False)
print(f"Chat message str: {messages_str}", flush=True)
generate_data = {
"key": api_key,
"prompt_input": messages_str,
"model_name": data.model,
"pipeline_params": {
"temperature": data.temperature,
"top_p": data.top_p,
"max_tokens": data.max_tokens
}
}
response = await self.generate(TextPrompt(**generate_data))
return response['prompt_output']

def base64_to_pil_image(self, base64_image):
image = base64.b64decode(base64_image)
Expand Down Expand Up @@ -632,3 +668,8 @@ async def instantid_api(request: Request, data: ImageToImage):
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def controlnet_api(request: Request, data: ImageToImage):
return await app.controlnet_api(request, data)

@app.app.post("/api/v1/chat/completions")
@limiter.limit(API_RATE_LIMIT)
async def chat_completions_api(request: Request, data: ChatCompletion):
return await app.chat_completions(request, data)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ tqdm==4.66.1
httpx==0.27.0
prometheus_fastapi_instrumentator==6.0.0
pymongo==4.7.3
slowapi==0.1.9
slowapi==0.1.9
transformers
jinja2==3.1.0