Skip to content

Commit 4f91fab

Browse files
authored
language_models: Add xAI support to Zed Cloud provider (#38928)
This PR adds xAI support to the Zed Cloud provider. Release Notes: - N/A
1 parent 0e0f48d commit 4f91fab

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

crates/cloud_llm_client/src/cloud_llm_client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ pub enum LanguageModelProvider {
144144
Anthropic,
145145
OpenAi,
146146
Google,
147+
XAi,
147148
}
148149

149150
#[derive(Debug, Clone, Serialize, Deserialize)]

crates/language_model/src/language_model.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId
5050
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
5151
LanguageModelProviderName::new("OpenAI");
5252

53+
pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
54+
pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
55+
5356
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
5457
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
5558
LanguageModelProviderName::new("Zed");

crates/language_models/src/provider/cloud.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use util::{ResultExt as _, maybe};
4646
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
4747
use crate::provider::google::{GoogleEventMapper, into_google};
4848
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
49+
use crate::provider::x_ai::count_xai_tokens;
4950

5051
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
5152
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
@@ -579,6 +580,7 @@ impl LanguageModel for CloudLanguageModel {
579580
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
580581
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
581582
Google => language_model::GOOGLE_PROVIDER_ID,
583+
XAi => language_model::X_AI_PROVIDER_ID,
582584
}
583585
}
584586

@@ -588,6 +590,7 @@ impl LanguageModel for CloudLanguageModel {
588590
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
589591
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
590592
Google => language_model::GOOGLE_PROVIDER_NAME,
593+
XAi => language_model::X_AI_PROVIDER_NAME,
591594
}
592595
}
593596

@@ -618,7 +621,8 @@ impl LanguageModel for CloudLanguageModel {
618621
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
619622
match self.model.provider {
620623
cloud_llm_client::LanguageModelProvider::Anthropic
621-
| cloud_llm_client::LanguageModelProvider::OpenAi => {
624+
| cloud_llm_client::LanguageModelProvider::OpenAi
625+
| cloud_llm_client::LanguageModelProvider::XAi => {
622626
LanguageModelToolSchemaFormat::JsonSchema
623627
}
624628
cloud_llm_client::LanguageModelProvider::Google => {
@@ -648,6 +652,7 @@ impl LanguageModel for CloudLanguageModel {
648652
})
649653
}
650654
cloud_llm_client::LanguageModelProvider::OpenAi
655+
| cloud_llm_client::LanguageModelProvider::XAi
651656
| cloud_llm_client::LanguageModelProvider::Google => None,
652657
}
653658
}
@@ -668,6 +673,13 @@ impl LanguageModel for CloudLanguageModel {
668673
};
669674
count_open_ai_tokens(request, model, cx)
670675
}
676+
cloud_llm_client::LanguageModelProvider::XAi => {
677+
let model = match x_ai::Model::from_id(&self.model.id.0) {
678+
Ok(model) => model,
679+
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
680+
};
681+
count_xai_tokens(request, model, cx)
682+
}
671683
cloud_llm_client::LanguageModelProvider::Google => {
672684
let client = self.client.clone();
673685
let llm_api_token = self.llm_api_token.clone();
@@ -845,6 +857,56 @@ impl LanguageModel for CloudLanguageModel {
845857
});
846858
async move { Ok(future.await?.boxed()) }.boxed()
847859
}
860+
cloud_llm_client::LanguageModelProvider::XAi => {
861+
let client = self.client.clone();
862+
let model = match x_ai::Model::from_id(&self.model.id.0) {
863+
Ok(model) => model,
864+
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
865+
};
866+
let request = into_open_ai(
867+
request,
868+
model.id(),
869+
model.supports_parallel_tool_calls(),
870+
model.supports_prompt_cache_key(),
871+
None,
872+
None,
873+
);
874+
let llm_api_token = self.llm_api_token.clone();
875+
let future = self.request_limiter.stream(async move {
876+
let PerformLlmCompletionResponse {
877+
response,
878+
usage,
879+
includes_status_messages,
880+
tool_use_limit_reached,
881+
} = Self::perform_llm_completion(
882+
client.clone(),
883+
llm_api_token,
884+
app_version,
885+
CompletionBody {
886+
thread_id,
887+
prompt_id,
888+
intent,
889+
mode,
890+
provider: cloud_llm_client::LanguageModelProvider::XAi,
891+
model: request.model.clone(),
892+
provider_request: serde_json::to_value(&request)
893+
.map_err(|e| anyhow!(e))?,
894+
},
895+
)
896+
.await?;
897+
898+
let mut mapper = OpenAiEventMapper::new();
899+
Ok(map_cloud_completion_events(
900+
Box::pin(
901+
response_lines(response, includes_status_messages)
902+
.chain(usage_updated_event(usage))
903+
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
904+
),
905+
move |event| mapper.map_event(event),
906+
))
907+
});
908+
async move { Ok(future.await?.boxed()) }.boxed()
909+
}
848910
cloud_llm_client::LanguageModelProvider::Google => {
849911
let client = self.client.clone();
850912
let request =

0 commit comments

Comments
 (0)