Files
DataStructureAndAlgorithm/2023210/main.cpp
unlockable 1eb2c4e303 80.
2024-01-22 22:04:29 +01:00

93 lines
2.5 KiB
C++

#include <stdio.h>
struct CalcResult {
int dim_sizes[32];
int min_multiply_times;
};
unsigned short *tensors;
// unsigned short tensors[66000];
int n = 0, k = 0;
CalcResult *m;
// CalcResult m[10000];
// int * m;
int tensor_pos(int num, int dim) {
return k * num + dim;
}
int setm(int i, int j, int num, int *dims) {
if (j < i) {
return 0;
}
m[(2 * n - i + 1) * i / 2 + j - i].min_multiply_times = num;
for (int i = 0; i < k - 2; i++) {
m[(2 * n - i + 1) * i / 2 + j - i].dim_sizes[i] = dims[i];
}
return m[(2 * n - i + 1) * i / 2 + j - i].min_multiply_times;
}
CalcResult getm(int i, int j) {
if (j <= i) {
return CalcResult{{0}, 0};
}
return m[(2 * n - i + 1) * i / 2 + j - i];
}
CalcResult find_min(int start, int end) {
if (start == end || getm(start, end).min_multiply_times != 0) {
return getm(start, end);
}
int min = 2147483647;
int needed_multiply_by_d = 1;
// prepare the dims
int segment_dim_size[32] = {0};
for (int dim = 0; dim < k - 2; dim++) {
for (int tens = start; tens <= end; tens++) {
int d_k = tensors[tensor_pos(tens, dim)];
if (d_k != 1) {
needed_multiply_by_d *= d_k;
segment_dim_size[dim] = d_k;
break;
}
segment_dim_size[dim] = 1;
}
}
for (int split_pos = start; split_pos < end; split_pos++) {
// int a = tensors[tensor_pos(start, k - 2)];
// a = tensors[tensor_pos(split_pos, k - 1)];
// a = tensor_pos(end, k - 1);
int total_needed = tensors[tensor_pos(start, k - 2)] *
tensors[tensor_pos(split_pos, k - 1)] * tensors[tensor_pos(end, k - 1)] *
needed_multiply_by_d;
CalcResult left = find_min(start, split_pos), right = find_min(split_pos + 1, end);
total_needed += left.min_multiply_times + right.min_multiply_times;
if (total_needed < min) {
min = total_needed;
}
}
setm(start, end, min, segment_dim_size);
return getm(start, end);
}
int main() {
scanf("%d %d", &n, &k);
tensors = new unsigned short[n * k];
m = new CalcResult[(n + 1) * n / 2];
for (int tensor_count = 0; tensor_count < n; tensor_count++) {
for (int dim = 0; dim < k; dim++) {
scanf("%hu", &tensors[tensor_pos(tensor_count, dim)]);
}
}
printf("%d\n", find_min(0, n - 1).min_multiply_times);
return 0;
}