Rust Basics: the add Trait

Rust Basics: the add Trait

·

0 min read

In my journey for better understanding of anything I tend to always return to the basics.

To often we base our assumptions on blind 'guess-timations' where we don't understand why something is actually happening but we observed that certain patterns lead to success.

Which can be great for beginners, but the longer you do things the less you want to rely on blind luck.

So for our last weeks rust hacking Session at the metalab vienna we picked something simple. Implementing your own add trait for a custom fixed-point datatype.

Sound's easy, right? Maybe you hear the subtle giggle in the back where the mathematicians are. Everything is easy ... as long as you just have to apply trained concepts and not build them.

So let's start with our datatype:

pub struct Fixed {
    integer: i32,
    decimal: u32,
}

we have:

  • integer before the comma (can be signed)
  • decimal after the comma (can't be signed)

Pretty easy so far isn't it?

So the first thing we want to do is print our datatype. Please remember this is about basics so I try to not use macros for the moment, but implement it based on traits.

So we're looking at: doc.rust-lang.org/std/fmt/trait.Display.html

and can deduce:

use std::fmt;

impl fmt::Display for Fixed {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}.{}", &self.integer, &self.decimal)
    }
}

our code will now look like:

use std::fmt;

pub struct Fixed {
    integer: i32,
    decimal: u32,
}

impl fmt::Display for Fixed {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}.{}", &self.integer, &self.decimal)
    }
}

fn main() {
    println!("{}", Fixed {integer: 1, decimal: 0});
}

and if we run it it will display:

1.0

link to this example on the rust playground

so we can now display our value :) lets get to the next step the real task: Implementing the Add trait for our struct.

So why do we want to have the Add trait? So we can write:

println!("{} + {} = {}", fixed1, fixed2, fixed1 + fixed2)

I don't know it about all operations but at least for the '+' symbol in rust the Add trait is the resolving operation.

Because we're lazy, at least I am ;), we don't want to write:

Fixed {
  integer: i32,
  decimal: u32
}

all the time. instead lets use a common pattern in rust. The ::from builder

impl Fixed {
    pub fn from(integer: i32, decimal: u32) -> Fixed {
        Fixed {
            integer,
            decimal
        }
    }
}

so we can now write:

let fixed1 = Fixed::from(1, 1);

link to this example on the rust playground

And than we implement the Add trait according to the documentation

impl Add for Fixed {
    type Output = Fixed;

    fn add(self, rhs: Self) -> Self::Output {
        return Fixed {
            integer: self.integer + rhs.integer,
            decimal: self.decimal + rhs.decimal
        }
    }
}

I purposely skipped the details for the Display trait but the Add trait is important so lets look at what it's built of to get a better understanding of this operational trait pattern

So what should add do?

a + b = c

so we have a which is on the left-hand-side (LHS) of the plus-symbol (+) and we have b which is on the right-hand-side (RHS) of it.

adding the values of both will than return a 3rd value which is the output c.

I will leave it at that. We have some set-theory and type cohesion topics where we could go: ℕ ∈ Z ∈ ℚ. We could dip our toes now into. What can be added to what as well as a general direction of change.

Lets just assume that we only add our Fixed with Fixed and we stay within our type always.

To quote every math teach ever ... let's assume there are only Fixed points numbers. ;)

So we're done? Nope ... not really

The next step would be:

use std::fmt;
use std::ops::Add;

pub struct Fixed {
    integer: i32,
    decimal: u32,
}

impl fmt::Display for Fixed {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}.{}", &self.integer, &self.decimal)
    }
}

impl Add for Fixed {
    type Output = Fixed;

    fn add(self, rhs: Self) -> Self::Output {
        return Fixed {
            integer: self.integer + rhs.integer,
            decimal: self.decimal + rhs.decimal
        }
    }
}


impl Fixed {
    pub fn from(integer: i32, decimal: u32) -> Fixed {
        Fixed {
            integer,
            decimal
        }
    }
}

