-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Segment anything 2 pipeline image (#185)
* feat(pipeline): add SAM2 image segmentation prototype This commit introduces a prototype implementation of the [Segment Anything v2](https://github.com/facebookresearch/segment-anything-2) (SAM2) pipeline within the AI worker. The prototype demonstrates the basic functionality needed to perform segmentation on an image. Note that video segmentation is not yet implemented. Additionally, the dependencies were updated quickly, which may temporarily break other pipelines. * revert Dockerfile, requirements, add sam2 Dockerfile * refactor: enhance SAM2 input handling and error management This commit allows nested arrays to be supplied as JSON strings for SAM2 input. It also implements robust error handling to return a 400 error with a descriptive message when incorrect parameters are provided. * refactor: improve SAM2 return time This commit ensures that we return the masks, iou_predictions and low_res_masks in json format. * Sam2 -> SegmentAnything2 * update go bindings * update multipart.go binding with NewSegmentAnything2Writer * update worker and multipart methods * predictions -> scores, mask -> logits * add sam2 specific multipartwriter fields * add segment-anything-2 to containerHostPorts * fix pipeline name in worker.go * revert Dockerfile, requirements, add sam2 Dockerfile * Sam2 -> SegmentAnything2 * predictions -> scores, mask -> logits * feat: replace JSON.dump with str This commit replaces the JSON.dump method with a simple str method since it is highly unlikely that the string contains invalid data. Co-authored-by: Peter Schroedl <peter_schroedl@me.com> * move pipeline-specific dockerfile * update openapi yaml * add segment anything specific readme * update go bindings * refactor: move SAM2 docker This commit moves the SAM2 docker file inside the docker container. * refactor: add FastAPI descriptions This commit cleansup the codebase and adds FastAPI parameter and pipeline descriptions. * refactor: improve sam2 route function name This commit improves the sam2 route function name so that it is more pythonic and shows up nicer in the OpenAPI spec pipeline summary. * chore(worker): update golang bindings This commit updates the golang bindings so that the runner changes are reflected. * refactor(runner): add media_type This commit adds the media type content MIME type to the segment anything 2 pipeline. * chore(worker): remove debug patch This commit removes the debug patch which was accidentally added to the code. * feat(runnner): add SAM2 model download command This commit adds the SAM2 model download command so that orchestrators can pre-download the model. * refactor(worker): change SAM2 multipart reader param order This commit ensures that the parameters are in the same order as the pipeline parameters. * determine docker image in createContainer * fix: fix examples This commit fixes the example scripts. --------- Co-authored-by: Rick Staa <rick.staa@outlook.com> Co-authored-by: Elite Encoder <john@eliteencoder.net> Co-authored-by: Peter Schroedl <peter@livepeer.org>
- Loading branch information
1 parent
c887845
commit c4d02e9
Showing
16 changed files
with
2,004 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import logging | ||
from typing import List, Optional, Tuple | ||
|
||
import PIL | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_torch_device, get_model_dir | ||
from app.routes.util import InferenceError | ||
from PIL import ImageFile | ||
from sam2.sam2_image_predictor import SAM2ImagePredictor | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SegmentAnything2Pipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
kwargs = {"cache_dir": get_model_dir()} | ||
|
||
torch_device = get_torch_device() | ||
|
||
self.tm = SAM2ImagePredictor.from_pretrained( | ||
model_id=model_id, | ||
device=torch_device, | ||
**kwargs, | ||
) | ||
|
||
def __call__( | ||
self, image: PIL.Image, **kwargs | ||
) -> Tuple[List[PIL.Image], List[Optional[bool]]]: | ||
try: | ||
self.tm.set_image(image) | ||
prediction = self.tm.predict(**kwargs) | ||
except Exception as e: | ||
raise InferenceError(original_exception=e) | ||
|
||
return prediction | ||
|
||
def __str__(self) -> str: | ||
return f"Segment Anything 2 model_id={self.model_id}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import logging | ||
import os | ||
from typing import Annotated | ||
|
||
import numpy as np | ||
from app.dependencies import get_pipeline | ||
from app.pipelines.base import Pipeline | ||
from app.routes.util import ( | ||
HTTPError, | ||
InferenceError, | ||
MasksResponse, | ||
http_error, | ||
json_str_to_np_array, | ||
) | ||
from fastapi import APIRouter, Depends, File, Form, UploadFile, status | ||
from fastapi.responses import JSONResponse | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from PIL import Image, ImageFile | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
router = APIRouter() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
RESPONSES = { | ||
status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, | ||
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, | ||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, | ||
} | ||
|
||
|
||
# TODO: Make model_id and other None properties optional once Go codegen tool supports | ||
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373. | ||
@router.post( | ||
"/segment-anything-2", | ||
response_model=MasksResponse, | ||
responses=RESPONSES, | ||
description="Segment objects in an image.", | ||
) | ||
@router.post( | ||
"/segment-anything-2/", | ||
response_model=MasksResponse, | ||
responses=RESPONSES, | ||
include_in_schema=False, | ||
) | ||
async def segment_anything_2( | ||
image: Annotated[ | ||
UploadFile, File(description="Image to segment.", media_type="image/*") | ||
], | ||
model_id: Annotated[ | ||
str, Form(description="Hugging Face model ID used for image generation.") | ||
] = "", | ||
point_coords: Annotated[ | ||
str, | ||
Form( | ||
description=( | ||
"Nx2 array of point prompts to the model, where each point is in (X,Y) " | ||
"in pixels." | ||
) | ||
), | ||
] = None, | ||
point_labels: Annotated[ | ||
str, | ||
Form( | ||
description=( | ||
"Labels for the point prompts, where 1 indicates a foreground point " | ||
"and 0 indicates a background point." | ||
) | ||
), | ||
] = None, | ||
box: Annotated[ | ||
str, | ||
Form( | ||
description=( | ||
"A length 4 array given as a box prompt to the model, in XYXY format." | ||
) | ||
), | ||
] = None, | ||
mask_input: Annotated[ | ||
str, | ||
Form( | ||
description=( | ||
"A low-resolution mask input to the model, typically from a previous " | ||
"prediction iteration, with the form 1xHxW (H=W=256 for SAM)." | ||
) | ||
), | ||
] = None, | ||
multimask_output: Annotated[ | ||
bool, | ||
Form( | ||
description=( | ||
"If true, the model will return three masks for ambiguous input " | ||
"prompts, often producing better masks than a single prediction." | ||
) | ||
), | ||
] = True, | ||
return_logits: Annotated[ | ||
bool, | ||
Form( | ||
description=( | ||
"If true, returns un-thresholded mask logits instead of a binary mask." | ||
) | ||
), | ||
] = True, | ||
normalize_coords: Annotated[ | ||
bool, | ||
Form( | ||
description=( | ||
"If true, the point coordinates will be normalized to the range [0,1], " | ||
"with point_coords expected to be with respect to image dimensions." | ||
) | ||
), | ||
] = True, | ||
pipeline: Pipeline = Depends(get_pipeline), | ||
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), | ||
): | ||
auth_token = os.environ.get("AUTH_TOKEN") | ||
if auth_token: | ||
if not token or token.credentials != auth_token: | ||
return JSONResponse( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
content=http_error("Invalid bearer token"), | ||
) | ||
|
||
if model_id != "" and model_id != pipeline.model_id: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content=http_error( | ||
f"pipeline configured with {pipeline.model_id} but called with " | ||
f"{model_id}" | ||
), | ||
) | ||
|
||
try: | ||
point_coords = json_str_to_np_array(point_coords, var_name="point_coords") | ||
point_labels = json_str_to_np_array(point_labels, var_name="point_labels") | ||
box = json_str_to_np_array(box, var_name="box") | ||
mask_input = json_str_to_np_array(mask_input, var_name="mask_input") | ||
except ValueError as e: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content=http_error(str(e)), | ||
) | ||
|
||
try: | ||
image = Image.open(image.file).convert("RGB") | ||
masks, scores, low_res_mask_logits = pipeline( | ||
image, | ||
point_coords=point_coords, | ||
point_labels=point_labels, | ||
box=box, | ||
mask_input=mask_input, | ||
multimask_output=multimask_output, | ||
return_logits=return_logits, | ||
normalize_coords=normalize_coords, | ||
) | ||
except Exception as e: | ||
logger.error(f"Segment Anything 2 error: {e}") | ||
logger.exception(e) | ||
if isinstance(e, InferenceError): | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content=http_error(str(e)), | ||
) | ||
|
||
return JSONResponse( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
content=http_error("Segment Anything 2 error"), | ||
) | ||
|
||
# Return masks sorted by descending score as string. | ||
sorted_ind = np.argsort(scores)[::-1] | ||
return { | ||
"masks": str(masks[sorted_ind].tolist()), | ||
"scores": str(scores[sorted_ind].tolist()), | ||
"logits": str(low_res_mask_logits[sorted_ind].tolist()), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
FROM livepeer/ai-runner:base | ||
|
||
RUN pip install --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 xformers==0.0.27 git+https://github.com/facebookresearch/segment-anything-2.git@main#egg=sam-2 | ||
|
||
CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Runner Docker Images | ||
|
||
This folder contains Dockerfiles for pipelines supported by the Livepeer AI network. The list is maintained by the Livepeer community and audited by the [Core AI team](https://explorer.livepeer.org/treasury/42084921863832634370966409987770520882792921083596034115019946998721416745190). In the future, we will enable custom pipelines to be used with the Livepeer AI network. | ||
|
||
## Building a Pipeline-Specific Container | ||
|
||
> [!NOTE] | ||
> We are transitioning our existing pipelines to this new structure. As a result, the base container is currently somewhat bloated. In the future, the base image will contain only the necessary dependencies to run any pipeline. | ||
All pipeline-specific containers are built on top of the base container found in the main [runner](../) folder and on [Docker Hub](https://hub.docker.com/r/livepeer/ai-runner). The base container includes the minimum dependencies to run any pipeline, while pipeline-specific containers add the necessary dependencies for their respective pipelines. This structure allows for faster build times, less dependency bloat, and easier maintenance. | ||
|
||
### Steps to Build a Pipeline-Specific Container | ||
|
||
To build a pipeline-specific container, you need to build the base container first. The base container is tagged as `base`, and the pipeline-specific container is built from the Dockerfile in the pipeline-specific folder. For example, to build the `segment-anything-2` pipeline-specific container, follow these steps: | ||
|
||
1. **Navigate to the `ai-worker/runner` Directory**: | ||
|
||
```bash | ||
cd ai-worker/runner | ||
``` | ||
|
||
2. **Build the Base Container**: | ||
|
||
```bash | ||
docker build -t livepeer/ai-runner:base . | ||
``` | ||
|
||
This command builds the base container and tags it as `livepeer/ai-runner:base`. | ||
|
||
3. **Build the `segment-anything-2` Pipeline-Specific Container**: | ||
|
||
```bash | ||
docker build -f docker/Dockerfile.segment_anything_2 -t livepeer/ai-runner:segment-anything-2 . | ||
``` | ||
|
||
This command builds the `segment-anything-2` pipeline-specific container using the Dockerfile located at [docker/Dockerfile.segment_anything_2](docker/Dockerfile.segment_anything_2) and tags it as `livepeer/ai-runner:segment-anything-2`. |
Oops, something went wrong.