Skip to content

Commit a819128

Browse files
authored
Ctypes __mul__ (RustPython#6305)
1 parent 23ec5a5 commit a819128

File tree

2 files changed

+88
-13
lines changed

2 files changed

+88
-13
lines changed

crates/vm/src/stdlib/ctypes/base.rs

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use crate::builtins::PyType;
33
use crate::builtins::{PyBytes, PyFloat, PyInt, PyNone, PyStr, PyTypeRef};
44
use crate::convert::ToPyObject;
55
use crate::function::{Either, OptionalArg};
6+
use crate::protocol::PyNumberMethods;
67
use crate::stdlib::ctypes::_ctypes::new_simple_type;
7-
use crate::types::Constructor;
8+
use crate::types::{AsNumber, Constructor};
89
use crate::{AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine};
910
use crossbeam_utils::atomic::AtomicCell;
1011
use num_traits::ToPrimitive;
@@ -158,9 +159,10 @@ pub struct PyCData {
158159
impl PyCData {}
159160

160161
#[pyclass(module = "_ctypes", name = "PyCSimpleType", base = PyType)]
162+
#[derive(Debug, PyPayload)]
161163
pub struct PyCSimpleType {}
162164

163-
#[pyclass(flags(BASETYPE))]
165+
#[pyclass(flags(BASETYPE), with(AsNumber))]
164166
impl PyCSimpleType {
165167
#[allow(clippy::new_ret_no_self)]
166168
#[pymethod]
@@ -186,6 +188,33 @@ impl PyCSimpleType {
186188

187189
PyCSimpleType::from_param(cls, as_parameter, vm)
188190
}
191+
192+
#[pymethod]
193+
fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult {
194+
PyCSimple::repeat(cls, n, vm)
195+
}
196+
}
197+
198+
impl AsNumber for PyCSimpleType {
199+
fn as_number() -> &'static PyNumberMethods {
200+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
201+
multiply: Some(|a, b, vm| {
202+
// a is a PyCSimpleType instance (type object like c_char)
203+
// b is int (array size)
204+
let cls = a
205+
.downcast_ref::<PyType>()
206+
.ok_or_else(|| vm.new_type_error("expected type".to_owned()))?;
207+
let n = b
208+
.try_index(vm)?
209+
.as_bigint()
210+
.to_isize()
211+
.ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?;
212+
PyCSimple::repeat(cls.to_owned(), n, vm)
213+
}),
214+
..PyNumberMethods::NOT_IMPLEMENTED
215+
};
216+
&AS_NUMBER
217+
}
189218
}
190219

191220
#[pyclass(
@@ -215,8 +244,18 @@ impl Constructor for PyCSimple {
215244
let attributes = cls.get_attributes();
216245
let _type_ = attributes
217246
.iter()
218-
.find(|(k, _)| k.to_object().str(vm).unwrap().to_string() == *"_type_")
219-
.unwrap()
247+
.find(|(k, _)| {
248+
k.to_object()
249+
.str(vm)
250+
.map(|s| s.to_string() == "_type_")
251+
.unwrap_or(false)
252+
})
253+
.ok_or_else(|| {
254+
vm.new_type_error(format!(
255+
"cannot create '{}' instances: no _type_ attribute",
256+
cls.name()
257+
))
258+
})?
220259
.1
221260
.str(vm)?
222261
.to_string();
@@ -276,11 +315,6 @@ impl PyCSimple {
276315
}
277316
.to_pyobject(vm))
278317
}
279-
280-
#[pyclassmethod]
281-
fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult {
282-
PyCSimple::repeat(cls, n, vm)
283-
}
284318
}
285319

286320
impl PyCSimple {

crates/vm/src/stdlib/ctypes/pointer.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
use crossbeam_utils::atomic::AtomicCell;
2+
use num_traits::ToPrimitive;
13
use rustpython_common::lock::PyRwLock;
24

3-
use crate::builtins::PyType;
5+
use crate::builtins::{PyType, PyTypeRef};
6+
use crate::convert::ToPyObject;
7+
use crate::protocol::PyNumberMethods;
48
use crate::stdlib::ctypes::PyCData;
5-
use crate::{PyObjectRef, PyResult};
9+
use crate::types::AsNumber;
10+
use crate::{PyObjectRef, PyResult, VirtualMachine};
611

712
#[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")]
813
#[derive(PyPayload, Debug)]
@@ -11,8 +16,44 @@ pub struct PyCPointerType {
1116
pub(crate) inner: PyCPointer,
1217
}
1318

14-
#[pyclass]
15-
impl PyCPointerType {}
19+
#[pyclass(flags(IMMUTABLETYPE), with(AsNumber))]
20+
impl PyCPointerType {
21+
#[pymethod]
22+
fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult {
23+
use super::array::{PyCArray, PyCArrayType};
24+
if n < 0 {
25+
return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}")));
26+
}
27+
Ok(PyCArrayType {
28+
inner: PyCArray {
29+
typ: PyRwLock::new(cls),
30+
length: AtomicCell::new(n as usize),
31+
value: PyRwLock::new(vm.ctx.none()),
32+
},
33+
}
34+
.to_pyobject(vm))
35+
}
36+
}
37+
38+
impl AsNumber for PyCPointerType {
39+
fn as_number() -> &'static PyNumberMethods {
40+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
41+
multiply: Some(|a, b, vm| {
42+
let cls = a
43+
.downcast_ref::<PyType>()
44+
.ok_or_else(|| vm.new_type_error("expected type".to_owned()))?;
45+
let n = b
46+
.try_index(vm)?
47+
.as_bigint()
48+
.to_isize()
49+
.ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?;
50+
PyCPointerType::__mul__(cls.to_owned(), n, vm)
51+
}),
52+
..PyNumberMethods::NOT_IMPLEMENTED
53+
};
54+
&AS_NUMBER
55+
}
56+
}
1657

1758
#[pyclass(
1859
name = "_Pointer",

0 commit comments

Comments
 (0)