|
| 1 | +// Base64 encoding module |
| 2 | + |
| 3 | +pub(crate) use _base64::make_module; |
| 4 | + |
| 5 | +const PAD_BYTE: u8 = b'='; |
| 6 | +const ENCODE_TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; |
| 7 | + |
| 8 | +#[inline] |
| 9 | +fn encoded_output_len(input_len: usize) -> Option<usize> { |
| 10 | + input_len |
| 11 | + .checked_add(2) |
| 12 | + .map(|n| n / 3) |
| 13 | + .and_then(|blocks| blocks.checked_mul(4)) |
| 14 | +} |
| 15 | + |
| 16 | +#[inline] |
| 17 | +fn encode_into(input: &[u8], output: &mut [u8]) -> usize { |
| 18 | + let mut src_index = 0; |
| 19 | + let mut dst_index = 0; |
| 20 | + let len = input.len(); |
| 21 | + |
| 22 | + // Process full 3-byte chunks |
| 23 | + while src_index + 3 <= len { |
| 24 | + let chunk = (u32::from(input[src_index]) << 16) |
| 25 | + | (u32::from(input[src_index + 1]) << 8) |
| 26 | + | u32::from(input[src_index + 2]); |
| 27 | + output[dst_index] = ENCODE_TABLE[((chunk >> 18) & 0x3f) as usize]; |
| 28 | + output[dst_index + 1] = ENCODE_TABLE[((chunk >> 12) & 0x3f) as usize]; |
| 29 | + output[dst_index + 2] = ENCODE_TABLE[((chunk >> 6) & 0x3f) as usize]; |
| 30 | + output[dst_index + 3] = ENCODE_TABLE[(chunk & 0x3f) as usize]; |
| 31 | + src_index += 3; |
| 32 | + dst_index += 4; |
| 33 | + } |
| 34 | + |
| 35 | + // Process remaining bytes (1 or 2 bytes) |
| 36 | + match len - src_index { |
| 37 | + 0 => {} |
| 38 | + 1 => { |
| 39 | + let chunk = u32::from(input[src_index]) << 16; |
| 40 | + output[dst_index] = ENCODE_TABLE[((chunk >> 18) & 0x3f) as usize]; |
| 41 | + output[dst_index + 1] = ENCODE_TABLE[((chunk >> 12) & 0x3f) as usize]; |
| 42 | + output[dst_index + 2] = PAD_BYTE; |
| 43 | + output[dst_index + 3] = PAD_BYTE; |
| 44 | + dst_index += 4; |
| 45 | + } |
| 46 | + 2 => { |
| 47 | + let chunk = (u32::from(input[src_index]) << 16) |
| 48 | + | (u32::from(input[src_index + 1]) << 8); |
| 49 | + output[dst_index] = ENCODE_TABLE[((chunk >> 18) & 0x3f) as usize]; |
| 50 | + output[dst_index + 1] = ENCODE_TABLE[((chunk >> 12) & 0x3f) as usize]; |
| 51 | + output[dst_index + 2] = ENCODE_TABLE[((chunk >> 6) & 0x3f) as usize]; |
| 52 | + output[dst_index + 3] = PAD_BYTE; |
| 53 | + dst_index += 4; |
| 54 | + } |
| 55 | + _ => unreachable!("len - src_index cannot exceed 2"), |
| 56 | + } |
| 57 | + |
| 58 | + dst_index |
| 59 | +} |
| 60 | + |
| 61 | +#[pymodule(name = "_base64")] |
| 62 | +mod _base64 { |
| 63 | + use crate::vm::{PyResult, VirtualMachine, function::ArgBytesLike}; |
| 64 | + |
| 65 | + #[pyfunction] |
| 66 | + fn standard_b64encode(data: ArgBytesLike, vm: &VirtualMachine) -> PyResult<Vec<u8>> { |
| 67 | + data.with_ref(|input| { |
| 68 | + let input_len = input.len(); |
| 69 | + |
| 70 | + let Some(output_len) = super::encoded_output_len(input_len) else { |
| 71 | + return Err(vm.new_memory_error("output length overflow".to_owned())); |
| 72 | + }; |
| 73 | + |
| 74 | + if output_len > isize::MAX as usize { |
| 75 | + return Err(vm.new_memory_error("output too large".to_owned())); |
| 76 | + } |
| 77 | + |
| 78 | + let mut output = vec![0u8; output_len]; |
| 79 | + let written = super::encode_into(input, &mut output); |
| 80 | + debug_assert_eq!(written, output_len); |
| 81 | + |
| 82 | + Ok(output) |
| 83 | + }) |
| 84 | + } |
| 85 | +} |
0 commit comments