diff --git a/runner/app/main.py b/runner/app/main.py index 6f511420..e3b4078a 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -52,6 +52,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.upscale import UpscalePipeline return UpscalePipeline(model_id) + case "segment-anything-2": + from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline + + return SegmentAnything2Pipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -82,6 +86,10 @@ def load_route(pipeline: str) -> any: from app.routes import upscale return upscale.router + case "segment-anything-2": + from app.routes import segment_anything_2 + + return segment_anything_2.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/segment_anything_2.py b/runner/app/pipelines/segment_anything_2.py new file mode 100644 index 00000000..64c4080d --- /dev/null +++ b/runner/app/pipelines/segment_anything_2.py @@ -0,0 +1,41 @@ +import logging +from typing import List, Optional, Tuple + +import PIL +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_torch_device, get_model_dir +from app.routes.util import InferenceError +from PIL import ImageFile +from sam2.sam2_image_predictor import SAM2ImagePredictor + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +logger = logging.getLogger(__name__) + + +class SegmentAnything2Pipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + kwargs = {"cache_dir": get_model_dir()} + + torch_device = get_torch_device() + + self.tm = SAM2ImagePredictor.from_pretrained( + model_id=model_id, + device=torch_device, + **kwargs, + ) + + def __call__( + self, image: PIL.Image, **kwargs + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + try: + self.tm.set_image(image) + prediction = self.tm.predict(**kwargs) + except Exception as e: + raise InferenceError(original_exception=e) + + return prediction + + def __str__(self) -> str: + return f"Segment Anything 2 model_id={self.model_id}" diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py new file mode 100644 index 00000000..70436432 --- /dev/null +++ b/runner/app/routes/segment_anything_2.py @@ -0,0 +1,179 @@ +import logging +import os +from typing import Annotated + +import numpy as np +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.util import ( + HTTPError, + InferenceError, + MasksResponse, + http_error, + json_str_to_np_array, +) +from fastapi import APIRouter, Depends, File, Form, UploadFile, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from PIL import Image, ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +# TODO: Make model_id and other None properties optional once Go codegen tool supports +# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373. +@router.post( + "/segment-anything-2", + response_model=MasksResponse, + responses=RESPONSES, + description="Segment objects in an image.", +) +@router.post( + "/segment-anything-2/", + response_model=MasksResponse, + responses=RESPONSES, + include_in_schema=False, +) +async def segment_anything_2( + image: Annotated[ + UploadFile, File(description="Image to segment.", media_type="image/*") + ], + model_id: Annotated[ + str, Form(description="Hugging Face model ID used for image generation.") + ] = "", + point_coords: Annotated[ + str, + Form( + description=( + "Nx2 array of point prompts to the model, where each point is in (X,Y) " + "in pixels." + ) + ), + ] = None, + point_labels: Annotated[ + str, + Form( + description=( + "Labels for the point prompts, where 1 indicates a foreground point " + "and 0 indicates a background point." + ) + ), + ] = None, + box: Annotated[ + str, + Form( + description=( + "A length 4 array given as a box prompt to the model, in XYXY format." + ) + ), + ] = None, + mask_input: Annotated[ + str, + Form( + description=( + "A low-resolution mask input to the model, typically from a previous " + "prediction iteration, with the form 1xHxW (H=W=256 for SAM)." + ) + ), + ] = None, + multimask_output: Annotated[ + bool, + Form( + description=( + "If true, the model will return three masks for ambiguous input " + "prompts, often producing better masks than a single prediction." + ) + ), + ] = True, + return_logits: Annotated[ + bool, + Form( + description=( + "If true, returns un-thresholded mask logits instead of a binary mask." + ) + ), + ] = True, + normalize_coords: Annotated[ + bool, + Form( + description=( + "If true, the point coordinates will be normalized to the range [0,1], " + "with point_coords expected to be with respect to image dimensions." + ) + ), + ] = True, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}" + ), + ) + + try: + point_coords = json_str_to_np_array(point_coords, var_name="point_coords") + point_labels = json_str_to_np_array(point_labels, var_name="point_labels") + box = json_str_to_np_array(box, var_name="box") + mask_input = json_str_to_np_array(mask_input, var_name="mask_input") + except ValueError as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + + try: + image = Image.open(image.file).convert("RGB") + masks, scores, low_res_mask_logits = pipeline( + image, + point_coords=point_coords, + point_labels=point_labels, + box=box, + mask_input=mask_input, + multimask_output=multimask_output, + return_logits=return_logits, + normalize_coords=normalize_coords, + ) + except Exception as e: + logger.error(f"Segment Anything 2 error: {e}") + logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=http_error("Segment Anything 2 error"), + ) + + # Return masks sorted by descending score as string. + sorted_ind = np.argsort(scores)[::-1] + return { + "masks": str(masks[sorted_ind].tolist()), + "scores": str(scores[sorted_ind].tolist()), + "logits": str(low_res_mask_logits[sorted_ind].tolist()), + } diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 1fad9bf9..8a319e84 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -1,8 +1,10 @@ import base64 import io +import json import os -from typing import List +from typing import List, Optional +import numpy as np from fastapi import UploadFile from PIL import Image from pydantic import BaseModel, Field @@ -30,6 +32,18 @@ class VideoResponse(BaseModel): frames: List[List[Media]] = Field(..., description="The generated video frames.") +class MasksResponse(BaseModel): + """Response model for object segmentation.""" + + masks: str = Field(..., description="The generated masks.") + scores: str = Field( + ..., description="The model's confidence scores for each generated mask." + ) + logits: str = Field( + ..., description="The raw, unnormalized predictions (logits) for the masks." + ) + + class chunk(BaseModel): """A chunk of text with a timestamp.""" @@ -56,6 +70,22 @@ class HTTPError(BaseModel): detail: APIError = Field(..., description="Detailed error information.") +class InferenceError(Exception): + """Exception raised for errors during model inference.""" + + def __init__(self, message="Error during model execution", original_exception=None): + """Initialize the exception. + + Args: + message: The error message. + original_exception: The original exception that caused the error. + """ + if original_exception: + message = f"{message}: {original_exception}" + super().__init__(message) + self.original_exception = original_exception + + def http_error(msg: str) -> HTTPError: """Create an HTTP error response with the specified message. @@ -118,3 +148,31 @@ def file_exceeds_max_size( except Exception as e: print(f"Error checking file size: {e}") return False + + +def json_str_to_np_array( + data: Optional[str], var_name: Optional[str] = None +) -> Optional[np.ndarray]: + """Converts a JSON string to a NumPy array. + + Args: + data: The JSON string to convert. + var_name: The name of the variable being converted. Used in error messages. + + Returns: + The NumPy array if the conversion is successful, None otherwise. + + Raises: + ValueError: If an error occurs during JSON parsing. + """ + if data: + try: + array = np.array(json.loads(data)) + return array + except json.JSONDecodeError as e: + error_message = "Error parsing JSON" + if var_name: + error_message += f" for {var_name}" + error_message += f": {e}" + raise ValueError(error_message) + return None diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 9fe40837..2e646434 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -60,6 +60,9 @@ function download_all_models() { # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models + + # Custom pipeline models. + huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models } # Enable HF transfer acceleration. diff --git a/runner/docker/Dockerfile.segment_anything_2 b/runner/docker/Dockerfile.segment_anything_2 new file mode 100644 index 00000000..0f9faf0b --- /dev/null +++ b/runner/docker/Dockerfile.segment_anything_2 @@ -0,0 +1,5 @@ +FROM livepeer/ai-runner:base + +RUN pip install --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 xformers==0.0.27 git+https://github.com/facebookresearch/segment-anything-2.git@main#egg=sam-2 + +CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"] diff --git a/runner/docker/README.md b/runner/docker/README.md new file mode 100644 index 00000000..d49aee57 --- /dev/null +++ b/runner/docker/README.md @@ -0,0 +1,36 @@ +# Runner Docker Images + +This folder contains Dockerfiles for pipelines supported by the Livepeer AI network. The list is maintained by the Livepeer community and audited by the [Core AI team](https://explorer.livepeer.org/treasury/42084921863832634370966409987770520882792921083596034115019946998721416745190). In the future, we will enable custom pipelines to be used with the Livepeer AI network. + +## Building a Pipeline-Specific Container + +> [!NOTE] +> We are transitioning our existing pipelines to this new structure. As a result, the base container is currently somewhat bloated. In the future, the base image will contain only the necessary dependencies to run any pipeline. + +All pipeline-specific containers are built on top of the base container found in the main [runner](../) folder and on [Docker Hub](https://hub.docker.com/r/livepeer/ai-runner). The base container includes the minimum dependencies to run any pipeline, while pipeline-specific containers add the necessary dependencies for their respective pipelines. This structure allows for faster build times, less dependency bloat, and easier maintenance. + +### Steps to Build a Pipeline-Specific Container + +To build a pipeline-specific container, you need to build the base container first. The base container is tagged as `base`, and the pipeline-specific container is built from the Dockerfile in the pipeline-specific folder. For example, to build the `segment-anything-2` pipeline-specific container, follow these steps: + +1. **Navigate to the `ai-worker/runner` Directory**: + + ```bash + cd ai-worker/runner + ``` + +2. **Build the Base Container**: + + ```bash + docker build -t livepeer/ai-runner:base . + ``` + + This command builds the base container and tags it as `livepeer/ai-runner:base`. + +3. **Build the `segment-anything-2` Pipeline-Specific Container**: + + ```bash + docker build -f docker/Dockerfile.segment_anything_2 -t livepeer/ai-runner:segment-anything-2 . + ``` + + This command builds the `segment-anything-2` pipeline-specific container using the Dockerfile located at [docker/Dockerfile.segment_anything_2](docker/Dockerfile.segment_anything_2) and tags it as `livepeer/ai-runner:segment-anything-2`. diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index b50514f8..5a3a1834 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -235,6 +235,50 @@ paths: $ref: '#/components/schemas/HTTPValidationError' security: - HTTPBearer: [] + /segment-anything-2: + post: + summary: Segment Anything 2 + description: Segment objects in an image. + operationId: segment_anything_2 + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_segment_anything_2_segment_anything_2_post' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/MasksResponse' + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] components: schemas: APIError: @@ -392,6 +436,61 @@ components: - image - model_id title: Body_image_to_video_image_to_video_post + Body_segment_anything_2_segment_anything_2_post: + properties: + image: + type: string + format: binary + title: Image + description: Image to segment. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: '' + point_coords: + type: string + title: Point Coords + description: Nx2 array of point prompts to the model, where each point is + in (X,Y) in pixels. + point_labels: + type: string + title: Point Labels + description: Labels for the point prompts, where 1 indicates a foreground + point and 0 indicates a background point. + box: + type: string + title: Box + description: A length 4 array given as a box prompt to the model, in XYXY + format. + mask_input: + type: string + title: Mask Input + description: A low-resolution mask input to the model, typically from a + previous prediction iteration, with the form 1xHxW (H=W=256 for SAM). + multimask_output: + type: boolean + title: Multimask Output + description: If true, the model will return three masks for ambiguous input + prompts, often producing better masks than a single prediction. + default: true + return_logits: + type: boolean + title: Return Logits + description: If true, returns un-thresholded mask logits instead of a binary + mask. + default: true + normalize_coords: + type: boolean + title: Normalize Coords + description: If true, the point coordinates will be normalized to the range + [0,1], with point_coords expected to be with respect to image dimensions. + default: true + type: object + required: + - image + - model_id + title: Body_segment_anything_2_segment_anything_2_post Body_upscale_upscale_post: properties: prompt: @@ -463,6 +562,27 @@ components: - images title: ImageResponse description: Response model for image generation. + MasksResponse: + properties: + masks: + type: string + title: Masks + description: The generated masks. + scores: + type: string + title: Scores + description: The model's confidence scores for each generated mask. + logits: + type: string + title: Logits + description: The raw, unnormalized predictions (logits) for the masks. + type: object + required: + - masks + - scores + - logits + title: MasksResponse + description: Response model for object segmentation. Media: properties: url: diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index 13042ad3..e0c86e56 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -9,6 +9,7 @@ health, image_to_image, image_to_video, + segment_anything_2, text_to_image, upscale, ) @@ -121,6 +122,7 @@ def write_openapi(fname: str, entrypoint: str = "runner", version: str = "0.0.0" app.include_router(image_to_video.router) app.include_router(upscale.router) app.include_router(audio_to_text.router) + app.include_router(segment_anything_2.router) use_route_names_as_operation_ids(app) diff --git a/runner/openapi.json b/runner/openapi.json new file mode 100644 index 00000000..c1dd7b4c --- /dev/null +++ b/runner/openapi.json @@ -0,0 +1,1019 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "Livepeer AI Runner", + "description": "An application to run AI pipelines", + "version": "0.1.0" + }, + "servers": [ + { + "url": "https://dream-gateway.livepeer.cloud", + "description": "Livepeer Cloud Community Gateway" + } + ], + "paths": { + "/health": { + "get": { + "summary": "Health", + "operationId": "health", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HealthCheck" + } + } + } + } + } + } + }, + "/text-to-image": { + "post": { + "summary": "Text To Image", + "operationId": "text_to_image", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TextToImageParams" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ImageResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + }, + "/image-to-image": { + "post": { + "summary": "Image To Image", + "operationId": "image_to_image", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_image_to_image_image_to_image_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ImageResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + }, + "/image-to-video": { + "post": { + "summary": "Image To Video", + "operationId": "image_to_video", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_image_to_video_image_to_video_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VideoResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + }, + "/upscale": { + "post": { + "summary": "Upscale", + "operationId": "upscale", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upscale_upscale_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ImageResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + }, + "/audio-to-text": { + "post": { + "summary": "Audio To Text", + "operationId": "audio_to_text", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_audio_to_text_audio_to_text_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TextResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "413": { + "description": "Request Entity Too Large", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + }, + "/segment-anything-2": { + "post": { + "summary": "Segmentanything2", + "operationId": "SegmentAnything2", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_SegmentAnything2_segment_anything_2_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MasksResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + } + }, + "components": { + "schemas": { + "APIError": { + "properties": { + "msg": { + "type": "string", + "title": "Msg" + } + }, + "type": "object", + "required": [ + "msg" + ], + "title": "APIError" + }, + "Body_SegmentAnything2_segment_anything_2_post": { + "properties": { + "image": { + "type": "string", + "format": "binary", + "title": "Image" + }, + "model_id": { + "type": "string", + "title": "Model Id", + "default": "" + }, + "point_coords": { + "type": "string", + "title": "Point Coords" + }, + "point_labels": { + "type": "string", + "title": "Point Labels" + }, + "box": { + "type": "string", + "title": "Box" + }, + "mask_input": { + "type": "string", + "title": "Mask Input" + }, + "multimask_output": { + "type": "boolean", + "title": "Multimask Output", + "default": true + }, + "return_logits": { + "type": "boolean", + "title": "Return Logits", + "default": true + }, + "normalize_coords": { + "type": "boolean", + "title": "Normalize Coords", + "default": true + } + }, + "type": "object", + "required": [ + "image" + ], + "title": "Body_SegmentAnything2_segment_anything_2_post" + }, + "Body_audio_to_text_audio_to_text_post": { + "properties": { + "audio": { + "type": "string", + "format": "binary", + "title": "Audio", + "description": "Uploaded audio file to be transcribed." + }, + "model_id": { + "type": "string", + "title": "Model Id", + "description": "Hugging Face model ID used for transcription.", + "default": "" + } + }, + "type": "object", + "required": [ + "audio" + ], + "title": "Body_audio_to_text_audio_to_text_post" + }, + "Body_image_to_image_image_to_image_post": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt", + "description": "Text prompt(s) to guide image generation." + }, + "image": { + "type": "string", + "format": "binary", + "title": "Image", + "description": "Uploaded image to modify with the pipeline." + }, + "model_id": { + "type": "string", + "title": "Model Id", + "description": "Hugging Face model ID used for image generation.", + "default": "" + }, + "strength": { + "type": "number", + "title": "Strength", + "description": "Degree of transformation applied to the reference image (0 to 1).", + "default": 0.8 + }, + "guidance_scale": { + "type": "number", + "title": "Guidance Scale", + "description": "Encourages model to generate images closely linked to the text prompt (higher values may reduce image quality).", + "default": 7.5 + }, + "image_guidance_scale": { + "type": "number", + "title": "Image Guidance Scale", + "description": "Degree to which the generated image is pushed towards the initial image.", + "default": 1.5 + }, + "negative_prompt": { + "type": "string", + "title": "Negative Prompt", + "description": "Text prompt(s) to guide what to exclude from image generation. Ignored if guidance_scale < 1.", + "default": "" + }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "description": "Perform a safety check to estimate if generated images could be offensive or harmful.", + "default": true + }, + "seed": { + "type": "integer", + "title": "Seed", + "description": "Seed for random number generation." + }, + "num_inference_steps": { + "type": "integer", + "title": "Num Inference Steps", + "description": "Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength.", + "default": 100 + }, + "num_images_per_prompt": { + "type": "integer", + "title": "Num Images Per Prompt", + "description": "Number of images to generate per prompt.", + "default": 1 + } + }, + "type": "object", + "required": [ + "prompt", + "image" + ], + "title": "Body_image_to_image_image_to_image_post" + }, + "Body_image_to_video_image_to_video_post": { + "properties": { + "image": { + "type": "string", + "format": "binary", + "title": "Image", + "description": "Uploaded image to generate a video from." + }, + "model_id": { + "type": "string", + "title": "Model Id", + "description": "Hugging Face model ID used for video generation.", + "default": "" + }, + "height": { + "type": "integer", + "title": "Height", + "description": "The height in pixels of the generated video.", + "default": 576 + }, + "width": { + "type": "integer", + "title": "Width", + "description": "The width in pixels of the generated video.", + "default": 1024 + }, + "fps": { + "type": "integer", + "title": "Fps", + "description": "The frames per second of the generated video.", + "default": 6 + }, + "motion_bucket_id": { + "type": "integer", + "title": "Motion Bucket Id", + "description": "Used for conditioning the amount of motion for the generation. The higher the number the more motion will be in the video.", + "default": 127 + }, + "noise_aug_strength": { + "type": "number", + "title": "Noise Aug Strength", + "description": "Amount of noise added to the conditioning image. Higher values reduce resemblance to the conditioning image and increase motion.", + "default": 0.02 + }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "description": "Perform a safety check to estimate if generated images could be offensive or harmful.", + "default": true + }, + "seed": { + "type": "integer", + "title": "Seed", + "description": "Seed for random number generation." + }, + "num_inference_steps": { + "type": "integer", + "title": "Num Inference Steps", + "description": "Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength.", + "default": 25 + } + }, + "type": "object", + "required": [ + "image" + ], + "title": "Body_image_to_video_image_to_video_post" + }, + "Body_upscale_upscale_post": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt", + "description": "Text prompt(s) to guide upscaled image generation." + }, + "image": { + "type": "string", + "format": "binary", + "title": "Image", + "description": "Uploaded image to modify with the pipeline." + }, + "model_id": { + "type": "string", + "title": "Model Id", + "description": "Hugging Face model ID used for upscaled image generation.", + "default": "" + }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "description": "Perform a safety check to estimate if generated images could be offensive or harmful.", + "default": true + }, + "seed": { + "type": "integer", + "title": "Seed", + "description": "Seed for random number generation." + }, + "num_inference_steps": { + "type": "integer", + "title": "Num Inference Steps", + "description": "Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength.", + "default": 75 + } + }, + "type": "object", + "required": [ + "prompt", + "image" + ], + "title": "Body_upscale_upscale_post" + }, + "HTTPError": { + "properties": { + "detail": { + "$ref": "#/components/schemas/APIError" + } + }, + "type": "object", + "required": [ + "detail" + ], + "title": "HTTPError" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "HealthCheck": { + "properties": { + "status": { + "type": "string", + "title": "Status", + "default": "OK" + } + }, + "type": "object", + "title": "HealthCheck" + }, + "ImageResponse": { + "properties": { + "images": { + "items": { + "$ref": "#/components/schemas/Media" + }, + "type": "array", + "title": "Images" + } + }, + "type": "object", + "required": [ + "images" + ], + "title": "ImageResponse" + }, + "MasksResponse": { + "properties": { + "masks": { + "type": "string", + "title": "Masks" + }, + "scores": { + "type": "string", + "title": "Scores" + }, + "logits": { + "type": "string", + "title": "Logits" + } + }, + "type": "object", + "required": [ + "masks", + "scores", + "logits" + ], + "title": "MasksResponse" + }, + "Media": { + "properties": { + "url": { + "type": "string", + "title": "Url" + }, + "seed": { + "type": "integer", + "title": "Seed" + }, + "nsfw": { + "type": "boolean", + "title": "Nsfw" + } + }, + "type": "object", + "required": [ + "url", + "seed", + "nsfw" + ], + "title": "Media" + }, + "TextResponse": { + "properties": { + "text": { + "type": "string", + "title": "Text" + }, + "chunks": { + "items": { + "$ref": "#/components/schemas/chunk" + }, + "type": "array", + "title": "Chunks" + } + }, + "type": "object", + "required": [ + "text", + "chunks" + ], + "title": "TextResponse" + }, + "TextToImageParams": { + "properties": { + "model_id": { + "type": "string", + "title": "Model Id", + "description": "Hugging Face model ID used for image generation.", + "default": "" + }, + "prompt": { + "type": "string", + "title": "Prompt", + "description": "Text prompt(s) to guide image generation. Separate multiple prompts with '|' if supported by the model." + }, + "height": { + "type": "integer", + "title": "Height", + "description": "The height in pixels of the generated image.", + "default": 576 + }, + "width": { + "type": "integer", + "title": "Width", + "description": "The width in pixels of the generated image.", + "default": 1024 + }, + "guidance_scale": { + "type": "number", + "title": "Guidance Scale", + "description": "Encourages model to generate images closely linked to the text prompt (higher values may reduce image quality).", + "default": 7.5 + }, + "negative_prompt": { + "type": "string", + "title": "Negative Prompt", + "description": "Text prompt(s) to guide what to exclude from image generation. Ignored if guidance_scale < 1.", + "default": "" + }, + "safety_check": { + "type": "boolean", + "title": "Safety Check", + "description": "Perform a safety check to estimate if generated images could be offensive or harmful.", + "default": true + }, + "seed": { + "type": "integer", + "title": "Seed", + "description": "Seed for random number generation." + }, + "num_inference_steps": { + "type": "integer", + "title": "Num Inference Steps", + "description": "Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength.", + "default": 50 + }, + "num_images_per_prompt": { + "type": "integer", + "title": "Num Images Per Prompt", + "description": "Number of images to generate per prompt.", + "default": 1 + } + }, + "type": "object", + "required": [ + "prompt" + ], + "title": "TextToImageParams" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "type": "array", + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + }, + "type": "object", + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError" + }, + "VideoResponse": { + "properties": { + "frames": { + "items": { + "items": { + "$ref": "#/components/schemas/Media" + }, + "type": "array" + }, + "type": "array", + "title": "Frames" + } + }, + "type": "object", + "required": [ + "frames" + ], + "title": "VideoResponse" + }, + "chunk": { + "properties": { + "timestamp": { + "items": {}, + "type": "array", + "title": "Timestamp" + }, + "text": { + "type": "string", + "title": "Text" + } + }, + "type": "object", + "required": [ + "timestamp", + "text" + ], + "title": "chunk" + } + }, + "securitySchemes": { + "HTTPBearer": { + "type": "http", + "scheme": "bearer" + } + } + } +} \ No newline at end of file diff --git a/runner/openapi.yaml b/runner/openapi.yaml index f7facff3..dfd87a18 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -246,6 +246,50 @@ paths: $ref: '#/components/schemas/HTTPValidationError' security: - HTTPBearer: [] + /segment-anything-2: + post: + summary: Segment Anything 2 + description: Segment objects in an image. + operationId: segment_anything_2 + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_segment_anything_2_segment_anything_2_post' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/MasksResponse' + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] components: schemas: APIError: @@ -400,6 +444,60 @@ components: required: - image title: Body_image_to_video_image_to_video_post + Body_segment_anything_2_segment_anything_2_post: + properties: + image: + type: string + format: binary + title: Image + description: Image to segment. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: '' + point_coords: + type: string + title: Point Coords + description: Nx2 array of point prompts to the model, where each point is + in (X,Y) in pixels. + point_labels: + type: string + title: Point Labels + description: Labels for the point prompts, where 1 indicates a foreground + point and 0 indicates a background point. + box: + type: string + title: Box + description: A length 4 array given as a box prompt to the model, in XYXY + format. + mask_input: + type: string + title: Mask Input + description: A low-resolution mask input to the model, typically from a + previous prediction iteration, with the form 1xHxW (H=W=256 for SAM). + multimask_output: + type: boolean + title: Multimask Output + description: If true, the model will return three masks for ambiguous input + prompts, often producing better masks than a single prediction. + default: true + return_logits: + type: boolean + title: Return Logits + description: If true, returns un-thresholded mask logits instead of a binary + mask. + default: true + normalize_coords: + type: boolean + title: Normalize Coords + description: If true, the point coordinates will be normalized to the range + [0,1], with point_coords expected to be with respect to image dimensions. + default: true + type: object + required: + - image + title: Body_segment_anything_2_segment_anything_2_post Body_upscale_upscale_post: properties: prompt: @@ -478,6 +576,27 @@ components: - images title: ImageResponse description: Response model for image generation. + MasksResponse: + properties: + masks: + type: string + title: Masks + description: The generated masks. + scores: + type: string + title: Scores + description: The model's confidence scores for each generated mask. + logits: + type: string + title: Logits + description: The raw, unnormalized predictions (logits) for the masks. + type: object + required: + - masks + - scores + - logits + title: MasksResponse + description: Response model for object segmentation. Media: properties: url: diff --git a/worker/container.go b/worker/container.go index 121044d6..1ad88425 100644 --- a/worker/container.go +++ b/worker/container.go @@ -27,10 +27,11 @@ type RunnerEndpoint struct { } type RunnerContainerConfig struct { - Type RunnerContainerType - Pipeline string - ModelID string - Endpoint RunnerEndpoint + Type RunnerContainerType + Pipeline string + ModelID string + Endpoint RunnerEndpoint + ContainerImageID string // For managed containers only ID string diff --git a/worker/docker.go b/worker/docker.go index ca4626eb..f42c0e49 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -31,17 +31,23 @@ const containerCreator = "ai-worker" // This only works right now on a single GPU because if there is another container // using the GPU we stop it so we don't have to worry about having enough ports var containerHostPorts = map[string]string{ - "text-to-image": "8000", - "image-to-image": "8100", - "image-to-video": "8200", - "upscale": "8300", - "audio-to-text": "8400", + "text-to-image": "8000", + "image-to-image": "8100", + "image-to-video": "8200", + "upscale": "8300", + "audio-to-text": "8400", + "segment-anything-2": "8500", +} + +// Mapping for per pipeline container images. +var pipelineToImage = map[string]string{ + "segment-anything-2": "livepeer/ai-runner:segment-anything-2", } type DockerManager struct { - containerImageID string - gpus []string - modelDir string + defaultImage string + gpus []string + modelDir string dockerClient *client.Client // gpu ID => container name @@ -51,7 +57,7 @@ type DockerManager struct { mu *sync.Mutex } -func NewDockerManager(containerImageID string, gpus []string, modelDir string) (*DockerManager, error) { +func NewDockerManager(defaultImage string, gpus []string, modelDir string) (*DockerManager, error) { dockerClient, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { return nil, err @@ -65,13 +71,13 @@ func NewDockerManager(containerImageID string, gpus []string, modelDir string) ( cancel() return &DockerManager{ - containerImageID: containerImageID, - gpus: gpus, - modelDir: modelDir, - dockerClient: dockerClient, - gpuContainers: make(map[string]string), - containers: make(map[string]*RunnerContainer), - mu: &sync.Mutex{}, + defaultImage: defaultImage, + gpus: gpus, + modelDir: modelDir, + dockerClient: dockerClient, + gpuContainers: make(map[string]string), + containers: make(map[string]*RunnerContainer), + mu: &sync.Mutex{}, }, nil } @@ -162,8 +168,12 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo // NOTE: We currently allow only one container per GPU for each pipeline. containerHostPort := containerHostPorts[pipeline][:3] + gpu containerName := dockerContainerName(pipeline, modelID, containerHostPort) + containerImage := m.defaultImage + if pipelineSpecificImage, ok := pipelineToImage[pipeline]; ok { + containerImage = pipelineSpecificImage + } - slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID)) + slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID), slog.String("containerImage", containerImage)) // Add optimization flags as environment variables. envVars := []string{ @@ -175,7 +185,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo } containerConfig := &container.Config{ - Image: m.containerImageID, + Image: containerImage, Env: envVars, Volumes: map[string]struct{}{ containerModelDir: {}, diff --git a/worker/multipart.go b/worker/multipart.go index 865b9114..d3d75fb3 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -240,3 +240,71 @@ func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestB return mw, nil } + +func NewSegmentAnything2MultipartWriter(w io.Writer, req SegmentAnything2MultipartRequestBody) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + writer, err := mw.CreateFormFile("image", req.Image.Filename()) + if err != nil { + return nil, err + } + imageSize := req.Image.FileSize() + imageRdr, err := req.Image.Reader() + if err != nil { + return nil, err + } + copied, err := io.Copy(writer, imageRdr) + if err != nil { + return nil, err + } + if copied != imageSize { + return nil, fmt.Errorf("failed to copy image to multipart request imageBytes=%v copiedBytes=%v", imageSize, copied) + } + + // Handle input fields. + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, err + } + } + if req.PointCoords != nil { + if err := mw.WriteField("point_coords", *req.PointCoords); err != nil { + return nil, err + } + } + if req.PointLabels != nil { + if err := mw.WriteField("point_labels", *req.PointLabels); err != nil { + return nil, err + } + } + if req.Box != nil { + if err := mw.WriteField("box", *req.Box); err != nil { + return nil, err + } + } + if req.MaskInput != nil { + if err := mw.WriteField("mask_input", *req.MaskInput); err != nil { + return nil, err + } + } + if req.MultimaskOutput != nil { + if err := mw.WriteField("multimask_output", strconv.FormatBool(*req.MultimaskOutput)); err != nil { + return nil, err + } + } + if req.ReturnLogits != nil { + if err := mw.WriteField("return_logits", strconv.FormatBool(*req.ReturnLogits)); err != nil { + return nil, err + } + } + if req.NormalizeCoords != nil { + if err := mw.WriteField("normalize_coords", strconv.FormatBool(*req.NormalizeCoords)); err != nil { + return nil, err + } + } + + if err := mw.Close(); err != nil { + return nil, err + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 3ae798f5..3ae7e692 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -110,6 +110,36 @@ type BodyImageToVideoImageToVideoPost struct { Width *int `json:"width,omitempty"` } +// BodySegmentAnything2SegmentAnything2Post defines model for Body_segment_anything_2_segment_anything_2_post. +type BodySegmentAnything2SegmentAnything2Post struct { + // Box A length 4 array given as a box prompt to the model, in XYXY format. + Box *string `json:"box,omitempty"` + + // Image Image to segment. + Image openapi_types.File `json:"image"` + + // MaskInput A low-resolution mask input to the model, typically from a previous prediction iteration, with the form 1xHxW (H=W=256 for SAM). + MaskInput *string `json:"mask_input,omitempty"` + + // ModelId Hugging Face model ID used for image generation. + ModelId *string `json:"model_id,omitempty"` + + // MultimaskOutput If true, the model will return three masks for ambiguous input prompts, often producing better masks than a single prediction. + MultimaskOutput *bool `json:"multimask_output,omitempty"` + + // NormalizeCoords If true, the point coordinates will be normalized to the range [0,1], with point_coords expected to be with respect to image dimensions. + NormalizeCoords *bool `json:"normalize_coords,omitempty"` + + // PointCoords Nx2 array of point prompts to the model, where each point is in (X,Y) in pixels. + PointCoords *string `json:"point_coords,omitempty"` + + // PointLabels Labels for the point prompts, where 1 indicates a foreground point and 0 indicates a background point. + PointLabels *string `json:"point_labels,omitempty"` + + // ReturnLogits If true, returns un-thresholded mask logits instead of a binary mask. + ReturnLogits *bool `json:"return_logits,omitempty"` +} + // BodyUpscaleUpscalePost defines model for Body_upscale_upscale_post. type BodyUpscaleUpscalePost struct { // Image Uploaded image to modify with the pipeline. @@ -153,6 +183,18 @@ type ImageResponse struct { Images []Media `json:"images"` } +// MasksResponse Response model for object segmentation. +type MasksResponse struct { + // Logits The raw, unnormalized predictions (logits) for the masks. + Logits string `json:"logits"` + + // Masks The generated masks. + Masks string `json:"masks"` + + // Scores The model's confidence scores for each generated mask. + Scores string `json:"scores"` +} + // Media A media object containing information about the generated media. type Media struct { // Nsfw Whether the media was flagged as NSFW. @@ -249,6 +291,9 @@ type ImageToImageMultipartRequestBody = BodyImageToImageImageToImagePost // ImageToVideoMultipartRequestBody defines body for ImageToVideo for multipart/form-data ContentType. type ImageToVideoMultipartRequestBody = BodyImageToVideoImageToVideoPost +// SegmentAnything2MultipartRequestBody defines body for SegmentAnything2 for multipart/form-data ContentType. +type SegmentAnything2MultipartRequestBody = BodySegmentAnything2SegmentAnything2Post + // TextToImageJSONRequestBody defines body for TextToImage for application/json ContentType. type TextToImageJSONRequestBody = TextToImageParams @@ -402,6 +447,9 @@ type ClientInterface interface { // ImageToVideoWithBody request with any body ImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // SegmentAnything2WithBody request with any body + SegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // TextToImageWithBody request with any body TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -459,6 +507,18 @@ func (c *Client) ImageToVideoWithBody(ctx context.Context, contentType string, b return c.Client.Do(req) } +func (c *Client) SegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewSegmentAnything2RequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewTextToImageRequestWithBody(c.Server, contentType, body) if err != nil { @@ -609,6 +669,35 @@ func NewImageToVideoRequestWithBody(server string, contentType string, body io.R return req, nil } +// NewSegmentAnything2RequestWithBody generates requests for SegmentAnything2 with any type of body +func NewSegmentAnything2RequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/segment-anything-2") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewTextToImageRequest calls the generic TextToImage builder with application/json body func NewTextToImageRequest(server string, body TextToImageJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader @@ -733,6 +822,9 @@ type ClientWithResponsesInterface interface { // ImageToVideoWithBodyWithResponse request with any body ImageToVideoWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ImageToVideoResponse, error) + // SegmentAnything2WithBodyWithResponse request with any body + SegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*SegmentAnything2Response, error) + // TextToImageWithBodyWithResponse request with any body TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) @@ -843,6 +935,32 @@ func (r ImageToVideoResponse) StatusCode() int { return 0 } +type SegmentAnything2Response struct { + Body []byte + HTTPResponse *http.Response + JSON200 *MasksResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r SegmentAnything2Response) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r SegmentAnything2Response) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type TextToImageResponse struct { Body []byte HTTPResponse *http.Response @@ -931,6 +1049,15 @@ func (c *ClientWithResponses) ImageToVideoWithBodyWithResponse(ctx context.Conte return ParseImageToVideoResponse(rsp) } +// SegmentAnything2WithBodyWithResponse request with arbitrary body returning *SegmentAnything2Response +func (c *ClientWithResponses) SegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*SegmentAnything2Response, error) { + rsp, err := c.SegmentAnything2WithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseSegmentAnything2Response(rsp) +} + // TextToImageWithBodyWithResponse request with arbitrary body returning *TextToImageResponse func (c *ClientWithResponses) TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) { rsp, err := c.TextToImageWithBody(ctx, contentType, body, reqEditors...) @@ -1152,6 +1279,60 @@ func ParseImageToVideoResponse(rsp *http.Response) (*ImageToVideoResponse, error return response, nil } +// ParseSegmentAnything2Response parses an HTTP response from a SegmentAnything2WithResponse call +func ParseSegmentAnything2Response(rsp *http.Response) (*SegmentAnything2Response, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &SegmentAnything2Response{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest MasksResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseTextToImageResponse parses an HTTP response from a TextToImageWithResponse call func ParseTextToImageResponse(rsp *http.Response) (*TextToImageResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1274,6 +1455,9 @@ type ServerInterface interface { // Image To Video // (POST /image-to-video) ImageToVideo(w http.ResponseWriter, r *http.Request) + // Segment Anything 2 + // (POST /segment-anything-2) + SegmentAnything2(w http.ResponseWriter, r *http.Request) // Text To Image // (POST /text-to-image) TextToImage(w http.ResponseWriter, r *http.Request) @@ -1310,6 +1494,12 @@ func (_ Unimplemented) ImageToVideo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } +// Segment Anything 2 +// (POST /segment-anything-2) +func (_ Unimplemented) SegmentAnything2(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Text To Image // (POST /text-to-image) func (_ Unimplemented) TextToImage(w http.ResponseWriter, r *http.Request) { @@ -1397,6 +1587,23 @@ func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.R handler.ServeHTTP(w, r.WithContext(ctx)) } +// SegmentAnything2 operation middleware +func (siw *ServerInterfaceWrapper) SegmentAnything2(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.SegmentAnything2(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // TextToImage operation middleware func (siw *ServerInterfaceWrapper) TextToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1556,6 +1763,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-video", wrapper.ImageToVideo) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/segment-anything-2", wrapper.SegmentAnything2) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.TextToImage) }) @@ -1569,47 +1779,57 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xa62/bOBL/VwjdAdsFUttJN9dDvqXdPoJru0HtbD/0AoMWRxK3EqnlI66v5//9wKEk", - "6xkre2kW2/pTYpOc+c2PnAfH/BKEMsulAGF0cPYl0GECGcV/zy8vXigllfufgQ4Vzw2XIjhzIwTcEFGg", - "cyk0kEwySCfBUZArmYMyHFBGpuPu8kUCxfIMtKYxuHWGmxSCs+Ctjt2nTe4+aKO4iIPt9ihQ8LvlClhw", - "9hGlXu+WVECrdXL1G4Qm2B4FzyTbLKllXC6NXBr4bFqfcqmNg9jEjXO6yK/yVFIGjOA4iXgKxEiyAmIU", - "FW7mCpizJpIqoyY4C1ZcULWp2XeOkjsWHgXI4JIzrzWiNnXrg6MWhNc2jrmIyUsaFqyTi5+J1cBIJFWF", - "A6c3ePVT2V5yvek1esdROMQ9z2gMbqr/p/Wxn/3YckZFCEsdUgehRsjTyWmbkRcilFbRGHTBh5EkBgGK", - "GiCoRpMwlRrSDUm5+ATMzTAJEIee5EpmuSGPEh4noMgNTa2TRDdEAbNhIYL8bmnKzebHOqevCpxkjjgr", - "CoTNVqAcBbj2lnPkZRvpkPNoQ9bcJAgt5zmkXMDth+kCxfccJs/uLTwed3n8GWIFCGad8NDDKHkskXJN", - "cqsTpHBNFdM4iwtuOE39nEkbH9lP0/2cfQ+xgDz6+B8FAmJq+A0s/VHYA2KxOzSP9I942CxnQNYJNe4T", - "fA5Ty4BESmZdSOQiFlI5PiPS3B7ybzubPQnJcR32uwIaufTQ+tDbzDuTXuag+mw4bpvwDoknMirdo+4x", - "OajCvAYQm5ELP/kSVAcOFwZiv5eIR0SgAE0zkOsmmtlsGA8DIbl2e4wLJ+StVOD/J1ZbmjofBooeXDhs", - "4ZilKStriE7lGhSpUDgxzKZ4jlcboo0CEZukY185n8wRdZ91dXrHnIrbzuTwnmoagdkswwTCTw3yjLLQ", - "Zu8SlIsQhBK/jOAyPIra8AyjYNT2ZE1CaVPmUpeMIhDaHTKpSEJVFtm0DnPupT5HMBXYlZQpUIFoAViX", - "kTkUbqmoYDIj3tsHqHCTe/ku96rBwmzyz4HgJSOfAX3I5FIQmucp34V8BeUe+515NHMjx42wPi91diJV", - "K1Xm5Qb6MN/OmSNS396secMZyPbH/qwZtRztH0c9dVekaAYanVxDKAVDyhqRHnXU+Xg54AsJ8DhphprT", - "p71a/UzCBcn5Z0j1CKWvvfA+vaOTahXTqJePMfkPZtT7SVEext1TVCbd7OXKhp/AtFEcnzxtw7gqFbot", - "5u5LB8pRTjNphXEb4GX6qjFpJincMx9e3VDhuu7fzMXjYuWap6kLIFzgUGcL3/ppzxB0w7B6upBcw5La", - "eDng6rOTtnHnlQm4mFDGdg7eMNgXJOR1o7QryjoFGrJVioXJ4FpCBSNchAqoLu1upA0EcG5jMhw09qfE", - "k9O/cEY85KqSiTVnrdN7PDv5qS8e4sw7hcMPKLurtZWR9iSi4WwylIhsjsVp9bc/9fxZ15z7CcqFbeyP", - "XyD2+PfT0++o4h3F5qH03RdO7lZq9rppj0+/XiwuB5p6bmhkV4+BoTzFPlma/hIFZx+/BH9XEAVnwd+m", - "u37itGgmTqsG3fa6W7c7UcAKzVxUlfukw0Ghtmb7zpwBW3+lKWcorrJ6yBRuIMOvbrOkLW+7w+It2QGh", - "StEN2lBH2xbQhxtoapLnpQM08WpDjW1Gl+CXfzUuLzihr8u3K613Cnr0Y7R9XxyB7jl53zgcg42XngSh", - "+1vBbfd0q0dtxltgnNa3wDco+ragkyN1/Rg1Le6hxGvq9sFJ5gaIn+jKR0O5Lx5F7Qa6kta08juu69Ik", - "dLTuqvmQgClLca9wTTWJUhrHwAjV5N385YdGaHdixocrtwluxGfE+r2p0jiq/rEq7Rd+9f4NWSegagJJ", - "SIWLwDQMQWvfNC8VXKl0b5fa4hztoSBttf3029Wzjy5n3elkY4f4toMdJlZ82nuwUYyfOvp04/T66X7u", - "VbVP91HgpI9BUOfYMbGXZOMnFTZeN1ff5i9ufCHRrS6pot7Yb7XHf59tkE4H/ZY2yKFp/v00zU+/6545", - "mUNOkefMpobnKRTLtL9J/vDfH9zR0DbPpSoA+2ZVUb4erh1/ShejE81GdjGKA9NKOM2E0pN19hb7qQwb", - "lT4Vm+L20j4PXzoQr2uZ+I0MUU1PLi4ePOwqEXzg0Hfi/Be7qYiZLNy3+/Kys8OrKmbWmBpxwfiVM5B3", - "KoP62tatHx/wd4V9VUjZhXdzG4XQHev9dgFU/lDhQeyp/wuodc4ahPQw5muxnvofB/Dgu1iGwYgSwzPQ", - "hmZ5l6bhUg0FFB6EUvdXa2680DQgsxzuCC75rpG3qGTt4c/UJzpgNSY9UR0GMWSFVnGzmbvN9GS4+/Az", - "oApU9fII45z/qhKSGJMHWyfD3ap6dqH4kc/7pIvCygpyflE1F3WNyTf8BnIA5cbfWyFQ0Q0o7WXdzCbH", - "kyeOWpmDoDkPzoInk+PJzO0kNQninuLzl8dGPi63s+yKtragehNUey+EdUNZjbujgagvWPk0aCGLzXaU", - "gzbPJNvgRUMKAwK1+CRIlZm6LPSYUUN3T7f2OdG41zzb5qa7JIhfeBdBFk5msxau2i5Mf9OOgrGgGvcJ", - "1N3KbBaviZFNyW7aUfDTPULYtZJ69D+jjLz3++H1Hj+M3itBrUmk4v8BhoqPnzyM4sJY8kIYVycupCRv", - "qIo96ycn9wqi01PrwtlNIVXf7fShNv9CGFCCpmQO6gZUiaAW07CEqEezj9fb66NA2yyjalN6NllIgr7t", - "lk4TbMLhTRgQfTMW+B5d8BV9rt4FHOty27pRBUS0Bis9FxGrX4L6Q+J5nqeb8uegxkMNjIvU1fWuSKjV", - "jk1esAQsKsGvHCRHPN944DDZ7FMe4uRwnDyEqLuGKP9UcyF9w6Ll1Vi3D3v1q753NuOdGcvvh3Lm4Z/A", - "H9iZm5eOgzMfnPkrOLN3LXRmV9yPyNCvWl11dOVaE1133bjWnLnVi/+/+0Gz/XPIvAdn/UacFdvgzcRb", - "POMY9tIrP4FQURTTq035TBF/fjaaKNAytWW3rumxxfKvnHN7H6UcHPfguN+I45ZetPWrnBiNi5qaqpbj", - "81RaRp7LLLOCmw15RQ2s6SYoXk5go1OfTadMAc0ex350khbLJ6Fbjr9NDMifG2wzDImtBGmcN6U5n67A", - "0GnZnw+219v/BQAA//8cUcrpEDoAAA==", + "H4sIAAAAAAAC/+xbe2/bxrL/KgveCzQBZFt26ubCQP9w0jY2bpIasdO0yDWEFTkityF32X3YVnP93Q9m", + "dknxack5jovT6i+L5OzMb2Z2Hvvw5yhWRakkSGuio8+RiTMoOP08Pjv9UWul8XcCJtaitELJ6Ai/MMBP", + "TIMplTTACpVAvhtNolKrErQVQDwKk/aHX2QQhhdgDE8Bx1lhc4iOojcmxadliQ/GaiHT6PZ2Emn4wwkN", + "SXT0kbherobUQOtxav47xDa6nUQvVLKccZcINbNqZuHGdp5KZSxCbOMmmj7y92WueAIJo+9sIXJgVrE5", + "MKu5RMo5JKjNQumC2+gomgvJ9bKh3zFx7mk4iciCM5F4qQvuchwfTToQTlyaCpmyn3gcrM5Of2DOQMIW", + "Stc4iLxlV0+arDWuV71h3s1MOGZ7UfAUkNT/6DwOWz91IuEyhpmJOUJoGOT57mHXIj/KWDnNUzDBHlax", + "FCRoboGRGMPiXBnIlywX8hMkSGEzYIielVoVpWVPMpFmoNkVzx1y4kumIXFxYMH+cDwXdvm0adNXASc7", + "J5y1CaQr5qDRBDT2jnnkeVuFyMViya6FzQhaKUrIhYS7J9MpsR+YTN66d9hxv2/HHyDVQGCuMxF7GJUd", + "K6TCsNKZjEx4zXViiEpIYQXPPc1uFx9bb6aHmfseYoC88fSfRBJSbsUVzPxUWAPiYjVpnpinNNmcSIBd", + "Z9ziE9zEuUuALbQq+pDYaSqVRnsuWNs97P/cdPosZvtN2G8DNHbmoQ2hd4UPJjMrQQ/psN9V4S0ZnqlF", + "FR7NiClBB/VaQFzBTj3xGegeHCEtpN6XhEcuQAOpZqE0bTTT6TieBKQSBn1MA3fZG6XB/2bOOJ5jDAOn", + "CA4BGwKzUmXuLDO5ugbNahTIJnE5zeP5khmrQaY26+lX0bNzQj2kXdO8m8yKu+bkuE8NX4BdzuIM4k8t", + "41ntoGu9M9CYIRhnfhijYTQVjRUFZcFFN5INi5XLEyxdarEAaXCSKc0yrouFy5swzz3XlwSmBjtXKgcu", + "CS1A0rfIOYSw1FwmqmA+2kdMgcSD9q581bLCdPd/RpKXWvgK6FOmUJLxsszFKuVrqHzsPfNkil/2W2n9", + "vJLZy1SdUllWDvRpvlszNyh9a6vmlUhAdR+Hq+aiE2jfTQb6roXmBRgKcgOxkgmZrJXpSUbTHj+NxEIG", + "Is3aqebw+aBUT8mEZKW4gdxsIPTEMx+Su3FRrXMa9/wpJ39hRX2YEuVh3L9EFQqpZ3MXfwLbRbF/8LwL", + "430lEF0s8CWCQpPzQjlp0QGep+8as3aRIp/59IqfQujizwLzcRh5LfIcE4iQ9Knnwjee7AWBbinWLBdK", + "GJhxl85GQn160FXuuFaBBjOeJKsAbynsGxJ20mrtQlunwUAxz6kxGR3LuEyYkLEGbiq9W2WDABy7lI0n", + "jfUl8eDwP7gibmtVZYlrkXRm7/704NuhfEiU90qHH4h3X2qnIq0pROPVZKwQGUgLkHbG5dJmQqazg6FX", + "wwVprm4G9g5YTlONfcu41nzJUnEFknHDOJurm2o5FiKSMugELfXrb7/+xnzebtrlhboZXf/0hZ9WlSHo", + "8KW1gJtPMyFLZwf1U9c7GozKHaU/JGZE3FHKLksRU/zSUoGzUsOVUM7gj0TENFrYMAMnq7UhRdD+zcnN", + "B/bk5PsP3x8cfkeT9/z4TauPeYOSTwnmVytoX7rmKlyO8W4+zZSztSHvyByn2Nk5mKws6OuPBus0FiBs", + "/5ChIVy8mIvUoTG96f20MhOmFhYkPiYuRr3mYC3oMNJmXGJuEjLNoeGGllYVcvazRz6UZiROqlz8CbNY", + "KZ2Y+6lXKiEto5FCcgumLrU131VDy2UK7ON0sn8ZpgiNDnIZ3JQQW08+B0+gweBLfOXdl4gCs6qSpl3b", + "giz20uswpGhTWD8Y3t4chChXi6BVcEQnFq4z0MCAxwE+E+g49uTXyW9PV3mytXgisi6y1QTzwHI+h3wA", + "2Gt6X/c+LWgVmn0mZCJisj9HUki1cjIJ1NgZTFskcx5/apL04XqxQ3D9NJ7lKhX2HrPFDzPMyR2MAJOp", + "HHshmp6eFxPSWOwP1AIhUo6j701073wQvfbS+37eqMrco1SMVRtX0lZI/Xe4rvxVm2oPkzGDbsmXb1et", + "6SafH/6D9lc2suZ2o2Vd83q/jY3BMB2I6ZOLi7ORIyT8tOEZUgKWi5xOZfL850V09PFz9N8aFtFR9F97", + "q9OrvXB0tVcfB91e9neJkBUkQbKQ9T7Rbs8GQWxD95U6I7r+wnORELta6zFVhIWCXt2lSZff7QqL12QF", + "hKos6dBE22UwhBt4brOXVQC08RrLrWtnl+jn/21tlRHB0JnSaiNnJWBAPmXbd2EK9OfJu9bkGG05BwqE", + "GT547IYnjt7IGW8gEbzpAr8dPuSCXq00zWnU1njAJNi3m3uZxI+tVjUjVmn2Fl2raH49YU422stV82vY", + "Ez/0ad0vUbfcTCvdzqG9Vlrrih4/MsFgwo6VHnMt2eMbTLhyIRIqNJ6ccFNz2RbZSoye8dqj5wDMVOTB", + "qpcd7Hf6l2bSwMqxwA+VM2MlLRd+K0o29rPnCleSbfPhuL7DpVlc98V8yMBWG3te4DU3bJHzNIUE1+Fv", + "z3/60CrdyGbzcoSewC++42nuwtYSN9pNcTofZv7+3evQoa9UiLnECsvjGIzxR/CVgPc6X+tVRzTGQyGz", + "Nf1J7hrwI/Yk9wpTOm++K3HFmZPro4XYeNKNsxeRN7PXSy+qm70mEXLfBEHTxmiJtUa2nijoeNkefVe8", + "4PcLRWnzjGvulf273hh4yEOV3nn8HYcq2yP4f84R/OE/+gSenUPJyc60CVnSTp/flKKdgm/+/xucGsaV", + "pdIBcL1VtV1W/mVnIr1stuGZSJgwnYLTLigDVWftYi5XcWslx+UyrE678+FzD+LlbbN3jknMQC0O1ydX", + "nQhdlxyacf7FipQwswt8u64uox5eVKBsWGqDBeQvIgF1rzZo6BC8c5WBbims60KqM32kbTVC91zPdRug", + "6tqDB7FmfRegNm3WMsiAxXwvNtD/0wea+JjLKBlxZkUBxvKi7JtpvFUjBiGCiOv6bg2/B0kjPKvPPcaV", + "vRvGu6h5rbGfbRIisIYlvaF6FqSUFTst7PIcnemNcXJxcfYCuAZd32OmPOdf1Uwya8voFnngqmrAC+HK", + "kI9JzMLaSXZ8Wm8em+ayV1xBCaDx+zsnJQm6Am08r6vp7v7uMzStKkHyUkRH0bPd/d0pepLbjHDv0WXa", + "Hat2KndWu94dF9Q3jBu3j/05SujGcWoQ6tOkumh8oYKz0eRg7AuVLGmhoaQFSVJ8EeTa7mEV2km45auL", + "4OuCaLO7wbdtp2MRpBc+RMgKB9NpB1fDC3u/GzTBpqBa6wmS3alsjpaJC5ezFdkk+vYBIay2Cgfkv+AJ", + "e+f94eXuP47c95I7mykt/oSEBO8/exzBQVn2o7TYJ14oxV5znXqrHxw8KIjenmkfzoqE1fuqh4/l/FNp", + "QUues3PQV6ArBI2cRi1EM5t9vLy9nETGFQXXyyqy2YViFNs4dC+jTVZaCQOhb+cCvwcbfcWYa+7ybhpy", + "t02lAkTShjo9zIj1Sd9wSjwuy3xZHfe1rn1SXuTY12OT0Ogd23ahFjB0gl85SW5wGfSR02R7H3qbJ8fz", + "5DZF3TdF+YtWF8pvWHSimvr28ah+NXRrd/Ngpvb7sYJ5/ELdIwdze9GxDeZtMH+FYPahRcEcjhp3qps9", + "OwfjAX3uacPBFl3s4nIsigPxceB78JUj+R6Xlh45otvHiNuI3kb0w0V0FZFVlLEDH9W4ZN+g737VOSuj", + "At04GjP9sG5sud4Z0f/eqr+9qbvtp7cB+zcJWDrcarfT4fLdeJS+9wR1rWXzZfWvTHSpxBq2+jeFfsSG", + "4V+5/g5eJdwG7jZw/yaBW0XRrR+FbAwN6vw/QnWQ8DJXLmEvVVE4KeySveIWrvkyCveh6PjCHO3tJRp4", + "sZP6r7t5GL4b43A6cRzhf25p83CMbc3IEN0eL8XeHCzfq07dotvL238FAAD//x9j1/Y0RgAA", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index ca707327..06b0b946 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -40,14 +40,13 @@ func (sb EnvValue) String() string { type OptimizationFlags map[string]EnvValue type Worker struct { - manager *DockerManager - + manager *DockerManager externalContainers map[string]*RunnerContainer mu *sync.Mutex } -func NewWorker(containerImageID string, gpus []string, modelDir string) (*Worker, error) { - manager, err := NewDockerManager(containerImageID, gpus, modelDir) +func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) { + manager, err := NewDockerManager(defaultImage, gpus, modelDir) if err != nil { return nil, err } @@ -304,6 +303,54 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return resp.JSON200, nil } +func (w *Worker) SegmentAnything2(ctx context.Context, req SegmentAnything2MultipartRequestBody) (*MasksResponse, error) { + c, err := w.borrowContainer(ctx, "segment-anything-2", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + var buf bytes.Buffer + mw, err := NewSegmentAnything2MultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.SegmentAnything2WithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("segment anything 2 container returned 422", slog.String("err", string(val))) + return nil, errors.New("segment anything 2 container returned 422") + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("segment anything 2 container returned 400", slog.String("err", string(val))) + return nil, errors.New("segment anything 2 container returned 400") + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("segment anything 2 container returned 500", slog.String("err", string(val))) + return nil, errors.New("segment anything 2 container returned 500") + } + + return resp.JSON200, nil +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags)