diff --git a/src/structure/matrix.rs b/src/structure/matrix.rs index 72220e2..7e0241d 100644 --- a/src/structure/matrix.rs +++ b/src/structure/matrix.rs @@ -3293,6 +3293,9 @@ impl LinearAlgebra for Matrix { /// /// Implementation of [RosettaCode](https://rosettacode.org/wiki/Reduced_row_echelon_form) fn rref(&self) -> Matrix { + let max_abs = self.data.iter().fold(0f64, |acc, &x| acc.max(x.abs())); + let epsilon = (max_abs * 1e-12).max(1e-15); + let mut lead = 0usize; let mut result = self.clone(); 'outer: for r in 0..self.row { @@ -3300,7 +3303,7 @@ impl LinearAlgebra for Matrix { break; } let mut i = r; - while result[(i, lead)] == 0f64 { + while result[(i, lead)].abs() < epsilon { i += 1; if self.row == i { i = r; @@ -3314,7 +3317,7 @@ impl LinearAlgebra for Matrix { result.swap(i, r, Row); } let tmp = result[(r, lead)]; - if tmp != 0f64 { + if tmp.abs() > epsilon { unsafe { result.row_mut(r).iter_mut().for_each(|t| *(*t) /= tmp); } diff --git a/tests/matrix.rs b/tests/matrix.rs index defb914..3ce5d3c 100644 --- a/tests/matrix.rs +++ b/tests/matrix.rs @@ -111,3 +111,54 @@ fn test_kronecker() { let c1 = a1.kronecker(&b1); assert_eq!(c1, ml_matrix("0 5 0 10;6 7 12 14;0 15 0 20;18 21 24 28")); } + +#[test] +fn test_rref() { + let a = ml_matrix( + r#" + -3 2 -1 -1; + 6 -6 7 -7; + 3 -4 4 -6"#, + ); + let b = a.rref(); + + assert_eq!( + b, + ml_matrix( + r#" + 1 0 0 2; + 0 1 0 2; + 0 0 1 -1"# + ) + ); +} + +#[test] +fn test_rref_unstable() { + let epsilon = 1e-10; + + // this matrix used to become unstable during rref + let a = ml_matrix( + r#" + 1 1 0 0 0 1 0 1 31; + 1 1 1 1 0 0 1 1 185; + 0 0 1 0 0 1 1 1 165; + 1 0 1 0 1 1 0 1 32; + 1 0 1 0 0 0 1 1 174; + 0 0 1 0 1 1 1 1 171; + 0 1 1 0 1 1 0 1 27; + 1 0 0 1 0 1 0 0 20; + 1 0 1 1 0 1 0 0 23"#, + ); + + let b = a.rref(); + + // creating a row like "0 0 0 0 0 0 0 0 1" will "prove" 0 == 1 + // which is a tell of numeric instability + for row in 0..b.row { + let ends_in_1 = (b[(row, b.col - 1)] - 1.0).abs() < epsilon; + let rest_zeroes = (0..b.col - 1).all(|col| b[(row, col)].abs() < epsilon); + + assert!(!(ends_in_1 && rest_zeroes)); + } +}