Skip to content

Commit

Permalink
Segment anything 2 pipeline image (#185)
Browse files Browse the repository at this point in the history
* feat(pipeline): add SAM2 image segmentation prototype

This commit introduces a prototype implementation of the
[Segment Anything v2](https://github.com/facebookresearch/segment-anything-2)
(SAM2) pipeline within the AI worker. The prototype demonstrates the basic
functionality needed to perform segmentation on an image. Note that video
segmentation is not yet implemented. Additionally, the dependencies were
updated quickly, which may temporarily break other pipelines.

* revert Dockerfile, requirements, add sam2 Dockerfile

* refactor: enhance SAM2 input handling and error management

This commit allows nested arrays to be supplied as JSON strings for SAM2
input. It also implements robust error handling to return a 400 error with
a descriptive message when incorrect parameters are provided.

* refactor: improve SAM2 return time

This commit ensures that we return the masks, iou_predictions and
low_res_masks in json format.

* Sam2 -> SegmentAnything2

* update go bindings

* update multipart.go binding with NewSegmentAnything2Writer

* update worker and multipart methods

* predictions -> scores, mask -> logits

* add sam2 specific multipartwriter fields

* add segment-anything-2 to containerHostPorts

* fix pipeline name in worker.go

* revert Dockerfile, requirements, add sam2 Dockerfile

* Sam2 -> SegmentAnything2

* predictions -> scores, mask -> logits

* feat: replace JSON.dump with str

This commit replaces the JSON.dump method with a simple str method since
it is highly unlikely that the string contains invalid data.

Co-authored-by: Peter Schroedl <peter_schroedl@me.com>

* move pipeline-specific dockerfile

* update openapi yaml

* add segment anything specific readme

* update go bindings

* refactor: move SAM2 docker

This commit moves the SAM2 docker file inside the docker container.

* refactor: add FastAPI descriptions

This commit cleansup the codebase and adds FastAPI parameter and
pipeline descriptions.

* refactor: improve sam2 route function name

This commit improves the sam2 route function name so that it is more
pythonic and shows up nicer in the OpenAPI spec pipeline summary.

* chore(worker): update golang bindings

This commit updates the golang bindings so that the runner changes are
reflected.

* refactor(runner): add media_type

This commit adds the media type content MIME type to the segment
anything 2 pipeline.

* chore(worker): remove debug patch

This commit removes the debug patch which was accidentally added to the
code.

* feat(runnner): add SAM2 model download command

This commit adds the SAM2 model download command so that orchestrators
can pre-download the model.

* refactor(worker): change SAM2 multipart reader param order

This commit ensures that the parameters are in the same order as the
pipeline parameters.

* determine docker image in createContainer

* fix: fix examples

This commit fixes the example scripts.

---------

Co-authored-by: Rick Staa <rick.staa@outlook.com>
Co-authored-by: Elite Encoder <john@eliteencoder.net>
Co-authored-by: Peter Schroedl <peter@livepeer.org>
  • Loading branch information
4 people authored Sep 4, 2024
1 parent c887845 commit c4d02e9
Show file tree
Hide file tree
Showing 16 changed files with 2,004 additions and 68 deletions.
8 changes: 8 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "segment-anything-2":
from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline

return SegmentAnything2Pipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -82,6 +86,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "segment-anything-2":
from app.routes import segment_anything_2

return segment_anything_2.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
41 changes: 41 additions & 0 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import List, Optional, Tuple

import PIL
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)


class SegmentAnything2Pipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()

self.tm = SAM2ImagePredictor.from_pretrained(
model_id=model_id,
device=torch_device,
**kwargs,
)

def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
try:
self.tm.set_image(image)
prediction = self.tm.predict(**kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return prediction

def __str__(self) -> str:
return f"Segment Anything 2 model_id={self.model_id}"
179 changes: 179 additions & 0 deletions runner/app/routes/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import os
from typing import Annotated

import numpy as np
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import (
HTTPError,
InferenceError,
MasksResponse,
http_error,
json_str_to_np_array,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373.
@router.post(
"/segment-anything-2",
response_model=MasksResponse,
responses=RESPONSES,
description="Segment objects in an image.",
)
@router.post(
"/segment-anything-2/",
response_model=MasksResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def segment_anything_2(
image: Annotated[
UploadFile, File(description="Image to segment.", media_type="image/*")
],
model_id: Annotated[
str, Form(description="Hugging Face model ID used for image generation.")
] = "",
point_coords: Annotated[
str,
Form(
description=(
"Nx2 array of point prompts to the model, where each point is in (X,Y) "
"in pixels."
)
),
] = None,
point_labels: Annotated[
str,
Form(
description=(
"Labels for the point prompts, where 1 indicates a foreground point "
"and 0 indicates a background point."
)
),
] = None,
box: Annotated[
str,
Form(
description=(
"A length 4 array given as a box prompt to the model, in XYXY format."
)
),
] = None,
mask_input: Annotated[
str,
Form(
description=(
"A low-resolution mask input to the model, typically from a previous "
"prediction iteration, with the form 1xHxW (H=W=256 for SAM)."
)
),
] = None,
multimask_output: Annotated[
bool,
Form(
description=(
"If true, the model will return three masks for ambiguous input "
"prompts, often producing better masks than a single prediction."
)
),
] = True,
return_logits: Annotated[
bool,
Form(
description=(
"If true, returns un-thresholded mask logits instead of a binary mask."
)
),
] = True,
normalize_coords: Annotated[
bool,
Form(
description=(
"If true, the point coordinates will be normalized to the range [0,1], "
"with point_coords expected to be with respect to image dimensions."
)
),
] = True,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

try:
point_coords = json_str_to_np_array(point_coords, var_name="point_coords")
point_labels = json_str_to_np_array(point_labels, var_name="point_labels")
box = json_str_to_np_array(box, var_name="box")
mask_input = json_str_to_np_array(mask_input, var_name="mask_input")
except ValueError as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

try:
image = Image.open(image.file).convert("RGB")
masks, scores, low_res_mask_logits = pipeline(
image,
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
except Exception as e:
logger.error(f"Segment Anything 2 error: {e}")
logger.exception(e)
if isinstance(e, InferenceError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("Segment Anything 2 error"),
)

# Return masks sorted by descending score as string.
sorted_ind = np.argsort(scores)[::-1]
return {
"masks": str(masks[sorted_ind].tolist()),
"scores": str(scores[sorted_ind].tolist()),
"logits": str(low_res_mask_logits[sorted_ind].tolist()),
}
60 changes: 59 additions & 1 deletion runner/app/routes/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import base64
import io
import json
import os
from typing import List
from typing import List, Optional

import numpy as np
from fastapi import UploadFile
from PIL import Image
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -30,6 +32,18 @@ class VideoResponse(BaseModel):
frames: List[List[Media]] = Field(..., description="The generated video frames.")


class MasksResponse(BaseModel):
"""Response model for object segmentation."""

masks: str = Field(..., description="The generated masks.")
scores: str = Field(
..., description="The model's confidence scores for each generated mask."
)
logits: str = Field(
..., description="The raw, unnormalized predictions (logits) for the masks."
)


class chunk(BaseModel):
"""A chunk of text with a timestamp."""

Expand All @@ -56,6 +70,22 @@ class HTTPError(BaseModel):
detail: APIError = Field(..., description="Detailed error information.")


class InferenceError(Exception):
"""Exception raised for errors during model inference."""

def __init__(self, message="Error during model execution", original_exception=None):
"""Initialize the exception.
Args:
message: The error message.
original_exception: The original exception that caused the error.
"""
if original_exception:
message = f"{message}: {original_exception}"
super().__init__(message)
self.original_exception = original_exception


def http_error(msg: str) -> HTTPError:
"""Create an HTTP error response with the specified message.
Expand Down Expand Up @@ -118,3 +148,31 @@ def file_exceeds_max_size(
except Exception as e:
print(f"Error checking file size: {e}")
return False


def json_str_to_np_array(
data: Optional[str], var_name: Optional[str] = None
) -> Optional[np.ndarray]:
"""Converts a JSON string to a NumPy array.
Args:
data: The JSON string to convert.
var_name: The name of the variable being converted. Used in error messages.
Returns:
The NumPy array if the conversion is successful, None otherwise.
Raises:
ValueError: If an error occurs during JSON parsing.
"""
if data:
try:
array = np.array(json.loads(data))
return array
except json.JSONDecodeError as e:
error_message = "Error parsing JSON"
if var_name:
error_message += f" for {var_name}"
error_message += f": {e}"
raise ValueError(error_message)
return None
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ function download_all_models() {

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models
}

# Enable HF transfer acceleration.
Expand Down
5 changes: 5 additions & 0 deletions runner/docker/Dockerfile.segment_anything_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FROM livepeer/ai-runner:base

RUN pip install --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 xformers==0.0.27 git+https://github.com/facebookresearch/segment-anything-2.git@main#egg=sam-2

CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"]
36 changes: 36 additions & 0 deletions runner/docker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Runner Docker Images

This folder contains Dockerfiles for pipelines supported by the Livepeer AI network. The list is maintained by the Livepeer community and audited by the [Core AI team](https://explorer.livepeer.org/treasury/42084921863832634370966409987770520882792921083596034115019946998721416745190). In the future, we will enable custom pipelines to be used with the Livepeer AI network.

## Building a Pipeline-Specific Container

> [!NOTE]
> We are transitioning our existing pipelines to this new structure. As a result, the base container is currently somewhat bloated. In the future, the base image will contain only the necessary dependencies to run any pipeline.
All pipeline-specific containers are built on top of the base container found in the main [runner](../) folder and on [Docker Hub](https://hub.docker.com/r/livepeer/ai-runner). The base container includes the minimum dependencies to run any pipeline, while pipeline-specific containers add the necessary dependencies for their respective pipelines. This structure allows for faster build times, less dependency bloat, and easier maintenance.

### Steps to Build a Pipeline-Specific Container

To build a pipeline-specific container, you need to build the base container first. The base container is tagged as `base`, and the pipeline-specific container is built from the Dockerfile in the pipeline-specific folder. For example, to build the `segment-anything-2` pipeline-specific container, follow these steps:

1. **Navigate to the `ai-worker/runner` Directory**:

```bash
cd ai-worker/runner
```

2. **Build the Base Container**:

```bash
docker build -t livepeer/ai-runner:base .
```

This command builds the base container and tags it as `livepeer/ai-runner:base`.

3. **Build the `segment-anything-2` Pipeline-Specific Container**:

```bash
docker build -f docker/Dockerfile.segment_anything_2 -t livepeer/ai-runner:segment-anything-2 .
```

This command builds the `segment-anything-2` pipeline-specific container using the Dockerfile located at [docker/Dockerfile.segment_anything_2](docker/Dockerfile.segment_anything_2) and tags it as `livepeer/ai-runner:segment-anything-2`.
Loading

0 comments on commit c4d02e9

Please sign in to comment.