fn main() {
    let fixed1 = Fixed::from(1, 1);
    let fixed2 = Fixed::from(1, 1);
    let result = fixed1 + fixed2;

    println!("{} + {} = {}", fixed1, fixed2, result)
}

which will lead to a compile fail since we need to use our variables as values. so we need to implement the copy trait for them.

impl Copy for Fixed { }

impl Clone for Fixed {
    fn clone(&self) -> Fixed {
        *self
    }
}

We need the clone for the copy trait.

So with this link to the playground you have fully functional positive number, non overflow example.

Are we done? ... Nope ....

  • we are only done if we never have to carry a number from the decimal realm to integer realm.
  • and if we only add two positive numbers

So lets go for the negative number issue.

as soon as one integer part is negative the decimals get subtracted.

fn main() {
    let fixed1 = Fixed::from(-1, 1);
    let fixed2 = Fixed::from(1, 1);
    let result = fixed1 + fixed2;

    println!("{} + {} = {}", fixed1, fixed2, result)
}

currently will return

0.2

which is not really what we would expect ;D or should we do oldschool math and just define it as behaviour so it's not wrong? ;)

Jokes aside one possible solution would be:

impl Add for Fixed {
    type Output = Fixed;

    fn add(self, rhs: Self) -> Self::Output {
        if rhs.integer < 0 || self.integer < 0 {
            return Fixed {
                integer: self.integer + rhs.integer,
                decimal: self.decimal - rhs.decimal
            }
        }

        Fixed {
            integer: self.integer + rhs.integer,
            decimal: self.decimal + rhs.decimal
        }
    }
}

But as we should know this is only valid if there is no carry. But since it's and addition with negative numbers we don't have to care which one is negative thank you commutative-law.

-1.5 + 1.5  = 0

only issue is ... what if both are negative?

-1.1 + -1.1 = -2.2

an not

-1.1 + -1.1 = 2.0

so we we need to be specific only if one of them is negative we substract the decimals from each other.

if (rhs.integer < 0 && self.integer > 0) || (self.integer < 0 && rhs.integer > 0) {
  return Fixed {
    integer: self.integer + rhs.integer,
    decimal: self.decimal - rhs.decimal
  }
}

okay this works so far but maybe some of you already spotted the next issue?

what about

  1.14
- 1.16

this case? this will fail subtracting 14 - 16 for an unsigned integer leads to an overflow exception.

fn main() {
    let fixed1 = Fixed::from(1, 14);
    let fixed2 = Fixed::from(-1, 16);
    let result = fixed1 + fixed2;

    println!("{} + {} = {}", fixed1, fixed2, result)
}

thread 'main' panicked at 'attempt to subtract with overflow' .....

since this is an addition we can use the commutative law and just flip it and it remains the same but we need to this programmatic.

so we need to make sure the number with the bigger number is always on the left

impl Add for Fixed {
    type Output = Fixed;

    fn add(self, rhs: Self) -> Self::Output {
        // we now take the bigger one on the left side of the plus 
        let lhs_decimal = max(self.decimal, rhs.decimal);
        // and the smaller one on the right side of the plus
        let rhs_decimal = min(self.decimal, rhs.decimal);

        if (rhs.integer < 0 && self.integer > 0) || (self.integer < 0 && rhs.integer > 0) {
            return Fixed {
                integer: self.integer + rhs.integer,
                decimal: lhs_decimal - rhs_decimal // applied here
            }
        }

        Fixed {
            integer: self.integer + rhs.integer,
            decimal: self.decimal + rhs.decimal
        }
    }
}

Okay so one problem less, but far from over. Lets get to the problem. the interaction of the realms. We want our decimal to carry to our integer realm if necessary

1.9 + 0.1 = 2

how do we realize if we need to carry something? Maybe there is a better way for it, but we used the magnitude of the exponent.

