본문 바로가기

Zig

임의의 크기 행렬 구현

const std = @import("std");

pub const matrix4x4 = matrix(f32, 4, 4);
pub const matrix3x3 = matrix(f32, 3, 3);

pub const matrix_error = error{not_exist_inverse_matrix};

///  row ↕, col ↔
pub fn matrix(comptime T: type, row: comptime_int, col: comptime_int) type {
    switch (@typeInfo(T)) {
        .Int, .Float, .ComptimeInt, .ComptimeFloat => {},
        else => {
            @compileError("not a number type");
        },
    }

    return struct {
        const Self = @This();
        e: [row][col]T,

        pub fn init() Self {
            return Self{
                .e = .{.{0} ** col} ** row,
            };
        }
        pub fn identity() Self {
            if (col == row) { //identity matrix
                var result: Self = undefined;
                comptime var i = 0;
                inline while (i < row) : (i += 1) {
                    comptime var j = 0;
                    inline while (j < col) : (j += 1) {
                        if (i == j) {
                            result.e[i][j] = 1;
                        } else {
                            result.e[i][j] = 0;
                        }
                    }
                }
                return result;
            } else {
                @compileError("identity : not a identity matrix");
            }
        }
        pub fn addition(self: *const Self, _matrix: *const Self) Self {
            var result: Self = self;
            comptime var r = 0;
            comptime var c = 0;
            inline while (r < row) : (r += 1) {
                c = 0;
                inline while (c < col) : (c += 1) {
                    result.e[r][c] += _matrix.e[r][c];
                }
            }
            return result;
        }
        pub fn subtract(self: *const Self, _matrix: *const Self) Self {
            var result: Self = self;
            comptime var r = 0;
            comptime var c = 0;
            inline while (r < row) : (r += 1) {
                c = 0;
                inline while (c < col) : (c += 1) {
                    result.e[r][c] -= _matrix.e[r][c];
                }
            }
            return result;
        }
        ///[row x COL] = [row x col][col x COL] COL 은 _matrix 행렬의 열(column) 갯수입니다.
        pub fn multiply(self: *const Self, COL: comptime_int, _matrix: *const matrix(T, col, COL)) matrix(T, row, COL) {
            var result: matrix(T, row, COL) = matrix(T, row, COL).init();
            comptime var r = 0;
            comptime var c = 0;
            comptime var n = 0;
            inline while (r < row) : (r += 1) {
                c = 0;
                inline while (c < COL) : (c += 1) {
                    n = 0;
                    inline while (n < COL) : (n += 1) {
                        result.e[r][c] += self.*.e[r][n] * _matrix.e[n][c];
                    }
                }
            }
            return result;
        }
        fn swap_row(self: *Self, i: isize, j: isize) void {
            if (i == j) return;
            var k: usize = 0;
            while (k < row) : (k += 1) {
                std.mem.swap(T, &self.e[@intCast(i)][k], &self.e[@intCast(j)][k]);
            }
        }
        pub fn transpose(self: *Self) matrix(T, col, row) {
            var result: matrix(T, col, row) = undefined;
            var r: i32 = 0;
            while (r < row) : (r += 1) {
                var c: i32 = 0;
                while (c < col) : (c += 1) {
                    result.e[c][r] = self.*.e[r][c];
                }
            }
            return result;
        }
        fn det(n: comptime_int, _matrix: [n][n]T) T {
            if (n == 1) return _matrix[0][0];

            var minor_matrix: [n][n - 1][n - 1]T = undefined;
            var k: usize = 0;
            while (k < n) : (k += 1) {
                var i: usize = 0;
                while (i < (n - 1)) : (i += 1) {
                    var j: usize = 0;
                    while (j < n) : (j += 1) {
                        if (j < k) {
                            minor_matrix[k][i][j] = _matrix[i + 1][j];
                        } else if (j > k) {
                            minor_matrix[k][i][j - 1] = _matrix[i + 1][j];
                        }
                    }
                }
            }
            var sum: T = 0;
            var test_: T = 1;
            k = 0;
            while (k < n) : (k += 1) {
                sum += test_ * _matrix[0][k] * det(n - 1, minor_matrix[k]);
                test_ *= -1;
            }
            return sum;
        }
        ///https://nate9389.tistory.com/63
        pub fn determinant(self: *Self) T {
            if (col != row) @compileError("determinant : not a identity matrix");
            return det(row, self.e);
        }
        ///https://blog.naver.com/lovebuthate/221153359469
        pub fn inverse(self: *Self) !Self {
            if (col != row) @compileError("inverse : not a identity matrix");

            const nn = col; // 행 열이 어짜피 같으므로 nn 변수로 통일
            var a: Self = self.*;
            var b: Self = identity();
            var k: isize = 0;
            while (k < nn) : (k += 1) {
                var t = k - 1;
                while (t + 1 < nn and self.e[@intCast(t + 1)][@intCast(k)] == 0) : (t += 1) {}
                if (t == k - 1) t += 1;
                if (t == nn - 1 and self.e[@intCast(t)][@intCast(k)] == 0) return matrix_error.not_exist_inverse_matrix;
                a.swap_row(k, t);
                b.swap_row(k, t);
                const d = a.e[@intCast(k)][@intCast(k)];
                var j: usize = 0;
                //k행 k열에 해당하는 수로 k행의 각 숫자를 나눔
                while (j < nn) : (j += 1) {
                    a.e[@intCast(k)][j] /= d;
                    b.e[@intCast(k)][j] /= d;
                }
                //k행을 제외한 다른 행에 숫자를 곱하고 더하는 과정
                var i: usize = 0;
                while (i < nn) : (i += 1) {
                    if (i != k) {
                        const m = a.e[i][@intCast(k)];
                        var ii: usize = 0;
                        while (ii < nn) : (ii += 1) {
                            if (ii >= k) a.e[i][ii] -= a.e[@intCast(k)][ii] * m;
                            b.e[i][ii] -= b.e[@intCast(k)][ii] * m;
                        }
                    }
                }
            }
            return b;
        }
        pub fn format(self: *const Self, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
            _ = fmt;
            _ = options;

            try writer.print("{s}\n", .{@typeName(Self)});

            comptime var i = 0;
            inline while (i < row) : (i += 1) {
                comptime var j = 0;
                try writer.print("{{", .{});
                inline while (j < col - 1) : (j += 1) {
                    try writer.print("{d}, ", .{self.e[i][j]});
                }
                try writer.print("{d}}}\n", .{self.e[i][j]});
            }
        }
    };
}

모든 크기에 행렬에 대응되도록 구현해봤는데... 실제로 게임에서 많이 쓰는 4x4 크기같은 정해진 행렬은 따로 최적화 코드를 짜는게 효율적이므로 별 의미는 없긴 합니다;; 역행렬, 행렬식 코드는 주석 블로그 링크에 있는 코드 보고 짰습니다. 

'Zig' 카테고리의 다른 글

tag union으로 타입마다 함수 호출하기  (0) 2024.10.31
zig에서 C 문자열 다루기  (0) 2024.10.10