;;; matrix.ss from Schemathics -- Noel Welsh (module matrix mzscheme (provide copy-matrix matrix-ref matrix-set! lu-factor! lu-factor/pivoting! solve-lower-triangle solve-upper-triangle make-solver make-solver/pivoting invert-matrix row-reduce! row-echelon-form! reduced-row-echelon-form! ) ;;; Macro to loop over integers from n (inclusive) to m (exclusive) (define-syntax for (syntax-rules () ((for (i n m) forms ...) (let ((fixed-m m)) (let loop ((i n)) (if (< i fixed-m) (begin forms ... (loop (+ i 1))))))))) ;;; Loop over exactly the same range as for (for ...), but do it ;;; backwards (define-syntax reverse-for (syntax-rules () ((for (i n m) forms ...) (let ((fixed-n n)) (let loop ((i (- m 1))) (if (>= i fixed-n) (begin forms ... (loop (- i 1))))))))) ;;; Tabulate a procedure in a vector (define (tabulate-vector proc size) (let ((vec (make-vector size))) (for (i 0 size) (vector-set! vec i (proc i))) vec)) ;;; Vector analog of `map' (define (map-vector proc v) (tabulate-vector (lambda (i) (proc (vector-ref v i))) (vector-length v))) ;;; Copy a vector (define (copy-vector v) (map-vector (lambda (x) x) v)) ;;; Copy a matrix (define (copy-matrix m) (map-vector copy-vector m)) (define (matrix-ref matrix row col) (vector-ref (vector-ref matrix row) col)) (define (matrix-set! matrix row col value) (vector-set! (vector-ref matrix row) col value)) (define (tabulate-matrix proc num-rows num-cols) (tabulate-vector (lambda (i) (tabulate-vector (lambda (j) (proc i j)) num-cols)) num-rows)) (define (transpose-matrix m) (let ((num-rows (vector-length m)) (num-cols (vector-length (vector-ref m 0)))) (tabulate-matrix (lambda (i j) (matrix-ref m j i)) num-cols num-rows))) ;;; Make the i-th unit vector of length n. (define (unit-vector i n) (tabulate-vector (lambda (j) (if (= i j) 1 0)) n)) ;;; Make the identity matrix of dimension n (define (identity-matrix n) (tabulate-vector (lambda (i) (unit-vector i n)) n)) ;;; Make the identity permutation of length n (define (identity-permutation n) (tabulate-vector (lambda (i) i) n)) ;;; Apply a permutation to a vector. ;;; If the vector is also a permutation, then this is ;;; permutation multiplication. (define (apply-permutation perm v) (tabulate-vector (lambda (i) (vector-ref v (vector-ref perm i))) (vector-length perm))) ;;; Swap two elements in a vector. ;;; Works OK if i = j. (define (vector-swap! v i j) (let ((vi (vector-ref v i)) (vj (vector-ref v j))) (vector-set! v i vj) (vector-set! v j vi))) ;;; Find an optimal pivot in matrix m. ;;; Search in column col, from row to bottom of the column. (define (find-pivot m row col) (let ((best-row row) (best-value (abs (matrix-ref m row col)))) (for (r (+ row 1) (vector-length m)) (let ((value (abs (matrix-ref m r col)))) (if (> value best-value) (begin (set! best-row r) (set! best-value value))))) best-row)) ;;; Find an optimal pivot and perform the swap, ;;; also on the permutation. (define (perform-pivot! perm m row col) (let ((pivot (find-pivot m row col))) (vector-swap! m row pivot) (vector-swap! perm row pivot))) ;;; Add alpha*row1 to row2, starting from column col to col-end (define (add-row! m alpha row1 row2 col col-end) (for (c col col-end) (matrix-set! m row2 c (+ (matrix-ref m row2 c) (* alpha (matrix-ref m row1 c)))))) ;;; Eliminate row2 by using row1 (define (eliminate-row! m row1 row2 col col-end) (let ((alpha (/ (matrix-ref m row2 col) (matrix-ref m row1 col)))) (add-row! m (- alpha) row1 row2 col col-end) ; (matrix-set! m row2 col alpha) )) ; eliminate row-from to row-to (exclusively) using row (define (eliminate-rows! m row row-from row-to col col-end) (for (i row-from row-to) (eliminate-row! m row i col col-end))) ;;; Make LU factorisation without pivoting (define (lu-factor! m) (let ((size (vector-length m))) (for (col 0 (- size 1)) (for (row (+ col 1) size) (eliminate-row! m col row col size))))) ;;; Make LU factorisation with pivoting. ;;; Return the permutation. (define (lu-factor/pivoting! m) (let* ((size (vector-length m)) (perm (identity-permutation size))) (for (col 0 (- size 1)) (perform-pivot! perm m col col) (for (row (+ col 1) size) (eliminate-row! m col row col size))) perm)) ;;; Scale row with factor alpha starting from column col to col-end (define (scale-row! m alpha row col col-end) (for (c col col-end) (matrix-set! m row c (* alpha (matrix-ref m row c))))) ; perform the wanted type of row reduction on the matrix m, ; return the list of pivot columns (define (perform-row-reduction! m echelon-form? reduced?) (let* ((row-end (vector-length m)) (col-end (vector-length (vector-ref m 0))) (perm (identity-permutation row-end)) (pivot-row 0) (pivots '())) (for (col 0 col-end) (unless (= pivot-row row-end) (perform-pivot! perm m pivot-row col) (unless (= (matrix-ref m pivot-row col) 0) (set! pivots (cons col pivots)) (if echelon-form? ; the pivots are ones in echelon form (scale-row! m (/ (matrix-ref m pivot-row col)) pivot-row col col-end)) (eliminate-rows! m pivot-row (+ pivot-row 1) row-end col col-end) (if reduced? ; in reduced echolon from the pivot is the only non-zero element (eliminate-rows! m pivot-row 0 pivot-row col col-end)) (set! pivot-row (+ pivot-row 1))))) (reverse pivots))) (define (row-reduce! m) (perform-row-reduction! m #f #f)) (define (row-echelon-form! m) (perform-row-reduction! m #t #f)) (define (reduced-row-echelon-form! m) (perform-row-reduction! m #t #t)) ;;; Compute a "partial inproduct", i.e. an inproduct of a slice ;;; of both vectors, indicated by [n,m) (define (partial-inproduct v1 v2 n m) (do ([i n (+ i 1)] [sum 0 (+ sum (* (vector-ref v1 i) (vector-ref v2 i)))]) [(= i m) sum])) ;;; Solve the system L x = b, for some lower-triangle matrix L ;;; with 1's on the diagonal. Only the part of l under the diagonal ;;; is actually inspected, so the rest can contain arbitrary data. ;;; (E.g. the U part of the LU decomposition which generated L.) (define (solve-lower-triangle l b) (let* ((x-len (vector-length l)) (x (make-vector x-len))) (for (i 0 x-len) (vector-set! x i (- (vector-ref b i) (partial-inproduct (vector-ref l i) x 0 i)))) x)) ;;; Solve the system U x = b, for some upper-triangle matrix u. ;;; Only the part of u up and above the diagonal ;;; is actually inspected, so the rest can contain arbitrary data. ;;; (E.g. the L part of the LU decomposition which generated U.) (define (solve-upper-triangle u b) (let* ((x-len (vector-length u)) (x (make-vector x-len))) (reverse-for (i 0 x-len) (vector-set! x i (/ (- (vector-ref b i) (partial-inproduct (vector-ref u i) x (+ i 1) x-len)) (matrix-ref u i i)))) x)) ;;; Return a "solver procedure" for the matrix. ;;; The result can be used to find the solution in some point, e.g. ;;; (define solver (make-solver test-matrix)) ;;; (display (solver some-vector)) (define (make-solver m) (let ((lu (copy-matrix m))) (lu-factor! lu) (lambda (v) (solve-upper-triangle lu (solve-lower-triangle lu v))))) ;;; Return a "solver procedure" for the matrix, with pivoting. (define (make-solver/pivoting m) (let* ((lu (copy-matrix m)) (perm (lu-factor/pivoting! lu))) (lambda (v) (let ((v (apply-permutation perm v))) (solve-upper-triangle lu (solve-lower-triangle lu v)))))) ;;; Compute the inverse of a matrix. ;;; Do NOT use this procedure if all you want to do is solving ;;; a system Ax = b. In that case, use make-solver/pivoting. (define (invert-matrix m) (let ((solver (make-solver/pivoting m)) (unit (identity-matrix (vector-length m)))) (transpose-matrix (map-vector solver unit)))) )