#34. 多项式乘法
这是一道模板题。
给你两个多项式,请输出乘起来后的多项式。
输入格式
第一行两个整数 n 和 m,分别表示两个多项式的次数。
第二行 n+1 个整数,分别表示第一个多项式的 0 到 n 次项前的系数。
第三行 m+1 个整数,分别表示第一个多项式的 0 到 m 次项前的系数。
输出格式
一行 n+m+1 个整数,分别表示乘起来后的多项式的 0 到 n+m 次项前的系数。
input
1 2
1 2
1 2 1
output
1 4 5 2
explanation
(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3。
限制与约定
0≤n,m≤105,保证输入中的系数大于等于 0 且小于等于 9。
时间限制:1s
空间限制:256MB
FFT模板题
#include <cmath> #include <cstdio> #include <algorithm> using namespace std; const int MaxN = 262144; const double Pi = acos(-1); int R[MaxN]; int N, M, L, T, l; struct complex { double real, imag; complex (const double &r = 0, const double &i = 0) { real = r, imag = i; } complex operator + (const complex &b) { return complex(real + b.real, imag + b.imag); } complex operator - (const complex &b) { return complex(real - b.real, imag - b.imag); } complex operator * (const complex &b) { complex c; c.real = real * b.real - imag * b.imag; c.imag = real * b.imag + imag * b.real; return c; } complex operator *= (const complex &b) { *this = *this * b; return *this; } }a[MaxN], b[MaxN]; #define real(x) a[x].real #define imag(x) a[x].imag void FFT (complex *a, const int &n, const int &res) { for (int i = 0; i < n; ++i) if (i < R[i]) swap(a[i], a[R[i]]); for (int k = 1; k < n; k <<= 1) { complex w = complex (cos(Pi / k), res * sin(Pi / k)); for (int s = 0; s < n; s += k << 1) { complex Wx = complex(1, 0); for (int i = s; i < s + k; ++i) { complex u1 = a[i], u2 = Wx * a[i + k]; a[i] = u1 + u2; a[i + k] = u1 - u2; Wx *= w; } } } if (!~res) { for (int i = 0; i < n; ++i) real(i) /= n; } } int main(){ scanf("%d%d", &N, &M); ++N, ++M; for (int i = 0; i < N; ++i) scanf("%lf", &a[i].real); for (int i = 0; i < M; ++i) scanf("%lf", &b[i].real); T = N + M - 1, L = 1, l = 1; while (L < T) L <<= 1, ++l; for (int i = 0; i < L; ++i) R[i] = (R[i >> 1] >> 1) | ((i & 1) << l - 2); FFT(a, L, 1), FFT(b, L, 1); for (int i = 0; i < L; ++i) a[i] *= b[i]; FFT(a, L, -1); for (int i = 0; i < T; ++i) printf("%d ", int(a[i].real + 0.5)); }
Ps:输出时需要四舍五入转int,不然可能会出现奇怪的‘-0’
NTT版本
#include <cmath> #include <cstdio> #include <algorithm> using namespace std; const int MaxN = 262144; const int Mod = 998244353; const int G = 3; int R[MaxN]; long long unit[2][MaxN]; int N, M, L, T, l; int a[MaxN], b[MaxN]; void FFT (int *a, const int &n, const int &res) { for (int i = 0; i < n; ++i) if (i < R[i]) swap(a[i], a[R[i]]); for (int k = 1, K = 1 ; K < n; K <<= 1, ++k) { int w = unit[res][1 << (l - k - 1)]; for (int s = 0; s < n; s += K << 1) { long long Wx = 1; for (int i = s; i < s + K; ++i) { int u1 = a[i], u2 = Wx * a[i + K] % Mod; a[i] = ((long long)u1 + u2) % Mod; a[i + K] = ((long long)u1 - u2 + Mod) % Mod; Wx = Wx * w % Mod; } } } } int pow(int x, int k) { if (!k) return 1; if (k == 1) return x; long long tmp = pow(x, k / 2); tmp = tmp * tmp % Mod; if (k % 2) tmp = tmp * x % Mod; return tmp; } int main(){ scanf("%d%d", &N, &M); ++N, ++M; for (int i = 0; i < N; ++i) scanf("%d", &a[i]); for (int i = 0; i < M; ++i) scanf("%d", &b[i]); T = N + M - 1, L = 1, l = 1; while (L < T) L <<= 1, ++l; int w = pow (G, (Mod - 1)/L); unit[1][0] = unit[0][0] = 1; for (int i = 1; i < L; ++i) { R[i] = (R[i >> 1] >> 1) | ((i & 1) << l - 2); unit[0][L - i] = unit[1][i] = unit[1][i - 1] * w % Mod; } FFT(a, L, 1), FFT(b, L, 1); for (int i = 0; i < L; ++i) a[i] = (long long)a[i] * b[i] % Mod; FFT(a, L, 0); int Q = pow(L, Mod - 2); for (int i = 0; i < T; ++i) printf("%d ", int((long long)a[i] * Q % Mod)); }