Skip to content

Commit f98c231

Browse files
committed
moving Variable from Python to Rust
1 parent 80402f7 commit f98c231

File tree

2 files changed

+81
-46
lines changed

2 files changed

+81
-46
lines changed

notebooks/experiments.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@ def _(tribles):
9898
return (IdOwner,)
9999

100100

101+
@app.cell
102+
def _(tribles):
103+
Variable = tribles.Variable
104+
return (Variable,)
105+
106+
107+
@app.cell
108+
def _(tribles):
109+
VariableContext = tribles.VariableContext
110+
return (VariableContext,)
111+
112+
101113
@app.cell
102114
def _(IdOwner):
103115
owner = IdOwner()
@@ -236,35 +248,6 @@ def _(Id, Namespace, tribles):
236248
return (metadata_ns,)
237249

238250

239-
@app.cell
240-
def _(Variable):
241-
class VariableContext:
242-
def __init__(self):
243-
self.variables = []
244-
245-
def new(self, name=None):
246-
i = len(self.variables)
247-
assert i < 128
248-
v = Variable(i, name)
249-
self.variables.append(v)
250-
return v
251-
252-
def check_schemas(self):
253-
for v in self.variables:
254-
if not v.schema:
255-
if v.name:
256-
name = "'" + v.name + "'"
257-
else:
258-
name = "_"
259-
raise TypeError(
260-
"missing schema for variable "
261-
+ name
262-
+ "/"
263-
+ str(v.index)
264-
)
265-
return (VariableContext,)
266-
267-
268251
@app.cell
269252
def _(VariableContext, tribles):
270253
def find(query):

src/lib.rs

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
};
77

88
use itertools::Itertools;
9-
use parking_lot::{ArcMutexGuard, Mutex, RawMutex};
9+
use parking_lot::{ArcMutexGuard, Mutex, RawMutex, RwLock};
1010
use pyo3::{
1111
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
1212
prelude::*,
@@ -204,6 +204,7 @@ pub fn metadata_description() -> PyTribleSet {
204204
PyTribleSet(Mutex::new(tribles::metadata::metadata::description()))
205205
}
206206

207+
#[derive(Debug, Copy, Clone)]
207208
#[pyclass(frozen, name = "Id")]
208209
pub struct PyId(Id);
209210

@@ -274,7 +275,7 @@ impl PyIdOwner {
274275

275276
pub fn has(&self, v: PyRef<'_, PyVariable>) -> PyConstraint {
276277
PyConstraint {
277-
constraint: Arc::new(self.0.lock_arc().has(Variable::new(v.index))),
278+
constraint: Arc::new(self.0.lock_arc().has(Variable::new(v.0.read().index))),
278279
}
279280
}
280281

@@ -325,7 +326,7 @@ impl PyIdOwnerGuard {
325326
pub fn has(&self, v: PyRef<'_, PyVariable>) -> PyResult<PyConstraint> {
326327
if let Some(guard) = &mut *self.0.lock() {
327328
Ok(PyConstraint {
328-
constraint: Arc::new(guard.has(Variable::new(v.index))),
329+
constraint: Arc::new(guard.has(Variable::new(v.0.read().index))),
329330
})
330331
} else {
331332
Err(PyErr::new::<PyRuntimeError, _>("guard has been released"))
@@ -490,24 +491,25 @@ impl PyTribleSet {
490491
) -> PyConstraint {
491492
PyConstraint {
492493
constraint: Arc::new(self.0.lock().pattern(
493-
Variable::new(ev.index),
494-
Variable::new(av.index),
495-
Variable::<UnknownValue>::new(vv.index),
494+
Variable::new(ev.0.read().index),
495+
Variable::new(av.0.read().index),
496+
Variable::<UnknownValue>::new(vv.0.read().index),
496497
)),
497498
}
498499
}
499500
}
500501

501-
#[pyclass(frozen, name = "Variable")]
502-
pub struct PyVariable {
502+
pub struct InnerVariable {
503503
index: usize,
504+
name: String,
504505

505-
_value_schema: Id,
506+
_value_schema: Option<Id>,
506507
_blob_schema: Option<Id>,
507508
}
508509

509-
#[pymethods]
510-
impl PyVariable {}
510+
#[pyclass(frozen, name = "Variable")]
511+
pub struct PyVariable(RwLock<InnerVariable>);
512+
511513
// class Variable:
512514
// def __init__(self, index, name=None):
513515
// self.index = index
@@ -539,6 +541,51 @@ impl PyVariable {}
539541
// + str(schema)
540542
// )
541543

544+
#[pymethods]
545+
impl PyVariable {
546+
#[new]
547+
pub fn new(index: usize, name: String) -> Self {
548+
PyVariable(RwLock::new(InnerVariable {
549+
index,
550+
name,
551+
_value_schema: None,
552+
_blob_schema: None,
553+
}))
554+
}
555+
556+
#[pyo3(signature = (value_schema, blob_schema=None))]
557+
pub fn annotate_schemas(&self, value_schema: PyId, blob_schema: Option<&PyId>) {
558+
let mut variable = self.0.write();
559+
variable._value_schema = Some(value_schema.0);
560+
variable._blob_schema = blob_schema.map(|id| id.0);
561+
}
562+
}
563+
564+
//class VariableContext:
565+
//def __init__(self):
566+
//self.variables = []
567+
//
568+
//def new(self, name=None):
569+
//i = len(self.variables)
570+
//assert i < 128
571+
//v = Variable(i, name)
572+
//self.variables.append(v)
573+
//return v
574+
//
575+
//def check_schemas(self):
576+
//for v in self.variables:
577+
// if not v.schema:
578+
// if v.name:
579+
// name = "'" + v.name + "'"
580+
// else:
581+
// name = "_"
582+
// raise TypeError(
583+
// "missing schema for variable "
584+
// + name
585+
// + "/"
586+
// + str(v.index)
587+
// )
588+
542589
#[pyclass(frozen, name = "Query")]
543590
pub struct PyQuery {
544591
query: Mutex<
@@ -580,30 +627,35 @@ pub fn intersect(constraints: Vec<Py<PyConstraint>>) -> PyConstraint {
580627

581628
/// Find solutions for the provided constraint.
582629
#[pyfunction]
583-
pub fn solve(projected: Vec<Py<PyVariable>>, constraint: &PyConstraint) -> PyQuery {
630+
pub fn solve(projected: Vec<Py<PyVariable>>, constraint: &PyConstraint) -> PyResult<PyQuery> {
584631
let constraint = constraint.constraint.clone();
585632

586633
let postprocessing = Box::new(move |binding: &Binding| {
587634
let mut vec = vec![];
588635
for v in &projected {
589636
let v = v.get();
590637
let value = *binding
591-
.get(v.index)
638+
.get(v.0.read().index)
592639
.expect("constraint should contain projected variables");
640+
593641
vec.push(PyValue {
594642
value,
595-
_value_schema: v._value_schema,
596-
_blob_schema: v._blob_schema,
643+
_value_schema: v
644+
.0
645+
.read()
646+
._value_schema
647+
.expect("variable with uninitialized value schema"),
648+
_blob_schema: v.0.read()._blob_schema,
597649
});
598650
}
599651
vec
600652
}) as Box<dyn Fn(&Binding) -> Vec<PyValue> + Send>;
601653

602654
let query = tribles::query::Query::new(constraint, postprocessing);
603655

604-
PyQuery {
656+
Ok(PyQuery {
605657
query: Mutex::new(query),
606-
}
658+
})
607659
}
608660

609661
#[pymethods]

0 commit comments

Comments
 (0)