2738: 矩阵乘法
Description
给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。
Input
第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
Output
对于每组询问输出第K小的数。
Sample Input
2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
2 1
3 4
1 2 1 2 1
1 1 2 2 3
Sample Output
1
3
3
HINT
矩阵中数字是109以内的非负整数;
20%的数据:N<=100,Q<=1000;
40%的数据:N<=300,Q<=10000;
60%的数据:N<=400,Q<=30000;
100%的数据:N<=500,Q<=60000。
整体二分
定义solve(l, r , S) 表答案落在[l, r]区间内的询问集合为S
二分答案mid, 在二维的树状数组中维护 值在[1, mid]区间的数的个数
我们将矩阵中满足条件的数的个数转化为
从(1,1)到(x, y)矩阵中满足条件的数的个数 <这是一个经典模型>
对于询问i, 设在该矩阵中满足条件的数的个数为res
若 res >= k, 则答案在区间[l, mid]内
若 res < k ,则答案在区间[mid + 1, r]内
递归求出答案即可
#include <cstdio> #include <iostream> #include <algorithm> using namespace std; const int MaxN = 510; const int MaxQ = 60010; int N, M, A, Q, T; int delta[MaxN][MaxN]; struct point{ int x, y, key; point (int a = 0, int b = 0, int k = 0) { x = a, y = b, key = k; } }p[MaxN * MaxN]; struct question{ int x1, y1, x2, y2, k, w, key; }que[MaxQ], q1[MaxQ], q2[MaxQ]; bool cmp(const point &x, const point &y) { return x.key < y.key; } int lowbit(int x){ return x & -x; } void ins(const int &x, const int &y, const int &k) { for (int i = x; i <= N; i += lowbit(i)) for (int j = y; j <= N; j += lowbit(j)) delta[i][j] += k; } int query(const int &x, const int &y){ int res = 0; for (int i = x; i; i -= lowbit(i)) for (int j = y; j; j -= lowbit(j)) res += delta[i][j]; return res; } void solve(int l, int r, int ql, int qr) { if (ql > qr) return; if (l == r) { for (int i = ql; i <= qr; ++i) que[i].key = l; return; } int mid = l + r >> 1, t1 = 0, t2 = 0; while (p[T + 1].key <= mid && T < M) ++T, ins(p[T].x, p[T].y, 1); while (p[T].key > mid && T > 0) ins(p[T].x, p[T].y, -1), --T; for (int i = ql; i <= qr; ++i) { int res = query(que[i].x2, que[i].y2) + query(que[i].x1 - 1, que[i].y1 - 1) - query(que[i].x1 - 1, que[i].y2) - query(que[i].x2, que[i].y1 - 1); if (res >= que[i].k) q1[++t1] = que[i]; else q2[++t2] = que[i]; } for (int i = 1; i <= t1; ++i) que[ql + i - 1] = q1[i]; for (int i = 1; i <= t2; ++i) que[ql + t1 + i - 1] = q2[i]; solve(l, mid, ql, ql + t1 - 1); solve(mid + 1, r, ql + t1, qr); } bool cmpQ(const question &x, const question &y){ return x.w < y.w; } int main() { scanf("%d%d", &N, &Q); for (int i = 1; i <= N; ++i) for (int j = 1; j <= N; ++j) { int x; scanf("%d", &x); A = max(A, x); p[++M] = point(i, j, x); } sort(p + 1, p + M + 1, cmp); for (int i = 1; i <= Q; ++i) { scanf("%d%d%d%d%d", &que[i].x1, &que[i].y1, &que[i].x2, &que[i].y2, &que[i].k); que[i].w = i, que[i].key = 0; } solve(0, A, 1, Q); sort(que + 1, que + Q + 1, cmpQ); for (int i = 1; i <= Q; ++i) printf("%d\n", que[i].key); }