diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index 66c337c..bd183f2 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -341,7 +341,7 @@ func main() { handler := message.NewHandler(rec, log, ace, ce, vllme, aoe, v, uv, m, um, rlm, accessCache, userAccessCache) - eventConsumer := message.NewConsumer(eventMessageChan, log, 4, handler.HandleEventWithRequestAndResponse) + eventConsumer := message.NewConsumer(eventMessageChan, log, 4, handler.HandleEventWithRequestAndResponse, cfg.OpenAiUrls) eventConsumer.StartEventMessageConsumers() detector, err := amazon.NewClient(cfg.AmazonRequestTimeout, cfg.AmazonConnectionTimeout, log, cfg.AmazonRegion) @@ -352,7 +352,7 @@ func main() { scanner := pii.NewScanner(detector) cd := custompolicy.NewOpenAiDetector(cfg.CustomPolicyDetectionTimeout, cfg.OpenAiApiKey) - ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache, userAccessCache, pm, scanner, cd, die, um, cfg.RemoveUserAgent) + ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache, userAccessCache, pm, scanner, cd, die, um, cfg.RemoveUserAgent, cfg.OpenAiUrls) if err != nil { log.Sugar().Fatalf("error creating proxy http server: %v", err) } diff --git a/docs/proxy.yaml b/docs/proxy.yaml index 2e6c492..bc5ce0a 100644 --- a/docs/proxy.yaml +++ b/docs/proxy.yaml @@ -37,7 +37,7 @@ paths: 200: description: Service is up and running. - /api/providers/openai/v1/chat/completions: + /api/providers/{openaiProviderName}/v1/chat/completions: post: parameters: - in: header @@ -55,12 +55,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: OpenAI Chat Completions description: This endpoint is set up for proxying OpenAI chat completion requests. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/chat). - /api/providers/openai/v1/embeddings: + /api/providers/{openaiProviderName}/v1/embeddings: post: parameters: - in: header @@ -78,12 +85,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Call OpenAI embeddings description: This endpoint is set up for proxying OpenAI embedding requests. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/embeddings/create). - /api/providers/openai/v1/moderations: + /api/providers/{openaiProviderName}/v1/moderations: post: parameters: - in: header @@ -101,12 +115,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Call OpenAI moderations description: This endpoint is set up for proxying OpenAI moderation requests. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/moderations/create). - /api/providers/openai/v1/models: + /api/providers/{openaiProviderName}/v1/models: get: parameters: - in: header @@ -119,12 +140,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Get OpenAI models description: This endpoint is set up for retrieving OpenAI models. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/models/list). - /api/providers/openai/v1/models/{model}: + /api/providers/{openaiProviderName}/v1/models/{model}: get: tags: - OpenAI @@ -146,6 +174,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: model required: true @@ -153,7 +188,7 @@ paths: type: string description: Model identifier - /api/providers/openai/v1/files: + /api/providers/{openaiProviderName}/v1/files: get: parameters: - in: header @@ -171,6 +206,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: List files @@ -198,7 +240,7 @@ paths: summary: Upload a file description: This endpoint is set up for creating an OpenAI file. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/files/create). - /api/providers/openai/v1/files/{file_id}: + /api/providers/{openaiProviderName}/v1/files/{file_id}: post: tags: - OpenAI @@ -220,6 +262,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: file_id required: true @@ -254,7 +303,7 @@ paths: summary: Retrieve a file description: This endpoint is set up for retrieving an OpenAI file. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/files/retrieve). - /api/providers/openai/v1/files/{file_id}/content: + /api/providers/{openaiProviderName}/v1/files/{file_id}/content: get: parameters: - in: header @@ -272,6 +321,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: file_id required: true @@ -282,7 +338,7 @@ paths: summary: Retrieve file content description: This endpoint is set up for retrieving an OpenAI file content. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/files/retrieve-contents). - /api/providers/openai/v1/batches: + /api/providers/{openaiProviderName}/v1/batches: post: parameters: - in: header @@ -300,6 +356,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create a batch @@ -327,7 +390,7 @@ paths: summary: List batches description: This endpoint is set up for listing batches. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/batch/list). - /api/providers/openai/v1/batches/{batch_id}: + /api/providers/{openaiProviderName}/v1/batches/{batch_id}: get: tags: - OpenAI @@ -349,6 +412,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: batch_id required: true @@ -377,7 +447,7 @@ paths: summary: Cancel a batch description: This endpoint is set up for canceling a batch. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/batch/cancel). - /api/providers/openai/v1/images/generations: + /api/providers/{openaiProviderName}/v1/images/generations: post: parameters: - in: header @@ -395,12 +465,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Generate images description: This endpoint is set up for generating OpenAI images. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/images/create). - /api/providers/openai/v1/images/edits: + /api/providers/{openaiProviderName}/v1/images/edits: post: parameters: - in: header @@ -418,12 +495,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Edit images description: This endpoint is set up for editing OpenAI generated images. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/images/createEdit). - /api/providers/openai/v1/images/variations: + /api/providers/{openaiProviderName}/v1/images/variations: post: parameters: - in: header @@ -441,12 +525,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create image variations description: This endpoint is set up for creating OpenAI image variations. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/images/createVariation). - /api/providers/openai/v1/audio/speech: + /api/providers/{openaiProviderName}/v1/audio/speech: post: parameters: - in: header @@ -464,12 +555,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create speech description: This endpoint is set up for creating speeches. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/audio/createSpeech). - /api/providers/openai/v1/audio/transcriptions: + /api/providers/{openaiProviderName}/v1/audio/transcriptions: post: parameters: - in: header @@ -487,12 +585,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create transcriptions description: This endpoint is set up for editing generated images. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/audio/createTranscription). - /api/providers/openai/v1/audios/translations: + /api/providers/{openaiProviderName}/v1/audios/translations: post: parameters: - in: header @@ -510,12 +615,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create translations description: This endpoint is set up for creating translations. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/audio/createTranslation). - /api/providers/openai/v1/assistants: + /api/providers/{openaiProviderName}/v1/assistants: post: parameters: - in: header @@ -533,6 +645,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create assistant @@ -560,7 +679,7 @@ paths: summary: List assistants description: This endpoint is set up for listing OpenAI assistants. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/assistants/listAssistants). - /api/providers/openai/v1/assistants/{assistant_id}: + /api/providers/{openaiProviderName}/v1/assistants/{assistant_id}: get: tags: - OpenAI @@ -582,6 +701,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: assistant_id required: true @@ -642,7 +768,7 @@ paths: summary: Delete assistant description: This endpoint is set up for deleting an OpenAI assistant. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/assistants/deleteAssistant). - /api/providers/openai/v1/assistants/{assistant_id}/files: + /api/providers/{openaiProviderName}/v1/assistants/{assistant_id}/files: post: parameters: - in: path @@ -665,6 +791,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create assistant file @@ -697,7 +830,7 @@ paths: summary: List assistant files description: This endpoint is set up for retrieving OpenAI assistant files. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/assistants/listAssistantFiles). - /api/providers/openai/v1/assistants/{assistant_id}/files/{file_id}: + /api/providers/{openaiProviderName}/v1/assistants/{assistant_id}/files/{file_id}: get: parameters: - in: header @@ -715,6 +848,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: assistant_id required: true @@ -762,7 +902,7 @@ paths: summary: Delete assistant file description: This endpoint is set up for deleting an OpenAI assistant file. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/assistants/deleteAssistantFile). - /api/providers/openai/v1/threads: + /api/providers/{openaiProviderName}/v1/threads: post: parameters: - in: header @@ -780,12 +920,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create thread description: This endpoint is set up for creating an OpenAI thread. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/threads/createThread). - /api/providers/openai/v1/threads/{thread_id}: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}: get: tags: - OpenAI @@ -807,6 +954,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: thread_id required: true @@ -862,7 +1016,7 @@ paths: summary: Delete thread description: This endpoint is set up for deleting an OpenAI thread. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/threads/deleteThread). - /api/providers/openai/v1/threads/{thread_id}/messages: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/messages: post: parameters: - in: header @@ -880,6 +1034,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: thread_id required: true @@ -918,7 +1079,7 @@ paths: summary: List messages description: This endpoint is set up for listing OpenAI messages. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/messages/listMessages). - /api/providers/openai/v1/threads/{thread_id}/messages/{message_id}: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/messages/{message_id}: get: tags: - OpenAI @@ -940,6 +1101,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: message_id required: true @@ -982,7 +1150,7 @@ paths: summary: Modify message description: This endpoint is set up for modifying an OpenAI message. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/messages/modifyMessage). - ? /api/providers/openai/v1/threads/{thread_id}/messages/{message_id}/files/{file_id} + ? /api/providers/{openaiProviderName}/v1/threads/{thread_id}/messages/{message_id}/files/{file_id} : get: tags: - OpenAI @@ -1004,6 +1172,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: file_id required: true @@ -1020,9 +1195,16 @@ paths: schema: type: string - /api/providers/openai/v1/threads/{thread_id}/messages/{message_id}/files: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/messages/{message_id}/files: get: parameters: + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: thread_id required: true @@ -1053,7 +1235,7 @@ paths: summary: List message files description: This endpoint is set up for retrieving OpenAI message files. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/messages/listMessageFiles). - /api/providers/openai/v1/threads/{thread_id}/runs: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/runs: post: parameters: - in: header @@ -1071,6 +1253,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: thread_id required: true @@ -1108,7 +1297,7 @@ paths: summary: List runs description: This endpoint is set up for retrieving OpenAI runs. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/runs/listRuns). - /api/providers/openai/v1/threads/{thread_id}/runs/{run_id}: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/runs/{run_id}: get: tags: - OpenAI @@ -1130,6 +1319,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: run_id required: true @@ -1174,7 +1370,7 @@ paths: summary: Modify run description: This endpoint is set up for modifying an OpenAI run. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/runs/modifyRun). - /api/providers/openai/v1/threads/{thread_id}/runs/{run_id}/cancel: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/runs/{run_id}/cancel: post: parameters: - in: header @@ -1192,6 +1388,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: thread_id required: true @@ -1207,7 +1410,7 @@ paths: summary: Cancel a run description: This endpoint is set up for cancelling an OpenAI run. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/runs/cancelRun). - /api/providers/openai/v1/threads/runs: + /api/providers/{openaiProviderName}/v1/threads/runs: post: parameters: - in: header @@ -1225,12 +1428,19 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai tags: - OpenAI summary: Create thread and run description: This endpoint is set up for creating an OpenAI thread and run. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/runs/createThreadAndRun). - /api/providers/openai/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}: get: tags: - OpenAI @@ -1252,6 +1462,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: step_id required: true @@ -1269,7 +1486,7 @@ paths: schema: type: string - /api/providers/openai/v1/threads/{thread_id}/runs/{run_id}/steps: + /api/providers/{openaiProviderName}/v1/threads/{thread_id}/runs/{run_id}/steps: get: parameters: - in: header @@ -1287,6 +1504,13 @@ paths: schema: type: string description: Timeout for the request. Format can be `1s`, `1m`, `1h`, etc. + - in: path + name: openaiProviderName + required: true + schema: + type: string + description: Name of OpenAI provider + example: openai - in: path name: thread_id required: true diff --git a/go.mod b/go.mod index d917dea..df95320 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.27.7 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2 github.com/aws/aws-sdk-go-v2/service/comprehend v1.31.2 - github.com/caarlos0/env v3.5.0+incompatible + github.com/caarlos0/env/v11 v11.3.1 github.com/cenkalti/backoff/v4 v4.3.0 github.com/fatih/color v1.15.0 github.com/gin-gonic/gin v1.9.1 diff --git a/go.sum b/go.sum index 4e61f2c..ec4bad0 100644 --- a/go.sum +++ b/go.sum @@ -59,8 +59,8 @@ github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/caarlos0/env v3.5.0+incompatible h1:Yy0UN8o9Wtr/jGHZDpCBLpNrzcFLLM2yixi/rBrKyJs= -github.com/caarlos0/env v3.5.0+incompatible/go.mod h1:tdCsowwCzMLdkqRYDlHpZCp2UooDD3MspDBjZ2AD02Y= +github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= +github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index bec413c..8b9c575 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -163,13 +163,21 @@ func (a *Authenticator) getProviderSettingsThatCanAccessCustomRoute(path string, return selected } -func canAccessPath(provider string, path string) bool { +func canAccessPath(provider string, path string, openAiUrls map[string]string) bool { if provider == "bedrock" && !strings.HasPrefix(path, "/api/providers/bedrock") { return false } - if provider == "openai" && !strings.HasPrefix(path, "/api/providers/openai") { - return false + if provider == "openai" { + matchedPaths := false + for openAiName := range openAiUrls { + if strings.HasPrefix(path, "/api/providers/"+openAiName) { + matchedPaths = true + } + } + if !matchedPaths { + return false + } } if provider == "azure" && !strings.HasPrefix(path, "/api/providers/azure/openai") { @@ -204,7 +212,7 @@ func anonymize(input string) string { return string(input[0:5]) + "**********************************************" } -func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) { +func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, openAiUrls map[string]string) (*key.ResponseKey, []*provider.Setting, error) { raw, err := getApiKey(req) if err != nil { return nil, nil, err @@ -255,7 +263,7 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons continue } - if canAccessPath(setting.Provider, req.URL.Path) { + if canAccessPath(setting.Provider, req.URL.Path, openAiUrls) { selected = append(selected, setting) } diff --git a/internal/config/config.go b/internal/config/config.go index 7abedde..f40afa8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,7 +6,7 @@ import ( "path/filepath" "time" - "github.com/caarlos0/env" + "github.com/caarlos0/env/v11" "github.com/joho/godotenv" "github.com/knadh/koanf/parsers/json" @@ -16,41 +16,42 @@ import ( ) type Config struct { - PostgresqlHosts string `koanf:"postgresql_hosts" env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"` - PostgresqlDbName string `koanf:"postgresql_db_name" env:"POSTGRESQL_DB_NAME"` - PostgresqlUsername string `koanf:"postgresql_username" env:"POSTGRESQL_USERNAME"` - PostgresqlPassword string `koanf:"postgresql_password" env:"POSTGRESQL_PASSWORD"` - PostgresqlSslMode string `koanf:"postgresql_ssl_mode" env:"POSTGRESQL_SSL_MODE" envDefault:"disable"` - PostgresqlPort string `koanf:"postgresql_port" env:"POSTGRESQL_PORT" envDefault:"5432"` - RedisHosts string `koanf:"redis_hosts" env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"` - RedisPort string `koanf:"redis_port" env:"REDIS_PORT" envDefault:"6379"` - RedisUsername string `koanf:"redis_username" env:"REDIS_USERNAME"` - RedisPassword string `koanf:"redis_password" env:"REDIS_PASSWORD"` - RedisDBStartIndex int `koanf:"redis_db_start_index" env:"REDIS_DB_START_INDEX" envDefault:"0"` - RedisReadTimeout time.Duration `koanf:"redis_read_time_out" env:"REDIS_READ_TIME_OUT" envDefault:"1s"` - RedisWriteTimeout time.Duration `koanf:"redis_write_time_out" env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"` - PostgresqlReadTimeout time.Duration `koanf:"postgresql_read_time_out" env:"POSTGRESQL_READ_TIME_OUT" envDefault:"10m"` - PostgresqlWriteTimeout time.Duration `koanf:"postgresql_write_time_out" env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"5s"` - InMemoryDbUpdateInterval time.Duration `koanf:"in_memory_db_update_interval" env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"` - TelemetryProvider string `koanf:"telemetry_provider" env:"TELEMETRY_PROVIDER" envDefault:"statsd"` - StatsEnabled bool `koanf:"stats_enabled" env:"STATS_ENABLED" envDefault:"true"` - StatsAddress string `koanf:"stats_address" env:"STATS_ADDRESS" envDefault:"127.0.0.1:8125"` - PrometheusEnabled bool `koanf:"prometheus_enabled" env:"PROMETHEUS_ENABLED" envDefault:"true"` - PrometheusPort string `koanf:"prometheus_port" env:"PROMETHEUS_PORT" envDefault:"2112"` - AdminPass string `koanf:"admin_pass" env:"ADMIN_PASS"` - ProxyTimeout time.Duration `koanf:"proxy_timeout" env:"PROXY_TIMEOUT" envDefault:"600s"` - NumberOfEventMessageConsumers int `koanf:"number_of_event_message_consumers" env:"NUMBER_OF_EVENT_MESSAGE_CONSUMERS" envDefault:"3"` - OpenAiApiKey string `koanf:"openai_api_key" env:"OPENAI_API_KEY"` - CustomPolicyDetectionTimeout time.Duration `koanf:"custom_policy_detection_timeout" env:"CUSTOM_POLICY_DETECTION_TIMEOUT" envDefault:"10m"` - AmazonRegion string `koanf:"amazon_region" env:"AMAZON_REGION" envDefault:"us-west-2"` - AmazonRequestTimeout time.Duration `koanf:"amazon_request_timeout" env:"AMAZON_REQUEST_TIMEOUT" envDefault:"5s"` - AmazonConnectionTimeout time.Duration `koanf:"amazon_connection_timeout" env:"AMAZON_CONNECTION_TIMEOUT" envDefault:"10s"` - RemoveUserAgent bool `koanf:"remove_user_agent" env:"REMOVE_USER_AGENT" envDefault:"false"` - EnableEncrytion bool `koanf:"enable_encryption" env:"ENABLE_ENCRYPTION" envDefault:"false"` - EncryptionEndpoint string `koanf:"encryption_endpoint" env:"ENCRYPTION_ENDPOINT"` - DecryptionEndpoint string `koanf:"decryption_endpoint" env:"DECRYPTION_ENDPOINT"` - EncryptionTimeout time.Duration `koanf:"encryption_timeout" env:"ENCRYPTION_TIMEOUT" envDefault:"5s"` - Audience string `koanf:"audience" env:"AUDIENCE"` + PostgresqlHosts string `koanf:"postgresql_hosts" env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"` + PostgresqlDbName string `koanf:"postgresql_db_name" env:"POSTGRESQL_DB_NAME"` + PostgresqlUsername string `koanf:"postgresql_username" env:"POSTGRESQL_USERNAME"` + PostgresqlPassword string `koanf:"postgresql_password" env:"POSTGRESQL_PASSWORD"` + PostgresqlSslMode string `koanf:"postgresql_ssl_mode" env:"POSTGRESQL_SSL_MODE" envDefault:"disable"` + PostgresqlPort string `koanf:"postgresql_port" env:"POSTGRESQL_PORT" envDefault:"5432"` + RedisHosts string `koanf:"redis_hosts" env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"` + RedisPort string `koanf:"redis_port" env:"REDIS_PORT" envDefault:"6379"` + RedisUsername string `koanf:"redis_username" env:"REDIS_USERNAME"` + RedisPassword string `koanf:"redis_password" env:"REDIS_PASSWORD"` + RedisDBStartIndex int `koanf:"redis_db_start_index" env:"REDIS_DB_START_INDEX" envDefault:"0"` + RedisReadTimeout time.Duration `koanf:"redis_read_time_out" env:"REDIS_READ_TIME_OUT" envDefault:"1s"` + RedisWriteTimeout time.Duration `koanf:"redis_write_time_out" env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"` + PostgresqlReadTimeout time.Duration `koanf:"postgresql_read_time_out" env:"POSTGRESQL_READ_TIME_OUT" envDefault:"10m"` + PostgresqlWriteTimeout time.Duration `koanf:"postgresql_write_time_out" env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"5s"` + InMemoryDbUpdateInterval time.Duration `koanf:"in_memory_db_update_interval" env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"` + TelemetryProvider string `koanf:"telemetry_provider" env:"TELEMETRY_PROVIDER" envDefault:"statsd"` + StatsEnabled bool `koanf:"stats_enabled" env:"STATS_ENABLED" envDefault:"true"` + StatsAddress string `koanf:"stats_address" env:"STATS_ADDRESS" envDefault:"127.0.0.1:8125"` + PrometheusEnabled bool `koanf:"prometheus_enabled" env:"PROMETHEUS_ENABLED" envDefault:"true"` + PrometheusPort string `koanf:"prometheus_port" env:"PROMETHEUS_PORT" envDefault:"2112"` + AdminPass string `koanf:"admin_pass" env:"ADMIN_PASS"` + ProxyTimeout time.Duration `koanf:"proxy_timeout" env:"PROXY_TIMEOUT" envDefault:"600s"` + NumberOfEventMessageConsumers int `koanf:"number_of_event_message_consumers" env:"NUMBER_OF_EVENT_MESSAGE_CONSUMERS" envDefault:"3"` + OpenAiApiKey string `koanf:"openai_api_key" env:"OPENAI_API_KEY"` + OpenAiUrls map[string]string `koanf:"openai_api_urls" env:"OPENAI_API_URLS" envKeyValSeparator:"=" envSeparator:"," envDefault:"openai=https://api.openai.com,openai-eu=https://eu.api.openai.com"` + CustomPolicyDetectionTimeout time.Duration `koanf:"custom_policy_detection_timeout" env:"CUSTOM_POLICY_DETECTION_TIMEOUT" envDefault:"10m"` + AmazonRegion string `koanf:"amazon_region" env:"AMAZON_REGION" envDefault:"us-west-2"` + AmazonRequestTimeout time.Duration `koanf:"amazon_request_timeout" env:"AMAZON_REQUEST_TIMEOUT" envDefault:"5s"` + AmazonConnectionTimeout time.Duration `koanf:"amazon_connection_timeout" env:"AMAZON_CONNECTION_TIMEOUT" envDefault:"10s"` + RemoveUserAgent bool `koanf:"remove_user_agent" env:"REMOVE_USER_AGENT" envDefault:"false"` + EnableEncrytion bool `koanf:"enable_encryption" env:"ENABLE_ENCRYPTION" envDefault:"false"` + EncryptionEndpoint string `koanf:"encryption_endpoint" env:"ENCRYPTION_ENDPOINT"` + DecryptionEndpoint string `koanf:"decryption_endpoint" env:"DECRYPTION_ENDPOINT"` + EncryptionTimeout time.Duration `koanf:"encryption_timeout" env:"ENCRYPTION_TIMEOUT" envDefault:"5s"` + Audience string `koanf:"audience" env:"AUDIENCE"` } func prepareDotEnv(envFilePath string) error { diff --git a/internal/message/consumer.go b/internal/message/consumer.go index 115f590..c3a7581 100644 --- a/internal/message/consumer.go +++ b/internal/message/consumer.go @@ -11,7 +11,8 @@ type Consumer struct { done chan bool log *zap.Logger numOfEventConsumers int - handle func(Message) error + handle func(Message, map[string]string) error + openAiUrls map[string]string } type recorder interface { @@ -20,13 +21,14 @@ type recorder interface { RecordEvent(e *event.Event) error } -func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message) error) *Consumer { +func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message, map[string]string) error, openAiUrls map[string]string) *Consumer { return &Consumer{ messageChan: mc, done: make(chan bool), log: log, numOfEventConsumers: num, handle: handle, + openAiUrls: openAiUrls, } } @@ -40,7 +42,7 @@ func (c *Consumer) StartEventMessageConsumers() { return case m := <-c.messageChan: - err := c.handle(m) + err := c.handle(m, c.openAiUrls) if err != nil { continue } diff --git a/internal/message/handler.go b/internal/message/handler.go index e216820..fbdee51 100644 --- a/internal/message/handler.go +++ b/internal/message/handler.go @@ -2,6 +2,7 @@ package message import ( "errors" + "fmt" "net/http" "strings" "time" @@ -313,7 +314,7 @@ func (h *Handler) handleUserValidationResult(u *user.User, cost float64) error { return nil } -func (h *Handler) HandleEventWithRequestAndResponse(m Message) error { +func (h *Handler) HandleEventWithRequestAndResponse(m Message, openAiUrls map[string]string) error { e, ok := m.Data.(*event.EventWithRequestAndContent) if !ok { telemetry.Incr("bricksllm.message.handler.handle_event_with_request_and_response.message_data_parsing_error", nil, 1) @@ -322,7 +323,7 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error { } if e.Key != nil && !e.Key.Revoked && e.Event != nil { - err := h.decorateEvent(m) + err := h.decorateEvent(m, openAiUrls) if err != nil { telemetry.Incr("bricksllm.message.handler.handle_event_with_request_and_response.decorate_event_error", nil, 1) h.log.Debug("error when decorating event", zap.Error(err)) @@ -403,7 +404,7 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error { return nil } -func (h *Handler) decorateEvent(m Message) error { +func (h *Handler) decorateEvent(m Message, openAiUrls map[string]string) error { telemetry.Incr("bricksllm.message.handler.decorate_event.request", nil, 1) e, ok := m.Data.(*event.EventWithRequestAndContent) @@ -413,23 +414,25 @@ func (h *Handler) decorateEvent(m Message) error { return errors.New("message data cannot be parsed as event with request and response") } - if e.Event.Path == "/api/providers/openai/v1/audio/speech" { - csr, ok := e.Request.(*goopenai.CreateSpeechRequest) - if !ok { - telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) - h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data)) - return errors.New("event request data cannot be parsed as anthropic completion request") - } - - if e.Event.Status == http.StatusOK { - cost, err := h.e.EstimateSpeechCost(csr.Input, string(csr.Model)) - if err != nil { - telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1) - h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err)) - return err + for openAiName := range openAiUrls { + if e.Event.Path == fmt.Sprintf("/api/providers/%s/v1/audio/speech", openAiName) { + csr, ok := e.Request.(*goopenai.CreateSpeechRequest) + if !ok { + telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) + h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as anthropic completion request") } - e.Event.CostInUsd = cost + if e.Event.Status == http.StatusOK { + cost, err := h.e.EstimateSpeechCost(csr.Input, string(csr.Model)) + if err != nil { + telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1) + h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err)) + return err + } + + e.Event.CostInUsd = cost + } } } @@ -607,42 +610,44 @@ func (h *Handler) decorateEvent(m Message) error { } } - if e.Event.Path == "/api/providers/openai/v1/chat/completions" { - ccr, ok := e.Request.(*goopenai.ChatCompletionRequest) - if !ok { - telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) - h.log.Debug("event contains data that cannot be converted to openai completion request", zap.Any("data", m.Data)) - return errors.New("event request data cannot be parsed as openai completion request") - } - - if ccr.Stream { - tks, cost, err := h.e.EstimateChatCompletionPromptCostWithTokenCounts(ccr) - if err != nil { - telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_prompt_cost_with_token_counts", nil, 1) - return err + for openAiName := range openAiUrls { + if e.Event.Path == fmt.Sprintf("/api/providers/%s/v1/chat/completions", openAiName) { + ccr, ok := e.Request.(*goopenai.ChatCompletionRequest) + if !ok { + telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) + h.log.Debug("event contains data that cannot be converted to openai completion request", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as openai completion request") } - completiontks, completionCost, err := h.e.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content) - if err != nil { - telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_stream_cost_with_token_counts", nil, 1) - return err - } + if ccr.Stream { + tks, cost, err := h.e.EstimateChatCompletionPromptCostWithTokenCounts(ccr) + if err != nil { + telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_prompt_cost_with_token_counts", nil, 1) + return err + } - e.Event.PromptTokenCount = tks - e.Event.CompletionTokenCount = completiontks + completiontks, completionCost, err := h.e.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content) + if err != nil { + telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_stream_cost_with_token_counts", nil, 1) + return err + } - if e.Event.Status == http.StatusOK { - e.Event.CostInUsd = cost + completionCost + e.Event.PromptTokenCount = tks + e.Event.CompletionTokenCount = completiontks - if e.CostMap != nil { - newCost, err := provider.EstimateTotalCostWithCostMaps(e.Event.Model, tks, completiontks, 1000, e.CostMap.PromptCostPerModel, e.CostMap.CompletionCostPerModel) - if err != nil { - h.log.Debug("error when estimating total cost with cost maps", zap.Error(err)) - telemetry.Incr("bricksllm.proxy.decorate_event.estimate_total_cost_with_cost_maps_error", nil, 1) - } + if e.Event.Status == http.StatusOK { + e.Event.CostInUsd = cost + completionCost - if newCost != 0 { - e.Event.CostInUsd = newCost + if e.CostMap != nil { + newCost, err := provider.EstimateTotalCostWithCostMaps(e.Event.Model, tks, completiontks, 1000, e.CostMap.PromptCostPerModel, e.CostMap.CompletionCostPerModel) + if err != nil { + h.log.Debug("error when estimating total cost with cost maps", zap.Error(err)) + telemetry.Incr("bricksllm.proxy.decorate_event.estimate_total_cost_with_cost_maps_error", nil, 1) + } + + if newCost != 0 { + e.Event.CostInUsd = newCost + } } } } diff --git a/internal/route/route.go b/internal/route/route.go index b4b3ebf..27a6279 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -243,7 +243,7 @@ func InitializeBackoff(strategy string, dur time.Duration) backoff.BackOff { return backoff.NewConstantBackOff(dur) } -func (r *Route) RunStepsV2(req *Request, rec recorder, log *zap.Logger, kc *key.ResponseKey) (*Response, error) { +func (r *Route) RunStepsV2(req *Request, rec recorder, log *zap.Logger, kc *key.ResponseKey, openAiUrls map[string]string) (*Response, error) { if len(r.Steps) == 0 { return nil, errors.New("steps are empty") } @@ -320,7 +320,7 @@ func (r *Route) RunStepsV2(req *Request, rec recorder, log *zap.Logger, kc *key. } }() - hreq, err := req.createHttpRequest(ctx, step.Provider, r.ShouldRunEmbeddings(), step.Params, bs) + hreq, err := req.createHttpRequest(ctx, step.Provider, r.ShouldRunEmbeddings(), step.Params, bs, openAiUrls) if err != nil { return err } @@ -389,7 +389,7 @@ func (r *Route) RunStepsV2(req *Request, rec recorder, log *zap.Logger, kc *key. return nil, errors.New("no responses") } -func (r *Route) RunSteps(req *Request, rec recorder, log *zap.Logger) (*Response, error) { +func (r *Route) RunSteps(req *Request, rec recorder, log *zap.Logger, openAiUrls map[string]string) (*Response, error) { if len(r.Steps) == 0 { return nil, errors.New("steps are empty") } @@ -484,7 +484,7 @@ func (r *Route) RunSteps(req *Request, rec recorder, log *zap.Logger) (*Response } } - url := buildRequestUrl(step.Provider, r.ShouldRunEmbeddings(), resourceName, step.Params) + url := buildRequestUrl(step.Provider, r.ShouldRunEmbeddings(), resourceName, step.Params, openAiUrls) if len(url) == 0 { return nil, errors.New("only azure openai, openai chat completion and embeddings models are supported") @@ -648,13 +648,14 @@ type Response struct { Response *http.Response } -func buildRequestUrl(provider string, runEmbeddings bool, resourceName string, params map[string]string) string { - if provider == "openai" && runEmbeddings { - return "https://api.openai.com/v1/embeddings" - } - - if provider == "openai" && !runEmbeddings { - return "https://api.openai.com/v1/chat/completions" +func buildRequestUrl(provider string, runEmbeddings bool, resourceName string, params, openAiUrls map[string]string) string { + // Configure URL based on provider + if openAiUrl, ok := openAiUrls[provider]; ok { + if runEmbeddings { + return openAiUrl + "/v1/embeddings" + } else { + return openAiUrl + "/v1/chat/completions" + } } deploymentId := params["deploymentId"] @@ -680,7 +681,7 @@ func setHttpRequestAuthHeader(provider string, req *http.Request, key string) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) } -func (r *Request) createHttpRequest(ctx context.Context, provider string, runEmbeddings bool, params map[string]string, data []byte) (*http.Request, error) { +func (r *Request) createHttpRequest(ctx context.Context, provider string, runEmbeddings bool, params map[string]string, data []byte, openAiUrls map[string]string) (*http.Request, error) { resourceName := "" if provider == "azure" { val, err := r.GetSettingValue("azure", "resourceName") @@ -696,7 +697,7 @@ func (r *Request) createHttpRequest(ctx context.Context, provider string, runEmb return nil, err } - url := buildRequestUrl(provider, runEmbeddings, resourceName, params) + url := buildRequestUrl(provider, runEmbeddings, resourceName, params, openAiUrls) if len(url) == 0 { return nil, errors.New("request url is empty") } diff --git a/internal/server/web/proxy/audio.go b/internal/server/web/proxy/audio.go index af6e2a9..559be7f 100644 --- a/internal/server/web/proxy/audio.go +++ b/internal/server/web/proxy/audio.go @@ -20,7 +20,7 @@ import ( "go.uber.org/zap/zapcore" ) -func getSpeechHandler(prod bool, client http.Client) gin.HandlerFunc { +func getSpeechHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_speech_handler.requests", nil, 1) @@ -33,7 +33,7 @@ func getSpeechHandler(prod bool, client http.Client) gin.HandlerFunc { ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/speech", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, c.Request.Method, openAiUrl+"/v1/audio/speech", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai http request") @@ -167,7 +167,7 @@ func getContentType(format string) string { return "text/plain; charset=utf-8" } -func getTranscriptionsHandler(prod bool, client http.Client, e estimator) gin.HandlerFunc { +func getTranscriptionsHandler(prod bool, client http.Client, e estimator, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_transcriptions_handler.requests", nil, 1) @@ -180,7 +180,7 @@ func getTranscriptionsHandler(prod bool, client http.Client, e estimator) gin.Ha ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/transcriptions", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, c.Request.Method, openAiUrl+"/v1/audio/transcriptions", c.Request.Body) if err != nil { logError(log, "error when creating transcriptions openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai transcriptions http request") @@ -331,7 +331,7 @@ func getTranscriptionsHandler(prod bool, client http.Client, e estimator) gin.Ha } } -func getTranslationsHandler(prod bool, client http.Client, e estimator) gin.HandlerFunc { +func getTranslationsHandler(prod bool, client http.Client, e estimator, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_translations_handler.requests", nil, 1) @@ -344,7 +344,7 @@ func getTranslationsHandler(prod bool, client http.Client, e estimator) gin.Hand ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/translations", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, c.Request.Method, openAiUrl+"/v1/audio/translations", c.Request.Body) if err != nil { logError(log, "error when creating translations openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai translations http request") diff --git a/internal/server/web/proxy/chat_completion.go b/internal/server/web/proxy/chat_completion.go index 19c6245..2f77ec9 100644 --- a/internal/server/web/proxy/chat_completion.go +++ b/internal/server/web/proxy/chat_completion.go @@ -17,7 +17,7 @@ import ( goopenai "github.com/sashabaranov/go-openai" ) -func getChatCompletionHandler(prod, private bool, client http.Client, e estimator) gin.HandlerFunc { +func getChatCompletionHandler(prod, private bool, client http.Client, e estimator, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_chat_completion_handler.requests", nil, 1) @@ -30,7 +30,7 @@ func getChatCompletionHandler(prod, private bool, client http.Client, e estimato ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/chat/completions", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openAiUrl+"/v1/chat/completions", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") diff --git a/internal/server/web/proxy/embedding.go b/internal/server/web/proxy/embedding.go index 1d689e3..dade156 100644 --- a/internal/server/web/proxy/embedding.go +++ b/internal/server/web/proxy/embedding.go @@ -30,7 +30,7 @@ type EmbeddingResponseBase64 struct { Usage goopenai.Usage `json:"usage"` } -func getEmbeddingHandler(prod, private bool, client http.Client, e estimator) gin.HandlerFunc { +func getEmbeddingHandler(prod, private bool, client http.Client, e estimator, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_embedding_handler.requests", nil, 1) @@ -50,7 +50,7 @@ func getEmbeddingHandler(prod, private bool, client http.Client, e estimator) gi ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/embeddings", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, c.Request.Method, openAiUrl+"/v1/embeddings", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai http request") diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index 9ce91ab..db8f1d2 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -67,7 +67,7 @@ type deepinfraEstimator interface { } type authenticator interface { - AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) + AuthenticateHttpRequest(req *http.Request, openAiUrls map[string]string) (*key.ResponseKey, []*provider.Setting, error) } type validator interface { @@ -169,7 +169,7 @@ type CustomPolicyDetector interface { Detect(input []string, requirements []string) (bool, error) } -func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManager, a authenticator, prod, private bool, log *zap.Logger, pub publisher, prefix string, ac accessCache, uac userAccessCache, client http.Client, scanner Scanner, cd CustomPolicyDetector, um userManager, removeUserAgent bool) gin.HandlerFunc { +func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManager, a authenticator, prod, private bool, log *zap.Logger, pub publisher, prefix string, ac accessCache, uac userAccessCache, client http.Client, scanner Scanner, cd CustomPolicyDetector, um userManager, removeUserAgent bool, openAiUrls map[string]string) gin.HandlerFunc { return func(c *gin.Context) { if c == nil || c.Request == nil { JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty") @@ -304,7 +304,7 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag return } - kc, settings, err := a.AuthenticateHttpRequest(c.Request) + kc, settings, err := a.AuthenticateHttpRequest(c.Request, openAiUrls) enrichedEvent.Key = kc _, ok := err.(notAuthorizedError) if ok { @@ -762,389 +762,392 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag policyInput = er } - if c.FullPath() == "/api/providers/openai/v1/chat/completions" { - ccr := &goopenai.ChatCompletionRequest{} - // this is a hack around an open issue in go-openai. - // https://github.com/sashabaranov/go-openai/issues/884 - cleaned, err := sjson.Delete(string(body), "response_format.json_schema") - if err != nil { - logWithCid.Warn("removing response_format.json_schema", zap.Error(err)) - } - err = json.Unmarshal([]byte(cleaned), ccr) - if err != nil { - logError(logWithCid, "error when unmarshalling chat completion request", prod, err) - return - } - - userId = ccr.User + for openAiName := range openAiUrls { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/chat/completions", openAiName) { + ccr := &goopenai.ChatCompletionRequest{} + // this is a hack around an open issue in go-openai. + // https://github.com/sashabaranov/go-openai/issues/884 + cleaned, err := sjson.Delete(string(body), "response_format.json_schema") + if err != nil { + logWithCid.Warn("removing response_format.json_schema", zap.Error(err)) + } + err = json.Unmarshal([]byte(cleaned), ccr) + if err != nil { + logError(logWithCid, "error when unmarshalling chat completion request", prod, err) + return + } - enrichedEvent.Request = ccr + userId = ccr.User - c.Set("model", ccr.Model) + enrichedEvent.Request = ccr - logRequest(logWithCid, prod, private, ccr) + c.Set("model", ccr.Model) - if ccr.Stream { - c.Set("stream", true) - } + logRequest(logWithCid, prod, private, ccr) - policyInput = ccr - } + if ccr.Stream { + c.Set("stream", true) + } - if c.FullPath() == "/api/providers/openai/v1/embeddings" { - er := &goopenai.EmbeddingRequest{} - err = json.Unmarshal(body, er) - if err != nil { - logError(logWithCid, "error when unmarshalling embedding request", prod, err) - return + policyInput = ccr } - userId = er.User + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/embeddings", openAiName) { + er := &goopenai.EmbeddingRequest{} + err = json.Unmarshal(body, er) + if err != nil { + logError(logWithCid, "error when unmarshalling embedding request", prod, err) + return + } - c.Set("model", string(er.Model)) - c.Set("encoding_format", string(er.EncodingFormat)) + userId = er.User - logEmbeddingRequest(logWithCid, prod, private, er) + c.Set("model", string(er.Model)) + c.Set("encoding_format", string(er.EncodingFormat)) - policyInput = er - } + logEmbeddingRequest(logWithCid, prod, private, er) - if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost { - ir := &goopenai.ImageRequest{} - err := json.Unmarshal(body, ir) - if err != nil { - logError(logWithCid, "error when unmarshalling create image request", prod, err) - return + policyInput = er } - c.Set("model", ir.Model) + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/generations", openAiName) && c.Request.Method == http.MethodPost { + ir := &goopenai.ImageRequest{} + err := json.Unmarshal(body, ir) + if err != nil { + logError(logWithCid, "error when unmarshalling create image request", prod, err) + return + } + + c.Set("model", ir.Model) + + if len(ir.Model) == 0 { + c.Set("model", "dall-e-2") + } - if len(ir.Model) == 0 { - c.Set("model", "dall-e-2") + c.Set("model", ir.Model) + logCreateImageRequest(logWithCid, ir, prod, private) } - c.Set("model", ir.Model) - logCreateImageRequest(logWithCid, ir, prod, private) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/edits", openAiName) && c.Request.Method == http.MethodPost { + prompt := c.PostForm("model") + model := c.PostForm("model") + size := c.PostForm("size") + user := c.PostForm("user") - if c.FullPath() == "/api/providers/openai/v1/images/edits" && c.Request.Method == http.MethodPost { - prompt := c.PostForm("model") - model := c.PostForm("model") - size := c.PostForm("size") - user := c.PostForm("user") + userId = user - userId = user + responseFormat := c.PostForm("response_format") + n, _ := strconv.Atoi(c.PostForm("n")) - responseFormat := c.PostForm("response_format") - n, _ := strconv.Atoi(c.PostForm("n")) + c.Set("model", model) - c.Set("model", model) + if len(model) == 0 { + c.Set("model", "dall-e-2") + } - if len(model) == 0 { - c.Set("model", "dall-e-2") + logEditImageRequest(logWithCid, prompt, model, n, size, responseFormat, user, prod, private) } - logEditImageRequest(logWithCid, prompt, model, n, size, responseFormat, user, prod, private) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/variations", openAiName) && c.Request.Method == http.MethodPost { + model := c.PostForm("model") + size := c.PostForm("size") + user := c.PostForm("user") - if c.FullPath() == "/api/providers/openai/v1/images/variations" && c.Request.Method == http.MethodPost { - model := c.PostForm("model") - size := c.PostForm("size") - user := c.PostForm("user") + userId = user - userId = user + responseFormat := c.PostForm("response_format") + n, _ := strconv.Atoi(c.PostForm("n")) - responseFormat := c.PostForm("response_format") - n, _ := strconv.Atoi(c.PostForm("n")) + c.Set("model", model) - c.Set("model", model) + if len(model) == 0 { + c.Set("model", "dall-e-2") + } - if len(model) == 0 { - c.Set("model", "dall-e-2") + logImageVariationsRequest(logWithCid, model, n, size, responseFormat, user, prod) } - logImageVariationsRequest(logWithCid, model, n, size, responseFormat, user, prod) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/audio/speech", openAiName) && c.Request.Method == http.MethodPost { + sr := &goopenai.CreateSpeechRequest{} + err := json.Unmarshal(body, sr) + if err != nil { + logError(logWithCid, "error when unmarshalling create speech request", prod, err) + return + } - if c.FullPath() == "/api/providers/openai/v1/audio/speech" && c.Request.Method == http.MethodPost { - sr := &goopenai.CreateSpeechRequest{} - err := json.Unmarshal(body, sr) - if err != nil { - logError(logWithCid, "error when unmarshalling create speech request", prod, err) - return + enrichedEvent.Request = sr + + c.Set("model", string(sr.Model)) + + logCreateSpeechRequest(logWithCid, sr, prod, private) } - enrichedEvent.Request = sr + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/audio/transcriptions", openAiName) && c.Request.Method == http.MethodPost { + model := c.PostForm("model") + language := c.PostForm("language") + prompt := c.PostForm("prompt") + responseFormat := c.PostForm("response_format") + temperature := c.PostForm("temperature") - c.Set("model", string(sr.Model)) + c.Set("model", model) - logCreateSpeechRequest(logWithCid, sr, prod, private) - } + converted, _ := strconv.ParseFloat(temperature, 64) + logCreateTranscriptionRequest(logWithCid, model, language, prompt, responseFormat, converted, prod, private) + } - if c.FullPath() == "/api/providers/openai/v1/audio/transcriptions" && c.Request.Method == http.MethodPost { - model := c.PostForm("model") - language := c.PostForm("language") - prompt := c.PostForm("prompt") - responseFormat := c.PostForm("response_format") - temperature := c.PostForm("temperature") + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/audio/translations", openAiName) && c.Request.Method == http.MethodPost { + model := c.PostForm("model") + prompt := c.PostForm("prompt") + responseFormat := c.PostForm("response_format") + temperature := c.PostForm("temperature") - c.Set("model", model) + c.Set("model", model) - converted, _ := strconv.ParseFloat(temperature, 64) - logCreateTranscriptionRequest(logWithCid, model, language, prompt, responseFormat, converted, prod, private) - } + converted, _ := strconv.ParseFloat(temperature, 64) + logCreateTranslationRequest(logWithCid, model, prompt, responseFormat, converted, prod, private) + } - if c.FullPath() == "/api/providers/openai/v1/audio/translations" && c.Request.Method == http.MethodPost { - model := c.PostForm("model") - prompt := c.PostForm("prompt") - responseFormat := c.PostForm("response_format") - temperature := c.PostForm("temperature") + if len(kc.AllowedPaths) != 0 && !containsPath(kc.AllowedPaths, c.FullPath(), c.Request.Method) { + telemetry.Incr("bricksllm.proxy.get_middleware.path_not_allowed", nil, 1) + JSON(c, http.StatusForbidden, "[BricksLLM] path is not allowed") + c.Abort() + return + } - c.Set("model", model) + model := c.GetString("model") + if !isModelAllowed(model, settings) { + telemetry.Incr("bricksllm.proxy.get_middleware.model_not_allowed", nil, 1) + JSON(c, http.StatusForbidden, "[BricksLLM] model is not allowed") + c.Abort() + return + } - converted, _ := strconv.ParseFloat(temperature, 64) - logCreateTranslationRequest(logWithCid, model, prompt, responseFormat, converted, prod, private) - } + aid := c.Param("assistant_id") + fid := c.Param("file_id") + tid := c.Param("thread_id") + mid := c.Param("message_id") + rid := c.Param("run_id") + sid := c.Param("step_id") + md := c.Param("model") + qm := map[string]string{} - if len(kc.AllowedPaths) != 0 && !containsPath(kc.AllowedPaths, c.FullPath(), c.Request.Method) { - telemetry.Incr("bricksllm.proxy.get_middleware.path_not_allowed", nil, 1) - JSON(c, http.StatusForbidden, "[BricksLLM] path is not allowed") - c.Abort() - return - } + if val, ok := c.GetQuery("limit"); ok { + qm["limit"] = val + } - model := c.GetString("model") - if !isModelAllowed(model, settings) { - telemetry.Incr("bricksllm.proxy.get_middleware.model_not_allowed", nil, 1) - JSON(c, http.StatusForbidden, "[BricksLLM] model is not allowed") - c.Abort() - return - } + if val, ok := c.GetQuery("order"); ok { + qm["order"] = val + } - aid := c.Param("assistant_id") - fid := c.Param("file_id") - tid := c.Param("thread_id") - mid := c.Param("message_id") - rid := c.Param("run_id") - sid := c.Param("step_id") - md := c.Param("model") - qm := map[string]string{} - - if val, ok := c.GetQuery("limit"); ok { - qm["limit"] = val - } + if val, ok := c.GetQuery("after"); ok { + qm["after"] = val + } - if val, ok := c.GetQuery("order"); ok { - qm["order"] = val - } + if val, ok := c.GetQuery("before"); ok { + qm["before"] = val + } - if val, ok := c.GetQuery("after"); ok { - qm["after"] = val - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName) && c.Request.Method == http.MethodPost { + logCreateAssistantRequest(logWithCid, body, prod, private) - if val, ok := c.GetQuery("before"); ok { - qm["before"] = val - } + ar := &goopenai.AssistantRequest{} - if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodPost { - logCreateAssistantRequest(logWithCid, body, prod, private) + err = json.Unmarshal(body, ar) + if err != nil { + logError(logWithCid, "error when unmarshalling assistant request", prod, err) + } - ar := &goopenai.AssistantRequest{} + if err == nil { + c.Set("model", ar.Model) - err = json.Unmarshal(body, ar) - if err != nil { - logError(logWithCid, "error when unmarshalling assistant request", prod, err) + policyInput = ar + } } - if err == nil { - c.Set("model", ar.Model) - - policyInput = ar + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveAssistantRequest(logWithCid, prod, aid) } - } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodGet { - logRetrieveAssistantRequest(logWithCid, prod, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodPost { + logModifyAssistantRequest(logWithCid, body, prod, private, aid) + } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodPost { - logModifyAssistantRequest(logWithCid, body, prod, private, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodDelete { + logDeleteAssistantRequest(logWithCid, prod, aid) + } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodDelete { - logDeleteAssistantRequest(logWithCid, prod, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName) && c.Request.Method == http.MethodGet { + logListAssistantsRequest(logWithCid, prod) + } - if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet { - logListAssistantsRequest(logWithCid, prod) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName) && c.Request.Method == http.MethodPost { + logCreateAssistantFileRequest(logWithCid, body, prod, aid) + } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodPost { - logCreateAssistantFileRequest(logWithCid, body, prod, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveAssistantFileRequest(logWithCid, prod, fid, aid) + } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodGet { - logRetrieveAssistantFileRequest(logWithCid, prod, fid, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName) && c.Request.Method == http.MethodDelete { + logDeleteAssistantFileRequest(logWithCid, prod, fid, aid) + } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodDelete { - logDeleteAssistantFileRequest(logWithCid, prod, fid, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName) && c.Request.Method == http.MethodGet { + logListAssistantFilesRequest(logWithCid, prod, aid, qm) + } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodGet { - logListAssistantFilesRequest(logWithCid, prod, aid, qm) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads", openAiName) && c.Request.Method == http.MethodPost { + logCreateThreadRequest(logWithCid, body, prod, private) - if c.FullPath() == "/api/providers/openai/v1/threads" && c.Request.Method == http.MethodPost { - logCreateThreadRequest(logWithCid, body, prod, private) + tr := &openai.ThreadRequest{} - tr := &openai.ThreadRequest{} + err = json.Unmarshal(body, tr) + if err != nil { + logError(logWithCid, "error when unmarshalling create thread request", prod, err) + } - err = json.Unmarshal(body, tr) - if err != nil { - logError(logWithCid, "error when unmarshalling create thread request", prod, err) + if err == nil { + policyInput = tr + } } - if err == nil { - policyInput = tr + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveThreadRequest(logWithCid, prod, tid) } - } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodGet { - logRetrieveThreadRequest(logWithCid, prod, tid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodPost { + logModifyThreadRequest(logWithCid, body, prod, tid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodPost { - logModifyThreadRequest(logWithCid, body, prod, tid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodDelete { + logDeleteThreadRequest(logWithCid, prod, tid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodDelete { - logDeleteThreadRequest(logWithCid, prod, tid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName) && c.Request.Method == http.MethodPost { + logCreateMessageRequest(logWithCid, body, prod, private) - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodPost { - logCreateMessageRequest(logWithCid, body, prod, private) + mr := &openai.MessageRequest{} + err := json.Unmarshal(body, mr) + if err != nil { + logError(logWithCid, "error when unmarshalling create message request", prod, err) + } - mr := &openai.MessageRequest{} - err := json.Unmarshal(body, mr) - if err != nil { - logError(logWithCid, "error when unmarshalling create message request", prod, err) + if err == nil { + policyInput = mr + } } - if err == nil { - policyInput = mr + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveMessageRequest(logWithCid, prod, mid, tid) } - } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id" && c.Request.Method == http.MethodGet { - logRetrieveMessageRequest(logWithCid, prod, mid, tid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName) && c.Request.Method == http.MethodPost { + logModifyMessageRequest(logWithCid, body, prod, private, tid, mid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id" && c.Request.Method == http.MethodPost { - logModifyMessageRequest(logWithCid, body, prod, private, tid, mid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName) && c.Request.Method == http.MethodGet { + logListMessagesRequest(logWithCid, prod, aid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodGet { - logListMessagesRequest(logWithCid, prod, aid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveMessageFileRequest(logWithCid, prod, mid, tid, fid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id" && c.Request.Method == http.MethodGet { - logRetrieveMessageFileRequest(logWithCid, prod, mid, tid, fid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files", openAiName) && c.Request.Method == http.MethodGet { + logListMessageFilesRequest(logWithCid, prod, tid, mid, qm) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files" && c.Request.Method == http.MethodGet { - logListMessageFilesRequest(logWithCid, prod, tid, mid, qm) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName) && c.Request.Method == http.MethodPost { + logCreateRunRequest(logWithCid, body, prod, private) - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodPost { - logCreateRunRequest(logWithCid, body, prod, private) + rr := &goopenai.RunRequest{} + err := json.Unmarshal(body, rr) + if err != nil { + logError(logWithCid, "error when unmarshalling create run request", prod, err) + } - rr := &goopenai.RunRequest{} - err := json.Unmarshal(body, rr) - if err != nil { - logError(logWithCid, "error when unmarshalling create run request", prod, err) + if err == nil { + c.Set("model", rr.Model) + policyInput = rr + } } - if err == nil { - c.Set("model", rr.Model) - policyInput = rr + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveRunRequest(logWithCid, prod, tid, rid) } - } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodGet { - logRetrieveRunRequest(logWithCid, prod, tid, rid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName) && c.Request.Method == http.MethodPost { + logModifyRunRequest(logWithCid, body, prod, tid, rid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodPost { - logModifyRunRequest(logWithCid, body, prod, tid, rid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName) && c.Request.Method == http.MethodGet { + logListRunsRequest(logWithCid, prod, tid, qm) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodGet { - logListRunsRequest(logWithCid, prod, tid, qm) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", openAiName) && c.Request.Method == http.MethodPost { + logSubmitToolOutputsRequest(logWithCid, body, prod, tid, rid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs" && c.Request.Method == http.MethodPost { - logSubmitToolOutputsRequest(logWithCid, body, prod, tid, rid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/cancel", openAiName) && c.Request.Method == http.MethodPost { + logCancelARunRequest(logWithCid, prod, tid, rid) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel" && c.Request.Method == http.MethodPost { - logCancelARunRequest(logWithCid, prod, tid, rid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/runs", openAiName) && c.Request.Method == http.MethodPost { + logCreateThreadAndRunRequest(logWithCid, body, prod, private) - if c.FullPath() == "/api/providers/openai/v1/threads/runs" && c.Request.Method == http.MethodPost { - logCreateThreadAndRunRequest(logWithCid, body, prod, private) + r := &openai.CreateThreadAndRunRequest{} + err := json.Unmarshal(body, r) + if err != nil { + logError(logWithCid, "error when unmarshalling create thread and run request", prod, err) + } - r := &openai.CreateThreadAndRunRequest{} - err := json.Unmarshal(body, r) - if err != nil { - logError(logWithCid, "error when unmarshalling create thread and run request", prod, err) + if err == nil { + c.Set("model", r.Model) + policyInput = r + } } - if err == nil { - c.Set("model", r.Model) - policyInput = r + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps/:step_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveRunStepRequest(logWithCid, prod, tid, rid, sid) } - } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id" && c.Request.Method == http.MethodGet { - logRetrieveRunStepRequest(logWithCid, prod, tid, rid, sid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps", openAiName) && c.Request.Method == http.MethodGet { + logListRunStepsRequest(logWithCid, prod, tid, rid, qm) + } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps" && c.Request.Method == http.MethodGet { - logListRunStepsRequest(logWithCid, prod, tid, rid, qm) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/moderations", openAiName) && c.Request.Method == http.MethodPost { + logCreateModerationRequest(logWithCid, body, prod, private) + } - if c.FullPath() == "/api/providers/openai/v1/moderations" && c.Request.Method == http.MethodPost { - logCreateModerationRequest(logWithCid, body, prod, private) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveModelRequest(logWithCid, prod, md) + } - if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodGet { - logRetrieveModelRequest(logWithCid, prod, md) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName) && c.Request.Method == http.MethodDelete { + logDeleteModelRequest(logWithCid, prod, md) + } - if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodDelete { - logDeleteModelRequest(logWithCid, prod, md) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodGet { + logListFilesRequest(logWithCid, prod, qm) + } - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodGet { - logListFilesRequest(logWithCid, prod, qm) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodPost { + purpose := c.PostForm("purpose") + logUploadFileRequest(logWithCid, prod, purpose) + } - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodPost { - purpose := c.PostForm("purpose") - logUploadFileRequest(logWithCid, prod, purpose) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName) && c.Request.Method == http.MethodDelete { + logDeleteFileRequest(logWithCid, prod, fid) + } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id" && c.Request.Method == http.MethodDelete { - logDeleteFileRequest(logWithCid, prod, fid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveFileRequest(logWithCid, prod, fid) + } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id" && c.Request.Method == http.MethodGet { - logRetrieveFileRequest(logWithCid, prod, fid) - } + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id/content", openAiName) && c.Request.Method == http.MethodGet { + logRetrieveFileContentRequest(logWithCid, prod, fid) + } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet { - logRetrieveFileContentRequest(logWithCid, prod, fid) } if ac.GetAccessStatus(kc.KeyId) { diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index e6078f4..d7ec0c1 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -36,8 +36,9 @@ type PoliciesManager interface { } type ProxyServer struct { - server *http.Server - log *zap.Logger + server *http.Server + log *zap.Logger + openAiUrls map[string]string } type recorder interface { @@ -79,14 +80,14 @@ func CorsMiddleware() gin.HandlerFunc { } } -func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeout time.Duration, ac accessCache, uac userAccessCache, pm PoliciesManager, scanner Scanner, cd CustomPolicyDetector, die deepinfraEstimator, um userManager, removeAgentHeaders bool) (*ProxyServer, error) { +func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeout time.Duration, ac accessCache, uac userAccessCache, pm PoliciesManager, scanner Scanner, cd CustomPolicyDetector, die deepinfraEstimator, um userManager, removeAgentHeaders bool, openAiUrls map[string]string) (*ProxyServer, error) { router := gin.New() prod := mode == "production" private := privacyMode == "strict" router.Use(CorsMiddleware()) router.Use(getTimeoutMiddleware(timeout)) - router.Use(getMiddleware(cpm, rm, pm, a, prod, private, log, pub, "proxy", ac, uac, http.Client{}, scanner, cd, um, removeAgentHeaders)) + router.Use(getMiddleware(cpm, rm, pm, a, prod, private, log, pub, "proxy", ac, uac, http.Client{}, scanner, cd, um, removeAgentHeaders, openAiUrls)) client := http.Client{} @@ -96,82 +97,105 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan // health check router.GET("/api/health", getGetHealthCheckHandler()) - // audios - router.POST("/api/providers/openai/v1/audio/speech", getSpeechHandler(prod, client)) - router.POST("/api/providers/openai/v1/audio/transcriptions", getTranscriptionsHandler(prod, client, e)) - router.POST("/api/providers/openai/v1/audio/translations", getTranslationsHandler(prod, client, e)) - - // completions - router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(prod, private, client, e)) - - // embeddings - router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(prod, private, client, e)) - - // moderations - router.POST("/api/providers/openai/v1/moderations", getPassThroughHandler(prod, private, client)) - - // models - router.GET("/api/providers/openai/v1/models", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/models/:model", getPassThroughHandler(prod, private, client)) - router.DELETE("/api/providers/openai/v1/models/:model", getPassThroughHandler(prod, private, client)) - - // assistants - router.POST("/api/providers/openai/v1/assistants", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client)) - router.DELETE("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/assistants", getPassThroughHandler(prod, private, client)) - - // assistant files - router.POST("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(prod, private, client)) - router.DELETE("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(prod, private, client)) - - // threads - router.POST("/api/providers/openai/v1/threads", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client)) - router.DELETE("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client)) - - // messages - router.POST("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(prod, private, client)) - - // message files - router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files", getPassThroughHandler(prod, private, client)) - - // runs - router.POST("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/threads/runs", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps", getPassThroughHandler(prod, private, client)) - - // files - router.GET("/api/providers/openai/v1/files", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/files", getPassThroughHandler(prod, private, client)) - router.DELETE("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/files/:file_id/content", getPassThroughHandler(prod, private, client)) - - // batch - router.POST("/api/providers/openai/v1/batches", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/batches/:batch_id", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/batches/:batch_id/cancel", getPassThroughHandler(prod, private, client)) - router.GET("/api/providers/openai/v1/batches", getPassThroughHandler(prod, private, client)) - - // images - router.POST("/api/providers/openai/v1/images/generations", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/images/edits", getPassThroughHandler(prod, private, client)) - router.POST("/api/providers/openai/v1/images/variations", getPassThroughHandler(prod, private, client)) + // For each openAI URL + for openAiName, openAiUrl := range openAiUrls { + + // audios + router.POST(fmt.Sprintf("/api/providers/%s/v1/audio/speech", openAiName), getSpeechHandler(prod, client, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/audio/transcriptions", openAiName), getTranscriptionsHandler(prod, client, e, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/audio/translations", openAiName), getTranslationsHandler(prod, client, e, openAiUrl)) + + // completions + router.POST(fmt.Sprintf("/api/providers/%s/v1/chat/completions", openAiName), getChatCompletionHandler(prod, private, client, e, openAiUrl)) + + // embeddings + router.POST(fmt.Sprintf("/api/providers/%s/v1/embeddings", openAiName), getEmbeddingHandler(prod, private, client, e, openAiUrl)) + + // moderations + router.POST(fmt.Sprintf("/api/providers/%s/v1/moderations", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // models + router.GET(fmt.Sprintf("/api/providers/%s/v1/models", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // assistants + router.POST(fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // assistant files + router.POST(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // threads + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // messages + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // message files + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files/:file_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // runs + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/cancel", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/threads/runs", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps/:step_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // files + router.GET(fmt.Sprintf("/api/providers/%s/v1/files", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/files", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/files/:file_id/content", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // batch + router.POST(fmt.Sprintf("/api/providers/%s/v1/batches", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/batches/:batch_id", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/batches/:batch_id/cancel", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/batches", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // images + router.POST(fmt.Sprintf("/api/providers/%s/v1/images/generations", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/images/edits", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/images/variations", openAiName), getPassThroughHandler(prod, private, client, openAiName, openAiUrl)) + + // vector store + router.POST(fmt.Sprintf("/api/providers/%s/v1/vector_stores", openAiName), getCreateVectorStoreHandler(prod, client, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/vector_stores", openAiName), getListVectorStoresHandler(prod, client, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id", openAiName), getGetVectorStoreHandler(prod, client, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id", openAiName), getModifyVectorStoreHandler(prod, client, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id", openAiName), getDeleteVectorStoreHandler(prod, client, openAiUrl)) + + // vector store files + router.POST(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/files", openAiName), getCreateVectorStoreFileHandler(prod, client, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/files", openAiName), getListVectorStoreFilesHandler(prod, client, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/files/:file_id", openAiName), getGetVectorStoreFileHandler(prod, client, openAiUrl)) + router.DELETE(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/files/:file_id", openAiName), getDeleteVectorStoreFileHandler(prod, client, openAiUrl)) + + // vector store file batches + router.POST(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/file_batches", openAiName), getCreateVectorStoreFileBatchHandler(prod, client, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/file_batches/:batch_id", openAiName), getGetVectorStoreFileBatchHandler(prod, client, openAiUrl)) + router.POST(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel", openAiName), getCancelVectorStoreFileBatchHandler(prod, client, openAiUrl)) + router.GET(fmt.Sprintf("/api/providers/%s/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files", openAiName), getListVectorStoreFileBatchFilesHandler(prod, client, openAiUrl)) + } // azure router.POST("/api/providers/azure/openai/deployments/:deployment_id/chat/completions", getAzureChatCompletionHandler(prod, private, client, aoe)) @@ -199,26 +223,7 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan router.POST("/api/custom/providers/:provider/*wildcard", getCustomProviderHandler(prod, client)) // custom route - router.POST("/api/routes/*route", getRouteHandler(prod, c, aoe, e, client, r)) - - // vector store - router.POST("/api/providers/openai/v1/vector_stores", getCreateVectorStoreHandler(prod, client)) - router.GET("/api/providers/openai/v1/vector_stores", getListVectorStoresHandler(prod, client)) - router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id", getGetVectorStoreHandler(prod, client)) - router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id", getModifyVectorStoreHandler(prod, client)) - router.DELETE("/api/providers/openai/v1/vector_stores/:vector_store_id", getDeleteVectorStoreHandler(prod, client)) - - // vector store files - router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/files", getCreateVectorStoreFileHandler(prod, client)) - router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/files", getListVectorStoreFilesHandler(prod, client)) - router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id", getGetVectorStoreFileHandler(prod, client)) - router.DELETE("/api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id", getDeleteVectorStoreFileHandler(prod, client)) - - // vector store file batches - router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches", getCreateVectorStoreFileBatchHandler(prod, client)) - router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id", getGetVectorStoreFileBatchHandler(prod, client)) - router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel", getCancelVectorStoreFileBatchHandler(prod, client)) - router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files", getListVectorStoreFileBatchFilesHandler(prod, client)) + router.POST("/api/routes/*route", getRouteHandler(prod, c, aoe, e, client, r, openAiUrls)) srv := &http.Server{ Addr: ":8002", @@ -226,8 +231,9 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan } return &ProxyServer{ - log: log, - server: srv, + log: log, + server: srv, + openAiUrls: openAiUrls, }, nil } @@ -279,7 +285,7 @@ func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Write return nil } -func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFunc { +func getPassThroughHandler(prod, private bool, client http.Client, openAiName, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) @@ -297,7 +303,7 @@ func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFu ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - targetUrl, err := buildProxyUrl(c) + targetUrl, err := buildProxyUrl(c, openAiName, openAiUrl) if err != nil { telemetry.Incr("bricksllm.proxy.get_pass_through_handler.proxy_url_not_found", tags, 1) logError(log, "error when building proxy url", prod, err) @@ -317,7 +323,7 @@ func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFu copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent")) - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodPost { purpose := c.PostForm("purpose") var b bytes.Buffer @@ -364,7 +370,7 @@ func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFu req.Body = io.NopCloser(&b) } - if c.FullPath() == "/api/providers/openai/v1/images/edits" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/edits", openAiName) && c.Request.Method == http.MethodPost { var b bytes.Buffer writer := multipart.NewWriter(&b) @@ -445,7 +451,7 @@ func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFu req.Body = io.NopCloser(&b) } - if c.FullPath() == "/api/providers/openai/v1/images/variations" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/variations", openAiName) && c.Request.Method == http.MethodPost { var b bytes.Buffer writer := multipart.NewWriter(&b) @@ -525,163 +531,163 @@ func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFu telemetry.Incr("bricksllm.proxy.get_pass_through_handler.success", tags, 1) telemetry.Timing("bricksllm.proxy.get_pass_through_handler.success_latency", dur, tags, 1) - if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName) && c.Request.Method == http.MethodPost { logAssistantResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodGet { logAssistantResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodPost { logAssistantResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodDelete { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodDelete { logDeleteAssistantResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName) && c.Request.Method == http.MethodGet { logListAssistantsResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName) && c.Request.Method == http.MethodPost { logAssistantFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { logAssistantFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodDelete { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName) && c.Request.Method == http.MethodDelete { logDeleteAssistantFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName) && c.Request.Method == http.MethodGet { logListAssistantFilesResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads", openAiName) && c.Request.Method == http.MethodPost { logThreadResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodGet { logThreadResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodPost { logThreadResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodDelete { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodDelete { logDeleteThreadResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName) && c.Request.Method == http.MethodPost { logMessageResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName) && c.Request.Method == http.MethodGet { logMessageResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName) && c.Request.Method == http.MethodPost { logMessageResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName) && c.Request.Method == http.MethodGet { logListMessagesResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { logRetrieveMessageFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files", openAiName) && c.Request.Method == http.MethodGet { logListMessageFilesResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName) && c.Request.Method == http.MethodPost { logRunResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName) && c.Request.Method == http.MethodGet { logRunResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName) && c.Request.Method == http.MethodPost { logRunResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName) && c.Request.Method == http.MethodGet { logListRunsResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", openAiName) && c.Request.Method == http.MethodPost { logRunResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/cancel", openAiName) && c.Request.Method == http.MethodPost { logRunResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/runs" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/runs", openAiName) && c.Request.Method == http.MethodPost { logRunResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps/:step_id", openAiName) && c.Request.Method == http.MethodGet { logRetrieveRunStepResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps", openAiName) && c.Request.Method == http.MethodGet { logListRunStepsResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/moderations" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/moderations", openAiName) && c.Request.Method == http.MethodPost { logCreateModerationResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/models" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models", openAiName) && c.Request.Method == http.MethodGet { logListModelsResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName) && c.Request.Method == http.MethodGet { logRetrieveModelResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodDelete { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName) && c.Request.Method == http.MethodDelete { logDeleteModelResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodGet { logListFilesResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodPost { logUploadFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id" && c.Request.Method == http.MethodDelete { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName) && c.Request.Method == http.MethodDelete { logDeleteFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { logRetrieveFileResponse(log, bytes, prod) } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id/content", openAiName) && c.Request.Method == http.MethodGet { logRetrieveFileContentResponse(log, prod) } - if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/generations", openAiName) && c.Request.Method == http.MethodPost { logImageResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/images/edits" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/edits", openAiName) && c.Request.Method == http.MethodPost { logImageResponse(log, bytes, prod, private) } - if c.FullPath() == "/api/providers/openai/v1/images/variations" && c.Request.Method == http.MethodPost { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/variations", openAiName) && c.Request.Method == http.MethodPost { logImageResponse(log, bytes, prod, private) } } @@ -714,193 +720,193 @@ func getPassThroughHandler(prod, private bool, client http.Client) gin.HandlerFu } } -func buildProxyUrl(c *gin.Context) (string, error) { - if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/assistants", nil +func buildProxyUrl(c *gin.Context, openAiName, openAiUrl string) (string, error) { + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/assistants", nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodDelete { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id", openAiName) && c.Request.Method == http.MethodDelete { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/assistants", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/assistants", nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id") + "/files", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id") + "/files", nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id") + "/files/" + c.Param("file_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id") + "/files/" + c.Param("file_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodDelete { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id") + "/files/" + c.Param("file_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files/:file_id", openAiName) && c.Request.Method == http.MethodDelete { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id") + "/files/" + c.Param("file_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/assistants/" + c.Param("assistant_id") + "/files", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/assistants/:assistant_id/files", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/assistants/" + c.Param("assistant_id") + "/files", nil } - if c.FullPath() == "/api/providers/openai/v1/threads" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodDelete { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id", openAiName) && c.Request.Method == http.MethodDelete { + return openAiUrl + "/v1/threads/" + c.Param("thread_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/messages", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/messages", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/messages", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/messages", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id") + "/files/" + c.Param("file_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id") + "/files/" + c.Param("file_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id") + "/files", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/messages/:message_id/files", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/messages/" + c.Param("message_id") + "/files", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/submit_tool_outputs", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/submit_tool_outputs", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/cancel", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/cancel", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/cancel", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/runs" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/threads/runs", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/runs", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/threads/runs", nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/steps/" + c.Param("step_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps/:step_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/steps/" + c.Param("step_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/steps", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/threads/" + c.Param("thread_id") + "/runs/" + c.Param("run_id") + "/steps", nil } - if c.FullPath() == "/api/providers/openai/v1/moderations" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/moderations", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/moderations", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/moderations", nil } - if c.FullPath() == "/api/providers/openai/v1/models" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/models", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/models", nil } - if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/models/" + c.Param("model"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/models/" + c.Param("model"), nil } - if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodDelete { - return "https://api.openai.com/v1/models/" + c.Param("model"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/models/:model", openAiName) && c.Request.Method == http.MethodDelete { + return openAiUrl + "/v1/models/" + c.Param("model"), nil } - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/files", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/files", nil } - if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/files", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/files", nil } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id" && c.Request.Method == http.MethodDelete { - return "https://api.openai.com/v1/files/" + c.Param("file_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName) && c.Request.Method == http.MethodDelete { + return openAiUrl + "/v1/files/" + c.Param("file_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/files/" + c.Param("file_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/files/" + c.Param("file_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/files/" + c.Param("file_id") + "/content", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/files/:file_id/content", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/files/" + c.Param("file_id") + "/content", nil } - if c.FullPath() == "/api/providers/openai/v1/batches" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/batches", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/batches", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/batches", nil } - if c.FullPath() == "/api/providers/openai/v1/batches/:batch_id" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/batches/" + c.Param("batch_id"), nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/batches/:batch_id", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/batches/" + c.Param("batch_id"), nil } - if c.FullPath() == "/api/providers/openai/v1/batches/:batch_id/cancel" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/batches/" + c.Param("batch_id") + "/cancel", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/batches/:batch_id/cancel", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/batches/" + c.Param("batch_id") + "/cancel", nil } - if c.FullPath() == "/api/providers/openai/v1/batches" && c.Request.Method == http.MethodGet { - return "https://api.openai.com/v1/batches", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/batches", openAiName) && c.Request.Method == http.MethodGet { + return openAiUrl + "/v1/batches", nil } - if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/images/generations", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/generations", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/images/generations", nil } - if c.FullPath() == "/api/providers/openai/v1/images/edits" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/images/edits", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/edits", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/images/edits", nil } - if c.FullPath() == "/api/providers/openai/v1/images/variations" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/images/variations", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/images/variations", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/images/variations", nil } - if c.FullPath() == "/api/providers/openai/v1/audio/speech" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/audio/speech", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/audio/speech", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/audio/speech", nil } - if c.FullPath() == "/api/providers/openai/v1/audio/transcriptions" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/audio/transcriptions", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/audio/transcriptions", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/audio/transcriptions", nil } - if c.FullPath() == "/api/providers/openai/v1/audio/translations" && c.Request.Method == http.MethodPost { - return "https://api.openai.com/v1/audio/translations", nil + if c.FullPath() == fmt.Sprintf("/api/providers/%s/v1/audio/translations", openAiName) && c.Request.Method == http.MethodPost { + return openAiUrl + "/v1/audio/translations", nil } return "", errors.New("cannot find corresponding OpenAI target proxy") @@ -915,80 +921,100 @@ var ( func (ps *ProxyServer) Run() { go func() { - ps.log.Info("proxy server listening at 8002") - - // health check - ps.log.Info("PORT 8002 | GET | /api/health is ready") - - // audio - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/audio/speech is ready for creating openai speeches") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/audio/transcriptions is ready for creating openai transcriptions") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/audio/translations is ready for creating openai translations") - - // chat completions - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/chat/completions is ready for forwarding chat completion requests to openai") - - // embeddings - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/embeddings is ready for forwarding embeddings requests to openai") - - // moderations - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/moderations is ready for forwarding moderation requests to openai") - - // models - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/models is ready for listing openai models") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/models/:model is ready for retrieving an openai model") - - // files - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/files is ready for listing files from openai") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/files is ready for uploading files to openai") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/files/:file_id is ready for retrieving a file metadata from openai") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/files/:file_id/content is ready for retrieving a file's content from openai") - - // assistants - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/assistants is ready for creating openai assistants") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/assistants/:assistant_id is ready for retrieving an openai assistant") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/assistants/:assistant_id is ready for modifying an openai assistant") - ps.log.Info("PORT 8002 | DELETE | /api/providers/openai/v1/assistants/:assistant_id is ready for deleting an openai assistant") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/assistants is ready for retrieving openai assistants") - - // assistant files - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/assistants/:assistant_id/files is ready for creating openai assistant file") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/assistants/:assistant_id/files/:file_id is ready for retrieving openai assistant file") - ps.log.Info("PORT 8002 | DELETE | /api/providers/openai/v1/assistants/:assistant_id/files/:file_id is ready for deleting openai assistant file") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/assistants/:assistant_id/files is ready for retireving openai assistant files") - - // threads - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads is ready for creating an openai thread") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id is ready for retrieving an openai thread") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id is ready for modifying an openai thread") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id is ready for deleting an openai thread") - - // messages - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id/messages is ready for creating an openai message") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/messages/:message_id is ready for retrieving an openai message") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id/messages/:message_id is ready for modifying an openai message") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/messages is ready for retrieving openai messages") - - // message files - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id is ready for retrieving an openai message file") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/messages/:message_id/files is ready for retrieving openai message files") - - // runs - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id/runs is ready for creating an openai run") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/runs/:run_id is ready for retrieving an openai run") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id/runs/:run_id is ready for modifying an openai run") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/runs is ready for retrieving openai runs") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs is ready for submitting tool outputs to an openai run") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel is ready for cancelling an openai run") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/threads/runs is ready for creating an openai thread and run") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id is ready for retrieving an openai run step") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps is ready for retrieving openai run steps") - - // images - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/images/generations is ready for generating openai images") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/images/edits is ready for editting openi images") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/images/variations is ready for generating openai image variations") - + for openAiName := range ps.openAiUrls { + ps.log.Info("proxy server listening at 8002") + + // health check + ps.log.Info("PORT 8002 | GET | /api/health is ready") + + // audio + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/audio/speech is ready for creating openai speeches", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/audio/transcriptions is ready for creating openai transcriptions", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/audio/translations is ready for creating openai translations", openAiName)) + + // chat completions + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/chat/completions is ready for forwarding chat completion requests to openai", openAiName)) + + // embeddings + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/embeddings is ready for forwarding embeddings requests to openai", openAiName)) + + // moderations + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/moderations is ready for forwarding moderation requests to openai", openAiName)) + + // models + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/models is ready for listing openai models", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/models/:model is ready for retrieving an openai model", openAiName)) + + // files + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/files is ready for listing files from openai", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/files is ready for uploading files to openai", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/files/:file_id is ready for retrieving a file metadata from openai", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/files/:file_id/content is ready for retrieving a file's content from openai", openAiName)) + + // assistants + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/assistants is ready for creating openai assistants", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/assistants/:assistant_id is ready for retrieving an openai assistant", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/assistants/:assistant_id is ready for modifying an openai assistant", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | DELETE | /api/providers/%s/v1/assistants/:assistant_id is ready for deleting an openai assistant", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/assistants is ready for retrieving openai assistants", openAiName)) + + // assistant files + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/assistants/:assistant_id/files is ready for creating openai assistant file", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/assistants/:assistant_id/files/:file_id is ready for retrieving openai assistant file", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | DELETE | /api/providers/%s/v1/assistants/:assistant_id/files/:file_id is ready for deleting openai assistant file", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/assistants/:assistant_id/files is ready for retireving openai assistant files", openAiName)) + + // threads + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads is ready for creating an openai thread", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id is ready for retrieving an openai thread", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id is ready for modifying an openai thread", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id is ready for deleting an openai thread", openAiName)) + + // messages + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id/messages is ready for creating an openai message", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/messages/:message_id is ready for retrieving an openai message", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id/messages/:message_id is ready for modifying an openai message", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/messages is ready for retrieving openai messages", openAiName)) + + // message files + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/messages/:message_id/files/:file_id is ready for retrieving an openai message file", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/messages/:message_id/files is ready for retrieving openai message files", openAiName)) + + // runs + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id/runs is ready for creating an openai run", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/runs/:run_id is ready for retrieving an openai run", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id/runs/:run_id is ready for modifying an openai run", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/runs is ready for retrieving openai runs", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs is ready for submitting tool outputs to an openai run", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/:thread_id/runs/:run_id/cancel is ready for cancelling an openai run", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/threads/runs is ready for creating an openai thread and run", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps/:step_id is ready for retrieving an openai run step", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/threads/:thread_id/runs/:run_id/steps is ready for retrieving openai run steps", openAiName)) + + // images + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/images/generations is ready for generating openai images", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/images/edits is ready for editting openi images", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/images/variations is ready for generating openai image variations", openAiName)) + + // vector store + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/vector_stores is ready for creating an openai vector store", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/vector_stores is ready for listing openai vector stores", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/vector_stores/:vector_store_id is ready for getting an openai vector store", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/vector_stores/:vector_store_id is ready for modifying an openai vector store", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | DELETE | /api/providers/%s/v1/vector_stores/:vector_store_id is ready for deleting an openai vector store", openAiName)) + + // vector store files + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/vector_stores/:vector_store_id/files is ready for creating an openai vector store file", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/vector_stores/:vector_store_id/files is ready for listing openai vector store files", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/vector_stores/:vector_store_id/files/:file_id is ready for getting an openai vector store file", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | DELETE | /api/providers/%s/v1/vector_stores/:vector_store_id/files/:file_id is ready for deleting an openai vector store file", openAiName)) + + // vector store file batches + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/vector_stores/:vector_store_id/file_batches is ready for creating an openai vector store file batch", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/vector_stores/:vector_store_id/file_batches/:batch_id is ready for getting an openai vector store file batch", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | POST | /api/providers/%s/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel is ready for cancelling an openai vector store file batch", openAiName)) + ps.log.Info(fmt.Sprintf("PORT 8002 | GET | /api/providers/%s/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files is ready for listing openai vector store file batch files", openAiName)) + } // azure ps.log.Info("PORT 8002 | POST | /api/providers/azure/openai/deployments/:deployment_id/chat/completions is ready for forwarding completion requests to azure openai") ps.log.Info("PORT 8002 | POST | /api/providers/azure/openai/deployments/:deployment_id/embeddings is ready for forwarding embeddings requests to azure openai") @@ -1016,25 +1042,6 @@ func (ps *ProxyServer) Run() { // custom route ps.log.Info("PORT 8002 | POST | /api/routes/*route is ready for forwarding requests to a custom route") - // vector store - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores is ready for creating an openai vector store") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores is ready for listing openai vector stores") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id is ready for getting an openai vector store") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id is ready for modifying an openai vector store") - ps.log.Info("PORT 8002 | DELETE | /api/providers/openai/v1/vector_stores/:vector_store_id is ready for deleting an openai vector store") - - // vector store files - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id/files is ready for creating an openai vector store file") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/files is ready for listing openai vector store files") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id is ready for getting an openai vector store file") - ps.log.Info("PORT 8002 | DELETE | /api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id is ready for deleting an openai vector store file") - - // vector store file batches - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches is ready for creating an openai vector store file batch") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id is ready for getting an openai vector store file batch") - ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel is ready for cancelling an openai vector store file batch") - ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files is ready for listing openai vector store file batch files") - if err := ps.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { ps.log.Sugar().Fatalf("error proxy server listening: %v", err) return diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go index 0dd3531..861429b 100644 --- a/internal/server/web/proxy/route.go +++ b/internal/server/web/proxy/route.go @@ -25,7 +25,7 @@ type cache interface { GetBytes(key string) ([]byte, error) } -func getRouteHandler(prod bool, ca cache, aoe azureEstimator, e estimator, client http.Client, rec recorder) gin.HandlerFunc { +func getRouteHandler(prod bool, ca cache, aoe azureEstimator, e estimator, client http.Client, rec recorder, openAiUrls map[string]string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) trueStart := time.Now() @@ -107,7 +107,7 @@ func getRouteHandler(prod bool, ca cache, aoe azureEstimator, e estimator, clien rreq.Request = bs } - runRes, err := rc.RunStepsV2(rreq, rec, log, kc) + runRes, err := rc.RunStepsV2(rreq, rec, log, kc, openAiUrls) if err != nil { telemetry.Incr("bricksllm.proxy.get_route_handeler.run_steps_error", tags, 1) logError(log, "error when running steps", prod, err) diff --git a/internal/server/web/proxy/vector_store.go b/internal/server/web/proxy/vector_store.go index 423a821..5e91467 100644 --- a/internal/server/web/proxy/vector_store.go +++ b/internal/server/web/proxy/vector_store.go @@ -13,7 +13,7 @@ import ( goopenai "github.com/sashabaranov/go-openai" ) -func getCreateVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc { +func getCreateVectorStoreHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_create_vector_store_handler.requests", nil, 1) @@ -26,7 +26,7 @@ func getCreateVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/vector_stores", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openAiUrl+"/v1/vector_stores", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -94,7 +94,7 @@ func getCreateVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc } } -func getListVectorStoresHandler(prod bool, client http.Client) gin.HandlerFunc { +func getListVectorStoresHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_list_vector_stores_handler.requests", nil, 1) @@ -107,7 +107,7 @@ func getListVectorStoresHandler(prod bool, client http.Client) gin.HandlerFunc { ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.openai.com/v1/vector_stores", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAiUrl+"/v1/vector_stores", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -175,7 +175,7 @@ func getListVectorStoresHandler(prod bool, client http.Client) gin.HandlerFunc { } } -func getGetVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc { +func getGetVectorStoreHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_get_vector_store_handler.requests", nil, 1) @@ -188,7 +188,7 @@ func getGetVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc { ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id"), c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id"), c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -256,7 +256,7 @@ func getGetVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc { } } -func getModifyVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc { +func getModifyVectorStoreHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_modify_vector_store_handler.requests", nil, 1) @@ -269,7 +269,7 @@ func getModifyVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id"), c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id"), c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -337,7 +337,7 @@ func getModifyVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc } } -func getDeleteVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc { +func getDeleteVectorStoreHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_delete_vector_store_handler.requests", nil, 1) @@ -350,7 +350,7 @@ func getDeleteVectorStoreHandler(prod bool, client http.Client) gin.HandlerFunc ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id"), c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id"), c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") diff --git a/internal/server/web/proxy/vector_store_file.go b/internal/server/web/proxy/vector_store_file.go index 91282c6..07c7490 100644 --- a/internal/server/web/proxy/vector_store_file.go +++ b/internal/server/web/proxy/vector_store_file.go @@ -13,7 +13,7 @@ import ( goopenai "github.com/sashabaranov/go-openai" ) -func getCreateVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerFunc { +func getCreateVectorStoreFileHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_create_vector_store_file_handler.requests", nil, 1) @@ -26,7 +26,7 @@ func getCreateVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerF ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/files", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/files", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -94,7 +94,7 @@ func getCreateVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerF } } -func getListVectorStoreFilesHandler(prod bool, client http.Client) gin.HandlerFunc { +func getListVectorStoreFilesHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_list_vector_store_files_handler.requests", nil, 1) @@ -107,7 +107,7 @@ func getListVectorStoreFilesHandler(prod bool, client http.Client) gin.HandlerFu ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/files", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/files", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -175,7 +175,7 @@ func getListVectorStoreFilesHandler(prod bool, client http.Client) gin.HandlerFu } } -func getGetVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerFunc { +func getGetVectorStoreFileHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_get_vector_store_file_handler.requests", nil, 1) @@ -188,7 +188,7 @@ func getGetVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerFunc ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/files/"+c.Param("file_id"), c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/files/"+c.Param("file_id"), c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -256,7 +256,7 @@ func getGetVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerFunc } } -func getDeleteVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerFunc { +func getDeleteVectorStoreFileHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_delete_vector_store_file_handler.requests", nil, 1) @@ -269,7 +269,7 @@ func getDeleteVectorStoreFileHandler(prod bool, client http.Client) gin.HandlerF ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/files/"+c.Param("file_id"), c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/files/"+c.Param("file_id"), c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") diff --git a/internal/server/web/proxy/vector_store_file_batch.go b/internal/server/web/proxy/vector_store_file_batch.go index 16e80eb..f33b103 100644 --- a/internal/server/web/proxy/vector_store_file_batch.go +++ b/internal/server/web/proxy/vector_store_file_batch.go @@ -13,7 +13,7 @@ import ( goopenai "github.com/sashabaranov/go-openai" ) -func getCreateVectorStoreFileBatchHandler(prod bool, client http.Client) gin.HandlerFunc { +func getCreateVectorStoreFileBatchHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_create_vector_store_file_batch_handler.requests", nil, 1) @@ -26,7 +26,7 @@ func getCreateVectorStoreFileBatchHandler(prod bool, client http.Client) gin.Han ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -94,7 +94,7 @@ func getCreateVectorStoreFileBatchHandler(prod bool, client http.Client) gin.Han } } -func getGetVectorStoreFileBatchHandler(prod bool, client http.Client) gin.HandlerFunc { +func getGetVectorStoreFileBatchHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_get_vector_store_file_batch_handler.requests", nil, 1) @@ -107,7 +107,7 @@ func getGetVectorStoreFileBatchHandler(prod bool, client http.Client) gin.Handle ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches/"+c.Param("batch_id"), c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches/"+c.Param("batch_id"), c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai http request") @@ -175,7 +175,7 @@ func getGetVectorStoreFileBatchHandler(prod bool, client http.Client) gin.Handle } } -func getCancelVectorStoreFileBatchHandler(prod bool, client http.Client) gin.HandlerFunc { +func getCancelVectorStoreFileBatchHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_cancel_vector_store_file_batch_handler.requests", nil, 1) @@ -188,7 +188,7 @@ func getCancelVectorStoreFileBatchHandler(prod bool, client http.Client) gin.Han ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches/"+c.Param("batch_id")+"/cancel", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches/"+c.Param("batch_id")+"/cancel", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") @@ -256,7 +256,7 @@ func getCancelVectorStoreFileBatchHandler(prod bool, client http.Client) gin.Han } } -func getListVectorStoreFileBatchFilesHandler(prod bool, client http.Client) gin.HandlerFunc { +func getListVectorStoreFileBatchFilesHandler(prod bool, client http.Client, openAiUrl string) gin.HandlerFunc { return func(c *gin.Context) { log := util.GetLogFromCtx(c) telemetry.Incr("bricksllm.proxy.get_list_vector_store_file_batch_files_handler.requests", nil, 1) @@ -269,7 +269,7 @@ func getListVectorStoreFileBatchFilesHandler(prod bool, client http.Client) gin. ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.openai.com/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches/"+c.Param("batch_id")+"/files", c.Request.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAiUrl+"/v1/vector_stores/"+c.Param("vector_store_id")+"/file_batches/"+c.Param("batch_id")+"/files", c.Request.Body) if err != nil { logError(log, "error when creating openai http request", prod, err) JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create azure openai http request") diff --git a/internal/testing/openai_chat_completion_integration_test.go b/internal/testing/openai_chat_completion_integration_test.go index c47444e..943e248 100644 --- a/internal/testing/openai_chat_completion_integration_test.go +++ b/internal/testing/openai_chat_completion_integration_test.go @@ -13,7 +13,7 @@ import ( "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/provider" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" - "github.com/caarlos0/env" + "github.com/caarlos0/env/v11" goopenai "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"