Skip to content

Commit d8a4a09

Browse files
authored
Fix #[pyclass(base=...)] notation (RustPython#6242)
1 parent ed9a61d commit d8a4a09

File tree

14 files changed

+251
-224
lines changed

14 files changed

+251
-224
lines changed

derive-impl/src/pyclass.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ fn generate_class_def(
308308
ident: &Ident,
309309
name: &str,
310310
module_name: Option<&str>,
311-
base: Option<String>,
311+
base: Option<syn::Path>,
312312
metaclass: Option<String>,
313313
unhashable: bool,
314314
attrs: &[Attribute],
@@ -358,7 +358,6 @@ fn generate_class_def(
358358
Some(quote! { rustpython_vm::builtins::PyTuple })
359359
} else {
360360
base.as_ref().map(|typ| {
361-
let typ = Ident::new(typ, ident.span());
362361
quote_spanned! { ident.span() => #typ }
363362
})
364363
}
@@ -382,7 +381,6 @@ fn generate_class_def(
382381
});
383382

384383
let base_or_object = if let Some(base) = base {
385-
let base = Ident::new(&base, ident.span());
386384
quote! { #base }
387385
} else {
388386
quote! { ::rustpython_vm::builtins::PyBaseObject }

derive-impl/src/util.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,35 @@ impl ItemMetaInner {
187187
Ok(value)
188188
}
189189

190+
pub fn _optional_path(&self, key: &str) -> Result<Option<syn::Path>> {
191+
let value = if let Some((_, meta)) = self.meta_map.get(key) {
192+
let Meta::NameValue(syn::MetaNameValue { value, .. }) = meta else {
193+
bail_span!(
194+
meta,
195+
"#[{}({} = ...)] must be a name-value pair",
196+
self.meta_name(),
197+
key
198+
)
199+
};
200+
201+
// Try to parse as a Path (identifier or path like Foo or foo::Bar)
202+
match syn::parse2::<syn::Path>(value.to_token_stream()) {
203+
Ok(path) => Some(path),
204+
Err(_) => {
205+
bail_span!(
206+
value,
207+
"#[{}({} = ...)] must be a valid type path (e.g., PyBaseException)",
208+
self.meta_name(),
209+
key
210+
)
211+
}
212+
}
213+
} else {
214+
None
215+
};
216+
Ok(value)
217+
}
218+
190219
pub fn _has_key(&self, key: &str) -> Result<bool> {
191220
Ok(matches!(self.meta_map.get(key), Some((_, _))))
192221
}
@@ -384,8 +413,8 @@ impl ClassItemMeta {
384413
self.inner()._optional_str("ctx")
385414
}
386415

387-
pub fn base(&self) -> Result<Option<String>> {
388-
self.inner()._optional_str("base")
416+
pub fn base(&self) -> Result<Option<syn::Path>> {
417+
self.inner()._optional_path("base")
389418
}
390419

391420
pub fn unhashable(&self) -> Result<bool> {

derive/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub fn derive_from_args(input: TokenStream) -> TokenStream {
3434
/// - `IMMUTABLETYPE`: class attributes are immutable.
3535
/// - `with`: which trait implementations are to be included in the python class.
3636
/// ```rust, ignore
37-
/// #[pyclass(module = "my_module", name = "MyClass", base = "BaseClass")]
37+
/// #[pyclass(module = "my_module", name = "MyClass", base = BaseClass)]
3838
/// struct MyStruct {
3939
/// x: i32,
4040
/// }

stdlib/src/ssl.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ mod _ssl {
246246

247247
/// An error occurred in the SSL implementation.
248248
#[pyattr]
249-
#[pyexception(name = "SSLError", base = "PyOSError")]
249+
#[pyexception(name = "SSLError", base = PyOSError)]
250250
#[derive(Debug)]
251251
pub struct PySslError {}
252252

@@ -269,7 +269,7 @@ mod _ssl {
269269

270270
/// A certificate could not be verified.
271271
#[pyattr]
272-
#[pyexception(name = "SSLCertVerificationError", base = "PySslError")]
272+
#[pyexception(name = "SSLCertVerificationError", base = PySslError)]
273273
#[derive(Debug)]
274274
pub struct PySslCertVerificationError {}
275275

@@ -278,7 +278,7 @@ mod _ssl {
278278

279279
/// SSL/TLS session closed cleanly.
280280
#[pyattr]
281-
#[pyexception(name = "SSLZeroReturnError", base = "PySslError")]
281+
#[pyexception(name = "SSLZeroReturnError", base = PySslError)]
282282
#[derive(Debug)]
283283
pub struct PySslZeroReturnError {}
284284

@@ -287,7 +287,7 @@ mod _ssl {
287287

288288
/// Non-blocking SSL socket needs to read more data.
289289
#[pyattr]
290-
#[pyexception(name = "SSLWantReadError", base = "PySslError")]
290+
#[pyexception(name = "SSLWantReadError", base = PySslError)]
291291
#[derive(Debug)]
292292
pub struct PySslWantReadError {}
293293

@@ -296,7 +296,7 @@ mod _ssl {
296296

297297
/// Non-blocking SSL socket needs to write more data.
298298
#[pyattr]
299-
#[pyexception(name = "SSLWantWriteError", base = "PySslError")]
299+
#[pyexception(name = "SSLWantWriteError", base = PySslError)]
300300
#[derive(Debug)]
301301
pub struct PySslWantWriteError {}
302302

@@ -305,7 +305,7 @@ mod _ssl {
305305

306306
/// System error when attempting SSL operation.
307307
#[pyattr]
308-
#[pyexception(name = "SSLSyscallError", base = "PySslError")]
308+
#[pyexception(name = "SSLSyscallError", base = PySslError)]
309309
#[derive(Debug)]
310310
pub struct PySslSyscallError {}
311311

@@ -314,7 +314,7 @@ mod _ssl {
314314

315315
/// SSL/TLS connection terminated abruptly.
316316
#[pyattr]
317-
#[pyexception(name = "SSLEOFError", base = "PySslError")]
317+
#[pyexception(name = "SSLEOFError", base = PySslError)]
318318
#[derive(Debug)]
319319
pub struct PySslEOFError {}
320320

vm/src/builtins/bool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl PyObjectRef {
7777
}
7878
}
7979

80-
#[pyclass(name = "bool", module = false, base = "PyInt")]
80+
#[pyclass(name = "bool", module = false, base = PyInt)]
8181
pub struct PyBool;
8282

8383
impl PyPayload for PyBool {

vm/src/builtins/builtin_func.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ impl Representable for PyNativeFunction {
148148
impl Unconstructible for PyNativeFunction {}
149149

150150
// `PyCMethodObject` in CPython
151-
#[pyclass(name = "builtin_method", module = false, base = "PyNativeFunction")]
151+
#[pyclass(name = "builtin_method", module = false, base = PyNativeFunction)]
152152
pub struct PyNativeMethod {
153153
pub(crate) func: PyNativeFunction,
154154
pub(crate) class: &'static Py<PyType>, // TODO: the actual life is &'self

0 commit comments

Comments
 (0)