Skip to content

Commit

Permalink
feat: add LLM Pipeline (#137)
Browse files Browse the repository at this point in the history
This commit adds a new LLM pipeline and  to the ai-worker.
  • Loading branch information
kyriediculous authored Sep 30, 2024
1 parent 2d158a3 commit 6b00498
Show file tree
Hide file tree
Showing 14 changed files with 1,004 additions and 58 deletions.
18 changes: 18 additions & 0 deletions dev/check_torch_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import subprocess

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")

# Check system CUDA version
try:
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_version = nvcc_output.split("release ")[-1].split(",")[0]
print(f"System CUDA version: {cuda_version}")
except:
print("Unable to check system CUDA version")

# Print the current device
print(f"Current device: {torch.cuda.get_device_name(0)}")
6 changes: 6 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline

return SegmentAnything2Pipeline(model_id)
case "llm":
from app.pipelines.llm import LLMPipeline
return LLMPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -88,6 +91,9 @@ def load_route(pipeline: str) -> any:
from app.routes import segment_anything_2

return segment_anything_2.router
case "llm":
from app.routes import llm
return llm.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
201 changes: 201 additions & 0 deletions runner/app/pipelines/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import asyncio
import logging
import os
import psutil
from typing import Dict, Any, List, Optional, AsyncGenerator, Union

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download, snapshot_download
from threading import Thread

logger = logging.getLogger(__name__)


def get_max_memory():
num_gpus = torch.cuda.device_count()
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"
max_memory = {**gpu_memory, "cpu": cpu_memory}

logger.info(f"Max memory configuration: {max_memory}")
return max_memory


def load_model_8bit(model_id: str, **kwargs):
max_memory = get_max_memory()

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
max_memory=max_memory,
offload_folder="offload",
low_cpu_mem_usage=True,
**kwargs
)

return tokenizer, model


def load_model_fp16(model_id: str, **kwargs):
device = get_torch_device()
max_memory = get_max_memory()

# Check for fp16 variant
local_model_path = os.path.join(
get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
has_fp16_variant = any(".fp16.safetensors" in fname for _, _,
files in os.walk(local_model_path) for fname in files)

if device != "cpu" and has_fp16_variant:
logger.info("Loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

checkpoint_dir = snapshot_download(
model_id, cache_dir=get_model_dir(), local_files_only=True)

model = load_checkpoint_and_dispatch(
model,
checkpoint_dir,
device_map="auto",
max_memory=max_memory,
# Adjust based on your model architecture
no_split_module_classes=["LlamaDecoderLayer"],
dtype=kwargs.get("torch_dtype", torch.float32),
offload_folder="offload",
offload_state_dict=True,
)

return tokenizer, model


class LLMPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {
"cache_dir": get_model_dir(),
"local_files_only": True,
}
self.device = get_torch_device()

# Generate the correct folder name
folder_path = file_download.repo_folder_name(
repo_id=model_id, repo_type="model")
self.local_model_path = os.path.join(get_model_dir(), folder_path)
self.checkpoint_dir = snapshot_download(
model_id, cache_dir=get_model_dir(), local_files_only=True)

logger.info(f"Local model path: {self.local_model_path}")
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}")

use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true"

if use_8bit:
logger.info("Using 8-bit quantization")
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs)
else:
logger.info("Using fp16/bf16 precision")
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs)

logger.info(
f"Model loaded and distributed. Device map: {self.model.hf_device_map}"
)

# Set up generation config
self.generation_config = self.model.generation_config

self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Optional: Add optimizations
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"LLMPipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)

async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
conversation = []
if system_msg:
conversation.append({"role": "system", "content": system_msg})
if history:
conversation.extend(history)
conversation.append({"role": "user", "content": prompt})

input_ids = self.tokenizer.apply_chat_template(
conversation, return_tensors="pt").to(self.model.device)
attention_mask = torch.ones_like(input_ids)

max_new_tokens = kwargs.get("max_tokens", 256)
temperature = kwargs.get("temperature", 0.7)

streamer = TextIteratorStreamer(
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

generate_kwargs = self.generation_config.to_dict()
generate_kwargs.update({
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.eos_token_id,
})

thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs)
thread.start()

total_tokens = 0
try:
for text in streamer:
total_tokens += 1
yield text
await asyncio.sleep(0) # Allow other tasks to run
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise

input_length = input_ids.size(1)
yield {"tokens_used": input_length + total_tokens}

def model_generate_wrapper(self, **kwargs):
try:
logger.debug("Entering model.generate")
with torch.cuda.amp.autocast(): # Use automatic mixed precision
self.model.generate(**kwargs)
logger.debug("Exiting model.generate")
except Exception as e:
logger.error(f"Error in model.generate: {str(e)}", exc_info=True)
raise

def __str__(self):
return f"LLMPipeline(model_id={self.model_id})"
118 changes: 118 additions & 0 deletions runner/app/routes/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import logging
import os
from typing import Annotated
from fastapi import APIRouter, Depends, Form, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, LLMResponse, http_error
import json

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_200_OK: {"model": LLMResponse},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post(
"/llm",
response_model=LLMResponse,
responses=RESPONSES,
operation_id="genLLM",
description="Generate text using a language model.",
summary="LLM",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "llm"},
)
@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False)
async def llm(
prompt: Annotated[str, Form()],
model_id: Annotated[str, Form()] = "",
system_msg: Annotated[str, Form()] = "",
temperature: Annotated[float, Form()] = 0.7,
max_tokens: Annotated[int, Form()] = 256,
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON
stream: Annotated[bool, Form()] = False,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

try:
history_list = json.loads(history)
if not isinstance(history_list, list):
raise ValueError("History must be a JSON array")

generator = pipeline(
prompt=prompt,
history=history_list,
system_msg=system_msg if system_msg else None,
temperature=temperature,
max_tokens=max_tokens
)

if stream:
return StreamingResponse(stream_generator(generator), media_type="text/event-stream")
else:
full_response = ""
async for chunk in generator:
if isinstance(chunk, dict):
tokens_used = chunk["tokens_used"]
break
full_response += chunk

return LLMResponse(response=full_response, tokens_used=tokens_used)

except json.JSONDecodeError:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "Invalid JSON format for history"}
)
except ValueError as ve:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(ve)}
)
except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error during LLM processing."}
)


async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict): # This is the final result
yield f"data: {json.dumps(chunk)}\n\n"
break
else:
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
5 changes: 5 additions & 0 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class TextResponse(BaseModel):
chunks: List[chunk] = Field(..., description="The generated text chunks.")


class LLMResponse(BaseModel):
response: str
tokens_used: int


class APIError(BaseModel):
"""API error response model."""

Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ function download_restricted_models() {

# Download text-to-image and image-to-image models.
huggingface-cli download black-forest-labs/FLUX.1-dev --include "*.safetensors" "*.json" "*.txt" "*.model" --exclude ".onnx" ".onnx_data" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}
# Download LLM models (Warning: large model size)
huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "*.json" "*.bin" "*.safetensors" "*.txt" --cache-dir models

}

# Enable HF transfer acceleration.
Expand Down
Loading

0 comments on commit 6b00498

Please sign in to comment.