1.9 -> exp 10^-1
1.1 -> exp 10^-1
2.0 -> exp 10^0
-------------------------
in our datastructure

1.9 -> exp 10^-1
1.1 -> exp 10^-1
1.10 -> exp 10^-2 = 2.0
-------------------------

another way would be the assumption that the sum of the left most number is not allowed to get smaller than the value of the lower one.

in our case

1.9 
1.1 
----- 
3.0

9 + 1 = 1 [0] smaller than 9 
9 + 9 = 1 [8] smaller than 9
7 + 7 = 1 [4] smaller than 7
7 + 3 = 1 [0] smaller than 7

which can be seen as binary value 

00010001
00000001
----------------
0001001[0] smaller than 1

which leads to a nice small mechanic :) but first we need to do something else we need to create numbers of the same magnitude.

1.1000
1.100

in our system is not seen as 0.1 + 0.1 so we need to equalize the magnitude. to do that we need to get the log10 of the number :) because that's our power ;D

so first we need our magnitude calculated. I took this code and converted it to:

fn magnitude(decimal: u32) -> u32 {
    match decimal {
        _ if decimal >= 1000000000 => 9,
        _ if decimal >= 100000000 => 8,
        _ if decimal >= 10000000 => 7,
        _ if decimal >= 1000000 => 6,
        _ if decimal >= 100000 => 5,
        _ if decimal >= 10000 => 4,
        _ if decimal >= 1000 => 3,
        _ if decimal >= 100 => 2,
        _ => 1,
    }
}

and to reduce the function calls

#[inline(always)]
fn magnitude(decimal: u32) -> u32 {

also the trait got a little bit modified

impl Add for Fixed {
    type Output = Fixed;

    fn add(self, rhs: Self) -> Self::Output {
        let lhs_decimal = max(self.decimal, rhs.decimal);
        let rhs_decimal = min(self.decimal, rhs.decimal);

        // here are our magnitude calculations
        let lhs_magnitude = magnitude(lhs_decimal);
        let rhs_decimal_prepared = 10u32.pow(lhs_magnitude-1) * rhs_decimal;

        if (rhs.integer < 0 && self.integer > 0) || (self.integer < 0 && rhs.integer > 0) {
            return Fixed {
                integer: self.integer + rhs.integer,
                decimal: lhs_decimal - rhs_decimal_prepared
            }
        }

        Fixed {
            integer: self.integer + rhs.integer,
            decimal: lhs_decimal + rhs_decimal_prepared
        }
}

the working example + unitests on the playground

So now we can add and subtract correct as long as we don't have ... a carry ;)

Now for the carry :) if the pow10 exponent of the result is bigger than the pow10 exponent of the original lhs_decimal we have a +1 to the integer :)

also we need to stay in the lhs_decimal pow10 exponent so we need to substract the extra pow10 from the result again because this power is the carried +1-

  let mut decimal_result: u32;
        let mut carry = 0;


        if (rhs.integer.is_negative() && self.integer > 0) || (self.integer.is_negative() && rhs.integer > 0) {
            decimal_result = lhs_decimal - rhs_decimal_prepared;
        } else {
            decimal_result = lhs_decimal + rhs_decimal_prepared;
        }


        let result_magnitude = magnitude(decimal_result);
        if magnitude(decimal_result) > lhs_magnitude {
            carry = 1;
            decimal_result = decimal_result - 10u32.pow(result_magnitude)
        }

This will lead us to:

use std::fmt;
use std::ops::Add;
use std::cmp::{max, min};

pub struct Fixed {
    integer: i32,
    decimal: u32,
}

impl fmt::Display for Fixed {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}.{}", &self.integer, &self.decimal)
    }
}

impl Add for Fixed {
    type Output = Fixed;

