fix: refactor code, add test cases

This commit is contained in:
Charles
2024-11-26 22:08:06 -08:00
parent 5a34fe3d93
commit e246a313dd
5 changed files with 298 additions and 248 deletions
+54
View File
@@ -0,0 +1,54 @@
use crate::Roller;
#[derive(Debug)]
pub struct Buckets<const S: usize> {
_buckets: [u64; S],
_offsets: [i64; S],
}
impl<const S: usize> Buckets<S> {
/// Creates a new bucketer to store the values that can
/// be produced by roller in constant-spaced bucket. That
/// is, each bucket represents the count of a fixed range,
/// and all buckets have the same size range.
pub fn new<const R: usize>(roller: &Roller<R>) -> Self {
let min = roller.min();
let max = roller.max();
// Divide the number of buckets we have to work with by the range
// Store the step size
let step = ((max - min) as usize) / S + 1;
let mut _offsets = [0; S];
let mut cur = min;
for val in _offsets.iter_mut() {
*val = cur;
cur += step as i64;
}
Self {
_buckets: [0; S],
_offsets,
}
}
pub fn insert(&mut self, val: i64) {
// Binary search to find insertion point
let mut pp = self._offsets.partition_point(|p| *p < val);
if pp == self._buckets.len() {
// This value was beyond the last bucket; go ahead and store if in the
// last bucket
pp = self.buckets().len() - 1;
}
self._buckets[pp] += 1;
}
pub fn buckets(&self) -> &[u64] {
&self._buckets[..]
}
pub fn labels(&self) -> &[i64] {
&self._offsets[..]
}
pub fn max(&self) -> u64 {
*self._buckets.iter().max().unwrap()
}
}
+10 -247
View File
@@ -1,245 +1,8 @@
use nom::{
branch::alt,
character::complete::{self, multispace0},
combinator::opt,
IResult,
};
use rand::Rng;
mod bucket;
mod roller;
#[derive(Debug)]
pub struct Buckets<const S: usize> {
_buckets: [u64; S],
_offsets: [i64; S],
}
impl<const S: usize> Buckets<S> {
/// Creates a new bucketer to store the values that can
/// be produced by roller in constant-spaced bucket. That
/// is, each bucket represents the count of a fixed range,
/// and all buckets have the same size range.
pub fn new<const R: usize>(roller: &Roller<R>) -> Self {
let min = roller.min();
let max = roller.max();
// Divide the number of buckets we have to work with by the range
// Store the step size
let step = ((max - min) as usize) / S + 1;
let mut _offsets = [0; S];
let mut cur = min;
for val in _offsets.iter_mut() {
*val = cur;
cur += step as i64;
}
Self {
_buckets: [0; S],
_offsets,
}
}
pub fn insert(&mut self, val: i64) {
// Binary search to find insertion point
let mut pp = self._offsets.partition_point(|p| *p < val);
if pp == self._buckets.len() {
// This value was beyond the last bucket; go ahead and store if in the
// last bucket
pp = self.buckets().len() - 1;
}
self._buckets[pp] += 1;
}
pub fn buckets(&self) -> &[u64] {
&self._buckets[..]
}
pub fn labels(&self) -> &[i64] {
&self._offsets[..]
}
pub fn max(&self) -> u64 {
*self._buckets.iter().max().unwrap()
}
}
pub struct Roller<const S: usize> {
exprs: [Option<Cmd>; S],
}
impl<const S: usize> Roller<S> {
/// parse converts the str, of form:
/// 2d8+1d8+4
/// into a parsed expression.
pub fn parse(mut expr: &str) -> Result<Roller<S>, nom::Err<nom::error::Error<&str>>> {
let mut op = Oper::Add;
let mut exprs = [const { None }; S];
let mut i = 0;
while expr.len() > 0 {
let (e, term) = term(expr)?;
expr = e;
exprs[i] = Some(Cmd { term, oper: op });
i += 1;
if i == exprs.len() {
return Err(nom::Err::Incomplete(nom::Needed::new(S + 1)));
}
// Ignore trailing whitespace
let (e, _) = multispace0(expr)?;
expr = e;
// Get the next oper
if e.len() > 0 {
let (e, _op) = oper(expr)?;
op = _op;
expr = e;
}
}
Ok(Roller { exprs })
}
pub fn roll<R: Rng>(&self, rng: &mut R) -> i64 {
let mut sum = 0;
for expr in &self.exprs {
if expr.is_none() {
break;
}
let cmd = expr.as_ref().unwrap();
match cmd.oper {
Oper::Add => sum += cmd.term.val(rng) as i64,
Oper::Sub => sum -= cmd.term.val(rng) as i64,
};
}
sum
}
pub fn min(&self) -> i64 {
let mut sum = 0;
for expr in &self.exprs {
if expr.is_none() {
break;
}
let cmd = expr.as_ref().unwrap();
match cmd.oper {
Oper::Add => sum += cmd.term.min() as i64,
Oper::Sub => sum -= cmd.term.min() as i64,
};
}
sum
}
pub fn max(&self) -> i64 {
let mut sum = 0;
for expr in &self.exprs {
if expr.is_none() {
break;
}
let cmd = expr.as_ref().unwrap();
match cmd.oper {
Oper::Add => sum += cmd.term.max() as i64,
Oper::Sub => sum -= cmd.term.max() as i64,
};
}
sum
}
}
struct Cmd {
oper: Oper,
term: Term,
}
enum Term {
Roll(Roll),
Const(u64),
}
impl Term {
fn val<R: Rng>(&self, rng: &mut R) -> u64 {
match self {
Term::Const(c) => *c,
Term::Roll(r) => r.val(rng),
}
}
fn min(&self) -> u64 {
match self {
Term::Const(c) => *c,
Term::Roll(r) => r.min(),
}
}
fn max(&self) -> u64 {
match self {
Term::Const(c) => *c,
Term::Roll(r) => r.max(),
}
}
}
#[derive(Clone, Copy)]
enum Oper {
Add,
Sub,
}
pub struct Roll {
reps: u64,
dice: u64,
}
impl Roll {
fn val<R: Rng>(&self, rng: &mut R) -> u64 {
let mut total = 0;
for _ in 0..self.reps {
total += rng.gen::<u64>() % self.dice + 1;
}
total
}
fn min(&self) -> u64 {
// Roll a 1 on each dice
self.reps
}
fn max(&self) -> u64 {
// Roll max on each dice
self.reps * self.dice
}
}
fn term(e: &str) -> IResult<&str, Term> {
// Ignore whitespace
let (e, _) = multispace0(e)?;
alt((roll, cnst))(e)
}
fn roll(mut e: &str) -> IResult<&str, Term> {
let mut reps = 1;
if let (_e, Some(_reps)) = opt(complete::u64)(e)? {
e = _e;
reps = _reps;
}
let (e, _) = complete::char('d')(e)?;
let (e, dice) = complete::u64(e)?;
Ok((e, Term::Roll(Roll { reps, dice })))
}
fn cnst(e: &str) -> IResult<&str, Term> {
let (e, val) = complete::u64(e)?;
Ok((e, Term::Const(val)))
}
fn oper(e: &str) -> IResult<&str, Oper> {
// Ignore whitespace
let (e, _) = multispace0(e)?;
alt((add, sub))(e)
}
fn add(e: &str) -> IResult<&str, Oper> {
let (e, _) = complete::char('+')(e)?;
Ok((e, Oper::Add))
}
fn sub(e: &str) -> IResult<&str, Oper> {
let (e, _) = complete::char('-')(e)?;
Ok((e, Oper::Sub))
}
pub use bucket::Buckets;
pub use roller::Roller;
#[cfg(test)]
mod tests {
@@ -249,7 +12,7 @@ mod tests {
#[test]
fn roll_many_d6s() {
let mut rng = StdRng::seed_from_u64(1337);
let roller = Roller::<1024>::parse("1d6").unwrap();
let roller = Roller::<1024>::new_parse("1d6").unwrap();
let mut buckets = Buckets::<6>::new(&roller);
for _ in 0..1000 {
buckets.insert(roller.roll(&mut rng));
@@ -263,7 +26,7 @@ mod tests {
#[test]
fn roll_many_d6s_plus_2() {
let mut rng = StdRng::seed_from_u64(1337);
let roller = Roller::<1024>::parse("1d6+2").unwrap();
let roller = Roller::<1024>::new_parse("1d6+2").unwrap();
let mut buckets = Buckets::<6>::new(&roller);
for _ in 0..1000 {
buckets.insert(roller.roll(&mut rng));
@@ -282,7 +45,7 @@ mod tests {
#[test]
fn roll_many_d6s_plus_1d6() {
let mut rng = StdRng::seed_from_u64(1337);
let roller = Roller::<1024>::parse("1d6+1d6").unwrap();
let roller = Roller::<1024>::new_parse("1d6+1d6").unwrap();
let mut buckets = Buckets::<6>::new(&roller);
for _ in 0..1000 {
buckets.insert(roller.roll(&mut rng));
@@ -301,7 +64,7 @@ mod tests {
#[test]
fn roll_many_d6s_plus_1d6_minus_one() {
let mut rng = StdRng::seed_from_u64(1337);
let roller = Roller::<1024>::parse("1d6+1d6-1").unwrap();
let roller = Roller::<1024>::new_parse("1d6+1d6-1").unwrap();
let mut buckets = Buckets::<6>::new(&roller);
for _ in 0..1000 {
buckets.insert(roller.roll(&mut rng));
@@ -320,7 +83,7 @@ mod tests {
#[test]
fn negative_result() {
let mut rng = StdRng::seed_from_u64(1337);
let roller = Roller::<1024>::parse("1d6-5").unwrap();
let roller = Roller::<1024>::new_parse("1d6-5").unwrap();
for _ in 0..1000 {
roller.roll(&mut rng);
}
@@ -328,7 +91,7 @@ mod tests {
#[test]
fn min_max() {
let roller = Roller::<1024>::parse("2d6+1d8+1").unwrap();
let roller = Roller::<1024>::new_parse("2d6+1d8+1").unwrap();
assert_eq!(roller.min(), 4);
assert_eq!(roller.max(), 21);
}
+2 -1
View File
@@ -4,12 +4,13 @@ use std::io::{self, BufRead, Write};
use textplots::{Chart, Plot, Shape};
fn main() {
let mut roller = Roller::default();
let stdin = io::stdin();
print!("> ");
io::stdout().flush().unwrap();
for line in stdin.lock().lines() {
let line = line.unwrap();
let roller = match Roller::<1024>::parse(&line) {
match roller.parse(&line) {
Ok(r) => r,
Err(e) => {
println!("bad input; err: {}", e);
View File
+232
View File
@@ -0,0 +1,232 @@
use nom::{
branch::alt, character::complete::{self, multispace0}, combinator::opt, error::{Error, ErrorKind}, IResult
};
use rand::Rng;
pub struct Roller<const S: usize> {
exprs: [Option<Cmd>; S],
}
impl Default for Roller<256> {
fn default() -> Self {
Roller::new()
}
}
impl<const S: usize> Roller<S> {
pub fn new_parse(expr: &str) -> Result<Self, nom::Err<nom::error::Error<&str>>> {
let mut roller = Self::new();
roller.parse(expr)?;
Ok(roller)
}
/// new creates a new roller which does not contain an expression.
/// call 'parse' to load a new expression into the roller.
pub fn new() -> Self {
let exprs = [const { None }; S];
Roller { exprs }
}
/// parse converts the str, of form:
/// 2d8+1d8+4
/// into a parsed expression.
pub fn parse<'a, 'b>(&'a mut self, expr: &'b str) -> Result<(), nom::Err<nom::error::Error<&'b str>>> {
let limit = parse(expr, &mut self.exprs)?;
// Implementation detail of this struct; we don't zero out the
// expression between parses, instead, just set the n+1 one to
// None. This will cause roll to stop.
if limit+1 < self.exprs.len() {
self.exprs[limit+1] = None;
}
Ok(())
}
pub fn roll<R: Rng>(&self, rng: &mut R) -> i64 {
let mut sum = 0;
for expr in &self.exprs {
if expr.is_none() {
break;
}
let cmd = expr.as_ref().unwrap();
match cmd.oper {
Oper::Add => sum += cmd.term.val(rng) as i64,
Oper::Sub => sum -= cmd.term.val(rng) as i64,
};
}
sum
}
pub fn min(&self) -> i64 {
let mut sum = 0;
for expr in &self.exprs {
if expr.is_none() {
break;
}
let cmd = expr.as_ref().unwrap();
match cmd.oper {
Oper::Add => sum += cmd.term.min() as i64,
Oper::Sub => sum -= cmd.term.min() as i64,
};
}
sum
}
pub fn max(&self) -> i64 {
let mut sum = 0;
for expr in &self.exprs {
if expr.is_none() {
break;
}
let cmd = expr.as_ref().unwrap();
match cmd.oper {
Oper::Add => sum += cmd.term.max() as i64,
Oper::Sub => sum -= cmd.term.max() as i64,
};
}
sum
}
}
fn parse<'a, 'b>(mut expr: &'a str, exprs: &'b mut [Option<Cmd>]) -> Result<usize, nom::Err<nom::error::Error<&'a str>>> {
let mut op = Oper::Add;
let mut i = 0;
while expr.len() > 0 {
let (e, term) = term(expr)?;
expr = e;
exprs[i] = Some(Cmd { term, oper: op });
i += 1;
if i == exprs.len() {
return Err(nom::Err::Failure(Error::new(expr, ErrorKind::TooLarge)));
}
// Ignore trailing whitespace
let (e, _) = multispace0(expr)?;
expr = e;
// Get the next oper
if e.len() > 0 {
let (e, _op) = oper(expr)?;
op = _op;
expr = e;
}
}
Ok(i)
}
struct Cmd {
oper: Oper,
term: Term,
}
enum Term {
Roll(Roll),
Const(u64),
}
impl Term {
fn val<R: Rng>(&self, rng: &mut R) -> u64 {
match self {
Term::Const(c) => *c,
Term::Roll(r) => r.val(rng),
}
}
fn min(&self) -> u64 {
match self {
Term::Const(c) => *c,
Term::Roll(r) => r.min(),
}
}
fn max(&self) -> u64 {
match self {
Term::Const(c) => *c,
Term::Roll(r) => r.max(),
}
}
}
#[derive(Clone, Copy)]
enum Oper {
Add,
Sub,
}
struct Roll {
reps: u64,
dice: u64,
}
impl Roll {
fn val<R: Rng>(&self, rng: &mut R) -> u64 {
let mut total = 0;
for _ in 0..self.reps {
total += rng.gen::<u64>() % self.dice + 1;
}
total
}
fn min(&self) -> u64 {
// Roll a 1 on each dice
self.reps
}
fn max(&self) -> u64 {
// Roll max on each dice
self.reps * self.dice
}
}
fn term(e: &str) -> IResult<&str, Term> {
// Ignore whitespace
let (e, _) = multispace0(e)?;
alt((roll, cnst))(e)
}
fn roll(mut e: &str) -> IResult<&str, Term> {
let mut reps = 1;
if let (_e, Some(_reps)) = opt(complete::u64)(e)? {
e = _e;
reps = _reps;
}
let (e, _) = complete::char('d')(e)?;
let (e, dice) = complete::u64(e)?;
Ok((e, Term::Roll(Roll { reps, dice })))
}
fn cnst(e: &str) -> IResult<&str, Term> {
let (e, val) = complete::u64(e)?;
Ok((e, Term::Const(val)))
}
fn oper(e: &str) -> IResult<&str, Oper> {
// Ignore whitespace
let (e, _) = multispace0(e)?;
alt((add, sub))(e)
}
fn add(e: &str) -> IResult<&str, Oper> {
let (e, _) = complete::char('+')(e)?;
Ok((e, Oper::Add))
}
fn sub(e: &str) -> IResult<&str, Oper> {
let (e, _) = complete::char('-')(e)?;
Ok((e, Oper::Sub))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parser_tests() -> Result<(), Box<dyn std::error::Error>> {
let mut exprs = [const {None}; 256];
parse("1d8", &mut exprs)?;
parse(" 1d8 ", &mut exprs)?;
parse("1+1d8", &mut exprs)?;
parse("1d8+1", &mut exprs)?;
parse("d8", &mut exprs)?;
parse("1", &mut exprs)?;
parse("1d8 + 1d8 + 2", &mut exprs)?;
Ok(())
}
}