|
1 | | -use std::io::{Cursor, Read}; |
2 | | -use std::net::SocketAddr; |
3 | | -use std::sync::mpsc::{sync_channel, Receiver}; |
4 | | -use std::thread; |
| 1 | +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; |
| 2 | +use std::time::Duration; |
5 | 3 |
|
6 | | -use mime::Mime; |
7 | | -use multipart::server::Multipart; |
8 | | -use tokio::runtime::Builder; |
9 | | -use warp::Filter; |
| 4 | +use axum::extract::{DefaultBodyLimit, Multipart, State}; |
| 5 | +use axum::routing::post; |
| 6 | +use axum::Router; |
| 7 | +use bytes::Bytes; |
10 | 8 |
|
11 | | -fn start_server() -> (u16, Receiver<Option<String>>) { |
| 9 | +#[derive(Debug, PartialEq, Eq)] |
| 10 | +struct Part { |
| 11 | + name: Option<String>, |
| 12 | + file_name: Option<String>, |
| 13 | + content_type: Option<String>, |
| 14 | + data: Bytes, |
| 15 | +} |
| 16 | + |
| 17 | +async fn start_server() -> (u16, Receiver<Vec<Part>>) { |
12 | 18 | let (send, recv) = sync_channel(1); |
13 | | - let rt = Builder::new_multi_thread().enable_io().enable_time().build().unwrap(); |
14 | | - // ported from warp::multipart, which has a length limit (and we're generic over Read) |
15 | | - let filter = warp::path("multipart") |
16 | | - .and( |
17 | | - warp::header::<Mime>("content-type") |
18 | | - .and_then(|ct: Mime| async move { |
19 | | - ct.get_param("boundary") |
20 | | - .map(|mime| mime.to_string()) |
21 | | - .ok_or_else(warp::reject::reject) |
22 | | - }) |
23 | | - .and(warp::body::bytes()) |
24 | | - .map(|boundary, bytes| Multipart::with_body(Cursor::new(bytes), boundary)), |
25 | | - ) |
26 | | - .map(move |mut form: Multipart<_>| { |
27 | | - let mut found_text = false; |
28 | | - let mut found_file = false; |
29 | | - let mut err = false; |
30 | | - let mut buf = String::new(); |
31 | | - form.foreach_entry(|mut entry| { |
32 | | - if err { |
33 | | - return; |
34 | | - } |
35 | | - entry.data.read_to_string(&mut buf).unwrap(); |
36 | | - if !found_text && &*entry.headers.name == "Hello" && buf == "world!" { |
37 | | - found_text = true; |
38 | | - } else if !found_file |
39 | | - && &*entry.headers.name == "file" |
40 | | - && entry.headers.filename.as_deref() == Some("hello.txt") |
41 | | - && entry.headers.content_type.as_ref().map(|x| x.as_ref() == "text/plain") == Some(true) |
42 | | - && buf == "Hello, world!" |
43 | | - { |
44 | | - found_file = true; |
45 | | - } else { |
46 | | - send.send(Some(format!("Unexpected entry {:?} = {:?}", entry.headers, buf))) |
47 | | - .unwrap(); |
48 | | - err = true; |
49 | | - } |
50 | | - buf.clear(); |
51 | | - }) |
52 | | - .unwrap(); |
53 | | - if err { |
54 | | - return "ERR"; |
55 | | - } |
56 | | - send.send(Some( |
57 | | - match (found_text, found_file) { |
58 | | - (false, false) => "Missing both fields!", |
59 | | - (true, false) => "Missing file field!", |
60 | | - (false, true) => "Missing text field!", |
61 | | - (true, true) => { |
62 | | - send.send(None).unwrap(); |
63 | | - return "OK"; |
64 | | - } |
65 | | - } |
66 | | - .to_string(), |
67 | | - )) |
68 | | - .unwrap(); |
69 | | - "ERR" |
70 | | - }); |
71 | | - let (addr, fut) = |
72 | | - rt.block_on(async { warp::serve(filter).bind_ephemeral("0.0.0.0:0".parse::<SocketAddr>().unwrap()) }); |
73 | | - let port = addr.port(); |
74 | | - thread::spawn(move || { |
75 | | - rt.block_on(fut); |
| 19 | + |
| 20 | + async fn accept_form(State(send): State<SyncSender<Vec<Part>>>, mut multipart: Multipart) -> &'static str { |
| 21 | + let mut parts = Vec::new(); |
| 22 | + while let Some(field) = multipart.next_field().await.unwrap() { |
| 23 | + parts.push(Part { |
| 24 | + name: field.name().map(|s| s.to_string()), |
| 25 | + file_name: field.file_name().map(|s| s.to_string()), |
| 26 | + content_type: field.content_type().map(|s| s.to_string()), |
| 27 | + data: field.bytes().await.unwrap(), |
| 28 | + }); |
| 29 | + } |
| 30 | + send.send(parts).unwrap(); |
| 31 | + "OK" |
| 32 | + } |
| 33 | + |
| 34 | + let app = Router::new() |
| 35 | + .route("/multipart", post(accept_form)) |
| 36 | + .layer(DefaultBodyLimit::disable()) |
| 37 | + .with_state(send); |
| 38 | + |
| 39 | + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); |
| 40 | + let port = listener.local_addr().unwrap().port(); |
| 41 | + tokio::spawn(async move { |
| 42 | + axum::serve(listener, app).await.unwrap(); |
76 | 43 | }); |
77 | 44 | (port, recv) |
78 | 45 | } |
79 | 46 |
|
80 | | -#[test] |
81 | | -fn test_multipart_default() -> attohttpc::Result<()> { |
82 | | - let file = attohttpc::MultipartFile::new("file", b"Hello, world!") |
| 47 | +#[tokio::test(flavor = "multi_thread")] |
| 48 | +async fn test_multipart_default() -> attohttpc::Result<()> { |
| 49 | + let file = attohttpc::MultipartFile::new("file", b"abc123") |
83 | 50 | .with_type("text/plain")? |
84 | 51 | .with_filename("hello.txt"); |
85 | 52 | let form = attohttpc::MultipartBuilder::new() |
86 | 53 | .with_text("Hello", "world!") |
87 | 54 | .with_file(file) |
88 | 55 | .build()?; |
89 | 56 |
|
90 | | - let (port, recv) = start_server(); |
| 57 | + let (port, recv) = start_server().await; |
91 | 58 |
|
92 | 59 | attohttpc::post(format!("http://localhost:{port}/multipart")) |
93 | 60 | .body(form) |
94 | 61 | .send()? |
95 | 62 | .text()?; |
96 | 63 |
|
97 | | - if let Some(err) = recv.recv().unwrap() { |
98 | | - panic!("{}", err); |
99 | | - } |
| 64 | + let parts = recv.recv_timeout(Duration::from_secs(5)).unwrap(); |
| 65 | + assert_eq!(parts.len(), 2); |
| 66 | + assert_eq!( |
| 67 | + parts, |
| 68 | + vec![ |
| 69 | + Part { |
| 70 | + name: Some("Hello".to_string()), |
| 71 | + file_name: None, |
| 72 | + content_type: None, |
| 73 | + data: Bytes::from(&b"world!"[..]) |
| 74 | + }, |
| 75 | + Part { |
| 76 | + name: Some("file".to_string()), |
| 77 | + file_name: Some("hello.txt".to_string()), |
| 78 | + content_type: Some("text/plain".to_string()), |
| 79 | + data: Bytes::from(&b"abc123"[..]) |
| 80 | + } |
| 81 | + ] |
| 82 | + ); |
100 | 83 |
|
101 | 84 | Ok(()) |
102 | 85 | } |
0 commit comments