Skip to content

Commit

Permalink
identify operators by index instead of function-address
Browse files Browse the repository at this point in the history
  • Loading branch information
bertiqwerty committed Jul 30, 2024
1 parent 3a405fc commit 2069b62
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 178 deletions.
95 changes: 64 additions & 31 deletions src/expression/deep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ use crate::{
},
exerr,
expression::flat::ExprIdxVec,
operators::UnaryOp,
BinOp, ExError, ExResult, Express, FloatOpsFactory, MakeOperators, MatchLiteral, NumberMatcher,
operators::{BinOpWithIdx, OperateBinary, UnaryFuncWithIdx, UnaryOp},
ExError, ExResult, Express, FloatOpsFactory, MakeOperators, MatchLiteral, NumberMatcher,
Operator,
};

#[cfg(feature = "partial")]
use crate::{DiffDataType, Differentiate};

/// Container of binary operators of one expression.
pub type BinOpVec<T> = SmallVec<[BinOp<T>; N_NODES_ON_STACK]>;
pub type BinOpVec<T> = SmallVec<[BinOpWithIdx<T>; N_NODES_ON_STACK]>;

macro_rules! attach_unary_op {
($name:ident) => {
Expand Down Expand Up @@ -54,7 +54,7 @@ mod detail {
use crate::{
data_type::DataType,
definitions::N_BINOPS_OF_DEEPEX_ON_STACK,
operators::{UnaryOp, VecOfUnaryFuncs},
operators::{BinOpWithIdx, UnaryFuncWithIdx, UnaryOp, VecOfUnaryFuncs},
parser::{self, Paren, ParsedToken},
DeepEx, ExError, ExResult, MakeOperators, MatchLiteral,
};
Expand Down Expand Up @@ -176,7 +176,7 @@ mod detail {
/// operator can be a composition of multiple functions.
pub fn process_unary<'a, T, OF, LM>(
token_idx: usize,
unary_op: fn(T) -> T,
unary_op: UnaryFuncWithIdx<T>,
repr: &'a str,
parsed_tokens: &[ParsedToken<'a, T>],
parsed_vars: &[&'a str],
Expand All @@ -191,9 +191,9 @@ mod detail {
let iter_of_uops = iter::once(Ok((repr, unary_op))).chain(
(token_idx + 1..parsed_tokens.len())
.map(|j| match &parsed_tokens[j] {
ParsedToken::Op(op) => {
ParsedToken::Op((op_idx, op)) => {
if op.has_unary() {
Some(op)
Some((op_idx, op))
} else {
None
}
Expand All @@ -203,7 +203,14 @@ mod detail {
.take_while(|op| op.is_some())
.map(|op| {
let op = op.unwrap();
Ok((op.repr(), op.unary()?))
let (op_idx, op) = op;
Ok((
op.repr(),
UnaryFuncWithIdx {
idx: *op_idx,
f: op.unary()?,
},
))
}),
);
let vec_of_uops = iter_of_uops
Expand Down Expand Up @@ -278,7 +285,7 @@ mod detail {
let mut idx_tkn: usize = 0;
while idx_tkn < parsed_tokens.len() {
match &parsed_tokens[idx_tkn] {
ParsedToken::Op(op) => {
ParsedToken::Op((op_idx, op)) => {
if parser::is_operator_binary(
op,
if idx_tkn == 0 {
Expand All @@ -287,13 +294,19 @@ mod detail {
Some(&parsed_tokens[idx_tkn - 1])
},
)? {
bin_ops.push(op.bin()?);
bin_ops.push(BinOpWithIdx {
idx: *op_idx,
op: op.bin()?,
});
reprs_bin_ops.push(op.repr());
idx_tkn += 1;
} else {
let (node, idx_forward) = process_unary(
idx_tkn,
op.unary()?,
UnaryFuncWithIdx {
f: op.unary()?,
idx: *op_idx,
},
op.repr(),
parsed_tokens,
parsed_vars,
Expand Down Expand Up @@ -381,34 +394,45 @@ mod detail {
fn find_op<'a, T: Clone + Debug>(
repr: &'a str,
ops: &[Operator<'a, T>],
) -> Option<Operator<'a, T>> {
ops.iter().find(|op| op.repr() == repr).cloned()
) -> Option<(usize, Operator<'a, T>)> {
ops.iter()
.enumerate()
.find(|(_, op)| op.repr() == repr)
.map(|(i, op)| (i, op.clone()))
}

fn find_bin_op<'a, T: Clone + Debug>(
repr: &'a str,
ops: &[Operator<'a, T>],
) -> ExResult<BinOpsWithReprs<'a, T>> {
let op = find_op(repr, ops).ok_or_else(|| exerr!("did not find operator {}", repr))?;
let (op_idx, op) =
find_op(repr, ops).ok_or_else(|| exerr!("did not find operator {}", repr))?;
Ok(BinOpsWithReprs {
reprs: smallvec::smallvec![op.repr()],
ops: smallvec::smallvec![op.bin()?],
ops: smallvec::smallvec![BinOpWithIdx {
idx: op_idx,
op: op.bin()?
}],
})
}

fn find_unary_op<'a, T: Clone + Debug>(
repr: &'a str,
ops: &[Operator<'a, T>],
) -> ExResult<UnaryOpWithReprs<'a, T>> {
let op = find_op(repr, ops).ok_or_else(|| exerr!("did not find operator {}", repr))?;
let (op_idx, op) =
find_op(repr, ops).ok_or_else(|| exerr!("did not find operator {}", repr))?;
Ok(UnaryOpWithReprs {
reprs: smallvec::smallvec![op.repr()],
op: UnaryOp::from_vec(smallvec::smallvec![op.unary()?]),
op: UnaryOp::from_vec(smallvec::smallvec![UnaryFuncWithIdx {
idx: op_idx,
f: op.unary()?
}]),
})
}

pub fn prioritized_indices<T, OF, LM>(
bin_ops: &[BinOp<T>],
bin_ops: &[BinOpWithIdx<T>],
nodes: &[DeepNode<T, OF, LM>],
) -> ExprIdxVec
where
Expand All @@ -417,13 +441,14 @@ where
LM: MatchLiteral,
<T as FromStr>::Err: Debug,
{
let prio_increase = |bin_op_idx: usize| match (&nodes[bin_op_idx], &nodes[bin_op_idx + 1]) {
(DeepNode::Num(_), DeepNode::Num(_)) if bin_ops[bin_op_idx].is_commutative => {
let prio_inc = 5;
&bin_ops[bin_op_idx].prio * 10 + prio_inc
}
_ => &bin_ops[bin_op_idx].prio * 10,
};
let prio_increase =
|bin_op_node_idx: usize| match (&nodes[bin_op_node_idx], &nodes[bin_op_node_idx + 1]) {
(DeepNode::Num(_), DeepNode::Num(_)) if bin_ops[bin_op_node_idx].op.is_commutative => {
let prio_inc = 5;
&bin_ops[bin_op_node_idx].op.prio * 10 + prio_inc
}
_ => &bin_ops[bin_op_node_idx].op.prio * 10,
};

let mut indices: ExprIdxVec = (0..bin_ops.len()).collect();
indices.sort_by(|i1, i2| {
Expand Down Expand Up @@ -600,7 +625,7 @@ where
.bin_ops
.ops
.iter()
.map(|o| o.prio)
.map(|o| o.op.prio)
.collect::<SmallVec<[i64; N_NODES_ON_STACK]>>();
let mut used_prio_indices = ExprIdxVec::new();

Expand All @@ -614,7 +639,7 @@ where
if let (DeepNode::Num(num_1), DeepNode::Num(num_2)) = (node_1, node_2) {
if !(already_declined[num_idx] || already_declined[num_idx + 1]) {
let bin_op_result =
(self.bin_ops.ops[bin_op_idx].apply)(num_1.clone(), num_2.clone());
self.bin_ops.ops[bin_op_idx].apply(num_1.clone(), num_2.clone());
self.nodes[num_idx] = DeepNode::Num(bin_op_result);
self.nodes.remove(num_idx + 1);
already_declined.remove(num_idx + 1);
Expand Down Expand Up @@ -1387,28 +1412,36 @@ fn test_deep_compile() {
DeepNode::Num(0.5),
DeepNode::Num(1.4),
];
let make_bin_op = |idx: usize| BinOpWithIdx {
idx,
op: ops[idx].bin().unwrap(),
};
let make_unary_op = |idx: usize| UnaryFuncWithIdx {
idx,
f: ops[idx].unary().unwrap(),
};
let bin_ops = BinOpsWithReprs {
reprs: smallvec::smallvec![ops[1].repr(), ops[3].repr()],
ops: smallvec::smallvec![ops[1].bin().unwrap(), ops[3].bin().unwrap()],
ops: smallvec::smallvec![make_bin_op(1), make_bin_op(3)],
};
assert_eq!(bin_ops.reprs[0], "*");
assert_eq!(bin_ops.reprs[1], "+");
let unary_op = UnaryOpWithReprs {
reprs: smallvec::smallvec![ops[8].repr()],
op: UnaryOp::from_vec(smallvec::smallvec![ops[8].unary().unwrap()]),
op: UnaryOp::from_vec(smallvec::smallvec![make_unary_op(8)]),
};
assert_eq!(unary_op.reprs[0], "abs");
let deep_ex = DeepEx::new(nodes, bin_ops, unary_op).unwrap();

let bin_ops = BinOpsWithReprs {
reprs: smallvec::smallvec![ops[1].repr(), ops[3].repr()],
ops: smallvec::smallvec![ops[1].bin().unwrap(), ops[3].bin().unwrap()],
ops: smallvec::smallvec![make_bin_op(1), make_bin_op(3)],
};
assert_eq!(bin_ops.reprs[0], "*");
assert_eq!(bin_ops.reprs[1], "+");
let unary_op = UnaryOpWithReprs {
reprs: smallvec::smallvec![ops[8].repr()],
op: UnaryOp::from_vec(smallvec::smallvec![ops[8].unary().unwrap()]),
op: UnaryOp::from_vec(smallvec::smallvec![make_unary_op(8)]),
};
assert_eq!(unary_op.reprs[0], "abs");
let nodes = vec![
Expand Down
Loading

0 comments on commit 2069b62

Please # to comment.