Skip to content

Commit 59808a4

Browse files
committed
Fix revision dir logic when resolving tokenizer path in launcher
Also move before launching shards so as to fail-fast
1 parent 73d6db9 commit 59808a4

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

launcher/src/main.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ fn main() -> ExitCode {
9494
// Determine number of shards based on command line arg and env vars
9595
let num_shard = find_num_shards(args.num_shard);
9696

97+
// Resolve fast tokenizer path
98+
let tokenizer_path = resolve_tokenizer_path(
99+
&args.model_name, args.revision.as_deref()
100+
).expect("Could not find tokenizer for model");
101+
97102
// Signal handler
98103
let running = Arc::new(AtomicBool::new(true));
99104
let r = running.clone();
@@ -173,9 +178,6 @@ fn main() -> ExitCode {
173178
return ExitCode::SUCCESS;
174179
}
175180

176-
let tokenizer_path = resolve_tokenizer_path(args.model_name, args.revision)
177-
.expect("Could not find tokenizer for model");
178-
179181
// All shard started
180182
// Start webserver
181183
info!("Starting Router");
@@ -546,7 +548,7 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
546548
}
547549

548550

549-
fn resolve_tokenizer_path(model_name: String, revision: Option<String>) -> Result<String, io::Error> {
551+
fn resolve_tokenizer_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
550552
let cache = env::var("TRANSFORMERS_CACHE")
551553
.or_else(|_| env::var("HUGGINGFACE_HUB_CACHE")).ok();
552554
let mut model_dir = cache.as_ref().map(
@@ -558,18 +560,21 @@ fn resolve_tokenizer_path(model_name: String, revision: Option<String>) -> Resul
558560
}
559561
}
560562
if let Some(dir) = model_dir {
561-
let ref_name = revision.unwrap_or("main".into());
562-
let ref_path = dir.join("refs").join(&ref_name);
563-
let ref_contents = fs::read_to_string(ref_path)?;
563+
let revision = revision.unwrap_or("main");
564+
let ref_path = dir.join("refs").join(&revision);
565+
let revision = match ref_path.try_exists()? {
566+
true => fs::read_to_string(ref_path)?,
567+
false => revision.to_string(),
568+
};
564569
let tok_path = dir.join("snapshots")
565-
.join(ref_contents).join("tokenizer.json");
570+
.join(&revision).join("tokenizer.json");
566571
if tok_path.try_exists()? {
567572
Ok(tok_path.to_string_lossy().into())
568573
} else {
569574
Err(io::Error::new(
570575
ErrorKind::NotFound,
571576
format!(
572-
"Tokenizer file not found in local cache for model {model_name}, revision {ref_name}"
577+
"Tokenizer file not found in local cache for model {model_name}, revision {revision}"
573578
)
574579
))
575580
}

0 commit comments

Comments
 (0)