|
1 | 1 | use pyo3::prelude::*; |
2 | | -use pyo3::buffer::PyBuffer; |
3 | | -use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList}; |
| 2 | +use pyo3::types::{PyBytes, PyDict, PyList}; |
4 | 3 | use pyo3::IntoPyObjectExt; |
| 4 | +use std::borrow::Cow; |
5 | 5 |
|
6 | 6 | use json_prob_parser::beam; |
7 | 7 | use json_prob_parser::json::JsonValue; |
@@ -36,57 +36,23 @@ fn json_to_py(py: Python<'_>, v: &JsonValue) -> PyObject { |
36 | 36 |
|
37 | 37 | #[pyfunction] |
38 | 38 | fn strict_loads_py(py: Python<'_>, input: &Bound<'_, PyAny>) -> PyResult<PyObject> { |
39 | | - let parsed = if let Ok(s) = input.extract::<&str>() { |
40 | | - strict::strict_parse(s) |
41 | | - .map_err(|e| pyo3::exceptions::PyValueError::new_err((e.message, e.pos)))? |
42 | | - } else if let Ok(b) = input.downcast::<PyBytes>() { |
43 | | - let s = std::str::from_utf8(b.as_bytes()).map_err(|_| { |
| 39 | + let parsed = if let Ok(s) = input.extract::<Cow<str>>() { |
| 40 | + strict::strict_parse(s.as_ref()) |
| 41 | + } else if let Ok(b) = input.extract::<Cow<[u8]>>() { |
| 42 | + let s = std::str::from_utf8(b.as_ref()).map_err(|_| { |
44 | 43 | pyo3::exceptions::PyValueError::new_err(( |
45 | 44 | "str is not valid UTF-8: surrogates not allowed".to_string(), |
46 | 45 | 0_usize, |
47 | 46 | )) |
48 | 47 | })?; |
49 | 48 | strict::strict_parse(s) |
50 | | - .map_err(|e| pyo3::exceptions::PyValueError::new_err((e.message, e.pos)))? |
51 | | - } else if let Ok(ba) = input.downcast::<PyByteArray>() { |
52 | | - let parsed = { |
53 | | - // SAFETY: We do not call back into Python while using this slice. |
54 | | - let bytes = unsafe { ba.as_bytes() }; |
55 | | - let s = std::str::from_utf8(bytes).map_err(|_| { |
56 | | - pyo3::exceptions::PyValueError::new_err(( |
57 | | - "str is not valid UTF-8: surrogates not allowed".to_string(), |
58 | | - 0_usize, |
59 | | - )) |
60 | | - })?; |
61 | | - strict::strict_parse(s) |
62 | | - }; |
63 | | - parsed.map_err(|e| pyo3::exceptions::PyValueError::new_err((e.message, e.pos)))? |
64 | | - } else if let Ok(buf) = PyBuffer::<u8>::get(input) { |
65 | | - let parsed = { |
66 | | - let cells = buf.as_slice(py).ok_or_else(|| { |
67 | | - pyo3::exceptions::PyValueError::new_err(( |
68 | | - "input buffer must be C-contiguous".to_string(), |
69 | | - 0_usize, |
70 | | - )) |
71 | | - })?; |
72 | | - |
73 | | - // ReadOnlyCell<u8> is repr(transparent) over UnsafeCell<u8>, so this is safe. |
74 | | - let bytes = unsafe { std::slice::from_raw_parts(cells.as_ptr() as *const u8, cells.len()) }; |
75 | | - let s = std::str::from_utf8(bytes).map_err(|_| { |
76 | | - pyo3::exceptions::PyValueError::new_err(( |
77 | | - "str is not valid UTF-8: surrogates not allowed".to_string(), |
78 | | - 0_usize, |
79 | | - )) |
80 | | - })?; |
81 | | - strict::strict_parse(s) |
82 | | - }; |
83 | | - parsed.map_err(|e| pyo3::exceptions::PyValueError::new_err((e.message, e.pos)))? |
84 | 49 | } else { |
85 | 50 | return Err(pyo3::exceptions::PyValueError::new_err(( |
86 | 51 | "input must be bytes, bytearray, memoryview, or str".to_string(), |
87 | 52 | 0_usize, |
88 | 53 | ))); |
89 | | - }; |
| 54 | + } |
| 55 | + .map_err(|e| pyo3::exceptions::PyValueError::new_err((e.message, e.pos)))?; |
90 | 56 |
|
91 | 57 | Ok(json_to_py(py, &parsed)) |
92 | 58 | } |
|
0 commit comments