Skip to content

Commit

Permalink
Support for Azure OpenAI API
Browse files Browse the repository at this point in the history
Authored-by: Pablo Valdunciel <pablo.valdunciel@docyet.com>
  • Loading branch information
pabvald authored Sep 24, 2024
1 parent 76b528d commit 7ae06bb
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Support for [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference). Requires two environment variables to be st: `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_HOST`(i.e. https://<resource-name>.openai.azure.com).

### Fixed

## [0.56.1]
Expand Down
11 changes: 11 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ Requires two environment variables to be set:
"""
struct DatabricksOpenAISchema <: AbstractOpenAISchema end

"""
AzureOpenAISchema
AzureOpenAISchema() allows user to call Azure OpenAI API. [API Reference](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
Requires two environment variables to be set:
- `AZURE_OPENAI_API_KEY`: Azure token
- `AZURE_OPENAI_HOST`: Address of the Azure resource (`"https://<resource>.openai.azure.com"`)
"""
struct AzureOpenAISchema <: AbstractOpenAISchema end

"""
FireworksOpenAISchema
Expand Down
51 changes: 51 additions & 0 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,34 @@ function OpenAI.create_chat(schema::DatabricksOpenAISchema,
kwargs...)
end
end
function OpenAI.create_chat(schema::AzureOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
api_version = api_version
)
# Override standard OpenAI request endpoint
OpenAI.openai_request(
"chat/completions",
provider;
method = "POST",
http_kwargs = http_kwargs,
messages = conversation,
query = Dict("api-version" => provider.api_version),
streamcallback = streamcallback,
kwargs...
)
end

# Extend OpenAI create_embeddings to allow for testing
function OpenAI.create_embeddings(schema::AbstractOpenAISchema,
Expand Down Expand Up @@ -367,6 +395,29 @@ function OpenAI.create_embeddings(schema::FireworksOpenAISchema,
base_url = url)
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::AzureOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
api_version = api_version)
# Override standard OpenAI request endpoint
OpenAI.openai_request(
"embeddings",
provider;
method = "POST",
input = docs,
query = Dict("api-version" => provider.api_version),
kwargs...
)
end

## Temporary fix -- it will be moved upstream
function OpenAI.create_embeddings(provider::AbstractCustomProvider,
Expand Down
15 changes: 14 additions & 1 deletion src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Check your preferences by calling `get_preferences(key::String)`.
# Available Preferences (for `set_preferences!`)
- `OPENAI_API_KEY`: The API key for the OpenAI API. See [OpenAI's documentation](https://platform.openai.com/docs/quickstart?context=python) for more information.
- `AZURE_OPENAI_API_KEY`: The API key for the Azure OpenAI API. See [Azure OpenAI's documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) for more information.
- `AZURE_OPENAI_HOST`: The host for the Azure OpenAI API. See [Azure OpenAI's documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) for more information.
- `MISTRALAI_API_KEY`: The API key for the Mistral AI API. See [Mistral AI's documentation](https://docs.mistral.ai/) for more information.
- `COHERE_API_KEY`: The API key for the Cohere API. See [Cohere's documentation](https://docs.cohere.com/docs/the-cohere-platform) for more information.
- `DATABRICKS_API_KEY`: The API key for the Databricks Foundation Model API. See [Databricks' documentation](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html) for more information.
Expand Down Expand Up @@ -39,6 +41,8 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava
# Available ENV Variables
- `OPENAI_API_KEY`: The API key for the OpenAI API.
- `AZURE_OPENAI_API_KEY`: The API key for the Azure OpenAI API.
- `AZURE_OPENAI_HOST`: The host for the Azure OpenAI API. This is the URL built as `https://<resource-name>.openai.azure.com`.
- `MISTRALAI_API_KEY`: The API key for the Mistral AI API.
- `COHERE_API_KEY`: The API key for the Cohere API.
- `LOCAL_SERVER`: The URL of the local server to use for `ai*` calls. Defaults to `http://localhost:10897/v1`. This server is called when you call `model="local"`
Expand All @@ -62,6 +66,8 @@ const PREFERENCES = nothing
"Keys that are allowed to be set via `set_preferences!`"
const ALLOWED_PREFERENCES = ["MISTRALAI_API_KEY",
"OPENAI_API_KEY",
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_HOST",
"COHERE_API_KEY",
"DATABRICKS_API_KEY",
"DATABRICKS_HOST",
Expand Down Expand Up @@ -138,6 +144,8 @@ global MODEL_IMAGE_GENERATION::String = @load_preference("MODEL_IMAGE_GENERATION
# First, load from preferences, then from environment variables
# Instantiate empty global variables
global OPENAI_API_KEY::String = ""
global AZURE_OPENAI_API_KEY::String = ""
global AZURE_OPENAI_HOST::String = ""
global MISTRALAI_API_KEY::String = ""
global COHERE_API_KEY::String = ""
global DATABRICKS_API_KEY::String = ""
Expand All @@ -163,7 +171,12 @@ function load_api_keys!()
# Note: Disable this warning by setting OPENAI_API_KEY to anything
isempty(OPENAI_API_KEY) &&
@warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=<api-key>`!"

global AZURE_OPENAI_API_KEY
AZURE_OPENAI_API_KEY = @load_preference("AZURE_OPENAI_API_KEY",
default=get(ENV, "AZURE_OPENAI_API_KEY", ""))
global AZURE_OPENAI_HOST
AZURE_OPENAI_HOST = @load_preference("AZURE_OPENAI_HOST",
default=get(ENV, "AZURE_OPENAI_HOST", ""))
global MISTRALAI_API_KEY
MISTRALAI_API_KEY = @load_preference("MISTRALAI_API_KEY",
default=get(ENV, "MISTRALAI_API_KEY", ""))
Expand Down

0 comments on commit 7ae06bb

Please sign in to comment.