Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ You can load models by setting the `OCR_SERVICE_TESSERACT_LANG` variable, you ca

**For performance reasons it is recommended that you load only one model at a time, as processing time will increase slightly per model loaded.**

## API

## API specification

The Service, by default, will be listening on port `8090` and the returned content extraction result will be represented in JSON format.
Expand Down
2 changes: 1 addition & 1 deletion env/ocr_service.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The default images for ocr-service:
# - cogstacksystems/cogstack-ocr-service:latest
OCR_SERVICE_IMAGE_RELEASE_VERSION=1.0.6
OCR_SERVICE_IMAGE_RELEASE_VERSION=1.0.7
OCR_SERVICE_DOCKER_IMAGE="cogstacksystems/cogstack-ocr-service:${OCR_SERVICE_IMAGE_RELEASE_VERSION:-latest}"

OCR_SERVICE_CPU_THREADS=1
Expand Down
3 changes: 2 additions & 1 deletion ocr_service/api/health.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter
from fastapi.responses import ORJSONResponse

from ocr_service.dto.info_response import InfoResponse
from ocr_service.utils.utils import get_app_info

health_api = APIRouter(prefix="/api")
Expand All @@ -11,6 +12,6 @@ def health() -> ORJSONResponse:
return ORJSONResponse(content={"status": "healthy"})


@health_api.get("/info")
@health_api.get("/info", response_model=InfoResponse, response_class=ORJSONResponse)
def info() -> ORJSONResponse:
return ORJSONResponse(content=get_app_info())
44 changes: 25 additions & 19 deletions ocr_service/api/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import orjson
from fastapi import APIRouter, File, Request, UploadFile
from fastapi.responses import ORJSONResponse, Response
from pydantic import ValidationError
from starlette.datastructures import FormData

from ocr_service.dto.process_request import ProcessRequest
from ocr_service.dto.process_response import ProcessResponse
from ocr_service.processor.processor import Processor
from ocr_service.settings import settings
from ocr_service.utils.utils import build_response, setup_logging
Expand All @@ -18,7 +21,7 @@
log = setup_logging("api", log_level=settings.LOG_LEVEL)


