This commit is contained in:
unlockable
2024-01-24 20:52:17 +01:00
parent 6770ceac90
commit 212e810db7

View File

@@ -15,40 +15,40 @@ int dim_sizes[32] = {0};
// int precalc[2][1048576] = {0};
int tensor_pos(int num, int dim) {
inline int tensor_pos(int num, int dim) {
return k * num + dim;
}
int setm(int i, int j, int num) {
inline int setm(int i, int j, int num) {
if (j <= i) {
return 0;
}
return m[(2 * n - i + 1) * i / 2 + j - i].min_multiply_times = num;
return m[i * n + j].min_multiply_times = num;
}
int getm(int i, int j) {
inline int getm(int i, int j) {
if (j <= i) {
return 0;
}
return m[(2 * n - i + 1) * i / 2 + j - i].min_multiply_times;
return m[i * n + j].min_multiply_times;
}
int setmdim(int i, int j, int d) {
inline int setmdim(int i, int j, int d) {
if (j < i) {
return 0;
}
return m[(2 * n - i + 1) * i / 2 + j - i].dim_status = d;
return m[i * n + j].dim_status = d;
}
int getmdim(int i, int j) {
inline int getmdim(int i, int j) {
if (j < i) {
return 0;
}
return m[(2 * n - i + 1) * i / 2 + j - i].dim_status;
return m[i * n + j].dim_status;
}
int find_min(int start, int end) {
@@ -85,6 +85,40 @@ int find_min(int start, int end) {
return min;
}
int clac_min() {
for (int i = 0; i < n; i++) {
setm(i, i, 0);
}
for (int length = 1; length <= n; length++) {
for (int start = 0; start <= n - length; start++) {
int min = 2147483647;
int needed_multiply_by_d = 1;
int dim_status = getmdim(start, start + length);
// prepare the dims
for (int i = 0; i < k - 2; i++) {
if ((dim_status >> i) & 1) {
needed_multiply_by_d *= dim_sizes[i];
}
}
int matrix_partial_needed = tensors[tensor_pos(start, k - 2)] *
tensors[tensor_pos(start + length, k - 1)] * needed_multiply_by_d;
for (int split = 0; split < length; split++) {
int total_needed = matrix_partial_needed * tensors[tensor_pos(start + split, k - 1)] + getm(start, start + split) + getm(start + split + 1, start + length);
if (total_needed < min) {
min = total_needed;
}
}
setm(start, start + length, min);
}
}
return 0;
}
void preparedim(int length) {
if (length == 1) {
for (int i = 0; i < n; i++) {
@@ -108,7 +142,8 @@ void preparedim(int length) {
int main() {
scanf("%d %d", &n, &k);
tensors = new unsigned short[n * k];
m = new CalcResult[(n + 1) * n / 2];
// m = new CalcResult[(n + 1) * n / 2];
m = new CalcResult[n * n];
for (int i = 0; i < 32; i++) {
dim_sizes[i] = 1;
@@ -137,7 +172,10 @@ int main() {
preparedim(n);
printf("%d\n", find_min(0, n - 1));
// printf("%d\n", find_min(0, n - 1));
clac_min();
printf("%d\n", getm(0, n - 1));
return 0;
}