diff --git a/.github/workflows/prow-github.yml b/.github/workflows/prow-github.yml new file mode 100644 index 00000000..0c5f11fd --- /dev/null +++ b/.github/workflows/prow-github.yml @@ -0,0 +1,37 @@ +# Run specified actions or jobs for issue and PR comments + +name: "Prow github actions" +on: + issue_comment: + types: [created] + +# Grant additional permissions to the GITHUB_TOKEN +permissions: + # Allow labeling issues + issues: write + # Allow adding a review to a pull request + pull-requests: write + +jobs: + prow-execute: + runs-on: ubuntu-latest + steps: + - uses: jpmcb/prow-github-actions@v2.0.0 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + prow-commands: "/assign + /unassign + /approve + /retitle + /area + /kind + /priority + /remove + /lgtm + /close + /reopen + /lock + /milestone + /hold + /cc + /uncc" diff --git a/.github/workflows/prow-pr-automerge.yml b/.github/workflows/prow-pr-automerge.yml new file mode 100644 index 00000000..c9bb0972 --- /dev/null +++ b/.github/workflows/prow-pr-automerge.yml @@ -0,0 +1,18 @@ +# This Github workflow will check every 5m for PRs with the lgtm label and will attempt to automatically merge them. +# If the hold label is present, it will block automatic merging. + +name: "Prow merge on lgtm label" +on: + schedule: + - cron: "*/5 * * * *" # every 5 minutes + +jobs: + auto-merge: + runs-on: ubuntu-latest + steps: + - uses: jpmcb/prow-github-actions@v2.0.0 + with: + jobs: 'lgtm' + github-token: "${{ secrets.GITHUB_TOKEN }}" + merge-method: 'squash' + diff --git a/.github/workflows/prow-pr-remove-lgtm.yml b/.github/workflows/prow-pr-remove-lgtm.yml new file mode 100644 index 00000000..caf208f3 --- /dev/null +++ b/.github/workflows/prow-pr-remove-lgtm.yml @@ -0,0 +1,11 @@ +name: Run Jobs on PR +on: pull_request + +jobs: + execute: + runs-on: ubuntu-latest + steps: + - uses: jpmcb/prow-github-actions@v2.0.0 + with: + jobs: lgtm + github-token: '${{ secrets.GITHUB_TOKEN }}' diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml new file mode 100644 index 00000000..622a4ffd --- /dev/null +++ b/.github/workflows/stale.yaml @@ -0,0 +1,38 @@ +name: 'Mark stale issues' + +on: + schedule: + - cron: '0 1 * * *' + +jobs: + stale-issues: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - name: 'Mark stale issues' + uses: actions/stale@v9 + with: + days-before-issue-stale: 90 + days-before-pr-stale: -1 + days-before-close: -1 + stale-issue-label: 'lifecycle/stale' + exempt-issue-labels: 'lifecycle/rotten' + + - name: 'Mark rotten issues' + uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-pr-stale: -1 + days-before-close: -1 + stale-issue-label: 'lifecycle/rotten' + only-labels: 'lifecycle/stale' + labels-to-remove-when-stale: 'lifecycle/stale' + + - name: 'Close rotten issues' + uses: actions/stale@v9 + with: + days-before-stale: -1 + days-before-issue-close: 30 + days-before-pr-close: -1 + stale-issue-label: 'lifecycle/rotten' diff --git a/.github/workflows/unstale.yaml b/.github/workflows/unstale.yaml new file mode 100644 index 00000000..319ecf14 --- /dev/null +++ b/.github/workflows/unstale.yaml @@ -0,0 +1,27 @@ +name: 'Unstale Issue' + +on: + issues: + types: [ reopened ] + issue_comment: + types: [ created ] + +jobs: + remove-stale: + runs-on: ubuntu-latest + permissions: + issues: write + if: >- + github.event.issue.state == 'open' && + (contains(github.event.issue.labels.*.name, 'lifecycle/stale') || + contains(github.event.issue.labels.*.name, 'lifecycle/rotten')) + steps: + - name: 'Checkout repository' + uses: actions/checkout@v5 + + - name: 'Remove stale labels' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "Removing 'stale' label from issue #${{ github.event.issue.number }}" + gh issue edit ${{ github.event.issue.number }} --remove-label "lifecycle/stale,lifecycle/rotten" diff --git a/.gitignore b/.gitignore index 401369b3..ac6e6006 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,6 @@ bin lib vendor .vscode +.devcontainer # MacOSX .DS_Store \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 36a8836e..58a6ccbe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,6 @@ RUN microdnf install -y dnf && \ COPY --from=builder /workspace/bin/llm-d-inference-sim /app/llm-d-inference-sim -# USER 65532:65532 -USER root +USER 65532:65532 ENTRYPOINT ["/app/llm-d-inference-sim"] diff --git a/Makefile b/Makefile index 819091f4..c4fddb1d 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,16 @@ IMAGE_TAG_BASE ?= $(IMAGE_REGISTRY)/$(PROJECT_NAME) SIM_TAG ?= dev IMG = $(IMAGE_TAG_BASE):$(SIM_TAG) +ifeq ($(TARGETOS),darwin) +ifeq ($(TARGETARCH),amd64) +TOKENIZER_ARCH = x86_64 +else +TOKENIZER_ARCH = $(TARGETARCH) +endif +else +TOKENIZER_ARCH = $(TARGETARCH) +endif + CONTAINER_TOOL := $(shell { command -v docker >/dev/null 2>&1 && echo docker; } || { command -v podman >/dev/null 2>&1 && echo podman; } || echo "") BUILDER := $(shell command -v buildah >/dev/null 2>&1 && echo buildah || echo $(CONTAINER_TOOL)) PLATFORMS ?= linux/amd64 # linux/arm64 # linux/s390x,linux/ppc64le @@ -36,7 +46,7 @@ SRC = $(shell find . -type f -name '*.go') help: ## Print help @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) -LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' +GO_LDFLAGS := -extldflags '-L$(shell pwd)/lib $(LDFLAGS)' CGO_ENABLED=1 TOKENIZER_LIB = lib/libtokenizers.a # Extract TOKENIZER_VERSION from Dockerfile @@ -48,7 +58,7 @@ $(TOKENIZER_LIB): ## Download the HuggingFace tokenizer bindings. @echo "Downloading HuggingFace tokenizer bindings for version $(TOKENIZER_VERSION)..." mkdir -p lib - curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TOKENIZER_ARCH).tar.gz | tar -xz -C lib ranlib lib/*.a ##@ Development @@ -67,7 +77,11 @@ format: ## Format Go source files .PHONY: test test: check-ginkgo download-tokenizer download-zmq ## Run tests @printf "\033[33;1m==== Running tests ====\033[0m\n" - CGO_ENABLED=1 ginkgo -ldflags="$(LDFLAGS)" -v -r +ifdef GINKGO_FOCUS + CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r --focus="$(GINKGO_FOCUS)" +else + CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r +endif .PHONY: post-deploy-test post-deploy-test: ## Run post deployment tests @@ -84,7 +98,7 @@ lint: check-golangci-lint ## Run lint .PHONY: build build: check-go download-tokenizer download-zmq @printf "\033[33;1m==== Building ====\033[0m\n" - go build -ldflags="$(LDFLAGS)" -o bin/$(PROJECT_NAME) cmd/$(PROJECT_NAME)/main.go + go build -ldflags="$(GO_LDFLAGS)" -o bin/$(PROJECT_NAME) cmd/$(PROJECT_NAME)/main.go ##@ Container Build/Push @@ -92,8 +106,8 @@ build: check-go download-tokenizer download-zmq image-build: check-container-tool ## Build Docker image ## Build Docker image using $(CONTAINER_TOOL) @printf "\033[33;1m==== Building Docker image $(IMG) ====\033[0m\n" $(CONTAINER_TOOL) build \ - --platform $(TARGETOS)/$(TARGETARCH) \ - --build-arg TARGETOS=$(TARGETOS)\ + --platform linux/$(TARGETARCH) \ + --build-arg TARGETOS=linux \ --build-arg TARGETARCH=$(TARGETARCH)\ -t $(IMG) . @@ -160,7 +174,7 @@ check-ginkgo: .PHONY: check-golangci-lint check-golangci-lint: @command -v golangci-lint >/dev/null 2>&1 || { \ - echo "❌ golangci-lint is not installed. Install from https://golangci-lint.run/usage/install/"; exit 1; } + echo "❌ golangci-lint is not installed. Install from https://golangci-lint.run/docs/welcome/install/"; exit 1; } .PHONY: check-container-tool check-container-tool: diff --git a/OWNERS b/OWNERS new file mode 100644 index 00000000..8b464bc5 --- /dev/null +++ b/OWNERS @@ -0,0 +1,15 @@ +approvers: +- mayabar +- irar2 +- shmuelk +- elevran +- kfirtoledo +- nilig + +reviewers: +- mayabar +- irar2 +- shmuelk +- elevran +- kfirtoledo +- nilig diff --git a/README.md b/README.md index c40e7e28..8b1bd80c 100644 --- a/README.md +++ b/README.md @@ -116,15 +116,25 @@ For more details see the The following environment variables can be used to change the image tag: `REGISTRY`, `SIM_TAG`, `IMAGE_TAG_BASE` or `IMG`. +Note: On macOS, use `make image-build TARGETOS=linux` to pull the correct base image. + ### Running To run the vLLM Simulator image under Docker, run: ```bash @@ -186,6 +200,13 @@ To run the vLLM simulator in a Kubernetes cluster, run: kubectl apply -f manifests/deployment.yaml ``` +When testing locally with kind, build the docker image with `make build-image` then load into the cluster: +```shell +kind load --name kind docker-image ghcr.io/llm-d/llm-d-inference-sim:dev +``` + +Update the `deployment.yaml` file to use the dev tag. + To verify the deployment is available, run: ```bash kubectl get deployment vllm-llama3-8b-instruct diff --git a/go.mod b/go.mod index 65cbb3fb..70fc81bf 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/buaazp/fasthttprouter v0.1.1 github.com/go-logr/logr v1.4.2 github.com/google/uuid v1.6.0 - github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a + github.com/llm-d/llm-d-kv-cache-manager v0.2.1 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/openai/openai-go v0.1.0-beta.10 diff --git a/go.sum b/go.sum index 93f6363c..56ae979d 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,6 @@ github.com/buaazp/fasthttprouter v0.1.1/go.mod h1:h/Ap5oRVLeItGKTVBb+heQPks+HdIU github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/daulet/tokenizers v1.20.2 h1:tlq/vIOiBTKDPets3596aFvmJYLn3XI6LFKq4q9LKhQ= -github.com/daulet/tokenizers v1.20.2/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/daulet/tokenizers v1.22.1 h1:3wzAFIxfgRuqGKka8xdkeTbctDmmqOOs12GofqdorpM= github.com/daulet/tokenizers v1.22.1/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -68,10 +66,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache-manager v0.2.0 h1:7MXFPjy3P8nZ7HbB1LWhhVLHvNTLbZglkD/ZcT7UU1k= -github.com/llm-d/llm-d-kv-cache-manager v0.2.0/go.mod h1:ZTqwsnIVC6R5YuTUrYofPIUnCeZ9RvXn1UQAdxLYl1Y= -github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a h1:PXR37HLgYYfolzWQA2uQOEiJlj3IV9YSvgaEFqCRSa8= -github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a/go.mod h1:g2UlYKNJ4S860SAQ/QoRnytAFfnp8f1luW4IuZSMwCE= +github.com/llm-d/llm-d-kv-cache-manager v0.2.1 h1:PKIjJPUF9ILLFBNvZRa0QQ/liTQjBKwWChzcenEdM08= +github.com/llm-d/llm-d-kv-cache-manager v0.2.1/go.mod h1:s1xaE4ImkihWaLg2IQh4VN6L1PgN5RD1u1VarPey6dw= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/manifests/config_with_fake.yaml b/manifests/config_with_fake.yaml new file mode 100644 index 00000000..81d14c54 --- /dev/null +++ b/manifests/config_with_fake.yaml @@ -0,0 +1,16 @@ +model: "Qwen/Qwen2-0.5B" +max-loras: 2 +max-cpu-loras: 5 +max-num-seqs: 5 +mode: "random" +time-to-first-token: 2000 +inter-token-latency: 1000 +kv-cache-transfer-latency: 100 +seed: 100100100 +fake-metrics: + running-requests: 16 + waiting-requests: 3 + kv-cache-usage: 0.3 + loras: + - '{"running":"lora1,lora2","waiting":"lora3","timestamp":1257894567}' + - '{"running":"lora1,lora3","waiting":"","timestamp":1257894569}' diff --git a/manifests/deployment.yaml b/manifests/deployment.yaml index 75c001aa..aa23f3d5 100644 --- a/manifests/deployment.yaml +++ b/manifests/deployment.yaml @@ -25,6 +25,17 @@ spec: image: ghcr.io/llm-d/llm-d-inference-sim:latest imagePullPolicy: IfNotPresent name: vllm-sim + env: + - name: POD_NAME + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.name + - name: POD_NAMESPACE + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.namespace ports: - containerPort: 8000 name: http diff --git a/manifests/invalid-config.yaml b/manifests/invalid-config.yaml new file mode 100644 index 00000000..515bf4a5 --- /dev/null +++ b/manifests/invalid-config.yaml @@ -0,0 +1,9 @@ +port: 8001 +model: "Qwen/Qwen2-0.5B" +max-num-seqs: 5 +mode: "random" +time-to-first-token: 2000 +inter-token-latency: 1000 +kv-cache-transfer-latency: 100 +seed: 100100100 +zmq-max-connect-attempts: -111 diff --git a/pkg/common/config.go b/pkg/common/config.go index 181deb30..1e8add97 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -34,99 +34,139 @@ const ( vLLMDefaultPort = 8000 ModeRandom = "random" ModeEcho = "echo" + // Failure type constants + FailureTypeRateLimit = "rate_limit" + FailureTypeInvalidAPIKey = "invalid_api_key" + FailureTypeContextLength = "context_length" + FailureTypeServerError = "server_error" + FailureTypeInvalidRequest = "invalid_request" + FailureTypeModelNotFound = "model_not_found" + dummy = "dummy" ) type Configuration struct { // Port defines on which port the simulator runs - Port int `yaml:"port"` + Port int `yaml:"port" json:"port"` // Model defines the current base model name - Model string `yaml:"model"` + Model string `yaml:"model" json:"model"` // ServedModelNames is one or many model names exposed by the API - ServedModelNames []string `yaml:"served-model-name"` + ServedModelNames []string `yaml:"served-model-name" json:"served-model-name"` // MaxLoras defines maximum number of loaded LoRAs - MaxLoras int `yaml:"max-loras"` + MaxLoras int `yaml:"max-loras" json:"max-loras"` // MaxCPULoras defines maximum number of LoRAs to store in CPU memory - MaxCPULoras int `yaml:"max-cpu-loras"` + MaxCPULoras int `yaml:"max-cpu-loras" json:"max-cpu-loras"` // MaxNumSeqs is maximum number of sequences per iteration (the maximum // number of inference requests that could be processed at the same time) - MaxNumSeqs int `yaml:"max-num-seqs"` + MaxNumSeqs int `yaml:"max-num-seqs" json:"max-num-seqs"` // MaxModelLen is the model's context window, the maximum number of tokens // in a single request including input and output. Default value is 1024. - MaxModelLen int `yaml:"max-model-len"` + MaxModelLen int `yaml:"max-model-len" json:"max-model-len"` // LoraModulesString is a list of LoRA adapters as strings - LoraModulesString []string `yaml:"lora-modules"` + LoraModulesString []string `yaml:"lora-modules" json:"lora-modules"` // LoraModules is a list of LoRA adapters LoraModules []LoraModule // TimeToFirstToken time before the first token will be returned, in milliseconds - TimeToFirstToken int `yaml:"time-to-first-token"` + TimeToFirstToken int `yaml:"time-to-first-token" json:"time-to-first-token"` // TimeToFirstTokenStdDev standard deviation for time before the first token will be returned, // in milliseconds, optional, default is 0, can't be more than 30% of TimeToFirstToken, will not // cause the actual time to first token to differ by more than 70% from TimeToFirstToken - TimeToFirstTokenStdDev int `yaml:"time-to-first-token-std-dev"` + TimeToFirstTokenStdDev int `yaml:"time-to-first-token-std-dev" json:"time-to-first-token-std-dev"` // InterTokenLatency time between generated tokens, in milliseconds - InterTokenLatency int `yaml:"inter-token-latency"` + InterTokenLatency int `yaml:"inter-token-latency" json:"inter-token-latency"` // InterTokenLatencyStdDev standard deviation for time between generated tokens, in milliseconds, // optional, default is 0, can't be more than 30% of InterTokenLatency, will not cause the actual // inter token latency to differ by more than 70% from InterTokenLatency - InterTokenLatencyStdDev int `yaml:"inter-token-latency-std-dev"` + InterTokenLatencyStdDev int `yaml:"inter-token-latency-std-dev" json:"inter-token-latency-std-dev"` // KVCacheTransferLatency time to "transfer" kv-cache from another vLLM instance in case P/D is activated, // in milliseconds - KVCacheTransferLatency int `yaml:"kv-cache-transfer-latency"` + KVCacheTransferLatency int `yaml:"kv-cache-transfer-latency" json:"kv-cache-transfer-latency"` // KVCacheTransferLatencyStdDev standard deviation for time to "transfer" kv-cache from another // vLLM instance in case P/D is activated, in milliseconds, optional, default is 0, can't be more // than 30% of KVCacheTransferLatency, will not cause the actual latency to differ by more than 70% from // KVCacheTransferLatency - KVCacheTransferLatencyStdDev int `yaml:"kv-cache-transfer-latency-std-dev"` + KVCacheTransferLatencyStdDev int `yaml:"kv-cache-transfer-latency-std-dev" json:"kv-cache-transfer-latency-std-dev"` // Mode defines the simulator response generation mode, valid values: echo, random - Mode string `yaml:"mode"` + Mode string `yaml:"mode" json:"mode"` // Seed defines random seed for operations - Seed int64 `yaml:"seed"` + Seed int64 `yaml:"seed" json:"seed"` // MaxToolCallIntegerParam defines the maximum possible value of integer parameters in a tool call, // optional, defaults to 100 - MaxToolCallIntegerParam int `yaml:"max-tool-call-integer-param"` + MaxToolCallIntegerParam int `yaml:"max-tool-call-integer-param" json:"max-tool-call-integer-param"` // MinToolCallIntegerParam defines the minimum possible value of integer parameters in a tool call, // optional, defaults to 0 - MinToolCallIntegerParam int `yaml:"min-tool-call-integer-param"` + MinToolCallIntegerParam int `yaml:"min-tool-call-integer-param" json:"min-tool-call-integer-param"` // MaxToolCallNumberParam defines the maximum possible value of number (float) parameters in a tool call, // optional, defaults to 100 - MaxToolCallNumberParam float64 `yaml:"max-tool-call-number-param"` + MaxToolCallNumberParam float64 `yaml:"max-tool-call-number-param" json:"max-tool-call-number-param"` // MinToolCallNumberParam defines the minimum possible value of number (float) parameters in a tool call, // optional, defaults to 0 - MinToolCallNumberParam float64 `yaml:"min-tool-call-number-param"` + MinToolCallNumberParam float64 `yaml:"min-tool-call-number-param" json:"min-tool-call-number-param"` // MaxToolCallArrayParamLength defines the maximum possible length of array parameters in a tool call, // optional, defaults to 5 - MaxToolCallArrayParamLength int `yaml:"max-tool-call-array-param-length"` + MaxToolCallArrayParamLength int `yaml:"max-tool-call-array-param-length" json:"max-tool-call-array-param-length"` // MinToolCallArrayParamLength defines the minimum possible length of array parameters in a tool call, // optional, defaults to 1 - MinToolCallArrayParamLength int `yaml:"min-tool-call-array-param-length"` + MinToolCallArrayParamLength int `yaml:"min-tool-call-array-param-length" json:"min-tool-call-array-param-length"` // ToolCallNotRequiredParamProbability is the probability to add a parameter, that is not required, // in a tool call, optional, defaults to 50 - ToolCallNotRequiredParamProbability int `yaml:"tool-call-not-required-param-probability"` + ToolCallNotRequiredParamProbability int `yaml:"tool-call-not-required-param-probability" json:"tool-call-not-required-param-probability"` // ObjectToolCallNotRequiredParamProbability is the probability to add a field, that is not required, // in an object in a tool call, optional, defaults to 50 - ObjectToolCallNotRequiredParamProbability int `yaml:"object-tool-call-not-required-field-probability"` + ObjectToolCallNotRequiredParamProbability int `yaml:"object-tool-call-not-required-field-probability" json:"object-tool-call-not-required-field-probability"` // EnableKVCache defines if kv cache feature will be enabled - EnableKVCache bool `yaml:"enable-kvcache"` + EnableKVCache bool `yaml:"enable-kvcache" json:"enable-kvcache"` // KVCacheSize is the maximum number of token blocks in kv cache, the default value is 1024 - KVCacheSize int `yaml:"kv-cache-size"` + KVCacheSize int `yaml:"kv-cache-size" json:"kv-cache-size"` // TokenizersCacheDir is the directory for caching tokenizers - TokenizersCacheDir string `yaml:"tokenizers-cache-dir"` + TokenizersCacheDir string `yaml:"tokenizers-cache-dir" json:"tokenizers-cache-dir"` // TokenBlockSize is token block size for contiguous chunks of tokens, possible values: 8,16,32,64,128, defaults to 16 - TokenBlockSize int `yaml:"block-size"` + TokenBlockSize int `yaml:"block-size" json:"block-size"` // HashSeed is the seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable) - HashSeed string `yaml:"hash-seed"` + HashSeed string `yaml:"hash-seed" json:"hash-seed"` // ZMQEndpoint is the ZMQ address to publish events, the default value is tcp://localhost:5557 - ZMQEndpoint string `yaml:"zmq-endpoint"` + ZMQEndpoint string `yaml:"zmq-endpoint" json:"zmq-endpoint"` + // ZMQMaxConnectAttempts defines the maximum number (10) of retries when ZMQ connection fails + ZMQMaxConnectAttempts uint `yaml:"zmq-max-connect-attempts" json:"zmq-max-connect-attempts"` + // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 - EventBatchSize int `yaml:"event-batch-size"` + EventBatchSize int `yaml:"event-batch-size" json:"event-batch-size"` + + // FakeMetrics is a set of metrics to send to Prometheus instead of the real data + FakeMetrics *Metrics `yaml:"fake-metrics" json:"fake-metrics"` + + // FailureInjectionRate is the probability (0-100) of injecting failures + FailureInjectionRate int `yaml:"failure-injection-rate" json:"failure-injection-rate"` + // FailureTypes is a list of specific failure types to inject (empty means all types) + FailureTypes []string `yaml:"failure-types" json:"failure-types"` +} + +type Metrics struct { + // LoraMetrics + LoraMetrics []LorasMetrics `json:"loras"` + LorasString []string `yaml:"loras"` + // RunningRequests is the number of inference requests that are currently being processed + RunningRequests int64 `yaml:"running-requests" json:"running-requests"` + // WaitingRequests is the number of inference requests that are waiting to be processed + WaitingRequests int64 `yaml:"waiting-requests" json:"waiting-requests"` + // KVCacheUsagePercentage is the fraction of KV-cache blocks currently in use (from 0 to 1) + KVCacheUsagePercentage float32 `yaml:"kv-cache-usage" json:"kv-cache-usage"` +} + +type LorasMetrics struct { + // RunningLoras is a comma separated list of running LoRAs + RunningLoras string `json:"running"` + // WaitingLoras is a comma separated list of waiting LoRAs + WaitingLoras string `json:"waiting"` + // Timestamp is the timestamp of the metric + Timestamp float64 `json:"timestamp"` } type LoraModule struct { @@ -168,6 +208,29 @@ func (c *Configuration) unmarshalLoras() error { return nil } +func (c *Configuration) unmarshalFakeMetrics(fakeMetricsString string) error { + var metrics *Metrics + if err := json.Unmarshal([]byte(fakeMetricsString), &metrics); err != nil { + return err + } + c.FakeMetrics = metrics + return nil +} + +func (c *Configuration) unmarshalLoraFakeMetrics() error { + if c.FakeMetrics != nil { + c.FakeMetrics.LoraMetrics = make([]LorasMetrics, 0) + for _, jsonStr := range c.FakeMetrics.LorasString { + var lora LorasMetrics + if err := json.Unmarshal([]byte(jsonStr), &lora); err != nil { + return err + } + c.FakeMetrics.LoraMetrics = append(c.FakeMetrics.LoraMetrics, lora) + } + } + return nil +} + func newConfig() *Configuration { return &Configuration{ Port: vLLMDefaultPort, @@ -199,7 +262,14 @@ func (c *Configuration) load(configFile string) error { return fmt.Errorf("failed to unmarshal configuration: %s", err) } - return c.unmarshalLoras() + if err := c.unmarshalLoras(); err != nil { + return err + } + if err := c.unmarshalLoraFakeMetrics(); err != nil { + return err + } + + return nil } func (c *Configuration) validate() error { @@ -299,6 +369,39 @@ func (c *Configuration) validate() error { if c.EventBatchSize < 1 { return errors.New("event batch size cannot less than 1") } + + if c.FailureInjectionRate < 0 || c.FailureInjectionRate > 100 { + return errors.New("failure injection rate should be between 0 and 100") + } + + validFailureTypes := map[string]bool{ + FailureTypeRateLimit: true, + FailureTypeInvalidAPIKey: true, + FailureTypeContextLength: true, + FailureTypeServerError: true, + FailureTypeInvalidRequest: true, + FailureTypeModelNotFound: true, + } + for _, failureType := range c.FailureTypes { + if !validFailureTypes[failureType] { + return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType, + FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, + FailureTypeServerError, FailureTypeInvalidRequest, FailureTypeModelNotFound) + } + } + + if c.ZMQMaxConnectAttempts > 10 { + return errors.New("zmq retries times cannot be more than 10") + } + + if c.FakeMetrics != nil { + if c.FakeMetrics.RunningRequests < 0 || c.FakeMetrics.WaitingRequests < 0 { + return errors.New("fake metrics request counters cannot be negative") + } + if c.FakeMetrics.KVCacheUsagePercentage < 0 || c.FakeMetrics.KVCacheUsagePercentage > 1 { + return errors.New("fake metrics KV cache usage must be between 0 ans 1") + } + } return nil } @@ -316,6 +419,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { servedModelNames := getParamValueFromArgs("served-model-name") loraModuleNames := getParamValueFromArgs("lora-modules") + fakeMetrics := getParamValueFromArgs("fake-metrics") f := pflag.NewFlagSet("llm-d-inference-sim flags", pflag.ContinueOnError) @@ -326,7 +430,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -350,17 +454,27 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") + f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") + + failureTypes := getParamValueFromArgs("failure-types") + var dummyFailureTypes multiString + f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") + f.Lookup("failure-types").NoOptDefVal = dummy + // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string f.StringVar(&dummyString, "config", "", "The path to a yaml configuration file. The command line values overwrite the configuration file values") var dummyMultiString multiString f.Var(&dummyMultiString, "served-model-name", "Model names exposed by the API (a list of space-separated strings)") f.Var(&dummyMultiString, "lora-modules", "List of LoRA adapters (a list of space-separated JSON strings)") + f.Var(&dummyMultiString, "fake-metrics", "A set of metrics to report to Prometheus instead of the real metrics") // In order to allow empty arguments, we set a dummy NoOptDefVal for these flags - f.Lookup("served-model-name").NoOptDefVal = "dummy" - f.Lookup("lora-modules").NoOptDefVal = "dummy" + f.Lookup("served-model-name").NoOptDefVal = dummy + f.Lookup("lora-modules").NoOptDefVal = dummy + f.Lookup("fake-metrics").NoOptDefVal = dummy flagSet := flag.NewFlagSet("simFlagSet", flag.ExitOnError) klog.InitFlags(flagSet) @@ -381,9 +495,17 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { return nil, err } } + if fakeMetrics != nil { + if err := config.unmarshalFakeMetrics(fakeMetrics[0]); err != nil { + return nil, err + } + } if servedModelNames != nil { config.ServedModelNames = servedModelNames } + if failureTypes != nil { + config.FailureTypes = failureTypes + } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 6e768c27..770716a6 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -51,7 +51,6 @@ func createDefaultConfig(model string) *Configuration { c.KVCacheTransferLatency = 100 c.Seed = 100100100 c.LoraModules = []LoraModule{} - return c } @@ -104,12 +103,14 @@ var _ = Describe("Simulator configuration", func() { "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", } c.EventBatchSize = 5 + c.ZMQMaxConnectAttempts = 1 test = testCase{ name: "config file with command line args", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", "--served-model-name", "alias1", "alias2", "--seed", "100", "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", "--event-batch-size", "5", + "--zmq-max-connect-attempts", "1", }, expectedConfig: c, } @@ -122,6 +123,7 @@ var _ = Describe("Simulator configuration", func() { c.LoraModulesString = []string{ "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", } + c.ZMQMaxConnectAttempts = 0 test = testCase{ name: "config file with command line args with different format", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", @@ -173,7 +175,7 @@ var _ = Describe("Simulator configuration", func() { } tests = append(tests, test) - // Config from config.yaml file plus command line args with time to copy cache + // Config from basic-config.yaml file plus command line args with time to copy cache c = createDefaultConfig(qwenModelName) c.Port = 8001 // basic config file does not contain properties related to lora @@ -181,12 +183,82 @@ var _ = Describe("Simulator configuration", func() { c.MaxCPULoras = 1 c.KVCacheTransferLatency = 50 test = testCase{ - name: "config file with command line args with time to transfer kv-cache", + name: "basic config file with command line args with time to transfer kv-cache", args: []string{"cmd", "--config", "../../manifests/basic-config.yaml", "--kv-cache-transfer-latency", "50"}, expectedConfig: c, } tests = append(tests, test) + // Config from config_with_fake.yaml file + c = createDefaultConfig(qwenModelName) + c.FakeMetrics = &Metrics{ + RunningRequests: 16, + WaitingRequests: 3, + KVCacheUsagePercentage: float32(0.3), + LoraMetrics: []LorasMetrics{ + {RunningLoras: "lora1,lora2", WaitingLoras: "lora3", Timestamp: 1257894567}, + {RunningLoras: "lora1,lora3", WaitingLoras: "", Timestamp: 1257894569}, + }, + LorasString: []string{ + "{\"running\":\"lora1,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567}", + "{\"running\":\"lora1,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}", + }, + } + test = testCase{ + name: "config with fake metrics file", + args: []string{"cmd", "--config", "../../manifests/config_with_fake.yaml"}, + expectedConfig: c, + } + tests = append(tests, test) + + // Fake metrics from command line + c = newConfig() + c.Model = model + c.ServedModelNames = []string{c.Model} + c.MaxCPULoras = 1 + c.Seed = 100 + c.FakeMetrics = &Metrics{ + RunningRequests: 10, + WaitingRequests: 30, + KVCacheUsagePercentage: float32(0.4), + LoraMetrics: []LorasMetrics{ + {RunningLoras: "lora4,lora2", WaitingLoras: "lora3", Timestamp: 1257894567}, + {RunningLoras: "lora4,lora3", WaitingLoras: "", Timestamp: 1257894569}, + }, + LorasString: nil, + } + test = testCase{ + name: "metrics from command line", + args: []string{"cmd", "--model", model, "--seed", "100", + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + }, + expectedConfig: c, + } + tests = append(tests, test) + + // Fake metrics from both the config file and command line + c = createDefaultConfig(qwenModelName) + c.FakeMetrics = &Metrics{ + RunningRequests: 10, + WaitingRequests: 30, + KVCacheUsagePercentage: float32(0.4), + LoraMetrics: []LorasMetrics{ + {RunningLoras: "lora4,lora2", WaitingLoras: "lora3", Timestamp: 1257894567}, + {RunningLoras: "lora4,lora3", WaitingLoras: "", Timestamp: 1257894569}, + }, + LorasString: nil, + } + test = testCase{ + name: "metrics from config file and command line", + args: []string{"cmd", "--config", "../../manifests/config_with_fake.yaml", + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + }, + expectedConfig: c, + } + tests = append(tests, test) + for _, test := range tests { When(test.name, func() { It("should create correct configuration", func() { @@ -298,6 +370,37 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--event-batch-size", "-35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid failure injection rate > 100", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "150"}, + }, + { + name: "invalid failure injection rate < 0", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "-10"}, + }, + { + name: "invalid failure type", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "50", + "--failure-types", "invalid_type"}, + }, + { + name: "invalid fake metrics: negative running requests", + args: []string{"cmd", "--fake-metrics", "{\"running-requests\":-10,\"waiting-requests\":30,\"kv-cache-usage\":0.4}", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid fake metrics: kv cache usage", + args: []string{"cmd", "--fake-metrics", "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":40}", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) zmq-max-connect-attempts for argument", + args: []string{"cmd", "zmq-max-connect-attempts", "-1", "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) zmq-max-connect-attempts for config file", + args: []string{"cmd", "--config", "../../manifests/invalid-config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/common/publisher.go b/pkg/common/publisher.go index d7d6e325..883c05a2 100644 --- a/pkg/common/publisher.go +++ b/pkg/common/publisher.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "sync/atomic" + "time" zmq "github.com/pebbe/zmq4" "github.com/vmihailenco/msgpack/v5" @@ -38,24 +39,34 @@ type Publisher struct { // NewPublisher creates a new ZMQ publisher. // endpoint is the ZMQ address to bind to (e.g., "tcp://*:5557"). -func NewPublisher(endpoint string) (*Publisher, error) { +// retries is the maximum number of connection attempts. +func NewPublisher(endpoint string, retries uint) (*Publisher, error) { socket, err := zmq.NewSocket(zmq.PUB) if err != nil { return nil, fmt.Errorf("failed to create ZMQ PUB socket: %w", err) } - if err := socket.Connect(endpoint); err != nil { - errClose := socket.Close() - return nil, errors.Join( - fmt.Errorf("failed to connect to %s: %w", endpoint, err), - errClose, - ) + // Retry connection with specified retry times and intervals + for i := uint(0); i <= retries; i++ { + err = socket.Connect(endpoint) + if err == nil { + return &Publisher{ + socket: socket, + endpoint: endpoint, + }, nil + } + + // If not the last attempt, wait before retrying + if i < retries { + time.Sleep(1 * time.Second) + } } - return &Publisher{ - socket: socket, - endpoint: endpoint, - }, nil + errClose := socket.Close() + return nil, errors.Join( + fmt.Errorf("failed to connect to %s after %d retries: %w", endpoint, retries+1, err), + errClose, + ) } // PublishEvent publishes a KV cache event batch to the ZMQ topic. diff --git a/pkg/common/publisher_test.go b/pkg/common/publisher_test.go index 5df18940..a9d6582b 100644 --- a/pkg/common/publisher_test.go +++ b/pkg/common/publisher_test.go @@ -29,9 +29,11 @@ import ( ) const ( - topic = "test-topic" - endpoint = "tcp://localhost:5557" - data = "Hello" + topic = "test-topic" + subEndpoint = "tcp://*:5557" + pubEndpoint = "tcp://localhost:5557" + data = "Hello" + retries = 0 ) var _ = Describe("Publisher", func() { @@ -40,7 +42,7 @@ var _ = Describe("Publisher", func() { Expect(err).NotTo(HaveOccurred()) sub, err := zctx.NewSocket(zmq.SUB) Expect(err).NotTo(HaveOccurred()) - err = sub.Bind(endpoint) + err = sub.Bind(subEndpoint) Expect(err).NotTo(HaveOccurred()) err = sub.SetSubscribe(topic) Expect(err).NotTo(HaveOccurred()) @@ -49,7 +51,7 @@ var _ = Describe("Publisher", func() { time.Sleep(100 * time.Millisecond) - pub, err := NewPublisher(endpoint) + pub, err := NewPublisher(pubEndpoint, retries) Expect(err).NotTo(HaveOccurred()) ctx, cancel := context.WithCancel(context.Background()) @@ -77,4 +79,40 @@ var _ = Describe("Publisher", func() { Expect(err).NotTo(HaveOccurred()) Expect(payload).To(Equal(data)) }) + It("should fail when connection attempts exceed maximum retries", func() { + // Use invalid address format, which will cause connection to fail + invalidEndpoint := "invalid-address-format" + + pub, err := NewPublisher(invalidEndpoint, 2) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to connect")) + Expect(err.Error()).To(ContainSubstring("after 3 retries")) // 2 retries = 3 total attempts + + if pub != nil { + //nolint + pub.Close() + } + }) + It("should retry connection successfully", func() { + // Step 1: Try to connect to a temporarily non-existent service + // This will trigger the retry mechanism + go func() { + // Delay starting the server to simulate service recovery + time.Sleep(2 * time.Second) + + // Start subscriber as server + sub, err := zmq.NewSocket(zmq.SUB) + Expect(err).NotTo(HaveOccurred()) + //nolint + defer sub.Close() + err = sub.Bind(subEndpoint) + Expect(err).NotTo(HaveOccurred()) + }() + + // Step 2: Publisher will retry connection and eventually succeed + pub, err := NewPublisher(pubEndpoint, 5) // 5 retries + Expect(err).NotTo(HaveOccurred()) // Should eventually succeed + //nolint + defer pub.Close() + }) }) diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 309dc9d5..2cb4ad66 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -244,6 +244,13 @@ func RandomNorm(mean float64, stddev float64) float64 { return value } +// GenerateUUIDString generates a UUID string under a lock +func GenerateUUIDString() string { + randMutex.Lock() + defer randMutex.Unlock() + return uuid.NewString() +} + // Regular expression for the response tokenization var re *regexp.Regexp diff --git a/pkg/kv-cache/block_cache.go b/pkg/kv-cache/block_cache.go index e66c7224..56d2253b 100644 --- a/pkg/kv-cache/block_cache.go +++ b/pkg/kv-cache/block_cache.go @@ -48,7 +48,7 @@ func newBlockCache(config *common.Configuration, logger logr.Logger) (*blockCach // TODO read size of channel from config eChan := make(chan EventData, 10000) - publisher, err := common.NewPublisher(config.ZMQEndpoint) + publisher, err := common.NewPublisher(config.ZMQEndpoint, config.ZMQMaxConnectAttempts) if err != nil { return nil, err } diff --git a/pkg/kv-cache/kv_cache_sender.go b/pkg/kv-cache/kv_cache_sender.go index f8af3638..2b7bee14 100644 --- a/pkg/kv-cache/kv_cache_sender.go +++ b/pkg/kv-cache/kv_cache_sender.go @@ -16,7 +16,6 @@ limitations under the License. package kvcache import ( - "bytes" "context" "fmt" "time" @@ -90,24 +89,14 @@ func (s *KVEventSender) Run(ctx context.Context) error { } // Encode eventData's hash value to msgpack.RawMessage + var payload []byte var err error - var payload bytes.Buffer - enc := msgpack.NewEncoder(&payload) - enc.UseArrayEncodedStructs(true) switch eventData.action { case eventActionStore: - bs := &kvevents.BlockStoredEvent{ - TypeField: BlockStored, - BlockStored: &kvevents.BlockStored{BlockHashes: eventData.hashValues}, - } - err = enc.Encode(bs) + payload, err = msgpack.Marshal(storedToTaggedUnion(kvevents.BlockStored{BlockHashes: eventData.hashValues})) case eventActionRemove: - br := &kvevents.BlockRemovedEvent{ - TypeField: BlockRemoved, - BlockRemoved: &kvevents.BlockRemoved{BlockHashes: eventData.hashValues}, - } - err = enc.Encode(br) + payload, err = msgpack.Marshal(removedToTaggedUnion(kvevents.BlockRemoved{BlockHashes: eventData.hashValues})) default: return fmt.Errorf("invalid event action %d", eventData.action) } @@ -115,7 +104,7 @@ func (s *KVEventSender) Run(ctx context.Context) error { return fmt.Errorf("failed to marshal value: %w", err) } - s.batch = append(s.batch, payload.Bytes()) + s.batch = append(s.batch, payload) // check if batch is big enough to be sent if len(s.batch) >= s.maxBatchSize { @@ -139,6 +128,24 @@ func (s *KVEventSender) Run(ctx context.Context) error { } } +func storedToTaggedUnion(bs kvevents.BlockStored) []any { + return []any{ + BlockStored, + bs.BlockHashes, + bs.ParentBlockHash, + bs.TokenIds, + bs.BlockSize, + bs.LoraID, + } +} + +func removedToTaggedUnion(br kvevents.BlockRemoved) []any { + return []any{ + BlockRemoved, + br.BlockHashes, + } +} + // helper to publish collected batch if not empty func (s *KVEventSender) publishHelper(ctx context.Context) error { if len(s.batch) == 0 { diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index cc259c5b..7731196e 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -33,10 +33,11 @@ import ( ) const ( - req1ID = "req1" - req2ID = "req2" - req3ID = "req3" - endpoint = "tcp://localhost:5557" + req1ID = "req1" + req2ID = "req2" + req3ID = "req3" + subEndpoint = "tcp://*:5557" + pubEndpoint = "tcp://localhost:5557" ) type ActionType int @@ -200,11 +201,12 @@ var _ = Describe("KV cache", Ordered, func() { time.Sleep(300 * time.Millisecond) config := &common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: test.cacheSize, - ZMQEndpoint: endpoint, - EventBatchSize: 1, + Port: 1234, + Model: "model", + KVCacheSize: test.cacheSize, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, + EventBatchSize: 1, } sub, topic := createSub(config) @@ -303,10 +305,11 @@ var _ = Describe("KV cache", Ordered, func() { It("should send events correctly", func() { config := &common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: 4, - ZMQEndpoint: endpoint, + Port: 1234, + Model: "model", + KVCacheSize: 4, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, } sub, topic := createSub(config) @@ -412,10 +415,11 @@ var _ = Describe("KV cache", Ordered, func() { for _, testCase := range testCases { It(testCase.name, func() { config := common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: testCase.cacheSize, - ZMQEndpoint: endpoint, + Port: 1234, + Model: "model", + KVCacheSize: testCase.cacheSize, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, } blockCache, err := newBlockCache(&config, GinkgoLogr) Expect(err).NotTo(HaveOccurred()) @@ -496,22 +500,25 @@ func parseEvent(parts [][]byte, expectedTopic string, expectedSeq uint64) ([]uin Expect(err).NotTo(HaveOccurred()) for _, rawEvent := range eventBatch.Events { var taggedUnion []msgpack.RawMessage - err = msgpack.Unmarshal(rawEvent, &taggedUnion) + err := msgpack.Unmarshal(rawEvent, &taggedUnion) Expect(err).NotTo(HaveOccurred()) Expect(len(taggedUnion)).To(BeNumerically(">", 1)) + payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) + Expect(err).NotTo(HaveOccurred()) + var tag string err = msgpack.Unmarshal(taggedUnion[0], &tag) Expect(err).NotTo(HaveOccurred()) switch tag { case BlockStored: - var bs kvevents.BlockStoredEvent - err = msgpack.Unmarshal(rawEvent, &bs) + var bs kvevents.BlockStored + err = msgpack.Unmarshal(payloadBytes, &bs) stored = append(stored, bs.BlockHashes...) case BlockRemoved: - var br kvevents.BlockRemovedEvent - err = msgpack.Unmarshal(rawEvent, &br) + var br kvevents.BlockRemoved + err = msgpack.Unmarshal(payloadBytes, &br) removed = append(removed, br.BlockHashes...) default: @@ -528,7 +535,7 @@ func createSub(config *common.Configuration) (*zmq.Socket, string) { Expect(err).NotTo(HaveOccurred()) sub, err := zctx.NewSocket(zmq.SUB) Expect(err).NotTo(HaveOccurred()) - err = sub.Bind(endpoint) + err = sub.Bind(subEndpoint) Expect(err).NotTo(HaveOccurred()) topic := createTopic(config) err = sub.SetSubscribe(topic) diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go new file mode 100644 index 00000000..69daf36e --- /dev/null +++ b/pkg/llm-d-inference-sim/failures.go @@ -0,0 +1,88 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "fmt" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +const ( + // Error message templates + rateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1." + modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" +) + +var predefinedFailures = map[string]openaiserverapi.CompletionError{ + common.FailureTypeRateLimit: openaiserverapi.NewCompletionError(rateLimitMessageTemplate, 429, nil), + common.FailureTypeInvalidAPIKey: openaiserverapi.NewCompletionError("Incorrect API key provided.", 401, nil), + common.FailureTypeContextLength: openaiserverapi.NewCompletionError( + "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", + 400, stringPtr("messages")), + common.FailureTypeServerError: openaiserverapi.NewCompletionError( + "The server is overloaded or not ready yet.", 503, nil), + common.FailureTypeInvalidRequest: openaiserverapi.NewCompletionError( + "Invalid request: missing required parameter 'model'.", 400, stringPtr("model")), + common.FailureTypeModelNotFound: openaiserverapi.NewCompletionError(modelNotFoundMessageTemplate, + 404, stringPtr("model")), +} + +// shouldInjectFailure determines whether to inject a failure based on configuration +func shouldInjectFailure(config *common.Configuration) bool { + if config.FailureInjectionRate == 0 { + return false + } + + return common.RandomInt(1, 100) <= config.FailureInjectionRate +} + +// getRandomFailure returns a random failure from configured types or all types if none specified +func getRandomFailure(config *common.Configuration) openaiserverapi.CompletionError { + var availableFailures []string + if len(config.FailureTypes) == 0 { + // Use all failure types if none specified + for failureType := range predefinedFailures { + availableFailures = append(availableFailures, failureType) + } + } else { + availableFailures = config.FailureTypes + } + + if len(availableFailures) == 0 { + // Fallback to server_error if no valid types + return predefinedFailures[common.FailureTypeServerError] + } + + randomIndex := common.RandomInt(0, len(availableFailures)-1) + randomType := availableFailures[randomIndex] + + // Customize message with current model name + failure := predefinedFailures[randomType] + if randomType == common.FailureTypeRateLimit && config.Model != "" { + failure.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) + } else if randomType == common.FailureTypeModelNotFound && config.Model != "" { + failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) + } + + return failure +} + +func stringPtr(s string) *string { + return &s +} diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go new file mode 100644 index 00000000..5ff48034 --- /dev/null +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -0,0 +1,334 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +var _ = Describe("Failures", func() { + Describe("getRandomFailure", Ordered, func() { + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) + + It("should return a failure from all types when none specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(BeNumerically(">=", 400)) + Expect(failure.Message).ToNot(BeEmpty()) + Expect(failure.Type).ToNot(BeEmpty()) + }) + + It("should return rate limit failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{common.FailureTypeRateLimit}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(429)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) + Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) + }) + + It("should return invalid API key failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{common.FailureTypeInvalidAPIKey}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(401)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(401))) + Expect(failure.Message).To(Equal("Incorrect API key provided.")) + }) + + It("should return context length failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{common.FailureTypeContextLength}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(400)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(400))) + Expect(failure.Param).ToNot(BeNil()) + Expect(*failure.Param).To(Equal("messages")) + }) + + It("should return server error when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{common.FailureTypeServerError}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(503)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(503))) + }) + + It("should return model not found failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{common.FailureTypeModelNotFound}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(404)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(404))) + Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) + }) + + It("should return server error as fallback for empty types", func() { + config := &common.Configuration{ + FailureTypes: []string{}, + } + // This test is probabilistic since it randomly selects, but we can test structure + failure := getRandomFailure(config) + Expect(failure.Code).To(BeNumerically(">=", 400)) + Expect(failure.Type).ToNot(BeEmpty()) + }) + }) + Describe("Simulator with failure injection", func() { + var ( + client *http.Client + ctx context.Context + ) + + AfterEach(func() { + if ctx != nil { + ctx.Done() + } + }) + + Context("with 100% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should always return an error response for chat completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + + It("should always return an error response for text completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{ + Model: openai.CompletionNewParamsModel(model), + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + }) + + Context("with specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", common.FailureTypeRateLimit, + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only rate limit errors", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(429)) + Expect(openaiError.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) + Expect(strings.Contains(openaiError.Message, model)).To(BeTrue()) + }) + }) + + Context("with multiple specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only specified error types", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + // Make multiple requests to verify we get the expected error types + for i := 0; i < 10; i++ { + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + + // Should only be one of the specified types + Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) + Expect(openaiError.Type == openaiserverapi.ErrorCodeToType(401) || + openaiError.Type == openaiserverapi.ErrorCodeToType(503)).To(BeTrue()) + } + }) + }) + + Context("with 0% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "0", + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should never return errors and behave like random mode", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Choices).To(HaveLen(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + Expect(resp.Model).To(Equal(model)) + }) + }) + + Context("testing all predefined failure types", func() { + DescribeTable("should return correct error for each failure type", + func(failureType string, expectedStatusCode int, expectedErrorType string) { + ctx := context.Background() + client, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", failureType, + }, nil) + Expect(err).ToNot(HaveOccurred()) + + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(expectedStatusCode)) + Expect(openaiError.Type).To(Equal(expectedErrorType)) + // Note: OpenAI Go client doesn't directly expose the error code field, + // but we can verify via status code and type + }, + Entry("rate_limit", common.FailureTypeRateLimit, 429, openaiserverapi.ErrorCodeToType(429)), + Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, openaiserverapi.ErrorCodeToType(401)), + Entry("context_length", common.FailureTypeContextLength, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("server_error", common.FailureTypeServerError, 503, openaiserverapi.ErrorCodeToType(503)), + Entry("invalid_request", common.FailureTypeInvalidRequest, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("model_not_found", common.FailureTypeModelNotFound, 404, openaiserverapi.ErrorCodeToType(404)), + ) + }) + }) +}) diff --git a/pkg/llm-d-inference-sim/lora_test.go b/pkg/llm-d-inference-sim/lora_test.go index 682b8411..7ec37d0d 100644 --- a/pkg/llm-d-inference-sim/lora_test.go +++ b/pkg/llm-d-inference-sim/lora_test.go @@ -37,7 +37,7 @@ var _ = Describe("LoRAs", func() { client, err := startServerWithArgs(ctx, "", []string{"cmd", "--model", model, "--mode", common.ModeEcho, "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", - "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}"}) + "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}"}, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index c869ecd8..5b065648 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -19,9 +19,9 @@ limitations under the License. package llmdinferencesim import ( + "context" "strconv" "strings" - "sync/atomic" "time" "github.com/prometheus/client_golang/prometheus" @@ -96,25 +96,40 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { return nil } -// setInitialPrometheusMetrics send default values to prometheus +// setInitialPrometheusMetrics sends the default values to prometheus or +// the fake metrics if set func (s *VllmSimulator) setInitialPrometheusMetrics() { + var nRunningReqs, nWaitingReqs, kvCacheUsage float64 + if s.config.FakeMetrics != nil { + nRunningReqs = float64(s.config.FakeMetrics.RunningRequests) + nWaitingReqs = float64(s.config.FakeMetrics.WaitingRequests) + kvCacheUsage = float64(s.config.FakeMetrics.KVCacheUsagePercentage) + } modelName := s.getDisplayedModelName(s.config.Model) - s.loraInfo.WithLabelValues( - strconv.Itoa(s.config.MaxLoras), - "", - "").Set(float64(time.Now().Unix())) - - s.nRunningReqs = 0 - s.runningRequests.WithLabelValues( - modelName).Set(float64(s.nRunningReqs)) - s.waitingRequests.WithLabelValues( - modelName).Set(float64(0)) - s.kvCacheUsagePercentage.WithLabelValues( - modelName).Set(float64(0)) + s.runningRequests.WithLabelValues(modelName).Set(nRunningReqs) + s.waitingRequests.WithLabelValues(modelName).Set(nWaitingReqs) + s.kvCacheUsagePercentage.WithLabelValues(modelName).Set(kvCacheUsage) + + if s.config.FakeMetrics != nil && len(s.config.FakeMetrics.LoraMetrics) != 0 { + for _, metrics := range s.config.FakeMetrics.LoraMetrics { + s.loraInfo.WithLabelValues( + strconv.Itoa(s.config.MaxLoras), + metrics.RunningLoras, + metrics.WaitingLoras).Set(metrics.Timestamp) + } + } else { + s.loraInfo.WithLabelValues( + strconv.Itoa(s.config.MaxLoras), + "", + "").Set(float64(time.Now().Unix())) + } } // reportLoras sets information about loaded LoRA adapters func (s *VllmSimulator) reportLoras() { + if s.config.FakeMetrics != nil { + return + } if s.loraInfo == nil { // Happens in the tests return @@ -138,18 +153,61 @@ func (s *VllmSimulator) reportLoras() { // reportRunningRequests sets information about running completion requests func (s *VllmSimulator) reportRunningRequests() { + if s.config.FakeMetrics != nil { + return + } if s.runningRequests != nil { - nRunningReqs := atomic.LoadInt64(&(s.nRunningReqs)) s.runningRequests.WithLabelValues( - s.getDisplayedModelName(s.config.Model)).Set(float64(nRunningReqs)) + s.getDisplayedModelName(s.config.Model)).Set(float64(s.nRunningReqs)) } } // reportWaitingRequests sets information about waiting completion requests func (s *VllmSimulator) reportWaitingRequests() { + if s.config.FakeMetrics != nil { + return + } if s.waitingRequests != nil { - nWaitingReqs := atomic.LoadInt64(&(s.nWaitingReqs)) s.waitingRequests.WithLabelValues( - s.getDisplayedModelName(s.config.Model)).Set(float64(nWaitingReqs)) + s.getDisplayedModelName(s.config.Model)).Set(float64(s.nWaitingReqs)) + } +} + +func (s *VllmSimulator) unregisterPrometheus() { + prometheus.Unregister(s.loraInfo) + prometheus.Unregister(s.runningRequests) + prometheus.Unregister(s.waitingRequests) + prometheus.Unregister(s.kvCacheUsagePercentage) +} + +// startMetricsUpdaters starts the various metrics updaters +func (s *VllmSimulator) startMetricsUpdaters(ctx context.Context) { + go s.waitingRequestsUpdater(ctx) + go s.runningRequestsUpdater(ctx) +} + +// waitingRequestsUpdater updates the waiting requests metric by listening on the relevant channel +func (s *VllmSimulator) waitingRequestsUpdater(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case inc := <-s.waitingReqChan: + s.nWaitingReqs += inc + s.reportWaitingRequests() + } + } +} + +// runningRequestsUpdater updates the running requests metric by listening on the relevant channel +func (s *VllmSimulator) runningRequestsUpdater(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case inc := <-s.runReqChan: + s.nRunningReqs += inc + s.reportRunningRequests() + } } } diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go new file mode 100644 index 00000000..0d4e1f3c --- /dev/null +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -0,0 +1,276 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +var _ = Describe("Simulator metrics", Ordered, func() { + It("Should send correct running and waiting requests metrics", func() { + modelName := "testmodel" + // Three requests, only two can run in parallel, we expect + // two running requests and one waiting request in the metrics + ctx := context.TODO() + args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, + "--time-to-first-token", "3000", "--max-num-seqs", "2"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: modelName, + } + + var wg sync.WaitGroup + wg.Add(1) + + for range 3 { + go func() { + defer GinkgoRecover() + _, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + }() + } + + go func() { + defer wg.Done() + defer GinkgoRecover() + + time.Sleep(300 * time.Millisecond) + metricsResp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"testmodel\"} 2")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"testmodel\"} 1")) + }() + + wg.Wait() + }) + + It("Should send correct lora metrics", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--time-to-first-token", "3000", + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params1 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora1", + } + + _, err = openaiclient.Chat.Completions.New(ctx, params1) + Expect(err).NotTo(HaveOccurred()) + + params2 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora2", + } + + _, err = openaiclient.Chat.Completions.New(ctx, params2) + Expect(err).NotTo(HaveOccurred()) + + metricsResp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + + // We sent two sequentual requests to two different LoRAs, we expect to see (in this order) + // 1. running_lora_adapter = lora1 + // 2. running_lora_adapter = lora2 + // 3. running_lora_adapter = {} + lora1 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1\",waiting_lora_adapters=\"\"}" + lora2 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2\",waiting_lora_adapters=\"\"}" + empty := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"\",waiting_lora_adapters=\"\"}" + + Expect(metrics).To(ContainSubstring(lora1)) + Expect(metrics).To(ContainSubstring(lora2)) + Expect(metrics).To(ContainSubstring(empty)) + + // Check the order + lora1Timestamp := extractTimestamp(metrics, lora1) + lora2Timestamp := extractTimestamp(metrics, lora2) + noLorasTimestamp := extractTimestamp(metrics, empty) + + Expect(lora1Timestamp < lora2Timestamp).To(BeTrue()) + Expect(lora2Timestamp < noLorasTimestamp).To(BeTrue()) + }) + + It("Should send correct lora metrics for parallel requests", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--time-to-first-token", "2000", + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params1 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora1", + } + + params2 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora2", + } + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + time.Sleep(1 * time.Second) + defer wg.Done() + defer GinkgoRecover() + _, err := openaiclient.Chat.Completions.New(ctx, params2) + Expect(err).NotTo(HaveOccurred()) + }() + + _, err = openaiclient.Chat.Completions.New(ctx, params1) + Expect(err).NotTo(HaveOccurred()) + + wg.Wait() + + metricsResp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + + // We sent two parallel requests: first to lora1 and then to lora2 (with a delay), we expect + // to see (in this order) + // 1. running_lora_adapter = lora1 + // 2. running_lora_adapter = lora2,lora1 (the order of LoRAs doesn't matter here) + // 3. running_lora_adapter = lora2 + // 4. running_lora_adapter = {} + lora1 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1\",waiting_lora_adapters=\"\"}" + lora12 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1,lora2\",waiting_lora_adapters=\"\"}" + lora21 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2,lora1\",waiting_lora_adapters=\"\"}" + lora2 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2\",waiting_lora_adapters=\"\"}" + empty := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"\",waiting_lora_adapters=\"\"}" + + Expect(metrics).To(ContainSubstring(lora1)) + Expect(metrics).To(Or(ContainSubstring(lora12), ContainSubstring(lora21))) + Expect(metrics).To(ContainSubstring(lora2)) + Expect(metrics).To(ContainSubstring(empty)) + + // Check the order + lora1Timestamp := extractTimestamp(metrics, lora1) + lora2Timestamp := extractTimestamp(metrics, lora2) + noLorasTimestamp := extractTimestamp(metrics, empty) + var twoLorasTimestamp float64 + if strings.Contains(metrics, lora12) { + twoLorasTimestamp = extractTimestamp(metrics, lora12) + } else { + twoLorasTimestamp = extractTimestamp(metrics, lora21) + } + Expect(lora1Timestamp < twoLorasTimestamp).To(BeTrue()) + Expect(twoLorasTimestamp < lora2Timestamp).To(BeTrue()) + Expect(lora2Timestamp < noLorasTimestamp).To(BeTrue()) + }) + + Context("fake metrics", func() { + It("Should respond with fake metrics to /metrics", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + } + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + + defer s.unregisterPrometheus() + + resp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"my_model\"} 10")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"my_model\"} 30")) + Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"my_model\"} 0.4")) + Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora2\",waiting_lora_adapters=\"lora3\"} 1.257894567e+09")) + Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora3\",waiting_lora_adapters=\"\"} 1.257894569e+09")) + }) + }) +}) + +func extractTimestamp(metrics string, key string) float64 { + re := regexp.MustCompile(key + ` (\S+)`) + result := re.FindStringSubmatch(metrics) + Expect(len(result)).To(BeNumerically(">", 1)) + f, err := strconv.ParseFloat(result[1], 64) + Expect(err).NotTo(HaveOccurred()) + return f +} diff --git a/pkg/llm-d-inference-sim/seed_test.go b/pkg/llm-d-inference-sim/seed_test.go index 190ae7bf..505b4938 100644 --- a/pkg/llm-d-inference-sim/seed_test.go +++ b/pkg/llm-d-inference-sim/seed_test.go @@ -33,7 +33,7 @@ var _ = Describe("Simulator with seed", func() { func() { ctx := context.TODO() client, err := startServerWithArgs(ctx, common.ModeRandom, - []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--seed", "100"}) + []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--seed", "100"}, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 77b0738a..323a0162 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -20,16 +20,16 @@ package llmdinferencesim import ( "context" "encoding/json" + "errors" "fmt" "net" + "os" "strings" "sync" - "sync/atomic" "time" "github.com/buaazp/fasthttprouter" "github.com/go-logr/logr" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/valyala/fasthttp" @@ -46,6 +46,13 @@ const ( textCompletionObject = "text_completion" chatCompletionObject = "chat.completion" chatCompletionChunkObject = "chat.completion.chunk" + + podHeader = "x-inference-pod" + namespaceHeader = "x-inference-namespace" + podNameEnv = "POD_NAME" + podNsEnv = "POD_NAMESPACE" + + maxNumberOfRequests = 1000 ) // VllmSimulator simulates vLLM server supporting OpenAI API @@ -63,8 +70,12 @@ type VllmSimulator struct { waitingLoras sync.Map // nRunningReqs is the number of inference requests that are currently being processed nRunningReqs int64 + // runReqChan is a channel to update nRunningReqs + runReqChan chan int64 // nWaitingReqs is the number of inference requests that are waiting to be processed nWaitingReqs int64 + // waitingReqChan is a channel to update nWaitingReqs + waitingReqChan chan int64 // loraInfo is prometheus gauge loraInfo *prometheus.GaugeVec // runningRequests is prometheus gauge @@ -79,20 +90,28 @@ type VllmSimulator struct { toolsValidator *openaiserverapi.Validator // kv cache functionality kvcacheHelper *kvcache.KVCacheHelper + // namespace where simulator is running + namespace string + // pod name of simulator + pod string } // New creates a new VllmSimulator instance with the given logger func New(logger logr.Logger) (*VllmSimulator, error) { - toolsValidtor, err := openaiserverapi.CreateValidator() + toolsValidator, err := openaiserverapi.CreateValidator() if err != nil { return nil, fmt.Errorf("failed to create tools validator: %s", err) } return &VllmSimulator{ logger: logger, - reqChan: make(chan *openaiserverapi.CompletionReqCtx, 1000), - toolsValidator: toolsValidtor, + reqChan: make(chan *openaiserverapi.CompletionReqCtx, maxNumberOfRequests), + toolsValidator: toolsValidator, kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration + namespace: os.Getenv(podNsEnv), + pod: os.Getenv(podNameEnv), + runReqChan: make(chan int64, maxNumberOfRequests), + waitingReqChan: make(chan int64, maxNumberOfRequests), }, nil } @@ -103,8 +122,14 @@ func (s *VllmSimulator) Start(ctx context.Context) error { if err != nil { return err } + s.config = config + err = s.showConfig(s.logger) + if err != nil { + return err + } + for _, lora := range config.LoraModules { s.loraAdaptors.Store(lora.Name, "") } @@ -130,13 +155,16 @@ func (s *VllmSimulator) Start(ctx context.Context) error { for i := 1; i <= s.config.MaxNumSeqs; i++ { go s.reqProcessingWorker(ctx, i) } + + s.startMetricsUpdaters(ctx) + listener, err := s.newListener() if err != nil { return err } - // start the http server - return s.startServer(listener) + // start the http server with context support + return s.startServer(ctx, listener) } func (s *VllmSimulator) newListener() (net.Listener, error) { @@ -149,7 +177,7 @@ func (s *VllmSimulator) newListener() (net.Listener, error) { } // startServer starts http server on port defined in command line -func (s *VllmSimulator) startServer(listener net.Listener) error { +func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { r := fasthttprouter.New() // support completion APIs @@ -172,13 +200,33 @@ func (s *VllmSimulator) startServer(listener net.Listener) error { Logger: s, } - defer func() { - if err := listener.Close(); err != nil { - s.logger.Error(err, "server listener close failed") - } + // Start server in a goroutine + serverErr := make(chan error, 1) + go func() { + s.logger.Info("HTTP server starting") + serverErr <- server.Serve(listener) }() - return server.Serve(listener) + // Wait for either context cancellation or server error + select { + case <-ctx.Done(): + s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") + + // Gracefully shutdown the server + if err := server.Shutdown(); err != nil { + s.logger.Error(err, "Error during server shutdown") + return err + } + + s.logger.Info("HTTP server stopped") + return nil + + case err := <-serverErr: + if err != nil { + s.logger.Error(err, "HTTP server failed") + } + return err + } } // Print prints to a log, implementation of fasthttp.Logger @@ -188,7 +236,7 @@ func (s *VllmSimulator) Printf(format string, args ...interface{}) { // readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) { - requestID := uuid.NewString() + requestID := common.GenerateUUIDString() if isChatCompletion { var req openaiserverapi.ChatCompletionRequest @@ -246,20 +294,20 @@ func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) { s.unloadLora(ctx) } -func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, string, int) { +func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, int) { if !s.isValidModel(req.GetModel()) { - return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), "NotFoundError", fasthttp.StatusNotFound + return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), fasthttp.StatusNotFound } if req.GetMaxCompletionTokens() != nil && *req.GetMaxCompletionTokens() <= 0 { - return "Max completion tokens and max tokens should be positive", "Invalid request", fasthttp.StatusBadRequest + return "Max completion tokens and max tokens should be positive", fasthttp.StatusBadRequest } if req.IsDoRemoteDecode() && req.IsStream() { - return "Prefill does not support streaming", "Invalid request", fasthttp.StatusBadRequest + return "Prefill does not support streaming", fasthttp.StatusBadRequest } - return "", "", fasthttp.StatusOK + return "", fasthttp.StatusOK } // isValidModel checks if the given model is the base model or one of "loaded" LoRAs @@ -291,6 +339,13 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { + // Check if we should inject a failure + if shouldInjectFailure(s.config) { + failure := getRandomFailure(s.config) + s.sendCompletionError(ctx, failure, true) + return + } + vllmReq, err := s.readRequest(ctx, isChatCompletion) if err != nil { s.logger.Error(err, "failed to read and parse request body") @@ -298,9 +353,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } - errMsg, errType, errCode := s.validateRequest(vllmReq) + errMsg, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, errMsg, errType, errCode) + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(errMsg, errCode, nil), false) return } @@ -327,8 +382,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest) + message := fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens) + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(message, fasthttp.StatusBadRequest, nil), false) return } @@ -340,9 +396,8 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple IsChatCompletion: isChatCompletion, Wg: &wg, } + s.waitingReqChan <- 1 s.reqChan <- reqCtx - atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) - s.reportWaitingRequests() wg.Wait() } @@ -357,8 +412,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { s.logger.Info("reqProcessingWorker worker exiting: reqChan closed") return } - atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) - s.reportWaitingRequests() + + s.waitingReqChan <- -1 req := reqCtx.CompletionReq model := req.GetModel() @@ -381,8 +436,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // TODO - check if this request went to the waiting queue - add it to waiting map s.reportLoras() } - atomic.AddInt64(&(s.nRunningReqs), 1) - s.reportRunningRequests() + + s.runReqChan <- 1 var responseTokens []string var finishReason string @@ -453,9 +508,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // decrease model usage reference number func (s *VllmSimulator) responseSentCallback(model string) { - - atomic.AddInt64(&(s.nRunningReqs), -1) - s.reportRunningRequests() + s.runReqChan <- -1 // Only LoRA models require reference-count handling. if !s.isLora(model) { @@ -483,22 +536,25 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string, errType string, code int) { - compErr := openaiserverapi.CompletionError{ - Object: "error", - Message: msg, - Type: errType, - Code: code, - Param: nil, +// isInjected indicates if this is an injected failure for logging purposes +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, + compErr openaiserverapi.CompletionError, isInjected bool) { + if isInjected { + s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) + } else { + s.logger.Error(nil, compErr.Message) } - s.logger.Error(nil, compErr.Message) - data, err := json.Marshal(compErr) + errorResp := openaiserverapi.ErrorResponse{ + Error: compErr, + } + + data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(code) + ctx.SetStatusCode(compErr.Code) ctx.SetBody(data) } } @@ -534,7 +590,7 @@ func (s *VllmSimulator) HandleError(_ *fasthttp.RequestCtx, err error) { func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { baseResp := openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: time.Now().Unix(), Model: modelName, Usage: usageData, @@ -599,9 +655,15 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill) + s.getTotalInterTokenLatency(numOfTokens) time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond) - // TODO - maybe add pod id to response header for testing ctx.Response.Header.SetContentType("application/json") ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + // Add pod and namespace information to response headers for testing/debugging + if s.pod != "" { + ctx.Response.Header.Add(podHeader, s.pod) + } + if s.namespace != "" { + ctx.Response.Header.Add(namespaceHeader, s.namespace) + } ctx.Response.SetBody(data) s.responseSentCallback(modelName) @@ -691,3 +753,36 @@ func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { } return s.config.ServedModelNames[0] } + +func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error { + if tgtLgr == logr.Discard() { + return errors.New("target logger is nil, cannot show configuration") + } + cfgJSON, err := json.Marshal(s.config) + if err != nil { + return fmt.Errorf("failed to marshal configuration to JSON: %w", err) + } + + // clean LoraModulesString field + var m map[string]interface{} + err = json.Unmarshal(cfgJSON, &m) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON to map: %w", err) + } + m["lora-modules"] = m["LoraModules"] + delete(m, "LoraModules") + delete(m, "LoraModulesString") + + // clean fake-metrics field + if field, ok := m["fake-metrics"].(map[string]interface{}); ok { + delete(field, "LorasString") + } + + // show in JSON + cfgJSON, err = json.MarshalIndent(m, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal configuration to JSON: %w", err) + } + tgtLgr.Info("Configuration:", "", string(cfgJSON)) + return nil +} diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index fb8c0e8f..9e4c882b 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" . "github.com/onsi/ginkgo/v2" @@ -44,10 +45,16 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var userMsgTokens int64 func startServer(ctx context.Context, mode string) (*http.Client, error) { - return startServerWithArgs(ctx, mode, nil) + return startServerWithArgs(ctx, mode, nil, nil) } -func startServerWithArgs(ctx context.Context, mode string, args []string) (*http.Client, error) { +func startServerWithArgs(ctx context.Context, mode string, args []string, envs map[string]string) (*http.Client, error) { + _, client, err := startServerWithArgsAndMetrics(ctx, mode, args, envs, false) + return client, err +} + +func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []string, envs map[string]string, + setMetrics bool) (*VllmSimulator, *http.Client, error) { oldArgs := os.Args defer func() { os.Args = oldArgs @@ -58,15 +65,30 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http } else { os.Args = []string{"cmd", "--model", model, "--mode", mode} } + + if envs != nil { + for k, v := range envs { + err := os.Setenv(k, v) + Expect(err).NotTo(HaveOccurred()) + } + + defer func() { + for k := range envs { + err := os.Unsetenv(k) + Expect(err).NotTo(HaveOccurred()) + } + }() + } + logger := klog.Background() s, err := New(logger) if err != nil { - return nil, err + return nil, nil, err } config, err := common.ParseCommandParamsAndLoadConfig() if err != nil { - return nil, err + return nil, nil, err } s.config = config @@ -76,6 +98,13 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http common.InitRandom(s.config.Seed) + if setMetrics { + err = s.createAndRegisterPrometheus() + if err != nil { + return nil, nil, err + } + } + // calculate number of tokens for user message, // must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine userMsgTokens = int64(len(common.Tokenize(userMessage))) @@ -85,16 +114,18 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http go s.reqProcessingWorker(ctx, i) } + s.startMetricsUpdaters(ctx) + listener := fasthttputil.NewInmemoryListener() // start the http server go func() { - if err := s.startServer(listener); err != nil { + if err := s.startServer(ctx, listener); err != nil { logger.Error(err, "error starting server") } }() - return &http.Client{ + return s, &http.Client{ Transport: &http.Transport{ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return listener.Dial() @@ -402,12 +433,219 @@ var _ = Describe("Simulator", func() { Expect(resp.StatusCode).To(Equal(http.StatusOK)) }) + Context("namespace and pod headers", func() { + It("Should not include namespace and pod headers in chat completion response when env is not set", func() { + ctx := context.TODO() + + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") + Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") + }) + + It("Should include namespace and pod headers in chat completion response", func() { + ctx := context.TODO() + + testNamespace := "test-namespace" + testPod := "test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + + It("Should include namespace and pod headers in chat completion streaming response", func() { + ctx := context.TODO() + + testNamespace := "stream-test-namespace" + testPod := "stream-test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + + It("Should not include namespace and pod headers in chat completion streaming response when env is not set", func() { + ctx := context.TODO() + + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") + Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") + }) + + It("Should include namespace and pod headers in completion response", func() { + ctx := context.TODO() + + testNamespace := "test-namespace" + testPod := "test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + Model: openai.CompletionNewParamsModel(model), + } + var httpResp *http.Response + resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + + It("Should include namespace and pod headers in completion streaming response", func() { + ctx := context.TODO() + + testNamespace := "stream-test-namespace" + testPod := "stream-test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + Model: openai.CompletionNewParamsModel(model), + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + var httpResp *http.Response + resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + }) + Context("max-model-len context window validation", func() { It("Should reject requests exceeding context window", func() { ctx := context.TODO() // Start server with max-model-len=10 args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--max-model-len", "10"} - client, err := startServerWithArgs(ctx, common.ModeRandom, args) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) // Test with raw HTTP to verify the error response format @@ -457,7 +695,7 @@ var _ = Describe("Simulator", func() { ctx := context.TODO() // Start server with max-model-len=50 args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, "--max-model-len", "50"} - client, err := startServerWithArgs(ctx, common.ModeEcho, args) + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -483,7 +721,7 @@ var _ = Describe("Simulator", func() { ctx := context.TODO() // Start server with max-model-len=10 args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--max-model-len", "10"} - client, err := startServerWithArgs(ctx, common.ModeRandom, args) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) // Test with raw HTTP for text completion @@ -523,6 +761,8 @@ var _ = Describe("Simulator", func() { KVCacheTransferLatency: 2048, KVCacheTransferLatencyStdDev: 2048, } + + common.InitRandom(time.Now().UnixNano()) }) DescribeTable("should calculate inter token latency correctly", @@ -530,8 +770,8 @@ var _ = Describe("Simulator", func() { simulator.config.InterTokenLatency = interTokenLatency simulator.config.InterTokenLatencyStdDev = stddev interToken := simulator.getInterTokenLatency() - Expect(interToken).To(BeNumerically(">=", float32(interTokenLatency)*0.3)) - Expect(interToken).To(BeNumerically("<=", float32(interTokenLatency)*1.7)) + Expect(interToken).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3))) + Expect(interToken).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7))) }, func(interTokenLatency int, stddev int) string { return fmt.Sprintf("interTokenLatency: %d stddev: %d", interTokenLatency, stddev) @@ -547,8 +787,8 @@ var _ = Describe("Simulator", func() { simulator.config.InterTokenLatency = interTokenLatency simulator.config.InterTokenLatencyStdDev = stddev latency := simulator.getTotalInterTokenLatency(numberOfTokens) - Expect(latency).To(BeNumerically(">=", float32(interTokenLatency)*0.3*float32(numberOfTokens))) - Expect(latency).To(BeNumerically("<=", float32(interTokenLatency)*1.7*float32(numberOfTokens))) + Expect(latency).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3*float32(numberOfTokens)))) + Expect(latency).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7*float32(numberOfTokens)))) }, func(interTokenLatency int, stddev int, numberOfTokens int) string { return fmt.Sprintf("interTokenLatency: %d stddev: %d, numberOfTokens: %d", interTokenLatency, @@ -569,11 +809,11 @@ var _ = Describe("Simulator", func() { simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill) if doREmotePrefill { - Expect(timeToFirst).To(BeNumerically(">=", float32(kvCacheLatency)*0.3)) - Expect(timeToFirst).To(BeNumerically("<=", float32(kvCacheLatency)*1.7)) + Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) } else { - Expect(timeToFirst).To(BeNumerically(">=", float32(timeToFirstToken)*0.3)) - Expect(timeToFirst).To(BeNumerically("<=", float32(timeToFirstToken)*1.7)) + Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7))) } }, func(timeToFirstToken int, timeToFirstTokenStdDev int, diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index b173924c..969f29af 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -22,7 +22,6 @@ import ( "fmt" "time" - "github.com/google/uuid" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/valyala/fasthttp" @@ -45,6 +44,14 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons context.ctx.SetContentType("text/event-stream") context.ctx.SetStatusCode(fasthttp.StatusOK) + // Add pod and namespace information to response headers for testing/debugging + if s.pod != "" { + context.ctx.Response.Header.Add(podHeader, s.pod) + } + if s.namespace != "" { + context.ctx.Response.Header.Add(namespaceHeader, s.namespace) + } + context.ctx.SetBodyStreamWriter(func(w *bufio.Writer) { context.creationTime = time.Now().Unix() @@ -146,7 +153,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // supports both modes (text and chat) func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk { baseChunk := openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, Model: context.model, Usage: usageData, @@ -171,7 +178,7 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { return &openaiserverapi.TextCompletionResponse{ BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, Model: context.model, Object: textCompletionObject, @@ -191,7 +198,7 @@ func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, tok role string, finishReason *string) openaiserverapi.CompletionRespChunk { chunk := openaiserverapi.ChatCompletionRespChunk{ BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, Model: context.model, Object: chatCompletionChunkObject, diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index b4e1d009..c996db59 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -525,7 +525,7 @@ var _ = Describe("Simulator for request with tools", func() { "--min-tool-call-number-param", fmt.Sprint(min), "--max-tool-call-number-param", fmt.Sprint(max), } - client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs) + client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -787,7 +787,7 @@ var _ = Describe("Simulator for request with tools", func() { serverArgs := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--tool-call-not-required-param-probability", strconv.Itoa(probability), } - client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs) + client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -832,7 +832,7 @@ var _ = Describe("Simulator for request with tools", func() { "--min-tool-call-integer-param", strconv.Itoa(min), "--max-tool-call-integer-param", strconv.Itoa(max), } - client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs) + client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index a8f4a652..d32784e3 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -21,6 +21,8 @@ import ( "encoding/json" "errors" "strings" + + "github.com/valyala/fasthttp" ) // CompletionResponse interface representing both completion response types (text and chat) @@ -208,14 +210,53 @@ type ChatRespChunkChoice struct { // CompletionError defines the simulator's response in case of an error type CompletionError struct { - // Object is a type of this Object, "error" - Object string `json:"object"` // Message is an error Message Message string `json:"message"` // Type is a type of the error Type string `json:"type"` - // Params is the error's parameters + // Param is the error's parameter Param *string `json:"param"` - // Code is http status Code - Code int `json:"code"` + // Code is the error code integer (same as HTTP status code) + Code int `json:"code,omitempty"` +} + +// NewCompletionError creates a new CompletionError +func NewCompletionError(message string, code int, param *string) CompletionError { + return CompletionError{ + Message: message, + Code: code, + Type: ErrorCodeToType(code), + Param: param, + } +} + +// ErrorResponse wraps the error in the expected OpenAI format +type ErrorResponse struct { + Error CompletionError `json:"error"` +} + +// ErrorCodeToType maps error code to error type according to https://www.npmjs.com/package/openai +func ErrorCodeToType(code int) string { + errorType := "" + switch code { + case fasthttp.StatusBadRequest: + errorType = "BadRequestError" + case fasthttp.StatusUnauthorized: + errorType = "AuthenticationError" + case fasthttp.StatusForbidden: + errorType = "PermissionDeniedError" + case fasthttp.StatusNotFound: + errorType = "NotFoundError" + case fasthttp.StatusUnprocessableEntity: + errorType = "UnprocessableEntityError" + case fasthttp.StatusTooManyRequests: + errorType = "RateLimitError" + default: + if code >= fasthttp.StatusInternalServerError { + errorType = "InternalServerError" + } else { + errorType = "APIConnectionError" + } + } + return errorType }