    fn add(self, rhs: Self) -> Self::Output {
        let lhs_decimal = max(self.decimal, rhs.decimal);
        let rhs_decimal = min(self.decimal, rhs.decimal);
        let lhs_magnitude = magnitude(lhs_decimal);
        let rhs_decimal_prepared = 10u32.pow(lhs_magnitude-1) * rhs_decimal;

        let mut decimal_result: u32;
        let mut carry = 0;


        if (rhs.integer.is_negative() && self.integer > 0) || (self.integer.is_negative() && rhs.integer > 0) {
            decimal_result = lhs_decimal - rhs_decimal_prepared;
        } else {
            decimal_result = lhs_decimal + rhs_decimal_prepared;
        }


        let result_magnitude = magnitude(decimal_result);
        if magnitude(decimal_result) > lhs_magnitude {
            carry = 1;
            decimal_result = decimal_result - 10u32.pow(result_magnitude)
        }

        Fixed {
            integer: self.integer + rhs.integer + carry,
            decimal: decimal_result
        }
    }
}

#[inline(always)]
fn magnitude(decimal: u32) -> u32 {
    match decimal {
        _ if decimal >= 1000000000 => 9,
        _ if decimal >= 100000000 => 8,
        _ if decimal >= 10000000 => 7,
        _ if decimal >= 1000000 => 6,
        _ if decimal >= 100000 => 5,
        _ if decimal >= 10000 => 4,
        _ if decimal >= 1000 => 3,
        _ if decimal >= 100 => 2,
        _ => 1,
    }
}


impl Copy for Fixed { }

impl Clone for Fixed {
    fn clone(&self) -> Fixed {
        *self
    }
}



impl Fixed {
    pub fn from(integer: i32, decimal: u32) -> Fixed {
        Fixed {
            integer,
            decimal
        }
    }
}

fn main() {
    let fixed1 = Fixed::from(1, 940);
    let fixed2 = Fixed::from(1, 16);
    let result = fixed1 + fixed2;

    println!("{} + {} = {}", fixed1, fixed2, result)
}

#[cfg(test)]
mod tests {
    // Note this useful idiom: importing names from outer (for mod tests) scope.
    use super::*;

    #[test]
    fn test_add_integer() {
        let f1 = Fixed::from(1, 0);
        let f2 = Fixed::from(1, 0);

        let result = f1 + f2;

        assert_eq!(2, result.integer)
    }

    #[test]
    fn test_add_negative_integer() {
        let f1 = Fixed::from(1, 0);
        let f2 = Fixed::from(-1, 0);

        let result = f1 + f2;

        assert_eq!(0, result.integer)
    }

    #[test]
    fn test_two_add_negative_integer() {
        let f1 = Fixed::from(-1, 0);
        let f2 = Fixed::from(-1, 0);

        let result = f1 + f2;

        assert_eq!(-2, result.integer)
    }

    #[test]
    fn test_add_decimal() {
        let f1 = Fixed::from(0, 10);
        let f2 = Fixed::from(1, 0);

        let result = f1 + f2;

        assert_eq!(10, result.decimal)
    }

    #[test]
    fn test_add_different_magnitude_decimal() {
        let f1 = Fixed::from(0, 10);
        let f2 = Fixed::from(1, 100);

        let result = f1 + f2;

        assert_eq!(200, result.decimal)
    }

    #[test]
    fn test_add_negative_magnitude_decimal() {
        let f1 = Fixed::from(-1, 10);
        let f2 = Fixed::from(1, 100);

        let result = f1 + f2;

        assert_eq!(0, result.decimal)
    }

   /* #[test] // currently failing (negative zero problem)
    fn test_add_negative_zero_decimal() {
        let f1 = Fixed::from(-0, 10);
        let f2 = Fixed::from(1, 100);

        let result = f1 + f2;

        assert_eq!(0, result.decimal)
    }
    */

}

And now we can Add our custom Fixed floating point more or less mathematically correct. I guess all of it can be optimized and it's not very practical.

For example instead of calculating the exponent it would make sense to save it on initialization and so on.

But it is a fun exercise to build such a thing :)

Thanks for reading :)