@process_api.post("/process")
@process_api.post("/process", response_model=ProcessResponse, response_class=ORJSONResponse)
def process(request: Request, file: Optional[UploadFile] = File(default=None)) -> ORJSONResponse:
"""
Processes raw binary input stream, file, or
Expand Down Expand Up @@ -52,26 +55,29 @@ def process(request: Request, file: Optional[UploadFile] = File(default=None)) -
if isinstance(record, list) and len(record) > 0:
record = record[0]

footer = record.get("footer", {}) # type: ignore
log.info("Stream contains valid JSON.")

# JSON with base64 field
if isinstance(record, dict) and "binary_data" in record:
encoded: str = record.get("binary_data", {})

if encoded not in (None, "", {}):
try:
stream = base64.b64decode(encoded, validate=True)
log.info("binary_data successfully base64-decoded")
except Exception:
log.warning("binary_data is not valid base64; forcing bytes")
stream = bytes(encoded) if isinstance(encoded, bytes | bytearray) \
else str(encoded).encode("utf-8")
else:
stream = b""
if not isinstance(record, dict):
return ORJSONResponse(content={"detail": "Invalid JSON payload"}, status_code=422)

try:
payload = ProcessRequest.model_validate(record)
except ValidationError as exc:
return ORJSONResponse(content={"detail": exc.errors()}, status_code=422)

footer = payload.footer or {}
encoded: str = payload.binary_data

if encoded not in (None, "", {}):
try:
stream = base64.b64decode(encoded, validate=True)
log.info("binary_data successfully base64-decoded")
except Exception:
log.warning("binary_data is not valid base64; forcing bytes")
stream = bytes(encoded) if isinstance(encoded, bytes | bytearray) \
else str(encoded).encode("utf-8")
else:
log.info("JSON found but no binary_data; using raw body")
stream = raw_body
stream = b""

except Exception:
log.warning("Stream does not contain valid JSON." + str(traceback.format_exc()))
Expand Down Expand Up @@ -112,7 +118,7 @@ def process(request: Request, file: Optional[UploadFile] = File(default=None)) -
return ORJSONResponse(content=response, status_code=code, media_type="application/json")


@process_api.post("/process_file")
@process_api.post("/process_file", response_model=ProcessResponse, response_class=ORJSONResponse)
def process_file(request: Request, file: UploadFile = File(...)) -> ORJSONResponse:

file_name: str = file.filename if file.filename else ""
Expand Down
10 changes: 10 additions & 0 deletions ocr_service/dto/info_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel, Field


class InfoResponse(BaseModel):
"""Response payload for the /api/info endpoint."""

service_app_name: str = Field(..., description="Service name.")
service_version: str = Field(..., description="Service version string.")
service_model: str = Field(..., description="Tesseract model path/prefix.")
config: str = Field(..., description="Reserved config field.")
4 changes: 2 additions & 2 deletions ocr_service/dto/process_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class ProcessContext(BaseModel):
state shared between the document conversion and OCR stages.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

stream: bytes
"""Raw document bytes provided by the caller."""

Expand All @@ -36,8 +38,6 @@ class ProcessContext(BaseModel):
pdf_stream: bytes = b""
"""Intermediate PDF bytes used for downstream conversion/OCR."""

model_config = ConfigDict(arbitrary_types_allowed=True)

_checks: TextChecks | None = PrivateAttr(default=None)
"""Lazy text-type detection cache. Initialized on first access."""

Expand Down
12 changes: 12 additions & 0 deletions ocr_service/dto/process_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any

from pydantic import BaseModel, ConfigDict, Field


class ProcessRequest(BaseModel):
"""JSON payloads sent to /api/process."""

model_config = ConfigDict(extra="ignore")

binary_data: str = Field(..., description="Base64-encoded document bytes.")
footer: dict[str, Any] | None = Field(default=None, description="Optional passthrough payload.")
19 changes: 19 additions & 0 deletions ocr_service/dto/process_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any

from pydantic import BaseModel, Field


class ProcessResult(BaseModel):
"""Inner OCR result payload for /api/process endpoints."""

text: str = Field(..., description="Extracted/OCR text.")
footer: dict[str, Any] | None = Field(default=None, description="Optional passthrough payload.")
metadata: dict[str, Any] = Field(default_factory=dict, description="Document metadata.")
success: str = Field(..., description="Success flag encoded as a string.")
timestamp: str = Field(..., description="Processing timestamp.")


class ProcessResponse(BaseModel):
"""Response payload for /api/process endpoints."""

result: ProcessResult = Field(..., description="OCR processing result.")
8 changes: 6 additions & 2 deletions ocr_service/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
from sys import platform
from typing import Any, Literal

from pydantic import Field, computed_field, field_validator
from pydantic import AliasChoices, Field, computed_field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=None, extra="ignore", validate_assignment=True)

OCR_SERVICE_VERSION: str = Field("1.0.6", min_length=1)
OCR_SERVICE_VERSION: str = Field(
"dev",
min_length=1,
validation_alias=AliasChoices("OCR_SERVICE_VERSION", "OCR_SERVICE_IMAGE_RELEASE_VERSION"),
)
OCR_SERVICE_LOG_LEVEL: int = Field(10, ge=0, le=50)
OCR_SERVICE_DEBUG_MODE: bool = Field(False)
OCR_TMP_DIR: str | None = None
Expand Down
25 changes: 24 additions & 1 deletion ocr_service/tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import traceback
import unittest

import orjson
from fastapi import FastAPI
from fastapi.testclient import TestClient

from ocr_service.settings import settings
from ocr_service.app import create_app
from ocr_service.settings import settings
from ocr_service.tests.utils_helpers import DOCS, WSGIEnvironInjector, get_file, lev_similarity
from ocr_service.utils.utils import sync_port_mapping

Expand Down Expand Up @@ -210,3 +211,25 @@ def test_process_record_binary_data_json_payload(self):
self.log.info("Testing test_process_record_binary_data_json_payload")
payload: bytes = get_file("payloads/sample_base64_record_nifi.json")
self._test_json_payload_json_b64_binary_data(payload=payload.decode())

def test_process_record_binary_data_invalid_payload_returns_422(self):
self.log.info("Testing invalid record payload returns 422")
payload = {"binary_data": None}
response = self.client.post(self.ENDPOINT_PROCESS_SINGLE,
content=orjson.dumps(payload),
headers={"Content-Type": "application/json"})
self.assertEqual(response.status_code, 422)
data = response.json()
self.assertIn("detail", data)
self.assertIsInstance(data["detail"], list)

def test_process_record_binary_data_missing_binary_data_returns_422(self):
self.log.info("Testing missing binary_data returns 422")
payload = {"footer": {"source": "test"}}
response = self.client.post(self.ENDPOINT_PROCESS_SINGLE,
content=orjson.dumps(payload),
headers={"Content-Type": "application/json"})
self.assertEqual(response.status_code, 422)
data = response.json()
self.assertIn("detail", data)
self.assertIsInstance(data["detail"], list)