From 638d0f74c554850b000a8465a4322ead5a9990ec Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 6 Aug 2025 16:17:19 +0300 Subject: [PATCH 01/46] Add definition of new action input (#123) Signed-off-by: Shmuel Kallner Signed-off-by: Sergey Marunich --- .github/actions/docker-build-and-push/action.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/actions/docker-build-and-push/action.yml b/.github/actions/docker-build-and-push/action.yml index ffeeabe6..ec923e03 100644 --- a/.github/actions/docker-build-and-push/action.yml +++ b/.github/actions/docker-build-and-push/action.yml @@ -13,6 +13,10 @@ inputs: registry: required: true description: Container registry (e.g., ghcr.io/llm-d) + prerelease: + required: true + description: indicates whether or not this is a pre-release (not a release) build + runs: using: "composite" steps: From 9ffe9574dfcaa38846bb8faaa41d0c47c1f38559 Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Thu, 7 Aug 2025 11:01:16 +0300 Subject: [PATCH 02/46] KV cache and tokenization related configuration (#125) Signed-off-by: Ira Signed-off-by: Sergey Marunich --- README.md | 7 +++++- pkg/common/config.go | 37 ++++++++++++++++++++++++++++ pkg/common/config_test.go | 10 ++++++++ pkg/kv-cache/kv_cache.go | 18 ++++++++------ pkg/llm-d-inference-sim/simulator.go | 2 +- 5 files changed, 64 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7628756b..fb6636c3 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,12 @@ For more details see the 100 { return errors.New("ObjectToolCallNotRequiredParamProbability should be between 0 and 100") } + + if c.TokenBlockSize != 8 && c.TokenBlockSize != 16 && c.TokenBlockSize != 32 && + c.TokenBlockSize != 64 && c.TokenBlockSize != 128 { + return errors.New("token block size should be one of the following: 8, 16, 32, 64, 128") + } + + if c.KVCacheSize < 0 { + return errors.New("KV cache size cannot be negative") + } return nil } @@ -313,7 +337,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MinToolCallArrayParamLength, "min-tool-call-array-param-length", config.MinToolCallArrayParamLength, "Minimum possible length of array parameters in a tool call") f.IntVar(&config.ToolCallNotRequiredParamProbability, "tool-call-not-required-param-probability", config.ToolCallNotRequiredParamProbability, "Probability to add a parameter, that is not required, in a tool call") f.IntVar(&config.ObjectToolCallNotRequiredParamProbability, "object-tool-call-not-required-field-probability", config.ObjectToolCallNotRequiredParamProbability, "Probability to add a field, that is not required, in an object in a tool call") + f.BoolVar(&config.EnableKVCache, "enable-kvcache", config.EnableKVCache, "Defines if KV cache feature is enabled") + f.IntVar(&config.KVCacheSize, "kv-cache-size", config.KVCacheSize, "Maximum number of token blocks in kv cache") + f.IntVar(&config.TokenBlockSize, "block-size", config.TokenBlockSize, "Token block size for contiguous chunks of tokens, possible values: 8,16,32,64,128") + f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") + f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") + f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -348,6 +378,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { config.ServedModelNames = servedModelNames } + if config.HashSeed == "" { + hashSeed := os.Getenv("PYTHONHASHSEED") + if hashSeed != "" { + config.HashSeed = hashSeed + } + } + if err := config.validate(); err != nil { return nil, err } diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index a813138e..2fd8446a 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -281,6 +281,16 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--kv-cache-transfer-latency-std-dev", "-35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid (negative) kv-cache-size", + args: []string{"cmd", "--kv-cache-size", "-35", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid block-size", + args: []string{"cmd", "--block-size", "35", + "--config", "../../manifests/config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index 7493d747..e68f750d 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -21,16 +21,12 @@ import ( "fmt" "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" ) -const ( - // TODO move it to configuration - maxBlocks = 100 -) - type KVCacheHelper struct { tokenizer tokenization.Tokenizer tokensProcessor kvblock.TokenProcessor // turns tokens to kv block keys @@ -38,12 +34,18 @@ type KVCacheHelper struct { blockCache *blockCache } -func NewKVCacheHelper(logger logr.Logger) (*KVCacheHelper, error) { - // TODO update config by command line params +func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCacheHelper, error) { tokenProcConfig := kvblock.DefaultTokenProcessorConfig() + tokenProcConfig.BlockSize = config.TokenBlockSize + if config.HashSeed != "" { + tokenProcConfig.HashSeed = config.HashSeed + } tokensProcessor := kvblock.NewChunkedTokenDatabase(tokenProcConfig) tokenizationConfig := tokenization.DefaultConfig() + if config.TokenizersCacheDir != "" { + tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir + } tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) if err != nil { @@ -53,7 +55,7 @@ func NewKVCacheHelper(logger logr.Logger) (*KVCacheHelper, error) { return &KVCacheHelper{ tokenizer: tokenizer, tokensProcessor: tokensProcessor, - blockCache: newBlockCache(maxBlocks, logger), + blockCache: newBlockCache(config.KVCacheSize, logger), logger: logger, }, nil } diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index b1d9ca54..77b0738a 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -118,7 +118,7 @@ func (s *VllmSimulator) Start(ctx context.Context) error { } if s.config.EnableKVCache { - s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.logger) + s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger) if err != nil { return err } From a5a7d81c311b01b1ddd1015ba5bf7a79052ab0fb Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Thu, 7 Aug 2025 11:01:36 +0300 Subject: [PATCH 03/46] Another attempt at adding a latest tag only on release builds (#124) Signed-off-by: Shmuel Kallner Signed-off-by: Sergey Marunich --- .github/actions/docker-build-and-push/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/docker-build-and-push/action.yml b/.github/actions/docker-build-and-push/action.yml index ec923e03..be358ece 100644 --- a/.github/actions/docker-build-and-push/action.yml +++ b/.github/actions/docker-build-and-push/action.yml @@ -36,7 +36,7 @@ runs: - name: Build image run: | - if [[ ${{ inputs.prerelease }} ]]; then + if [[ ${{ inputs.prerelease }} == "true" ]]; then LATEST_TAG="" else LATEST_TAG="-t ${{ inputs.registry }}/${{ inputs.image-name }}:latest" From 951f4a38ed25afa821748e11cf91094ebf37dcc8 Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Tue, 12 Aug 2025 10:05:24 +0300 Subject: [PATCH 04/46] Publish kv-cache events (#126) * Publish kv-cache events Signed-off-by: Ira * Fix lint errors Signed-off-by: Ira * Review fixes Signed-off-by: Ira * Sleep to allow prevous sub to close Signed-off-by: Ira --------- Signed-off-by: Ira Signed-off-by: Sergey Marunich --- Makefile | 2 +- go.mod | 7 +- go.sum | 29 +- pkg/common/config.go | 7 + pkg/common/config_test.go | 7 + pkg/common/publisher.go | 11 +- pkg/common/publisher_test.go | 4 +- pkg/kv-cache/block_cache.go | 24 +- pkg/kv-cache/block_cache_test.go | 329 ------------------- pkg/kv-cache/kv_cache.go | 8 +- pkg/kv-cache/kv_cache_sender.go | 38 ++- pkg/kv-cache/kv_cache_test.go | 537 +++++++++++++++++++++++++++++++ 12 files changed, 614 insertions(+), 389 deletions(-) delete mode 100644 pkg/kv-cache/block_cache_test.go create mode 100644 pkg/kv-cache/kv_cache_test.go diff --git a/Makefile b/Makefile index b546a1f4..0f761f54 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ $(TOKENIZER_LIB): ## Download the HuggingFace tokenizer bindings. @echo "Downloading HuggingFace tokenizer bindings..." mkdir -p lib - curl -L https://github.com/daulet/tokenizers/releases/download/v1.20.2/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + curl -L https://github.com/daulet/tokenizers/releases/download/v1.22.1/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib ranlib lib/*.a ##@ Development diff --git a/go.mod b/go.mod index 613112f8..65cbb3fb 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/buaazp/fasthttprouter v0.1.1 github.com/go-logr/logr v1.4.2 github.com/google/uuid v1.6.0 - github.com/llm-d/llm-d-kv-cache-manager v0.2.0 + github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/openai/openai-go v0.1.0-beta.10 @@ -17,7 +17,6 @@ require ( github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/spf13/pflag v1.0.6 github.com/valyala/fasthttp v1.59.0 - github.com/vmihailenco/msgpack v4.0.4+incompatible github.com/vmihailenco/msgpack/v5 v5.4.1 gopkg.in/yaml.v3 v3.0.1 k8s.io/klog/v2 v2.130.1 @@ -27,7 +26,7 @@ require ( github.com/andybalholm/brotli v1.1.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/daulet/tokenizers v1.20.2 // indirect + github.com/daulet/tokenizers v1.22.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect @@ -37,7 +36,6 @@ require ( github.com/go-openapi/swag v0.23.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/protobuf v1.5.2 // indirect github.com/google/gnostic-models v0.6.9 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect @@ -69,7 +67,6 @@ require ( golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.9.0 // indirect golang.org/x/tools v0.31.0 // indirect - google.golang.org/appengine v1.6.8 // indirect google.golang.org/protobuf v1.36.5 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index b8be9af0..93f6363c 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/daulet/tokenizers v1.20.2 h1:tlq/vIOiBTKDPets3596aFvmJYLn3XI6LFKq4q9LKhQ= github.com/daulet/tokenizers v1.20.2/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= +github.com/daulet/tokenizers v1.22.1 h1:3wzAFIxfgRuqGKka8xdkeTbctDmmqOOs12GofqdorpM= +github.com/daulet/tokenizers v1.22.1/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -37,12 +39,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/gnostic-models v0.6.9 h1:MU/8wDLif2qCXZmzncUQ/BOfxWfthHi63KqpoNbWqVw= github.com/google/gnostic-models v0.6.9/go.mod h1:CiWsm0s6BSQd1hRn8/QmxqB6BesYcbSZxsz9b0KuDBw= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= @@ -72,6 +70,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/llm-d/llm-d-kv-cache-manager v0.2.0 h1:7MXFPjy3P8nZ7HbB1LWhhVLHvNTLbZglkD/ZcT7UU1k= github.com/llm-d/llm-d-kv-cache-manager v0.2.0/go.mod h1:ZTqwsnIVC6R5YuTUrYofPIUnCeZ9RvXn1UQAdxLYl1Y= +github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a h1:PXR37HLgYYfolzWQA2uQOEiJlj3IV9YSvgaEFqCRSa8= +github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a/go.mod h1:g2UlYKNJ4S860SAQ/QoRnytAFfnp8f1luW4IuZSMwCE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -137,8 +137,6 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI= github.com/valyala/fasthttp v1.59.0/go.mod h1:GTxNb9Bc6r2a9D0TWNSPwDz78UxnTGBViY3xZNEqyYU= -github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= -github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= @@ -149,7 +147,6 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -157,16 +154,12 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= @@ -174,24 +167,15 @@ golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= @@ -200,17 +184,12 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/common/config.go b/pkg/common/config.go index b613e621..181deb30 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -125,6 +125,8 @@ type Configuration struct { // ZMQEndpoint is the ZMQ address to publish events, the default value is tcp://localhost:5557 ZMQEndpoint string `yaml:"zmq-endpoint"` + // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 + EventBatchSize int `yaml:"event-batch-size"` } type LoraModule struct { @@ -183,6 +185,7 @@ func newConfig() *Configuration { KVCacheSize: 1024, TokenBlockSize: 16, ZMQEndpoint: "tcp://localhost:5557", + EventBatchSize: 16, } } @@ -293,6 +296,9 @@ func (c *Configuration) validate() error { if c.KVCacheSize < 0 { return errors.New("KV cache size cannot be negative") } + if c.EventBatchSize < 1 { + return errors.New("event batch size cannot less than 1") + } return nil } @@ -344,6 +350,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") + f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 2fd8446a..6e768c27 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -103,11 +103,13 @@ var _ = Describe("Simulator configuration", func() { "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", } + c.EventBatchSize = 5 test = testCase{ name: "config file with command line args", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", "--served-model-name", "alias1", "alias2", "--seed", "100", "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", + "--event-batch-size", "5", }, expectedConfig: c, } @@ -291,6 +293,11 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--block-size", "35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid (negative) event-batch-size", + args: []string{"cmd", "--event-batch-size", "-35", + "--config", "../../manifests/config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/common/publisher.go b/pkg/common/publisher.go index 868e15fa..d7d6e325 100644 --- a/pkg/common/publisher.go +++ b/pkg/common/publisher.go @@ -17,6 +17,7 @@ limitations under the License. package common import ( + "bytes" "context" "encoding/binary" "errors" @@ -24,7 +25,7 @@ import ( "sync/atomic" zmq "github.com/pebbe/zmq4" - "github.com/vmihailenco/msgpack" + "github.com/vmihailenco/msgpack/v5" "k8s.io/klog/v2" ) @@ -62,7 +63,11 @@ func NewPublisher(endpoint string) (*Publisher, error) { func (p *Publisher) PublishEvent(ctx context.Context, topic string, batch interface{}) error { logger := klog.FromContext(ctx).V(0) - payload, err := msgpack.Marshal(batch) + // Use an encoder configured for struct as array + var payload bytes.Buffer + enc := msgpack.NewEncoder(&payload) + enc.UseArrayEncodedStructs(true) + err := enc.Encode(batch) if err != nil { return fmt.Errorf("failed to marshal event batch: %w", err) } @@ -73,7 +78,7 @@ func (p *Publisher) PublishEvent(ctx context.Context, topic string, batch interf binary.BigEndian.PutUint64(seqBytes, seq) // send topic, sequence, payload - if _, err := p.socket.SendMessage(topic, seqBytes, payload); err != nil { + if _, err := p.socket.SendMessage(topic, seqBytes, payload.Bytes()); err != nil { return fmt.Errorf("failed to send message to topic %s: %w", topic, err) } diff --git a/pkg/common/publisher_test.go b/pkg/common/publisher_test.go index 46f772de..5df18940 100644 --- a/pkg/common/publisher_test.go +++ b/pkg/common/publisher_test.go @@ -25,7 +25,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" zmq "github.com/pebbe/zmq4" - "github.com/vmihailenco/msgpack" + "github.com/vmihailenco/msgpack/v5" ) const ( @@ -44,6 +44,8 @@ var _ = Describe("Publisher", func() { Expect(err).NotTo(HaveOccurred()) err = sub.SetSubscribe(topic) Expect(err).NotTo(HaveOccurred()) + //nolint + defer sub.Close() time.Sleep(100 * time.Millisecond) diff --git a/pkg/kv-cache/block_cache.go b/pkg/kv-cache/block_cache.go index 4e8ffecc..e66c7224 100644 --- a/pkg/kv-cache/block_cache.go +++ b/pkg/kv-cache/block_cache.go @@ -23,11 +23,11 @@ import ( "time" "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" ) const ( capacityError = "the kv cache does not have sufficient capacity to store this request" - batchSize = 3 delay = time.Second ) @@ -44,20 +44,24 @@ type blockCache struct { } // newBlockCache creates a new blockCache with the specified maximum number of blocks -func newBlockCache(maxBlocks int, logger logr.Logger) *blockCache { +func newBlockCache(config *common.Configuration, logger logr.Logger) (*blockCache, error) { // TODO read size of channel from config eChan := make(chan EventData, 10000) + publisher, err := common.NewPublisher(config.ZMQEndpoint) + if err != nil { + return nil, err + } + return &blockCache{ requestToBlocks: make(map[string][]uint64), usedBlocks: make(map[uint64]int), unusedBlocks: make(map[uint64]time.Time), - maxBlocks: maxBlocks, + maxBlocks: config.KVCacheSize, eventChan: eChan, - // TODO - create topic name from pod ip + model name - eventSender: NewKVEventSender(&Publisher{}, "topic1", eChan, batchSize, delay, logger), - logger: logger, - } + eventSender: NewKVEventSender(publisher, createTopic(config), eChan, config.EventBatchSize, delay, logger), + logger: logger, + }, nil } func (b *blockCache) start(ctx context.Context) { @@ -128,7 +132,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error { } delete(bc.unusedBlocks, oldestUnusedHash) - bc.eventChan <- EventData{action: eventActionRemove, hashValues: []uint64{block}} + bc.eventChan <- EventData{action: eventActionRemove, hashValues: []uint64{oldestUnusedHash}} } // Add the new block @@ -214,3 +218,7 @@ func (bc *blockCache) getBlockInfo(blockHash uint64) (int, bool) { return 0, false } + +func createTopic(config *common.Configuration) string { + return fmt.Sprintf("kv@$localhost:%d@%s", config.Port, config.Model) +} diff --git a/pkg/kv-cache/block_cache_test.go b/pkg/kv-cache/block_cache_test.go deleted file mode 100644 index 589b3188..00000000 --- a/pkg/kv-cache/block_cache_test.go +++ /dev/null @@ -1,329 +0,0 @@ -/* -Copyright 2025 The llm-d-inference-sim Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package kvcache - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/llm-d/llm-d-inference-sim/pkg/common" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -const ( - req1ID = "req1" - req2ID = "req2" - req3ID = "req3" -) - -type ActionType int - -const ( - actionStartRequest ActionType = iota - actionFinishRequest -) - -type testRequest struct { - id string - blocks []uint64 -} - -type expectedBlockInfo struct { - exists bool - refCount int -} - -type testAction struct { - action ActionType - request testRequest - isError bool - errMsg string - expectedActiveRequests int - expectedTotalBlocks int - expectedUnusedBlocks int - expectedBlocksInfo map[uint64]expectedBlockInfo -} - -func newStartAction(request testRequest) testAction { - return testAction{ - action: actionStartRequest, - request: request, - isError: false, - expectedActiveRequests: -1, - expectedTotalBlocks: -1, - expectedUnusedBlocks: -1, - } -} -func newInvalidTestAction(action ActionType, request testRequest, errMsg string) testAction { - return testAction{ - action: action, - request: request, - isError: true, - errMsg: errMsg, - expectedActiveRequests: -1, - expectedTotalBlocks: -1, - expectedUnusedBlocks: -1, - } -} -func newTestActionWithExpectedValues(action ActionType, request testRequest, expectedActiveRequests int, - expectedTotalBlocks int, expectedUnusedBlocks int, expectedBlocksInfo map[uint64]expectedBlockInfo) testAction { - return testAction{ - action: action, - request: request, - isError: false, - expectedActiveRequests: expectedActiveRequests, - expectedTotalBlocks: expectedTotalBlocks, - expectedUnusedBlocks: expectedUnusedBlocks, - expectedBlocksInfo: expectedBlocksInfo, - } -} - -type testCase struct { - name string - cacheSize int - actions []testAction -} - -type threadTestCase struct { - name string - cacheSize int - numGoroutines int - numOperations int - minBlockLen int - maxBlockLen int - maxHashValue uint64 - shouldUseAllCache bool -} - -var _ = Describe("Block cache", Ordered, func() { - common.InitRandom(time.Now().UnixNano()) - - Context("general tests", func() { - // check single request processing, ensure cache is valid after request processing started - // and after the processing was finished - req1 := testRequest{req1ID, []uint64{1, 2}} - req2 := testRequest{req2ID, []uint64{3, 4}} - req2_1 := testRequest{req2ID, []uint64{1, 3}} - req3 := testRequest{req3ID, []uint64{5, 6}} - - testCases := []testCase{{ - name: "single request", - cacheSize: 3, - actions: []testAction{ - newTestActionWithExpectedValues(actionStartRequest, req1, 1, 2, 0, nil), - newTestActionWithExpectedValues(actionFinishRequest, req1, 0, 2, 2, nil), - }, - }, { - name: "two requests", - cacheSize: 5, - actions: []testAction{ - newStartAction(req1), - newTestActionWithExpectedValues(actionStartRequest, req2, 2, 4, 0, nil), - newTestActionWithExpectedValues(actionFinishRequest, req1, 1, 4, 2, nil), - newTestActionWithExpectedValues(actionFinishRequest, req2, 0, 4, 4, nil), - }, - }, { - name: "reusing blocks", - cacheSize: 5, - actions: []testAction{ - newStartAction(req1), - // Check block '1' reference count (should be 2) - newTestActionWithExpectedValues(actionStartRequest, req2_1, 2, 3, 0, map[uint64]expectedBlockInfo{1: {true, 2}}), - // Check block '1' reference count (should be 1) - newTestActionWithExpectedValues(actionFinishRequest, req1, 1, 3, 1, map[uint64]expectedBlockInfo{1: {true, 1}}), - }, - }, { - name: "block eviction", - cacheSize: 4, - actions: []testAction{ - newStartAction(req1), - newStartAction(req2), - newTestActionWithExpectedValues(actionFinishRequest, req2, -1, -1, -1, map[uint64]expectedBlockInfo{3: {true, 0}}), - newTestActionWithExpectedValues(actionStartRequest, req3, -1, -1, -1, map[uint64]expectedBlockInfo{ - 5: {true, 1}, - 3: {false, 0}, - }), - }, - }, { - name: "cache full, no eviction", - cacheSize: 4, - actions: []testAction{ - newStartAction(req1), - newStartAction(req2), - newInvalidTestAction(actionStartRequest, req3, capacityError), - }, - }} - - for _, test := range testCases { - It(test.name, func() { - ctx, cancel := context.WithCancel(context.Background()) - - wg := sync.WaitGroup{} - wg.Add(1) - - blockCache := newBlockCache(test.cacheSize, GinkgoLogr) - - go func() { - blockCache.start(ctx) - wg.Done() - }() - - defer func() { - cancel() - wg.Wait() // wait for goroutine to exit - }() - - for _, action := range test.actions { - var err error - - switch action.action { - case actionStartRequest: - err = blockCache.startRequest(action.request.id, action.request.blocks) - case actionFinishRequest: - err = blockCache.finishRequest(action.request.id) - } - - if action.isError { - Expect(err).To(HaveOccurred()) - if len(action.errMsg) > 0 { - Expect(err.Error()).To(Equal(action.errMsg)) - } - continue - } - - // ensure that error does not accured - Expect(err).NotTo(HaveOccurred()) - - // check cache info if required - if action.expectedActiveRequests >= 0 || action.expectedTotalBlocks >= 0 || action.expectedUnusedBlocks >= 0 { - activeRequests, totalBlocks, unusedBlocks := blockCache.getStats() - if action.expectedActiveRequests >= 0 { - Expect(activeRequests).To(Equal(action.expectedActiveRequests)) - } - if action.expectedTotalBlocks >= 0 { - Expect(totalBlocks).To(Equal(action.expectedTotalBlocks)) - } - if action.expectedUnusedBlocks >= 0 { - Expect(unusedBlocks).To(Equal(action.expectedUnusedBlocks)) - } - } - - // check specific blocks info if required - if len(action.expectedBlocksInfo) > 0 { - for block, expectedInfo := range action.expectedBlocksInfo { - refCount, exists := blockCache.getBlockInfo(block) - if expectedInfo.exists { - Expect(exists).To(BeTrue()) - } else { - Expect(exists).To(BeFalse()) - } - if expectedInfo.refCount >= 0 { - Expect(refCount).To(Equal(expectedInfo.refCount)) - } - } - } - } - }) - } - }) - - Context("thread safety", func() { - testCases := []threadTestCase{{ - name: "run add/remove requests in parallel, use partial cache", - cacheSize: 1000, - numGoroutines: 50, - numOperations: 100, - minBlockLen: 2, - maxBlockLen: 10, - maxHashValue: 100, - shouldUseAllCache: false, - }, { - name: "run add/remove requests in parallel, use all cache", - cacheSize: 100, - numGoroutines: 50, - numOperations: 10, - minBlockLen: 2, - maxBlockLen: 10, - maxHashValue: 100, - shouldUseAllCache: true, - }} - - for _, testCase := range testCases { - It(testCase.name, func() { - blockCache := newBlockCache(testCase.cacheSize, GinkgoLogr) - var wg sync.WaitGroup - - // Start multiple goroutines performing concurrent operations - for i := range testCase.numGoroutines { - wg.Add(1) - go func(id int) { - defer wg.Done() - - for j := range testCase.numOperations { - reqID := fmt.Sprintf("req_%d_%d", id, j) - blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, testCase.maxHashValue) - - err := blockCache.startRequest(reqID, blocks) - if err != nil { - // some operations may fail due to cache being full, which is expected - Expect(err.Error()).To(Equal(capacityError)) - continue - } - - time.Sleep(time.Duration(common.RandomInt(1, 100)) * time.Microsecond) - - err = blockCache.finishRequest(reqID) - Expect(err).NotTo(HaveOccurred()) - } - }(i) - } - - wg.Wait() - - activeReqs, totalBlocks, unusedBlocks := blockCache.getStats() - fmt.Printf("Thread safety test completed. Final stats: Active requests: %d, Total blocks: %d, Unused blocks: %d\n", - activeReqs, totalBlocks, unusedBlocks) - if testCase.shouldUseAllCache { - Expect(totalBlocks).To(Equal(testCase.cacheSize)) - } - Expect(totalBlocks).To(Equal(unusedBlocks)) - }) - } - }) -}) - -func createRandomArray(minArrLen, maxArrLen int, maxValue uint64) []uint64 { - // Random length between a and b (inclusive) - length := common.RandomInt(minArrLen, maxArrLen) - - // Create array with random values - arr := make([]uint64, 0) - seen := make(map[uint64]struct{}) - - for len(arr) < length { - val := uint64(common.RandomInt(0, int(maxValue))) - if _, exists := seen[val]; !exists { - seen[val] = struct{}{} - arr = append(arr, val) - } - } - - return arr -} diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index e68f750d..dbbd7645 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -47,15 +47,17 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCach tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir } tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) - if err != nil { return nil, fmt.Errorf("failed to create tokenizer: %w", err) } - + blockCache, err := newBlockCache(config, logger) + if err != nil { + return nil, fmt.Errorf("failed to create block cache: %w", err) + } return &KVCacheHelper{ tokenizer: tokenizer, tokensProcessor: tokensProcessor, - blockCache: newBlockCache(config.KVCacheSize, logger), + blockCache: blockCache, logger: logger, }, nil } diff --git a/pkg/kv-cache/kv_cache_sender.go b/pkg/kv-cache/kv_cache_sender.go index d15b7723..f8af3638 100644 --- a/pkg/kv-cache/kv_cache_sender.go +++ b/pkg/kv-cache/kv_cache_sender.go @@ -16,11 +16,13 @@ limitations under the License. package kvcache import ( + "bytes" "context" "fmt" "time" "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" "github.com/vmihailenco/msgpack/v5" ) @@ -32,21 +34,18 @@ const ( eventActionRemove ) +const ( + BlockStored = "BlockStored" + BlockRemoved = "BlockRemoved" +) + type EventData struct { action EventAction hashValues []uint64 } -type Publisher struct{} - -func (p *Publisher) PublishEvent(ctx context.Context, topic string, batch interface{}) error { - // mock implementation - fmt.Printf("Publish batch %#v\n", batch) - return nil -} - type KVEventSender struct { - publisher *Publisher + publisher *common.Publisher topic string eventChan chan EventData maxBatchSize int @@ -55,7 +54,8 @@ type KVEventSender struct { logger logr.Logger } -func NewKVEventSender(publisher *Publisher, topic string, ch chan EventData, maxBatchSize int, delay time.Duration, logger logr.Logger) *KVEventSender { +func NewKVEventSender(publisher *common.Publisher, topic string, ch chan EventData, maxBatchSize int, + delay time.Duration, logger logr.Logger) *KVEventSender { return &KVEventSender{ publisher: publisher, topic: topic, @@ -90,14 +90,24 @@ func (s *KVEventSender) Run(ctx context.Context) error { } // Encode eventData's hash value to msgpack.RawMessage - var blockPayloadBytes msgpack.RawMessage var err error + var payload bytes.Buffer + enc := msgpack.NewEncoder(&payload) + enc.UseArrayEncodedStructs(true) switch eventData.action { case eventActionStore: - blockPayloadBytes, err = msgpack.Marshal(kvevents.BlockStored{BlockHashes: eventData.hashValues}) + bs := &kvevents.BlockStoredEvent{ + TypeField: BlockStored, + BlockStored: &kvevents.BlockStored{BlockHashes: eventData.hashValues}, + } + err = enc.Encode(bs) case eventActionRemove: - blockPayloadBytes, err = msgpack.Marshal(kvevents.BlockRemoved{BlockHashes: eventData.hashValues}) + br := &kvevents.BlockRemovedEvent{ + TypeField: BlockRemoved, + BlockRemoved: &kvevents.BlockRemoved{BlockHashes: eventData.hashValues}, + } + err = enc.Encode(br) default: return fmt.Errorf("invalid event action %d", eventData.action) } @@ -105,7 +115,7 @@ func (s *KVEventSender) Run(ctx context.Context) error { return fmt.Errorf("failed to marshal value: %w", err) } - s.batch = append(s.batch, blockPayloadBytes) + s.batch = append(s.batch, payload.Bytes()) // check if batch is big enough to be sent if len(s.batch) >= s.maxBatchSize { diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go new file mode 100644 index 00000000..cc259c5b --- /dev/null +++ b/pkg/kv-cache/kv_cache_test.go @@ -0,0 +1,537 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvcache + +import ( + "context" + "encoding/binary" + "fmt" + "sync" + "time" + + zmq "github.com/pebbe/zmq4" + "github.com/vmihailenco/msgpack/v5" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + req1ID = "req1" + req2ID = "req2" + req3ID = "req3" + endpoint = "tcp://localhost:5557" +) + +type ActionType int + +const ( + actionStartRequest ActionType = iota + actionFinishRequest +) + +type testRequest struct { + id string + blocks []uint64 +} + +type expectedBlockInfo struct { + exists bool + refCount int +} + +type testAction struct { + action ActionType + request testRequest + isError bool + errMsg string + expectedActiveRequests int + expectedTotalBlocks int + expectedUnusedBlocks int + expectedBlocksInfo map[uint64]expectedBlockInfo +} + +func newStartAction(request testRequest) testAction { + return testAction{ + action: actionStartRequest, + request: request, + isError: false, + expectedActiveRequests: -1, + expectedTotalBlocks: -1, + expectedUnusedBlocks: -1, + } +} +func newInvalidTestAction(action ActionType, request testRequest, errMsg string) testAction { + return testAction{ + action: action, + request: request, + isError: true, + errMsg: errMsg, + expectedActiveRequests: -1, + expectedTotalBlocks: -1, + expectedUnusedBlocks: -1, + } +} +func newTestActionWithExpectedValues(action ActionType, request testRequest, expectedActiveRequests int, + expectedTotalBlocks int, expectedUnusedBlocks int, expectedBlocksInfo map[uint64]expectedBlockInfo) testAction { + return testAction{ + action: action, + request: request, + isError: false, + expectedActiveRequests: expectedActiveRequests, + expectedTotalBlocks: expectedTotalBlocks, + expectedUnusedBlocks: expectedUnusedBlocks, + expectedBlocksInfo: expectedBlocksInfo, + } +} + +type testCase struct { + name string + cacheSize int + actions []testAction + expectedRemovedBlocks int + expectedStoredBlocks int +} + +type threadTestCase struct { + name string + cacheSize int + numGoroutines int + numOperations int + minBlockLen int + maxBlockLen int + maxHashValue uint64 + shouldUseAllCache bool +} + +var _ = Describe("KV cache", Ordered, func() { + common.InitRandom(time.Now().UnixNano()) + + Context("general tests", func() { + // check single request processing, ensure cache is valid after request processing started + // and after the processing was finished + req1 := testRequest{req1ID, []uint64{1, 2}} + req2 := testRequest{req2ID, []uint64{3, 4}} + req2_1 := testRequest{req2ID, []uint64{1, 3}} + req3 := testRequest{req3ID, []uint64{5, 6}} + + testCases := []testCase{ + { + name: "single request", + cacheSize: 3, + actions: []testAction{ + newTestActionWithExpectedValues(actionStartRequest, req1, 1, 2, 0, nil), + newTestActionWithExpectedValues(actionFinishRequest, req1, 0, 2, 2, nil), + }, + expectedRemovedBlocks: 0, + expectedStoredBlocks: 2, + }, + { + name: "two requests", + cacheSize: 5, + actions: []testAction{ + newStartAction(req1), + newTestActionWithExpectedValues(actionStartRequest, req2, 2, 4, 0, nil), + newTestActionWithExpectedValues(actionFinishRequest, req1, 1, 4, 2, nil), + newTestActionWithExpectedValues(actionFinishRequest, req2, 0, 4, 4, nil), + }, + expectedRemovedBlocks: 0, + expectedStoredBlocks: 4, + }, + { + name: "reusing blocks", + cacheSize: 5, + actions: []testAction{ + newStartAction(req1), + // Check block '1' reference count (should be 2) + newTestActionWithExpectedValues(actionStartRequest, req2_1, 2, 3, 0, map[uint64]expectedBlockInfo{1: {true, 2}}), + // Check block '1' reference count (should be 1) + newTestActionWithExpectedValues(actionFinishRequest, req1, 1, 3, 1, map[uint64]expectedBlockInfo{1: {true, 1}}), + }, + expectedRemovedBlocks: 0, + expectedStoredBlocks: 3, + }, + { + name: "block eviction", + cacheSize: 4, + actions: []testAction{ + newStartAction(req1), + newStartAction(req2), + newTestActionWithExpectedValues(actionFinishRequest, req2, -1, -1, -1, map[uint64]expectedBlockInfo{3: {true, 0}}), + newTestActionWithExpectedValues(actionStartRequest, req3, -1, -1, -1, map[uint64]expectedBlockInfo{ + 5: {true, 1}, + 3: {false, 0}, + }), + }, + expectedRemovedBlocks: 2, + expectedStoredBlocks: 6, + }, + { + name: "cache full, no eviction", + cacheSize: 4, + actions: []testAction{ + newStartAction(req1), + newStartAction(req2), + newInvalidTestAction(actionStartRequest, req3, capacityError), + }, + expectedRemovedBlocks: 0, + expectedStoredBlocks: 4, + }, + } + + for _, test := range testCases { + It(test.name, func() { + time.Sleep(300 * time.Millisecond) + + config := &common.Configuration{ + Port: 1234, + Model: "model", + KVCacheSize: test.cacheSize, + ZMQEndpoint: endpoint, + EventBatchSize: 1, + } + + sub, topic := createSub(config) + //nolint + defer sub.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + blockCache, err := newBlockCache(config, GinkgoLogr) + Expect(err).NotTo(HaveOccurred()) + + go func() { + blockCache.start(ctx) + wg.Done() + }() + + defer func() { + cancel() + wg.Wait() // wait for goroutine to exit + }() + + go func() { + // Make sure that the subscriber listens before the events are published + time.Sleep(time.Second) + + for _, action := range test.actions { + var err error + switch action.action { + case actionStartRequest: + err = blockCache.startRequest(action.request.id, action.request.blocks) + case actionFinishRequest: + err = blockCache.finishRequest(action.request.id) + } + + if action.isError { + Expect(err).To(HaveOccurred()) + if len(action.errMsg) > 0 { + Expect(err.Error()).To(Equal(action.errMsg)) + } + continue + } + + // ensure that error has not occurred + Expect(err).NotTo(HaveOccurred()) + + // check cache info if required + if action.expectedActiveRequests >= 0 || action.expectedTotalBlocks >= 0 || action.expectedUnusedBlocks >= 0 { + activeRequests, totalBlocks, unusedBlocks := blockCache.getStats() + if action.expectedActiveRequests >= 0 { + Expect(activeRequests).To(Equal(action.expectedActiveRequests)) + } + if action.expectedTotalBlocks >= 0 { + Expect(totalBlocks).To(Equal(action.expectedTotalBlocks)) + } + if action.expectedUnusedBlocks >= 0 { + Expect(unusedBlocks).To(Equal(action.expectedUnusedBlocks)) + } + } + + // check specific blocks info if required + if len(action.expectedBlocksInfo) > 0 { + for block, expectedInfo := range action.expectedBlocksInfo { + refCount, exists := blockCache.getBlockInfo(block) + if expectedInfo.exists { + Expect(exists).To(BeTrue()) + } else { + Expect(exists).To(BeFalse()) + } + if expectedInfo.refCount >= 0 { + Expect(refCount).To(Equal(expectedInfo.refCount)) + } + } + } + } + }() + + storedCount := 0 + removedCount := 0 + for i := range test.expectedRemovedBlocks + test.expectedStoredBlocks { + parts, err := sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + stored, removed := parseEvent(parts, topic, uint64(i+1)) + storedCount += len(stored) + removedCount += len(removed) + } + Expect(removedCount).To(Equal(test.expectedRemovedBlocks)) + Expect(storedCount).To(Equal(test.expectedStoredBlocks)) + }) + } + }) + + Context("events", func() { + + It("should send events correctly", func() { + config := &common.Configuration{ + Port: 1234, + Model: "model", + KVCacheSize: 4, + ZMQEndpoint: endpoint, + } + + sub, topic := createSub(config) + //nolint + defer sub.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + + blockCache, err := newBlockCache(config, GinkgoLogr) + Expect(err).NotTo(HaveOccurred()) + + go func() { + blockCache.start(ctx) + wg.Done() + }() + + defer func() { + cancel() + wg.Wait() // wait for goroutine to exit + }() + + expectedRemovedBlocks := []uint64{2, 4} + expectedStoredBlocks := []uint64{1, 2, 3, 4, 5, 6} + + go func() { + // Make sure that the subscriber listens before the events are published + time.Sleep(time.Second) + + req1 := testRequest{"req1", []uint64{1, 2}} + req2 := testRequest{"req2", []uint64{3, 4}} + req3 := testRequest{"req3", []uint64{1, 3}} + req4 := testRequest{"req4", []uint64{5, 6}} + + // blocks 1 and 2 stored + err = blockCache.startRequest(req1.id, req1.blocks) + Expect(err).NotTo(HaveOccurred()) + // blocks 3 and 4 stored + err = blockCache.startRequest(req2.id, req2.blocks) + Expect(err).NotTo(HaveOccurred()) + // no new blocks stored, reuse of 1 and 3 + err = blockCache.startRequest(req3.id, req3.blocks) + Expect(err).NotTo(HaveOccurred()) + // no space left - should fail + err = blockCache.startRequest(req4.id, req4.blocks) + Expect(err).To(HaveOccurred()) + + err = blockCache.finishRequest(req1.id) + Expect(err).NotTo(HaveOccurred()) + err = blockCache.finishRequest(req2.id) + Expect(err).NotTo(HaveOccurred()) + // now 2 and 4 are not in use + + // blocks 2 and 4 should be removed, and 5 and 6 stored + err = blockCache.startRequest(req4.id, req4.blocks) + Expect(err).NotTo(HaveOccurred()) + }() + + removedBlocks := make([]uint64, 0) + storedBlocks := make([]uint64, 0) + count := uint64(1) + for { + parts, err := sub.RecvMessageBytes(0) + Expect(err).NotTo(HaveOccurred()) + stored, removed := parseEvent(parts, topic, count) + storedBlocks = append(storedBlocks, stored...) + removedBlocks = append(removedBlocks, removed...) + count++ + + if len(removedBlocks) == len(expectedRemovedBlocks) && len(storedBlocks) == len(expectedStoredBlocks) { + break + } + } + Expect(removedBlocks).To(Equal(expectedRemovedBlocks)) + Expect(storedBlocks).To(Equal(expectedStoredBlocks)) + }) + + }) + + Context("thread safety", func() { + testCases := []threadTestCase{{ + name: "run add/remove requests in parallel, use partial cache", + cacheSize: 1000, + numGoroutines: 50, + numOperations: 100, + minBlockLen: 2, + maxBlockLen: 10, + maxHashValue: 100, + shouldUseAllCache: false, + }, { + name: "run add/remove requests in parallel, use all cache", + cacheSize: 100, + numGoroutines: 50, + numOperations: 10, + minBlockLen: 2, + maxBlockLen: 10, + maxHashValue: 100, + shouldUseAllCache: true, + }} + + for _, testCase := range testCases { + It(testCase.name, func() { + config := common.Configuration{ + Port: 1234, + Model: "model", + KVCacheSize: testCase.cacheSize, + ZMQEndpoint: endpoint, + } + blockCache, err := newBlockCache(&config, GinkgoLogr) + Expect(err).NotTo(HaveOccurred()) + var wg sync.WaitGroup + + // Start multiple goroutines performing concurrent operations + for i := range testCase.numGoroutines { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for j := range testCase.numOperations { + reqID := fmt.Sprintf("req_%d_%d", id, j) + blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, testCase.maxHashValue) + + err := blockCache.startRequest(reqID, blocks) + if err != nil { + // some operations may fail due to cache being full, which is expected + Expect(err.Error()).To(Equal(capacityError)) + continue + } + + time.Sleep(time.Duration(common.RandomInt(1, 100)) * time.Microsecond) + + err = blockCache.finishRequest(reqID) + Expect(err).NotTo(HaveOccurred()) + } + }(i) + } + + wg.Wait() + + activeReqs, totalBlocks, unusedBlocks := blockCache.getStats() + fmt.Printf("Thread safety test completed. Final stats: Active requests: %d, Total blocks: %d, Unused blocks: %d\n", + activeReqs, totalBlocks, unusedBlocks) + if testCase.shouldUseAllCache { + Expect(totalBlocks).To(Equal(testCase.cacheSize)) + } + Expect(totalBlocks).To(Equal(unusedBlocks)) + }) + } + }) +}) + +func createRandomArray(minArrLen, maxArrLen int, maxValue uint64) []uint64 { + // Random length between a and b (inclusive) + length := common.RandomInt(minArrLen, maxArrLen) + + // Create array with random values + arr := make([]uint64, 0) + seen := make(map[uint64]struct{}) + + for len(arr) < length { + val := uint64(common.RandomInt(0, int(maxValue))) + if _, exists := seen[val]; !exists { + seen[val] = struct{}{} + arr = append(arr, val) + } + } + + return arr +} + +func parseEvent(parts [][]byte, expectedTopic string, expectedSeq uint64) ([]uint64, []uint64) { + // The message should be [topic, seq, payload] + Expect(parts).To(HaveLen(3)) + + Expect(string(parts[0])).To(Equal(expectedTopic)) + + seq := binary.BigEndian.Uint64(parts[1]) + Expect(seq).To(Equal(expectedSeq)) + + removed := make([]uint64, 0) + stored := make([]uint64, 0) + + var eventBatch kvevents.EventBatch + err := msgpack.Unmarshal(parts[2], &eventBatch) + Expect(err).NotTo(HaveOccurred()) + for _, rawEvent := range eventBatch.Events { + var taggedUnion []msgpack.RawMessage + err = msgpack.Unmarshal(rawEvent, &taggedUnion) + Expect(err).NotTo(HaveOccurred()) + Expect(len(taggedUnion)).To(BeNumerically(">", 1)) + + var tag string + err = msgpack.Unmarshal(taggedUnion[0], &tag) + Expect(err).NotTo(HaveOccurred()) + + switch tag { + case BlockStored: + var bs kvevents.BlockStoredEvent + err = msgpack.Unmarshal(rawEvent, &bs) + stored = append(stored, bs.BlockHashes...) + case BlockRemoved: + var br kvevents.BlockRemovedEvent + err = msgpack.Unmarshal(rawEvent, &br) + removed = append(removed, br.BlockHashes...) + + default: + Fail("unexpected tag " + tag) + continue + } + Expect(err).NotTo(HaveOccurred()) + } + return stored, removed +} + +func createSub(config *common.Configuration) (*zmq.Socket, string) { + zctx, err := zmq.NewContext() + Expect(err).NotTo(HaveOccurred()) + sub, err := zctx.NewSocket(zmq.SUB) + Expect(err).NotTo(HaveOccurred()) + err = sub.Bind(endpoint) + Expect(err).NotTo(HaveOccurred()) + topic := createTopic(config) + err = sub.SetSubscribe(topic) + Expect(err).NotTo(HaveOccurred()) + return sub, topic +} From 69301924783120bff6c66e44d07e1c9ff6f0ff22 Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Wed, 13 Aug 2025 17:52:48 -0400 Subject: [PATCH 05/46] Add failure injection mode to simulator Introduces a 'failure' mode to the simulator, allowing random injection of OpenAI API-compatible error responses for testing error handling. Adds configuration options for failure injection rate and specific failure types, implements error response logic, and updates documentation and tests to cover the new functionality. Signed-off-by: Sergey Marunich --- README.md | 6 +- pkg/common/config.go | 51 +++- pkg/common/failures.go | 122 ++++++++++ pkg/common/failures_test.go | 134 +++++++++++ pkg/llm-d-inference-sim/simulator.go | 34 ++- pkg/llm-d-inference-sim/simulator_test.go | 271 +++++++++++++++++++++- pkg/openai-server-api/response.go | 13 +- 7 files changed, 609 insertions(+), 22 deletions(-) create mode 100644 pkg/common/failures.go create mode 100644 pkg/common/failures_test.go diff --git a/README.md b/README.md index fb6636c3..c69d3cf1 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,10 @@ In addition, it supports a subset of vLLM's Prometheus metrics. These metrics ar The simulated inference has no connection with the model and LoRA adapters specified in the command line parameters or via the /v1/load_lora_adapter HTTP REST endpoint. The /v1/models endpoint returns simulated results based on those same command line parameters and those loaded via the /v1/load_lora_adapter HTTP REST endpoint. -The simulator supports two modes of operation: +The simulator supports three modes of operation: - `echo` mode: the response contains the same text that was received in the request. For `/v1/chat/completions` the last message for the role=`user` is used. - `random` mode: the response is randomly chosen from a set of pre-defined sentences. +- `failure` mode: randomly injects OpenAI API compatible error responses for testing error handling. Timing of the response is defined by the `time-to-first-token` and `inter-token-latency` parameters. In case P/D is enabled for a request, `kv-cache-transfer-latency` will be used instead of `time-to-first-token`. @@ -101,6 +102,7 @@ For more details see the 100 { + return errors.New("failure injection rate should be between 0 and 100") + } + + validFailureTypes := map[string]bool{ + "rate_limit": true, + "invalid_api_key": true, + "context_length": true, + "server_error": true, + "invalid_request": true, + "model_not_found": true, + } + for _, failureType := range c.FailureTypes { + if !validFailureTypes[failureType] { + return fmt.Errorf("invalid failure type '%s', valid types are: rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found", failureType) + } + } + return nil } @@ -326,7 +353,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences; failure - randomly injects API errors") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -351,6 +378,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") + + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures when in failure mode") + + failureTypes := getParamValueFromArgs("failure-types") + var dummyFailureTypes multiString + f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") + f.Lookup("failure-types").NoOptDefVal = "dummy" // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -384,6 +418,9 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { if servedModelNames != nil { config.ServedModelNames = servedModelNames } + if failureTypes != nil { + config.FailureTypes = failureTypes + } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") diff --git a/pkg/common/failures.go b/pkg/common/failures.go new file mode 100644 index 00000000..e4fc0dd6 --- /dev/null +++ b/pkg/common/failures.go @@ -0,0 +1,122 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" + "math/rand" + "time" +) + +type FailureSpec struct { + StatusCode int + ErrorType string + ErrorCode string + Message string + Param *string +} + +var predefinedFailures = map[string]FailureSpec{ + "rate_limit": { + StatusCode: 429, + ErrorType: "rate_limit_exceeded", + ErrorCode: "rate_limit_exceeded", + Message: "Rate limit reached for model in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1.", + Param: nil, + }, + "invalid_api_key": { + StatusCode: 401, + ErrorType: "invalid_request_error", + ErrorCode: "invalid_api_key", + Message: "Incorrect API key provided", + Param: nil, + }, + "context_length": { + StatusCode: 400, + ErrorType: "invalid_request_error", + ErrorCode: "context_length_exceeded", + Message: "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", + Param: stringPtr("messages"), + }, + "server_error": { + StatusCode: 503, + ErrorType: "server_error", + ErrorCode: "server_error", + Message: "The server is overloaded or not ready yet.", + Param: nil, + }, + "invalid_request": { + StatusCode: 400, + ErrorType: "invalid_request_error", + ErrorCode: "invalid_request_error", + Message: "Invalid request: missing required parameter 'model'.", + Param: stringPtr("model"), + }, + "model_not_found": { + StatusCode: 404, + ErrorType: "invalid_request_error", + ErrorCode: "model_not_found", + Message: "The model 'gpt-nonexistent' does not exist", + Param: stringPtr("model"), + }, +} + +// ShouldInjectFailure determines whether to inject a failure based on configuration +func ShouldInjectFailure(config *Configuration) bool { + if config.Mode != ModeFailure { + return false + } + + rand.Seed(time.Now().UnixNano()) + return rand.Intn(100) < config.FailureInjectionRate +} + +// GetRandomFailure returns a random failure from configured types or all types if none specified +func GetRandomFailure(config *Configuration) FailureSpec { + rand.Seed(time.Now().UnixNano()) + + var availableFailures []string + if len(config.FailureTypes) == 0 { + // Use all failure types if none specified + for failureType := range predefinedFailures { + availableFailures = append(availableFailures, failureType) + } + } else { + availableFailures = config.FailureTypes + } + + if len(availableFailures) == 0 { + // Fallback to server_error if no valid types + return predefinedFailures["server_error"] + } + + randomType := availableFailures[rand.Intn(len(availableFailures))] + + // Customize message with current model name + failure := predefinedFailures[randomType] + if randomType == "rate_limit" && config.Model != "" { + failure.Message = fmt.Sprintf("Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1.", config.Model) + } else if randomType == "model_not_found" && config.Model != "" { + failure.Message = fmt.Sprintf("The model '%s-nonexistent' does not exist", config.Model) + } + + return failure +} + +func stringPtr(s string) *string { + return &s +} \ No newline at end of file diff --git a/pkg/common/failures_test.go b/pkg/common/failures_test.go new file mode 100644 index 00000000..6fb50287 --- /dev/null +++ b/pkg/common/failures_test.go @@ -0,0 +1,134 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common_test + +import ( + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" +) + +var _ = Describe("Failures", func() { + Describe("ShouldInjectFailure", func() { + It("should not inject failure when not in failure mode", func() { + config := &common.Configuration{ + Mode: common.ModeRandom, + FailureInjectionRate: 100, + } + Expect(common.ShouldInjectFailure(config)).To(BeFalse()) + }) + + It("should not inject failure when rate is 0", func() { + config := &common.Configuration{ + Mode: common.ModeFailure, + FailureInjectionRate: 0, + } + Expect(common.ShouldInjectFailure(config)).To(BeFalse()) + }) + + It("should inject failure when in failure mode with 100% rate", func() { + config := &common.Configuration{ + Mode: common.ModeFailure, + FailureInjectionRate: 100, + } + Expect(common.ShouldInjectFailure(config)).To(BeTrue()) + }) + }) + + Describe("GetRandomFailure", func() { + It("should return a failure from all types when none specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(BeNumerically(">=", 400)) + Expect(failure.Message).ToNot(BeEmpty()) + Expect(failure.ErrorType).ToNot(BeEmpty()) + }) + + It("should return rate limit failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{"rate_limit"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(429)) + Expect(failure.ErrorType).To(Equal("rate_limit_exceeded")) + Expect(failure.ErrorCode).To(Equal("rate_limit_exceeded")) + Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) + }) + + It("should return invalid API key failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{"invalid_api_key"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(401)) + Expect(failure.ErrorType).To(Equal("invalid_request_error")) + Expect(failure.ErrorCode).To(Equal("invalid_api_key")) + Expect(failure.Message).To(Equal("Incorrect API key provided")) + }) + + It("should return context length failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{"context_length"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(400)) + Expect(failure.ErrorType).To(Equal("invalid_request_error")) + Expect(failure.ErrorCode).To(Equal("context_length_exceeded")) + Expect(failure.Param).ToNot(BeNil()) + Expect(*failure.Param).To(Equal("messages")) + }) + + It("should return server error when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{"server_error"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(503)) + Expect(failure.ErrorType).To(Equal("server_error")) + Expect(failure.ErrorCode).To(Equal("server_error")) + }) + + It("should return model not found failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{"model_not_found"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(404)) + Expect(failure.ErrorType).To(Equal("invalid_request_error")) + Expect(failure.ErrorCode).To(Equal("model_not_found")) + Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) + }) + + It("should return server error as fallback for empty types", func() { + config := &common.Configuration{ + FailureTypes: []string{}, + } + // This test is probabilistic since it randomly selects, but we can test structure + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(BeNumerically(">=", 400)) + Expect(failure.ErrorType).ToNot(BeEmpty()) + }) + }) +}) \ No newline at end of file diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 77b0738a..475e4c66 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -291,6 +291,13 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { + // Check if we should inject a failure + if common.ShouldInjectFailure(s.config) { + failure := common.GetRandomFailure(s.config) + s.sendFailureResponse(ctx, failure) + return + } + vllmReq, err := s.readRequest(ctx, isChatCompletion) if err != nil { s.logger.Error(err, "failed to read and parse request body") @@ -485,15 +492,15 @@ func (s *VllmSimulator) responseSentCallback(model string) { // sendCompletionError sends an error response for the current completion request func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string, errType string, code int) { compErr := openaiserverapi.CompletionError{ - Object: "error", Message: msg, Type: errType, - Code: code, + Code: "", Param: nil, } + errorResp := openaiserverapi.ErrorResponse{Error: compErr} s.logger.Error(nil, compErr.Message) - data, err := json.Marshal(compErr) + data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { @@ -503,6 +510,27 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string } } +// sendFailureResponse sends a predefined failure response for testing +func (s *VllmSimulator) sendFailureResponse(ctx *fasthttp.RequestCtx, failure common.FailureSpec) { + compErr := openaiserverapi.CompletionError{ + Message: failure.Message, + Type: failure.ErrorType, + Code: failure.ErrorCode, + Param: failure.Param, + } + errorResp := openaiserverapi.ErrorResponse{Error: compErr} + s.logger.Info("Injecting failure", "type", failure.ErrorType, "message", failure.Message) + + data, err := json.Marshal(errorResp) + if err != nil { + ctx.Error(err.Error(), fasthttp.StatusInternalServerError) + } else { + ctx.SetContentType("application/json") + ctx.SetStatusCode(failure.StatusCode) + ctx.SetBody(data) + } +} + // HandleModels handles /v1/models request according the data stored in the simulator func (s *VllmSimulator) HandleModels(ctx *fasthttp.RequestCtx) { modelsResp := s.createModelsResponse() diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index fb8c0e8f..89e60e4e 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -31,7 +31,6 @@ import ( . "github.com/onsi/gomega" "github.com/openai/openai-go" "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" "github.com/valyala/fasthttp/fasthttputil" "k8s.io/klog/v2" ) @@ -120,7 +119,7 @@ var _ = Describe("Simulator", func() { openai.UserMessage(userMessage), }, Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: openai.Bool(true)}, } stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) defer func() { @@ -183,7 +182,7 @@ var _ = Describe("Simulator", func() { OfString: openai.String(userMessage), }, Model: openai.CompletionNewParamsModel(model), - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: openai.Bool(true)}, } stream := openaiclient.Completions.NewStreaming(ctx, params) defer func() { @@ -246,11 +245,11 @@ var _ = Describe("Simulator", func() { // if maxTokens and maxCompletionTokens are passsed // maxCompletionTokens is used if maxTokens != 0 { - params.MaxTokens = param.NewOpt(int64(maxTokens)) + params.MaxTokens = openai.Int(int64(maxTokens)) numTokens = maxTokens } if maxCompletionTokens != 0 { - params.MaxCompletionTokens = param.NewOpt(int64(maxCompletionTokens)) + params.MaxCompletionTokens = openai.Int(int64(maxCompletionTokens)) numTokens = maxCompletionTokens } resp, err := openaiclient.Chat.Completions.New(ctx, params) @@ -329,7 +328,7 @@ var _ = Describe("Simulator", func() { } numTokens := 0 if maxTokens != 0 { - params.MaxTokens = param.NewOpt(int64(maxTokens)) + params.MaxTokens = openai.Int(int64(maxTokens)) numTokens = maxTokens } resp, err := openaiclient.Completions.New(ctx, params) @@ -589,4 +588,264 @@ var _ = Describe("Simulator", func() { Entry(nil, 10000, 0, 1000, 0, false), ) }) + + Describe("Failure injection mode", func() { + var ( + client *http.Client + ctx context.Context + ) + + AfterEach(func() { + if ctx != nil { + ctx.Done() + } + }) + + Context("with 100% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "100", + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should always return an error response for chat completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + + It("should always return an error response for text completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{ + Model: openai.CompletionNewParamsModel(model), + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + }) + + Context("with specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "100", + "--failure-types", "rate_limit", + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only rate limit errors", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(429)) + Expect(openaiError.Type).To(Equal("rate_limit_exceeded")) + Expect(strings.Contains(openaiError.Message, model)).To(BeTrue()) + }) + }) + + Context("with multiple specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "100", + "--failure-types", "invalid_api_key", "server_error", + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only specified error types", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + // Make multiple requests to verify we get the expected error types + for i := 0; i < 10; i++ { + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + + // Should only be one of the specified types + Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) + Expect(openaiError.Type == "invalid_request_error" || openaiError.Type == "server_error").To(BeTrue()) + } + }) + }) + + Context("with 0% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "0", + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should never return errors and behave like random mode", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Choices).To(HaveLen(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + Expect(resp.Model).To(Equal(model)) + }) + }) + + Context("testing all predefined failure types", func() { + DescribeTable("should return correct error for each failure type", + func(failureType string, expectedStatusCode int, expectedErrorType string, expectedErrorCode string) { + ctx := context.Background() + client, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "100", + "--failure-types", failureType, + }) + Expect(err).ToNot(HaveOccurred()) + + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(expectedStatusCode)) + Expect(openaiError.Type).To(Equal(expectedErrorType)) + // Note: OpenAI Go client doesn't directly expose the error code field, + // but we can verify via status code and type + }, + Entry("rate_limit", "rate_limit", 429, "rate_limit_exceeded", "rate_limit_exceeded"), + Entry("invalid_api_key", "invalid_api_key", 401, "invalid_request_error", "invalid_api_key"), + Entry("context_length", "context_length", 400, "invalid_request_error", "context_length_exceeded"), + Entry("server_error", "server_error", 503, "server_error", "server_error"), + Entry("invalid_request", "invalid_request", 400, "invalid_request_error", "invalid_request_error"), + Entry("model_not_found", "model_not_found", 404, "invalid_request_error", "model_not_found"), + ) + }) + + Context("configuration validation", func() { + It("should fail with invalid failure injection rate > 100", func() { + ctx := context.Background() + _, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "150", + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failure injection rate should be between 0 and 100")) + }) + + It("should fail with invalid failure injection rate < 0", func() { + ctx := context.Background() + _, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "-10", + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failure injection rate should be between 0 and 100")) + }) + + It("should fail with invalid failure type", func() { + ctx := context.Background() + _, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--mode", "failure", + "--failure-injection-rate", "50", + "--failure-types", "invalid_type", + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("invalid failure type 'invalid_type'")) + }) + }) + }) }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index a8f4a652..9e8549b3 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -208,14 +208,17 @@ type ChatRespChunkChoice struct { // CompletionError defines the simulator's response in case of an error type CompletionError struct { - // Object is a type of this Object, "error" - Object string `json:"object"` // Message is an error Message Message string `json:"message"` // Type is a type of the error Type string `json:"type"` - // Params is the error's parameters + // Param is the error's parameter Param *string `json:"param"` - // Code is http status Code - Code int `json:"code"` + // Code is the error code string + Code string `json:"code,omitempty"` +} + +// ErrorResponse wraps the error in the expected OpenAI format +type ErrorResponse struct { + Error CompletionError `json:"error"` } From 5ec92b836a6db4c2ab7cfcd8c7902741778a101d Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Thu, 14 Aug 2025 11:46:44 -0400 Subject: [PATCH 06/46] Refactor failure injection and update simulator error handling Failure injection is now controlled by a dedicated 'failure-injection-rate' parameter instead of a separate 'failure' mode. Failure type constants are centralized, and error handling in the simulator is refactored to use a unified method for sending error responses. Documentation and tests are updated to reflect these changes, and the OpenAI error response format now includes an 'object' field. Signed-off-by: Sergey Marunich --- Dockerfile | 2 +- README.md | 6 +- go.sum | 8 +- pkg/common/config.go | 27 ++++-- .../failures.go | 48 +++++----- .../failures_test.go | 46 ++++------ pkg/llm-d-inference-sim/simulator.go | 92 ++++++++++++------- pkg/llm-d-inference-sim/simulator_test.go | 12 +-- pkg/openai-server-api/response.go | 3 +- 9 files changed, 133 insertions(+), 111 deletions(-) rename pkg/{common => llm-d-inference-sim}/failures.go (68%) rename pkg/{common => llm-d-inference-sim}/failures_test.go (75%) diff --git a/Dockerfile b/Dockerfile index 87ba47d7..2af4a795 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,7 @@ COPY . . # HuggingFace tokenizer bindings RUN mkdir -p lib -RUN curl -L https://github.com/daulet/tokenizers/releases/download/v1.20.2/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib +RUN curl -L https://github.com/daulet/tokenizers/releases/download/v1.22.1/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib RUN ranlib lib/*.a # Build diff --git a/README.md b/README.md index c69d3cf1..550366b5 100644 --- a/README.md +++ b/README.md @@ -29,10 +29,11 @@ In addition, it supports a subset of vLLM's Prometheus metrics. These metrics ar The simulated inference has no connection with the model and LoRA adapters specified in the command line parameters or via the /v1/load_lora_adapter HTTP REST endpoint. The /v1/models endpoint returns simulated results based on those same command line parameters and those loaded via the /v1/load_lora_adapter HTTP REST endpoint. -The simulator supports three modes of operation: +The simulator supports two modes of operation: - `echo` mode: the response contains the same text that was received in the request. For `/v1/chat/completions` the last message for the role=`user` is used. - `random` mode: the response is randomly chosen from a set of pre-defined sentences. -- `failure` mode: randomly injects OpenAI API compatible error responses for testing error handling. + +Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. Timing of the response is defined by the `time-to-first-token` and `inter-token-latency` parameters. In case P/D is enabled for a request, `kv-cache-transfer-latency` will be used instead of `time-to-first-token`. @@ -102,7 +103,6 @@ For more details see the =", 400)) Expect(failure.Message).ToNot(BeEmpty()) Expect(failure.ErrorType).ToNot(BeEmpty()) @@ -67,9 +61,9 @@ var _ = Describe("Failures", func() { It("should return rate limit failure when specified", func() { config := &common.Configuration{ Model: "test-model", - FailureTypes: []string{"rate_limit"}, + FailureTypes: []string{common.FailureTypeRateLimit}, } - failure := common.GetRandomFailure(config) + failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(429)) Expect(failure.ErrorType).To(Equal("rate_limit_exceeded")) Expect(failure.ErrorCode).To(Equal("rate_limit_exceeded")) @@ -78,9 +72,9 @@ var _ = Describe("Failures", func() { It("should return invalid API key failure when specified", func() { config := &common.Configuration{ - FailureTypes: []string{"invalid_api_key"}, + FailureTypes: []string{common.FailureTypeInvalidAPIKey}, } - failure := common.GetRandomFailure(config) + failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(401)) Expect(failure.ErrorType).To(Equal("invalid_request_error")) Expect(failure.ErrorCode).To(Equal("invalid_api_key")) @@ -89,9 +83,9 @@ var _ = Describe("Failures", func() { It("should return context length failure when specified", func() { config := &common.Configuration{ - FailureTypes: []string{"context_length"}, + FailureTypes: []string{common.FailureTypeContextLength}, } - failure := common.GetRandomFailure(config) + failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(400)) Expect(failure.ErrorType).To(Equal("invalid_request_error")) Expect(failure.ErrorCode).To(Equal("context_length_exceeded")) @@ -101,9 +95,9 @@ var _ = Describe("Failures", func() { It("should return server error when specified", func() { config := &common.Configuration{ - FailureTypes: []string{"server_error"}, + FailureTypes: []string{common.FailureTypeServerError}, } - failure := common.GetRandomFailure(config) + failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(503)) Expect(failure.ErrorType).To(Equal("server_error")) Expect(failure.ErrorCode).To(Equal("server_error")) @@ -112,9 +106,9 @@ var _ = Describe("Failures", func() { It("should return model not found failure when specified", func() { config := &common.Configuration{ Model: "test-model", - FailureTypes: []string{"model_not_found"}, + FailureTypes: []string{common.FailureTypeModelNotFound}, } - failure := common.GetRandomFailure(config) + failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(404)) Expect(failure.ErrorType).To(Equal("invalid_request_error")) Expect(failure.ErrorCode).To(Equal("model_not_found")) @@ -126,7 +120,7 @@ var _ = Describe("Failures", func() { FailureTypes: []string{}, } // This test is probabilistic since it randomly selects, but we can test structure - failure := common.GetRandomFailure(config) + failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(BeNumerically(">=", 400)) Expect(failure.ErrorType).ToNot(BeEmpty()) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 475e4c66..00ae329f 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -292,9 +292,9 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure - if common.ShouldInjectFailure(s.config) { - failure := common.GetRandomFailure(s.config) - s.sendFailureResponse(ctx, failure) + if ShouldInjectFailure(s.config) { + failure := GetRandomFailure(s.config) + s.sendCompletionError(ctx, failure, true) return } @@ -307,7 +307,13 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple errMsg, errType, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, errMsg, errType, errCode) + s.sendCompletionError(ctx, FailureSpec{ + StatusCode: errCode, + ErrorType: errType, + ErrorCode: "", + Message: errMsg, + Param: nil, + }, false) return } @@ -334,8 +340,14 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest) + s.sendCompletionError(ctx, FailureSpec{ + StatusCode: fasthttp.StatusBadRequest, + ErrorType: "BadRequestError", + ErrorCode: "", + Message: fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), + Param: nil, + }, false) return } @@ -490,43 +502,53 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string, errType string, code int) { - compErr := openaiserverapi.CompletionError{ - Message: msg, - Type: errType, - Code: "", - Param: nil, - } - errorResp := openaiserverapi.ErrorResponse{Error: compErr} - s.logger.Error(nil, compErr.Message) - - data, err := json.Marshal(errorResp) - if err != nil { - ctx.Error(err.Error(), fasthttp.StatusInternalServerError) +// The first parameter can be either a string message or a FailureSpec +// isInjected indicates if this is an injected failure for logging purposes +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo interface{}, isInjected bool) { + var compErr openaiserverapi.CompletionError + var statusCode int + + switch v := errorInfo.(type) { + case string: + // Legacy call with string message (backward compatibility) + compErr = openaiserverapi.CompletionError{ + Message: v, + Type: "BadRequestError", + Code: "", + Param: nil, + } + statusCode = fasthttp.StatusBadRequest + case FailureSpec: + // New call with FailureSpec + compErr = openaiserverapi.CompletionError{ + Message: v.Message, + Type: v.ErrorType, + Code: v.ErrorCode, + Param: v.Param, + } + statusCode = v.StatusCode + default: + // For calls with msg, errType, and code - need to be updated in calling code + panic("sendCompletionError called with unexpected type") + } + + errorResp := openaiserverapi.ErrorResponse{ + Object: "error", + Error: compErr, + } + + if isInjected { + s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) } else { - ctx.SetContentType("application/json") - ctx.SetStatusCode(code) - ctx.SetBody(data) - } -} - -// sendFailureResponse sends a predefined failure response for testing -func (s *VllmSimulator) sendFailureResponse(ctx *fasthttp.RequestCtx, failure common.FailureSpec) { - compErr := openaiserverapi.CompletionError{ - Message: failure.Message, - Type: failure.ErrorType, - Code: failure.ErrorCode, - Param: failure.Param, + s.logger.Error(nil, compErr.Message) } - errorResp := openaiserverapi.ErrorResponse{Error: compErr} - s.logger.Info("Injecting failure", "type", failure.ErrorType, "message", failure.Message) data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(failure.StatusCode) + ctx.SetStatusCode(statusCode) ctx.SetBody(data) } } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 89e60e4e..467185fe 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -43,7 +43,8 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var userMsgTokens int64 func startServer(ctx context.Context, mode string) (*http.Client, error) { - return startServerWithArgs(ctx, mode, nil) + // Disable failure injection for tests by default + return startServerWithArgs(ctx, mode, []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"}) } func startServerWithArgs(ctx context.Context, mode string, args []string) (*http.Client, error) { @@ -55,7 +56,7 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http if args != nil { os.Args = args } else { - os.Args = []string{"cmd", "--model", model, "--mode", mode} + os.Args = []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"} } logger := klog.Background() @@ -607,7 +608,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", }) Expect(err).ToNot(HaveOccurred()) @@ -666,7 +666,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", "rate_limit", }) @@ -703,7 +702,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", "invalid_api_key", "server_error", }) @@ -744,7 +742,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "0", }) Expect(err).ToNot(HaveOccurred()) @@ -776,7 +773,6 @@ var _ = Describe("Simulator", func() { ctx := context.Background() client, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", failureType, }) @@ -828,7 +824,6 @@ var _ = Describe("Simulator", func() { ctx := context.Background() _, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "-10", }) Expect(err).To(HaveOccurred()) @@ -839,7 +834,6 @@ var _ = Describe("Simulator", func() { ctx := context.Background() _, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "50", "--failure-types", "invalid_type", }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index 9e8549b3..f816b06d 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -220,5 +220,6 @@ type CompletionError struct { // ErrorResponse wraps the error in the expected OpenAI format type ErrorResponse struct { - Error CompletionError `json:"error"` + Object string `json:"object"` + Error CompletionError `json:"error"` } From 8e0eefa6c50f54d323b99ae016d2268850a2e376 Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Thu, 14 Aug 2025 11:50:05 -0400 Subject: [PATCH 07/46] Make tokenizer version configurable from Dockerfile Extracts TOKENIZER_VERSION from the Dockerfile and uses it in the download-tokenizer target. This allows the Makefile to automatically use the correct tokenizer version specified in the Dockerfile, improving maintainability and consistency. Signed-off-by: Sergey Marunich --- Makefile | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 0f761f54..60d105ad 100644 --- a/Makefile +++ b/Makefile @@ -39,14 +39,16 @@ help: ## Print help LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' CGO_ENABLED=1 TOKENIZER_LIB = lib/libtokenizers.a +# Extract TOKENIZER_VERSION from Dockerfile +TOKENIZER_VERSION := $(shell grep '^ARG TOKENIZER_VERSION=' Dockerfile | cut -d'=' -f2) .PHONY: download-tokenizer download-tokenizer: $(TOKENIZER_LIB) $(TOKENIZER_LIB): ## Download the HuggingFace tokenizer bindings. - @echo "Downloading HuggingFace tokenizer bindings..." + @echo "Downloading HuggingFace tokenizer bindings for version $(TOKENIZER_VERSION)..." mkdir -p lib - curl -L https://github.com/daulet/tokenizers/releases/download/v1.22.1/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib ranlib lib/*.a ##@ Development @@ -224,4 +226,4 @@ download-zmq: ## Install ZMQ dependencies based on OS/ARCH exit 1; \ fi; \ echo "✅ ZMQ dependencies installed."; \ - fi + fi \ No newline at end of file From 75dcb722476e815ad45c5ea9ce98bc1225119d0b Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Wed, 13 Aug 2025 17:52:48 -0400 Subject: [PATCH 08/46] Add failure injection mode to simulator Introduces a 'failure' mode to the simulator, allowing random injection of OpenAI API-compatible error responses for testing error handling. Adds configuration options for failure injection rate and specific failure types, implements error response logic, and updates documentation and tests to cover the new functionality. Signed-off-by: Sergey Marunich --- README.md | 4 +- pkg/common/config.go | 27 ++--- pkg/common/failures.go | 122 ++++++++++++++++++++ pkg/common/failures_test.go | 134 ++++++++++++++++++++++ pkg/llm-d-inference-sim/simulator.go | 92 ++++++--------- pkg/llm-d-inference-sim/simulator_test.go | 12 +- pkg/openai-server-api/response.go | 3 +- 7 files changed, 314 insertions(+), 80 deletions(-) create mode 100644 pkg/common/failures.go create mode 100644 pkg/common/failures_test.go diff --git a/README.md b/README.md index 550366b5..8b9f22be 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,10 @@ In addition, it supports a subset of vLLM's Prometheus metrics. These metrics ar The simulated inference has no connection with the model and LoRA adapters specified in the command line parameters or via the /v1/load_lora_adapter HTTP REST endpoint. The /v1/models endpoint returns simulated results based on those same command line parameters and those loaded via the /v1/load_lora_adapter HTTP REST endpoint. -The simulator supports two modes of operation: +The simulator supports three modes of operation: - `echo` mode: the response contains the same text that was received in the request. For `/v1/chat/completions` the last message for the role=`user` is used. - `random` mode: the response is randomly chosen from a set of pre-defined sentences. +- `failure` mode: randomly injects OpenAI API compatible error responses for testing error handling. Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. @@ -103,6 +104,7 @@ For more details see the =", 400)) + Expect(failure.Message).ToNot(BeEmpty()) + Expect(failure.ErrorType).ToNot(BeEmpty()) + }) + + It("should return rate limit failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{"rate_limit"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(429)) + Expect(failure.ErrorType).To(Equal("rate_limit_exceeded")) + Expect(failure.ErrorCode).To(Equal("rate_limit_exceeded")) + Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) + }) + + It("should return invalid API key failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{"invalid_api_key"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(401)) + Expect(failure.ErrorType).To(Equal("invalid_request_error")) + Expect(failure.ErrorCode).To(Equal("invalid_api_key")) + Expect(failure.Message).To(Equal("Incorrect API key provided")) + }) + + It("should return context length failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{"context_length"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(400)) + Expect(failure.ErrorType).To(Equal("invalid_request_error")) + Expect(failure.ErrorCode).To(Equal("context_length_exceeded")) + Expect(failure.Param).ToNot(BeNil()) + Expect(*failure.Param).To(Equal("messages")) + }) + + It("should return server error when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{"server_error"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(503)) + Expect(failure.ErrorType).To(Equal("server_error")) + Expect(failure.ErrorCode).To(Equal("server_error")) + }) + + It("should return model not found failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{"model_not_found"}, + } + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(Equal(404)) + Expect(failure.ErrorType).To(Equal("invalid_request_error")) + Expect(failure.ErrorCode).To(Equal("model_not_found")) + Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) + }) + + It("should return server error as fallback for empty types", func() { + config := &common.Configuration{ + FailureTypes: []string{}, + } + // This test is probabilistic since it randomly selects, but we can test structure + failure := common.GetRandomFailure(config) + Expect(failure.StatusCode).To(BeNumerically(">=", 400)) + Expect(failure.ErrorType).ToNot(BeEmpty()) + }) + }) +}) \ No newline at end of file diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 00ae329f..475e4c66 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -292,9 +292,9 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure - if ShouldInjectFailure(s.config) { - failure := GetRandomFailure(s.config) - s.sendCompletionError(ctx, failure, true) + if common.ShouldInjectFailure(s.config) { + failure := common.GetRandomFailure(s.config) + s.sendFailureResponse(ctx, failure) return } @@ -307,13 +307,7 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple errMsg, errType, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, FailureSpec{ - StatusCode: errCode, - ErrorType: errType, - ErrorCode: "", - Message: errMsg, - Param: nil, - }, false) + s.sendCompletionError(ctx, errMsg, errType, errCode) return } @@ -340,14 +334,8 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, FailureSpec{ - StatusCode: fasthttp.StatusBadRequest, - ErrorType: "BadRequestError", - ErrorCode: "", - Message: fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), - Param: nil, - }, false) + s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest) return } @@ -502,53 +490,43 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -// The first parameter can be either a string message or a FailureSpec -// isInjected indicates if this is an injected failure for logging purposes -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo interface{}, isInjected bool) { - var compErr openaiserverapi.CompletionError - var statusCode int - - switch v := errorInfo.(type) { - case string: - // Legacy call with string message (backward compatibility) - compErr = openaiserverapi.CompletionError{ - Message: v, - Type: "BadRequestError", - Code: "", - Param: nil, - } - statusCode = fasthttp.StatusBadRequest - case FailureSpec: - // New call with FailureSpec - compErr = openaiserverapi.CompletionError{ - Message: v.Message, - Type: v.ErrorType, - Code: v.ErrorCode, - Param: v.Param, - } - statusCode = v.StatusCode - default: - // For calls with msg, errType, and code - need to be updated in calling code - panic("sendCompletionError called with unexpected type") - } - - errorResp := openaiserverapi.ErrorResponse{ - Object: "error", - Error: compErr, - } - - if isInjected { - s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string, errType string, code int) { + compErr := openaiserverapi.CompletionError{ + Message: msg, + Type: errType, + Code: "", + Param: nil, + } + errorResp := openaiserverapi.ErrorResponse{Error: compErr} + s.logger.Error(nil, compErr.Message) + + data, err := json.Marshal(errorResp) + if err != nil { + ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { - s.logger.Error(nil, compErr.Message) + ctx.SetContentType("application/json") + ctx.SetStatusCode(code) + ctx.SetBody(data) + } +} + +// sendFailureResponse sends a predefined failure response for testing +func (s *VllmSimulator) sendFailureResponse(ctx *fasthttp.RequestCtx, failure common.FailureSpec) { + compErr := openaiserverapi.CompletionError{ + Message: failure.Message, + Type: failure.ErrorType, + Code: failure.ErrorCode, + Param: failure.Param, } + errorResp := openaiserverapi.ErrorResponse{Error: compErr} + s.logger.Info("Injecting failure", "type", failure.ErrorType, "message", failure.Message) data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(statusCode) + ctx.SetStatusCode(failure.StatusCode) ctx.SetBody(data) } } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 467185fe..89e60e4e 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -43,8 +43,7 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var userMsgTokens int64 func startServer(ctx context.Context, mode string) (*http.Client, error) { - // Disable failure injection for tests by default - return startServerWithArgs(ctx, mode, []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"}) + return startServerWithArgs(ctx, mode, nil) } func startServerWithArgs(ctx context.Context, mode string, args []string) (*http.Client, error) { @@ -56,7 +55,7 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http if args != nil { os.Args = args } else { - os.Args = []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"} + os.Args = []string{"cmd", "--model", model, "--mode", mode} } logger := klog.Background() @@ -608,6 +607,7 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "100", }) Expect(err).ToNot(HaveOccurred()) @@ -666,6 +666,7 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", "rate_limit", }) @@ -702,6 +703,7 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", "invalid_api_key", "server_error", }) @@ -742,6 +744,7 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "0", }) Expect(err).ToNot(HaveOccurred()) @@ -773,6 +776,7 @@ var _ = Describe("Simulator", func() { ctx := context.Background() client, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", failureType, }) @@ -824,6 +828,7 @@ var _ = Describe("Simulator", func() { ctx := context.Background() _, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "-10", }) Expect(err).To(HaveOccurred()) @@ -834,6 +839,7 @@ var _ = Describe("Simulator", func() { ctx := context.Background() _, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, + "--mode", "failure", "--failure-injection-rate", "50", "--failure-types", "invalid_type", }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index f816b06d..9e8549b3 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -220,6 +220,5 @@ type CompletionError struct { // ErrorResponse wraps the error in the expected OpenAI format type ErrorResponse struct { - Object string `json:"object"` - Error CompletionError `json:"error"` + Error CompletionError `json:"error"` } From d7bb17562bc5b76245b5217ece0df531fce5380b Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Thu, 14 Aug 2025 11:46:44 -0400 Subject: [PATCH 09/46] Refactor failure injection and update simulator error handling Failure injection is now controlled by a dedicated 'failure-injection-rate' parameter instead of a separate 'failure' mode. Failure type constants are centralized, and error handling in the simulator is refactored to use a unified method for sending error responses. Documentation and tests are updated to reflect these changes, and the OpenAI error response format now includes an 'object' field. Signed-off-by: Sergey Marunich --- README.md | 6 +- pkg/common/config.go | 27 +++-- pkg/common/failures.go | 122 -------------------- pkg/common/failures_test.go | 134 ---------------------- pkg/llm-d-inference-sim/simulator.go | 92 +++++++++------ pkg/llm-d-inference-sim/simulator_test.go | 12 +- pkg/openai-server-api/response.go | 3 +- 7 files changed, 82 insertions(+), 314 deletions(-) delete mode 100644 pkg/common/failures.go delete mode 100644 pkg/common/failures_test.go diff --git a/README.md b/README.md index 8b9f22be..acf3a3bb 100644 --- a/README.md +++ b/README.md @@ -29,10 +29,11 @@ In addition, it supports a subset of vLLM's Prometheus metrics. These metrics ar The simulated inference has no connection with the model and LoRA adapters specified in the command line parameters or via the /v1/load_lora_adapter HTTP REST endpoint. The /v1/models endpoint returns simulated results based on those same command line parameters and those loaded via the /v1/load_lora_adapter HTTP REST endpoint. -The simulator supports three modes of operation: +The simulator supports two modes of operation: - `echo` mode: the response contains the same text that was received in the request. For `/v1/chat/completions` the last message for the role=`user` is used. - `random` mode: the response is randomly chosen from a set of pre-defined sentences. -- `failure` mode: randomly injects OpenAI API compatible error responses for testing error handling. + +Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. @@ -104,7 +105,6 @@ For more details see the =", 400)) - Expect(failure.Message).ToNot(BeEmpty()) - Expect(failure.ErrorType).ToNot(BeEmpty()) - }) - - It("should return rate limit failure when specified", func() { - config := &common.Configuration{ - Model: "test-model", - FailureTypes: []string{"rate_limit"}, - } - failure := common.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(429)) - Expect(failure.ErrorType).To(Equal("rate_limit_exceeded")) - Expect(failure.ErrorCode).To(Equal("rate_limit_exceeded")) - Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) - }) - - It("should return invalid API key failure when specified", func() { - config := &common.Configuration{ - FailureTypes: []string{"invalid_api_key"}, - } - failure := common.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(401)) - Expect(failure.ErrorType).To(Equal("invalid_request_error")) - Expect(failure.ErrorCode).To(Equal("invalid_api_key")) - Expect(failure.Message).To(Equal("Incorrect API key provided")) - }) - - It("should return context length failure when specified", func() { - config := &common.Configuration{ - FailureTypes: []string{"context_length"}, - } - failure := common.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(400)) - Expect(failure.ErrorType).To(Equal("invalid_request_error")) - Expect(failure.ErrorCode).To(Equal("context_length_exceeded")) - Expect(failure.Param).ToNot(BeNil()) - Expect(*failure.Param).To(Equal("messages")) - }) - - It("should return server error when specified", func() { - config := &common.Configuration{ - FailureTypes: []string{"server_error"}, - } - failure := common.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(503)) - Expect(failure.ErrorType).To(Equal("server_error")) - Expect(failure.ErrorCode).To(Equal("server_error")) - }) - - It("should return model not found failure when specified", func() { - config := &common.Configuration{ - Model: "test-model", - FailureTypes: []string{"model_not_found"}, - } - failure := common.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(404)) - Expect(failure.ErrorType).To(Equal("invalid_request_error")) - Expect(failure.ErrorCode).To(Equal("model_not_found")) - Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) - }) - - It("should return server error as fallback for empty types", func() { - config := &common.Configuration{ - FailureTypes: []string{}, - } - // This test is probabilistic since it randomly selects, but we can test structure - failure := common.GetRandomFailure(config) - Expect(failure.StatusCode).To(BeNumerically(">=", 400)) - Expect(failure.ErrorType).ToNot(BeEmpty()) - }) - }) -}) \ No newline at end of file diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 475e4c66..00ae329f 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -292,9 +292,9 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure - if common.ShouldInjectFailure(s.config) { - failure := common.GetRandomFailure(s.config) - s.sendFailureResponse(ctx, failure) + if ShouldInjectFailure(s.config) { + failure := GetRandomFailure(s.config) + s.sendCompletionError(ctx, failure, true) return } @@ -307,7 +307,13 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple errMsg, errType, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, errMsg, errType, errCode) + s.sendCompletionError(ctx, FailureSpec{ + StatusCode: errCode, + ErrorType: errType, + ErrorCode: "", + Message: errMsg, + Param: nil, + }, false) return } @@ -334,8 +340,14 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest) + s.sendCompletionError(ctx, FailureSpec{ + StatusCode: fasthttp.StatusBadRequest, + ErrorType: "BadRequestError", + ErrorCode: "", + Message: fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), + Param: nil, + }, false) return } @@ -490,43 +502,53 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string, errType string, code int) { - compErr := openaiserverapi.CompletionError{ - Message: msg, - Type: errType, - Code: "", - Param: nil, - } - errorResp := openaiserverapi.ErrorResponse{Error: compErr} - s.logger.Error(nil, compErr.Message) - - data, err := json.Marshal(errorResp) - if err != nil { - ctx.Error(err.Error(), fasthttp.StatusInternalServerError) +// The first parameter can be either a string message or a FailureSpec +// isInjected indicates if this is an injected failure for logging purposes +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo interface{}, isInjected bool) { + var compErr openaiserverapi.CompletionError + var statusCode int + + switch v := errorInfo.(type) { + case string: + // Legacy call with string message (backward compatibility) + compErr = openaiserverapi.CompletionError{ + Message: v, + Type: "BadRequestError", + Code: "", + Param: nil, + } + statusCode = fasthttp.StatusBadRequest + case FailureSpec: + // New call with FailureSpec + compErr = openaiserverapi.CompletionError{ + Message: v.Message, + Type: v.ErrorType, + Code: v.ErrorCode, + Param: v.Param, + } + statusCode = v.StatusCode + default: + // For calls with msg, errType, and code - need to be updated in calling code + panic("sendCompletionError called with unexpected type") + } + + errorResp := openaiserverapi.ErrorResponse{ + Object: "error", + Error: compErr, + } + + if isInjected { + s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) } else { - ctx.SetContentType("application/json") - ctx.SetStatusCode(code) - ctx.SetBody(data) - } -} - -// sendFailureResponse sends a predefined failure response for testing -func (s *VllmSimulator) sendFailureResponse(ctx *fasthttp.RequestCtx, failure common.FailureSpec) { - compErr := openaiserverapi.CompletionError{ - Message: failure.Message, - Type: failure.ErrorType, - Code: failure.ErrorCode, - Param: failure.Param, + s.logger.Error(nil, compErr.Message) } - errorResp := openaiserverapi.ErrorResponse{Error: compErr} - s.logger.Info("Injecting failure", "type", failure.ErrorType, "message", failure.Message) data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(failure.StatusCode) + ctx.SetStatusCode(statusCode) ctx.SetBody(data) } } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 89e60e4e..467185fe 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -43,7 +43,8 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var userMsgTokens int64 func startServer(ctx context.Context, mode string) (*http.Client, error) { - return startServerWithArgs(ctx, mode, nil) + // Disable failure injection for tests by default + return startServerWithArgs(ctx, mode, []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"}) } func startServerWithArgs(ctx context.Context, mode string, args []string) (*http.Client, error) { @@ -55,7 +56,7 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http if args != nil { os.Args = args } else { - os.Args = []string{"cmd", "--model", model, "--mode", mode} + os.Args = []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"} } logger := klog.Background() @@ -607,7 +608,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", }) Expect(err).ToNot(HaveOccurred()) @@ -666,7 +666,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", "rate_limit", }) @@ -703,7 +702,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", "invalid_api_key", "server_error", }) @@ -744,7 +742,6 @@ var _ = Describe("Simulator", func() { var err error client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "0", }) Expect(err).ToNot(HaveOccurred()) @@ -776,7 +773,6 @@ var _ = Describe("Simulator", func() { ctx := context.Background() client, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "100", "--failure-types", failureType, }) @@ -828,7 +824,6 @@ var _ = Describe("Simulator", func() { ctx := context.Background() _, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "-10", }) Expect(err).To(HaveOccurred()) @@ -839,7 +834,6 @@ var _ = Describe("Simulator", func() { ctx := context.Background() _, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, - "--mode", "failure", "--failure-injection-rate", "50", "--failure-types", "invalid_type", }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index 9e8549b3..f816b06d 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -220,5 +220,6 @@ type CompletionError struct { // ErrorResponse wraps the error in the expected OpenAI format type ErrorResponse struct { - Error CompletionError `json:"error"` + Object string `json:"object"` + Error CompletionError `json:"error"` } From c35dbca70351ed802c08649b6dbe6a8f94a93829 Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Thu, 7 Aug 2025 11:01:16 +0300 Subject: [PATCH 10/46] KV cache and tokenization related configuration (#125) Signed-off-by: Ira Signed-off-by: Sergey Marunich --- README.md | 6 ---- pkg/common/config.go | 59 +++------------------------------------ pkg/common/config_test.go | 7 ----- pkg/kv-cache/kv_cache.go | 8 ++---- 4 files changed, 7 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index acf3a3bb..fb6636c3 100644 --- a/README.md +++ b/README.md @@ -33,10 +33,6 @@ The simulator supports two modes of operation: - `echo` mode: the response contains the same text that was received in the request. For `/v1/chat/completions` the last message for the role=`user` is used. - `random` mode: the response is randomly chosen from a set of pre-defined sentences. -Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. - -Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. - Timing of the response is defined by the `time-to-first-token` and `inter-token-latency` parameters. In case P/D is enabled for a request, `kv-cache-transfer-latency` will be used instead of `time-to-first-token`. For a request with `stream=true`: `time-to-first-token` or `kv-cache-transfer-latency` defines the delay before the first token is returned, `inter-token-latency` defines the delay between subsequent tokens in the stream. @@ -126,8 +122,6 @@ For more details see the 100 { - return errors.New("failure injection rate should be between 0 and 100") - } - - validFailureTypes := map[string]bool{ - FailureTypeRateLimit: true, - FailureTypeInvalidAPIKey: true, - FailureTypeContextLength: true, - FailureTypeServerError: true, - FailureTypeInvalidRequest: true, - FailureTypeModelNotFound: true, - } - for _, failureType := range c.FailureTypes { - if !validFailureTypes[failureType] { - return fmt.Errorf("invalid failure type '%s', valid types are: rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found", failureType) - } - } - return nil } @@ -360,7 +320,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -384,14 +344,6 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") - f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") - - f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures when in failure mode") - - failureTypes := getParamValueFromArgs("failure-types") - var dummyFailureTypes multiString - f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") - f.Lookup("failure-types").NoOptDefVal = "dummy" // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -425,9 +377,6 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { if servedModelNames != nil { config.ServedModelNames = servedModelNames } - if failureTypes != nil { - config.FailureTypes = failureTypes - } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 6e768c27..2fd8446a 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -103,13 +103,11 @@ var _ = Describe("Simulator configuration", func() { "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", } - c.EventBatchSize = 5 test = testCase{ name: "config file with command line args", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", "--served-model-name", "alias1", "alias2", "--seed", "100", "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", - "--event-batch-size", "5", }, expectedConfig: c, } @@ -293,11 +291,6 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--block-size", "35", "--config", "../../manifests/config.yaml"}, }, - { - name: "invalid (negative) event-batch-size", - args: []string{"cmd", "--event-batch-size", "-35", - "--config", "../../manifests/config.yaml"}, - }, } for _, test := range invalidTests { diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index dbbd7645..e68f750d 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -47,17 +47,15 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCach tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir } tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) + if err != nil { return nil, fmt.Errorf("failed to create tokenizer: %w", err) } - blockCache, err := newBlockCache(config, logger) - if err != nil { - return nil, fmt.Errorf("failed to create block cache: %w", err) - } + return &KVCacheHelper{ tokenizer: tokenizer, tokensProcessor: tokensProcessor, - blockCache: blockCache, + blockCache: newBlockCache(config.KVCacheSize, logger), logger: logger, }, nil } From 2eca8e6560e75bac39188d4495ff199f88380a85 Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Tue, 12 Aug 2025 10:05:24 +0300 Subject: [PATCH 11/46] Publish kv-cache events (#126) * Publish kv-cache events Signed-off-by: Ira * Fix lint errors Signed-off-by: Ira * Review fixes Signed-off-by: Ira * Sleep to allow prevous sub to close Signed-off-by: Ira --------- Signed-off-by: Ira Signed-off-by: Sergey Marunich --- Makefile | 8 +++----- go.sum | 8 ++++---- pkg/common/config.go | 7 +++++++ pkg/common/config_test.go | 7 +++++++ pkg/kv-cache/kv_cache.go | 8 +++++--- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 60d105ad..0f761f54 100644 --- a/Makefile +++ b/Makefile @@ -39,16 +39,14 @@ help: ## Print help LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' CGO_ENABLED=1 TOKENIZER_LIB = lib/libtokenizers.a -# Extract TOKENIZER_VERSION from Dockerfile -TOKENIZER_VERSION := $(shell grep '^ARG TOKENIZER_VERSION=' Dockerfile | cut -d'=' -f2) .PHONY: download-tokenizer download-tokenizer: $(TOKENIZER_LIB) $(TOKENIZER_LIB): ## Download the HuggingFace tokenizer bindings. - @echo "Downloading HuggingFace tokenizer bindings for version $(TOKENIZER_VERSION)..." + @echo "Downloading HuggingFace tokenizer bindings..." mkdir -p lib - curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + curl -L https://github.com/daulet/tokenizers/releases/download/v1.22.1/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib ranlib lib/*.a ##@ Development @@ -226,4 +224,4 @@ download-zmq: ## Install ZMQ dependencies based on OS/ARCH exit 1; \ fi; \ echo "✅ ZMQ dependencies installed."; \ - fi \ No newline at end of file + fi diff --git a/go.sum b/go.sum index c34d530f..93f6363c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= -github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -13,6 +11,8 @@ github.com/buaazp/fasthttprouter v0.1.1/go.mod h1:h/Ap5oRVLeItGKTVBb+heQPks+HdIU github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/daulet/tokenizers v1.20.2 h1:tlq/vIOiBTKDPets3596aFvmJYLn3XI6LFKq4q9LKhQ= +github.com/daulet/tokenizers v1.20.2/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/daulet/tokenizers v1.22.1 h1:3wzAFIxfgRuqGKka8xdkeTbctDmmqOOs12GofqdorpM= github.com/daulet/tokenizers v1.22.1/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -68,6 +68,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/llm-d/llm-d-kv-cache-manager v0.2.0 h1:7MXFPjy3P8nZ7HbB1LWhhVLHvNTLbZglkD/ZcT7UU1k= +github.com/llm-d/llm-d-kv-cache-manager v0.2.0/go.mod h1:ZTqwsnIVC6R5YuTUrYofPIUnCeZ9RvXn1UQAdxLYl1Y= github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a h1:PXR37HLgYYfolzWQA2uQOEiJlj3IV9YSvgaEFqCRSa8= github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a/go.mod h1:g2UlYKNJ4S860SAQ/QoRnytAFfnp8f1luW4IuZSMwCE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -145,8 +147,6 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= -github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= diff --git a/pkg/common/config.go b/pkg/common/config.go index b613e621..181deb30 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -125,6 +125,8 @@ type Configuration struct { // ZMQEndpoint is the ZMQ address to publish events, the default value is tcp://localhost:5557 ZMQEndpoint string `yaml:"zmq-endpoint"` + // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 + EventBatchSize int `yaml:"event-batch-size"` } type LoraModule struct { @@ -183,6 +185,7 @@ func newConfig() *Configuration { KVCacheSize: 1024, TokenBlockSize: 16, ZMQEndpoint: "tcp://localhost:5557", + EventBatchSize: 16, } } @@ -293,6 +296,9 @@ func (c *Configuration) validate() error { if c.KVCacheSize < 0 { return errors.New("KV cache size cannot be negative") } + if c.EventBatchSize < 1 { + return errors.New("event batch size cannot less than 1") + } return nil } @@ -344,6 +350,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") + f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 2fd8446a..6e768c27 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -103,11 +103,13 @@ var _ = Describe("Simulator configuration", func() { "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", } + c.EventBatchSize = 5 test = testCase{ name: "config file with command line args", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", "--served-model-name", "alias1", "alias2", "--seed", "100", "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", + "--event-batch-size", "5", }, expectedConfig: c, } @@ -291,6 +293,11 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--block-size", "35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid (negative) event-batch-size", + args: []string{"cmd", "--event-batch-size", "-35", + "--config", "../../manifests/config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index e68f750d..dbbd7645 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -47,15 +47,17 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCach tokenizationConfig.TokenizersCacheDir = config.TokenizersCacheDir } tokenizer, err := tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) - if err != nil { return nil, fmt.Errorf("failed to create tokenizer: %w", err) } - + blockCache, err := newBlockCache(config, logger) + if err != nil { + return nil, fmt.Errorf("failed to create block cache: %w", err) + } return &KVCacheHelper{ tokenizer: tokenizer, tokensProcessor: tokensProcessor, - blockCache: newBlockCache(config.KVCacheSize, logger), + blockCache: blockCache, logger: logger, }, nil } From 28fb65b1f9ddf78851187d3826965884cd67499b Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Thu, 14 Aug 2025 14:16:32 +0300 Subject: [PATCH 12/46] Use same version of tokenizer in both Dockerfile and Makefile (#132) * - Use same version of tokenizer in both Dockerfile and Makefile - Fixes in readme file Signed-off-by: Maya Barnea * updates according PR's review Signed-off-by: Maya Barnea --------- Signed-off-by: Maya Barnea Signed-off-by: Sergey Marunich --- Dockerfile | 4 +++- Makefile | 6 ++++-- README.md | 6 ++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2af4a795..36a8836e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,9 @@ COPY . . # HuggingFace tokenizer bindings RUN mkdir -p lib -RUN curl -L https://github.com/daulet/tokenizers/releases/download/v1.22.1/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib +# Ensure that the TOKENIZER_VERSION matches the one used in the imported llm-d-kv-cache-manager version +ARG TOKENIZER_VERSION=v1.22.1 +RUN curl -L https://github.com/daulet/tokenizers/releases/download/${TOKENIZER_VERSION}/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib RUN ranlib lib/*.a # Build diff --git a/Makefile b/Makefile index 0f761f54..819091f4 100644 --- a/Makefile +++ b/Makefile @@ -39,14 +39,16 @@ help: ## Print help LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' CGO_ENABLED=1 TOKENIZER_LIB = lib/libtokenizers.a +# Extract TOKENIZER_VERSION from Dockerfile +TOKENIZER_VERSION := $(shell grep '^ARG TOKENIZER_VERSION=' Dockerfile | cut -d'=' -f2) .PHONY: download-tokenizer download-tokenizer: $(TOKENIZER_LIB) $(TOKENIZER_LIB): ## Download the HuggingFace tokenizer bindings. - @echo "Downloading HuggingFace tokenizer bindings..." + @echo "Downloading HuggingFace tokenizer bindings for version $(TOKENIZER_VERSION)..." mkdir -p lib - curl -L https://github.com/daulet/tokenizers/releases/download/v1.22.1/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib ranlib lib/*.a ##@ Development diff --git a/README.md b/README.md index fb6636c3..c40e7e28 100644 --- a/README.md +++ b/README.md @@ -116,13 +116,15 @@ For more details see the In addition, as we are using klog, the following parameters are available: diff --git a/pkg/common/config.go b/pkg/common/config.go index 181deb30..ebeaee6a 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -34,6 +34,14 @@ const ( vLLMDefaultPort = 8000 ModeRandom = "random" ModeEcho = "echo" + + // Failure type constants + FailureTypeRateLimit = "rate_limit" + FailureTypeInvalidAPIKey = "invalid_api_key" + FailureTypeContextLength = "context_length" + FailureTypeServerError = "server_error" + FailureTypeInvalidRequest = "invalid_request" + FailureTypeModelNotFound = "model_not_found" ) type Configuration struct { @@ -127,6 +135,11 @@ type Configuration struct { ZMQEndpoint string `yaml:"zmq-endpoint"` // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 EventBatchSize int `yaml:"event-batch-size"` + + // FailureInjectionRate is the probability (0-100) of injecting failures + FailureInjectionRate int `yaml:"failure-injection-rate"` + // FailureTypes is a list of specific failure types to inject (empty means all types) + FailureTypes []string `yaml:"failure-types"` } type LoraModule struct { @@ -182,10 +195,12 @@ func newConfig() *Configuration { MinToolCallArrayParamLength: 1, ToolCallNotRequiredParamProbability: 50, ObjectToolCallNotRequiredParamProbability: 50, - KVCacheSize: 1024, - TokenBlockSize: 16, - ZMQEndpoint: "tcp://localhost:5557", - EventBatchSize: 16, + KVCacheSize: 1024, + TokenBlockSize: 16, + ZMQEndpoint: "tcp://localhost:5557", + EventBatchSize: 16, + FailureInjectionRate: 10, + FailureTypes: []string{}, } } @@ -299,6 +314,25 @@ func (c *Configuration) validate() error { if c.EventBatchSize < 1 { return errors.New("event batch size cannot less than 1") } + + if c.FailureInjectionRate < 0 || c.FailureInjectionRate > 100 { + return errors.New("failure injection rate should be between 0 and 100") + } + + validFailureTypes := map[string]bool{ + FailureTypeRateLimit: true, + FailureTypeInvalidAPIKey: true, + FailureTypeContextLength: true, + FailureTypeServerError: true, + FailureTypeInvalidRequest: true, + FailureTypeModelNotFound: true, + } + for _, failureType := range c.FailureTypes { + if !validFailureTypes[failureType] { + return fmt.Errorf("invalid failure type '%s', valid types are: rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found", failureType) + } + } + return nil } @@ -326,7 +360,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -351,6 +385,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") + + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") + + failureTypes := getParamValueFromArgs("failure-types") + var dummyFailureTypes multiString + f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") + f.Lookup("failure-types").NoOptDefVal = "dummy" // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -384,6 +425,16 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { if servedModelNames != nil { config.ServedModelNames = servedModelNames } + if failureTypes != nil { + config.FailureTypes = failureTypes + } + + if config.HashSeed == "" { + hashSeed := os.Getenv("PYTHONHASHSEED") + if hashSeed != "" { + config.HashSeed = hashSeed + } + } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") From f5ae85b987c1096932bef161e4614c65d3da3bcc Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Thu, 14 Aug 2025 13:14:47 -0400 Subject: [PATCH 14/46] Set default failure injection rate to 0 Signed-off-by: Sergey Marunich --- README.md | 2 +- pkg/common/config.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 76de8f3d..4486b8e6 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ For more details see the Date: Thu, 14 Aug 2025 17:21:29 -0400 Subject: [PATCH 15/46] rebase duplicates Signed-off-by: Sergey Marunich --- pkg/common/config.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 4728f224..2d813228 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -436,13 +436,6 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { } } - if config.HashSeed == "" { - hashSeed := os.Getenv("PYTHONHASHSEED") - if hashSeed != "" { - config.HashSeed = hashSeed - } - } - if err := config.validate(); err != nil { return nil, err } From 106e27619f25c2544c2d6ed2d50f140858d11b9b Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Thu, 7 Aug 2025 11:01:16 +0300 Subject: [PATCH 16/46] re-base the changes Signed-off-by: Sergey Marunich KV cache and tokenization related configuration (#125) Signed-off-by: Ira Publish kv-cache events (#126) * Publish kv-cache events Signed-off-by: Ira * Fix lint errors Signed-off-by: Ira * Review fixes Signed-off-by: Ira * Sleep to allow prevous sub to close Signed-off-by: Ira --------- Signed-off-by: Ira Signed-off-by: Sergey Marunich Use same version of tokenizer in both Dockerfile and Makefile (#132) * - Use same version of tokenizer in both Dockerfile and Makefile - Fixes in readme file Signed-off-by: Maya Barnea * updates according PR's review Signed-off-by: Maya Barnea --------- Signed-off-by: Maya Barnea Signed-off-by: Sergey Marunich Replaces usage of param.NewOpt with openai.Int for MaxTokens and openai.Bool with param.NewOpt for IncludeUsage in simulator_test.go to align with updated API usage. Signed-off-by: Sergey Marunich --- README.md | 5 +-- pkg/common/config.go | 54 +++-------------------- pkg/llm-d-inference-sim/simulator_test.go | 16 +++---- 3 files changed, 14 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 4486b8e6..c40e7e28 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,6 @@ The simulator supports two modes of operation: - `echo` mode: the response contains the same text that was received in the request. For `/v1/chat/completions` the last message for the role=`user` is used. - `random` mode: the response is randomly chosen from a set of pre-defined sentences. -Additionally, the simulator can inject OpenAI API compatible error responses for testing error handling using the `failure-injection-rate` parameter. - Timing of the response is defined by the `time-to-first-token` and `inter-token-latency` parameters. In case P/D is enabled for a request, `kv-cache-transfer-latency` will be used instead of `time-to-first-token`. For a request with `stream=true`: `time-to-first-token` or `kv-cache-transfer-latency` defines the delay before the first token is returned, `inter-token-latency` defines the delay between subsequent tokens in the stream. @@ -118,14 +116,13 @@ For more details see the 100 { - return errors.New("failure injection rate should be between 0 and 100") - } - - validFailureTypes := map[string]bool{ - FailureTypeRateLimit: true, - FailureTypeInvalidAPIKey: true, - FailureTypeContextLength: true, - FailureTypeServerError: true, - FailureTypeInvalidRequest: true, - FailureTypeModelNotFound: true, - } - for _, failureType := range c.FailureTypes { - if !validFailureTypes[failureType] { - return fmt.Errorf("invalid failure type '%s', valid types are: rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found", failureType) - } - } - return nil } @@ -360,7 +326,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -385,13 +351,6 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") - - f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") - - failureTypes := getParamValueFromArgs("failure-types") - var dummyFailureTypes multiString - f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") - f.Lookup("failure-types").NoOptDefVal = "dummy" // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -425,9 +384,6 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { if servedModelNames != nil { config.ServedModelNames = servedModelNames } - if failureTypes != nil { - config.FailureTypes = failureTypes - } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 467185fe..c6629354 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -31,6 +31,7 @@ import ( . "github.com/onsi/gomega" "github.com/openai/openai-go" "github.com/openai/openai-go/option" + "github.com/openai/openai-go/packages/param" "github.com/valyala/fasthttp/fasthttputil" "k8s.io/klog/v2" ) @@ -43,8 +44,7 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var userMsgTokens int64 func startServer(ctx context.Context, mode string) (*http.Client, error) { - // Disable failure injection for tests by default - return startServerWithArgs(ctx, mode, []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"}) + return startServerWithArgs(ctx, mode, nil) } func startServerWithArgs(ctx context.Context, mode string, args []string) (*http.Client, error) { @@ -56,7 +56,7 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http if args != nil { os.Args = args } else { - os.Args = []string{"cmd", "--model", model, "--mode", mode, "--failure-injection-rate", "0"} + os.Args = []string{"cmd", "--model", model, "--mode", mode} } logger := klog.Background() @@ -120,7 +120,7 @@ var _ = Describe("Simulator", func() { openai.UserMessage(userMessage), }, Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: openai.Bool(true)}, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, } stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) defer func() { @@ -183,7 +183,7 @@ var _ = Describe("Simulator", func() { OfString: openai.String(userMessage), }, Model: openai.CompletionNewParamsModel(model), - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: openai.Bool(true)}, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, } stream := openaiclient.Completions.NewStreaming(ctx, params) defer func() { @@ -246,11 +246,11 @@ var _ = Describe("Simulator", func() { // if maxTokens and maxCompletionTokens are passsed // maxCompletionTokens is used if maxTokens != 0 { - params.MaxTokens = openai.Int(int64(maxTokens)) + params.MaxTokens = param.NewOpt(int64(maxTokens)) numTokens = maxTokens } if maxCompletionTokens != 0 { - params.MaxCompletionTokens = openai.Int(int64(maxCompletionTokens)) + params.MaxCompletionTokens = param.NewOpt(int64(maxCompletionTokens)) numTokens = maxCompletionTokens } resp, err := openaiclient.Chat.Completions.New(ctx, params) @@ -329,7 +329,7 @@ var _ = Describe("Simulator", func() { } numTokens := 0 if maxTokens != 0 { - params.MaxTokens = openai.Int(int64(maxTokens)) + params.MaxTokens = param.NewOpt(int64(maxTokens)) numTokens = maxTokens } resp, err := openaiclient.Completions.New(ctx, params) From 516222626eb14c2682159571f8bb15a61a24c3f2 Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Thu, 14 Aug 2025 17:48:41 -0400 Subject: [PATCH 17/46] Update option constructors in simulator tests Replaces usage of param.NewOpt with openai.Int for MaxTokens and openai.Bool with param.NewOpt for IncludeUsage in simulator_test.go to align with updated API usage. --- pkg/common/config.go | 63 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 181deb30..22a87dde 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -34,6 +34,14 @@ const ( vLLMDefaultPort = 8000 ModeRandom = "random" ModeEcho = "echo" + + // Failure type constants + FailureTypeRateLimit = "rate_limit" + FailureTypeInvalidAPIKey = "invalid_api_key" + FailureTypeContextLength = "context_length" + FailureTypeServerError = "server_error" + FailureTypeInvalidRequest = "invalid_request" + FailureTypeModelNotFound = "model_not_found" ) type Configuration struct { @@ -127,6 +135,11 @@ type Configuration struct { ZMQEndpoint string `yaml:"zmq-endpoint"` // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 EventBatchSize int `yaml:"event-batch-size"` + + // FailureInjectionRate is the probability (0-100) of injecting failures + FailureInjectionRate int `yaml:"failure-injection-rate"` + // FailureTypes is a list of specific failure types to inject (empty means all types) + FailureTypes []string `yaml:"failure-types"` } type LoraModule struct { @@ -182,10 +195,12 @@ func newConfig() *Configuration { MinToolCallArrayParamLength: 1, ToolCallNotRequiredParamProbability: 50, ObjectToolCallNotRequiredParamProbability: 50, - KVCacheSize: 1024, - TokenBlockSize: 16, - ZMQEndpoint: "tcp://localhost:5557", - EventBatchSize: 16, + KVCacheSize: 1024, + TokenBlockSize: 16, + ZMQEndpoint: "tcp://localhost:5557", + EventBatchSize: 16, + FailureInjectionRate: 10, + FailureTypes: []string{}, } } @@ -299,6 +314,25 @@ func (c *Configuration) validate() error { if c.EventBatchSize < 1 { return errors.New("event batch size cannot less than 1") } + + if c.FailureInjectionRate < 0 || c.FailureInjectionRate > 100 { + return errors.New("failure injection rate should be between 0 and 100") + } + + validFailureTypes := map[string]bool{ + FailureTypeRateLimit: true, + FailureTypeInvalidAPIKey: true, + FailureTypeContextLength: true, + FailureTypeServerError: true, + FailureTypeInvalidRequest: true, + FailureTypeModelNotFound: true, + } + for _, failureType := range c.FailureTypes { + if !validFailureTypes[failureType] { + return fmt.Errorf("invalid failure type '%s', valid types are: rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found", failureType) + } + } + return nil } @@ -326,7 +360,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -351,6 +385,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") + + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") + + failureTypes := getParamValueFromArgs("failure-types") + var dummyFailureTypes multiString + f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") + f.Lookup("failure-types").NoOptDefVal = "dummy" // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -384,6 +425,16 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { if servedModelNames != nil { config.ServedModelNames = servedModelNames } + if failureTypes != nil { + config.FailureTypes = failureTypes + } + + if config.HashSeed == "" { + hashSeed := os.Getenv("PYTHONHASHSEED") + if hashSeed != "" { + config.HashSeed = hashSeed + } + } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") @@ -422,4 +473,4 @@ func getParamValueFromArgs(param string) []string { } } return values -} +} \ No newline at end of file From 5182187d6d97e1fd779bb5f95831b9ea2f42f3ba Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Thu, 14 Aug 2025 17:56:09 -0400 Subject: [PATCH 18/46] Document failure injection options in README Added descriptions for `failure-injection-rate` and `failure-types` configuration options to clarify their usage and defaults. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c40e7e28..05dad5b6 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,8 @@ For more details see the Date: Sun, 17 Aug 2025 12:39:22 +0300 Subject: [PATCH 20/46] use newer version of kvcache-manager, update code accordingly (#133) Signed-off-by: Maya Barnea --- go.mod | 2 +- go.sum | 12 ++++++------ pkg/kv-cache/kv_cache_sender.go | 19 ++++--------------- pkg/kv-cache/kv_cache_test.go | 13 ++++++++----- 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index 65cbb3fb..30ba1417 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/buaazp/fasthttprouter v0.1.1 github.com/go-logr/logr v1.4.2 github.com/google/uuid v1.6.0 - github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a + github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250814115305-d5a8ca882318 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/openai/openai-go v0.1.0-beta.10 diff --git a/go.sum b/go.sum index 93f6363c..378898fd 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= +github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -11,8 +13,6 @@ github.com/buaazp/fasthttprouter v0.1.1/go.mod h1:h/Ap5oRVLeItGKTVBb+heQPks+HdIU github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/daulet/tokenizers v1.20.2 h1:tlq/vIOiBTKDPets3596aFvmJYLn3XI6LFKq4q9LKhQ= -github.com/daulet/tokenizers v1.20.2/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/daulet/tokenizers v1.22.1 h1:3wzAFIxfgRuqGKka8xdkeTbctDmmqOOs12GofqdorpM= github.com/daulet/tokenizers v1.22.1/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -68,10 +68,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache-manager v0.2.0 h1:7MXFPjy3P8nZ7HbB1LWhhVLHvNTLbZglkD/ZcT7UU1k= -github.com/llm-d/llm-d-kv-cache-manager v0.2.0/go.mod h1:ZTqwsnIVC6R5YuTUrYofPIUnCeZ9RvXn1UQAdxLYl1Y= -github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a h1:PXR37HLgYYfolzWQA2uQOEiJlj3IV9YSvgaEFqCRSa8= -github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250810103202-0adf0940f60a/go.mod h1:g2UlYKNJ4S860SAQ/QoRnytAFfnp8f1luW4IuZSMwCE= +github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250814115305-d5a8ca882318 h1:4V1tDOzD0EzatsdOjJnEt7+dJDQPTozfUU4g29dCrTY= +github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250814115305-d5a8ca882318/go.mod h1:g2UlYKNJ4S860SAQ/QoRnytAFfnp8f1luW4IuZSMwCE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -147,6 +145,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= diff --git a/pkg/kv-cache/kv_cache_sender.go b/pkg/kv-cache/kv_cache_sender.go index f8af3638..941f8422 100644 --- a/pkg/kv-cache/kv_cache_sender.go +++ b/pkg/kv-cache/kv_cache_sender.go @@ -16,7 +16,6 @@ limitations under the License. package kvcache import ( - "bytes" "context" "fmt" "time" @@ -90,24 +89,14 @@ func (s *KVEventSender) Run(ctx context.Context) error { } // Encode eventData's hash value to msgpack.RawMessage + var payload []byte var err error - var payload bytes.Buffer - enc := msgpack.NewEncoder(&payload) - enc.UseArrayEncodedStructs(true) switch eventData.action { case eventActionStore: - bs := &kvevents.BlockStoredEvent{ - TypeField: BlockStored, - BlockStored: &kvevents.BlockStored{BlockHashes: eventData.hashValues}, - } - err = enc.Encode(bs) + payload, err = msgpack.Marshal(kvevents.BlockStored{BlockHashes: eventData.hashValues}.ToTaggedUnion()) case eventActionRemove: - br := &kvevents.BlockRemovedEvent{ - TypeField: BlockRemoved, - BlockRemoved: &kvevents.BlockRemoved{BlockHashes: eventData.hashValues}, - } - err = enc.Encode(br) + payload, err = msgpack.Marshal(kvevents.BlockRemoved{BlockHashes: eventData.hashValues}.ToTaggedUnion()) default: return fmt.Errorf("invalid event action %d", eventData.action) } @@ -115,7 +104,7 @@ func (s *KVEventSender) Run(ctx context.Context) error { return fmt.Errorf("failed to marshal value: %w", err) } - s.batch = append(s.batch, payload.Bytes()) + s.batch = append(s.batch, payload) // check if batch is big enough to be sent if len(s.batch) >= s.maxBatchSize { diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index cc259c5b..4bea7c45 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -496,22 +496,25 @@ func parseEvent(parts [][]byte, expectedTopic string, expectedSeq uint64) ([]uin Expect(err).NotTo(HaveOccurred()) for _, rawEvent := range eventBatch.Events { var taggedUnion []msgpack.RawMessage - err = msgpack.Unmarshal(rawEvent, &taggedUnion) + err := msgpack.Unmarshal(rawEvent, &taggedUnion) Expect(err).NotTo(HaveOccurred()) Expect(len(taggedUnion)).To(BeNumerically(">", 1)) + payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) + Expect(err).NotTo(HaveOccurred()) + var tag string err = msgpack.Unmarshal(taggedUnion[0], &tag) Expect(err).NotTo(HaveOccurred()) switch tag { case BlockStored: - var bs kvevents.BlockStoredEvent - err = msgpack.Unmarshal(rawEvent, &bs) + var bs kvevents.BlockStored + err = msgpack.Unmarshal(payloadBytes, &bs) stored = append(stored, bs.BlockHashes...) case BlockRemoved: - var br kvevents.BlockRemovedEvent - err = msgpack.Unmarshal(rawEvent, &br) + var br kvevents.BlockRemoved + err = msgpack.Unmarshal(payloadBytes, &br) removed = append(removed, br.BlockHashes...) default: From 471fac0edff6c743752a23189a7760d8b06dbd11 Mon Sep 17 00:00:00 2001 From: Nina Polshakova Date: Mon, 18 Aug 2025 02:28:10 -0400 Subject: [PATCH 21/46] Add support to echo the sim's pod name and namespace (#128) * add pod name and ns headers Signed-off-by: npolshakova * add pod name and ns env Signed-off-by: npolshakova * Signed-off-by: npolshakova feedback Signed-off-by: npolshakova * reuse env var Signed-off-by: npolshakova * feedback Signed-off-by: npolshakova * add unset env tests Signed-off-by: npolshakova --------- Signed-off-by: npolshakova --- Makefile | 6 +- README.md | 9 + manifests/deployment.yaml | 11 ++ pkg/llm-d-inference-sim/lora_test.go | 2 +- pkg/llm-d-inference-sim/seed_test.go | 2 +- pkg/llm-d-inference-sim/simulator.go | 20 +- pkg/llm-d-inference-sim/simulator_test.go | 230 +++++++++++++++++++++- pkg/llm-d-inference-sim/streaming.go | 8 + pkg/llm-d-inference-sim/tools_test.go | 6 +- 9 files changed, 280 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index 819091f4..2e0ceae7 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ SRC = $(shell find . -type f -name '*.go') help: ## Print help @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) -LDFLAGS ?= -extldflags '-L$(shell pwd)/lib' +GO_LDFLAGS := -extldflags '-L$(shell pwd)/lib $(LDFLAGS)' CGO_ENABLED=1 TOKENIZER_LIB = lib/libtokenizers.a # Extract TOKENIZER_VERSION from Dockerfile @@ -67,7 +67,7 @@ format: ## Format Go source files .PHONY: test test: check-ginkgo download-tokenizer download-zmq ## Run tests @printf "\033[33;1m==== Running tests ====\033[0m\n" - CGO_ENABLED=1 ginkgo -ldflags="$(LDFLAGS)" -v -r + CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r .PHONY: post-deploy-test post-deploy-test: ## Run post deployment tests @@ -84,7 +84,7 @@ lint: check-golangci-lint ## Run lint .PHONY: build build: check-go download-tokenizer download-zmq @printf "\033[33;1m==== Building ====\033[0m\n" - go build -ldflags="$(LDFLAGS)" -o bin/$(PROJECT_NAME) cmd/$(PROJECT_NAME)/main.go + go build -ldflags="$(GO_LDFLAGS)" -o bin/$(PROJECT_NAME) cmd/$(PROJECT_NAME)/main.go ##@ Container Build/Push diff --git a/README.md b/README.md index c40e7e28..a7e0d5cc 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,8 @@ make image-build Please note that the default image tag is `ghcr.io/llm-d/llm-d-inference-sim:dev`.
The following environment variables can be used to change the image tag: `REGISTRY`, `SIM_TAG`, `IMAGE_TAG_BASE` or `IMG`. +Note: On macOS, use `make image-build TARGETOS=linux` to pull the correct base image. + ### Running To run the vLLM Simulator image under Docker, run: ```bash @@ -186,6 +188,13 @@ To run the vLLM simulator in a Kubernetes cluster, run: kubectl apply -f manifests/deployment.yaml ``` +When testing locally with kind, build the docker image with `make build-image` then load into the cluster: +```shell +kind load --name kind docker-image ghcr.io/llm-d/llm-d-inference-sim:dev +``` + +Update the `deployment.yaml` file to use the dev tag. + To verify the deployment is available, run: ```bash kubectl get deployment vllm-llama3-8b-instruct diff --git a/manifests/deployment.yaml b/manifests/deployment.yaml index 75c001aa..aa23f3d5 100644 --- a/manifests/deployment.yaml +++ b/manifests/deployment.yaml @@ -25,6 +25,17 @@ spec: image: ghcr.io/llm-d/llm-d-inference-sim:latest imagePullPolicy: IfNotPresent name: vllm-sim + env: + - name: POD_NAME + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.name + - name: POD_NAMESPACE + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.namespace ports: - containerPort: 8000 name: http diff --git a/pkg/llm-d-inference-sim/lora_test.go b/pkg/llm-d-inference-sim/lora_test.go index 682b8411..7ec37d0d 100644 --- a/pkg/llm-d-inference-sim/lora_test.go +++ b/pkg/llm-d-inference-sim/lora_test.go @@ -37,7 +37,7 @@ var _ = Describe("LoRAs", func() { client, err := startServerWithArgs(ctx, "", []string{"cmd", "--model", model, "--mode", common.ModeEcho, "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", - "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}"}) + "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}"}, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( diff --git a/pkg/llm-d-inference-sim/seed_test.go b/pkg/llm-d-inference-sim/seed_test.go index 190ae7bf..505b4938 100644 --- a/pkg/llm-d-inference-sim/seed_test.go +++ b/pkg/llm-d-inference-sim/seed_test.go @@ -33,7 +33,7 @@ var _ = Describe("Simulator with seed", func() { func() { ctx := context.TODO() client, err := startServerWithArgs(ctx, common.ModeRandom, - []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--seed", "100"}) + []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--seed", "100"}, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 77b0738a..b40d444d 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "net" + "os" "strings" "sync" "sync/atomic" @@ -46,6 +47,11 @@ const ( textCompletionObject = "text_completion" chatCompletionObject = "chat.completion" chatCompletionChunkObject = "chat.completion.chunk" + + podHeader = "x-inference-pod" + namespaceHeader = "x-inference-namespace" + podNameEnv = "POD_NAME" + podNsEnv = "POD_NAMESPACE" ) // VllmSimulator simulates vLLM server supporting OpenAI API @@ -79,6 +85,10 @@ type VllmSimulator struct { toolsValidator *openaiserverapi.Validator // kv cache functionality kvcacheHelper *kvcache.KVCacheHelper + // namespace where simulator is running + namespace string + // pod name of simulator + pod string } // New creates a new VllmSimulator instance with the given logger @@ -93,6 +103,8 @@ func New(logger logr.Logger) (*VllmSimulator, error) { reqChan: make(chan *openaiserverapi.CompletionReqCtx, 1000), toolsValidator: toolsValidtor, kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration + namespace: os.Getenv(podNsEnv), + pod: os.Getenv(podNameEnv), }, nil } @@ -599,9 +611,15 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill) + s.getTotalInterTokenLatency(numOfTokens) time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond) - // TODO - maybe add pod id to response header for testing ctx.Response.Header.SetContentType("application/json") ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + // Add pod and namespace information to response headers for testing/debugging + if s.pod != "" { + ctx.Response.Header.Add(podHeader, s.pod) + } + if s.namespace != "" { + ctx.Response.Header.Add(namespaceHeader, s.namespace) + } ctx.Response.SetBody(data) s.responseSentCallback(modelName) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index fb8c0e8f..931e3718 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -44,10 +44,10 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var userMsgTokens int64 func startServer(ctx context.Context, mode string) (*http.Client, error) { - return startServerWithArgs(ctx, mode, nil) + return startServerWithArgs(ctx, mode, nil, nil) } -func startServerWithArgs(ctx context.Context, mode string, args []string) (*http.Client, error) { +func startServerWithArgs(ctx context.Context, mode string, args []string, envs map[string]string) (*http.Client, error) { oldArgs := os.Args defer func() { os.Args = oldArgs @@ -58,6 +58,21 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http } else { os.Args = []string{"cmd", "--model", model, "--mode", mode} } + + if envs != nil { + for k, v := range envs { + err := os.Setenv(k, v) + Expect(err).NotTo(HaveOccurred()) + } + + defer func() { + for k := range envs { + err := os.Unsetenv(k) + Expect(err).NotTo(HaveOccurred()) + } + }() + } + logger := klog.Background() s, err := New(logger) @@ -402,12 +417,217 @@ var _ = Describe("Simulator", func() { Expect(resp.StatusCode).To(Equal(http.StatusOK)) }) + It("Should not include namespace and pod headers in chat completion response when env is not set", func() { + ctx := context.TODO() + + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") + Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") + }) + + It("Should include namespace and pod headers in chat completion response", func() { + ctx := context.TODO() + + testNamespace := "test-namespace" + testPod := "test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + + It("Should include namespace and pod headers in chat completion streaming response", func() { + ctx := context.TODO() + + testNamespace := "stream-test-namespace" + testPod := "stream-test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + + It("Should not include namespace and pod headers in chat completion streaming response when env is not set", func() { + ctx := context.TODO() + + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") + Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") + }) + + It("Should include namespace and pod headers in completion response", func() { + ctx := context.TODO() + + testNamespace := "test-namespace" + testPod := "test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + Model: openai.CompletionNewParamsModel(model), + } + var httpResp *http.Response + resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + + It("Should include namespace and pod headers in completion streaming response", func() { + ctx := context.TODO() + + testNamespace := "stream-test-namespace" + testPod := "stream-test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + Model: openai.CompletionNewParamsModel(model), + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + var httpResp *http.Response + resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) + + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) + Context("max-model-len context window validation", func() { It("Should reject requests exceeding context window", func() { ctx := context.TODO() // Start server with max-model-len=10 args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--max-model-len", "10"} - client, err := startServerWithArgs(ctx, common.ModeRandom, args) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) // Test with raw HTTP to verify the error response format @@ -457,7 +677,7 @@ var _ = Describe("Simulator", func() { ctx := context.TODO() // Start server with max-model-len=50 args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, "--max-model-len", "50"} - client, err := startServerWithArgs(ctx, common.ModeEcho, args) + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -483,7 +703,7 @@ var _ = Describe("Simulator", func() { ctx := context.TODO() // Start server with max-model-len=10 args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--max-model-len", "10"} - client, err := startServerWithArgs(ctx, common.ModeRandom, args) + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) // Test with raw HTTP for text completion diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index b173924c..c9789204 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -45,6 +45,14 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons context.ctx.SetContentType("text/event-stream") context.ctx.SetStatusCode(fasthttp.StatusOK) + // Add pod and namespace information to response headers for testing/debugging + if s.pod != "" { + context.ctx.Response.Header.Add(podHeader, s.pod) + } + if s.namespace != "" { + context.ctx.Response.Header.Add(namespaceHeader, s.namespace) + } + context.ctx.SetBodyStreamWriter(func(w *bufio.Writer) { context.creationTime = time.Now().Unix() diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index b4e1d009..c996db59 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -525,7 +525,7 @@ var _ = Describe("Simulator for request with tools", func() { "--min-tool-call-number-param", fmt.Sprint(min), "--max-tool-call-number-param", fmt.Sprint(max), } - client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs) + client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -787,7 +787,7 @@ var _ = Describe("Simulator for request with tools", func() { serverArgs := []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--tool-call-not-required-param-probability", strconv.Itoa(probability), } - client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs) + client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -832,7 +832,7 @@ var _ = Describe("Simulator for request with tools", func() { "--min-tool-call-integer-param", strconv.Itoa(min), "--max-tool-call-integer-param", strconv.Itoa(max), } - client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs) + client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( From a080a176a76c99e25332f93193fa73cb3a99835d Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Tue, 19 Aug 2025 09:46:33 +0300 Subject: [PATCH 22/46] Create UUID string under a lock (#143) Signed-off-by: Ira --- pkg/common/utils.go | 7 +++++++ pkg/llm-d-inference-sim/simulator.go | 5 ++--- pkg/llm-d-inference-sim/streaming.go | 7 +++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 309dc9d5..2cb4ad66 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -244,6 +244,13 @@ func RandomNorm(mean float64, stddev float64) float64 { return value } +// GenerateUUIDString generates a UUID string under a lock +func GenerateUUIDString() string { + randMutex.Lock() + defer randMutex.Unlock() + return uuid.NewString() +} + // Regular expression for the response tokenization var re *regexp.Regexp diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index b40d444d..da5f53e1 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -30,7 +30,6 @@ import ( "github.com/buaazp/fasthttprouter" "github.com/go-logr/logr" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/valyala/fasthttp" @@ -200,7 +199,7 @@ func (s *VllmSimulator) Printf(format string, args ...interface{}) { // readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) { - requestID := uuid.NewString() + requestID := common.GenerateUUIDString() if isChatCompletion { var req openaiserverapi.ChatCompletionRequest @@ -546,7 +545,7 @@ func (s *VllmSimulator) HandleError(_ *fasthttp.RequestCtx, err error) { func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { baseResp := openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: time.Now().Unix(), Model: modelName, Usage: usageData, diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index c9789204..969f29af 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -22,7 +22,6 @@ import ( "fmt" "time" - "github.com/google/uuid" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/valyala/fasthttp" @@ -154,7 +153,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // supports both modes (text and chat) func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk { baseChunk := openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, Model: context.model, Usage: usageData, @@ -179,7 +178,7 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { return &openaiserverapi.TextCompletionResponse{ BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, Model: context.model, Object: textCompletionObject, @@ -199,7 +198,7 @@ func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, tok role string, finishReason *string) openaiserverapi.CompletionRespChunk { chunk := openaiserverapi.ChatCompletionRespChunk{ BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + uuid.NewString(), + ID: chatComplIDPrefix + common.GenerateUUIDString(), Created: context.creationTime, Model: context.model, Object: chatCompletionChunkObject, From 430992528240034801f2b0e965fd4feca38515f0 Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Tue, 19 Aug 2025 14:13:17 +0300 Subject: [PATCH 23/46] Support fake metrics (#144) * Support fake metrics Signed-off-by: Ira * Readme Signed-off-by: Ira * Removed commented out code Signed-off-by: Ira --------- Signed-off-by: Ira --- README.md | 10 +++ manifests/config_with_fake.yaml | 16 +++++ pkg/common/config.go | 78 ++++++++++++++++++++- pkg/common/config_test.go | 85 ++++++++++++++++++++++- pkg/llm-d-inference-sim/metrics.go | 47 +++++++++---- pkg/llm-d-inference-sim/simulator_test.go | 42 ++++++++++- 6 files changed, 257 insertions(+), 21 deletions(-) create mode 100644 manifests/config_with_fake.yaml diff --git a/README.md b/README.md index a7e0d5cc..3874db1a 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,16 @@ For more details see the
1 { + return errors.New("fake metrics KV cache usage must be between 0 ans 1") + } + } return nil } @@ -316,6 +380,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { servedModelNames := getParamValueFromArgs("served-model-name") loraModuleNames := getParamValueFromArgs("lora-modules") + fakeMetrics := getParamValueFromArgs("fake-metrics") f := pflag.NewFlagSet("llm-d-inference-sim flags", pflag.ContinueOnError) @@ -358,9 +423,11 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { var dummyMultiString multiString f.Var(&dummyMultiString, "served-model-name", "Model names exposed by the API (a list of space-separated strings)") f.Var(&dummyMultiString, "lora-modules", "List of LoRA adapters (a list of space-separated JSON strings)") + f.Var(&dummyMultiString, "fake-metrics", "A set of metrics to send to Prometheus instead of the real data") // In order to allow empty arguments, we set a dummy NoOptDefVal for these flags - f.Lookup("served-model-name").NoOptDefVal = "dummy" - f.Lookup("lora-modules").NoOptDefVal = "dummy" + f.Lookup("served-model-name").NoOptDefVal = dummy + f.Lookup("lora-modules").NoOptDefVal = dummy + f.Lookup("fake-metrics").NoOptDefVal = dummy flagSet := flag.NewFlagSet("simFlagSet", flag.ExitOnError) klog.InitFlags(flagSet) @@ -381,6 +448,11 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { return nil, err } } + if fakeMetrics != nil { + if err := config.unmarshalFakeMetrics(fakeMetrics[0]); err != nil { + return nil, err + } + } if servedModelNames != nil { config.ServedModelNames = servedModelNames } diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 6e768c27..d4ec9677 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -51,7 +51,6 @@ func createDefaultConfig(model string) *Configuration { c.KVCacheTransferLatency = 100 c.Seed = 100100100 c.LoraModules = []LoraModule{} - return c } @@ -173,7 +172,7 @@ var _ = Describe("Simulator configuration", func() { } tests = append(tests, test) - // Config from config.yaml file plus command line args with time to copy cache + // Config from basic-config.yaml file plus command line args with time to copy cache c = createDefaultConfig(qwenModelName) c.Port = 8001 // basic config file does not contain properties related to lora @@ -181,12 +180,82 @@ var _ = Describe("Simulator configuration", func() { c.MaxCPULoras = 1 c.KVCacheTransferLatency = 50 test = testCase{ - name: "config file with command line args with time to transfer kv-cache", + name: "basic config file with command line args with time to transfer kv-cache", args: []string{"cmd", "--config", "../../manifests/basic-config.yaml", "--kv-cache-transfer-latency", "50"}, expectedConfig: c, } tests = append(tests, test) + // Config from config_with_fake.yaml file + c = createDefaultConfig(qwenModelName) + c.FakeMetrics = &Metrics{ + RunningRequests: 16, + WaitingRequests: 3, + KVCacheUsagePercentage: float32(0.3), + LoraMetrics: []LorasMetrics{ + {RunningLoras: "lora1,lora2", WaitingLoras: "lora3", Timestamp: 1257894567}, + {RunningLoras: "lora1,lora3", WaitingLoras: "", Timestamp: 1257894569}, + }, + LorasString: []string{ + "{\"running\":\"lora1,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567}", + "{\"running\":\"lora1,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}", + }, + } + test = testCase{ + name: "config with fake metrics file", + args: []string{"cmd", "--config", "../../manifests/config_with_fake.yaml"}, + expectedConfig: c, + } + tests = append(tests, test) + + // Fake metrics from command line + c = newConfig() + c.Model = model + c.ServedModelNames = []string{c.Model} + c.MaxCPULoras = 1 + c.Seed = 100 + c.FakeMetrics = &Metrics{ + RunningRequests: 10, + WaitingRequests: 30, + KVCacheUsagePercentage: float32(0.4), + LoraMetrics: []LorasMetrics{ + {RunningLoras: "lora4,lora2", WaitingLoras: "lora3", Timestamp: 1257894567}, + {RunningLoras: "lora4,lora3", WaitingLoras: "", Timestamp: 1257894569}, + }, + LorasString: nil, + } + test = testCase{ + name: "metrics from command line", + args: []string{"cmd", "--model", model, "--seed", "100", + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + }, + expectedConfig: c, + } + tests = append(tests, test) + + // Fake metrics from both the config file and command line + c = createDefaultConfig(qwenModelName) + c.FakeMetrics = &Metrics{ + RunningRequests: 10, + WaitingRequests: 30, + KVCacheUsagePercentage: float32(0.4), + LoraMetrics: []LorasMetrics{ + {RunningLoras: "lora4,lora2", WaitingLoras: "lora3", Timestamp: 1257894567}, + {RunningLoras: "lora4,lora3", WaitingLoras: "", Timestamp: 1257894569}, + }, + LorasString: nil, + } + test = testCase{ + name: "metrics from config file and command line", + args: []string{"cmd", "--config", "../../manifests/config_with_fake.yaml", + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + }, + expectedConfig: c, + } + tests = append(tests, test) + for _, test := range tests { When(test.name, func() { It("should create correct configuration", func() { @@ -298,6 +367,16 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--event-batch-size", "-35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid fake metrics: negative running requests", + args: []string{"cmd", "--fake-metrics", "{\"running-requests\":-10,\"waiting-requests\":30,\"kv-cache-usage\":0.4}", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid fake metrics: kv cache usage", + args: []string{"cmd", "--fake-metrics", "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":40}", + "--config", "../../manifests/config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index c869ecd8..3fa60a2e 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -96,25 +96,40 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { return nil } -// setInitialPrometheusMetrics send default values to prometheus +// setInitialPrometheusMetrics sends the default values to prometheus or +// the fake metrics if set func (s *VllmSimulator) setInitialPrometheusMetrics() { + var nRunningReqs, nWaitingReqs, kvCacheUsage float64 + if s.config.FakeMetrics != nil { + nRunningReqs = float64(s.config.FakeMetrics.RunningRequests) + nWaitingReqs = float64(s.config.FakeMetrics.WaitingRequests) + kvCacheUsage = float64(s.config.FakeMetrics.KVCacheUsagePercentage) + } modelName := s.getDisplayedModelName(s.config.Model) - s.loraInfo.WithLabelValues( - strconv.Itoa(s.config.MaxLoras), - "", - "").Set(float64(time.Now().Unix())) - - s.nRunningReqs = 0 - s.runningRequests.WithLabelValues( - modelName).Set(float64(s.nRunningReqs)) - s.waitingRequests.WithLabelValues( - modelName).Set(float64(0)) - s.kvCacheUsagePercentage.WithLabelValues( - modelName).Set(float64(0)) + s.runningRequests.WithLabelValues(modelName).Set(nRunningReqs) + s.waitingRequests.WithLabelValues(modelName).Set(nWaitingReqs) + s.kvCacheUsagePercentage.WithLabelValues(modelName).Set(kvCacheUsage) + + if s.config.FakeMetrics != nil && len(s.config.FakeMetrics.LoraMetrics) != 0 { + for _, metrics := range s.config.FakeMetrics.LoraMetrics { + s.loraInfo.WithLabelValues( + strconv.Itoa(s.config.MaxLoras), + metrics.RunningLoras, + metrics.WaitingLoras).Set(metrics.Timestamp) + } + } else { + s.loraInfo.WithLabelValues( + strconv.Itoa(s.config.MaxLoras), + "", + "").Set(float64(time.Now().Unix())) + } } // reportLoras sets information about loaded LoRA adapters func (s *VllmSimulator) reportLoras() { + if s.config.FakeMetrics != nil { + return + } if s.loraInfo == nil { // Happens in the tests return @@ -138,6 +153,9 @@ func (s *VllmSimulator) reportLoras() { // reportRunningRequests sets information about running completion requests func (s *VllmSimulator) reportRunningRequests() { + if s.config.FakeMetrics != nil { + return + } if s.runningRequests != nil { nRunningReqs := atomic.LoadInt64(&(s.nRunningReqs)) s.runningRequests.WithLabelValues( @@ -147,6 +165,9 @@ func (s *VllmSimulator) reportRunningRequests() { // reportWaitingRequests sets information about waiting completion requests func (s *VllmSimulator) reportWaitingRequests() { + if s.config.FakeMetrics != nil { + return + } if s.waitingRequests != nil { nWaitingReqs := atomic.LoadInt64(&(s.nWaitingReqs)) s.waitingRequests.WithLabelValues( diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 931e3718..b27d97a2 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -48,6 +48,10 @@ func startServer(ctx context.Context, mode string) (*http.Client, error) { } func startServerWithArgs(ctx context.Context, mode string, args []string, envs map[string]string) (*http.Client, error) { + return startServerWithArgsAndMetrics(ctx, mode, args, envs, false) +} + +func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []string, envs map[string]string, setMetrics bool) (*http.Client, error) { oldArgs := os.Args defer func() { os.Args = oldArgs @@ -91,6 +95,13 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m common.InitRandom(s.config.Seed) + if setMetrics { + err = s.createAndRegisterPrometheus() + if err != nil { + return nil, err + } + } + // calculate number of tokens for user message, // must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine userMsgTokens = int64(len(common.Tokenize(userMessage))) @@ -420,7 +431,7 @@ var _ = Describe("Simulator", func() { It("Should not include namespace and pod headers in chat completion response when env is not set", func() { ctx := context.TODO() - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, nil) + client, err := startServer(ctx, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -523,7 +534,7 @@ var _ = Describe("Simulator", func() { It("Should not include namespace and pod headers in chat completion streaming response when env is not set", func() { ctx := context.TODO() - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, nil) + client, err := startServer(ctx, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) openaiclient := openai.NewClient( @@ -809,4 +820,31 @@ var _ = Describe("Simulator", func() { Entry(nil, 10000, 0, 1000, 0, false), ) }) + + Context("fake metrics", func() { + It("Should respond with fake metrics to /metrics", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + } + + client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"my_model\"} 10")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"my_model\"} 30")) + Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"my_model\"} 0.4")) + Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora2\",waiting_lora_adapters=\"lora3\"} 1.257894567e+09")) + Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora3\",waiting_lora_adapters=\"\"} 1.257894569e+09")) + + }) + }) }) From efa82a551f578a98236838223d3e7bf98abdaadc Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Tue, 19 Aug 2025 14:20:41 +0300 Subject: [PATCH 24/46] Makefile fixes for MacOS (#146) Signed-off-by: Shmuel Kallner --- Makefile | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 2e0ceae7..7f4ea750 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,16 @@ IMAGE_TAG_BASE ?= $(IMAGE_REGISTRY)/$(PROJECT_NAME) SIM_TAG ?= dev IMG = $(IMAGE_TAG_BASE):$(SIM_TAG) +ifeq ($(TARGETOS),darwin) +ifeq ($(TARGETARCH),amd64) +TOKENIZER_ARCH = x86_64 +else +TOKENIZER_ARCH = $(TARGETARCH) +endif +else +TOKENIZER_ARCH = $(TARGETARCH) +endif + CONTAINER_TOOL := $(shell { command -v docker >/dev/null 2>&1 && echo docker; } || { command -v podman >/dev/null 2>&1 && echo podman; } || echo "") BUILDER := $(shell command -v buildah >/dev/null 2>&1 && echo buildah || echo $(CONTAINER_TOOL)) PLATFORMS ?= linux/amd64 # linux/arm64 # linux/s390x,linux/ppc64le @@ -48,7 +58,7 @@ $(TOKENIZER_LIB): ## Download the HuggingFace tokenizer bindings. @echo "Downloading HuggingFace tokenizer bindings for version $(TOKENIZER_VERSION)..." mkdir -p lib - curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TARGETARCH).tar.gz | tar -xz -C lib + curl -L https://github.com/daulet/tokenizers/releases/download/$(TOKENIZER_VERSION)/libtokenizers.$(TARGETOS)-$(TOKENIZER_ARCH).tar.gz | tar -xz -C lib ranlib lib/*.a ##@ Development @@ -92,8 +102,8 @@ build: check-go download-tokenizer download-zmq image-build: check-container-tool ## Build Docker image ## Build Docker image using $(CONTAINER_TOOL) @printf "\033[33;1m==== Building Docker image $(IMG) ====\033[0m\n" $(CONTAINER_TOOL) build \ - --platform $(TARGETOS)/$(TARGETARCH) \ - --build-arg TARGETOS=$(TARGETOS)\ + --platform linux/$(TARGETARCH) \ + --build-arg TARGETOS=linux \ --build-arg TARGETARCH=$(TARGETARCH)\ -t $(IMG) . From 54efd5b62f5c63d87f844de02898c9f7555cca71 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Tue, 19 Aug 2025 17:01:23 +0300 Subject: [PATCH 25/46] - return to kv-cache-manager version v0.2.1 (#147) - fix serialization of BlockStored and BlockRemoved structures to be compatible to v0.2.1 Signed-off-by: Maya Barnea --- go.mod | 2 +- go.sum | 8 ++------ pkg/kv-cache/kv_cache_sender.go | 22 ++++++++++++++++++++-- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 30ba1417..70fc81bf 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/buaazp/fasthttprouter v0.1.1 github.com/go-logr/logr v1.4.2 github.com/google/uuid v1.6.0 - github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250814115305-d5a8ca882318 + github.com/llm-d/llm-d-kv-cache-manager v0.2.1 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.37.0 github.com/openai/openai-go v0.1.0-beta.10 diff --git a/go.sum b/go.sum index 378898fd..56ae979d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= -github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -68,8 +66,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250814115305-d5a8ca882318 h1:4V1tDOzD0EzatsdOjJnEt7+dJDQPTozfUU4g29dCrTY= -github.com/llm-d/llm-d-kv-cache-manager v0.2.2-0.20250814115305-d5a8ca882318/go.mod h1:g2UlYKNJ4S860SAQ/QoRnytAFfnp8f1luW4IuZSMwCE= +github.com/llm-d/llm-d-kv-cache-manager v0.2.1 h1:PKIjJPUF9ILLFBNvZRa0QQ/liTQjBKwWChzcenEdM08= +github.com/llm-d/llm-d-kv-cache-manager v0.2.1/go.mod h1:s1xaE4ImkihWaLg2IQh4VN6L1PgN5RD1u1VarPey6dw= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -145,8 +143,6 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= -github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= diff --git a/pkg/kv-cache/kv_cache_sender.go b/pkg/kv-cache/kv_cache_sender.go index 941f8422..2b7bee14 100644 --- a/pkg/kv-cache/kv_cache_sender.go +++ b/pkg/kv-cache/kv_cache_sender.go @@ -94,9 +94,9 @@ func (s *KVEventSender) Run(ctx context.Context) error { switch eventData.action { case eventActionStore: - payload, err = msgpack.Marshal(kvevents.BlockStored{BlockHashes: eventData.hashValues}.ToTaggedUnion()) + payload, err = msgpack.Marshal(storedToTaggedUnion(kvevents.BlockStored{BlockHashes: eventData.hashValues})) case eventActionRemove: - payload, err = msgpack.Marshal(kvevents.BlockRemoved{BlockHashes: eventData.hashValues}.ToTaggedUnion()) + payload, err = msgpack.Marshal(removedToTaggedUnion(kvevents.BlockRemoved{BlockHashes: eventData.hashValues})) default: return fmt.Errorf("invalid event action %d", eventData.action) } @@ -128,6 +128,24 @@ func (s *KVEventSender) Run(ctx context.Context) error { } } +func storedToTaggedUnion(bs kvevents.BlockStored) []any { + return []any{ + BlockStored, + bs.BlockHashes, + bs.ParentBlockHash, + bs.TokenIds, + bs.BlockSize, + bs.LoraID, + } +} + +func removedToTaggedUnion(br kvevents.BlockRemoved) []any { + return []any{ + BlockRemoved, + br.BlockHashes, + } +} + // helper to publish collected batch if not empty func (s *KVEventSender) publishHelper(ctx context.Context) error { if len(s.batch) == 0 { From c3aae8d3b1e95ce62e47cdf7a2b11dc53e7ad79a Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Wed, 20 Aug 2025 09:44:54 +0300 Subject: [PATCH 26/46] present kv-cache related configuration parameters in readme file (#149) Signed-off-by: Maya Barnea --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 3874db1a..940ef95e 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,6 @@ For more details see the - `fake-metrics`: represents a predefined set of metrics to be sent to Prometheus as a substitute for the actual data. When specified, only these fake metrics will be reported — real metrics and fake metrics will never be reported simultaneously. The set should include values for - `running-requests` - `waiting-requests` From ad487ee0df580998730da149bfb26cab9951e9f9 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Wed, 20 Aug 2025 10:14:15 +0300 Subject: [PATCH 27/46] updated readme file - added environment variables (#151) Signed-off-by: Maya Barnea --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 940ef95e..328bc608 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,9 @@ In addition, as we are using klog, the following parameters are available: - `v`: number for the log level verbosity - `vmodule`: comma-separated list of pattern=N settings for file-filtered logging ---- +## Environment variables +- `POD_NAME`: the simulator pod name. If defined, the response will contain the HTTP header `x-inference-pod` with this value +- `POD_NAMESPACE`: the simulator pod namespace. If defined, the response will contain the HTTP header `x-inference-namespace` with this value ## Migrating from releases prior to v0.2.0 - `max-running-requests` was replaced by `max-num-seqs` From 21957bc7bf3553cf3fa4197e0032824d0cbdd707 Mon Sep 17 00:00:00 2001 From: Qifan Deng <20884468+pancak3@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:05:12 +1000 Subject: [PATCH 28/46] Fix zmq endpoints in test cases (#150) * Using the IP address 127.0.0.1 instead of localhost in test cases for zmq to prevent potential name resolution issues Signed-off-by: Qifan Deng * Ignore vscode devcontainer config Signed-off-by: Qifan Deng * Fix a formatting error introduced by commit 9235047 Signed-off-by: Qifan Deng --------- Signed-off-by: Qifan Deng --- .gitignore | 1 + pkg/common/publisher_test.go | 11 ++++++----- pkg/kv-cache/kv_cache_test.go | 17 +++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 401369b3..ac6e6006 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,6 @@ bin lib vendor .vscode +.devcontainer # MacOSX .DS_Store \ No newline at end of file diff --git a/pkg/common/publisher_test.go b/pkg/common/publisher_test.go index 5df18940..8f4609d5 100644 --- a/pkg/common/publisher_test.go +++ b/pkg/common/publisher_test.go @@ -29,9 +29,10 @@ import ( ) const ( - topic = "test-topic" - endpoint = "tcp://localhost:5557" - data = "Hello" + topic = "test-topic" + subEndpoint = "tcp://*:5557" + pubEndpoint = "tcp://localhost:5557" + data = "Hello" ) var _ = Describe("Publisher", func() { @@ -40,7 +41,7 @@ var _ = Describe("Publisher", func() { Expect(err).NotTo(HaveOccurred()) sub, err := zctx.NewSocket(zmq.SUB) Expect(err).NotTo(HaveOccurred()) - err = sub.Bind(endpoint) + err = sub.Bind(subEndpoint) Expect(err).NotTo(HaveOccurred()) err = sub.SetSubscribe(topic) Expect(err).NotTo(HaveOccurred()) @@ -49,7 +50,7 @@ var _ = Describe("Publisher", func() { time.Sleep(100 * time.Millisecond) - pub, err := NewPublisher(endpoint) + pub, err := NewPublisher(pubEndpoint) Expect(err).NotTo(HaveOccurred()) ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index 4bea7c45..8f57c516 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -33,10 +33,11 @@ import ( ) const ( - req1ID = "req1" - req2ID = "req2" - req3ID = "req3" - endpoint = "tcp://localhost:5557" + req1ID = "req1" + req2ID = "req2" + req3ID = "req3" + subEndpoint = "tcp://*:5557" + pubEndpoint = "tcp://localhost:5557" ) type ActionType int @@ -203,7 +204,7 @@ var _ = Describe("KV cache", Ordered, func() { Port: 1234, Model: "model", KVCacheSize: test.cacheSize, - ZMQEndpoint: endpoint, + ZMQEndpoint: pubEndpoint, EventBatchSize: 1, } @@ -306,7 +307,7 @@ var _ = Describe("KV cache", Ordered, func() { Port: 1234, Model: "model", KVCacheSize: 4, - ZMQEndpoint: endpoint, + ZMQEndpoint: pubEndpoint, } sub, topic := createSub(config) @@ -415,7 +416,7 @@ var _ = Describe("KV cache", Ordered, func() { Port: 1234, Model: "model", KVCacheSize: testCase.cacheSize, - ZMQEndpoint: endpoint, + ZMQEndpoint: pubEndpoint, } blockCache, err := newBlockCache(&config, GinkgoLogr) Expect(err).NotTo(HaveOccurred()) @@ -531,7 +532,7 @@ func createSub(config *common.Configuration) (*zmq.Socket, string) { Expect(err).NotTo(HaveOccurred()) sub, err := zctx.NewSocket(zmq.SUB) Expect(err).NotTo(HaveOccurred()) - err = sub.Bind(endpoint) + err = sub.Bind(subEndpoint) Expect(err).NotTo(HaveOccurred()) topic := createTopic(config) err = sub.SetSubscribe(topic) From 03050c7745dfb3ed03ea24950bf572db0af7478d Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Wed, 20 Aug 2025 18:32:04 +0300 Subject: [PATCH 29/46] change user to not be root in the dockerfile (#153) Signed-off-by: Maya Barnea --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 36a8836e..58a6ccbe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,6 @@ RUN microdnf install -y dnf && \ COPY --from=builder /workspace/bin/llm-d-inference-sim /app/llm-d-inference-sim -# USER 65532:65532 -USER root +USER 65532:65532 ENTRYPOINT ["/app/llm-d-inference-sim"] From 34c29cad24a19f8662e4a009a0437244b7eb5725 Mon Sep 17 00:00:00 2001 From: Zhengke Zhou Date: Thu, 21 Aug 2025 16:18:18 +0800 Subject: [PATCH 30/46] Add ZMQ connection retry configuration (#152) * Add ZMQ connection retry configuration Signed-off-by: zhengkezhou1 * add test & update readme Signed-off-by: zhengkezhou1 * add retries test Signed-off-by: zhengkezhou1 * more tests & rename Command line parameters Signed-off-by: zhengkezhou1 --------- Signed-off-by: zhengkezhou1 --- README.md | 1 + manifests/invalid-config.yaml | 9 ++++++++ pkg/common/config.go | 6 ++++++ pkg/common/config_test.go | 11 ++++++++++ pkg/common/publisher.go | 33 +++++++++++++++++++---------- pkg/common/publisher_test.go | 39 ++++++++++++++++++++++++++++++++++- pkg/kv-cache/block_cache.go | 2 +- pkg/kv-cache/kv_cache_test.go | 29 ++++++++++++++------------ 8 files changed, 104 insertions(+), 26 deletions(-) create mode 100644 manifests/invalid-config.yaml diff --git a/README.md b/README.md index 328bc608..275f8e13 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ For more details see the 10 { + return errors.New("zmq retries times cannot be more than 10") + } if c.FakeMetrics != nil { if c.FakeMetrics.RunningRequests < 0 || c.FakeMetrics.WaitingRequests < 0 { @@ -415,6 +420,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") + f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to retry ZMQ requests") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index d4ec9677..f50c40a9 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -103,12 +103,14 @@ var _ = Describe("Simulator configuration", func() { "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", } c.EventBatchSize = 5 + c.ZMQMaxConnectAttempts = 1 test = testCase{ name: "config file with command line args", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", "--served-model-name", "alias1", "alias2", "--seed", "100", "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", "--event-batch-size", "5", + "--zmq-max-connect-attempts", "1", }, expectedConfig: c, } @@ -121,6 +123,7 @@ var _ = Describe("Simulator configuration", func() { c.LoraModulesString = []string{ "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", } + c.ZMQMaxConnectAttempts = 0 test = testCase{ name: "config file with command line args with different format", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", @@ -377,6 +380,14 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--fake-metrics", "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":40}", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid (negative) zmq-max-connect-attempts for argument", + args: []string{"cmd", "zmq-max-connect-attempts", "-1", "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) zmq-max-connect-attempts for config file", + args: []string{"cmd", "--config", "../../manifests/invalid-config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/common/publisher.go b/pkg/common/publisher.go index d7d6e325..883c05a2 100644 --- a/pkg/common/publisher.go +++ b/pkg/common/publisher.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "sync/atomic" + "time" zmq "github.com/pebbe/zmq4" "github.com/vmihailenco/msgpack/v5" @@ -38,24 +39,34 @@ type Publisher struct { // NewPublisher creates a new ZMQ publisher. // endpoint is the ZMQ address to bind to (e.g., "tcp://*:5557"). -func NewPublisher(endpoint string) (*Publisher, error) { +// retries is the maximum number of connection attempts. +func NewPublisher(endpoint string, retries uint) (*Publisher, error) { socket, err := zmq.NewSocket(zmq.PUB) if err != nil { return nil, fmt.Errorf("failed to create ZMQ PUB socket: %w", err) } - if err := socket.Connect(endpoint); err != nil { - errClose := socket.Close() - return nil, errors.Join( - fmt.Errorf("failed to connect to %s: %w", endpoint, err), - errClose, - ) + // Retry connection with specified retry times and intervals + for i := uint(0); i <= retries; i++ { + err = socket.Connect(endpoint) + if err == nil { + return &Publisher{ + socket: socket, + endpoint: endpoint, + }, nil + } + + // If not the last attempt, wait before retrying + if i < retries { + time.Sleep(1 * time.Second) + } } - return &Publisher{ - socket: socket, - endpoint: endpoint, - }, nil + errClose := socket.Close() + return nil, errors.Join( + fmt.Errorf("failed to connect to %s after %d retries: %w", endpoint, retries+1, err), + errClose, + ) } // PublishEvent publishes a KV cache event batch to the ZMQ topic. diff --git a/pkg/common/publisher_test.go b/pkg/common/publisher_test.go index 8f4609d5..a9d6582b 100644 --- a/pkg/common/publisher_test.go +++ b/pkg/common/publisher_test.go @@ -33,6 +33,7 @@ const ( subEndpoint = "tcp://*:5557" pubEndpoint = "tcp://localhost:5557" data = "Hello" + retries = 0 ) var _ = Describe("Publisher", func() { @@ -50,7 +51,7 @@ var _ = Describe("Publisher", func() { time.Sleep(100 * time.Millisecond) - pub, err := NewPublisher(pubEndpoint) + pub, err := NewPublisher(pubEndpoint, retries) Expect(err).NotTo(HaveOccurred()) ctx, cancel := context.WithCancel(context.Background()) @@ -78,4 +79,40 @@ var _ = Describe("Publisher", func() { Expect(err).NotTo(HaveOccurred()) Expect(payload).To(Equal(data)) }) + It("should fail when connection attempts exceed maximum retries", func() { + // Use invalid address format, which will cause connection to fail + invalidEndpoint := "invalid-address-format" + + pub, err := NewPublisher(invalidEndpoint, 2) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to connect")) + Expect(err.Error()).To(ContainSubstring("after 3 retries")) // 2 retries = 3 total attempts + + if pub != nil { + //nolint + pub.Close() + } + }) + It("should retry connection successfully", func() { + // Step 1: Try to connect to a temporarily non-existent service + // This will trigger the retry mechanism + go func() { + // Delay starting the server to simulate service recovery + time.Sleep(2 * time.Second) + + // Start subscriber as server + sub, err := zmq.NewSocket(zmq.SUB) + Expect(err).NotTo(HaveOccurred()) + //nolint + defer sub.Close() + err = sub.Bind(subEndpoint) + Expect(err).NotTo(HaveOccurred()) + }() + + // Step 2: Publisher will retry connection and eventually succeed + pub, err := NewPublisher(pubEndpoint, 5) // 5 retries + Expect(err).NotTo(HaveOccurred()) // Should eventually succeed + //nolint + defer pub.Close() + }) }) diff --git a/pkg/kv-cache/block_cache.go b/pkg/kv-cache/block_cache.go index e66c7224..56d2253b 100644 --- a/pkg/kv-cache/block_cache.go +++ b/pkg/kv-cache/block_cache.go @@ -48,7 +48,7 @@ func newBlockCache(config *common.Configuration, logger logr.Logger) (*blockCach // TODO read size of channel from config eChan := make(chan EventData, 10000) - publisher, err := common.NewPublisher(config.ZMQEndpoint) + publisher, err := common.NewPublisher(config.ZMQEndpoint, config.ZMQMaxConnectAttempts) if err != nil { return nil, err } diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index 8f57c516..7731196e 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -201,11 +201,12 @@ var _ = Describe("KV cache", Ordered, func() { time.Sleep(300 * time.Millisecond) config := &common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: test.cacheSize, - ZMQEndpoint: pubEndpoint, - EventBatchSize: 1, + Port: 1234, + Model: "model", + KVCacheSize: test.cacheSize, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, + EventBatchSize: 1, } sub, topic := createSub(config) @@ -304,10 +305,11 @@ var _ = Describe("KV cache", Ordered, func() { It("should send events correctly", func() { config := &common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: 4, - ZMQEndpoint: pubEndpoint, + Port: 1234, + Model: "model", + KVCacheSize: 4, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, } sub, topic := createSub(config) @@ -413,10 +415,11 @@ var _ = Describe("KV cache", Ordered, func() { for _, testCase := range testCases { It(testCase.name, func() { config := common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: testCase.cacheSize, - ZMQEndpoint: pubEndpoint, + Port: 1234, + Model: "model", + KVCacheSize: testCase.cacheSize, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, } blockCache, err := newBlockCache(&config, GinkgoLogr) Expect(err).NotTo(HaveOccurred()) From e112efe8e07b10028dc408d0bf4e60b858a059ce Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Thu, 21 Aug 2025 12:44:43 +0300 Subject: [PATCH 31/46] Added CI automation (#155) * Added an OWNERS file to control who can review and approve PRs Signed-off-by: Shmuel Kallner * Added Prow automation Signed-off-by: Shmuel Kallner * Added automated marking of issues as stale Signed-off-by: Shmuel Kallner --------- Signed-off-by: Shmuel Kallner --- .github/workflows/prow-github.yml | 37 ++++++++++++++++++++++ .github/workflows/prow-pr-automerge.yml | 18 +++++++++++ .github/workflows/prow-pr-remove-lgtm.yml | 11 +++++++ .github/workflows/re-run-action.yml | 16 ++++++++++ .github/workflows/stale.yaml | 38 +++++++++++++++++++++++ .github/workflows/unstale.yaml | 27 ++++++++++++++++ OWNERS | 15 +++++++++ 7 files changed, 162 insertions(+) create mode 100644 .github/workflows/prow-github.yml create mode 100644 .github/workflows/prow-pr-automerge.yml create mode 100644 .github/workflows/prow-pr-remove-lgtm.yml create mode 100644 .github/workflows/re-run-action.yml create mode 100644 .github/workflows/stale.yaml create mode 100644 .github/workflows/unstale.yaml create mode 100644 OWNERS diff --git a/.github/workflows/prow-github.yml b/.github/workflows/prow-github.yml new file mode 100644 index 00000000..0c5f11fd --- /dev/null +++ b/.github/workflows/prow-github.yml @@ -0,0 +1,37 @@ +# Run specified actions or jobs for issue and PR comments + +name: "Prow github actions" +on: + issue_comment: + types: [created] + +# Grant additional permissions to the GITHUB_TOKEN +permissions: + # Allow labeling issues + issues: write + # Allow adding a review to a pull request + pull-requests: write + +jobs: + prow-execute: + runs-on: ubuntu-latest + steps: + - uses: jpmcb/prow-github-actions@v2.0.0 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + prow-commands: "/assign + /unassign + /approve + /retitle + /area + /kind + /priority + /remove + /lgtm + /close + /reopen + /lock + /milestone + /hold + /cc + /uncc" diff --git a/.github/workflows/prow-pr-automerge.yml b/.github/workflows/prow-pr-automerge.yml new file mode 100644 index 00000000..c9bb0972 --- /dev/null +++ b/.github/workflows/prow-pr-automerge.yml @@ -0,0 +1,18 @@ +# This Github workflow will check every 5m for PRs with the lgtm label and will attempt to automatically merge them. +# If the hold label is present, it will block automatic merging. + +name: "Prow merge on lgtm label" +on: + schedule: + - cron: "*/5 * * * *" # every 5 minutes + +jobs: + auto-merge: + runs-on: ubuntu-latest + steps: + - uses: jpmcb/prow-github-actions@v2.0.0 + with: + jobs: 'lgtm' + github-token: "${{ secrets.GITHUB_TOKEN }}" + merge-method: 'squash' + diff --git a/.github/workflows/prow-pr-remove-lgtm.yml b/.github/workflows/prow-pr-remove-lgtm.yml new file mode 100644 index 00000000..caf208f3 --- /dev/null +++ b/.github/workflows/prow-pr-remove-lgtm.yml @@ -0,0 +1,11 @@ +name: Run Jobs on PR +on: pull_request + +jobs: + execute: + runs-on: ubuntu-latest + steps: + - uses: jpmcb/prow-github-actions@v2.0.0 + with: + jobs: lgtm + github-token: '${{ secrets.GITHUB_TOKEN }}' diff --git a/.github/workflows/re-run-action.yml b/.github/workflows/re-run-action.yml new file mode 100644 index 00000000..cb2914e8 --- /dev/null +++ b/.github/workflows/re-run-action.yml @@ -0,0 +1,16 @@ +name: Re-Run PR tests + +on: + issue_comment: + types: [created] + +jobs: + rerun_pr_tests: + name: rerun_pr_tests + if: ${{ github.event.issue.pull_request }} + runs-on: ubuntu-20.04 + steps: + - uses: estroz/rerun-actions@main + with: + repo_token: ${{ secrets.GITHUB_TOKEN }} + comment_id: ${{ github.event.comment.id }} diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml new file mode 100644 index 00000000..622a4ffd --- /dev/null +++ b/.github/workflows/stale.yaml @@ -0,0 +1,38 @@ +name: 'Mark stale issues' + +on: + schedule: + - cron: '0 1 * * *' + +jobs: + stale-issues: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - name: 'Mark stale issues' + uses: actions/stale@v9 + with: + days-before-issue-stale: 90 + days-before-pr-stale: -1 + days-before-close: -1 + stale-issue-label: 'lifecycle/stale' + exempt-issue-labels: 'lifecycle/rotten' + + - name: 'Mark rotten issues' + uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-pr-stale: -1 + days-before-close: -1 + stale-issue-label: 'lifecycle/rotten' + only-labels: 'lifecycle/stale' + labels-to-remove-when-stale: 'lifecycle/stale' + + - name: 'Close rotten issues' + uses: actions/stale@v9 + with: + days-before-stale: -1 + days-before-issue-close: 30 + days-before-pr-close: -1 + stale-issue-label: 'lifecycle/rotten' diff --git a/.github/workflows/unstale.yaml b/.github/workflows/unstale.yaml new file mode 100644 index 00000000..319ecf14 --- /dev/null +++ b/.github/workflows/unstale.yaml @@ -0,0 +1,27 @@ +name: 'Unstale Issue' + +on: + issues: + types: [ reopened ] + issue_comment: + types: [ created ] + +jobs: + remove-stale: + runs-on: ubuntu-latest + permissions: + issues: write + if: >- + github.event.issue.state == 'open' && + (contains(github.event.issue.labels.*.name, 'lifecycle/stale') || + contains(github.event.issue.labels.*.name, 'lifecycle/rotten')) + steps: + - name: 'Checkout repository' + uses: actions/checkout@v5 + + - name: 'Remove stale labels' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "Removing 'stale' label from issue #${{ github.event.issue.number }}" + gh issue edit ${{ github.event.issue.number }} --remove-label "lifecycle/stale,lifecycle/rotten" diff --git a/OWNERS b/OWNERS new file mode 100644 index 00000000..8b464bc5 --- /dev/null +++ b/OWNERS @@ -0,0 +1,15 @@ +approvers: +- mayabar +- irar2 +- shmuelk +- elevran +- kfirtoledo +- nilig + +reviewers: +- mayabar +- irar2 +- shmuelk +- elevran +- kfirtoledo +- nilig From 4076bd2e5542d5f0a6fc0d4802a5cb98926b917d Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Thu, 21 Aug 2025 13:16:57 +0300 Subject: [PATCH 32/46] small changes in texts (#156) Signed-off-by: Maya Barnea --- README.md | 6 +++--- pkg/common/config.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 275f8e13..f274b2b1 100644 --- a/README.md +++ b/README.md @@ -122,13 +122,13 @@ For more details see the Date: Sun, 24 Aug 2025 01:09:01 -0400 Subject: [PATCH 33/46] Fix server interrupt (#161) * pass ctx to startServer Signed-off-by: npolshakova * fix sim test to use ctx Signed-off-by: npolshakova --------- Signed-off-by: npolshakova --- pkg/llm-d-inference-sim/simulator.go | 34 ++++++++++++++++++++--- pkg/llm-d-inference-sim/simulator_test.go | 2 +- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index da5f53e1..d2d61ca2 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -146,8 +146,8 @@ func (s *VllmSimulator) Start(ctx context.Context) error { return err } - // start the http server - return s.startServer(listener) + // start the http server with context support + return s.startServer(ctx, listener) } func (s *VllmSimulator) newListener() (net.Listener, error) { @@ -160,7 +160,7 @@ func (s *VllmSimulator) newListener() (net.Listener, error) { } // startServer starts http server on port defined in command line -func (s *VllmSimulator) startServer(listener net.Listener) error { +func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { r := fasthttprouter.New() // support completion APIs @@ -189,7 +189,33 @@ func (s *VllmSimulator) startServer(listener net.Listener) error { } }() - return server.Serve(listener) + // Start server in a goroutine + serverErr := make(chan error, 1) + go func() { + s.logger.Info("HTTP server starting") + serverErr <- server.Serve(listener) + }() + + // Wait for either context cancellation or server error + select { + case <-ctx.Done(): + s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") + + // Gracefully shutdown the server + if err := server.Shutdown(); err != nil { + s.logger.Error(err, "Error during server shutdown") + return err + } + + s.logger.Info("HTTP server stopped") + return nil + + case err := <-serverErr: + if err != nil { + s.logger.Error(err, "HTTP server failed") + } + return err + } } // Print prints to a log, implementation of fasthttp.Logger diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index b27d97a2..a86eb3f5 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -115,7 +115,7 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri // start the http server go func() { - if err := s.startServer(listener); err != nil { + if err := s.startServer(ctx, listener); err != nil { logger.Error(err, "error starting server") } }() From 1cdd97ebeb6043be05caa7c20268766957d9dc16 Mon Sep 17 00:00:00 2001 From: Qifan Deng <20884468+pancak3@users.noreply.github.com> Date: Sun, 24 Aug 2025 18:01:00 +1000 Subject: [PATCH 34/46] Show final config in simulaor default logger at Info lvel (#154) * Show final config in simulaor default logger at Info lvel Signed-off-by: Qifan Deng * Remove unnecessary local var and update show config prompt Signed-off-by: Qifan Deng * Resolve conflict due to new arg of zmq max retries Signed-off-by: Qifan Deng * Clean fields when show final configuration Signed-off-by: Qifan Deng * Simplify function syntax Signed-off-by: Qifan Deng * Fix golangci-lint installation link in makefile Signed-off-by: Qifan Deng * Fix err fmt when logger is invalid Signed-off-by: Qifan Deng --------- Signed-off-by: Qifan Deng --- Makefile | 2 +- pkg/common/config.go | 67 ++++++++++++++-------------- pkg/llm-d-inference-sim/simulator.go | 40 +++++++++++++++++ 3 files changed, 75 insertions(+), 34 deletions(-) diff --git a/Makefile b/Makefile index 7f4ea750..04691cf7 100644 --- a/Makefile +++ b/Makefile @@ -170,7 +170,7 @@ check-ginkgo: .PHONY: check-golangci-lint check-golangci-lint: @command -v golangci-lint >/dev/null 2>&1 || { \ - echo "❌ golangci-lint is not installed. Install from https://golangci-lint.run/usage/install/"; exit 1; } + echo "❌ golangci-lint is not installed. Install from https://golangci-lint.run/docs/welcome/install/"; exit 1; } .PHONY: check-container-tool check-container-tool: diff --git a/pkg/common/config.go b/pkg/common/config.go index 5f7357c3..3d5f6ac1 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -39,100 +39,101 @@ const ( type Configuration struct { // Port defines on which port the simulator runs - Port int `yaml:"port"` + Port int `yaml:"port" json:"port"` // Model defines the current base model name - Model string `yaml:"model"` + Model string `yaml:"model" json:"model"` // ServedModelNames is one or many model names exposed by the API - ServedModelNames []string `yaml:"served-model-name"` + ServedModelNames []string `yaml:"served-model-name" json:"served-model-name"` // MaxLoras defines maximum number of loaded LoRAs - MaxLoras int `yaml:"max-loras"` + MaxLoras int `yaml:"max-loras" json:"max-loras"` // MaxCPULoras defines maximum number of LoRAs to store in CPU memory - MaxCPULoras int `yaml:"max-cpu-loras"` + MaxCPULoras int `yaml:"max-cpu-loras" json:"max-cpu-loras"` // MaxNumSeqs is maximum number of sequences per iteration (the maximum // number of inference requests that could be processed at the same time) - MaxNumSeqs int `yaml:"max-num-seqs"` + MaxNumSeqs int `yaml:"max-num-seqs" json:"max-num-seqs"` // MaxModelLen is the model's context window, the maximum number of tokens // in a single request including input and output. Default value is 1024. - MaxModelLen int `yaml:"max-model-len"` + MaxModelLen int `yaml:"max-model-len" json:"max-model-len"` // LoraModulesString is a list of LoRA adapters as strings - LoraModulesString []string `yaml:"lora-modules"` + LoraModulesString []string `yaml:"lora-modules" json:"lora-modules"` // LoraModules is a list of LoRA adapters LoraModules []LoraModule // TimeToFirstToken time before the first token will be returned, in milliseconds - TimeToFirstToken int `yaml:"time-to-first-token"` + TimeToFirstToken int `yaml:"time-to-first-token" json:"time-to-first-token"` // TimeToFirstTokenStdDev standard deviation for time before the first token will be returned, // in milliseconds, optional, default is 0, can't be more than 30% of TimeToFirstToken, will not // cause the actual time to first token to differ by more than 70% from TimeToFirstToken - TimeToFirstTokenStdDev int `yaml:"time-to-first-token-std-dev"` + TimeToFirstTokenStdDev int `yaml:"time-to-first-token-std-dev" json:"time-to-first-token-std-dev"` // InterTokenLatency time between generated tokens, in milliseconds - InterTokenLatency int `yaml:"inter-token-latency"` + InterTokenLatency int `yaml:"inter-token-latency" json:"inter-token-latency"` // InterTokenLatencyStdDev standard deviation for time between generated tokens, in milliseconds, // optional, default is 0, can't be more than 30% of InterTokenLatency, will not cause the actual // inter token latency to differ by more than 70% from InterTokenLatency - InterTokenLatencyStdDev int `yaml:"inter-token-latency-std-dev"` + InterTokenLatencyStdDev int `yaml:"inter-token-latency-std-dev" json:"inter-token-latency-std-dev"` // KVCacheTransferLatency time to "transfer" kv-cache from another vLLM instance in case P/D is activated, // in milliseconds - KVCacheTransferLatency int `yaml:"kv-cache-transfer-latency"` + KVCacheTransferLatency int `yaml:"kv-cache-transfer-latency" json:"kv-cache-transfer-latency"` // KVCacheTransferLatencyStdDev standard deviation for time to "transfer" kv-cache from another // vLLM instance in case P/D is activated, in milliseconds, optional, default is 0, can't be more // than 30% of KVCacheTransferLatency, will not cause the actual latency to differ by more than 70% from // KVCacheTransferLatency - KVCacheTransferLatencyStdDev int `yaml:"kv-cache-transfer-latency-std-dev"` + KVCacheTransferLatencyStdDev int `yaml:"kv-cache-transfer-latency-std-dev" json:"kv-cache-transfer-latency-std-dev"` // Mode defines the simulator response generation mode, valid values: echo, random - Mode string `yaml:"mode"` + Mode string `yaml:"mode" json:"mode"` // Seed defines random seed for operations - Seed int64 `yaml:"seed"` + Seed int64 `yaml:"seed" json:"seed"` // MaxToolCallIntegerParam defines the maximum possible value of integer parameters in a tool call, // optional, defaults to 100 - MaxToolCallIntegerParam int `yaml:"max-tool-call-integer-param"` + MaxToolCallIntegerParam int `yaml:"max-tool-call-integer-param" json:"max-tool-call-integer-param"` // MinToolCallIntegerParam defines the minimum possible value of integer parameters in a tool call, // optional, defaults to 0 - MinToolCallIntegerParam int `yaml:"min-tool-call-integer-param"` + MinToolCallIntegerParam int `yaml:"min-tool-call-integer-param" json:"min-tool-call-integer-param"` // MaxToolCallNumberParam defines the maximum possible value of number (float) parameters in a tool call, // optional, defaults to 100 - MaxToolCallNumberParam float64 `yaml:"max-tool-call-number-param"` + MaxToolCallNumberParam float64 `yaml:"max-tool-call-number-param" json:"max-tool-call-number-param"` // MinToolCallNumberParam defines the minimum possible value of number (float) parameters in a tool call, // optional, defaults to 0 - MinToolCallNumberParam float64 `yaml:"min-tool-call-number-param"` + MinToolCallNumberParam float64 `yaml:"min-tool-call-number-param" json:"min-tool-call-number-param"` // MaxToolCallArrayParamLength defines the maximum possible length of array parameters in a tool call, // optional, defaults to 5 - MaxToolCallArrayParamLength int `yaml:"max-tool-call-array-param-length"` + MaxToolCallArrayParamLength int `yaml:"max-tool-call-array-param-length" json:"max-tool-call-array-param-length"` // MinToolCallArrayParamLength defines the minimum possible length of array parameters in a tool call, // optional, defaults to 1 - MinToolCallArrayParamLength int `yaml:"min-tool-call-array-param-length"` + MinToolCallArrayParamLength int `yaml:"min-tool-call-array-param-length" json:"min-tool-call-array-param-length"` // ToolCallNotRequiredParamProbability is the probability to add a parameter, that is not required, // in a tool call, optional, defaults to 50 - ToolCallNotRequiredParamProbability int `yaml:"tool-call-not-required-param-probability"` + ToolCallNotRequiredParamProbability int `yaml:"tool-call-not-required-param-probability" json:"tool-call-not-required-param-probability"` // ObjectToolCallNotRequiredParamProbability is the probability to add a field, that is not required, // in an object in a tool call, optional, defaults to 50 - ObjectToolCallNotRequiredParamProbability int `yaml:"object-tool-call-not-required-field-probability"` + ObjectToolCallNotRequiredParamProbability int `yaml:"object-tool-call-not-required-field-probability" json:"object-tool-call-not-required-field-probability"` // EnableKVCache defines if kv cache feature will be enabled - EnableKVCache bool `yaml:"enable-kvcache"` + EnableKVCache bool `yaml:"enable-kvcache" json:"enable-kvcache"` // KVCacheSize is the maximum number of token blocks in kv cache, the default value is 1024 - KVCacheSize int `yaml:"kv-cache-size"` + KVCacheSize int `yaml:"kv-cache-size" json:"kv-cache-size"` // TokenizersCacheDir is the directory for caching tokenizers - TokenizersCacheDir string `yaml:"tokenizers-cache-dir"` + TokenizersCacheDir string `yaml:"tokenizers-cache-dir" json:"tokenizers-cache-dir"` // TokenBlockSize is token block size for contiguous chunks of tokens, possible values: 8,16,32,64,128, defaults to 16 - TokenBlockSize int `yaml:"block-size"` + TokenBlockSize int `yaml:"block-size" json:"block-size"` // HashSeed is the seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable) - HashSeed string `yaml:"hash-seed"` + HashSeed string `yaml:"hash-seed" json:"hash-seed"` // ZMQEndpoint is the ZMQ address to publish events, the default value is tcp://localhost:5557 - ZMQEndpoint string `yaml:"zmq-endpoint"` + ZMQEndpoint string `yaml:"zmq-endpoint" json:"zmq-endpoint"` // ZMQMaxConnectAttempts defines the maximum number (10) of retries when ZMQ connection fails - ZMQMaxConnectAttempts uint `yaml:"zmq-max-connect-attempts"` + ZMQMaxConnectAttempts uint `yaml:"zmq-max-connect-attempts" json:"zmq-max-connect-attempts"` + // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 - EventBatchSize int `yaml:"event-batch-size"` + EventBatchSize int `yaml:"event-batch-size" json:"event-batch-size"` // FakeMetrics is a set of metrics to send to Prometheus instead of the real data - FakeMetrics *Metrics `yaml:"fake-metrics"` + FakeMetrics *Metrics `yaml:"fake-metrics" json:"fake-metrics"` } type Metrics struct { diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index d2d61ca2..da629e30 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -20,6 +20,7 @@ package llmdinferencesim import ( "context" "encoding/json" + "errors" "fmt" "net" "os" @@ -114,8 +115,14 @@ func (s *VllmSimulator) Start(ctx context.Context) error { if err != nil { return err } + s.config = config + err = s.showConfig(s.logger) + if err != nil { + return err + } + for _, lora := range config.LoraModules { s.loraAdaptors.Store(lora.Name, "") } @@ -734,3 +741,36 @@ func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { } return s.config.ServedModelNames[0] } + +func (s *VllmSimulator) showConfig(tgtLgr logr.Logger) error { + if tgtLgr == logr.Discard() { + return errors.New("target logger is nil, cannot show configuration") + } + cfgJSON, err := json.Marshal(s.config) + if err != nil { + return fmt.Errorf("failed to marshal configuration to JSON: %w", err) + } + + // clean LoraModulesString field + var m map[string]interface{} + err = json.Unmarshal(cfgJSON, &m) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON to map: %w", err) + } + m["lora-modules"] = m["LoraModules"] + delete(m, "LoraModules") + delete(m, "LoraModulesString") + + // clean fake-metrics field + if field, ok := m["fake-metrics"].(map[string]interface{}); ok { + delete(field, "LorasString") + } + + // show in JSON + cfgJSON, err = json.MarshalIndent(m, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal configuration to JSON: %w", err) + } + tgtLgr.Info("Configuration:", "", string(cfgJSON)) + return nil +} From bdc5ecb6e6b331f99000640e14aa0f509e18f44b Mon Sep 17 00:00:00 2001 From: Qifan Deng <20884468+pancak3@users.noreply.github.com> Date: Sun, 24 Aug 2025 18:31:44 +1000 Subject: [PATCH 35/46] Cast bounds type in tests to func def: latency, interToken, and timeToFirst (to int) (#163) * Cast bounds type in tests to func def: latency, interToken, and timeToFirst (to int) Signed-off-by: Qifan Deng * Use float 32 Signed-off-by: Qifan Deng --------- Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/simulator_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index a86eb3f5..cf9fd468 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -761,8 +761,8 @@ var _ = Describe("Simulator", func() { simulator.config.InterTokenLatency = interTokenLatency simulator.config.InterTokenLatencyStdDev = stddev interToken := simulator.getInterTokenLatency() - Expect(interToken).To(BeNumerically(">=", float32(interTokenLatency)*0.3)) - Expect(interToken).To(BeNumerically("<=", float32(interTokenLatency)*1.7)) + Expect(interToken).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3))) + Expect(interToken).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7))) }, func(interTokenLatency int, stddev int) string { return fmt.Sprintf("interTokenLatency: %d stddev: %d", interTokenLatency, stddev) @@ -778,8 +778,8 @@ var _ = Describe("Simulator", func() { simulator.config.InterTokenLatency = interTokenLatency simulator.config.InterTokenLatencyStdDev = stddev latency := simulator.getTotalInterTokenLatency(numberOfTokens) - Expect(latency).To(BeNumerically(">=", float32(interTokenLatency)*0.3*float32(numberOfTokens))) - Expect(latency).To(BeNumerically("<=", float32(interTokenLatency)*1.7*float32(numberOfTokens))) + Expect(latency).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3*float32(numberOfTokens)))) + Expect(latency).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7*float32(numberOfTokens)))) }, func(interTokenLatency int, stddev int, numberOfTokens int) string { return fmt.Sprintf("interTokenLatency: %d stddev: %d, numberOfTokens: %d", interTokenLatency, @@ -800,11 +800,11 @@ var _ = Describe("Simulator", func() { simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill) if doREmotePrefill { - Expect(timeToFirst).To(BeNumerically(">=", float32(kvCacheLatency)*0.3)) - Expect(timeToFirst).To(BeNumerically("<=", float32(kvCacheLatency)*1.7)) + Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) } else { - Expect(timeToFirst).To(BeNumerically(">=", float32(timeToFirstToken)*0.3)) - Expect(timeToFirst).To(BeNumerically("<=", float32(timeToFirstToken)*1.7)) + Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7))) } }, func(timeToFirstToken int, timeToFirstTokenStdDev int, From 703735de5edc8571d82fa9d8ce3d78eb6db9c3ee Mon Sep 17 00:00:00 2001 From: Qifan Deng <20884468+pancak3@users.noreply.github.com> Date: Sun, 24 Aug 2025 19:34:50 +1000 Subject: [PATCH 36/46] Remvoe unnecessary deferal of server close (#162) Signed-off-by: Qifan Deng --- pkg/llm-d-inference-sim/simulator.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index da629e30..d9813996 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -190,12 +190,6 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) Logger: s, } - defer func() { - if err := listener.Close(); err != nil { - s.logger.Error(err, "server listener close failed") - } - }() - // Start server in a goroutine serverErr := make(chan error, 1) go func() { From 06f6e42c0379cf4d64a592788ee0dca378a28919 Mon Sep 17 00:00:00 2001 From: Qifan Deng <20884468+pancak3@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:15:22 +1000 Subject: [PATCH 37/46] Fix: Rand generator is not set in a test suite which result in accessing nil pointer during runtime if run the only test suite (#166) * Add make flag to filter test case Signed-off-by: Qifan Deng * Init random generator in Check random latencies test suite Signed-off-by: Qifan Deng --------- Signed-off-by: Qifan Deng --- Makefile | 4 ++++ pkg/llm-d-inference-sim/simulator_test.go | 3 +++ 2 files changed, 7 insertions(+) diff --git a/Makefile b/Makefile index 04691cf7..c4fddb1d 100644 --- a/Makefile +++ b/Makefile @@ -77,7 +77,11 @@ format: ## Format Go source files .PHONY: test test: check-ginkgo download-tokenizer download-zmq ## Run tests @printf "\033[33;1m==== Running tests ====\033[0m\n" +ifdef GINKGO_FOCUS + CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r --focus="$(GINKGO_FOCUS)" +else CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r +endif .PHONY: post-deploy-test post-deploy-test: ## Run post deployment tests diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index cf9fd468..88d87759 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" . "github.com/onsi/ginkgo/v2" @@ -754,6 +755,8 @@ var _ = Describe("Simulator", func() { KVCacheTransferLatency: 2048, KVCacheTransferLatencyStdDev: 2048, } + + common.InitRandom(time.Now().UnixNano()) }) DescribeTable("should calculate inter token latency correctly", From bfa02ff6eb8e11111bd3b0e7c15e0c399d7db546 Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Mon, 25 Aug 2025 18:48:04 -0400 Subject: [PATCH 38/46] Refactor failure type usage and error response format Signed-off-by: Sergey Marunich --- pkg/common/config.go | 6 +-- pkg/common/config_test.go | 13 ++++++ pkg/llm-d-inference-sim/failures.go | 12 +++--- pkg/llm-d-inference-sim/simulator.go | 3 +- pkg/llm-d-inference-sim/simulator_test.go | 48 ++++------------------- pkg/openai-server-api/response.go | 3 +- 6 files changed, 32 insertions(+), 53 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index ce4f77cc..8866da75 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -199,8 +199,6 @@ func newConfig() *Configuration { TokenBlockSize: 16, ZMQEndpoint: "tcp://localhost:5557", EventBatchSize: 16, - FailureInjectionRate: 0, - FailureTypes: []string{}, } } @@ -329,7 +327,9 @@ func (c *Configuration) validate() error { } for _, failureType := range c.FailureTypes { if !validFailureTypes[failureType] { - return fmt.Errorf("invalid failure type '%s', valid types are: rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found", failureType) + return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType, + FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, + FailureTypeServerError, FailureTypeInvalidRequest, FailureTypeModelNotFound) } } return nil diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 6e768c27..024267cf 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -298,6 +298,19 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--event-batch-size", "-35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid failure injection rate > 100", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "150"}, + }, + { + name: "invalid failure injection rate < 0", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "-10"}, + }, + { + name: "invalid failure type", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "50", + "--failure-types", "invalid_type"}, + }, } for _, test := range invalidTests { diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go index faee379e..625c7785 100644 --- a/pkg/llm-d-inference-sim/failures.go +++ b/pkg/llm-d-inference-sim/failures.go @@ -24,8 +24,8 @@ import ( const ( // Error message templates - RateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1." - ModelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" + rateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1." + modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" ) type FailureSpec struct { @@ -41,7 +41,7 @@ var predefinedFailures = map[string]FailureSpec{ StatusCode: 429, ErrorType: "rate_limit_exceeded", ErrorCode: "rate_limit_exceeded", - Message: "Rate limit reached for model in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1.", + Message: rateLimitMessageTemplate, Param: nil, }, common.FailureTypeInvalidAPIKey: { @@ -76,7 +76,7 @@ var predefinedFailures = map[string]FailureSpec{ StatusCode: 404, ErrorType: "invalid_request_error", ErrorCode: "model_not_found", - Message: "The model 'gpt-nonexistent' does not exist", + Message: modelNotFoundMessageTemplate, Param: stringPtr("model"), }, } @@ -113,9 +113,9 @@ func GetRandomFailure(config *common.Configuration) FailureSpec { // Customize message with current model name failure := predefinedFailures[randomType] if randomType == common.FailureTypeRateLimit && config.Model != "" { - failure.Message = fmt.Sprintf(RateLimitMessageTemplate, config.Model) + failure.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) } else if randomType == common.FailureTypeModelNotFound && config.Model != "" { - failure.Message = fmt.Sprintf(ModelNotFoundMessageTemplate, config.Model) + failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) } return failure diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 00ae329f..4e38b65b 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -533,8 +533,7 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo } errorResp := openaiserverapi.ErrorResponse{ - Object: "error", - Error: compErr, + Error: compErr, } if isInjected { diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index c6629354..a21d2670 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -667,7 +667,7 @@ var _ = Describe("Simulator", func() { client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, "--failure-injection-rate", "100", - "--failure-types", "rate_limit", + "--failure-types", common.FailureTypeRateLimit, }) Expect(err).ToNot(HaveOccurred()) }) @@ -703,7 +703,7 @@ var _ = Describe("Simulator", func() { client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, "--failure-injection-rate", "100", - "--failure-types", "invalid_api_key", "server_error", + "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, }) Expect(err).ToNot(HaveOccurred()) }) @@ -800,46 +800,14 @@ var _ = Describe("Simulator", func() { // Note: OpenAI Go client doesn't directly expose the error code field, // but we can verify via status code and type }, - Entry("rate_limit", "rate_limit", 429, "rate_limit_exceeded", "rate_limit_exceeded"), - Entry("invalid_api_key", "invalid_api_key", 401, "invalid_request_error", "invalid_api_key"), - Entry("context_length", "context_length", 400, "invalid_request_error", "context_length_exceeded"), - Entry("server_error", "server_error", 503, "server_error", "server_error"), - Entry("invalid_request", "invalid_request", 400, "invalid_request_error", "invalid_request_error"), - Entry("model_not_found", "model_not_found", 404, "invalid_request_error", "model_not_found"), + Entry("rate_limit", common.FailureTypeRateLimit, 429, "rate_limit_exceeded", "rate_limit_exceeded"), + Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, "invalid_request_error", "invalid_api_key"), + Entry("context_length", common.FailureTypeContextLength, 400, "invalid_request_error", "context_length_exceeded"), + Entry("server_error", common.FailureTypeServerError, 503, "server_error", "server_error"), + Entry("invalid_request", common.FailureTypeInvalidRequest, 400, "invalid_request_error", "invalid_request_error"), + Entry("model_not_found", common.FailureTypeModelNotFound, 404, "invalid_request_error", "model_not_found"), ) }) - Context("configuration validation", func() { - It("should fail with invalid failure injection rate > 100", func() { - ctx := context.Background() - _, err := startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "150", - }) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("failure injection rate should be between 0 and 100")) - }) - - It("should fail with invalid failure injection rate < 0", func() { - ctx := context.Background() - _, err := startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "-10", - }) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("failure injection rate should be between 0 and 100")) - }) - - It("should fail with invalid failure type", func() { - ctx := context.Background() - _, err := startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "50", - "--failure-types", "invalid_type", - }) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("invalid failure type 'invalid_type'")) - }) - }) }) }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index f816b06d..9e8549b3 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -220,6 +220,5 @@ type CompletionError struct { // ErrorResponse wraps the error in the expected OpenAI format type ErrorResponse struct { - Object string `json:"object"` - Error CompletionError `json:"error"` + Error CompletionError `json:"error"` } From 700e36f76a16c0953b81bbc8ea24ad9bb90ff957 Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Mon, 25 Aug 2025 18:50:28 -0400 Subject: [PATCH 39/46] Refactor failure type flag handling and code formatting Signed-off-by: Sergey Marunich --- pkg/common/config.go | 44 +++++++++++------------ pkg/llm-d-inference-sim/failures.go | 14 ++++---- pkg/llm-d-inference-sim/failures_test.go | 2 +- pkg/llm-d-inference-sim/simulator.go | 6 ++-- pkg/llm-d-inference-sim/simulator_test.go | 24 ++++++------- 5 files changed, 45 insertions(+), 45 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 8866da75..bf74dbba 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -34,14 +34,15 @@ const ( vLLMDefaultPort = 8000 ModeRandom = "random" ModeEcho = "echo" - + dummyFlagValue = "dummy" + // Failure type constants - FailureTypeRateLimit = "rate_limit" - FailureTypeInvalidAPIKey = "invalid_api_key" - FailureTypeContextLength = "context_length" - FailureTypeServerError = "server_error" - FailureTypeInvalidRequest = "invalid_request" - FailureTypeModelNotFound = "model_not_found" + FailureTypeRateLimit = "rate_limit" + FailureTypeInvalidAPIKey = "invalid_api_key" + FailureTypeContextLength = "context_length" + FailureTypeServerError = "server_error" + FailureTypeInvalidRequest = "invalid_request" + FailureTypeModelNotFound = "model_not_found" ) type Configuration struct { @@ -195,10 +196,10 @@ func newConfig() *Configuration { MinToolCallArrayParamLength: 1, ToolCallNotRequiredParamProbability: 50, ObjectToolCallNotRequiredParamProbability: 50, - KVCacheSize: 1024, - TokenBlockSize: 16, - ZMQEndpoint: "tcp://localhost:5557", - EventBatchSize: 16, + KVCacheSize: 1024, + TokenBlockSize: 16, + ZMQEndpoint: "tcp://localhost:5557", + EventBatchSize: 16, } } @@ -312,8 +313,8 @@ func (c *Configuration) validate() error { if c.EventBatchSize < 1 { return errors.New("event batch size cannot less than 1") } - - if c.FailureInjectionRate < 0 || c.FailureInjectionRate > 100 { + + if c.FailureInjectionRate < 0 || c.FailureInjectionRate > 100 { return errors.New("failure injection rate should be between 0 and 100") } @@ -327,8 +328,8 @@ func (c *Configuration) validate() error { } for _, failureType := range c.FailureTypes { if !validFailureTypes[failureType] { - return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType, - FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, + return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType, + FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, FailureTypeServerError, FailureTypeInvalidRequest, FailureTypeModelNotFound) } } @@ -384,14 +385,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") - - f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") + + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") failureTypes := getParamValueFromArgs("failure-types") var dummyFailureTypes multiString f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") - f.Lookup("failure-types").NoOptDefVal = "dummy" - + f.Lookup("failure-types").NoOptDefVal = dummyFlagValue // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string @@ -400,8 +400,8 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.Var(&dummyMultiString, "served-model-name", "Model names exposed by the API (a list of space-separated strings)") f.Var(&dummyMultiString, "lora-modules", "List of LoRA adapters (a list of space-separated JSON strings)") // In order to allow empty arguments, we set a dummy NoOptDefVal for these flags - f.Lookup("served-model-name").NoOptDefVal = "dummy" - f.Lookup("lora-modules").NoOptDefVal = "dummy" + f.Lookup("served-model-name").NoOptDefVal = dummyFlagValue + f.Lookup("lora-modules").NoOptDefVal = dummyFlagValue flagSet := flag.NewFlagSet("simFlagSet", flag.ExitOnError) klog.InitFlags(flagSet) @@ -480,4 +480,4 @@ func getParamValueFromArgs(param string) []string { } } return values -} \ No newline at end of file +} diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go index 625c7785..7117ea12 100644 --- a/pkg/llm-d-inference-sim/failures.go +++ b/pkg/llm-d-inference-sim/failures.go @@ -24,7 +24,7 @@ import ( const ( // Error message templates - rateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1." + rateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1." modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" ) @@ -86,7 +86,7 @@ func ShouldInjectFailure(config *common.Configuration) bool { if config.FailureInjectionRate == 0 { return false } - + return common.RandomInt(1, 100) <= config.FailureInjectionRate } @@ -101,15 +101,15 @@ func GetRandomFailure(config *common.Configuration) FailureSpec { } else { availableFailures = config.FailureTypes } - + if len(availableFailures) == 0 { // Fallback to server_error if no valid types return predefinedFailures[common.FailureTypeServerError] } - + randomIndex := common.RandomInt(0, len(availableFailures)-1) randomType := availableFailures[randomIndex] - + // Customize message with current model name failure := predefinedFailures[randomType] if randomType == common.FailureTypeRateLimit && config.Model != "" { @@ -117,10 +117,10 @@ func GetRandomFailure(config *common.Configuration) FailureSpec { } else if randomType == common.FailureTypeModelNotFound && config.Model != "" { failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) } - + return failure } func stringPtr(s string) *string { return &s -} \ No newline at end of file +} diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go index 9a2bd1de..4ff80eed 100644 --- a/pkg/llm-d-inference-sim/failures_test.go +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -125,4 +125,4 @@ var _ = Describe("Failures", func() { Expect(failure.ErrorType).ToNot(BeEmpty()) }) }) -}) \ No newline at end of file +}) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 4e38b65b..1a27d9ef 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -507,7 +507,7 @@ func (s *VllmSimulator) responseSentCallback(model string) { func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo interface{}, isInjected bool) { var compErr openaiserverapi.CompletionError var statusCode int - + switch v := errorInfo.(type) { case string: // Legacy call with string message (backward compatibility) @@ -531,11 +531,11 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo // For calls with msg, errType, and code - need to be updated in calling code panic("sendCompletionError called with unexpected type") } - + errorResp := openaiserverapi.ErrorResponse{ Error: compErr, } - + if isInjected { s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) } else { diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index a21d2670..0ca8795a 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -607,7 +607,7 @@ var _ = Describe("Simulator", func() { ctx = context.Background() var err error client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, + "cmd", "--model", model, "--failure-injection-rate", "100", }) Expect(err).ToNot(HaveOccurred()) @@ -627,7 +627,7 @@ var _ = Describe("Simulator", func() { }) Expect(err).To(HaveOccurred()) - + var openaiError *openai.Error ok := errors.As(err, &openaiError) Expect(ok).To(BeTrue()) @@ -650,7 +650,7 @@ var _ = Describe("Simulator", func() { }) Expect(err).To(HaveOccurred()) - + var openaiError *openai.Error ok := errors.As(err, &openaiError) Expect(ok).To(BeTrue()) @@ -665,7 +665,7 @@ var _ = Describe("Simulator", func() { ctx = context.Background() var err error client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, + "cmd", "--model", model, "--failure-injection-rate", "100", "--failure-types", common.FailureTypeRateLimit, }) @@ -686,7 +686,7 @@ var _ = Describe("Simulator", func() { }) Expect(err).To(HaveOccurred()) - + var openaiError *openai.Error ok := errors.As(err, &openaiError) Expect(ok).To(BeTrue()) @@ -701,7 +701,7 @@ var _ = Describe("Simulator", func() { ctx = context.Background() var err error client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, + "cmd", "--model", model, "--failure-injection-rate", "100", "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, }) @@ -724,11 +724,11 @@ var _ = Describe("Simulator", func() { }) Expect(err).To(HaveOccurred()) - + var openaiError *openai.Error ok := errors.As(err, &openaiError) Expect(ok).To(BeTrue()) - + // Should only be one of the specified types Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) Expect(openaiError.Type == "invalid_request_error" || openaiError.Type == "server_error").To(BeTrue()) @@ -741,7 +741,7 @@ var _ = Describe("Simulator", func() { ctx = context.Background() var err error client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, + "cmd", "--model", model, "--failure-injection-rate", "0", }) Expect(err).ToNot(HaveOccurred()) @@ -772,7 +772,7 @@ var _ = Describe("Simulator", func() { func(failureType string, expectedStatusCode int, expectedErrorType string, expectedErrorCode string) { ctx := context.Background() client, err := startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, + "cmd", "--model", model, "--failure-injection-rate", "100", "--failure-types", failureType, }) @@ -791,13 +791,13 @@ var _ = Describe("Simulator", func() { }) Expect(err).To(HaveOccurred()) - + var openaiError *openai.Error ok := errors.As(err, &openaiError) Expect(ok).To(BeTrue()) Expect(openaiError.StatusCode).To(Equal(expectedStatusCode)) Expect(openaiError.Type).To(Equal(expectedErrorType)) - // Note: OpenAI Go client doesn't directly expose the error code field, + // Note: OpenAI Go client doesn't directly expose the error code field, // but we can verify via status code and type }, Entry("rate_limit", common.FailureTypeRateLimit, 429, "rate_limit_exceeded", "rate_limit_exceeded"), From 8f6d56c20e6ec6775e0c7951f33677c6bef50287 Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Mon, 25 Aug 2025 19:50:02 -0400 Subject: [PATCH 40/46] Fix config validation and simulator test argument handling Signed-off-by: Sergey Marunich --- pkg/common/config.go | 8 +++++--- pkg/common/config_test.go | 3 ++- pkg/llm-d-inference-sim/simulator_test.go | 13 +++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index f8b13722..65d434a2 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -41,7 +41,7 @@ const ( FailureTypeServerError = "server_error" FailureTypeInvalidRequest = "invalid_request" FailureTypeModelNotFound = "model_not_found" - dummy = "dummy" + dummy = "dummy" ) type Configuration struct { @@ -136,12 +136,12 @@ type Configuration struct { // ZMQMaxConnectAttempts defines the maximum number (10) of retries when ZMQ connection fails ZMQMaxConnectAttempts uint `yaml:"zmq-max-connect-attempts" json:"zmq-max-connect-attempts"` - // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 + // EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16 EventBatchSize int `yaml:"event-batch-size" json:"event-batch-size"` // FakeMetrics is a set of metrics to send to Prometheus instead of the real data FakeMetrics *Metrics `yaml:"fake-metrics" json:"fake-metrics"` - + // FailureInjectionRate is the probability (0-100) of injecting failures FailureInjectionRate int `yaml:"failure-injection-rate"` // FailureTypes is a list of specific failure types to inject (empty means all types) @@ -387,6 +387,8 @@ func (c *Configuration) validate() error { return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType, FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, FailureTypeServerError, FailureTypeInvalidRequest, FailureTypeModelNotFound) + } + } if c.ZMQMaxConnectAttempts > 10 { return errors.New("zmq retries times cannot be more than 10") diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 9f4c3f6a..770716a6 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -382,7 +382,8 @@ var _ = Describe("Simulator configuration", func() { name: "invalid failure type", args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "50", "--failure-types", "invalid_type"}, - }, + }, + { name: "invalid fake metrics: negative running requests", args: []string{"cmd", "--fake-metrics", "{\"running-requests\":-10,\"waiting-requests\":30,\"kv-cache-usage\":0.4}", "--config", "../../manifests/config.yaml"}, diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 4f018c3d..a412b725 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -843,7 +843,7 @@ var _ = Describe("Simulator", func() { client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, "--failure-injection-rate", "100", - }) + }, nil) Expect(err).ToNot(HaveOccurred()) }) @@ -902,7 +902,7 @@ var _ = Describe("Simulator", func() { "cmd", "--model", model, "--failure-injection-rate", "100", "--failure-types", common.FailureTypeRateLimit, - }) + }, nil) Expect(err).ToNot(HaveOccurred()) }) @@ -938,7 +938,7 @@ var _ = Describe("Simulator", func() { "cmd", "--model", model, "--failure-injection-rate", "100", "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, - }) + }, nil) Expect(err).ToNot(HaveOccurred()) }) @@ -977,7 +977,7 @@ var _ = Describe("Simulator", func() { client, err = startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, "--failure-injection-rate", "0", - }) + }, nil) Expect(err).ToNot(HaveOccurred()) }) @@ -1009,7 +1009,7 @@ var _ = Describe("Simulator", func() { "cmd", "--model", model, "--failure-injection-rate", "100", "--failure-types", failureType, - }) + }, nil) Expect(err).ToNot(HaveOccurred()) openaiClient := openai.NewClient( @@ -1042,7 +1042,8 @@ var _ = Describe("Simulator", func() { Entry("model_not_found", common.FailureTypeModelNotFound, 404, "invalid_request_error", "model_not_found"), ) }) - + }) + Context("fake metrics", func() { It("Should respond with fake metrics to /metrics", func() { ctx := context.TODO() From e0183b748110c1820c33b7ca2e728492183088bc Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Mon, 25 Aug 2025 20:53:52 -0400 Subject: [PATCH 41/46] remove duplicate Signed-off-by: Sergey Marunich --- pkg/common/config.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 65d434a2..23b43f80 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -514,20 +514,6 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { } } - if config.HashSeed == "" { - hashSeed := os.Getenv("PYTHONHASHSEED") - if hashSeed != "" { - config.HashSeed = hashSeed - } - } - - if config.HashSeed == "" { - hashSeed := os.Getenv("PYTHONHASHSEED") - if hashSeed != "" { - config.HashSeed = hashSeed - } - } - if err := config.validate(); err != nil { return nil, err } From 178a5948a54206ce8a78912a262a11d8595875be Mon Sep 17 00:00:00 2001 From: Sergey Marunich Date: Mon, 25 Aug 2025 21:12:26 -0400 Subject: [PATCH 42/46] Refactor failure handling to use CompletionError struct Failure handling in the simulator now uses the CompletionError struct from the openai-server-api package, replacing custom error fields with a unified structure. This improves consistency in error responses and simplifies error injection logic. Associated tests and error handling code have been updated to reflect this change. Signed-off-by: Sergey Marunich --- pkg/llm-d-inference-sim/failures.go | 80 +++++++++++++----------- pkg/llm-d-inference-sim/failures_test.go | 55 ++++++---------- pkg/llm-d-inference-sim/simulator.go | 43 +++++++------ pkg/openai-server-api/response.go | 4 +- 4 files changed, 87 insertions(+), 95 deletions(-) diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go index 7117ea12..1ad7e0ed 100644 --- a/pkg/llm-d-inference-sim/failures.go +++ b/pkg/llm-d-inference-sim/failures.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" ) const ( @@ -28,61 +29,70 @@ const ( modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" ) -type FailureSpec struct { +type FailureInfo struct { StatusCode int - ErrorType string - ErrorCode string - Message string - Param *string + Error openaiserverapi.CompletionError } -var predefinedFailures = map[string]FailureSpec{ +var predefinedFailures = map[string]FailureInfo{ common.FailureTypeRateLimit: { StatusCode: 429, - ErrorType: "rate_limit_exceeded", - ErrorCode: "rate_limit_exceeded", - Message: rateLimitMessageTemplate, - Param: nil, + Error: openaiserverapi.CompletionError{ + Message: rateLimitMessageTemplate, + Type: "rate_limit_exceeded", + Code: 429, + Param: nil, + }, }, common.FailureTypeInvalidAPIKey: { StatusCode: 401, - ErrorType: "invalid_request_error", - ErrorCode: "invalid_api_key", - Message: "Incorrect API key provided", - Param: nil, + Error: openaiserverapi.CompletionError{ + Message: "Incorrect API key provided", + Type: "invalid_request_error", + Code: 401, + Param: nil, + }, }, common.FailureTypeContextLength: { StatusCode: 400, - ErrorType: "invalid_request_error", - ErrorCode: "context_length_exceeded", - Message: "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", - Param: stringPtr("messages"), + Error: openaiserverapi.CompletionError{ + Message: "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", + Type: "invalid_request_error", + Code: 400, + Param: stringPtr("messages"), + }, }, common.FailureTypeServerError: { StatusCode: 503, - ErrorType: "server_error", - ErrorCode: "server_error", - Message: "The server is overloaded or not ready yet.", - Param: nil, + Error: openaiserverapi.CompletionError{ + Message: "The server is overloaded or not ready yet.", + Type: "server_error", + Code: 503, + Param: nil, + }, }, common.FailureTypeInvalidRequest: { StatusCode: 400, - ErrorType: "invalid_request_error", - ErrorCode: "invalid_request_error", - Message: "Invalid request: missing required parameter 'model'.", - Param: stringPtr("model"), + Error: openaiserverapi.CompletionError{ + Message: "Invalid request: missing required parameter 'model'.", + Type: "invalid_request_error", + Code: 400, + Param: stringPtr("model"), + }, }, common.FailureTypeModelNotFound: { StatusCode: 404, - ErrorType: "invalid_request_error", - ErrorCode: "model_not_found", - Message: modelNotFoundMessageTemplate, - Param: stringPtr("model"), + Error: openaiserverapi.CompletionError{ + Message: modelNotFoundMessageTemplate, + Type: "invalid_request_error", + Code: 404, + Param: stringPtr("model"), + }, }, } -// ShouldInjectFailure determines whether to inject a failure based on configuration -func ShouldInjectFailure(config *common.Configuration) bool { +// shouldInjectFailure determines whether to inject a failure based on configuration +func shouldInjectFailure(config *common.Configuration) bool { if config.FailureInjectionRate == 0 { return false } @@ -91,7 +101,7 @@ func ShouldInjectFailure(config *common.Configuration) bool { } // GetRandomFailure returns a random failure from configured types or all types if none specified -func GetRandomFailure(config *common.Configuration) FailureSpec { +func GetRandomFailure(config *common.Configuration) FailureInfo { var availableFailures []string if len(config.FailureTypes) == 0 { // Use all failure types if none specified @@ -113,9 +123,9 @@ func GetRandomFailure(config *common.Configuration) FailureSpec { // Customize message with current model name failure := predefinedFailures[randomType] if randomType == common.FailureTypeRateLimit && config.Model != "" { - failure.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) + failure.Error.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) } else if randomType == common.FailureTypeModelNotFound && config.Model != "" { - failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) + failure.Error.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) } return failure diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go index 4ff80eed..8f1d176f 100644 --- a/pkg/llm-d-inference-sim/failures_test.go +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -27,24 +27,7 @@ import ( ) var _ = Describe("Failures", func() { - Describe("ShouldInjectFailure", func() { - It("should not inject failure when injection rate is 0", func() { - config := &common.Configuration{ - Mode: common.ModeRandom, - FailureInjectionRate: 0, - } - Expect(llmdinferencesim.ShouldInjectFailure(config)).To(BeFalse()) - }) - - It("should inject failure when injection rate is 100", func() { - config := &common.Configuration{ - Mode: common.ModeRandom, - FailureInjectionRate: 100, - } - Expect(llmdinferencesim.ShouldInjectFailure(config)).To(BeTrue()) - }) - - }) + // Note: shouldInjectFailure is now private, so we test it through GetRandomFailure behavior Describe("GetRandomFailure", func() { It("should return a failure from all types when none specified", func() { @@ -54,8 +37,8 @@ var _ = Describe("Failures", func() { } failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(BeNumerically(">=", 400)) - Expect(failure.Message).ToNot(BeEmpty()) - Expect(failure.ErrorType).ToNot(BeEmpty()) + Expect(failure.Error.Message).ToNot(BeEmpty()) + Expect(failure.Error.Type).ToNot(BeEmpty()) }) It("should return rate limit failure when specified", func() { @@ -65,9 +48,9 @@ var _ = Describe("Failures", func() { } failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(429)) - Expect(failure.ErrorType).To(Equal("rate_limit_exceeded")) - Expect(failure.ErrorCode).To(Equal("rate_limit_exceeded")) - Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) + Expect(failure.Error.Type).To(Equal("rate_limit_exceeded")) + Expect(failure.Error.Code).To(Equal(429)) + Expect(strings.Contains(failure.Error.Message, "test-model")).To(BeTrue()) }) It("should return invalid API key failure when specified", func() { @@ -76,9 +59,9 @@ var _ = Describe("Failures", func() { } failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(401)) - Expect(failure.ErrorType).To(Equal("invalid_request_error")) - Expect(failure.ErrorCode).To(Equal("invalid_api_key")) - Expect(failure.Message).To(Equal("Incorrect API key provided")) + Expect(failure.Error.Type).To(Equal("invalid_request_error")) + Expect(failure.Error.Code).To(Equal(401)) + Expect(failure.Error.Message).To(Equal("Incorrect API key provided")) }) It("should return context length failure when specified", func() { @@ -87,10 +70,10 @@ var _ = Describe("Failures", func() { } failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(400)) - Expect(failure.ErrorType).To(Equal("invalid_request_error")) - Expect(failure.ErrorCode).To(Equal("context_length_exceeded")) - Expect(failure.Param).ToNot(BeNil()) - Expect(*failure.Param).To(Equal("messages")) + Expect(failure.Error.Type).To(Equal("invalid_request_error")) + Expect(failure.Error.Code).To(Equal(400)) + Expect(failure.Error.Param).ToNot(BeNil()) + Expect(*failure.Error.Param).To(Equal("messages")) }) It("should return server error when specified", func() { @@ -99,8 +82,8 @@ var _ = Describe("Failures", func() { } failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(503)) - Expect(failure.ErrorType).To(Equal("server_error")) - Expect(failure.ErrorCode).To(Equal("server_error")) + Expect(failure.Error.Type).To(Equal("server_error")) + Expect(failure.Error.Code).To(Equal(503)) }) It("should return model not found failure when specified", func() { @@ -110,9 +93,9 @@ var _ = Describe("Failures", func() { } failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(Equal(404)) - Expect(failure.ErrorType).To(Equal("invalid_request_error")) - Expect(failure.ErrorCode).To(Equal("model_not_found")) - Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) + Expect(failure.Error.Type).To(Equal("invalid_request_error")) + Expect(failure.Error.Code).To(Equal(404)) + Expect(strings.Contains(failure.Error.Message, "test-model-nonexistent")).To(BeTrue()) }) It("should return server error as fallback for empty types", func() { @@ -122,7 +105,7 @@ var _ = Describe("Failures", func() { // This test is probabilistic since it randomly selects, but we can test structure failure := llmdinferencesim.GetRandomFailure(config) Expect(failure.StatusCode).To(BeNumerically(">=", 400)) - Expect(failure.ErrorType).ToNot(BeEmpty()) + Expect(failure.Error.Type).ToNot(BeEmpty()) }) }) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 9aff6d07..ba418019 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -330,7 +330,7 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure - if ShouldInjectFailure(s.config) { + if shouldInjectFailure(s.config) { failure := GetRandomFailure(s.config) s.sendCompletionError(ctx, failure, true) return @@ -345,12 +345,14 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple errMsg, errType, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, FailureSpec{ + s.sendCompletionError(ctx, FailureInfo{ StatusCode: errCode, - ErrorType: errType, - ErrorCode: "", - Message: errMsg, - Param: nil, + Error: openaiserverapi.CompletionError{ + Message: errMsg, + Type: errType, + Code: errCode, + Param: nil, + }, }, false) return } @@ -378,13 +380,15 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, FailureSpec{ + s.sendCompletionError(ctx, FailureInfo{ StatusCode: fasthttp.StatusBadRequest, - ErrorType: "BadRequestError", - ErrorCode: "", - Message: fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), - Param: nil, + Error: openaiserverapi.CompletionError{ + Message: fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), + Type: "BadRequestError", + Code: fasthttp.StatusBadRequest, + Param: nil, + }, }, false) return } @@ -540,7 +544,7 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -// The first parameter can be either a string message or a FailureSpec +// The first parameter can be either a string message or a FailureInfo // isInjected indicates if this is an injected failure for logging purposes func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo interface{}, isInjected bool) { var compErr openaiserverapi.CompletionError @@ -552,18 +556,13 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo compErr = openaiserverapi.CompletionError{ Message: v, Type: "BadRequestError", - Code: "", + Code: 400, Param: nil, } statusCode = fasthttp.StatusBadRequest - case FailureSpec: - // New call with FailureSpec - compErr = openaiserverapi.CompletionError{ - Message: v.Message, - Type: v.ErrorType, - Code: v.ErrorCode, - Param: v.Param, - } + case FailureInfo: + // New call with FailureInfo + compErr = v.Error statusCode = v.StatusCode default: // For calls with msg, errType, and code - need to be updated in calling code diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index 9e8549b3..30865b57 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -214,8 +214,8 @@ type CompletionError struct { Type string `json:"type"` // Param is the error's parameter Param *string `json:"param"` - // Code is the error code string - Code string `json:"code,omitempty"` + // Code is the error code integer (same as HTTP status code) + Code int `json:"code,omitempty"` } // ErrorResponse wraps the error in the expected OpenAI format From 57657bf166fa70b426355149fa1b36645a7b15e8 Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Tue, 26 Aug 2025 11:54:50 +0300 Subject: [PATCH 43/46] Use channels for metrics updates, added metrics tests (#171) * Use channels for metrics updates. Metrics tests Signed-off-by: Ira * Review comments Signed-off-by: Ira --------- Signed-off-by: Ira --- pkg/llm-d-inference-sim/metrics.go | 47 ++- pkg/llm-d-inference-sim/metrics_test.go | 276 ++++++++++++++++ pkg/llm-d-inference-sim/simulator.go | 33 +- pkg/llm-d-inference-sim/simulator_test.go | 372 ++++++++++------------ 4 files changed, 514 insertions(+), 214 deletions(-) create mode 100644 pkg/llm-d-inference-sim/metrics_test.go diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index 3fa60a2e..5b065648 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -19,9 +19,9 @@ limitations under the License. package llmdinferencesim import ( + "context" "strconv" "strings" - "sync/atomic" "time" "github.com/prometheus/client_golang/prometheus" @@ -157,9 +157,8 @@ func (s *VllmSimulator) reportRunningRequests() { return } if s.runningRequests != nil { - nRunningReqs := atomic.LoadInt64(&(s.nRunningReqs)) s.runningRequests.WithLabelValues( - s.getDisplayedModelName(s.config.Model)).Set(float64(nRunningReqs)) + s.getDisplayedModelName(s.config.Model)).Set(float64(s.nRunningReqs)) } } @@ -169,8 +168,46 @@ func (s *VllmSimulator) reportWaitingRequests() { return } if s.waitingRequests != nil { - nWaitingReqs := atomic.LoadInt64(&(s.nWaitingReqs)) s.waitingRequests.WithLabelValues( - s.getDisplayedModelName(s.config.Model)).Set(float64(nWaitingReqs)) + s.getDisplayedModelName(s.config.Model)).Set(float64(s.nWaitingReqs)) + } +} + +func (s *VllmSimulator) unregisterPrometheus() { + prometheus.Unregister(s.loraInfo) + prometheus.Unregister(s.runningRequests) + prometheus.Unregister(s.waitingRequests) + prometheus.Unregister(s.kvCacheUsagePercentage) +} + +// startMetricsUpdaters starts the various metrics updaters +func (s *VllmSimulator) startMetricsUpdaters(ctx context.Context) { + go s.waitingRequestsUpdater(ctx) + go s.runningRequestsUpdater(ctx) +} + +// waitingRequestsUpdater updates the waiting requests metric by listening on the relevant channel +func (s *VllmSimulator) waitingRequestsUpdater(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case inc := <-s.waitingReqChan: + s.nWaitingReqs += inc + s.reportWaitingRequests() + } + } +} + +// runningRequestsUpdater updates the running requests metric by listening on the relevant channel +func (s *VllmSimulator) runningRequestsUpdater(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case inc := <-s.runReqChan: + s.nRunningReqs += inc + s.reportRunningRequests() + } } } diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go new file mode 100644 index 00000000..0d4e1f3c --- /dev/null +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -0,0 +1,276 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +var _ = Describe("Simulator metrics", Ordered, func() { + It("Should send correct running and waiting requests metrics", func() { + modelName := "testmodel" + // Three requests, only two can run in parallel, we expect + // two running requests and one waiting request in the metrics + ctx := context.TODO() + args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, + "--time-to-first-token", "3000", "--max-num-seqs", "2"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: modelName, + } + + var wg sync.WaitGroup + wg.Add(1) + + for range 3 { + go func() { + defer GinkgoRecover() + _, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + }() + } + + go func() { + defer wg.Done() + defer GinkgoRecover() + + time.Sleep(300 * time.Millisecond) + metricsResp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"testmodel\"} 2")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"testmodel\"} 1")) + }() + + wg.Wait() + }) + + It("Should send correct lora metrics", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--time-to-first-token", "3000", + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params1 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora1", + } + + _, err = openaiclient.Chat.Completions.New(ctx, params1) + Expect(err).NotTo(HaveOccurred()) + + params2 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora2", + } + + _, err = openaiclient.Chat.Completions.New(ctx, params2) + Expect(err).NotTo(HaveOccurred()) + + metricsResp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + + // We sent two sequentual requests to two different LoRAs, we expect to see (in this order) + // 1. running_lora_adapter = lora1 + // 2. running_lora_adapter = lora2 + // 3. running_lora_adapter = {} + lora1 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1\",waiting_lora_adapters=\"\"}" + lora2 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2\",waiting_lora_adapters=\"\"}" + empty := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"\",waiting_lora_adapters=\"\"}" + + Expect(metrics).To(ContainSubstring(lora1)) + Expect(metrics).To(ContainSubstring(lora2)) + Expect(metrics).To(ContainSubstring(empty)) + + // Check the order + lora1Timestamp := extractTimestamp(metrics, lora1) + lora2Timestamp := extractTimestamp(metrics, lora2) + noLorasTimestamp := extractTimestamp(metrics, empty) + + Expect(lora1Timestamp < lora2Timestamp).To(BeTrue()) + Expect(lora2Timestamp < noLorasTimestamp).To(BeTrue()) + }) + + It("Should send correct lora metrics for parallel requests", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--time-to-first-token", "2000", + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + + defer s.unregisterPrometheus() + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params1 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora1", + } + + params2 := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora2", + } + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + time.Sleep(1 * time.Second) + defer wg.Done() + defer GinkgoRecover() + _, err := openaiclient.Chat.Completions.New(ctx, params2) + Expect(err).NotTo(HaveOccurred()) + }() + + _, err = openaiclient.Chat.Completions.New(ctx, params1) + Expect(err).NotTo(HaveOccurred()) + + wg.Wait() + + metricsResp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + + // We sent two parallel requests: first to lora1 and then to lora2 (with a delay), we expect + // to see (in this order) + // 1. running_lora_adapter = lora1 + // 2. running_lora_adapter = lora2,lora1 (the order of LoRAs doesn't matter here) + // 3. running_lora_adapter = lora2 + // 4. running_lora_adapter = {} + lora1 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1\",waiting_lora_adapters=\"\"}" + lora12 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora1,lora2\",waiting_lora_adapters=\"\"}" + lora21 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2,lora1\",waiting_lora_adapters=\"\"}" + lora2 := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora2\",waiting_lora_adapters=\"\"}" + empty := "vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"\",waiting_lora_adapters=\"\"}" + + Expect(metrics).To(ContainSubstring(lora1)) + Expect(metrics).To(Or(ContainSubstring(lora12), ContainSubstring(lora21))) + Expect(metrics).To(ContainSubstring(lora2)) + Expect(metrics).To(ContainSubstring(empty)) + + // Check the order + lora1Timestamp := extractTimestamp(metrics, lora1) + lora2Timestamp := extractTimestamp(metrics, lora2) + noLorasTimestamp := extractTimestamp(metrics, empty) + var twoLorasTimestamp float64 + if strings.Contains(metrics, lora12) { + twoLorasTimestamp = extractTimestamp(metrics, lora12) + } else { + twoLorasTimestamp = extractTimestamp(metrics, lora21) + } + Expect(lora1Timestamp < twoLorasTimestamp).To(BeTrue()) + Expect(twoLorasTimestamp < lora2Timestamp).To(BeTrue()) + Expect(lora2Timestamp < noLorasTimestamp).To(BeTrue()) + }) + + Context("fake metrics", func() { + It("Should respond with fake metrics to /metrics", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, + "--fake-metrics", + "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", + } + + s, client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) + Expect(err).NotTo(HaveOccurred()) + + defer s.unregisterPrometheus() + + resp, err := client.Get("http://localhost/metrics") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"my_model\"} 10")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"my_model\"} 30")) + Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"my_model\"} 0.4")) + Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora2\",waiting_lora_adapters=\"lora3\"} 1.257894567e+09")) + Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora3\",waiting_lora_adapters=\"\"} 1.257894569e+09")) + }) + }) +}) + +func extractTimestamp(metrics string, key string) float64 { + re := regexp.MustCompile(key + ` (\S+)`) + result := re.FindStringSubmatch(metrics) + Expect(len(result)).To(BeNumerically(">", 1)) + f, err := strconv.ParseFloat(result[1], 64) + Expect(err).NotTo(HaveOccurred()) + return f +} diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index d9813996..9f56f798 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -26,7 +26,6 @@ import ( "os" "strings" "sync" - "sync/atomic" "time" "github.com/buaazp/fasthttprouter" @@ -52,6 +51,8 @@ const ( namespaceHeader = "x-inference-namespace" podNameEnv = "POD_NAME" podNsEnv = "POD_NAMESPACE" + + maxNumberOfRequests = 1000 ) // VllmSimulator simulates vLLM server supporting OpenAI API @@ -69,8 +70,12 @@ type VllmSimulator struct { waitingLoras sync.Map // nRunningReqs is the number of inference requests that are currently being processed nRunningReqs int64 + // runReqChan is a channel to update nRunningReqs + runReqChan chan int64 // nWaitingReqs is the number of inference requests that are waiting to be processed nWaitingReqs int64 + // waitingReqChan is a channel to update nWaitingReqs + waitingReqChan chan int64 // loraInfo is prometheus gauge loraInfo *prometheus.GaugeVec // runningRequests is prometheus gauge @@ -93,18 +98,20 @@ type VllmSimulator struct { // New creates a new VllmSimulator instance with the given logger func New(logger logr.Logger) (*VllmSimulator, error) { - toolsValidtor, err := openaiserverapi.CreateValidator() + toolsValidator, err := openaiserverapi.CreateValidator() if err != nil { return nil, fmt.Errorf("failed to create tools validator: %s", err) } return &VllmSimulator{ logger: logger, - reqChan: make(chan *openaiserverapi.CompletionReqCtx, 1000), - toolsValidator: toolsValidtor, + reqChan: make(chan *openaiserverapi.CompletionReqCtx, maxNumberOfRequests), + toolsValidator: toolsValidator, kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration namespace: os.Getenv(podNsEnv), pod: os.Getenv(podNameEnv), + runReqChan: make(chan int64, maxNumberOfRequests), + waitingReqChan: make(chan int64, maxNumberOfRequests), }, nil } @@ -148,6 +155,9 @@ func (s *VllmSimulator) Start(ctx context.Context) error { for i := 1; i <= s.config.MaxNumSeqs; i++ { go s.reqProcessingWorker(ctx, i) } + + s.startMetricsUpdaters(ctx) + listener, err := s.newListener() if err != nil { return err @@ -378,9 +388,8 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple IsChatCompletion: isChatCompletion, Wg: &wg, } + s.waitingReqChan <- 1 s.reqChan <- reqCtx - atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) - s.reportWaitingRequests() wg.Wait() } @@ -395,8 +404,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { s.logger.Info("reqProcessingWorker worker exiting: reqChan closed") return } - atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) - s.reportWaitingRequests() + + s.waitingReqChan <- -1 req := reqCtx.CompletionReq model := req.GetModel() @@ -419,8 +428,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // TODO - check if this request went to the waiting queue - add it to waiting map s.reportLoras() } - atomic.AddInt64(&(s.nRunningReqs), 1) - s.reportRunningRequests() + + s.runReqChan <- 1 var responseTokens []string var finishReason string @@ -491,9 +500,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // decrease model usage reference number func (s *VllmSimulator) responseSentCallback(model string) { - - atomic.AddInt64(&(s.nRunningReqs), -1) - s.reportRunningRequests() + s.runReqChan <- -1 // Only LoRA models require reference-count handling. if !s.isLora(model) { diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 88d87759..2641e5b9 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -49,10 +49,12 @@ func startServer(ctx context.Context, mode string) (*http.Client, error) { } func startServerWithArgs(ctx context.Context, mode string, args []string, envs map[string]string) (*http.Client, error) { - return startServerWithArgsAndMetrics(ctx, mode, args, envs, false) + _, client, err := startServerWithArgsAndMetrics(ctx, mode, args, envs, false) + return client, err } -func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []string, envs map[string]string, setMetrics bool) (*http.Client, error) { +func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []string, envs map[string]string, + setMetrics bool) (*VllmSimulator, *http.Client, error) { oldArgs := os.Args defer func() { os.Args = oldArgs @@ -82,11 +84,11 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri s, err := New(logger) if err != nil { - return nil, err + return nil, nil, err } config, err := common.ParseCommandParamsAndLoadConfig() if err != nil { - return nil, err + return nil, nil, err } s.config = config @@ -99,7 +101,7 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri if setMetrics { err = s.createAndRegisterPrometheus() if err != nil { - return nil, err + return nil, nil, err } } @@ -112,6 +114,8 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri go s.reqProcessingWorker(ctx, i) } + s.startMetricsUpdaters(ctx) + listener := fasthttputil.NewInmemoryListener() // start the http server @@ -121,7 +125,7 @@ func startServerWithArgsAndMetrics(ctx context.Context, mode string, args []stri } }() - return &http.Client{ + return s, &http.Client{ Transport: &http.Transport{ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return listener.Dial() @@ -429,209 +433,211 @@ var _ = Describe("Simulator", func() { Expect(resp.StatusCode).To(Equal(http.StatusOK)) }) - It("Should not include namespace and pod headers in chat completion response when env is not set", func() { - ctx := context.TODO() + Context("namespace and pod headers", func() { + It("Should not include namespace and pod headers in chat completion response when env is not set", func() { + ctx := context.TODO() - client, err := startServer(ctx, common.ModeRandom) - Expect(err).NotTo(HaveOccurred()) + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - } + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + } - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - // Check for namespace and pod headers - namespaceHeader := httpResp.Header.Get(namespaceHeader) - podHeader := httpResp.Header.Get(podHeader) + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) - Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") - Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") - }) + Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") + Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") + }) - It("Should include namespace and pod headers in chat completion response", func() { - ctx := context.TODO() + It("Should include namespace and pod headers in chat completion response", func() { + ctx := context.TODO() - testNamespace := "test-namespace" - testPod := "test-pod" - envs := map[string]string{ - podNameEnv: testPod, - podNsEnv: testNamespace, - } - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) - Expect(err).NotTo(HaveOccurred()) + testNamespace := "test-namespace" + testPod := "test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - } + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + } - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - // Check for namespace and pod headers - namespaceHeader := httpResp.Header.Get(namespaceHeader) - podHeader := httpResp.Header.Get(podHeader) + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) - Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") - Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") - }) + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) - It("Should include namespace and pod headers in chat completion streaming response", func() { - ctx := context.TODO() + It("Should include namespace and pod headers in chat completion streaming response", func() { + ctx := context.TODO() - testNamespace := "stream-test-namespace" - testPod := "stream-test-pod" - envs := map[string]string{ - podNameEnv: testPod, - podNsEnv: testNamespace, - } - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) - Expect(err).NotTo(HaveOccurred()) + testNamespace := "stream-test-namespace" + testPod := "stream-test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - // Check for namespace and pod headers - namespaceHeader := httpResp.Header.Get(namespaceHeader) - podHeader := httpResp.Header.Get(podHeader) + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) - Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") - Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") - }) + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) - It("Should not include namespace and pod headers in chat completion streaming response when env is not set", func() { - ctx := context.TODO() + It("Should not include namespace and pod headers in chat completion streaming response when env is not set", func() { + ctx := context.TODO() - client, err := startServer(ctx, common.ModeRandom) - Expect(err).NotTo(HaveOccurred()) + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: model, + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - // Check for namespace and pod headers - namespaceHeader := httpResp.Header.Get(namespaceHeader) - podHeader := httpResp.Header.Get(podHeader) + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) - Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") - Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") - }) + Expect(namespaceHeader).To(BeEmpty(), "Expected namespace header not to be present") + Expect(podHeader).To(BeEmpty(), "Expected pod header not to be present") + }) - It("Should include namespace and pod headers in completion response", func() { - ctx := context.TODO() + It("Should include namespace and pod headers in completion response", func() { + ctx := context.TODO() - testNamespace := "test-namespace" - testPod := "test-pod" - envs := map[string]string{ - podNameEnv: testPod, - podNsEnv: testNamespace, - } - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) - Expect(err).NotTo(HaveOccurred()) + testNamespace := "test-namespace" + testPod := "test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - } - var httpResp *http.Response - resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + Model: openai.CompletionNewParamsModel(model), + } + var httpResp *http.Response + resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - // Check for namespace and pod headers - namespaceHeader := httpResp.Header.Get(namespaceHeader) - podHeader := httpResp.Header.Get(podHeader) + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) - Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") - Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") - }) + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) - It("Should include namespace and pod headers in completion streaming response", func() { - ctx := context.TODO() + It("Should include namespace and pod headers in completion streaming response", func() { + ctx := context.TODO() - testNamespace := "stream-test-namespace" - testPod := "stream-test-pod" - envs := map[string]string{ - podNameEnv: testPod, - podNsEnv: testNamespace, - } - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) - Expect(err).NotTo(HaveOccurred()) + testNamespace := "stream-test-namespace" + testPod := "stream-test-pod" + envs := map[string]string{ + podNameEnv: testPod, + podNsEnv: testNamespace, + } + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } - var httpResp *http.Response - resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + Model: openai.CompletionNewParamsModel(model), + StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, + } + var httpResp *http.Response + resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - // Check for namespace and pod headers - namespaceHeader := httpResp.Header.Get(namespaceHeader) - podHeader := httpResp.Header.Get(podHeader) + // Check for namespace and pod headers + namespaceHeader := httpResp.Header.Get(namespaceHeader) + podHeader := httpResp.Header.Get(podHeader) - Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") - Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + Expect(namespaceHeader).To(Equal(testNamespace), "Expected namespace header to be present") + Expect(podHeader).To(Equal(testPod), "Expected pod header to be present") + }) }) Context("max-model-len context window validation", func() { @@ -824,30 +830,4 @@ var _ = Describe("Simulator", func() { ) }) - Context("fake metrics", func() { - It("Should respond with fake metrics to /metrics", func() { - ctx := context.TODO() - args := []string{"cmd", "--model", model, "--mode", common.ModeRandom, - "--fake-metrics", - "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":0.4,\"loras\":[{\"running\":\"lora4,lora2\",\"waiting\":\"lora3\",\"timestamp\":1257894567},{\"running\":\"lora4,lora3\",\"waiting\":\"\",\"timestamp\":1257894569}]}", - } - - client, err := startServerWithArgsAndMetrics(ctx, common.ModeRandom, args, nil, true) - Expect(err).NotTo(HaveOccurred()) - - resp, err := client.Get("http://localhost/metrics") - Expect(err).NotTo(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - - data, err := io.ReadAll(resp.Body) - Expect(err).NotTo(HaveOccurred()) - metrics := string(data) - Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"my_model\"} 10")) - Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"my_model\"} 30")) - Expect(metrics).To(ContainSubstring("vllm:gpu_cache_usage_perc{model_name=\"my_model\"} 0.4")) - Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora2\",waiting_lora_adapters=\"lora3\"} 1.257894567e+09")) - Expect(metrics).To(ContainSubstring("vllm:lora_requests_info{max_lora=\"1\",running_lora_adapters=\"lora4,lora3\",waiting_lora_adapters=\"\"} 1.257894569e+09")) - - }) - }) }) From 974b611fc3172279c431b1662a163dc8cce2209f Mon Sep 17 00:00:00 2001 From: Ira Rosen Date: Tue, 26 Aug 2025 12:25:04 +0300 Subject: [PATCH 44/46] Remove rerun on comment action (#174) Signed-off-by: Ira --- .github/workflows/re-run-action.yml | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 .github/workflows/re-run-action.yml diff --git a/.github/workflows/re-run-action.yml b/.github/workflows/re-run-action.yml deleted file mode 100644 index cb2914e8..00000000 --- a/.github/workflows/re-run-action.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: Re-Run PR tests - -on: - issue_comment: - types: [created] - -jobs: - rerun_pr_tests: - name: rerun_pr_tests - if: ${{ github.event.issue.pull_request }} - runs-on: ubuntu-20.04 - steps: - - uses: estroz/rerun-actions@main - with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - comment_id: ${{ github.event.comment.id }} From 72dde248acbef934c2f6f01abf1da8fd9a56816e Mon Sep 17 00:00:00 2001 From: Ira Date: Wed, 27 Aug 2025 10:58:06 +0300 Subject: [PATCH 45/46] Use one type for all errors. Map code to type Signed-off-by: Ira --- pkg/llm-d-inference-sim/failures.go | 80 +++++------------------ pkg/llm-d-inference-sim/failures_test.go | 70 ++++++++++---------- pkg/llm-d-inference-sim/simulator.go | 73 +++++---------------- pkg/llm-d-inference-sim/simulator_test.go | 20 +++--- pkg/openai-server-api/response.go | 38 +++++++++++ 5 files changed, 117 insertions(+), 164 deletions(-) diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go index 1ad7e0ed..69daf36e 100644 --- a/pkg/llm-d-inference-sim/failures.go +++ b/pkg/llm-d-inference-sim/failures.go @@ -29,66 +29,18 @@ const ( modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" ) -type FailureInfo struct { - StatusCode int - Error openaiserverapi.CompletionError -} - -var predefinedFailures = map[string]FailureInfo{ - common.FailureTypeRateLimit: { - StatusCode: 429, - Error: openaiserverapi.CompletionError{ - Message: rateLimitMessageTemplate, - Type: "rate_limit_exceeded", - Code: 429, - Param: nil, - }, - }, - common.FailureTypeInvalidAPIKey: { - StatusCode: 401, - Error: openaiserverapi.CompletionError{ - Message: "Incorrect API key provided", - Type: "invalid_request_error", - Code: 401, - Param: nil, - }, - }, - common.FailureTypeContextLength: { - StatusCode: 400, - Error: openaiserverapi.CompletionError{ - Message: "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", - Type: "invalid_request_error", - Code: 400, - Param: stringPtr("messages"), - }, - }, - common.FailureTypeServerError: { - StatusCode: 503, - Error: openaiserverapi.CompletionError{ - Message: "The server is overloaded or not ready yet.", - Type: "server_error", - Code: 503, - Param: nil, - }, - }, - common.FailureTypeInvalidRequest: { - StatusCode: 400, - Error: openaiserverapi.CompletionError{ - Message: "Invalid request: missing required parameter 'model'.", - Type: "invalid_request_error", - Code: 400, - Param: stringPtr("model"), - }, - }, - common.FailureTypeModelNotFound: { - StatusCode: 404, - Error: openaiserverapi.CompletionError{ - Message: modelNotFoundMessageTemplate, - Type: "invalid_request_error", - Code: 404, - Param: stringPtr("model"), - }, - }, +var predefinedFailures = map[string]openaiserverapi.CompletionError{ + common.FailureTypeRateLimit: openaiserverapi.NewCompletionError(rateLimitMessageTemplate, 429, nil), + common.FailureTypeInvalidAPIKey: openaiserverapi.NewCompletionError("Incorrect API key provided.", 401, nil), + common.FailureTypeContextLength: openaiserverapi.NewCompletionError( + "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", + 400, stringPtr("messages")), + common.FailureTypeServerError: openaiserverapi.NewCompletionError( + "The server is overloaded or not ready yet.", 503, nil), + common.FailureTypeInvalidRequest: openaiserverapi.NewCompletionError( + "Invalid request: missing required parameter 'model'.", 400, stringPtr("model")), + common.FailureTypeModelNotFound: openaiserverapi.NewCompletionError(modelNotFoundMessageTemplate, + 404, stringPtr("model")), } // shouldInjectFailure determines whether to inject a failure based on configuration @@ -100,8 +52,8 @@ func shouldInjectFailure(config *common.Configuration) bool { return common.RandomInt(1, 100) <= config.FailureInjectionRate } -// GetRandomFailure returns a random failure from configured types or all types if none specified -func GetRandomFailure(config *common.Configuration) FailureInfo { +// getRandomFailure returns a random failure from configured types or all types if none specified +func getRandomFailure(config *common.Configuration) openaiserverapi.CompletionError { var availableFailures []string if len(config.FailureTypes) == 0 { // Use all failure types if none specified @@ -123,9 +75,9 @@ func GetRandomFailure(config *common.Configuration) FailureInfo { // Customize message with current model name failure := predefinedFailures[randomType] if randomType == common.FailureTypeRateLimit && config.Model != "" { - failure.Error.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) + failure.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) } else if randomType == common.FailureTypeModelNotFound && config.Model != "" { - failure.Error.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) + failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) } return failure diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go index 8f1d176f..99259819 100644 --- a/pkg/llm-d-inference-sim/failures_test.go +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -14,31 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ -package llmdinferencesim_test +package llmdinferencesim import ( "strings" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/llm-d/llm-d-inference-sim/pkg/common" - llmdinferencesim "github.com/llm-d/llm-d-inference-sim/pkg/llm-d-inference-sim" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" ) var _ = Describe("Failures", func() { - // Note: shouldInjectFailure is now private, so we test it through GetRandomFailure behavior + Describe("getRandomFailure", Ordered, func() { + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) - Describe("GetRandomFailure", func() { It("should return a failure from all types when none specified", func() { config := &common.Configuration{ Model: "test-model", FailureTypes: []string{}, } - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(BeNumerically(">=", 400)) - Expect(failure.Error.Message).ToNot(BeEmpty()) - Expect(failure.Error.Type).ToNot(BeEmpty()) + failure := getRandomFailure(config) + Expect(failure.Code).To(BeNumerically(">=", 400)) + Expect(failure.Message).ToNot(BeEmpty()) + Expect(failure.Type).ToNot(BeEmpty()) }) It("should return rate limit failure when specified", func() { @@ -46,44 +49,40 @@ var _ = Describe("Failures", func() { Model: "test-model", FailureTypes: []string{common.FailureTypeRateLimit}, } - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(429)) - Expect(failure.Error.Type).To(Equal("rate_limit_exceeded")) - Expect(failure.Error.Code).To(Equal(429)) - Expect(strings.Contains(failure.Error.Message, "test-model")).To(BeTrue()) + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(429)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) + Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) }) It("should return invalid API key failure when specified", func() { config := &common.Configuration{ FailureTypes: []string{common.FailureTypeInvalidAPIKey}, } - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(401)) - Expect(failure.Error.Type).To(Equal("invalid_request_error")) - Expect(failure.Error.Code).To(Equal(401)) - Expect(failure.Error.Message).To(Equal("Incorrect API key provided")) + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(401)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(401))) + Expect(failure.Message).To(Equal("Incorrect API key provided.")) }) It("should return context length failure when specified", func() { config := &common.Configuration{ FailureTypes: []string{common.FailureTypeContextLength}, } - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(400)) - Expect(failure.Error.Type).To(Equal("invalid_request_error")) - Expect(failure.Error.Code).To(Equal(400)) - Expect(failure.Error.Param).ToNot(BeNil()) - Expect(*failure.Error.Param).To(Equal("messages")) + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(400)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(400))) + Expect(failure.Param).ToNot(BeNil()) + Expect(*failure.Param).To(Equal("messages")) }) It("should return server error when specified", func() { config := &common.Configuration{ FailureTypes: []string{common.FailureTypeServerError}, } - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(503)) - Expect(failure.Error.Type).To(Equal("server_error")) - Expect(failure.Error.Code).To(Equal(503)) + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(503)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(503))) }) It("should return model not found failure when specified", func() { @@ -91,11 +90,10 @@ var _ = Describe("Failures", func() { Model: "test-model", FailureTypes: []string{common.FailureTypeModelNotFound}, } - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(Equal(404)) - Expect(failure.Error.Type).To(Equal("invalid_request_error")) - Expect(failure.Error.Code).To(Equal(404)) - Expect(strings.Contains(failure.Error.Message, "test-model-nonexistent")).To(BeTrue()) + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(404)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(404))) + Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) }) It("should return server error as fallback for empty types", func() { @@ -103,9 +101,9 @@ var _ = Describe("Failures", func() { FailureTypes: []string{}, } // This test is probabilistic since it randomly selects, but we can test structure - failure := llmdinferencesim.GetRandomFailure(config) - Expect(failure.StatusCode).To(BeNumerically(">=", 400)) - Expect(failure.Error.Type).ToNot(BeEmpty()) + failure := getRandomFailure(config) + Expect(failure.Code).To(BeNumerically(">=", 400)) + Expect(failure.Type).ToNot(BeEmpty()) }) }) }) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index ba418019..96f58e52 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -284,20 +284,20 @@ func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) { s.unloadLora(ctx) } -func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, string, int) { +func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, int) { if !s.isValidModel(req.GetModel()) { - return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), "NotFoundError", fasthttp.StatusNotFound + return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), fasthttp.StatusNotFound } if req.GetMaxCompletionTokens() != nil && *req.GetMaxCompletionTokens() <= 0 { - return "Max completion tokens and max tokens should be positive", "Invalid request", fasthttp.StatusBadRequest + return "Max completion tokens and max tokens should be positive", fasthttp.StatusBadRequest } if req.IsDoRemoteDecode() && req.IsStream() { - return "Prefill does not support streaming", "Invalid request", fasthttp.StatusBadRequest + return "Prefill does not support streaming", fasthttp.StatusBadRequest } - return "", "", fasthttp.StatusOK + return "", fasthttp.StatusOK } // isValidModel checks if the given model is the base model or one of "loaded" LoRAs @@ -331,7 +331,7 @@ func (s *VllmSimulator) isLora(model string) bool { func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure if shouldInjectFailure(s.config) { - failure := GetRandomFailure(s.config) + failure := getRandomFailure(s.config) s.sendCompletionError(ctx, failure, true) return } @@ -343,17 +343,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } - errMsg, errType, errCode := s.validateRequest(vllmReq) + errMsg, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, FailureInfo{ - StatusCode: errCode, - Error: openaiserverapi.CompletionError{ - Message: errMsg, - Type: errType, - Code: errCode, - Param: nil, - }, - }, false) + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(errMsg, errCode, nil), false) return } @@ -380,16 +372,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, FailureInfo{ - StatusCode: fasthttp.StatusBadRequest, - Error: openaiserverapi.CompletionError{ - Message: fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), - Type: "BadRequestError", - Code: fasthttp.StatusBadRequest, - Param: nil, - }, - }, false) + message := fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens) + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(message, fasthttp.StatusBadRequest, nil), false) return } @@ -544,47 +529,25 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -// The first parameter can be either a string message or a FailureInfo // isInjected indicates if this is an injected failure for logging purposes -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, errorInfo interface{}, isInjected bool) { - var compErr openaiserverapi.CompletionError - var statusCode int - - switch v := errorInfo.(type) { - case string: - // Legacy call with string message (backward compatibility) - compErr = openaiserverapi.CompletionError{ - Message: v, - Type: "BadRequestError", - Code: 400, - Param: nil, - } - statusCode = fasthttp.StatusBadRequest - case FailureInfo: - // New call with FailureInfo - compErr = v.Error - statusCode = v.StatusCode - default: - // For calls with msg, errType, and code - need to be updated in calling code - panic("sendCompletionError called with unexpected type") - } - - errorResp := openaiserverapi.ErrorResponse{ - Error: compErr, - } - +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, + compErr openaiserverapi.CompletionError, isInjected bool) { if isInjected { s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) } else { s.logger.Error(nil, compErr.Message) } + errorResp := openaiserverapi.ErrorResponse{ + Error: compErr, + } + data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(statusCode) + ctx.SetStatusCode(compErr.Code) ctx.SetBody(data) } } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index a412b725..c388f17a 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" @@ -925,7 +926,7 @@ var _ = Describe("Simulator", func() { ok := errors.As(err, &openaiError) Expect(ok).To(BeTrue()) Expect(openaiError.StatusCode).To(Equal(429)) - Expect(openaiError.Type).To(Equal("rate_limit_exceeded")) + Expect(openaiError.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) Expect(strings.Contains(openaiError.Message, model)).To(BeTrue()) }) }) @@ -965,7 +966,8 @@ var _ = Describe("Simulator", func() { // Should only be one of the specified types Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) - Expect(openaiError.Type == "invalid_request_error" || openaiError.Type == "server_error").To(BeTrue()) + Expect(openaiError.Type == openaiserverapi.ErrorCodeToType(401) || + openaiError.Type == openaiserverapi.ErrorCodeToType(503)).To(BeTrue()) } }) }) @@ -1003,7 +1005,7 @@ var _ = Describe("Simulator", func() { Context("testing all predefined failure types", func() { DescribeTable("should return correct error for each failure type", - func(failureType string, expectedStatusCode int, expectedErrorType string, expectedErrorCode string) { + func(failureType string, expectedStatusCode int, expectedErrorType string) { ctx := context.Background() client, err := startServerWithArgs(ctx, "failure", []string{ "cmd", "--model", model, @@ -1034,12 +1036,12 @@ var _ = Describe("Simulator", func() { // Note: OpenAI Go client doesn't directly expose the error code field, // but we can verify via status code and type }, - Entry("rate_limit", common.FailureTypeRateLimit, 429, "rate_limit_exceeded", "rate_limit_exceeded"), - Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, "invalid_request_error", "invalid_api_key"), - Entry("context_length", common.FailureTypeContextLength, 400, "invalid_request_error", "context_length_exceeded"), - Entry("server_error", common.FailureTypeServerError, 503, "server_error", "server_error"), - Entry("invalid_request", common.FailureTypeInvalidRequest, 400, "invalid_request_error", "invalid_request_error"), - Entry("model_not_found", common.FailureTypeModelNotFound, 404, "invalid_request_error", "model_not_found"), + Entry("rate_limit", common.FailureTypeRateLimit, 429, openaiserverapi.ErrorCodeToType(429)), + Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, openaiserverapi.ErrorCodeToType(401)), + Entry("context_length", common.FailureTypeContextLength, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("server_error", common.FailureTypeServerError, 503, openaiserverapi.ErrorCodeToType(503)), + Entry("invalid_request", common.FailureTypeInvalidRequest, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("model_not_found", common.FailureTypeModelNotFound, 404, openaiserverapi.ErrorCodeToType(404)), ) }) }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index 30865b57..d32784e3 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -21,6 +21,8 @@ import ( "encoding/json" "errors" "strings" + + "github.com/valyala/fasthttp" ) // CompletionResponse interface representing both completion response types (text and chat) @@ -218,7 +220,43 @@ type CompletionError struct { Code int `json:"code,omitempty"` } +// NewCompletionError creates a new CompletionError +func NewCompletionError(message string, code int, param *string) CompletionError { + return CompletionError{ + Message: message, + Code: code, + Type: ErrorCodeToType(code), + Param: param, + } +} + // ErrorResponse wraps the error in the expected OpenAI format type ErrorResponse struct { Error CompletionError `json:"error"` } + +// ErrorCodeToType maps error code to error type according to https://www.npmjs.com/package/openai +func ErrorCodeToType(code int) string { + errorType := "" + switch code { + case fasthttp.StatusBadRequest: + errorType = "BadRequestError" + case fasthttp.StatusUnauthorized: + errorType = "AuthenticationError" + case fasthttp.StatusForbidden: + errorType = "PermissionDeniedError" + case fasthttp.StatusNotFound: + errorType = "NotFoundError" + case fasthttp.StatusUnprocessableEntity: + errorType = "UnprocessableEntityError" + case fasthttp.StatusTooManyRequests: + errorType = "RateLimitError" + default: + if code >= fasthttp.StatusInternalServerError { + errorType = "InternalServerError" + } else { + errorType = "APIConnectionError" + } + } + return errorType +} From 7994048d566f355b42bb8d34c0d865ba8c5e59fc Mon Sep 17 00:00:00 2001 From: Ira Date: Wed, 27 Aug 2025 13:22:04 +0300 Subject: [PATCH 46/46] Review comments Signed-off-by: Ira --- pkg/common/config.go | 4 +- pkg/llm-d-inference-sim/failures_test.go | 225 ++++++++++++++++++++++ pkg/llm-d-inference-sim/simulator_test.go | 222 --------------------- 3 files changed, 227 insertions(+), 224 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index 23b43f80..1e8add97 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -143,9 +143,9 @@ type Configuration struct { FakeMetrics *Metrics `yaml:"fake-metrics" json:"fake-metrics"` // FailureInjectionRate is the probability (0-100) of injecting failures - FailureInjectionRate int `yaml:"failure-injection-rate"` + FailureInjectionRate int `yaml:"failure-injection-rate" json:"failure-injection-rate"` // FailureTypes is a list of specific failure types to inject (empty means all types) - FailureTypes []string `yaml:"failure-types"` + FailureTypes []string `yaml:"failure-types" json:"failure-types"` } type Metrics struct { diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go index 99259819..5ff48034 100644 --- a/pkg/llm-d-inference-sim/failures_test.go +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -17,11 +17,16 @@ limitations under the License. package llmdinferencesim import ( + "context" + "errors" + "net/http" "strings" "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" @@ -106,4 +111,224 @@ var _ = Describe("Failures", func() { Expect(failure.Type).ToNot(BeEmpty()) }) }) + Describe("Simulator with failure injection", func() { + var ( + client *http.Client + ctx context.Context + ) + + AfterEach(func() { + if ctx != nil { + ctx.Done() + } + }) + + Context("with 100% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should always return an error response for chat completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + + It("should always return an error response for text completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{ + Model: openai.CompletionNewParamsModel(model), + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + }) + + Context("with specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", common.FailureTypeRateLimit, + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only rate limit errors", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(429)) + Expect(openaiError.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) + Expect(strings.Contains(openaiError.Message, model)).To(BeTrue()) + }) + }) + + Context("with multiple specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only specified error types", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + // Make multiple requests to verify we get the expected error types + for i := 0; i < 10; i++ { + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + + // Should only be one of the specified types + Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) + Expect(openaiError.Type == openaiserverapi.ErrorCodeToType(401) || + openaiError.Type == openaiserverapi.ErrorCodeToType(503)).To(BeTrue()) + } + }) + }) + + Context("with 0% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "0", + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should never return errors and behave like random mode", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Choices).To(HaveLen(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + Expect(resp.Model).To(Equal(model)) + }) + }) + + Context("testing all predefined failure types", func() { + DescribeTable("should return correct error for each failure type", + func(failureType string, expectedStatusCode int, expectedErrorType string) { + ctx := context.Background() + client, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", failureType, + }, nil) + Expect(err).ToNot(HaveOccurred()) + + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(expectedStatusCode)) + Expect(openaiError.Type).To(Equal(expectedErrorType)) + // Note: OpenAI Go client doesn't directly expose the error code field, + // but we can verify via status code and type + }, + Entry("rate_limit", common.FailureTypeRateLimit, 429, openaiserverapi.ErrorCodeToType(429)), + Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, openaiserverapi.ErrorCodeToType(401)), + Entry("context_length", common.FailureTypeContextLength, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("server_error", common.FailureTypeServerError, 503, openaiserverapi.ErrorCodeToType(503)), + Entry("invalid_request", common.FailureTypeInvalidRequest, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("model_not_found", common.FailureTypeModelNotFound, 404, openaiserverapi.ErrorCodeToType(404)), + ) + }) + }) }) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 971d0bc9..9e4c882b 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -28,7 +28,6 @@ import ( "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" - openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" @@ -830,225 +829,4 @@ var _ = Describe("Simulator", func() { Entry(nil, 10000, 0, 1000, 0, false), ) }) - - Describe("Failure injection mode", func() { - var ( - client *http.Client - ctx context.Context - ) - - AfterEach(func() { - if ctx != nil { - ctx.Done() - } - }) - - Context("with 100% failure injection rate", func() { - BeforeEach(func() { - ctx = context.Background() - var err error - client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "100", - }, nil) - Expect(err).ToNot(HaveOccurred()) - }) - - It("should always return an error response for chat completions", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - - Expect(err).To(HaveOccurred()) - - var openaiError *openai.Error - ok := errors.As(err, &openaiError) - Expect(ok).To(BeTrue()) - Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) - Expect(openaiError.Type).ToNot(BeEmpty()) - Expect(openaiError.Message).ToNot(BeEmpty()) - }) - - It("should always return an error response for text completions", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{ - Model: openai.CompletionNewParamsModel(model), - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - }) - - Expect(err).To(HaveOccurred()) - - var openaiError *openai.Error - ok := errors.As(err, &openaiError) - Expect(ok).To(BeTrue()) - Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) - Expect(openaiError.Type).ToNot(BeEmpty()) - Expect(openaiError.Message).ToNot(BeEmpty()) - }) - }) - - Context("with specific failure types", func() { - BeforeEach(func() { - ctx = context.Background() - var err error - client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "100", - "--failure-types", common.FailureTypeRateLimit, - }, nil) - Expect(err).ToNot(HaveOccurred()) - }) - - It("should return only rate limit errors", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - - Expect(err).To(HaveOccurred()) - - var openaiError *openai.Error - ok := errors.As(err, &openaiError) - Expect(ok).To(BeTrue()) - Expect(openaiError.StatusCode).To(Equal(429)) - Expect(openaiError.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) - Expect(strings.Contains(openaiError.Message, model)).To(BeTrue()) - }) - }) - - Context("with multiple specific failure types", func() { - BeforeEach(func() { - ctx = context.Background() - var err error - client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "100", - "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, - }, nil) - Expect(err).ToNot(HaveOccurred()) - }) - - It("should return only specified error types", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - // Make multiple requests to verify we get the expected error types - for i := 0; i < 10; i++ { - _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - - Expect(err).To(HaveOccurred()) - - var openaiError *openai.Error - ok := errors.As(err, &openaiError) - Expect(ok).To(BeTrue()) - - // Should only be one of the specified types - Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) - Expect(openaiError.Type == openaiserverapi.ErrorCodeToType(401) || - openaiError.Type == openaiserverapi.ErrorCodeToType(503)).To(BeTrue()) - } - }) - }) - - Context("with 0% failure injection rate", func() { - BeforeEach(func() { - ctx = context.Background() - var err error - client, err = startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "0", - }, nil) - Expect(err).ToNot(HaveOccurred()) - }) - - It("should never return errors and behave like random mode", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - - Expect(err).ToNot(HaveOccurred()) - Expect(resp.Choices).To(HaveLen(1)) - Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) - Expect(resp.Model).To(Equal(model)) - }) - }) - - Context("testing all predefined failure types", func() { - DescribeTable("should return correct error for each failure type", - func(failureType string, expectedStatusCode int, expectedErrorType string) { - ctx := context.Background() - client, err := startServerWithArgs(ctx, "failure", []string{ - "cmd", "--model", model, - "--failure-injection-rate", "100", - "--failure-types", failureType, - }, nil) - Expect(err).ToNot(HaveOccurred()) - - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - - Expect(err).To(HaveOccurred()) - - var openaiError *openai.Error - ok := errors.As(err, &openaiError) - Expect(ok).To(BeTrue()) - Expect(openaiError.StatusCode).To(Equal(expectedStatusCode)) - Expect(openaiError.Type).To(Equal(expectedErrorType)) - // Note: OpenAI Go client doesn't directly expose the error code field, - // but we can verify via status code and type - }, - Entry("rate_limit", common.FailureTypeRateLimit, 429, openaiserverapi.ErrorCodeToType(429)), - Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, openaiserverapi.ErrorCodeToType(401)), - Entry("context_length", common.FailureTypeContextLength, 400, openaiserverapi.ErrorCodeToType(400)), - Entry("server_error", common.FailureTypeServerError, 503, openaiserverapi.ErrorCodeToType(503)), - Entry("invalid_request", common.FailureTypeInvalidRequest, 400, openaiserverapi.ErrorCodeToType(400)), - Entry("model_not_found", common.FailureTypeModelNotFound, 404, openaiserverapi.ErrorCodeToType(404)), - ) - }) - }) })