diff --git a/2023210/main.cpp b/2023210/main.cpp index a62ab5b..158e683 100644 --- a/2023210/main.cpp +++ b/2023210/main.cpp @@ -15,40 +15,40 @@ int dim_sizes[32] = {0}; // int precalc[2][1048576] = {0}; -int tensor_pos(int num, int dim) { +inline int tensor_pos(int num, int dim) { return k * num + dim; } -int setm(int i, int j, int num) { +inline int setm(int i, int j, int num) { if (j <= i) { return 0; } - return m[(2 * n - i + 1) * i / 2 + j - i].min_multiply_times = num; + return m[i * n + j].min_multiply_times = num; } -int getm(int i, int j) { +inline int getm(int i, int j) { if (j <= i) { return 0; } - return m[(2 * n - i + 1) * i / 2 + j - i].min_multiply_times; + return m[i * n + j].min_multiply_times; } -int setmdim(int i, int j, int d) { +inline int setmdim(int i, int j, int d) { if (j < i) { return 0; } - return m[(2 * n - i + 1) * i / 2 + j - i].dim_status = d; + return m[i * n + j].dim_status = d; } -int getmdim(int i, int j) { +inline int getmdim(int i, int j) { if (j < i) { return 0; } - return m[(2 * n - i + 1) * i / 2 + j - i].dim_status; + return m[i * n + j].dim_status; } int find_min(int start, int end) { @@ -85,6 +85,40 @@ int find_min(int start, int end) { return min; } +int clac_min() { + for (int i = 0; i < n; i++) { + setm(i, i, 0); + } + + for (int length = 1; length <= n; length++) { + for (int start = 0; start <= n - length; start++) { + int min = 2147483647; + int needed_multiply_by_d = 1; + int dim_status = getmdim(start, start + length); + + // prepare the dims + for (int i = 0; i < k - 2; i++) { + if ((dim_status >> i) & 1) { + needed_multiply_by_d *= dim_sizes[i]; + } + } + + int matrix_partial_needed = tensors[tensor_pos(start, k - 2)] * + tensors[tensor_pos(start + length, k - 1)] * needed_multiply_by_d; + + for (int split = 0; split < length; split++) { + int total_needed = matrix_partial_needed * tensors[tensor_pos(start + split, k - 1)] + getm(start, start + split) + getm(start + split + 1, start + length); + if (total_needed < min) { + min = total_needed; + } + } + setm(start, start + length, min); + } + } + + return 0; +} + void preparedim(int length) { if (length == 1) { for (int i = 0; i < n; i++) { @@ -108,7 +142,8 @@ void preparedim(int length) { int main() { scanf("%d %d", &n, &k); tensors = new unsigned short[n * k]; - m = new CalcResult[(n + 1) * n / 2]; + // m = new CalcResult[(n + 1) * n / 2]; + m = new CalcResult[n * n]; for (int i = 0; i < 32; i++) { dim_sizes[i] = 1; @@ -137,7 +172,10 @@ int main() { preparedim(n); - printf("%d\n", find_min(0, n - 1)); + // printf("%d\n", find_min(0, n - 1)); + + clac_min(); + printf("%d\n", getm(0, n - 1)); return 0; } \ No newline at end of file